diff --git a/SimpleHttpServer/HttpServer.cs b/SimpleHttpServer/HttpServer.cs index b738e30..1473e8a 100644 --- a/SimpleHttpServer/HttpServer.cs +++ b/SimpleHttpServer/HttpServer.cs @@ -84,8 +84,11 @@ public sealed class HttpServer { RegisterConverter(); } - private readonly Dictionary<(string path, string rType), EndpointInvocationInfo> simpleEndpointMethodInfos = new(); + private readonly MultiKeyDictionary simpleEndpointMethodInfos = new(); // requestmethod, path + private readonly MultiKeyDictionary pathEndpointMethodInfos = new(); // requestmethod, path + private readonly Dictionary> pathEndpointMethodInfosTrees = new(); // reqmethod : pathtree private static readonly Type[] expectedEndpointParameterTypes = new[] { typeof(RequestContext) }; + internal static readonly int expectedEndpointParameterPrefixCount = expectedEndpointParameterTypes.Length; public void RegisterEndpointsFromType(Func? instanceFactory = null) where T : class { // T cannot be static, as generic args must be nonstatic if (stringToTypeParameterConverters.Count == 0) @@ -107,27 +110,48 @@ public sealed class HttpServer { Assert(mi.IsPublic, $"Method tagged with HttpEndpointAttribute must be public! ({GetFancyMethodName()})"); var methodParams = mi.GetParameters(); + // check the mandatory prefix parameters Assert(methodParams.Length >= expectedEndpointParameterTypes.Length); for (int i = 0; i < expectedEndpointParameterTypes.Length; i++) { Assert(methodParams[i].ParameterType.IsAssignableFrom(expectedEndpointParameterTypes[i]), $"Parameter at index {i} of {GetFancyMethodName()} is of a type that cannot contain the expected type {expectedEndpointParameterTypes[i].FullName}."); } + // check return type Assert(mi.ReturnType == typeof(Task), $"Return type of {GetFancyMethodName()} is not {typeof(Task)}!"); - + // check the rest of the method parameters var qparams = new List(); + var pparams = new List(); + int mParamIndex = expectedEndpointParameterTypes.Length; for (int i = expectedEndpointParameterTypes.Length; i < methodParams.Length; i++) { var par = methodParams[i]; var attr = par.GetCustomAttribute(false); - qparams.Add(new( - attr?.Name ?? par.Name ?? throw new ArgumentException($"C# variable name of parameter at index {i} of method {GetFancyMethodName()} is null!"), - par.ParameterType, - attr?.IsOptional ?? false) - ); + var pathAttr = par.GetCustomAttribute(false); + + if (attr != null && pathAttr != null) { + throw new ArgumentException($"A method argument cannot be tagged with both {nameof(ParameterAttribute)} and {nameof(PathParameterAttribute)}"); + } if (!stringToTypeParameterConverters.ContainsKey(par.ParameterType)) { - throw new MissingParameterConverterException($"Parameter converter for type {par.ParameterType} has not been registered (yet)!"); + throw new MissingParameterConverterException($"Parameter converter for type {par.ParameterType} for parameter at index {i} of method {GetFancyMethodName()} has not been registered (yet)!"); + } + + if (pathAttr != null) { // parameter is a path param + + pparams.Add(new( + pathAttr?.Name ?? throw new ArgumentException($"C# variable name of path parameter at index {i} of method {GetFancyMethodName()} is null!"), + par.ParameterType, + mParamIndex++ + ) + ); + } else { // parameter is a normal query param + qparams.Add(new( + attr?.Name ?? par.Name ?? throw new ArgumentException($"C# variable name of query parameter at index {i} of method {GetFancyMethodName()} is null!"), + par.ParameterType, + mParamIndex++, + attr?.IsOptional ?? false) + ); } } @@ -141,17 +165,35 @@ public sealed class HttpServer { foreach (var location in attrib.Locations) { var normLocation = NormalizeUrlPath(location); - int idx = normLocation.IndexOf('{'); - if (idx >= 0) { - // this path contains path parameters - throw new NotImplementedException("Path parameters are not yet implemented!"); - } var reqMethod = Enum.GetName(attrib.RequestMethod) ?? throw new ArgumentException("Request method was undefined"); - mainLogger.Information($"Registered endpoint: '{reqMethod} {normLocation}'"); - simpleEndpointMethodInfos.Add((normLocation, reqMethod), new EndpointInvocationInfo(mi, qparams, requiredChecks, classInstance)); + + var pparamsCopy = new List(pparams); + var splittedLocation = location[1..].Split('/'); + for (int i = 0; i < pparamsCopy.Count; i++) { + var pp = pparamsCopy[i]; + var idx = Array.IndexOf(splittedLocation, pp.Name); + Assert(idx != -1, "Path parameter name was incorrect?"); + pp.SegmentStartPos = idx; + pparamsCopy[i] = pp; + } + + var epInvocInfo = new EndpointInvocationInfo(mi, pparamsCopy, qparams, requiredChecks, classInstance); + if (pparams.Any()) { + mainLogger.Information($"Registered path endpoint: '{reqMethod} {normLocation}'"); + Assert(normLocation[0] == '/'); + pathEndpointMethodInfos.Add(reqMethod, normLocation[1..], epInvocInfo); + } else { + mainLogger.Information($"Registered simple endpoint: '{reqMethod} {normLocation}'"); + simpleEndpointMethodInfos.Add(reqMethod, normLocation, epInvocInfo); + } } } + + // rebuild path trees + pathEndpointMethodInfosTrees.Clear(); + foreach (var (reqMethod, d2) in pathEndpointMethodInfos.backingDict) + pathEndpointMethodInfosTrees.Add(reqMethod, new(d2)); } /// @@ -217,13 +259,23 @@ public sealed class HttpServer { } try { - if (simpleEndpointMethodInfos.TryGetValue((reqPath, requestMethod), out var endpointInvocationInfo)) { + /* Finding the endpoint that should process the request: + * 1. Try to see if there is a simple endpoint where request method and path match + * 2. Otherwise, try to see if a path-parameter-endpoint matches (duplicates throw an error on startup) + * 3. Otherwise, check if it is inside a static serve path + * 4. Otherwise, show 404 page */ + + EndpointInvocationInfo? pathEndpointInvocationInfo = null; + if (simpleEndpointMethodInfos.TryGetValue(requestMethod, reqPath, out var simpleEndpointInvocationInfo) || + pathEndpointMethodInfosTrees.TryGetValue(requestMethod, out var pt) && pt.TryGetPath(reqPath, out pathEndpointInvocationInfo)) { // try to find simple or pathparam-endpoint + var endpointInvocationInfo = simpleEndpointInvocationInfo ?? pathEndpointInvocationInfo ?? throw new Exception("retrieved endpoint is somehow null"); var mi = endpointInvocationInfo.methodInfo; var qparams = endpointInvocationInfo.queryParameters; + var pparams = endpointInvocationInfo.pathParameters; var args = splitted.Length == 2 ? splitted[1] : null; var parsedQParams = new Dictionary(); - var convertedQParamValues = new object[qparams.Count + 1]; + var convertedMParamValues = new object[expectedEndpointParameterTypes.Length + pparams.Count + qparams.Count]; // run the checks to see if the client is allowed to make this request if (!endpointInvocationInfo.CheckAll(rc.ListenerContext.Request)) { // if any check failed return Forbidden @@ -251,32 +303,45 @@ public sealed class HttpServer { if (parsedQParams.TryGetValue(qparam.Name, out var qparamValue)) { if (stringToTypeParameterConverters[qparam.Type].TryConvertFromString(qparamValue, out object objRes)) { - convertedQParamValues[i] = objRes; + convertedMParamValues[qparam.ArgPos] = objRes; } else { await HandleDefaultErrorPageAsync(rc, HttpStatusCode.BadRequest); return; } } else { if (qparam.IsOptional) { - convertedQParamValues[i] = null!; + convertedMParamValues[qparam.ArgPos] = null!; } else { await HandleDefaultErrorPageAsync(rc, HttpStatusCode.BadRequest, $"Missing required query parameter {qparam.Name}"); return; } } } - } else { + } else { // check for missing query parameters var requiredParams = qparams.Where(x => !x.IsOptional).Select(x => $"'{x.Name}'").ToList(); if (requiredParams.Any()) { await HandleDefaultErrorPageAsync(rc, HttpStatusCode.BadRequest, $"Missing required query parameter(s): {string.Join(",", requiredParams)}"); return; } } - convertedQParamValues[0] = rc; + if (pparams.Count != 0) { + var splittedReqPath = reqPath[1..].Split('/'); + for (int i = 0; i < pparams.Count; i++) { + var pparam = pparams[i]; + if (pparam.IsCatchAll) + convertedMParamValues[pparam.ArgPos] = string.Join('/', splittedReqPath[pparam.SegmentStartPos..]); + else + convertedMParamValues[pparam.ArgPos] = splittedReqPath[pparam.SegmentStartPos]; + } + } + + convertedMParamValues[0] = rc; rc.ParsedParameters = parsedQParams.AsReadOnly(); - await (Task) (mi.Invoke(endpointInvocationInfo.typeInstanceReference, convertedQParamValues) ?? throw new NullReferenceException("Website func returned null unexpectedly")); - } else { + // todo read and convert pathparams + + await (Task) (mi.Invoke(endpointInvocationInfo.typeInstanceReference, convertedMParamValues) ?? throw new NullReferenceException("Website func returned null unexpectedly")); + } else { // try to find suitable static serve path if (requestMethod == "GET") foreach (var (k, v) in staticServePaths) { if (reqPath.StartsWith(k)) { // do a static serve diff --git a/SimpleHttpServer/Types/EndpointInvocationInfo.cs b/SimpleHttpServer/Types/EndpointInvocationInfo.cs index ea4424b..adc6185 100644 --- a/SimpleHttpServer/Types/EndpointInvocationInfo.cs +++ b/SimpleHttpServer/Types/EndpointInvocationInfo.cs @@ -2,23 +2,46 @@ using System.Reflection; namespace SimpleHttpServer.Types; -internal readonly struct EndpointInvocationInfo { - internal record struct QueryParameterInfo(string Name, Type Type, bool IsOptional); +internal record EndpointInvocationInfo { + //internal record struct QueryParameterInfo(string Name, Type Type, bool isPathParam, bool Path_isCatchAll, bool Query_IsOptional) { + // public static QueryParameterInfo CreatePathParam(string name, Type type) => new(name, type, false, name == "$*", false); + // public static QueryParameterInfo CreateQueryParam(string name, Type type, bool isOptional) => new(name, type, false, false, isOptional); + //} + internal record struct PathParameterInfo(string Name, Type Type, int ArgPos, int SegmentStartPos, bool IsCatchAll) { + public PathParameterInfo(string name, Type type, int argPos) : this(name, type, argPos, -1, name == "$*") { } + } + + internal record struct QueryParameterInfo(string Name, Type Type, int ArgPos, bool IsOptional); internal readonly MethodInfo methodInfo; internal readonly List queryParameters; + internal readonly List pathParameters; internal readonly InternalEndpointCheckAttribute[] requiredChecks; /// /// a reference to the object in which this method is defined (or null if the class is static) /// internal readonly object? typeInstanceReference; - public EndpointInvocationInfo(MethodInfo methodInfo, List queryParameters, InternalEndpointCheckAttribute[] requiredChecks, object? typeInstanceReference) { + public EndpointInvocationInfo(MethodInfo methodInfo, List pathParameters, List queryParameters, InternalEndpointCheckAttribute[] requiredChecks, + object? typeInstanceReference) { + this.methodInfo = methodInfo ?? throw new ArgumentNullException(nameof(methodInfo)); this.queryParameters = queryParameters ?? throw new ArgumentNullException(nameof(queryParameters)); + this.pathParameters = pathParameters ?? throw new ArgumentNullException(nameof(pathParameters)); this.requiredChecks = requiredChecks; this.typeInstanceReference = typeInstanceReference; + + if (pathParameters.Any()) { + Assert(pathParameters.Count(x => x.IsCatchAll) <= 1); // at most one catchall parameter + var argPoses = pathParameters.Select(x => x.ArgPos).Concat(queryParameters.Select(x => x.ArgPos)).ToArray(); + var argCnt = pathParameters.Count + queryParameters.Count; + Assert(argPoses.Distinct().Count() == argCnt); // ArgPoses must be unique + Assert(argPoses.Min() == HttpServer.expectedEndpointParameterPrefixCount); // ArgPoses must start from just after the prefix + Assert(argPoses.Max() == HttpServer.expectedEndpointParameterPrefixCount + argCnt - 1); // ArgPoses must be contiguous + + Assert(pathParameters.All(x => x.SegmentStartPos != -1)); + } } - public readonly bool CheckAll(HttpListenerRequest req) => requiredChecks.All(x => x.Check(req)); + public bool CheckAll(HttpListenerRequest req) => requiredChecks.All(x => x.Check(req)); } diff --git a/SimpleHttpServer/Types/MultiKeyDictionary.cs b/SimpleHttpServer/Types/MultiKeyDictionary.cs new file mode 100644 index 0000000..fccf966 --- /dev/null +++ b/SimpleHttpServer/Types/MultiKeyDictionary.cs @@ -0,0 +1,22 @@ +using System.Diagnostics.CodeAnalysis; + +namespace SimpleHttpServer.Types; +internal class MultiKeyDictionary where K1 : notnull where K2 : notnull { + internal readonly Dictionary> backingDict = new(); + public MultiKeyDictionary() { } + + public void Add(K1 k1, K2 k2, V value) { + if (!backingDict.TryGetValue(k1, out var d2)) + d2 = new(); + d2.Add(k2, value); + backingDict[k1] = d2; + } + + public bool TryGetValue(K1 k1, K2 k2, [MaybeNullWhen(false)] out V value) { + if (backingDict.TryGetValue(k1, out var d2) && d2.TryGetValue(k2, out value)) + return true; + + value = default; + return false; + } +} diff --git a/SimpleHttpServer/Types/ParameterAttribute.cs b/SimpleHttpServer/Types/ParameterAttribute.cs index b3f4148..07bfb70 100644 --- a/SimpleHttpServer/Types/ParameterAttribute.cs +++ b/SimpleHttpServer/Types/ParameterAttribute.cs @@ -5,9 +5,6 @@ /// [AttributeUsage(AttributeTargets.Parameter, Inherited = false, AllowMultiple = false)] public sealed class ParameterAttribute : Attribute { - // See the attribute guidelines at - // http://go.microsoft.com/fwlink/?LinkId=85236 - public string Name { get; } public bool IsOptional { get; } public ParameterAttribute(string name, bool isOptional = false) { diff --git a/SimpleHttpServer/Types/PathParameterAttribute.cs b/SimpleHttpServer/Types/PathParameterAttribute.cs new file mode 100644 index 0000000..f74b770 --- /dev/null +++ b/SimpleHttpServer/Types/PathParameterAttribute.cs @@ -0,0 +1,28 @@ +namespace SimpleHttpServer.Types; + +/// +/// Specifies the name of a http endpoint path parameter. Path parameter names must be in the format $1, $2, $3, ..., and the end of the path may be $* +/// +[AttributeUsage(AttributeTargets.Parameter, Inherited = false, AllowMultiple = false)] +public sealed class PathParameterAttribute : Attribute { + public string Name { get; } + public PathParameterAttribute(string name) { + if (string.IsNullOrWhiteSpace(name)) { + throw new ArgumentException($"'{nameof(name)}' cannot be null or whitespace.", nameof(name)); + } + + if (!name.StartsWith('$')) { + throw new ArgumentException($"'{nameof(name)}' must start with $.", nameof(name)); + } + + if (name.Contains(' ')) { + throw new ArgumentException($"'{nameof(name)}' must not contain spaces.", nameof(name)); + } + + if (!uint.TryParse(name[1..], out _) && name != "$*") { + throw new ArgumentException($"'{nameof(name)}' must only consist of spaces or be exactly '$*'.", nameof(name)); + } + + Name = name; + } +} diff --git a/SimpleHttpServer/Types/PathTree.cs b/SimpleHttpServer/Types/PathTree.cs new file mode 100644 index 0000000..1d90615 --- /dev/null +++ b/SimpleHttpServer/Types/PathTree.cs @@ -0,0 +1,107 @@ +using System.Data; +using System.Diagnostics.CodeAnalysis; + +namespace SimpleHttpServer.Types; + +internal class PathTree where T : class { + private readonly Node? rootNode = null; + + public PathTree() : this(new()) { } + public PathTree(Dictionary dict) { + if (dict == null || dict.Count == 0) + return; + + rootNode = new(); + var currNode = rootNode; + var unpackedPaths = dict.Keys.Select(p => p.Split('/').ToArray()).ToArray(); + var unpackedLeafData = dict.Values.ToArray(); + for (int i = 0; i < unpackedPaths.Length; i++) { + var path = unpackedPaths[i]; + var catchallidx = Array.IndexOf(path, "$*"); + if (catchallidx != -1 && catchallidx != path.Length - 1) { + throw new Exception($"Found illegal catchall-wildcard in path: '{string.Join('/', path)}'"); + } + + var leafdata = unpackedLeafData[i] ?? throw new ArgumentNullException("Leafdata must not be null!"); + rootNode.AddSuccessor(path, leafdata); + } + } + + internal bool TryGetPath(string reqPath, [MaybeNullWhen(false)] out T endpoint) { + if (rootNode == null) { + endpoint = null; + return false; + } + + // try to find path-match + Node currNode = rootNode; + Assert(reqPath[0] == '/'); + var splittedPath = reqPath[1..].Split("/"); + Node? lastCatchallNode = null; + for (int i = 0; i < splittedPath.Length; i++) { + + // keep track of the current best catchallNode + if (currNode.catchAllNext != null) { + lastCatchallNode = currNode.catchAllNext; + } + + var seg = splittedPath[i]; + if (currNode.next?.TryGetValue(seg, out var next) == true) { // look for an explicit path to follow greedily + currNode = next; + } else if (currNode.pathWildcardNext != null) { // otherwise look for a single-wildcard to follow + currNode = currNode.pathWildcardNext; + } else { // otherwise we are done, there is no valid path --> fall back to the most specific catchall + endpoint = lastCatchallNode?.leafData; + return lastCatchallNode != null; + } + } + + // return found path + endpoint = currNode.leafData; + return endpoint != null; + } + + private class Node { + public T? leafData = null; // null means that this is a node without a value (e.g. when it is just part of a path) + public Dictionary? next = null; + public Node? pathWildcardNext = null; // path wildcard + public Node? catchAllNext = null; // trailing-catchall wildcard + + public void AddSuccessor(string[] segments, T newLeafData) { + if (segments.Length == 0) { // actually add the data to this node + Assert(leafData == null); + leafData = newLeafData; + return; + } + + var seg = segments[0]; + bool newIsWildcard = seg.Length > 1 && seg[0] == '$'; + if (newIsWildcard) { + bool newIsCatchallWildcard = newIsWildcard && seg.Length == 2 && seg[1] == '*'; + if (newIsCatchallWildcard) { // this is a catchall wildcard + Assert(catchAllNext == null); + catchAllNext = new(); + catchAllNext.AddSuccessor(segments[1..], newLeafData); + return; + } else { // must be single wildcard otherwise + Assert(pathWildcardNext == null); + pathWildcardNext = new(); + pathWildcardNext.AddSuccessor(segments[1..], newLeafData); + return; + } + } + + // otherwise we want to add a new constant path successor + if (next == null) { + next = new(); + } + + if (next.TryGetValue(seg, out var existingNode)) { + existingNode.AddSuccessor(segments[1..], newLeafData); + } else { + var newNode = next[seg] = new(); + newNode.AddSuccessor(segments[1..], newLeafData); + } + } + } +} \ No newline at end of file