Compare commits

...

12 Commits

10 changed files with 319 additions and 32 deletions

21
LICENSE Normal file
View File

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2024 00asdf, GHXX
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -1,9 +1,12 @@
global using static SimpleHttpServer.GlobalUsings;
using SimpleHttpServer.Types.Exceptions;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
namespace SimpleHttpServer;
internal static class GlobalUsings {
[DebuggerHidden]
internal static void Assert([DoesNotReturnIf(false)] bool b, string? message = null) {
if (!b) {
if (message == null)
@ -13,5 +16,6 @@ internal static class GlobalUsings {
}
}
[DebuggerHidden]
internal static void AssertImplies(bool x, bool y, string? message = null) => Assert(!x || y, message);
}

View File

@ -84,8 +84,11 @@ public sealed class HttpServer {
RegisterConverter<decimal>();
}
private readonly Dictionary<(string path, string rType), EndpointInvocationInfo> simpleEndpointMethodInfos = new();
private readonly MultiKeyDictionary<string, string, EndpointInvocationInfo> simpleEndpointMethodInfos = new(); // requestmethod, path
private readonly MultiKeyDictionary<string, string, EndpointInvocationInfo> pathEndpointMethodInfos = new(); // requestmethod, path
private readonly Dictionary<string, PathTree<EndpointInvocationInfo>> pathEndpointMethodInfosTrees = new(); // reqmethod : pathtree
private static readonly Type[] expectedEndpointParameterTypes = new[] { typeof(RequestContext) };
internal static readonly int expectedEndpointParameterPrefixCount = expectedEndpointParameterTypes.Length;
public void RegisterEndpointsFromType<T>(Func<T>? instanceFactory = null) where T : class { // T cannot be static, as generic args must be nonstatic
if (stringToTypeParameterConverters.Count == 0)
@ -107,27 +110,48 @@ public sealed class HttpServer {
Assert(mi.IsPublic, $"Method tagged with HttpEndpointAttribute must be public! ({GetFancyMethodName()})");
var methodParams = mi.GetParameters();
// check the mandatory prefix parameters
Assert(methodParams.Length >= expectedEndpointParameterTypes.Length);
for (int i = 0; i < expectedEndpointParameterTypes.Length; i++) {
Assert(methodParams[i].ParameterType.IsAssignableFrom(expectedEndpointParameterTypes[i]),
$"Parameter at index {i} of {GetFancyMethodName()} is of a type that cannot contain the expected type {expectedEndpointParameterTypes[i].FullName}.");
}
// check return type
Assert(mi.ReturnType == typeof(Task), $"Return type of {GetFancyMethodName()} is not {typeof(Task)}!");
// check the rest of the method parameters
var qparams = new List<QueryParameterInfo>();
var pparams = new List<PathParameterInfo>();
int mParamIndex = expectedEndpointParameterTypes.Length;
for (int i = expectedEndpointParameterTypes.Length; i < methodParams.Length; i++) {
var par = methodParams[i];
var attr = par.GetCustomAttribute<ParameterAttribute>(false);
qparams.Add(new(
attr?.Name ?? par.Name ?? throw new ArgumentException($"C# variable name of parameter at index {i} of method {GetFancyMethodName()} is null!"),
par.ParameterType,
attr?.IsOptional ?? false)
);
var pathAttr = par.GetCustomAttribute<PathParameterAttribute>(false);
if (attr != null && pathAttr != null) {
throw new ArgumentException($"A method argument cannot be tagged with both {nameof(ParameterAttribute)} and {nameof(PathParameterAttribute)}");
}
if (!stringToTypeParameterConverters.ContainsKey(par.ParameterType)) {
throw new MissingParameterConverterException($"Parameter converter for type {par.ParameterType} has not been registered (yet)!");
throw new MissingParameterConverterException($"Parameter converter for type {par.ParameterType} for parameter at index {i} of method {GetFancyMethodName()} has not been registered (yet)!");
}
if (pathAttr != null) { // parameter is a path param
pparams.Add(new(
pathAttr?.Name ?? throw new ArgumentException($"C# variable name of path parameter at index {i} of method {GetFancyMethodName()} is null!"),
par.ParameterType,
mParamIndex++
)
);
} else { // parameter is a normal query param
qparams.Add(new(
attr?.Name ?? par.Name ?? throw new ArgumentException($"C# variable name of query parameter at index {i} of method {GetFancyMethodName()} is null!"),
par.ParameterType,
mParamIndex++,
attr?.IsOptional ?? false)
);
}
}
@ -141,17 +165,35 @@ public sealed class HttpServer {
foreach (var location in attrib.Locations) {
var normLocation = NormalizeUrlPath(location);
int idx = normLocation.IndexOf('{');
if (idx >= 0) {
// this path contains path parameters
throw new NotImplementedException("Path parameters are not yet implemented!");
}
var reqMethod = Enum.GetName(attrib.RequestMethod) ?? throw new ArgumentException("Request method was undefined");
mainLogger.Information($"Registered endpoint: '{reqMethod} {normLocation}'");
simpleEndpointMethodInfos.Add((normLocation, reqMethod), new EndpointInvocationInfo(mi, qparams, requiredChecks, classInstance));
var pparamsCopy = new List<PathParameterInfo>(pparams);
var splittedLocation = location[1..].Split('/');
for (int i = 0; i < pparamsCopy.Count; i++) {
var pp = pparamsCopy[i];
var idx = Array.IndexOf(splittedLocation, pp.Name);
Assert(idx != -1, "Path parameter name was incorrect?");
pp.SegmentStartPos = idx;
pparamsCopy[i] = pp;
}
var epInvocInfo = new EndpointInvocationInfo(mi, pparamsCopy, qparams, requiredChecks, classInstance);
if (pparams.Any()) {
mainLogger.Information($"Registered path endpoint: '{reqMethod} {normLocation}'");
Assert(normLocation[0] == '/');
pathEndpointMethodInfos.Add(reqMethod, normLocation[1..], epInvocInfo);
} else {
mainLogger.Information($"Registered simple endpoint: '{reqMethod} {normLocation}'");
simpleEndpointMethodInfos.Add(reqMethod, normLocation, epInvocInfo);
}
}
}
// rebuild path trees
pathEndpointMethodInfosTrees.Clear();
foreach (var (reqMethod, d2) in pathEndpointMethodInfos.backingDict)
pathEndpointMethodInfosTrees.Add(reqMethod, new(d2));
}
/// <summary>
@ -171,7 +213,7 @@ public sealed class HttpServer {
private readonly Dictionary<Type, IParameterConverter> stringToTypeParameterConverters = new();
private static string NormalizeUrlPath(string url) {
private string NormalizeUrlPath(string url) {
var fwdSlashUrl = url.Replace('\\', '/');
var segments = fwdSlashUrl.Trim('/').Split('/', StringSplitOptions.RemoveEmptyEntries).ToList();
@ -200,7 +242,12 @@ public sealed class HttpServer {
}
rv.AppendJoin('/', simplifiedSegmentsReversed.Reverse<string>());
return '/' + (rv.ToString().TrimEnd('/') + (fwdSlashUrl.EndsWith('/') ? "/" : "")).TrimStart('/');
var suffix = (rv.ToString().TrimEnd('/') + (fwdSlashUrl.EndsWith('/') ? "/" : "")).TrimStart('/');
if (conf.TrimTrailingSlash) {
suffix = suffix.TrimEnd('/');
}
return '/' + suffix;
}
private async Task ProcessRequestAsync(HttpListenerContext ctx) {
@ -217,13 +264,23 @@ public sealed class HttpServer {
}
try {
if (simpleEndpointMethodInfos.TryGetValue((reqPath, requestMethod), out var endpointInvocationInfo)) {
/* Finding the endpoint that should process the request:
* 1. Try to see if there is a simple endpoint where request method and path match
* 2. Otherwise, try to see if a path-parameter-endpoint matches (duplicates throw an error on startup)
* 3. Otherwise, check if it is inside a static serve path
* 4. Otherwise, show 404 page */
EndpointInvocationInfo? pathEndpointInvocationInfo = null;
if (simpleEndpointMethodInfos.TryGetValue(requestMethod, reqPath, out var simpleEndpointInvocationInfo) ||
pathEndpointMethodInfosTrees.TryGetValue(requestMethod, out var pt) && pt.TryGetPath(reqPath, out pathEndpointInvocationInfo)) { // try to find simple or pathparam-endpoint
var endpointInvocationInfo = simpleEndpointInvocationInfo ?? pathEndpointInvocationInfo ?? throw new Exception("retrieved endpoint is somehow null");
var mi = endpointInvocationInfo.methodInfo;
var qparams = endpointInvocationInfo.queryParameters;
var pparams = endpointInvocationInfo.pathParameters;
var args = splitted.Length == 2 ? splitted[1] : null;
var parsedQParams = new Dictionary<string, string>();
var convertedQParamValues = new object[qparams.Count + 1];
var convertedMParamValues = new object[expectedEndpointParameterTypes.Length + pparams.Count + qparams.Count];
// run the checks to see if the client is allowed to make this request
if (!endpointInvocationInfo.CheckAll(rc.ListenerContext.Request)) { // if any check failed return Forbidden
@ -251,32 +308,53 @@ public sealed class HttpServer {
if (parsedQParams.TryGetValue(qparam.Name, out var qparamValue)) {
if (stringToTypeParameterConverters[qparam.Type].TryConvertFromString(qparamValue, out object objRes)) {
convertedQParamValues[i] = objRes;
convertedMParamValues[qparam.ArgPos] = objRes;
} else {
await HandleDefaultErrorPageAsync(rc, HttpStatusCode.BadRequest);
return;
}
} else {
if (qparam.IsOptional) {
convertedQParamValues[i] = null!;
convertedMParamValues[qparam.ArgPos] = null!;
} else {
await HandleDefaultErrorPageAsync(rc, HttpStatusCode.BadRequest, $"Missing required query parameter {qparam.Name}");
return;
}
}
}
} else {
} else { // check for missing query parameters
var requiredParams = qparams.Where(x => !x.IsOptional).Select(x => $"'{x.Name}'").ToList();
if (requiredParams.Any()) {
await HandleDefaultErrorPageAsync(rc, HttpStatusCode.BadRequest, $"Missing required query parameter(s): {string.Join(",", requiredParams)}");
return;
}
}
convertedQParamValues[0] = rc;
if (pparams.Count != 0) {
var splittedReqPath = reqPath[1..].Split('/');
for (int i = 0; i < pparams.Count; i++) {
var pparam = pparams[i];
string paramValue;
if (pparam.IsCatchAll)
paramValue = string.Join('/', splittedReqPath[pparam.SegmentStartPos..]);
else
paramValue = splittedReqPath[pparam.SegmentStartPos];
if (stringToTypeParameterConverters[pparam.Type].TryConvertFromString(paramValue, out var res))
convertedMParamValues[pparam.ArgPos] = res;
else {
await HandleDefaultErrorPageAsync(rc, HttpStatusCode.BadRequest);
return;
}
}
}
convertedMParamValues[0] = rc;
rc.ParsedParameters = parsedQParams.AsReadOnly();
await (Task) (mi.Invoke(endpointInvocationInfo.typeInstanceReference, convertedQParamValues) ?? throw new NullReferenceException("Website func returned null unexpectedly"));
} else {
// todo read and convert pathparams
await (Task) (mi.Invoke(endpointInvocationInfo.typeInstanceReference, convertedMParamValues) ?? throw new NullReferenceException("Website func returned null unexpectedly"));
} else { // try to find suitable static serve path
if (requestMethod == "GET")
foreach (var (k, v) in staticServePaths) {
if (reqPath.StartsWith(k)) { // do a static serve

View File

@ -13,6 +13,10 @@ public class SimpleHttpServerConfiguration {
/// See description of <see cref="DisableLogMessagePrinting"/>
/// </summary>
public CustomLogMessageHandler? LogMessageHandler { get; init; } = null;
/// <summary>
/// If set to true, paths ending with / are identical to paths without said trailing slash. E.g. /index is then the same as /index/
/// </summary>
public bool TrimTrailingSlash { get; init; } = true;
public SimpleHttpServerConfiguration() { }

View File

@ -2,23 +2,46 @@
using System.Reflection;
namespace SimpleHttpServer.Types;
internal readonly struct EndpointInvocationInfo {
internal record struct QueryParameterInfo(string Name, Type Type, bool IsOptional);
internal record EndpointInvocationInfo {
//internal record struct QueryParameterInfo(string Name, Type Type, bool isPathParam, bool Path_isCatchAll, bool Query_IsOptional) {
// public static QueryParameterInfo CreatePathParam(string name, Type type) => new(name, type, false, name == "$*", false);
// public static QueryParameterInfo CreateQueryParam(string name, Type type, bool isOptional) => new(name, type, false, false, isOptional);
//}
internal record struct PathParameterInfo(string Name, Type Type, int ArgPos, int SegmentStartPos, bool IsCatchAll) {
public PathParameterInfo(string name, Type type, int argPos) : this(name, type, argPos, -1, name == "$*") { }
}
internal record struct QueryParameterInfo(string Name, Type Type, int ArgPos, bool IsOptional);
internal readonly MethodInfo methodInfo;
internal readonly List<QueryParameterInfo> queryParameters;
internal readonly List<PathParameterInfo> pathParameters;
internal readonly InternalEndpointCheckAttribute[] requiredChecks;
/// <summary>
/// a reference to the object in which this method is defined (or null if the class is static)
/// </summary>
internal readonly object? typeInstanceReference;
public EndpointInvocationInfo(MethodInfo methodInfo, List<QueryParameterInfo> queryParameters, InternalEndpointCheckAttribute[] requiredChecks, object? typeInstanceReference) {
public EndpointInvocationInfo(MethodInfo methodInfo, List<PathParameterInfo> pathParameters, List<QueryParameterInfo> queryParameters, InternalEndpointCheckAttribute[] requiredChecks,
object? typeInstanceReference) {
this.methodInfo = methodInfo ?? throw new ArgumentNullException(nameof(methodInfo));
this.queryParameters = queryParameters ?? throw new ArgumentNullException(nameof(queryParameters));
this.pathParameters = pathParameters ?? throw new ArgumentNullException(nameof(pathParameters));
this.requiredChecks = requiredChecks;
this.typeInstanceReference = typeInstanceReference;
if (pathParameters.Any()) {
Assert(pathParameters.Count(x => x.IsCatchAll) <= 1); // at most one catchall parameter
var argPoses = pathParameters.Select(x => x.ArgPos).Concat(queryParameters.Select(x => x.ArgPos)).ToArray();
var argCnt = pathParameters.Count + queryParameters.Count;
Assert(argPoses.Distinct().Count() == argCnt); // ArgPoses must be unique
Assert(argPoses.Min() == HttpServer.expectedEndpointParameterPrefixCount); // ArgPoses must start from just after the prefix
Assert(argPoses.Max() == HttpServer.expectedEndpointParameterPrefixCount + argCnt - 1); // ArgPoses must be contiguous
Assert(pathParameters.All(x => x.SegmentStartPos != -1));
}
}
public readonly bool CheckAll(HttpListenerRequest req) => requiredChecks.All(x => x.Check(req));
public bool CheckAll(HttpListenerRequest req) => requiredChecks.All(x => x.Check(req));
}

View File

@ -0,0 +1,22 @@
using System.Diagnostics.CodeAnalysis;
namespace SimpleHttpServer.Types;
internal class MultiKeyDictionary<K1, K2, V> where K1 : notnull where K2 : notnull {
internal readonly Dictionary<K1, Dictionary<K2, V>> backingDict = new();
public MultiKeyDictionary() { }
public void Add(K1 k1, K2 k2, V value) {
if (!backingDict.TryGetValue(k1, out var d2))
d2 = new();
d2.Add(k2, value);
backingDict[k1] = d2;
}
public bool TryGetValue(K1 k1, K2 k2, [MaybeNullWhen(false)] out V value) {
if (backingDict.TryGetValue(k1, out var d2) && d2.TryGetValue(k2, out value))
return true;
value = default;
return false;
}
}

View File

@ -5,9 +5,6 @@
/// </summary>
[AttributeUsage(AttributeTargets.Parameter, Inherited = false, AllowMultiple = false)]
public sealed class ParameterAttribute : Attribute {
// See the attribute guidelines at
// http://go.microsoft.com/fwlink/?LinkId=85236
public string Name { get; }
public bool IsOptional { get; }
public ParameterAttribute(string name, bool isOptional = false) {

View File

@ -0,0 +1,28 @@
namespace SimpleHttpServer.Types;
/// <summary>
/// Specifies the name of a http endpoint path parameter. Path parameter names must be in the format $1, $2, $3, ..., and the end of the path may be $*
/// </summary>
[AttributeUsage(AttributeTargets.Parameter, Inherited = false, AllowMultiple = false)]
public sealed class PathParameterAttribute : Attribute {
public string Name { get; }
public PathParameterAttribute(string name) {
if (string.IsNullOrWhiteSpace(name)) {
throw new ArgumentException($"'{nameof(name)}' cannot be null or whitespace.", nameof(name));
}
if (!name.StartsWith('$')) {
throw new ArgumentException($"'{nameof(name)}' must start with $.", nameof(name));
}
if (name.Contains(' ')) {
throw new ArgumentException($"'{nameof(name)}' must not contain spaces.", nameof(name));
}
if (!uint.TryParse(name[1..], out _) && name != "$*") {
throw new ArgumentException($"'{nameof(name)}' must only consist of spaces or be exactly '$*'.", nameof(name));
}
Name = name;
}
}

View File

@ -0,0 +1,104 @@
using System.Data;
using System.Diagnostics.CodeAnalysis;
namespace SimpleHttpServer.Types;
internal class PathTree<T> where T : class {
private readonly Node? rootNode = null;
public PathTree() : this(new()) { }
public PathTree(Dictionary<string, T> dict) {
if (dict == null || dict.Count == 0)
return;
rootNode = new();
var currNode = rootNode;
var unpackedPaths = dict.Keys.Select(p => p.Split('/').ToArray()).ToArray();
var unpackedLeafData = dict.Values.ToArray();
for (int i = 0; i < unpackedPaths.Length; i++) {
var path = unpackedPaths[i];
var catchallidx = Array.IndexOf(path, "$*");
if (catchallidx != -1 && catchallidx != path.Length - 1) {
throw new Exception($"Found illegal catchall-wildcard in path: '{string.Join('/', path)}'");
}
var leafdata = unpackedLeafData[i] ?? throw new ArgumentNullException("Leafdata must not be null!");
rootNode.AddSuccessor(path, leafdata);
}
}
internal bool TryGetPath(string reqPath, [MaybeNullWhen(false)] out T endpoint) {
if (rootNode == null) {
endpoint = null;
return false;
}
// try to find path-match
Node currNode = rootNode;
Assert(reqPath[0] == '/');
var splittedPath = reqPath[1..].Split("/");
Node? lastCatchallNode = null;
for (int i = 0; i < splittedPath.Length; i++) {
// keep track of the current best catchallNode
if (currNode.catchAllNext != null) {
lastCatchallNode = currNode.catchAllNext;
}
var seg = splittedPath[i];
if (currNode.next?.TryGetValue(seg, out var next) == true) { // look for an explicit path to follow greedily
currNode = next;
} else if (currNode.pathWildcardNext != null) { // otherwise look for a single-wildcard to follow
currNode = currNode.pathWildcardNext;
} else { // otherwise we are done, there is no valid path --> fall back to the most specific catchall
endpoint = lastCatchallNode?.leafData;
return lastCatchallNode != null;
}
}
// return found path
endpoint = currNode.leafData;
return endpoint != null;
}
private class Node {
public T? leafData = null; // null means that this is a node without a value (e.g. when it is just part of a path)
public Dictionary<string, Node>? next = null;
public Node? pathWildcardNext = null; // path wildcard
public Node? catchAllNext = null; // trailing-catchall wildcard
public void AddSuccessor(string[] segments, T newLeafData) {
if (segments.Length == 0) { // actually add the data to this node
Assert(leafData == null);
leafData = newLeafData;
return;
}
var seg = segments[0];
bool newIsWildcard = seg.Length > 1 && seg[0] == '$';
if (newIsWildcard) {
bool newIsCatchallWildcard = newIsWildcard && seg.Length == 2 && seg[1] == '*';
if (newIsCatchallWildcard) { // this is a catchall wildcard
Assert(catchAllNext == null);
catchAllNext = new();
catchAllNext.AddSuccessor(segments[1..], newLeafData);
return;
} else { // must be single wildcard otherwise
pathWildcardNext ??= new();
pathWildcardNext.AddSuccessor(segments[1..], newLeafData);
return;
}
}
// otherwise we want to add a new constant path successor
next ??= new();
if (next.TryGetValue(seg, out var existingNode)) {
existingNode.AddSuccessor(segments[1..], newLeafData);
} else {
var newNode = next[seg] = new();
newNode.AddSuccessor(segments[1..], newLeafData);
}
}
}
}

View File

@ -35,6 +35,12 @@ public class RequestContext : IDisposable {
public void SetStatusCode(HttpStatusCode status) => SetStatusCode((int) status);
public async Task SetStatusCodeWriteLineDisposeAsync(HttpStatusCode status, string message) {
SetStatusCode(status);
await WriteLineToRespAsync(message);
await RespWriter.FlushAsync();
}
public async Task SetStatusCodeAndDisposeAsync(int status) {
using (this) {
SetStatusCode(status);