finish implementing path parameters

This commit is contained in:
GHXX 2024-08-11 04:43:20 +02:00
parent a24543063b
commit fecd40cd57
5 changed files with 236 additions and 27 deletions

View File

@ -84,8 +84,11 @@ public sealed class HttpServer {
RegisterConverter<decimal>(); RegisterConverter<decimal>();
} }
private readonly Dictionary<(string path, string rType), EndpointInvocationInfo> simpleEndpointMethodInfos = new(); private readonly MultiKeyDictionary<string, string, EndpointInvocationInfo> simpleEndpointMethodInfos = new(); // requestmethod, path
private readonly MultiKeyDictionary<string, string, EndpointInvocationInfo> pathEndpointMethodInfos = new(); // requestmethod, path
private readonly Dictionary<string, PathTree<EndpointInvocationInfo>> pathEndpointMethodInfosTrees = new(); // reqmethod : pathtree
private static readonly Type[] expectedEndpointParameterTypes = new[] { typeof(RequestContext) }; private static readonly Type[] expectedEndpointParameterTypes = new[] { typeof(RequestContext) };
internal static readonly int expectedEndpointParameterPrefixCount = expectedEndpointParameterTypes.Length;
public void RegisterEndpointsFromType<T>(Func<T>? instanceFactory = null) where T : class { // T cannot be static, as generic args must be nonstatic public void RegisterEndpointsFromType<T>(Func<T>? instanceFactory = null) where T : class { // T cannot be static, as generic args must be nonstatic
if (stringToTypeParameterConverters.Count == 0) if (stringToTypeParameterConverters.Count == 0)
@ -119,6 +122,8 @@ public sealed class HttpServer {
// check the rest of the method parameters // check the rest of the method parameters
var qparams = new List<QueryParameterInfo>(); var qparams = new List<QueryParameterInfo>();
var pparams = new List<PathParameterInfo>();
int mParamIndex = expectedEndpointParameterTypes.Length;
for (int i = expectedEndpointParameterTypes.Length; i < methodParams.Length; i++) { for (int i = expectedEndpointParameterTypes.Length; i < methodParams.Length; i++) {
var par = methodParams[i]; var par = methodParams[i];
var attr = par.GetCustomAttribute<ParameterAttribute>(false); var attr = par.GetCustomAttribute<ParameterAttribute>(false);
@ -128,18 +133,25 @@ public sealed class HttpServer {
throw new ArgumentException($"A method argument cannot be tagged with both {nameof(ParameterAttribute)} and {nameof(PathParameterAttribute)}"); throw new ArgumentException($"A method argument cannot be tagged with both {nameof(ParameterAttribute)} and {nameof(PathParameterAttribute)}");
} }
if (!stringToTypeParameterConverters.ContainsKey(par.ParameterType)) {
throw new MissingParameterConverterException($"Parameter converter for type {par.ParameterType} for parameter at index {i} of method {GetFancyMethodName()} has not been registered (yet)!");
}
if (pathAttr != null) { // parameter is a path param if (pathAttr != null) { // parameter is a path param
} else { // parameter is a normal one pparams.Add(new(
qparams.Add(new( pathAttr?.Name ?? throw new ArgumentException($"C# variable name of path parameter at index {i} of method {GetFancyMethodName()} is null!"),
attr?.Name ?? par.Name ?? throw new ArgumentException($"C# variable name of parameter at index {i} of method {GetFancyMethodName()} is null!"),
par.ParameterType, par.ParameterType,
mParamIndex++
)
);
} else { // parameter is a normal query param
qparams.Add(new(
attr?.Name ?? par.Name ?? throw new ArgumentException($"C# variable name of query parameter at index {i} of method {GetFancyMethodName()} is null!"),
par.ParameterType,
mParamIndex++,
attr?.IsOptional ?? false) attr?.IsOptional ?? false)
); );
if (!stringToTypeParameterConverters.ContainsKey(par.ParameterType)) {
throw new MissingParameterConverterException($"Parameter converter for type {par.ParameterType} has not been registered (yet)!");
}
} }
} }
@ -153,19 +165,37 @@ public sealed class HttpServer {
foreach (var location in attrib.Locations) { foreach (var location in attrib.Locations) {
var normLocation = NormalizeUrlPath(location); var normLocation = NormalizeUrlPath(location);
int idx = normLocation.IndexOf('{');
if (idx >= 0) {
// this path contains path parameters
throw new NotImplementedException("Path parameters are not yet implemented!");
}
var reqMethod = Enum.GetName(attrib.RequestMethod) ?? throw new ArgumentException("Request method was undefined"); var reqMethod = Enum.GetName(attrib.RequestMethod) ?? throw new ArgumentException("Request method was undefined");
mainLogger.Information($"Registered endpoint: '{reqMethod} {normLocation}'");
simpleEndpointMethodInfos.Add((normLocation, reqMethod), new EndpointInvocationInfo(mi, qparams, requiredChecks, classInstance)); var pparamsCopy = new List<PathParameterInfo>(pparams);
var splittedLocation = location[1..].Split('/');
for (int i = 0; i < pparamsCopy.Count; i++) {
var pp = pparamsCopy[i];
var idx = Array.IndexOf(splittedLocation, pp.Name);
Assert(idx != -1, "Path parameter name was incorrect?");
pp.SegmentStartPos = idx;
pparamsCopy[i] = pp;
}
var epInvocInfo = new EndpointInvocationInfo(mi, pparamsCopy, qparams, requiredChecks, classInstance);
if (pparams.Any()) {
mainLogger.Information($"Registered path endpoint: '{reqMethod} {normLocation}'");
Assert(normLocation[0] == '/');
pathEndpointMethodInfos.Add(reqMethod, normLocation[1..], epInvocInfo);
} else {
mainLogger.Information($"Registered simple endpoint: '{reqMethod} {normLocation}'");
simpleEndpointMethodInfos.Add(reqMethod, normLocation, epInvocInfo);
} }
} }
} }
// rebuild path trees
pathEndpointMethodInfosTrees.Clear();
foreach (var (reqMethod, d2) in pathEndpointMethodInfos.backingDict)
pathEndpointMethodInfosTrees.Add(reqMethod, new(d2));
}
/// <summary> /// <summary>
/// Serves all files located in <paramref name="filesystemDirectory"/> on a website path that is relative to <paramref name="requestPath"/>, /// Serves all files located in <paramref name="filesystemDirectory"/> on a website path that is relative to <paramref name="requestPath"/>,
/// while restricting requests to inside the local filesystem directory. Static serving has a lower priority than registering an endpoint. /// while restricting requests to inside the local filesystem directory. Static serving has a lower priority than registering an endpoint.
@ -229,13 +259,23 @@ public sealed class HttpServer {
} }
try { try {
if (simpleEndpointMethodInfos.TryGetValue((reqPath, requestMethod), out var endpointInvocationInfo)) { /* Finding the endpoint that should process the request:
* 1. Try to see if there is a simple endpoint where request method and path match
* 2. Otherwise, try to see if a path-parameter-endpoint matches (duplicates throw an error on startup)
* 3. Otherwise, check if it is inside a static serve path
* 4. Otherwise, show 404 page */
EndpointInvocationInfo? pathEndpointInvocationInfo = null;
if (simpleEndpointMethodInfos.TryGetValue(requestMethod, reqPath, out var simpleEndpointInvocationInfo) ||
pathEndpointMethodInfosTrees.TryGetValue(requestMethod, out var pt) && pt.TryGetPath(reqPath, out pathEndpointInvocationInfo)) { // try to find simple or pathparam-endpoint
var endpointInvocationInfo = simpleEndpointInvocationInfo ?? pathEndpointInvocationInfo ?? throw new Exception("retrieved endpoint is somehow null");
var mi = endpointInvocationInfo.methodInfo; var mi = endpointInvocationInfo.methodInfo;
var qparams = endpointInvocationInfo.queryParameters; var qparams = endpointInvocationInfo.queryParameters;
var pparams = endpointInvocationInfo.pathParameters;
var args = splitted.Length == 2 ? splitted[1] : null; var args = splitted.Length == 2 ? splitted[1] : null;
var parsedQParams = new Dictionary<string, string>(); var parsedQParams = new Dictionary<string, string>();
var convertedQParamValues = new object[qparams.Count + 1]; var convertedMParamValues = new object[expectedEndpointParameterTypes.Length + pparams.Count + qparams.Count];
// run the checks to see if the client is allowed to make this request // run the checks to see if the client is allowed to make this request
if (!endpointInvocationInfo.CheckAll(rc.ListenerContext.Request)) { // if any check failed return Forbidden if (!endpointInvocationInfo.CheckAll(rc.ListenerContext.Request)) { // if any check failed return Forbidden
@ -263,32 +303,45 @@ public sealed class HttpServer {
if (parsedQParams.TryGetValue(qparam.Name, out var qparamValue)) { if (parsedQParams.TryGetValue(qparam.Name, out var qparamValue)) {
if (stringToTypeParameterConverters[qparam.Type].TryConvertFromString(qparamValue, out object objRes)) { if (stringToTypeParameterConverters[qparam.Type].TryConvertFromString(qparamValue, out object objRes)) {
convertedQParamValues[i] = objRes; convertedMParamValues[qparam.ArgPos] = objRes;
} else { } else {
await HandleDefaultErrorPageAsync(rc, HttpStatusCode.BadRequest); await HandleDefaultErrorPageAsync(rc, HttpStatusCode.BadRequest);
return; return;
} }
} else { } else {
if (qparam.IsOptional) { if (qparam.IsOptional) {
convertedQParamValues[i] = null!; convertedMParamValues[qparam.ArgPos] = null!;
} else { } else {
await HandleDefaultErrorPageAsync(rc, HttpStatusCode.BadRequest, $"Missing required query parameter {qparam.Name}"); await HandleDefaultErrorPageAsync(rc, HttpStatusCode.BadRequest, $"Missing required query parameter {qparam.Name}");
return; return;
} }
} }
} }
} else { } else { // check for missing query parameters
var requiredParams = qparams.Where(x => !x.IsOptional).Select(x => $"'{x.Name}'").ToList(); var requiredParams = qparams.Where(x => !x.IsOptional).Select(x => $"'{x.Name}'").ToList();
if (requiredParams.Any()) { if (requiredParams.Any()) {
await HandleDefaultErrorPageAsync(rc, HttpStatusCode.BadRequest, $"Missing required query parameter(s): {string.Join(",", requiredParams)}"); await HandleDefaultErrorPageAsync(rc, HttpStatusCode.BadRequest, $"Missing required query parameter(s): {string.Join(",", requiredParams)}");
return; return;
} }
} }
convertedQParamValues[0] = rc; if (pparams.Count != 0) {
var splittedReqPath = reqPath[1..].Split('/');
for (int i = 0; i < pparams.Count; i++) {
var pparam = pparams[i];
if (pparam.IsCatchAll)
convertedMParamValues[pparam.ArgPos] = string.Join('/', splittedReqPath[pparam.SegmentStartPos..]);
else
convertedMParamValues[pparam.ArgPos] = splittedReqPath[pparam.SegmentStartPos];
}
}
convertedMParamValues[0] = rc;
rc.ParsedParameters = parsedQParams.AsReadOnly(); rc.ParsedParameters = parsedQParams.AsReadOnly();
await (Task) (mi.Invoke(endpointInvocationInfo.typeInstanceReference, convertedQParamValues) ?? throw new NullReferenceException("Website func returned null unexpectedly")); // todo read and convert pathparams
} else {
await (Task) (mi.Invoke(endpointInvocationInfo.typeInstanceReference, convertedMParamValues) ?? throw new NullReferenceException("Website func returned null unexpectedly"));
} else { // try to find suitable static serve path
if (requestMethod == "GET") if (requestMethod == "GET")
foreach (var (k, v) in staticServePaths) { foreach (var (k, v) in staticServePaths) {
if (reqPath.StartsWith(k)) { // do a static serve if (reqPath.StartsWith(k)) { // do a static serve

View File

@ -2,23 +2,46 @@
using System.Reflection; using System.Reflection;
namespace SimpleHttpServer.Types; namespace SimpleHttpServer.Types;
internal readonly struct EndpointInvocationInfo { internal record EndpointInvocationInfo {
internal record struct QueryParameterInfo(string Name, Type Type, bool IsOptional); //internal record struct QueryParameterInfo(string Name, Type Type, bool isPathParam, bool Path_isCatchAll, bool Query_IsOptional) {
// public static QueryParameterInfo CreatePathParam(string name, Type type) => new(name, type, false, name == "$*", false);
// public static QueryParameterInfo CreateQueryParam(string name, Type type, bool isOptional) => new(name, type, false, false, isOptional);
//}
internal record struct PathParameterInfo(string Name, Type Type, int ArgPos, int SegmentStartPos, bool IsCatchAll) {
public PathParameterInfo(string name, Type type, int argPos) : this(name, type, argPos, -1, name == "$*") { }
}
internal record struct QueryParameterInfo(string Name, Type Type, int ArgPos, bool IsOptional);
internal readonly MethodInfo methodInfo; internal readonly MethodInfo methodInfo;
internal readonly List<QueryParameterInfo> queryParameters; internal readonly List<QueryParameterInfo> queryParameters;
internal readonly List<PathParameterInfo> pathParameters;
internal readonly InternalEndpointCheckAttribute[] requiredChecks; internal readonly InternalEndpointCheckAttribute[] requiredChecks;
/// <summary> /// <summary>
/// a reference to the object in which this method is defined (or null if the class is static) /// a reference to the object in which this method is defined (or null if the class is static)
/// </summary> /// </summary>
internal readonly object? typeInstanceReference; internal readonly object? typeInstanceReference;
public EndpointInvocationInfo(MethodInfo methodInfo, List<QueryParameterInfo> queryParameters, InternalEndpointCheckAttribute[] requiredChecks, object? typeInstanceReference) { public EndpointInvocationInfo(MethodInfo methodInfo, List<PathParameterInfo> pathParameters, List<QueryParameterInfo> queryParameters, InternalEndpointCheckAttribute[] requiredChecks,
object? typeInstanceReference) {
this.methodInfo = methodInfo ?? throw new ArgumentNullException(nameof(methodInfo)); this.methodInfo = methodInfo ?? throw new ArgumentNullException(nameof(methodInfo));
this.queryParameters = queryParameters ?? throw new ArgumentNullException(nameof(queryParameters)); this.queryParameters = queryParameters ?? throw new ArgumentNullException(nameof(queryParameters));
this.pathParameters = pathParameters ?? throw new ArgumentNullException(nameof(pathParameters));
this.requiredChecks = requiredChecks; this.requiredChecks = requiredChecks;
this.typeInstanceReference = typeInstanceReference; this.typeInstanceReference = typeInstanceReference;
if (pathParameters.Any()) {
Assert(pathParameters.Count(x => x.IsCatchAll) <= 1); // at most one catchall parameter
var argPoses = pathParameters.Select(x => x.ArgPos).Concat(queryParameters.Select(x => x.ArgPos)).ToArray();
var argCnt = pathParameters.Count + queryParameters.Count;
Assert(argPoses.Distinct().Count() == argCnt); // ArgPoses must be unique
Assert(argPoses.Min() == HttpServer.expectedEndpointParameterPrefixCount); // ArgPoses must start from just after the prefix
Assert(argPoses.Max() == HttpServer.expectedEndpointParameterPrefixCount + argCnt - 1); // ArgPoses must be contiguous
Assert(pathParameters.All(x => x.SegmentStartPos != -1));
}
} }
public readonly bool CheckAll(HttpListenerRequest req) => requiredChecks.All(x => x.Check(req)); public bool CheckAll(HttpListenerRequest req) => requiredChecks.All(x => x.Check(req));
} }

View File

@ -0,0 +1,22 @@
using System.Diagnostics.CodeAnalysis;
namespace SimpleHttpServer.Types;
internal class MultiKeyDictionary<K1, K2, V> where K1 : notnull where K2 : notnull {
internal readonly Dictionary<K1, Dictionary<K2, V>> backingDict = new();
public MultiKeyDictionary() { }
public void Add(K1 k1, K2 k2, V value) {
if (!backingDict.TryGetValue(k1, out var d2))
d2 = new();
d2.Add(k2, value);
backingDict[k1] = d2;
}
public bool TryGetValue(K1 k1, K2 k2, [MaybeNullWhen(false)] out V value) {
if (backingDict.TryGetValue(k1, out var d2) && d2.TryGetValue(k2, out value))
return true;
value = default;
return false;
}
}

View File

@ -19,6 +19,10 @@ public sealed class PathParameterAttribute : Attribute {
throw new ArgumentException($"'{nameof(name)}' must not contain spaces.", nameof(name)); throw new ArgumentException($"'{nameof(name)}' must not contain spaces.", nameof(name));
} }
if (!uint.TryParse(name[1..], out _) && name != "$*") {
throw new ArgumentException($"'{nameof(name)}' must only consist of spaces or be exactly '$*'.", nameof(name));
}
Name = name; Name = name;
} }
} }

