using SimpleHttpServer.Types; using SimpleHttpServer.Types.Exceptions; using SimpleHttpServer.Types.ParameterConverters; using System.Net; using System.Numerics; using System.Reflection; using System.Text; using static SimpleHttpServer.Types.EndpointInvocationInfo; namespace SimpleHttpServer; public sealed class HttpServer { public int Port { get; } private readonly HttpListener listener; private Task? listenerTask; private readonly Logger mainLogger; private readonly Logger requestLogger; private readonly SimpleHttpServerConfiguration conf; private bool shutdown = false; public HttpServer(int port, SimpleHttpServerConfiguration configuration) { Port = port; conf = configuration; listener = new HttpListener(); listener.Prefixes.Add($"http://localhost:{port}/"); mainLogger = new(LogOutputTopic.Main, conf); requestLogger = new(LogOutputTopic.Request, conf); } public void Start() { mainLogger.Information($"Starting on port {Port}..."); Assert(listenerTask == null, "Server was already started!"); listener.Start(); listenerTask = Task.Run(GetContextLoopAsync); mainLogger.Information($"Ready to handle requests!"); } public async Task StopAsync(CancellationToken ctok) { mainLogger.Information("Stopping server..."); Assert(listenerTask != null, "Server was not started!"); shutdown = true; listener.Stop(); await listenerTask.WaitAsync(ctok); } public async Task GetContextLoopAsync() { while (!shutdown) { try { var ctx = await listener.GetContextAsync(); _ = ProcessRequestAsync(ctx); } catch (HttpListenerException ex) when (ex.ErrorCode == 995) { //The I/O operation has been aborted because of either a thread exit or an application request } catch (Exception ex) { mainLogger.Fatal($"Caught otherwise uncaught exception in GetContextLoop:\n{ex}"); } } } private void RegisterDefaultConverters() { void RegisterConverter() where T : IParsable { stringToTypeParameterConverters.Add(typeof(T), new ParsableParameterConverter()); } stringToTypeParameterConverters.Add(typeof(string), new StringParameterConverter()); stringToTypeParameterConverters.Add(typeof(bool), new BoolParsableParameterConverter()); RegisterConverter(); RegisterConverter(); RegisterConverter(); RegisterConverter(); RegisterConverter(); RegisterConverter(); RegisterConverter(); RegisterConverter(); RegisterConverter(); RegisterConverter(); RegisterConverter(); RegisterConverter(); RegisterConverter(); RegisterConverter(); RegisterConverter(); RegisterConverter(); } private readonly Dictionary<(string path, string rType), EndpointInvocationInfo> simpleEndpointMethodInfos = new(); private static readonly Type[] expectedEndpointParameterTypes = new[] { typeof(RequestContext) }; public void RegisterEndpointsFromType() { if (simpleEndpointMethodInfos.Count == 0) RegisterDefaultConverters(); var t = typeof(T); foreach (var (mi, attrib) in t.GetMethods() .ToDictionary(x => x, x => x.GetCustomAttributes()) .Where(x => x.Value.Any()).ToDictionary(x => x.Key, x => x.Value.Single())) { string GetFancyMethodName() => mi.DeclaringType!.FullName + "#" + mi.Name; Assert(mi.IsStatic, $"Method tagged with HttpEndpointAttribute must be static! ({GetFancyMethodName()})"); Assert(mi.IsPublic, $"Method tagged with HttpEndpointAttribute must be public! ({GetFancyMethodName()})"); var methodParams = mi.GetParameters(); Assert(methodParams.Length >= expectedEndpointParameterTypes.Length); for (int i = 0; i < expectedEndpointParameterTypes.Length; i++) { Assert(methodParams[i].ParameterType.IsAssignableFrom(expectedEndpointParameterTypes[i]), $"Parameter at index {i} of {GetFancyMethodName()} is of a type that cannot contain the expected type {expectedEndpointParameterTypes[i].FullName}."); } Assert(mi.ReturnType == typeof(Task), $"Return type of {GetFancyMethodName()} is not {typeof(Task)}!"); var qparams = new List(); for (int i = expectedEndpointParameterTypes.Length; i < methodParams.Length; i++) { var par = methodParams[i]; var attr = par.GetCustomAttribute(false); qparams.Add(new( attr?.Name ?? par.Name ?? throw new ArgumentException($"C# variable name of parameter at index {i} of method {GetFancyMethodName()} is null!"), par.ParameterType, attr?.IsOptional ?? false) ); if (!stringToTypeParameterConverters.ContainsKey(par.ParameterType)) { throw new MissingParameterConverterException($"Parameter converter for type {par.ParameterType} has not been registered (yet)!"); } } // stores the check attributes that are defined on the method and on the containing class var requiredChecks = mi.GetCustomAttributes(true).Concat(mi.DeclaringType?.GetCustomAttributes(true) ?? Enumerable.Empty()) .Where(a => a.GetType().IsAssignableTo(typeof(BaseEndpointCheckAttribute))).Cast().ToArray(); foreach (var location in attrib.Locations) { var normLocation = NormalizeUrlPath(location); int idx = normLocation.IndexOf('{'); if (idx >= 0) { // this path contains path parameters throw new NotImplementedException("Path parameters are not yet implemented!"); } var reqMethod = Enum.GetName(attrib.RequestMethod) ?? throw new ArgumentException("Request method was undefined"); mainLogger.Information($"Registered endpoint: '{reqMethod} {normLocation}'"); simpleEndpointMethodInfos.Add((normLocation, reqMethod), new EndpointInvocationInfo(mi, qparams, requiredChecks)); } } } /// /// Serves all files located in on a website path that is relative to , /// while restricting requests to inside the local filesystem directory. Static serving has a lower priority than registering an endpoint. /// /// /// public void RegisterStaticServePath(string requestPath, string filesystemDirectory) { var absPath = Path.GetFullPath(filesystemDirectory); string npath = NormalizeUrlPath(requestPath); mainLogger.Information($"Registered static serve path: '{npath}' --> '{absPath}'"); staticServePaths.Add(npath, absPath); } private readonly Dictionary staticServePaths = new(); private readonly Dictionary stringToTypeParameterConverters = new(); private static string NormalizeUrlPath(string url) { var fwdSlashUrl = url.Replace('\\', '/'); var segments = fwdSlashUrl.Trim('/').Split('/', StringSplitOptions.RemoveEmptyEntries).ToList(); List simplifiedSegmentsReversed = new List(); int doubleDotsEncountered = 0; for (int i = segments.Count - 1; i >= 0; i--) { var segment = segments[i]; if (segment == ".") { continue; // remove single dot segments } if (segment == "..") { doubleDotsEncountered++; // if we encounter a doubledot, keep track of that and dont add it to the output yet continue; } // otherwise only keep the segment if doubleDotsEncountered > 0 if (doubleDotsEncountered > 0) { doubleDotsEncountered--; continue; } simplifiedSegmentsReversed.Add(segment); } var rv = new StringBuilder(); for (int i = 0; i < doubleDotsEncountered; i++) { rv.Append("../"); } rv.AppendJoin('/', simplifiedSegmentsReversed.Reverse()); return '/' + (rv.ToString().TrimEnd('/') + (fwdSlashUrl.EndsWith('/') ? "/" : "")).TrimStart('/'); } private async Task ProcessRequestAsync(HttpListenerContext ctx) { using RequestContext rc = new RequestContext(ctx); // TODO add path escape countermeasure-unittests var splitted = (ctx.Request.RawUrl ?? "").Split('?', 2, StringSplitOptions.None); var reqPath = NormalizeUrlPath(WebUtility.UrlDecode(splitted.First())); string requestMethod = ctx.Request.HttpMethod.ToUpperInvariant(); bool wasStaticlyServed = false; void LogRequest() { requestLogger.Information($"{rc.ListenerContext.Response.StatusCode} {(wasStaticlyServed ? "static" : "endpnt")} {requestMethod} {ctx.Request.Url}"); } try { if (simpleEndpointMethodInfos.TryGetValue((reqPath, requestMethod), out var endpointInvocationInfo)) { var mi = endpointInvocationInfo.methodInfo; var qparams = endpointInvocationInfo.queryParameters; var args = splitted.Length == 2 ? splitted[1] : null; var parsedQParams = new Dictionary(); var convertedQParamValues = new object[qparams.Count + 1]; // run the checks to see if the client is allowed to make this request if (!endpointInvocationInfo.CheckAll(rc.ListenerContext.Request)) { // if any check failed return Forbidden await HandleDefaultErrorPageAsync(rc, HttpStatusCode.Forbidden, "Client is not allowed to access this resource"); return; } if (args != null) { var queryStringArgs = args.Split('&', StringSplitOptions.None); foreach (var queryKV in queryStringArgs) { var queryKVSplitted = queryKV.Split('='); if (queryKVSplitted.Length != 2) { await HandleDefaultErrorPageAsync(rc, HttpStatusCode.BadRequest, "Malformed request URL parameters"); return; } if (!parsedQParams.TryAdd(WebUtility.UrlDecode(queryKVSplitted[0]), WebUtility.UrlDecode(queryKVSplitted[1]))) { await HandleDefaultErrorPageAsync(rc, HttpStatusCode.BadRequest, "Duplicate request URL parameters"); return; } } for (int i = 0; i < qparams.Count;) { var qparam = qparams[i]; i++; if (parsedQParams.TryGetValue(qparam.Name, out var qparamValue)) { if (stringToTypeParameterConverters[qparam.Type].TryConvertFromString(qparamValue, out object objRes)) { convertedQParamValues[i] = objRes; } else { await HandleDefaultErrorPageAsync(rc, HttpStatusCode.BadRequest); return; } } else { if (qparam.IsOptional) { convertedQParamValues[i] = null!; } else { await HandleDefaultErrorPageAsync(rc, HttpStatusCode.BadRequest, $"Missing required query parameter {qparam.Name}"); return; } } } } convertedQParamValues[0] = rc; rc.ParsedParameters = parsedQParams.AsReadOnly(); await (Task) (mi.Invoke(null, convertedQParamValues) ?? throw new NullReferenceException("Website func returned null unexpectedly")); } else { if (requestMethod == "GET") foreach (var (k, v) in staticServePaths) { if (reqPath.StartsWith(k)) { // do a static serve wasStaticlyServed = true; var relativeStaticReqPath = reqPath[k.Length..]; var staticResponsePath = Path.GetFullPath(Path.Join(v, relativeStaticReqPath.TrimStart('/'))); if (Path.GetRelativePath(v, staticResponsePath).Contains("..")) { requestLogger.Warning($"Blocked GET request to {reqPath} as somehow the target file does not lie inside the static serve folder? Are you using symlinks?"); await HandleDefaultErrorPageAsync(rc, HttpStatusCode.NotFound); return; } if (File.Exists(staticResponsePath)) { rc.SetStatusCode(HttpStatusCode.OK); if (staticResponsePath.EndsWith(".svg")) { rc.ListenerContext.Response.AddHeader("Content-Type", "image/svg+xml"); } using var f = File.OpenRead(staticResponsePath); await f.CopyToAsync(rc.ListenerContext.Response.OutputStream); } else { await HandleDefaultErrorPageAsync(rc, HttpStatusCode.NotFound); } return; } } // invoke 404 await HandleDefaultErrorPageAsync(rc, 404); } } catch (Exception ex) { await HandleDefaultErrorPageAsync(rc, 500); mainLogger.Fatal($"Caught otherwise uncaught exception while ProcessingRequest:\n{ex}"); } finally { try { await rc.RespWriter.FlushAsync(); } catch (ObjectDisposedException) { } rc.ListenerContext.Response.Close(); LogRequest(); } } private static async Task HandleDefaultErrorPageAsync(RequestContext ctx, HttpStatusCode errorCode, string? statusDescription = null) => await HandleDefaultErrorPageAsync(ctx, (int) errorCode, statusDescription); private static async Task HandleDefaultErrorPageAsync(RequestContext ctx, int errorCode, string? statusDescription = null) { ctx.SetStatusCode(errorCode); string desc = statusDescription != null ? $"\r\n{statusDescription}" : ""; await ctx.WriteLineToRespAsync($"""

Oh no, an error occurred!

Code: {errorCode}

{desc} """); try { if (statusDescription == null) { await ctx.SetStatusCodeAndDisposeAsync(errorCode); } else { await ctx.SetStatusCodeAndDisposeAsync(errorCode, statusDescription); } } catch (ObjectDisposedException) { } } }