using SimpleHttpServer.Types; using SimpleHttpServer.Types.Exceptions; using SimpleHttpServer.Types.ParameterConverters; using System.Net; using System.Numerics; using System.Reflection; using System.Text; using static SimpleHttpServer.Types.EndpointInvocationInfo; namespace SimpleHttpServer; public sealed class HttpServer { public int Port { get; } private readonly HttpListener listener; private Task? listenerTask; private readonly Logger mainLogger; private readonly Logger requestLogger; private readonly SimpleHttpServerConfiguration conf; private bool shutdown = false; public HttpServer(int port, SimpleHttpServerConfiguration configuration) { Port = port; conf = configuration; listener = new HttpListener(); listener.Prefixes.Add($"http://localhost:{port}/"); mainLogger = new(LogOutputTopic.Main, conf); requestLogger = new(LogOutputTopic.Request, conf); } public void Start() { mainLogger.Information($"Starting on port {Port}..."); Assert(listenerTask == null, "Server was already started!"); listener.Start(); listenerTask = Task.Run(GetContextLoopAsync); mainLogger.Information($"Ready to handle requests!"); } public async Task StopAsync(CancellationToken ctok) { mainLogger.Information("Stopping server..."); Assert(listenerTask != null, "Server was not started!"); shutdown = true; listener.Stop(); await listenerTask.WaitAsync(ctok); } public async Task GetContextLoopAsync() { while (!shutdown) { try { var ctx = await listener.GetContextAsync(); _ = ProcessRequestAsync(ctx); } catch (HttpListenerException ex) when (ex.ErrorCode == 995) { //The I/O operation has been aborted because of either a thread exit or an application request } catch (Exception ex) { mainLogger.Fatal($"Caught otherwise uncaught exception in GetContextLoop:\n{ex}"); } } } private void RegisterDefaultConverters() { void RegisterConverter() where T : IParsable { stringToTypeParameterConverters.Add(typeof(T), new ParsableParameterConverter()); } stringToTypeParameterConverters.Add(typeof(string), new StringParameterConverter()); stringToTypeParameterConverters.Add(typeof(bool), new BoolParsableParameterConverter()); RegisterConverter(); RegisterConverter(); RegisterConverter(); RegisterConverter(); RegisterConverter(); RegisterConverter(); RegisterConverter(); RegisterConverter(); RegisterConverter(); RegisterConverter(); RegisterConverter(); RegisterConverter(); RegisterConverter(); RegisterConverter(); RegisterConverter(); RegisterConverter(); } private readonly MultiKeyDictionary simpleEndpointMethodInfos = new(); // requestmethod, path private readonly MultiKeyDictionary pathEndpointMethodInfos = new(); // requestmethod, path private readonly Dictionary> pathEndpointMethodInfosTrees = new(); // reqmethod : pathtree private static readonly Type[] expectedEndpointParameterTypes = new[] { typeof(RequestContext) }; internal static readonly int expectedEndpointParameterPrefixCount = expectedEndpointParameterTypes.Length; public void RegisterEndpointsFromType(Func? instanceFactory = null) where T : class { // T cannot be static, as generic args must be nonstatic if (stringToTypeParameterConverters.Count == 0) RegisterDefaultConverters(); var t = typeof(T); var mis = t.GetMethods() .ToDictionary(x => x, x => x.GetCustomAttributes()) .Where(x => x.Value.Any()).ToDictionary(x => x.Key, x => x.Value.Single()); var isStatic = mis.All(x => x.Key.IsStatic); // if all are static then there is no point in having a constructor as no instance data is accessible, but we allow passing a factory anyway Assert(isStatic || (instanceFactory != null), $"You must provide an instance factory if any methods of the given type ({typeof(T).FullName}) are non-static"); T? classInstance = instanceFactory?.Invoke(); foreach (var (mi, attrib) in mis) { string GetFancyMethodName() => mi.DeclaringType!.FullName + "#" + mi.Name; //Assert(mi.IsStatic, $"Method tagged with HttpEndpointAttribute must be static! ({GetFancyMethodName()})"); Assert(mi.IsPublic, $"Method tagged with HttpEndpointAttribute must be public! ({GetFancyMethodName()})"); var methodParams = mi.GetParameters(); // check the mandatory prefix parameters Assert(methodParams.Length >= expectedEndpointParameterTypes.Length); for (int i = 0; i < expectedEndpointParameterTypes.Length; i++) { Assert(methodParams[i].ParameterType.IsAssignableFrom(expectedEndpointParameterTypes[i]), $"Parameter at index {i} of {GetFancyMethodName()} is of a type that cannot contain the expected type {expectedEndpointParameterTypes[i].FullName}."); } // check return type Assert(mi.ReturnType == typeof(Task), $"Return type of {GetFancyMethodName()} is not {typeof(Task)}!"); // check the rest of the method parameters var qparams = new List(); var pparams = new List(); int mParamIndex = expectedEndpointParameterTypes.Length; for (int i = expectedEndpointParameterTypes.Length; i < methodParams.Length; i++) { var par = methodParams[i]; var attr = par.GetCustomAttribute(false); var pathAttr = par.GetCustomAttribute(false); if (attr != null && pathAttr != null) { throw new ArgumentException($"A method argument cannot be tagged with both {nameof(ParameterAttribute)} and {nameof(PathParameterAttribute)}"); } if (!stringToTypeParameterConverters.ContainsKey(par.ParameterType)) { throw new MissingParameterConverterException($"Parameter converter for type {par.ParameterType} for parameter at index {i} of method {GetFancyMethodName()} has not been registered (yet)!"); } if (pathAttr != null) { // parameter is a path param pparams.Add(new( pathAttr?.Name ?? throw new ArgumentException($"C# variable name of path parameter at index {i} of method {GetFancyMethodName()} is null!"), par.ParameterType, mParamIndex++ ) ); } else { // parameter is a normal query param qparams.Add(new( attr?.Name ?? par.Name ?? throw new ArgumentException($"C# variable name of query parameter at index {i} of method {GetFancyMethodName()} is null!"), par.ParameterType, mParamIndex++, attr?.IsOptional ?? false) ); } } // stores the check attributes that are defined on the method and on the containing class InternalEndpointCheckAttribute[] requiredChecks = mi.GetCustomAttributes(true) .Concat(mi.DeclaringType?.GetCustomAttributes(true) ?? Enumerable.Empty()) .Where(a => a.GetType().IsAssignableTo(typeof(InternalEndpointCheckAttribute))) .Cast().ToArray(); InternalEndpointCheckAttribute.Initialize(classInstance, requiredChecks); foreach (var location in attrib.Locations) { var normLocation = NormalizeUrlPath(location); var reqMethod = Enum.GetName(attrib.RequestMethod) ?? throw new ArgumentException("Request method was undefined"); var pparamsCopy = new List(pparams); var splittedLocation = location[1..].Split('/'); for (int i = 0; i < pparamsCopy.Count; i++) { var pp = pparamsCopy[i]; var idx = Array.IndexOf(splittedLocation, pp.Name); Assert(idx != -1, "Path parameter name was incorrect?"); pp.SegmentStartPos = idx; pparamsCopy[i] = pp; } var epInvocInfo = new EndpointInvocationInfo(mi, pparamsCopy, qparams, requiredChecks, classInstance); if (pparams.Any()) { mainLogger.Information($"Registered path endpoint: '{reqMethod} {normLocation}'"); Assert(normLocation[0] == '/'); pathEndpointMethodInfos.Add(reqMethod, normLocation[1..], epInvocInfo); } else { mainLogger.Information($"Registered simple endpoint: '{reqMethod} {normLocation}'"); simpleEndpointMethodInfos.Add(reqMethod, normLocation, epInvocInfo); } } } // rebuild path trees pathEndpointMethodInfosTrees.Clear(); foreach (var (reqMethod, d2) in pathEndpointMethodInfos.backingDict) pathEndpointMethodInfosTrees.Add(reqMethod, new(d2)); } /// /// Serves all files located in on a website path that is relative to , /// while restricting requests to inside the local filesystem directory. Static serving has a lower priority than registering an endpoint. /// /// /// public void RegisterStaticServePath(string requestPath, string filesystemDirectory) { var absPath = Path.GetFullPath(filesystemDirectory); string npath = NormalizeUrlPath(requestPath); mainLogger.Information($"Registered static serve path: '{npath}' --> '{absPath}'"); staticServePaths.Add(npath, absPath); } private readonly Dictionary staticServePaths = new(); private readonly Dictionary stringToTypeParameterConverters = new(); private static string NormalizeUrlPath(string url) { var fwdSlashUrl = url.Replace('\\', '/'); var segments = fwdSlashUrl.Trim('/').Split('/', StringSplitOptions.RemoveEmptyEntries).ToList(); List simplifiedSegmentsReversed = new List(); int doubleDotsEncountered = 0; for (int i = segments.Count - 1; i >= 0; i--) { var segment = segments[i]; if (segment == ".") { continue; // remove single dot segments } if (segment == "..") { doubleDotsEncountered++; // if we encounter a doubledot, keep track of that and dont add it to the output yet continue; } // otherwise only keep the segment if doubleDotsEncountered > 0 if (doubleDotsEncountered > 0) { doubleDotsEncountered--; continue; } simplifiedSegmentsReversed.Add(segment); } var rv = new StringBuilder(); for (int i = 0; i < doubleDotsEncountered; i++) { rv.Append("../"); } rv.AppendJoin('/', simplifiedSegmentsReversed.Reverse()); return '/' + (rv.ToString().TrimEnd('/') + (fwdSlashUrl.EndsWith('/') ? "/" : "")).TrimStart('/'); } private async Task ProcessRequestAsync(HttpListenerContext ctx) { using RequestContext rc = new RequestContext(ctx); // TODO add path escape countermeasure-unittests var splitted = (ctx.Request.RawUrl ?? "").Split('?', 2, StringSplitOptions.None); var reqPath = NormalizeUrlPath(WebUtility.UrlDecode(splitted.First())); string requestMethod = ctx.Request.HttpMethod.ToUpperInvariant(); bool wasStaticlyServed = false; void LogRequest() { requestLogger.Information($"{rc.ListenerContext.Response.StatusCode} {(wasStaticlyServed ? "static" : "endpnt")} {requestMethod} {ctx.Request.Url}"); } try { /* Finding the endpoint that should process the request: * 1. Try to see if there is a simple endpoint where request method and path match * 2. Otherwise, try to see if a path-parameter-endpoint matches (duplicates throw an error on startup) * 3. Otherwise, check if it is inside a static serve path * 4. Otherwise, show 404 page */ EndpointInvocationInfo? pathEndpointInvocationInfo = null; if (simpleEndpointMethodInfos.TryGetValue(requestMethod, reqPath, out var simpleEndpointInvocationInfo) || pathEndpointMethodInfosTrees.TryGetValue(requestMethod, out var pt) && pt.TryGetPath(reqPath, out pathEndpointInvocationInfo)) { // try to find simple or pathparam-endpoint var endpointInvocationInfo = simpleEndpointInvocationInfo ?? pathEndpointInvocationInfo ?? throw new Exception("retrieved endpoint is somehow null"); var mi = endpointInvocationInfo.methodInfo; var qparams = endpointInvocationInfo.queryParameters; var pparams = endpointInvocationInfo.pathParameters; var args = splitted.Length == 2 ? splitted[1] : null; var parsedQParams = new Dictionary(); var convertedMParamValues = new object[expectedEndpointParameterTypes.Length + pparams.Count + qparams.Count]; // run the checks to see if the client is allowed to make this request if (!endpointInvocationInfo.CheckAll(rc.ListenerContext.Request)) { // if any check failed return Forbidden await HandleDefaultErrorPageAsync(rc, HttpStatusCode.Forbidden, "Client is not allowed to access this resource"); return; } if (args != null) { var queryStringArgs = args.Split('&', StringSplitOptions.None); foreach (var queryKV in queryStringArgs) { var queryKVSplitted = queryKV.Split('='); if (queryKVSplitted.Length != 2) { await HandleDefaultErrorPageAsync(rc, HttpStatusCode.BadRequest, "Malformed request URL parameters"); return; } if (!parsedQParams.TryAdd(WebUtility.UrlDecode(queryKVSplitted[0]), WebUtility.UrlDecode(queryKVSplitted[1]))) { await HandleDefaultErrorPageAsync(rc, HttpStatusCode.BadRequest, "Duplicate request URL parameters"); return; } } for (int i = 0; i < qparams.Count;) { var qparam = qparams[i]; i++; if (parsedQParams.TryGetValue(qparam.Name, out var qparamValue)) { if (stringToTypeParameterConverters[qparam.Type].TryConvertFromString(qparamValue, out object objRes)) { convertedMParamValues[qparam.ArgPos] = objRes; } else { await HandleDefaultErrorPageAsync(rc, HttpStatusCode.BadRequest); return; } } else { if (qparam.IsOptional) { convertedMParamValues[qparam.ArgPos] = null!; } else { await HandleDefaultErrorPageAsync(rc, HttpStatusCode.BadRequest, $"Missing required query parameter {qparam.Name}"); return; } } } } else { // check for missing query parameters var requiredParams = qparams.Where(x => !x.IsOptional).Select(x => $"'{x.Name}'").ToList(); if (requiredParams.Any()) { await HandleDefaultErrorPageAsync(rc, HttpStatusCode.BadRequest, $"Missing required query parameter(s): {string.Join(",", requiredParams)}"); return; } } if (pparams.Count != 0) { var splittedReqPath = reqPath[1..].Split('/'); for (int i = 0; i < pparams.Count; i++) { var pparam = pparams[i]; string paramValue; if (pparam.IsCatchAll) paramValue = string.Join('/', splittedReqPath[pparam.SegmentStartPos..]); else paramValue = splittedReqPath[pparam.SegmentStartPos]; if (stringToTypeParameterConverters[pparam.Type].TryConvertFromString(paramValue, out var res)) convertedMParamValues[pparam.ArgPos] = res; else { await HandleDefaultErrorPageAsync(rc, HttpStatusCode.BadRequest); return; } } } convertedMParamValues[0] = rc; rc.ParsedParameters = parsedQParams.AsReadOnly(); // todo read and convert pathparams await (Task) (mi.Invoke(endpointInvocationInfo.typeInstanceReference, convertedMParamValues) ?? throw new NullReferenceException("Website func returned null unexpectedly")); } else { // try to find suitable static serve path if (requestMethod == "GET") foreach (var (k, v) in staticServePaths) { if (reqPath.StartsWith(k)) { // do a static serve wasStaticlyServed = true; var relativeStaticReqPath = reqPath[k.Length..]; var staticResponsePath = Path.GetFullPath(Path.Join(v, relativeStaticReqPath.TrimStart('/'))); if (Path.GetRelativePath(v, staticResponsePath).Contains("..")) { requestLogger.Warning($"Blocked GET request to {reqPath} as somehow the target file does not lie inside the static serve folder? Are you using symlinks?"); await HandleDefaultErrorPageAsync(rc, HttpStatusCode.NotFound); return; } if (File.Exists(staticResponsePath)) { rc.SetStatusCode(HttpStatusCode.OK); if (staticResponsePath.EndsWith(".svg")) { rc.ListenerContext.Response.AddHeader("Content-Type", "image/svg+xml"); } using var f = File.OpenRead(staticResponsePath); await f.CopyToAsync(rc.ListenerContext.Response.OutputStream); } else { await HandleDefaultErrorPageAsync(rc, HttpStatusCode.NotFound); } return; } } // invoke 404 await HandleDefaultErrorPageAsync(rc, 404); } } catch (Exception ex) { await HandleDefaultErrorPageAsync(rc, 500); mainLogger.Fatal($"Caught otherwise uncaught exception while ProcessingRequest:\n{ex}"); } finally { try { await rc.RespWriter.FlushAsync(); } catch (ObjectDisposedException) { } rc.ListenerContext.Response.Close(); LogRequest(); } } private static async Task HandleDefaultErrorPageAsync(RequestContext ctx, HttpStatusCode errorCode, string? statusDescription = null) => await HandleDefaultErrorPageAsync(ctx, (int) errorCode, statusDescription); private static async Task HandleDefaultErrorPageAsync(RequestContext ctx, int errorCode, string? statusDescription = null) { ctx.SetStatusCode(errorCode); string desc = statusDescription != null ? $"\r\n{statusDescription}" : ""; await ctx.WriteLineToRespAsync($"""

Oh no, an error occurred!

Code: {errorCode}

{desc} """); try { if (statusDescription == null) { await ctx.SetStatusCodeAndDisposeAsync(errorCode); } else { await ctx.SetStatusCodeAndDisposeAsync(errorCode, statusDescription); } } catch (ObjectDisposedException) { } } }