View File

@ -0,0 +1,107 @@
using System.Data;
using System.Diagnostics.CodeAnalysis;
namespace SimpleHttpServer.Types;
internal class PathTree<T> where T : class {
private readonly Node? rootNode = null;
public PathTree() : this(new()) { }
public PathTree(Dictionary<string, T> dict) {
if (dict == null || dict.Count == 0)
return;
rootNode = new();
var currNode = rootNode;
var unpackedPaths = dict.Keys.Select(p => p.Split('/').ToArray()).ToArray();
var unpackedLeafData = dict.Values.ToArray();
for (int i = 0; i < unpackedPaths.Length; i++) {
var path = unpackedPaths[i];
var catchallidx = Array.IndexOf(path, "$*");
if (catchallidx != -1 && catchallidx != path.Length - 1) {
throw new Exception($"Found illegal catchall-wildcard in path: '{string.Join('/', path)}'");
}
var leafdata = unpackedLeafData[i];
rootNode.AddSuccessor(path, leafdata);
}
}
internal bool TryGetPath(string reqPath, [MaybeNullWhen(false)] out T? endpoint) {
if (rootNode == null) {
endpoint = null;
return false;
}
// try to find path-match
Node currNode = rootNode;
Assert(reqPath[0] == '/');
var splittedPath = reqPath[1..].Split("/");
Node? lastCatchallNode = null;
for (int i = 0; i < splittedPath.Length; i++) {
// keep track of the current best catchallNode
if (currNode.catchAllNext != null) {
lastCatchallNode = currNode.catchAllNext;
}
var seg = splittedPath[i];
if (currNode.next?.TryGetValue(seg, out var next) == true) { // look for an explicit path to follow greedily
currNode = next;
} else if (currNode.pathWildcardNext != null) { // otherwise look for a single-wildcard to follow
currNode = currNode.pathWildcardNext;
} else { // otherwise we are done, there is no valid path --> fall back to the most specific catchall
endpoint = lastCatchallNode?.leafData;
return lastCatchallNode != null;
}
}
// return found path
endpoint = currNode.leafData;
return true;
}
private class Node {
public T? leafData = null;
public Dictionary<string, Node>? next = null;
public Node? pathWildcardNext = null; // path wildcard
public Node? catchAllNext = null; // trailing-catchall wildcard
public void AddSuccessor(string[] segments, T? newLeafData) {
if (segments.Length == 0) { // actually add the data to this node
Assert(leafData == null);
leafData = newLeafData;
return;
}
var seg = segments[0];
bool newIsWildcard = seg.Length > 1 && seg[0] == '$';
if (newIsWildcard) {
bool newIsCatchallWildcard = newIsWildcard && seg.Length == 2 && seg[1] == '*';
if (newIsCatchallWildcard) { // this is a catchall wildcard
Assert(catchAllNext == null);
catchAllNext = new();
catchAllNext.AddSuccessor(segments[1..], newLeafData);
return;
} else { // must be single wildcard otherwise
Assert(pathWildcardNext == null);
pathWildcardNext = new();
pathWildcardNext.AddSuccessor(segments[1..], newLeafData);
return;
}
}
// otherwise we want to add a new constant path successor
if (next == null) {
next = new();
}
if (next.TryGetValue(seg, out var existingNode)) {
existingNode.AddSuccessor(segments[1..], newLeafData);
} else {
var newNode = next[seg] = new();
newNode.AddSuccessor(segments[1..], newLeafData);
}
}
}
}