Compare commits

..

2 Commits

Author SHA1 Message Date
GHXX 7bc6086509 Huge refactor; Passing tests 2024-01-13 01:25:32 +01:00
GHXX a03fafebcf wip refactor 2024-01-09 06:05:15 +01:00
20 changed files with 371 additions and 1014 deletions

21
LICENSE
View File

@ -1,21 +0,0 @@
MIT License
Copyright (c) 2024 00asdf, GHXX
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -1,12 +1,9 @@
global using static SimpleHttpServer.GlobalUsings;
using SimpleHttpServer.Types.Exceptions;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
namespace SimpleHttpServer;
internal static class GlobalUsings {
[DebuggerHidden]
internal static void Assert([DoesNotReturnIf(false)] bool b, string? message = null) {
if (!b) {
if (message == null)
@ -16,6 +13,5 @@ internal static class GlobalUsings {
}
}
[DebuggerHidden]
internal static void AssertImplies(bool x, bool y, string? message = null) => Assert(!x || y, message);
}

View File

@ -1,15 +1,22 @@
using SimpleHttpServer.Types;
using SimpleHttpServer.Internal;
namespace SimpleHttpServer;
[AttributeUsage(AttributeTargets.Method, AllowMultiple = false)]
public class HttpEndpointAttribute : Attribute {
public class HttpEndpointAttribute<T> : Attribute where T : IAuthorizer {
public HttpRequestType RequestMethod { get; private set; }
public string[] Locations { get; private set; }
public Type Authorizer { get; private set; }
public HttpEndpointAttribute(HttpRequestType requestMethod, params string[] locations) {
RequestMethod = requestMethod;
Locations = locations;
Authorizer = typeof(T);
}
}
[AttributeUsage(AttributeTargets.Method)]
public class HttpEndpointAttribute : HttpEndpointAttribute<DefaultAuthorizer> {
public HttpEndpointAttribute(HttpRequestType type, params string[] locations) : base(type, locations) { }
}

View File

@ -1,4 +1,4 @@
namespace SimpleHttpServer.Types;
namespace SimpleHttpServer;
public enum HttpRequestType {
GET,

View File

@ -4,8 +4,6 @@ using SimpleHttpServer.Types.ParameterConverters;
using System.Net;
using System.Numerics;
using System.Reflection;
using System.Text;
using static SimpleHttpServer.Types.EndpointInvocationInfo;
namespace SimpleHttpServer;
@ -15,8 +13,7 @@ public sealed class HttpServer {
private readonly HttpListener listener;
private Task? listenerTask;
private readonly Logger mainLogger;
private readonly Logger requestLogger;
private readonly Logger logger;
private readonly SimpleHttpServerConfiguration conf;
private bool shutdown = false;
@ -25,20 +22,19 @@ public sealed class HttpServer {
conf = configuration;
listener = new HttpListener();
listener.Prefixes.Add($"http://localhost:{port}/");
mainLogger = new(LogOutputTopic.Main, conf);
requestLogger = new(LogOutputTopic.Request, conf);
logger = new(LogOutputTopic.Main, conf);
}
public void Start() {
mainLogger.Information($"Starting on port {Port}...");
logger.Information($"Starting on port {Port}...");
Assert(listenerTask == null, "Server was already started!");
listener.Start();
listenerTask = Task.Run(GetContextLoopAsync);
mainLogger.Information($"Ready to handle requests!");
logger.Information($"Ready to handle requests!");
}
public async Task StopAsync(CancellationToken ctok) {
mainLogger.Information("Stopping server...");
logger.Information("Stopping server...");
Assert(listenerTask != null, "Server was not started!");
shutdown = true;
listener.Stop();
@ -50,9 +46,8 @@ public sealed class HttpServer {
try {
var ctx = await listener.GetContextAsync();
_ = ProcessRequestAsync(ctx);
} catch (HttpListenerException ex) when (ex.ErrorCode == 995) { //The I/O operation has been aborted because of either a thread exit or an application request
} catch (Exception ex) {
mainLogger.Fatal($"Caught otherwise uncaught exception in GetContextLoop:\n{ex}");
logger.Fatal($"Caught otherwise uncaught exception in GetContextLoop:\n{ex}");
}
}
}
@ -61,7 +56,6 @@ public sealed class HttpServer {
void RegisterConverter<T>() where T : IParsable<T> {
stringToTypeParameterConverters.Add(typeof(T), new ParsableParameterConverter<T>());
}
stringToTypeParameterConverters.Add(typeof(string), new StringParameterConverter());
stringToTypeParameterConverters.Add(typeof(bool), new BoolParsableParameterConverter());
RegisterConverter<char>();
@ -84,335 +78,133 @@ public sealed class HttpServer {
RegisterConverter<decimal>();
}
private readonly MultiKeyDictionary<string, string, EndpointInvocationInfo> simpleEndpointMethodInfos = new(); // requestmethod, path
private readonly MultiKeyDictionary<string, string, EndpointInvocationInfo> pathEndpointMethodInfos = new(); // requestmethod, path
private readonly Dictionary<string, PathTree<EndpointInvocationInfo>> pathEndpointMethodInfosTrees = new(); // reqmethod : pathtree
private readonly Dictionary<(string path, string rType), EndpointInvocationInfo> simpleEndpointMethodInfos = new();
private static readonly Type[] expectedEndpointParameterTypes = new[] { typeof(RequestContext) };
internal static readonly int expectedEndpointParameterPrefixCount = expectedEndpointParameterTypes.Length;
public void RegisterEndpointsFromType<T>(Func<T>? instanceFactory = null) where T : class { // T cannot be static, as generic args must be nonstatic
if (stringToTypeParameterConverters.Count == 0)
public void RegisterEndpointsFromType<T>() {
if (simpleEndpointMethodInfos.Count == 0)
RegisterDefaultConverters();
var t = typeof(T);
var mis = t.GetMethods()
.ToDictionary(x => x, x => x.GetCustomAttributes<HttpEndpointAttribute>())
.Where(x => x.Value.Any()).ToDictionary(x => x.Key, x => x.Value.Single());
var isStatic = mis.All(x => x.Key.IsStatic); // if all are static then there is no point in having a constructor as no instance data is accessible, but we allow passing a factory anyway
Assert(isStatic || (instanceFactory != null), $"You must provide an instance factory if any methods of the given type ({typeof(T).FullName}) are non-static");
T? classInstance = instanceFactory?.Invoke();
foreach (var (mi, attrib) in mis) {
foreach (var (mi, attrib) in t.GetMethods()
.ToDictionary(x => x, x => x.GetCustomAttributes(typeof(HttpEndpointAttribute<>)))
.Where(x => x.Value.Any()).ToDictionary(x => x.Key, x => (HttpEndpointAttribute) x.Value.Single())) {
string GetFancyMethodName() => mi.DeclaringType!.FullName + "#" + mi.Name;
//Assert(mi.IsStatic, $"Method tagged with HttpEndpointAttribute must be static! ({GetFancyMethodName()})");
Assert(mi.IsStatic, $"Method tagged with HttpEndpointAttribute must be static! ({GetFancyMethodName()})");
Assert(mi.IsPublic, $"Method tagged with HttpEndpointAttribute must be public! ({GetFancyMethodName()})");
var methodParams = mi.GetParameters();
// check the mandatory prefix parameters
Assert(methodParams.Length >= expectedEndpointParameterTypes.Length);
for (int i = 0; i < expectedEndpointParameterTypes.Length; i++) {
Assert(methodParams[i].ParameterType.IsAssignableFrom(expectedEndpointParameterTypes[i]),
$"Parameter at index {i} of {GetFancyMethodName()} is of a type that cannot contain the expected type {expectedEndpointParameterTypes[i].FullName}.");
}
// check return type
Assert(mi.ReturnType == typeof(Task), $"Return type of {GetFancyMethodName()} is not {typeof(Task)}!");
// check the rest of the method parameters
var qparams = new List<QueryParameterInfo>();
var pparams = new List<PathParameterInfo>();
int mParamIndex = expectedEndpointParameterTypes.Length;
var qparams = new List<(string, (Type type, bool isOptional))>();
for (int i = expectedEndpointParameterTypes.Length; i < methodParams.Length; i++) {
var par = methodParams[i];
var attr = par.GetCustomAttribute<ParameterAttribute>(false);
var pathAttr = par.GetCustomAttribute<PathParameterAttribute>(false);
if (attr != null && pathAttr != null) {
throw new ArgumentException($"A method argument cannot be tagged with both {nameof(ParameterAttribute)} and {nameof(PathParameterAttribute)}");
}
qparams.Add((attr?.Name ?? par.Name ?? throw new ArgumentException($"C# variable name of parameter at index {i} of method {GetFancyMethodName()} is null!"),
(par.GetType(), attr?.IsOptional ?? false)));
if (!stringToTypeParameterConverters.ContainsKey(par.ParameterType)) {
throw new MissingParameterConverterException($"Parameter converter for type {par.ParameterType} for parameter at index {i} of method {GetFancyMethodName()} has not been registered (yet)!");
}
if (pathAttr != null) { // parameter is a path param
pparams.Add(new(
pathAttr?.Name ?? throw new ArgumentException($"C# variable name of path parameter at index {i} of method {GetFancyMethodName()} is null!"),
par.ParameterType,
mParamIndex++
)
);
} else { // parameter is a normal query param
qparams.Add(new(
attr?.Name ?? par.Name ?? throw new ArgumentException($"C# variable name of query parameter at index {i} of method {GetFancyMethodName()} is null!"),
par.ParameterType,
mParamIndex++,
attr?.IsOptional ?? false)
);
throw new MissingParameterConverterException($"Parameter converter for type {par.ParameterType} has not been registered (yet)!");
}
}
// stores the check attributes that are defined on the method and on the containing class
InternalEndpointCheckAttribute[] requiredChecks = mi.GetCustomAttributes<InternalEndpointCheckAttribute>(true)
.Concat(mi.DeclaringType?.GetCustomAttributes<InternalEndpointCheckAttribute>(true) ?? Enumerable.Empty<Attribute>())
.Where(a => a.GetType().IsAssignableTo(typeof(InternalEndpointCheckAttribute)))
.Cast<InternalEndpointCheckAttribute>().ToArray();
InternalEndpointCheckAttribute.Initialize(classInstance, requiredChecks);
foreach (var location in attrib.Locations) {
var normLocation = NormalizeUrlPath(location);
int idx = location.IndexOf('{');
if (idx >= 0) {
// this path contains path parameters
throw new NotImplementedException("Path parameters are not yet implemented!");
}
var reqMethod = Enum.GetName(attrib.RequestMethod) ?? throw new ArgumentException("Request method was undefined");
var pparamsCopy = new List<PathParameterInfo>(pparams);
var splittedLocation = location[1..].Split('/');
for (int i = 0; i < pparamsCopy.Count; i++) {
var pp = pparamsCopy[i];
var idx = Array.IndexOf(splittedLocation, pp.Name);
Assert(idx != -1, "Path parameter name was incorrect?");
pp.SegmentStartPos = idx;
pparamsCopy[i] = pp;
}
var epInvocInfo = new EndpointInvocationInfo(mi, pparamsCopy, qparams, requiredChecks, classInstance);
if (pparams.Any()) {
mainLogger.Information($"Registered path endpoint: '{reqMethod} {normLocation}'");
Assert(normLocation[0] == '/');
pathEndpointMethodInfos.Add(reqMethod, normLocation[1..], epInvocInfo);
} else {
mainLogger.Information($"Registered simple endpoint: '{reqMethod} {normLocation}'");
simpleEndpointMethodInfos.Add(reqMethod, normLocation, epInvocInfo);
}
simpleEndpointMethodInfos.Add((location, reqMethod), new EndpointInvocationInfo(mi, qparams));
}
}
// rebuild path trees
pathEndpointMethodInfosTrees.Clear();
foreach (var (reqMethod, d2) in pathEndpointMethodInfos.backingDict)
pathEndpointMethodInfosTrees.Add(reqMethod, new(d2));
}
/// <summary>
/// Serves all files located in <paramref name="filesystemDirectory"/> on a website path that is relative to <paramref name="requestPath"/>,
/// while restricting requests to inside the local filesystem directory. Static serving has a lower priority than registering an endpoint.
/// </summary>
/// <param name="requestPath"></param>
/// <param name="filesystemDirectory"></param>
public void RegisterStaticServePath(string requestPath, string filesystemDirectory) {
var absPath = Path.GetFullPath(filesystemDirectory);
string npath = NormalizeUrlPath(requestPath);
mainLogger.Information($"Registered static serve path: '{npath}' --> '{absPath}'");
staticServePaths.Add(npath, absPath);
}
private readonly Dictionary<string, string> staticServePaths = new();
private readonly Dictionary<Type, IParameterConverter> stringToTypeParameterConverters = new();
private string NormalizeUrlPath(string url) {
var fwdSlashUrl = url.Replace('\\', '/');
var segments = fwdSlashUrl.Trim('/').Split('/', StringSplitOptions.RemoveEmptyEntries).ToList();
List<string> simplifiedSegmentsReversed = new List<string>();
int doubleDotsEncountered = 0;
for (int i = segments.Count - 1; i >= 0; i--) {
var segment = segments[i];
if (segment == ".") {
continue; // remove single dot segments
}
if (segment == "..") {
doubleDotsEncountered++; // if we encounter a doubledot, keep track of that and dont add it to the output yet
continue;
}
// otherwise only keep the segment if doubleDotsEncountered > 0
if (doubleDotsEncountered > 0) {
doubleDotsEncountered--;
continue;
}
simplifiedSegmentsReversed.Add(segment);
}
var rv = new StringBuilder();
for (int i = 0; i < doubleDotsEncountered; i++) {
rv.Append("../");
}
rv.AppendJoin('/', simplifiedSegmentsReversed.Reverse<string>());
var suffix = (rv.ToString().TrimEnd('/') + (fwdSlashUrl.EndsWith('/') ? "/" : "")).TrimStart('/');
if (conf.TrimTrailingSlash) {
suffix = suffix.TrimEnd('/');
}
return '/' + suffix;
}
private async Task ProcessRequestAsync(HttpListenerContext ctx) {
using RequestContext rc = new RequestContext(ctx);
// TODO add path escape countermeasure-unittests
var splitted = (ctx.Request.RawUrl ?? "").Split('?', 2, StringSplitOptions.None);
var reqPath = NormalizeUrlPath(WebUtility.UrlDecode(splitted.First()));
string requestMethod = ctx.Request.HttpMethod.ToUpperInvariant();
bool wasStaticlyServed = false;
void LogRequest() {
requestLogger.Information($"{rc.ListenerContext.Response.StatusCode} {(wasStaticlyServed ? "static" : "endpnt")} {requestMethod} {ctx.Request.Url}");
}
try {
var decUri = WebUtility.UrlDecode(ctx.Request.RawUrl)!; // TODO add path escape countermeasures+unittests
var splitted = decUri.Split('?', 2, StringSplitOptions.None);
var path = WebUtility.UrlDecode(splitted.First());
/* Finding the endpoint that should process the request:
* 1. Try to see if there is a simple endpoint where request method and path match
* 2. Otherwise, try to see if a path-parameter-endpoint matches (duplicates throw an error on startup)
* 3. Otherwise, check if it is inside a static serve path
* 4. Otherwise, show 404 page */
EndpointInvocationInfo? pathEndpointInvocationInfo = null;
if (simpleEndpointMethodInfos.TryGetValue(requestMethod, reqPath, out var simpleEndpointInvocationInfo) ||
pathEndpointMethodInfosTrees.TryGetValue(requestMethod, out var pt) && pt.TryGetPath(reqPath, out pathEndpointInvocationInfo)) { // try to find simple or pathparam-endpoint
var endpointInvocationInfo = simpleEndpointInvocationInfo ?? pathEndpointInvocationInfo ?? throw new Exception("retrieved endpoint is somehow null");
using var rc = new RequestContext(ctx);
if (simpleEndpointMethodInfos.TryGetValue((decUri, ctx.Request.HttpMethod.ToUpperInvariant()), out var endpointInvocationInfo)) {
var mi = endpointInvocationInfo.methodInfo;
var qparams = endpointInvocationInfo.queryParameters;
var pparams = endpointInvocationInfo.pathParameters;
var args = splitted.Length == 2 ? splitted[1] : null;
var parsedQParams = new Dictionary<string, string>();
var convertedMParamValues = new object[expectedEndpointParameterTypes.Length + pparams.Count + qparams.Count];
var convertedQParamValues = new object[qparams.Count + 1];
// run the checks to see if the client is allowed to make this request
if (!endpointInvocationInfo.CheckAll(rc.ListenerContext.Request)) { // if any check failed return Forbidden
await HandleDefaultErrorPageAsync(rc, HttpStatusCode.Forbidden, "Client is not allowed to access this resource");
return;
}
// TODO add authcheck here
if (args != null) {
var queryStringArgs = args.Split('&', StringSplitOptions.None);
foreach (var queryKV in queryStringArgs) {
var queryKVSplitted = queryKV.Split('=');
if (queryKVSplitted.Length != 2) {
await HandleDefaultErrorPageAsync(rc, HttpStatusCode.BadRequest, "Malformed request URL parameters");
rc.SetStatusCodeAndDispose(HttpStatusCode.BadRequest, "Malformed request URL parameters");
return;
}
if (!parsedQParams.TryAdd(WebUtility.UrlDecode(queryKVSplitted[0]), WebUtility.UrlDecode(queryKVSplitted[1]))) {
await HandleDefaultErrorPageAsync(rc, HttpStatusCode.BadRequest, "Duplicate request URL parameters");
rc.SetStatusCodeAndDispose(HttpStatusCode.BadRequest, "Duplicate request URL parameters");
return;
}
}
for (int i = 0; i < qparams.Count;) {
var qparam = qparams[i];
var (qparamName, qparamInfo) = qparams[i];
i++;
if (parsedQParams.TryGetValue(qparam.Name, out var qparamValue)) {
if (stringToTypeParameterConverters[qparam.Type].TryConvertFromString(qparamValue, out object objRes)) {
convertedMParamValues[qparam.ArgPos] = objRes;
if (parsedQParams.TryGetValue(qparamName, out var qparamValue)) {
if (stringToTypeParameterConverters[qparamInfo.type].TryConvertFromString(qparamValue, out object objRes)) {
convertedQParamValues[i] = objRes;
} else {
await HandleDefaultErrorPageAsync(rc, HttpStatusCode.BadRequest);
rc.SetStatusCodeAndDispose(HttpStatusCode.BadRequest);
return;
}
} else {
if (qparam.IsOptional) {
convertedMParamValues[qparam.ArgPos] = null!;
if (qparamInfo.isOptional) {
convertedQParamValues[i] = null!;
} else {
await HandleDefaultErrorPageAsync(rc, HttpStatusCode.BadRequest, $"Missing required query parameter {qparam.Name}");
rc.SetStatusCodeAndDispose(HttpStatusCode.BadRequest, $"Missing required query parameter {qparamName}");
return;
}
}
}
} else { // check for missing query parameters
var requiredParams = qparams.Where(x => !x.IsOptional).Select(x => $"'{x.Name}'").ToList();
if (requiredParams.Any()) {
await HandleDefaultErrorPageAsync(rc, HttpStatusCode.BadRequest, $"Missing required query parameter(s): {string.Join(",", requiredParams)}");
return;
}
}
if (pparams.Count != 0) {
var splittedReqPath = reqPath[1..].Split('/');
for (int i = 0; i < pparams.Count; i++) {
var pparam = pparams[i];
string paramValue;
if (pparam.IsCatchAll)
paramValue = string.Join('/', splittedReqPath[pparam.SegmentStartPos..]);
else
paramValue = splittedReqPath[pparam.SegmentStartPos];
if (stringToTypeParameterConverters[pparam.Type].TryConvertFromString(paramValue, out var res))
convertedMParamValues[pparam.ArgPos] = res;
else {
await HandleDefaultErrorPageAsync(rc, HttpStatusCode.BadRequest);
return;
}
}
}
convertedMParamValues[0] = rc;
rc.ParsedParameters = parsedQParams.AsReadOnly();
// todo read and convert pathparams
await (Task) (mi.Invoke(endpointInvocationInfo.typeInstanceReference, convertedMParamValues) ?? throw new NullReferenceException("Website func returned null unexpectedly"));
} else { // try to find suitable static serve path
if (requestMethod == "GET")
foreach (var (k, v) in staticServePaths) {
if (reqPath.StartsWith(k)) { // do a static serve
wasStaticlyServed = true;
var relativeStaticReqPath = reqPath[k.Length..];
var staticResponsePath = Path.GetFullPath(Path.Join(v, relativeStaticReqPath.TrimStart('/')));
if (Path.GetRelativePath(v, staticResponsePath).Contains("..")) {
requestLogger.Warning($"Blocked GET request to {reqPath} as somehow the target file does not lie inside the static serve folder? Are you using symlinks?");
await HandleDefaultErrorPageAsync(rc, HttpStatusCode.NotFound);
return;
}
if (File.Exists(staticResponsePath)) {
rc.SetStatusCode(HttpStatusCode.OK);
if (staticResponsePath.EndsWith(".svg")) {
rc.ListenerContext.Response.AddHeader("Content-Type", "image/svg+xml");
}
using var f = File.OpenRead(staticResponsePath);
await f.CopyToAsync(rc.ListenerContext.Response.OutputStream);
} else {
await HandleDefaultErrorPageAsync(rc, HttpStatusCode.NotFound);
}
return;
}
}
convertedQParamValues[0] = rc;
await (Task) (mi.Invoke(null, convertedQParamValues) ?? throw new NullReferenceException("Website func returned null unexpectedly"));
} else {
// invoke 404
await HandleDefaultErrorPageAsync(rc, 404);
}
} catch (Exception ex) {
await HandleDefaultErrorPageAsync(rc, 500);
mainLogger.Fatal($"Caught otherwise uncaught exception while ProcessingRequest:\n{ex}");
} finally {
try { await rc.RespWriter.FlushAsync(); } catch (ObjectDisposedException) { }
rc.ListenerContext.Response.Close();
LogRequest();
logger.Fatal($"Caught otherwise uncaught exception while ProcessingRequest:\n{ex}");
}
}
private static async Task HandleDefaultErrorPageAsync(RequestContext ctx, HttpStatusCode errorCode, string? statusDescription = null) => await HandleDefaultErrorPageAsync(ctx, (int) errorCode, statusDescription);
private static async Task HandleDefaultErrorPageAsync(RequestContext ctx, int errorCode, string? statusDescription = null) {
ctx.SetStatusCode(errorCode);
string desc = statusDescription != null ? $"\r\n{statusDescription}" : "";
private static async Task HandleDefaultErrorPageAsync(RequestContext ctx, int errorCode) {
await ctx.WriteLineToRespAsync($"""
<body>
<h1>Oh no, an error occurred!</h1>
<p>Code: {errorCode}</p>{desc}
<h1>Oh no, and error occurred!</h1>
<p>Code: {errorCode}</p>
</body>
""");
try {
if (statusDescription == null) {
await ctx.SetStatusCodeAndDisposeAsync(errorCode);
} else {
await ctx.SetStatusCodeAndDisposeAsync(errorCode, statusDescription);
}
} catch (ObjectDisposedException) { }
}
}

View File

@ -0,0 +1,7 @@
using System.Net;
namespace SimpleHttpServer;
public interface IAuthorizer {
public abstract (bool auth, object? data) IsAuthenticated(HttpListenerContext contect);
}

View File

@ -0,0 +1,7 @@
using System.Net;
namespace SimpleHttpServer.Internal;
public sealed class DefaultAuthorizer : IAuthorizer {
public (bool auth, object? data) IsAuthenticated(HttpListenerContext contect) => (true, null);
}

View File

@ -1,73 +1,73 @@
//using Newtonsoft.Json;
//using System.Collections;
//using System.Net;
//using System.Reflection;
using Newtonsoft.Json;
using System.Collections;
using System.Net;
using System.Reflection;
//namespace SimpleHttpServer.Internal;
namespace SimpleHttpServer.Internal;
//internal class HttpEndpointHandler {
// private static readonly DefaultAuthorizer defaultAuth = new();
internal class HttpEndpointHandler {
private static readonly DefaultAuthorizer defaultAuth = new();
// private readonly IAuthorizer auth;
// private readonly MethodInfo handler;
// private readonly Dictionary<string, (int pindex, Type type, int pparamIdx)> @params;
// private readonly Func<Exception, HttpResponseBuilder> errorPageBuilder;
private readonly IAuthorizer auth;
private readonly MethodInfo handler;
private readonly Dictionary<string, (int pindex, Type type, int pparamIdx)> @params;
private readonly Func<Exception, HttpResponseBuilder> errorPageBuilder;
// public HttpEndpointHandler() {
// auth = defaultAuth;
// }
public HttpEndpointHandler() {
auth = defaultAuth;
}
// public HttpEndpointHandler(IAuthorizer auth) {
public HttpEndpointHandler(IAuthorizer auth) {
// }
}
// public virtual void Handle(HttpListenerContext ctx) {
// try {
// var (isAuth, authData) = auth.IsAuthenticated(ctx);
// if (!isAuth) {
// throw new HttpHandlingException(401, "Authorization required!");
// }
public virtual void Handle(HttpListenerContext ctx) {
try {
var (isAuth, authData) = auth.IsAuthenticated(ctx);
if (!isAuth) {
throw new HttpHandlingException(401, "Authorization required!");
}
// // collect parameters
// var invokeParams = new object?[@params.Count + 1];
// var set = new BitArray(@params.Count);
// invokeParams[0] = ctx;
// collect parameters
var invokeParams = new object?[@params.Count + 1];
var set = new BitArray(@params.Count);
invokeParams[0] = ctx;
// // read pparams
// read pparams
// // read qparams
// var qst = ctx.Request.QueryString;
// foreach (var qelem in ctx.Request.QueryString.AllKeys) {
// if (@params.ContainsKey(qelem!)) {
// var (pindex, type, isPParam) = @params[qelem!];
// if (type == typeof(string)) {
// invokeParams[pindex] = ctx.Request.QueryString[qelem!];
// set.Set(pindex - 1, true);
// } else {
// var elem = JsonConvert.DeserializeObject(ctx.Request.QueryString[qelem!]!, type);
// if (elem != null) {
// invokeParams[pindex] = elem;
// set.Set(pindex - 1, true);
// }
// }
// }
// }
// read qparams
var qst = ctx.Request.QueryString;
foreach (var qelem in ctx.Request.QueryString.AllKeys) {
if (@params.ContainsKey(qelem!)) {
var (pindex, type, isPParam) = @params[qelem!];
if (type == typeof(string)) {
invokeParams[pindex] = ctx.Request.QueryString[qelem!];
set.Set(pindex - 1, true);
} else {
var elem = JsonConvert.DeserializeObject(ctx.Request.QueryString[qelem!]!, type);
if (elem != null) {
invokeParams[pindex] = elem;
set.Set(pindex - 1, true);
}
}
}
}
// // fill with defaults
// foreach (var p in @params) {
// if (!set.Get(p.Value.pindex)) {
// invokeParams[p.Value.pindex] = p.Value.type.IsValueType ? Activator.CreateInstance(p.Value.type) : null;
// }
// }
// fill with defaults
foreach (var p in @params) {
if (!set.Get(p.Value.pindex)) {
invokeParams[p.Value.pindex] = p.Value.type.IsValueType ? Activator.CreateInstance(p.Value.type) : null;
}
}
// var builder = handler.Invoke(null, invokeParams) as HttpResponseBuilder;
// builder!.SendResponse(ctx.Response);
// } catch (Exception e) {
// if (e is TargetInvocationException tex) {
// e = tex.InnerException!;
// }
// errorPageBuilder(e).SendResponse(ctx.Response);
// }
// }
//}
var builder = handler.Invoke(null, invokeParams) as HttpResponseBuilder;
builder!.SendResponse(ctx.Response);
} catch (Exception e) {
if (e is TargetInvocationException tex) {
e = tex.InnerException!;
}
errorPageBuilder(e).SendResponse(ctx.Response);
}
}
}

View File

@ -1,245 +1,243 @@
//using Newtonsoft.Json;
//using System.Diagnostics.CodeAnalysis;
//using System.Security.Cryptography;
//using System.Text;
using Konscious.Security.Cryptography;
using Newtonsoft.Json;
using System.Security.Cryptography;
using System.Text;
//namespace SimpleHttpServer.Login;
namespace SimpleHttpServer.Login;
//internal struct SerialLoginData {
// public string passwordSalt;
// public string extraDataSalt;
// public string pwd;
// public string extraData;
internal struct SerialLoginData {
public string salt;
public string pwd;
public string additionalData;
// public LoginData ToPlainData() {
// return new LoginData {
// passwordSalt = Convert.FromBase64String(passwordSalt),
// extraDataSalt = Convert.FromBase64String(extraDataSalt)
// };
// }
//}
public LoginData toPlainData() {
return new LoginData {
salt = Convert.FromBase64String(salt),
password = Convert.FromBase64String(pwd)
};
}
}
//internal struct LoginData {
// public byte[] passwordSalt;
// public byte[] extraDataSalt;
// public byte[] passwordHash;
// public byte[] encryptedExtraData;
internal struct LoginData {
public byte[] salt;
public byte[] password;
public byte[] encryptedData;
// public SerialLoginData ToSerial() {
// return new SerialLoginData {
// passwordSalt = Convert.ToBase64String(passwordSalt),
// extraDataSalt = Convert.ToBase64String(extraDataSalt),
// pwd = Convert.ToBase64String(passwordHash),
// extraData = Convert.ToBase64String(encryptedExtraData)
// };
// }
//}
public SerialLoginData toSerial() {
return new SerialLoginData {
salt = Convert.ToBase64String(salt),
pwd = Convert.ToBase64String(password),
additionalData = Convert.ToBase64String(encryptedData)
};
}
}
//internal struct LoginDataProviderConfig {
internal struct LoginDataProviderConfig {
// /// <summary>
// /// Size of the password salt and the extradata salt. So each salt will be of size <see cref="SALT_SIZE"/>.
// /// </summary>
// public int SALT_SIZE = 32;
// public int KEY_LENGTH = 256 / 8;
// public int PBKDF2_ITERATIONS = 600_000;
public int SALT_SIZE = 32;
public int KEY_LENGTH = 256 / 8;
public int A2_ITERATIONS = 5;
public int A2_MEMORY_SIZE = 500_000;
public int A2_PARALLELISM = 8;
public int A2_HASH_LENGTH = 256 / 8;
public int A2_MAX_CONCURRENT = 4;
public int PBKDF2_ITERATIONS = 600_000;
// public LoginDataProviderConfig() { }
//}
public LoginDataProviderConfig() { }
}
//public class LoginProvider<TExtraData> {
public class LoginProvider<T> {
// private static readonly Func<TExtraData, byte[]> JsonSerialize = t => Encoding.UTF8.GetBytes(JsonConvert.SerializeObject(t));
// private static readonly Func<byte[], TExtraData> JsonDeserialize = b => JsonConvert.DeserializeObject<TExtraData>(Encoding.UTF8.GetString(b))!;
private static readonly Func<T, byte[]> JsonSerialize = t => Encoding.UTF8.GetBytes(JsonConvert.SerializeObject(t));
private static readonly Func<byte[], T> JsonDeserialize = b => JsonConvert.DeserializeObject<T>(Encoding.UTF8.GetString(b))!;
// [ThreadStatic]
// private static SHA256? _sha256PerThread;
// private static SHA256 Sha256PerThread { get => _sha256PerThread ??= SHA256.Create(); }
private readonly LoginDataProviderConfig config;
private readonly ReaderWriterLockSlim ldLock = new ReaderWriterLockSlim(LockRecursionPolicy.SupportsRecursion);
private readonly string ldPath;
private readonly Dictionary<string, LoginData> loginData;
private readonly SemaphoreSlim argon2Limit;
// private readonly LoginDataProviderConfig config;
// private readonly ReaderWriterLockSlim ldLock = new ReaderWriterLockSlim(LockRecursionPolicy.SupportsRecursion);
// private readonly string ldPath;
// private readonly Dictionary<string, LoginData> loginDatas;
private Func<T, byte[]> DataSerializer = JsonSerialize;
private Func<byte[], T> DataDeserializer = JsonDeserialize;
// private Func<TExtraData, byte[]> DataSerializer = JsonSerialize;
// private Func<byte[], TExtraData> DataDeserializer = JsonDeserialize;
// public void SetDataSerializers(Func<TExtraData, byte[]> serializer, Func<byte[], TExtraData> deserializer) {
// DataSerializer = serializer ?? JsonSerialize;
// DataDeserializer = deserializer ?? JsonDeserialize;
// }
public LoginProvider(string ldPath, string confPath) {
this.ldPath = ldPath;
loginData = LoadLoginData(ldPath);
config = LoadArgon2Config(confPath);
argon2Limit = new SemaphoreSlim(config.A2_MAX_CONCURRENT);
}
private static Dictionary<string, LoginData> LoadLoginData(string path) {
Dictionary<string, SerialLoginData> tempData;
if (!File.Exists(path)) {
File.WriteAllText(path, "{}", Encoding.UTF8);
tempData = new();
} else {
tempData = JsonConvert.DeserializeObject<Dictionary<string, SerialLoginData>>(File.ReadAllText(path))!;
if (tempData == null) {
throw new InvalidDataException($"could not read login data from file {path}");
}
}
var ld = new Dictionary<string, LoginData>();
foreach (var pair in tempData!) {
ld.Add(pair.Key, pair.Value.toPlainData());
}
return ld;
}
// public LoginProvider(string ldPath, string confPath) {
// this.ldPath = ldPath;
// loginDatas = LoadLoginDatas(ldPath);
// config = LoadLoginProviderConfig(confPath);
// }
private static LoginDataProviderConfig LoadArgon2Config(string path) {
if (!File.Exists(path)) {
var conf = new LoginDataProviderConfig();
File.WriteAllText(path, JsonConvert.SerializeObject(conf));
return conf;
}
return JsonConvert.DeserializeObject<LoginDataProviderConfig>(File.ReadAllText(path));
}
// private static Dictionary<string, LoginData> LoadLoginDatas(string path) {
// Dictionary<string, SerialLoginData> tempData;
// if (!File.Exists(path)) {
// File.WriteAllText(path, "{}", Encoding.UTF8);
// tempData = new();
// } else {
// tempData = JsonConvert.DeserializeObject<Dictionary<string, SerialLoginData>>(File.ReadAllText(path))!;
// if (tempData == null) {
// throw new InvalidDataException($"could not read login data from file {path}");
// }
// }
// var ld = new Dictionary<string, LoginData>();
// foreach (var pair in tempData) {
// ld.Add(pair.Key, pair.Value.ToPlainData());
// }
// return ld;
// }
public void SetDataSerialization(Func<T, byte[]> serializer, Func<byte[], T> deserializer) {
DataSerializer = serializer ?? JsonSerialize;
DataDeserializer = deserializer ?? JsonDeserialize;
}
// private void SaveLoginData() {
// var serial = new Dictionary<string, SerialLoginData>();
// ldLock.EnterWriteLock();
// try {
// foreach (var pair in loginDatas) {
// serial.Add(pair.Key, pair.Value.ToSerial());
// }
// } finally {
// ldLock.ExitWriteLock();
// }
// File.WriteAllText(ldPath, JsonConvert.SerializeObject(serial));
// }
private void StoreLoginData() {
var serial = new Dictionary<string, SerialLoginData>();
ldLock.EnterWriteLock();
try {
foreach (var pair in loginData!) {
serial.Add(pair.Key, pair.Value.toSerial());
}
} finally {
ldLock.ExitWriteLock();
}
File.WriteAllText(ldPath, JsonConvert.SerializeObject(serial));
}
// private static LoginDataProviderConfig LoadLoginProviderConfig(string path) {
// if (!File.Exists(path)) {
// var conf = new LoginDataProviderConfig();
// File.WriteAllText(path, JsonConvert.SerializeObject(conf));
// return conf;
// }
// return JsonConvert.DeserializeObject<LoginDataProviderConfig>(File.ReadAllText(path));
// }
public bool AddUser(string username, string password, T additional) {
ldLock.EnterWriteLock();
try {
if (loginData.ContainsKey(username)) {
return false;
}
var salt = RandomNumberGenerator.GetBytes(config.SALT_SIZE);
var pwdHash = HashPwd(password, salt);
LoginData ld = new LoginData() {
salt = salt,
password = pwdHash,
encryptedData = EncryptAdditionalData(password, salt, additional)
};
loginData.Add(username, ld);
StoreLoginData();
} finally {
ldLock.ExitWriteLock();
}
return true;
}
// public bool AddUser(string username, string password, TExtraData additional) {
// ldLock.EnterWriteLock();
// try {
// if (loginDatas.ContainsKey(username)) {
// return false;
// }
// var passwordSalt = RandomNumberGenerator.GetBytes(config.SALT_SIZE);
// var extraDataSalt = RandomNumberGenerator.GetBytes(config.SALT_SIZE);
// LoginData ld = new LoginData() {
// passwordSalt = passwordSalt,
// extraDataSalt = extraDataSalt,
// passwordHash = ComputeSaltedSha256Hash(password, passwordSalt),
// encryptedExtraData = EncryptExtraData(password, extraDataSalt, additional),
// };
// loginDatas.Add(username, ld);
// SaveLoginData();
// } finally {
// ldLock.ExitWriteLock();
// }
// return true;
// }
public bool RemoveUser(string username) {
ldLock.EnterWriteLock();
try {
var removed = loginData.Remove(username);
if (removed) {
StoreLoginData();
}
return removed;
} finally {
ldLock.ExitWriteLock();
}
}
// public bool RemoveUser(string username) {
// ldLock.EnterWriteLock();
// try {
// var removed = loginDatas.Remove(username);
// if (removed) {
// SaveLoginData();
// }
// return removed;
// } finally {
// ldLock.ExitWriteLock();
// }
// }
public bool ModifyUser(string username, string newPassword, T newAdditional) {
ldLock.EnterWriteLock();
try {
if (!loginData.ContainsKey(username)) {
return false;
}
loginData.Remove(username, out var data);
data.password = HashPwd(newPassword, data.salt);
data.encryptedData = EncryptAdditionalData(newPassword, data.salt, newAdditional);
loginData.Add(username, data);
StoreLoginData();
} finally {
ldLock.ExitWriteLock();
}
return true;
}
// public bool ModifyUser(string username, string newPassword, TExtraData newExtraData) {
// ldLock.EnterWriteLock();
// try {
// if (!loginDatas.ContainsKey(username)) {
// return false;
// }
// loginDatas.Remove(username, out var data);
// data.passwordHash = ComputeSaltedSha256Hash(newPassword, data.passwordSalt);
// data.encryptedExtraData = EncryptExtraData(newPassword, data.extraDataSalt, newExtraData);
// loginDatas.Add(username, data);
// SaveLoginData();
// } finally {
// ldLock.ExitWriteLock();
// }
// return true;
// }
public (bool, T) Authenticate(string username, string password) {
LoginData data;
ldLock.EnterReadLock();
try {
if (!loginData.TryGetValue(username, out data)) {
return (false, default(T)!);
}
} finally {
ldLock.ExitReadLock();
}
var hash = HashPwd(password, data.salt);
if (!hash.SequenceEqual(data.password)) {
return (false, default(T)!);
}
return (true, DecryptAdditionalData(password, data.salt, data.encryptedData));
}
// public bool TryAuthenticate(string username, string password, [MaybeNullWhen(false)] out TExtraData extraData) {
// LoginData data;
// ldLock.EnterReadLock();
// try {
// if (!loginDatas.TryGetValue(username, out data)) {
// extraData = default;
// return false;
// }
// } finally {
// ldLock.ExitReadLock();
// }
// var hash = ComputeSaltedSha256Hash(password, data.passwordSalt);
// if (!hash.SequenceEqual(data.passwordHash)) {
// extraData = default;
// return false;
// }
// extraData = DecryptExtraData(password, data.extraDataSalt, data.encryptedExtraData);
// return true;
// }
private byte[] HashPwd(string pwd, byte[] salt) {
byte[] hash;
argon2Limit.Wait();
try {
using (var argon2 = new Argon2id(Encoding.UTF8.GetBytes(pwd))) {
argon2.Iterations = config.A2_ITERATIONS;
argon2.MemorySize = config.A2_MEMORY_SIZE;
argon2.DegreeOfParallelism = config.A2_PARALLELISM;
argon2.Salt = salt;
hash = argon2.GetBytes(config.A2_HASH_LENGTH);
}
// force collection to reduce sustained memory usage if many hashes are done in close time proximity to each other
GC.Collect();
} finally {
argon2Limit.Release();
}
return hash;
}
// /// <summary>
// /// Threadsafe as the SHA256 instance (<see cref="Sha256PerThread"/>) is per thread.
// /// </summary>
// /// <param name="data"></param>
// /// <param name="salt"></param>
// /// <returns></returns>
// private static byte[] ComputeSaltedSha256Hash(string data, byte[] salt) {
// var dataBytes = Encoding.UTF8.GetBytes(data);
// var buf = new byte[data.Length + salt.Length];
// Buffer.BlockCopy(dataBytes, 0, buf, 0, dataBytes.Length);
// Buffer.BlockCopy(salt, 0, buf, dataBytes.Length, salt.Length);
// return Sha256PerThread.ComputeHash(buf);
// }
private byte[] EncryptAdditionalData(string pwd, byte[] salt, T data) {
var pbkdf2 = new Rfc2898DeriveBytes(Encoding.UTF8.GetBytes(pwd), salt, config.PBKDF2_ITERATIONS, HashAlgorithmName.SHA256);
var key = pbkdf2.GetBytes(config.KEY_LENGTH / 8);
// private byte[] EncryptExtraData(string pwd, byte[] salt, TExtraData extraData) {
// var pbkdf2 = new Rfc2898DeriveBytes(Encoding.UTF8.GetBytes(pwd), salt, config.PBKDF2_ITERATIONS, HashAlgorithmName.SHA256);
// var key = pbkdf2.GetBytes(config.KEY_LENGTH / 8);
var plainBytes = DataSerializer(data);
using var aes = Aes.Create();
aes.KeySize = config.KEY_LENGTH;
aes.Key = key;
aes.Mode = CipherMode.CBC;
aes.Padding = PaddingMode.PKCS7;
ICryptoTransform encryptor = aes.CreateEncryptor(aes.Key, aes.IV);
byte[] cipherBytes = encryptor.TransformFinalBlock(plainBytes, 0, plainBytes.Length);
// var plainBytes = DataSerializer(extraData);
// using var aes = Aes.Create();
// aes.KeySize = config.KEY_LENGTH;
// aes.Key = key;
// aes.Mode = CipherMode.CBC;
// aes.Padding = PaddingMode.PKCS7;
// ICryptoTransform encryptor = aes.CreateEncryptor(aes.Key, aes.IV);
// byte[] cipherBytes = encryptor.TransformFinalBlock(plainBytes, 0, plainBytes.Length);
var encryptedBytes = new byte[aes.IV.Length + cipherBytes.Length];
Array.Copy(aes.IV, 0, encryptedBytes, 0, aes.IV.Length);
Array.Copy(cipherBytes, 0, encryptedBytes, aes.IV.Length, cipherBytes.Length);
// var encryptedBytes = new byte[aes.IV.Length + cipherBytes.Length];
// Array.Copy(aes.IV, 0, encryptedBytes, 0, aes.IV.Length);
// Array.Copy(cipherBytes, 0, encryptedBytes, aes.IV.Length, cipherBytes.Length);
return encryptedBytes;
}
// return encryptedBytes;
// }
private T DecryptAdditionalData(string pwd, byte[] salt, byte[] encryptedData) {
var pbkdf2 = new Rfc2898DeriveBytes(Encoding.UTF8.GetBytes(pwd), salt, config.PBKDF2_ITERATIONS, HashAlgorithmName.SHA256);
var key = pbkdf2.GetBytes(config.KEY_LENGTH / 8);
// private TExtraData DecryptExtraData(string pwd, byte[] salt, byte[] encryptedData) {
// var pbkdf2 = new Rfc2898DeriveBytes(Encoding.UTF8.GetBytes(pwd), salt, config.PBKDF2_ITERATIONS, HashAlgorithmName.SHA256);
// var key = pbkdf2.GetBytes(config.KEY_LENGTH / 8);
using var aes = Aes.Create();
aes.KeySize = config.KEY_LENGTH;
aes.Key = key;
aes.Mode = CipherMode.CBC;
aes.Padding = PaddingMode.PKCS7;
var iv = new byte[aes.BlockSize / 8];
var cipherBytes = new byte[encryptedData.Length - iv.Length];
// using var aes = Aes.Create();
// aes.KeySize = config.KEY_LENGTH;
// aes.Key = key;
// aes.Mode = CipherMode.CBC;
// aes.Padding = PaddingMode.PKCS7;
// var iv = new byte[aes.BlockSize / 8];
// var cipherBytes = new byte[encryptedData.Length - iv.Length];
Array.Copy(encryptedData, 0, iv, 0, iv.Length);
Array.Copy(encryptedData, iv.Length, cipherBytes, 0, cipherBytes.Length);
// Array.Copy(encryptedData, 0, iv, 0, iv.Length);
// Array.Copy(encryptedData, iv.Length, cipherBytes, 0, cipherBytes.Length);
aes.IV = iv;
ICryptoTransform decryptor = aes.CreateDecryptor(aes.Key, aes.IV);
byte[] plainBytes = decryptor.TransformFinalBlock(cipherBytes, 0, cipherBytes.Length);
// aes.IV = iv;
// ICryptoTransform decryptor = aes.CreateDecryptor(aes.Key, aes.IV);
// byte[] plainBytes = decryptor.TransformFinalBlock(cipherBytes, 0, cipherBytes.Length);
// return DataDeserializer(plainBytes);
// }
//}
return DataDeserializer(plainBytes);
}
}

View File

@ -1,29 +1,19 @@
using System.Collections.ObjectModel;
using System.Net;
using System.Net;
namespace SimpleHttpServer.Types;
namespace SimpleHttpServer;
public class RequestContext : IDisposable {
public HttpListenerContext ListenerContext { get; }
public ReadOnlyDictionary<string, string> ParsedParameters { get; internal set; }
private TextReader? reqReader;
/// <summary>
/// THREADSAFE
/// </summary>
public TextReader ReqReader => reqReader ??= TextReader.Synchronized(new StreamReader(ListenerContext.Request.InputStream));
private StreamReader? reqReader;
public StreamReader ReqReader => reqReader ??= new(ListenerContext.Request.InputStream);
private TextWriter? respWriter;
/// <summary>
/// THREADSAFE
/// </summary>
public TextWriter RespWriter => respWriter ??= TextWriter.Synchronized(new StreamWriter(ListenerContext.Response.OutputStream) { NewLine = "\n" });
private StreamWriter? respWriter;
public StreamWriter RespWriter => respWriter ??= new(ListenerContext.Response.OutputStream) { NewLine = "\n" };
#pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable.
public RequestContext(HttpListenerContext listenerContext) {
ListenerContext = listenerContext;
}
#pragma warning restore CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable.
public async Task WriteLineToRespAsync(string resp) => await RespWriter.WriteLineAsync(resp);
public async Task WriteToRespAsync(string resp) => await RespWriter.WriteAsync(resp);
@ -35,46 +25,27 @@ public class RequestContext : IDisposable {
public void SetStatusCode(HttpStatusCode status) => SetStatusCode((int) status);
public async Task SetStatusCodeWriteLineDisposeAsync(HttpStatusCode status, string message) {
SetStatusCode(status);
await WriteLineToRespAsync(message);
await RespWriter.FlushAsync();
}
public async Task SetStatusCodeAndDisposeAsync(int status) {
using (this) {
public void SetStatusCodeAndDispose(int status) {
using (this)
SetStatusCode(status);
await WriteToRespAsync("\n\n");
await RespWriter.FlushAsync();
}
}
public async Task SetStatusCodeAndDisposeAsync(HttpStatusCode status) {
using (this) {
public void SetStatusCodeAndDispose(HttpStatusCode status) {
using (this)
SetStatusCode((int) status);
await WriteToRespAsync("\n\n");
await RespWriter.FlushAsync();
}
}
public async Task SetStatusCodeAndDisposeAsync(int status, string description) {
public void SetStatusCodeAndDispose(int status, string description) {
using (this) {
ListenerContext.Response.StatusCode = status;
ListenerContext.Response.StatusDescription = description;
await WriteToRespAsync("\n\n");
await RespWriter.FlushAsync();
}
}
public async Task SetStatusCodeAndDisposeAsync(HttpStatusCode status, string description) => await SetStatusCodeAndDisposeAsync((int) status, description);
public void SetStatusCodeAndDispose(HttpStatusCode status, string description) => SetStatusCodeAndDispose((int) status, description);
public async Task WriteRedirect302AndDisposeAsync(string url) {
ListenerContext.Response.AddHeader("Location", url);
await SetStatusCodeAndDisposeAsync(HttpStatusCode.Redirect);
}
public void Dispose() {
void IDisposable.Dispose() {
reqReader?.Dispose();
respWriter?.Dispose();
GC.SuppressFinalize(this);

View File

@ -7,6 +7,7 @@
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Konscious.Security.Cryptography.Argon2" Version="1.3.0" />
<PackageReference Include="Newtonsoft.Json" Version="13.0.3" />
</ItemGroup>

View File

@ -13,10 +13,6 @@ public class SimpleHttpServerConfiguration {
/// See description of <see cref="DisableLogMessagePrinting"/>
/// </summary>
public CustomLogMessageHandler? LogMessageHandler { get; init; } = null;
/// <summary>
/// If set to true, paths ending with / are identical to paths without said trailing slash. E.g. /index is then the same as /index/
/// </summary>
public bool TrimTrailingSlash { get; init; } = true;
public SimpleHttpServerConfiguration() { }

View File

@ -1,118 +0,0 @@
using System.Net;
using System.Reflection;
namespace SimpleHttpServer.Types;
public abstract class InternalEndpointCheckAttribute : Attribute {
public InternalEndpointCheckAttribute() {
CheckSharedVariables();
}
private void CheckSharedVariables() {
foreach (var f in GetType().GetRuntimeFields()) {
if (f.FieldType.IsAssignableTo(typeof(SharedVariable))) {
if (!f.IsInitOnly) {
throw new Exception($"Found non-readonly global field {f}!");
}
if (f.GetValue(this) == null) {
throw new Exception("Global fields must be assigned in the CCTOR!");
}
}
}
}
private void Initialize(object? instance, Dictionary<FieldInfo, List<(InternalEndpointCheckAttribute, SharedVariable)>> globals) {
SetInstance(instance);
foreach (var f in GetType().GetRuntimeFields()) {
if (f.FieldType.IsAssignableTo(typeof(SharedVariable))) {
SharedVariable origVal = (SharedVariable) f.GetValue(this)!;
if (globals.TryGetValue(f, out var options)) {
bool foundMatch = false;
foreach ((var checker, var gv) in options) {
if (Match(checker)) {
foundMatch = true;
// we need to unify their global variables
f.SetValue(this, gv);
}
}
if (!foundMatch) {
options.Add((this, origVal));
}
} else {
globals.Add(f, new List<(InternalEndpointCheckAttribute, SharedVariable)>() { (this, origVal) });
}
}
}
}
public static void Initialize(object? instance, IEnumerable<InternalEndpointCheckAttribute> endPointChecks) {
Dictionary<FieldInfo, List<(InternalEndpointCheckAttribute, SharedVariable)>> globals = new();
foreach (var check in endPointChecks) {
check.Initialize(instance, globals);
}
}
private interface SharedVariable {
// Tagging interface
}
/// <summary>
/// Represents a Mutable Shared Variable. Fields of this type need to be initialized in the CCtor.
/// </summary>
protected sealed class MSV<V> : SharedVariable {
private readonly V __default;
public V Val { get; set; } = default!;
public MSV() : this(default!) { }
public MSV(V _default) {
__default = _default;
}
public static implicit operator V(MSV<V> v) => v.Val;
}
/// <summary>
/// Represents an Immutable Shared Variable. Fields of this type need to be initialized in the CCtor.
/// </summary>
protected sealed class ISV<V> : SharedVariable {
private readonly V __default;
public V Val { get; } = default!;
public ISV() : this(default!) { }
public ISV(V _default) {
__default = _default;
}
public static implicit operator V(ISV<V> v) => v.Val;
}
/// <summary>
/// Executed when the endpoint is invoked. The endpoint invocation is skipped if any of the checks fail.
/// </summary>
/// <returns>True to allow invocation, false to prevent.</returns>
public abstract bool Check(HttpListenerRequest req);
protected virtual bool Match(InternalEndpointCheckAttribute other) => true;
internal abstract void SetInstance(object? instance);
}
[AttributeUsage(AttributeTargets.Method | AttributeTargets.Class, Inherited = true, AllowMultiple = true)]
public abstract class BaseEndpointCheckAttribute<T> : InternalEndpointCheckAttribute {
/// <summary>
/// A reference to the instance of the class that this attribute is attached to.
/// Will be null iff an class factory was passed in <see cref="HttpServer.RegisterEndpointsFromType{T}(Func{T}?)"/>.
/// </summary>
protected internal T? EndpointClassInstance { get; internal set; } = default;
public BaseEndpointCheckAttribute() : base() { }
internal override void SetInstance(object? instance) {
if (instance != null)
EndpointClassInstance = (T?) instance;
}
}

View File

@ -1,47 +1,12 @@
using System.Net;
using System.Reflection;
using System.Reflection;
namespace SimpleHttpServer.Types;
internal record EndpointInvocationInfo {
//internal record struct QueryParameterInfo(string Name, Type Type, bool isPathParam, bool Path_isCatchAll, bool Query_IsOptional) {
// public static QueryParameterInfo CreatePathParam(string name, Type type) => new(name, type, false, name == "$*", false);
// public static QueryParameterInfo CreateQueryParam(string name, Type type, bool isOptional) => new(name, type, false, false, isOptional);
//}
internal record struct PathParameterInfo(string Name, Type Type, int ArgPos, int SegmentStartPos, bool IsCatchAll) {
public PathParameterInfo(string name, Type type, int argPos) : this(name, type, argPos, -1, name == "$*") { }
}
internal record struct QueryParameterInfo(string Name, Type Type, int ArgPos, bool IsOptional);
internal struct EndpointInvocationInfo {
internal readonly MethodInfo methodInfo;
internal readonly List<QueryParameterInfo> queryParameters;
internal readonly List<PathParameterInfo> pathParameters;
internal readonly InternalEndpointCheckAttribute[] requiredChecks;
/// <summary>
/// a reference to the object in which this method is defined (or null if the class is static)
/// </summary>
internal readonly object? typeInstanceReference;
public EndpointInvocationInfo(MethodInfo methodInfo, List<PathParameterInfo> pathParameters, List<QueryParameterInfo> queryParameters, InternalEndpointCheckAttribute[] requiredChecks,
object? typeInstanceReference) {
internal readonly List<(string, (Type type, bool isOptional))> queryParameters;
public EndpointInvocationInfo(MethodInfo methodInfo, List<(string, (Type type, bool isOptional))> queryParameters) {
this.methodInfo = methodInfo ?? throw new ArgumentNullException(nameof(methodInfo));
this.queryParameters = queryParameters ?? throw new ArgumentNullException(nameof(queryParameters));
this.pathParameters = pathParameters ?? throw new ArgumentNullException(nameof(pathParameters));
this.requiredChecks = requiredChecks;
this.typeInstanceReference = typeInstanceReference;
if (pathParameters.Any()) {
Assert(pathParameters.Count(x => x.IsCatchAll) <= 1); // at most one catchall parameter
var argPoses = pathParameters.Select(x => x.ArgPos).Concat(queryParameters.Select(x => x.ArgPos)).ToArray();
var argCnt = pathParameters.Count + queryParameters.Count;
Assert(argPoses.Distinct().Count() == argCnt); // ArgPoses must be unique
Assert(argPoses.Min() == HttpServer.expectedEndpointParameterPrefixCount); // ArgPoses must start from just after the prefix
Assert(argPoses.Max() == HttpServer.expectedEndpointParameterPrefixCount + argCnt - 1); // ArgPoses must be contiguous
Assert(pathParameters.All(x => x.SegmentStartPos != -1));
}
}
public bool CheckAll(HttpListenerRequest req) => requiredChecks.All(x => x.Check(req));
}

View File

@ -1,22 +0,0 @@
using System.Diagnostics.CodeAnalysis;
namespace SimpleHttpServer.Types;
internal class MultiKeyDictionary<K1, K2, V> where K1 : notnull where K2 : notnull {
internal readonly Dictionary<K1, Dictionary<K2, V>> backingDict = new();
public MultiKeyDictionary() { }
public void Add(K1 k1, K2 k2, V value) {
if (!backingDict.TryGetValue(k1, out var d2))
d2 = new();
d2.Add(k2, value);
backingDict[k1] = d2;
}
public bool TryGetValue(K1 k1, K2 k2, [MaybeNullWhen(false)] out V value) {
if (backingDict.TryGetValue(k1, out var d2) && d2.TryGetValue(k2, out value))
return true;
value = default;
return false;
}
}

View File

@ -5,6 +5,9 @@
/// </summary>
[AttributeUsage(AttributeTargets.Parameter, Inherited = false, AllowMultiple = false)]
public sealed class ParameterAttribute : Attribute {
// See the attribute guidelines at
// http://go.microsoft.com/fwlink/?LinkId=85236
public string Name { get; }
public bool IsOptional { get; }
public ParameterAttribute(string name, bool isOptional = false) {

View File

@ -1,7 +0,0 @@
namespace SimpleHttpServer.Types.ParameterConverters;
internal class StringParameterConverter : IParameterConverter {
public bool TryConvertFromString(string value, out object result) {
result = value;
return true;
}
}

View File

@ -1,28 +0,0 @@
namespace SimpleHttpServer.Types;
/// <summary>
/// Specifies the name of a http endpoint path parameter. Path parameter names must be in the format $1, $2, $3, ..., and the end of the path may be $*
/// </summary>
[AttributeUsage(AttributeTargets.Parameter, Inherited = false, AllowMultiple = false)]
public sealed class PathParameterAttribute : Attribute {
public string Name { get; }
public PathParameterAttribute(string name) {
if (string.IsNullOrWhiteSpace(name)) {
throw new ArgumentException($"'{nameof(name)}' cannot be null or whitespace.", nameof(name));
}
if (!name.StartsWith('$')) {
throw new ArgumentException($"'{nameof(name)}' must start with $.", nameof(name));
}
if (name.Contains(' ')) {
throw new ArgumentException($"'{nameof(name)}' must not contain spaces.", nameof(name));
}
if (!uint.TryParse(name[1..], out _) && name != "$*") {
throw new ArgumentException($"'{nameof(name)}' must only consist of spaces or be exactly '$*'.", nameof(name));
}
Name = name;
}
}

View File

@ -1,104 +0,0 @@
using System.Data;
using System.Diagnostics.CodeAnalysis;
namespace SimpleHttpServer.Types;
internal class PathTree<T> where T : class {
private readonly Node? rootNode = null;
public PathTree() : this(new()) { }
public PathTree(Dictionary<string, T> dict) {
if (dict == null || dict.Count == 0)
return;
rootNode = new();
var currNode = rootNode;
var unpackedPaths = dict.Keys.Select(p => p.Split('/').ToArray()).ToArray();
var unpackedLeafData = dict.Values.ToArray();
for (int i = 0; i < unpackedPaths.Length; i++) {
var path = unpackedPaths[i];
var catchallidx = Array.IndexOf(path, "$*");
if (catchallidx != -1 && catchallidx != path.Length - 1) {
throw new Exception($"Found illegal catchall-wildcard in path: '{string.Join('/', path)}'");
}
var leafdata = unpackedLeafData[i] ?? throw new ArgumentNullException("Leafdata must not be null!");
rootNode.AddSuccessor(path, leafdata);
}
}
internal bool TryGetPath(string reqPath, [MaybeNullWhen(false)] out T endpoint) {
if (rootNode == null) {
endpoint = null;
return false;
}
// try to find path-match
Node currNode = rootNode;
Assert(reqPath[0] == '/');
var splittedPath = reqPath[1..].Split("/");
Node? lastCatchallNode = null;
for (int i = 0; i < splittedPath.Length; i++) {
// keep track of the current best catchallNode
if (currNode.catchAllNext != null) {
lastCatchallNode = currNode.catchAllNext;
}
var seg = splittedPath[i];
if (currNode.next?.TryGetValue(seg, out var next) == true) { // look for an explicit path to follow greedily
currNode = next;
} else if (currNode.pathWildcardNext != null) { // otherwise look for a single-wildcard to follow
currNode = currNode.pathWildcardNext;
} else { // otherwise we are done, there is no valid path --> fall back to the most specific catchall
endpoint = lastCatchallNode?.leafData;
return lastCatchallNode != null;
}
}
// return found path
endpoint = currNode.leafData;
return endpoint != null;
}
private class Node {
public T? leafData = null; // null means that this is a node without a value (e.g. when it is just part of a path)
public Dictionary<string, Node>? next = null;
public Node? pathWildcardNext = null; // path wildcard
public Node? catchAllNext = null; // trailing-catchall wildcard
public void AddSuccessor(string[] segments, T newLeafData) {
if (segments.Length == 0) { // actually add the data to this node
Assert(leafData == null);
leafData = newLeafData;
return;
}
var seg = segments[0];
bool newIsWildcard = seg.Length > 1 && seg[0] == '$';
if (newIsWildcard) {
bool newIsCatchallWildcard = newIsWildcard && seg.Length == 2 && seg[1] == '*';
if (newIsCatchallWildcard) { // this is a catchall wildcard
Assert(catchAllNext == null);
catchAllNext = new();
catchAllNext.AddSuccessor(segments[1..], newLeafData);
return;
} else { // must be single wildcard otherwise
pathWildcardNext ??= new();
pathWildcardNext.AddSuccessor(segments[1..], newLeafData);
return;
}
}
// otherwise we want to add a new constant path successor
next ??= new();
if (next.TryGetValue(seg, out var existingNode)) {
existingNode.AddSuccessor(segments[1..], newLeafData);
} else {
var newNode = next[seg] = new();
newNode.AddSuccessor(segments[1..], newLeafData);
}
}
}
}

View File

@ -1,6 +1,4 @@
using SimpleHttpServer;
using SimpleHttpServer.Types;
using System.Net;
namespace SimpleHttpServerTest;
@ -10,36 +8,19 @@ public class SimpleServerTest {
const int PORT = 8833;
private HttpServer? activeServer = null;
private HttpClient? activeHttpClient = null;
private bool failOnLogError = true;
private static string GetRequestPath(string url) => $"http://localhost:{PORT}/{url.TrimStart('/')}";
private async Task RequestGetStringAsync(string path) => await activeHttpClient!.GetStringAsync(GetRequestPath(path));
private async Task<HttpResponseMessage> AssertGetStatusCodeAsync(string path, HttpStatusCode statusCode) {
var resp = await activeHttpClient!.GetAsync(GetRequestPath(path));
Assert.AreEqual(statusCode, resp.StatusCode);
return resp;
}
[TestInitialize]
public void Init() {
var conf = new SimpleHttpServerConfiguration() {
DisableLogMessagePrinting = false,
LogMessageHandler = (LogOutputTopic topic, string message, LogOutputLevel logLevel) => {
if (failOnLogError && logLevel is LogOutputLevel.Error or LogOutputLevel.Fatal)
Assert.Fail($"An error was thrown in the log output:\n{topic} {message}");
}
};
var conf = new SimpleHttpServerConfiguration();
if (activeServer != null)
throw new InvalidOperationException("Tried to create another httpserver instance when an existing one was already running.");
Console.WriteLine("Starting server...");
failOnLogError = true;
activeServer = new HttpServer(PORT, conf);
activeServer.RegisterEndpointsFromType<TestEndpoints>();
activeServer.Start();
activeHttpClient = new HttpClient();
Console.WriteLine("Server started.");
}
@ -52,87 +33,20 @@ public class SimpleServerTest {
}
await Console.Out.WriteLineAsync("Shutting down server...");
await activeServer.StopAsync(ctokSrc.Token);
activeHttpClient?.Dispose();
activeHttpClient = null;
await Console.Out.WriteLineAsync("Shutdown finished.");
}
static string GetHttpPageContentFromPrefix(string page)
=> $"It works!!!!!!56sg5sdf46a4sd65a412f31sdfgdf89h74g9f8h4as56d4f56as2as1f3d24f87g9d87{page}";
[TestMethod]
public async Task CheckSimpleServe() {
var resp = await AssertGetStatusCodeAsync("/", HttpStatusCode.OK);
var str = await resp.Content.ReadAsStringAsync();
Assert.AreEqual("It works!", str);
}
[TestMethod]
public async Task CheckMultiServe() {
foreach (var item in "index2.html;testpage;testpage2;testpage3".Split(';')) {
await Console.Out.WriteLineAsync($"Checking page: /{item}");
var resp = await AssertGetStatusCodeAsync(item, HttpStatusCode.OK);
var str = await resp.Content.ReadAsStringAsync();
Assert.AreEqual(GetHttpPageContentFromPrefix(item), str);
}
}
[TestMethod]
public async Task CheckQueryArgs() {
foreach (var a1 in "test1;longstring2;something else with a space".Split(';')) {
foreach (var a2 in new[] { -10, 2, -2, 5, 0, 4 }) {
foreach (var a3 in new[] { -1, 9, 2, -20, 0 }) {
foreach (var a4 in new[] { -1, 9, 0 }) {
foreach (var page in "returnqueries;returnqueries2".Split(';')) {
var resp = await AssertGetStatusCodeAsync($"{page}?arg1={a1}&arg2={a2}&arg3={a3}&arg4={a4}", HttpStatusCode.OK);
var str = await resp.Content.ReadAsStringAsync();
Assert.AreEqual(TestEndpoints.GetReturnQueryPageResult(a1, a2, page == "returnqueries2" ? (a3 + a4) : a3), str);
}
}
}
}
}
using var hc = new HttpClient();
await hc.GetStringAsync(GetRequestPath("/"));
}
public class TestEndpoints {
[HttpEndpoint(HttpRequestType.GET, "/", "index.html")]
[HttpEndpoint(HttpRequestType.GET, "/", "index.html", "amogus.html")]
public static async Task Index(RequestContext req) {
await req.RespWriter.WriteAsync("It works!");
}
[HttpEndpoint(HttpRequestType.GET, "index2.html")]
public static async Task Index2(RequestContext req) {
await req.RespWriter.WriteAsync(GetHttpPageContentFromPrefix("index2.html"));
}
[HttpEndpoint(HttpRequestType.GET, "/testpage")]
public static async Task TestPage(RequestContext req) {
await req.RespWriter.WriteAsync(GetHttpPageContentFromPrefix("testpage"));
}
[HttpEndpoint(HttpRequestType.GET, "testpage2")]
public static async Task TestPage2(RequestContext req) {
await req.RespWriter.WriteAsync(GetHttpPageContentFromPrefix("testpage2"));
}
[HttpEndpoint(HttpRequestType.GET, "/testpage3")]
public static async Task TestPage3(RequestContext req) {
await req.RespWriter.WriteAsync(GetHttpPageContentFromPrefix("testpage3"));
}
public static string GetReturnQueryPageResult(string arg1, int arg2, int arg3) => $"{arg1};{arg2 * 2 - arg3 * 5}";
[HttpEndpoint(HttpRequestType.GET, "/returnqueries")]
public static async Task ReturnQueriesPage(RequestContext req, string arg1, int arg2, int arg3) {
await req.RespWriter.WriteAsync(GetReturnQueryPageResult(arg1, arg2, arg3));
}
[HttpEndpoint(HttpRequestType.GET, "/returnqueries2")]
public static async Task ReturnQueriesPage2(RequestContext req,
[Parameter("arg2")] int arg1, [Parameter("arg1")] string arg2, int arg3, [Parameter("arg4", true)] int arg4) {
// arg4 should be equal to zero as it should get the deafult value because it is not passed to the server
await req.RespWriter.WriteAsync(GetReturnQueryPageResult(arg2, arg1, arg3 + arg4));
await req.RespWriter.WriteLineAsync("It works!");
}
}
}