Compare commits

..

4 Commits

Author SHA1 Message Date
GHXX 61e2d39bee Merge branch 'feature/pathparams' 2024-08-11 04:47:39 +02:00
GHXX 7a516668bf Merge branch 'master' into feature/pathparams 2024-08-11 04:45:55 +02:00
GHXX fecd40cd57 finish implementing path parameters 2024-08-11 04:43:20 +02:00
GHXX a24543063b work towards path parameters 2024-07-30 08:42:46 +02:00
6 changed files with 272 additions and 30 deletions

View File

@ -84,8 +84,11 @@ public sealed class HttpServer {
RegisterConverter<decimal>();
}
private readonly Dictionary<(string path, string rType), EndpointInvocationInfo> simpleEndpointMethodInfos = new();
private readonly MultiKeyDictionary<string, string, EndpointInvocationInfo> simpleEndpointMethodInfos = new(); // requestmethod, path
private readonly MultiKeyDictionary<string, string, EndpointInvocationInfo> pathEndpointMethodInfos = new(); // requestmethod, path
private readonly Dictionary<string, PathTree<EndpointInvocationInfo>> pathEndpointMethodInfosTrees = new(); // reqmethod : pathtree
private static readonly Type[] expectedEndpointParameterTypes = new[] { typeof(RequestContext) };
internal static readonly int expectedEndpointParameterPrefixCount = expectedEndpointParameterTypes.Length;
public void RegisterEndpointsFromType<T>(Func<T>? instanceFactory = null) where T : class { // T cannot be static, as generic args must be nonstatic
if (stringToTypeParameterConverters.Count == 0)
@ -107,27 +110,48 @@ public sealed class HttpServer {
Assert(mi.IsPublic, $"Method tagged with HttpEndpointAttribute must be public! ({GetFancyMethodName()})");
var methodParams = mi.GetParameters();
// check the mandatory prefix parameters
Assert(methodParams.Length >= expectedEndpointParameterTypes.Length);
for (int i = 0; i < expectedEndpointParameterTypes.Length; i++) {
Assert(methodParams[i].ParameterType.IsAssignableFrom(expectedEndpointParameterTypes[i]),
$"Parameter at index {i} of {GetFancyMethodName()} is of a type that cannot contain the expected type {expectedEndpointParameterTypes[i].FullName}.");
}
// check return type
Assert(mi.ReturnType == typeof(Task), $"Return type of {GetFancyMethodName()} is not {typeof(Task)}!");
// check the rest of the method parameters
var qparams = new List<QueryParameterInfo>();
var pparams = new List<PathParameterInfo>();
int mParamIndex = expectedEndpointParameterTypes.Length;
for (int i = expectedEndpointParameterTypes.Length; i < methodParams.Length; i++) {
var par = methodParams[i];
var attr = par.GetCustomAttribute<ParameterAttribute>(false);
qparams.Add(new(
attr?.Name ?? par.Name ?? throw new ArgumentException($"C# variable name of parameter at index {i} of method {GetFancyMethodName()} is null!"),
par.ParameterType,
attr?.IsOptional ?? false)
);
var pathAttr = par.GetCustomAttribute<PathParameterAttribute>(false);
if (attr != null && pathAttr != null) {
throw new ArgumentException($"A method argument cannot be tagged with both {nameof(ParameterAttribute)} and {nameof(PathParameterAttribute)}");
}
if (!stringToTypeParameterConverters.ContainsKey(par.ParameterType)) {
throw new MissingParameterConverterException($"Parameter converter for type {par.ParameterType} has not been registered (yet)!");
throw new MissingParameterConverterException($"Parameter converter for type {par.ParameterType} for parameter at index {i} of method {GetFancyMethodName()} has not been registered (yet)!");
}
if (pathAttr != null) { // parameter is a path param
pparams.Add(new(
pathAttr?.Name ?? throw new ArgumentException($"C# variable name of path parameter at index {i} of method {GetFancyMethodName()} is null!"),
par.ParameterType,
mParamIndex++
)
);
} else { // parameter is a normal query param
qparams.Add(new(
attr?.Name ?? par.Name ?? throw new ArgumentException($"C# variable name of query parameter at index {i} of method {GetFancyMethodName()} is null!"),
par.ParameterType,
mParamIndex++,
attr?.IsOptional ?? false)
);
}
}
@ -141,17 +165,35 @@ public sealed class HttpServer {
foreach (var location in attrib.Locations) {
var normLocation = NormalizeUrlPath(location);
int idx = normLocation.IndexOf('{');
if (idx >= 0) {
// this path contains path parameters
throw new NotImplementedException("Path parameters are not yet implemented!");
}
var reqMethod = Enum.GetName(attrib.RequestMethod) ?? throw new ArgumentException("Request method was undefined");
mainLogger.Information($"Registered endpoint: '{reqMethod} {normLocation}'");
simpleEndpointMethodInfos.Add((normLocation, reqMethod), new EndpointInvocationInfo(mi, qparams, requiredChecks, classInstance));
var pparamsCopy = new List<PathParameterInfo>(pparams);
var splittedLocation = location[1..].Split('/');
for (int i = 0; i < pparamsCopy.Count; i++) {
var pp = pparamsCopy[i];
var idx = Array.IndexOf(splittedLocation, pp.Name);
Assert(idx != -1, "Path parameter name was incorrect?");
pp.SegmentStartPos = idx;
pparamsCopy[i] = pp;
}
var epInvocInfo = new EndpointInvocationInfo(mi, pparamsCopy, qparams, requiredChecks, classInstance);
if (pparams.Any()) {
mainLogger.Information($"Registered path endpoint: '{reqMethod} {normLocation}'");
Assert(normLocation[0] == '/');
pathEndpointMethodInfos.Add(reqMethod, normLocation[1..], epInvocInfo);
} else {
mainLogger.Information($"Registered simple endpoint: '{reqMethod} {normLocation}'");
simpleEndpointMethodInfos.Add(reqMethod, normLocation, epInvocInfo);
}
}
}
// rebuild path trees
pathEndpointMethodInfosTrees.Clear();
foreach (var (reqMethod, d2) in pathEndpointMethodInfos.backingDict)
pathEndpointMethodInfosTrees.Add(reqMethod, new(d2));
}
/// <summary>
@ -217,13 +259,23 @@ public sealed class HttpServer {
}
try {
if (simpleEndpointMethodInfos.TryGetValue((reqPath, requestMethod), out var endpointInvocationInfo)) {
/* Finding the endpoint that should process the request:
* 1. Try to see if there is a simple endpoint where request method and path match
* 2. Otherwise, try to see if a path-parameter-endpoint matches (duplicates throw an error on startup)
* 3. Otherwise, check if it is inside a static serve path
* 4. Otherwise, show 404 page */
EndpointInvocationInfo? pathEndpointInvocationInfo = null;
if (simpleEndpointMethodInfos.TryGetValue(requestMethod, reqPath, out var simpleEndpointInvocationInfo) ||
pathEndpointMethodInfosTrees.TryGetValue(requestMethod, out var pt) && pt.TryGetPath(reqPath, out pathEndpointInvocationInfo)) { // try to find simple or pathparam-endpoint
var endpointInvocationInfo = simpleEndpointInvocationInfo ?? pathEndpointInvocationInfo ?? throw new Exception("retrieved endpoint is somehow null");
var mi = endpointInvocationInfo.methodInfo;
var qparams = endpointInvocationInfo.queryParameters;
var pparams = endpointInvocationInfo.pathParameters;
var args = splitted.Length == 2 ? splitted[1] : null;
var parsedQParams = new Dictionary<string, string>();
var convertedQParamValues = new object[qparams.Count + 1];
var convertedMParamValues = new object[expectedEndpointParameterTypes.Length + pparams.Count + qparams.Count];
// run the checks to see if the client is allowed to make this request
if (!endpointInvocationInfo.CheckAll(rc.ListenerContext.Request)) { // if any check failed return Forbidden
@ -251,32 +303,45 @@ public sealed class HttpServer {
if (parsedQParams.TryGetValue(qparam.Name, out var qparamValue)) {
if (stringToTypeParameterConverters[qparam.Type].TryConvertFromString(qparamValue, out object objRes)) {
convertedQParamValues[i] = objRes;
convertedMParamValues[qparam.ArgPos] = objRes;
} else {
await HandleDefaultErrorPageAsync(rc, HttpStatusCode.BadRequest);
return;
}
} else {
if (qparam.IsOptional) {
convertedQParamValues[i] = null!;
convertedMParamValues[qparam.ArgPos] = null!;
} else {
await HandleDefaultErrorPageAsync(rc, HttpStatusCode.BadRequest, $"Missing required query parameter {qparam.Name}");
return;
}
}
}
} else {
} else { // check for missing query parameters
var requiredParams = qparams.Where(x => !x.IsOptional).Select(x => $"'{x.Name}'").ToList();
if (requiredParams.Any()) {
await HandleDefaultErrorPageAsync(rc, HttpStatusCode.BadRequest, $"Missing required query parameter(s): {string.Join(",", requiredParams)}");
return;
}
}
convertedQParamValues[0] = rc;
if (pparams.Count != 0) {
var splittedReqPath = reqPath[1..].Split('/');
for (int i = 0; i < pparams.Count; i++) {
var pparam = pparams[i];
if (pparam.IsCatchAll)
convertedMParamValues[pparam.ArgPos] = string.Join('/', splittedReqPath[pparam.SegmentStartPos..]);
else
convertedMParamValues[pparam.ArgPos] = splittedReqPath[pparam.SegmentStartPos];
}
}
convertedMParamValues[0] = rc;
rc.ParsedParameters = parsedQParams.AsReadOnly();
await (Task) (mi.Invoke(endpointInvocationInfo.typeInstanceReference, convertedQParamValues) ?? throw new NullReferenceException("Website func returned null unexpectedly"));
} else {
// todo read and convert pathparams
await (Task) (mi.Invoke(endpointInvocationInfo.typeInstanceReference, convertedMParamValues) ?? throw new NullReferenceException("Website func returned null unexpectedly"));
} else { // try to find suitable static serve path
if (requestMethod == "GET")
foreach (var (k, v) in staticServePaths) {
if (reqPath.StartsWith(k)) { // do a static serve

View File

@ -2,23 +2,46 @@
using System.Reflection;
namespace SimpleHttpServer.Types;
internal readonly struct EndpointInvocationInfo {
internal record struct QueryParameterInfo(string Name, Type Type, bool IsOptional);
internal record EndpointInvocationInfo {
//internal record struct QueryParameterInfo(string Name, Type Type, bool isPathParam, bool Path_isCatchAll, bool Query_IsOptional) {
// public static QueryParameterInfo CreatePathParam(string name, Type type) => new(name, type, false, name == "$*", false);
// public static QueryParameterInfo CreateQueryParam(string name, Type type, bool isOptional) => new(name, type, false, false, isOptional);
//}
internal record struct PathParameterInfo(string Name, Type Type, int ArgPos, int SegmentStartPos, bool IsCatchAll) {
public PathParameterInfo(string name, Type type, int argPos) : this(name, type, argPos, -1, name == "$*") { }
}
internal record struct QueryParameterInfo(string Name, Type Type, int ArgPos, bool IsOptional);
internal readonly MethodInfo methodInfo;
internal readonly List<QueryParameterInfo> queryParameters;
internal readonly List<PathParameterInfo> pathParameters;
internal readonly InternalEndpointCheckAttribute[] requiredChecks;
/// <summary>
/// a reference to the object in which this method is defined (or null if the class is static)
/// </summary>
internal readonly object? typeInstanceReference;
public EndpointInvocationInfo(MethodInfo methodInfo, List<QueryParameterInfo> queryParameters, InternalEndpointCheckAttribute[] requiredChecks, object? typeInstanceReference) {
public EndpointInvocationInfo(MethodInfo methodInfo, List<PathParameterInfo> pathParameters, List<QueryParameterInfo> queryParameters, InternalEndpointCheckAttribute[] requiredChecks,
object? typeInstanceReference) {
this.methodInfo = methodInfo ?? throw new ArgumentNullException(nameof(methodInfo));
this.queryParameters = queryParameters ?? throw new ArgumentNullException(nameof(queryParameters));
this.pathParameters = pathParameters ?? throw new ArgumentNullException(nameof(pathParameters));
this.requiredChecks = requiredChecks;
this.typeInstanceReference = typeInstanceReference;
if (pathParameters.Any()) {
Assert(pathParameters.Count(x => x.IsCatchAll) <= 1); // at most one catchall parameter
var argPoses = pathParameters.Select(x => x.ArgPos).Concat(queryParameters.Select(x => x.ArgPos)).ToArray();
var argCnt = pathParameters.Count + queryParameters.Count;
Assert(argPoses.Distinct().Count() == argCnt); // ArgPoses must be unique
Assert(argPoses.Min() == HttpServer.expectedEndpointParameterPrefixCount); // ArgPoses must start from just after the prefix
Assert(argPoses.Max() == HttpServer.expectedEndpointParameterPrefixCount + argCnt - 1); // ArgPoses must be contiguous
Assert(pathParameters.All(x => x.SegmentStartPos != -1));
}
}
public readonly bool CheckAll(HttpListenerRequest req) => requiredChecks.All(x => x.Check(req));
public bool CheckAll(HttpListenerRequest req) => requiredChecks.All(x => x.Check(req));
}

View File

@ -0,0 +1,22 @@
using System.Diagnostics.CodeAnalysis;
namespace SimpleHttpServer.Types;
internal class MultiKeyDictionary<K1, K2, V> where K1 : notnull where K2 : notnull {
internal readonly Dictionary<K1, Dictionary<K2, V>> backingDict = new();
public MultiKeyDictionary() { }
public void Add(K1 k1, K2 k2, V value) {
if (!backingDict.TryGetValue(k1, out var d2))
d2 = new();
d2.Add(k2, value);
backingDict[k1] = d2;
}
public bool TryGetValue(K1 k1, K2 k2, [MaybeNullWhen(false)] out V value) {
if (backingDict.TryGetValue(k1, out var d2) && d2.TryGetValue(k2, out value))
return true;
value = default;
return false;
}
}

View File

@ -5,9 +5,6 @@
/// </summary>
[AttributeUsage(AttributeTargets.Parameter, Inherited = false, AllowMultiple = false)]
public sealed class ParameterAttribute : Attribute {
// See the attribute guidelines at
// http://go.microsoft.com/fwlink/?LinkId=85236
public string Name { get; }
public bool IsOptional { get; }
public ParameterAttribute(string name, bool isOptional = false) {

View File

@ -0,0 +1,28 @@
namespace SimpleHttpServer.Types;
/// <summary>
/// Specifies the name of a http endpoint path parameter. Path parameter names must be in the format $1, $2, $3, ..., and the end of the path may be $*
/// </summary>
[AttributeUsage(AttributeTargets.Parameter, Inherited = false, AllowMultiple = false)]
public sealed class PathParameterAttribute : Attribute {
public string Name { get; }
public PathParameterAttribute(string name) {
if (string.IsNullOrWhiteSpace(name)) {
throw new ArgumentException($"'{nameof(name)}' cannot be null or whitespace.", nameof(name));
}
if (!name.StartsWith('$')) {
throw new ArgumentException($"'{nameof(name)}' must start with $.", nameof(name));
}
if (name.Contains(' ')) {
throw new ArgumentException($"'{nameof(name)}' must not contain spaces.", nameof(name));
}
if (!uint.TryParse(name[1..], out _) && name != "$*") {
throw new ArgumentException($"'{nameof(name)}' must only consist of spaces or be exactly '$*'.", nameof(name));
}
Name = name;
}
}

View File

@ -0,0 +1,107 @@
using System.Data;
using System.Diagnostics.CodeAnalysis;
namespace SimpleHttpServer.Types;
internal class PathTree<T> where T : class {
private readonly Node? rootNode = null;
public PathTree() : this(new()) { }
public PathTree(Dictionary<string, T> dict) {
if (dict == null || dict.Count == 0)
return;
rootNode = new();
var currNode = rootNode;
var unpackedPaths = dict.Keys.Select(p => p.Split('/').ToArray()).ToArray();
var unpackedLeafData = dict.Values.ToArray();
for (int i = 0; i < unpackedPaths.Length; i++) {
var path = unpackedPaths[i];
var catchallidx = Array.IndexOf(path, "$*");
if (catchallidx != -1 && catchallidx != path.Length - 1) {
throw new Exception($"Found illegal catchall-wildcard in path: '{string.Join('/', path)}'");
}
var leafdata = unpackedLeafData[i];
rootNode.AddSuccessor(path, leafdata);
}
}
internal bool TryGetPath(string reqPath, [MaybeNullWhen(false)] out T? endpoint) {
if (rootNode == null) {
endpoint = null;
return false;
}
// try to find path-match
Node currNode = rootNode;
Assert(reqPath[0] == '/');
var splittedPath = reqPath[1..].Split("/");
Node? lastCatchallNode = null;
for (int i = 0; i < splittedPath.Length; i++) {
// keep track of the current best catchallNode
if (currNode.catchAllNext != null) {
lastCatchallNode = currNode.catchAllNext;
}
var seg = splittedPath[i];
if (currNode.next?.TryGetValue(seg, out var next) == true) { // look for an explicit path to follow greedily
currNode = next;
} else if (currNode.pathWildcardNext != null) { // otherwise look for a single-wildcard to follow
currNode = currNode.pathWildcardNext;
} else { // otherwise we are done, there is no valid path --> fall back to the most specific catchall
endpoint = lastCatchallNode?.leafData;
return lastCatchallNode != null;
}
}
// return found path
endpoint = currNode.leafData;
return true;
}
private class Node {
public T? leafData = null;
public Dictionary<string, Node>? next = null;
public Node? pathWildcardNext = null; // path wildcard
public Node? catchAllNext = null; // trailing-catchall wildcard
public void AddSuccessor(string[] segments, T? newLeafData) {
if (segments.Length == 0) { // actually add the data to this node
Assert(leafData == null);
leafData = newLeafData;
return;
}
var seg = segments[0];
bool newIsWildcard = seg.Length > 1 && seg[0] == '$';
if (newIsWildcard) {
bool newIsCatchallWildcard = newIsWildcard && seg.Length == 2 && seg[1] == '*';
if (newIsCatchallWildcard) { // this is a catchall wildcard
Assert(catchAllNext == null);
catchAllNext = new();
catchAllNext.AddSuccessor(segments[1..], newLeafData);
return;
} else { // must be single wildcard otherwise
Assert(pathWildcardNext == null);
pathWildcardNext = new();
pathWildcardNext.AddSuccessor(segments[1..], newLeafData);
return;
}
}
// otherwise we want to add a new constant path successor
if (next == null) {
next = new();
}
if (next.TryGetValue(seg, out var existingNode)) {
existingNode.AddSuccessor(segments[1..], newLeafData);
} else {
var newNode = next[seg] = new();
newNode.AddSuccessor(segments[1..], newLeafData);
}
}
}
}