using SimpleHttpServer.Types; using SimpleHttpServer.Types.Exceptions; using SimpleHttpServer.Types.ParameterConverters; using System.Net; using System.Numerics; using System.Reflection; namespace SimpleHttpServer; public sealed class HttpServer { public int Port { get; } private readonly HttpListener listener; private Task? listenerTask; private readonly Logger logger; private readonly SimpleHttpServerConfiguration conf; private bool shutdown = false; public HttpServer(int port, SimpleHttpServerConfiguration configuration) { Port = port; conf = configuration; listener = new HttpListener(); listener.Prefixes.Add($"http://localhost:{port}/"); logger = new(LogOutputTopic.Main, conf); } public void Start() { logger.Information($"Starting on port {Port}..."); Assert(listenerTask == null, "Server was already started!"); listener.Start(); listenerTask = Task.Run(GetContextLoopAsync); logger.Information($"Ready to handle requests!"); } public async Task StopAsync(CancellationToken ctok) { logger.Information("Stopping server..."); Assert(listenerTask != null, "Server was not started!"); shutdown = true; listener.Stop(); await listenerTask.WaitAsync(ctok); } public async Task GetContextLoopAsync() { while (!shutdown) { try { var ctx = await listener.GetContextAsync(); _ = ProcessRequestAsync(ctx); } catch (HttpListenerException ex) when (ex.ErrorCode == 995) { //The I/O operation has been aborted because of either a thread exit or an application request } catch (Exception ex) { logger.Fatal($"Caught otherwise uncaught exception in GetContextLoop:\n{ex}"); } } } private void RegisterDefaultConverters() { void RegisterConverter() where T : IParsable { stringToTypeParameterConverters.Add(typeof(T), new ParsableParameterConverter()); } stringToTypeParameterConverters.Add(typeof(string), new StringParameterConverter()); stringToTypeParameterConverters.Add(typeof(bool), new BoolParsableParameterConverter()); RegisterConverter(); RegisterConverter(); RegisterConverter(); RegisterConverter(); RegisterConverter(); RegisterConverter(); RegisterConverter(); RegisterConverter(); RegisterConverter(); RegisterConverter(); RegisterConverter(); RegisterConverter(); RegisterConverter(); RegisterConverter(); RegisterConverter(); RegisterConverter(); } private readonly Dictionary<(string path, string rType), EndpointInvocationInfo> simpleEndpointMethodInfos = new(); private static readonly Type[] expectedEndpointParameterTypes = new[] { typeof(RequestContext) }; public void RegisterEndpointsFromType() { if (simpleEndpointMethodInfos.Count == 0) RegisterDefaultConverters(); var t = typeof(T); foreach (var (mi, attrib) in t.GetMethods() .ToDictionary(x => x, x => x.GetCustomAttributes(typeof(HttpEndpointAttribute<>))) .Where(x => x.Value.Any()).ToDictionary(x => x.Key, x => (HttpEndpointAttribute) x.Value.Single())) { string GetFancyMethodName() => mi.DeclaringType!.FullName + "#" + mi.Name; Assert(mi.IsStatic, $"Method tagged with HttpEndpointAttribute must be static! ({GetFancyMethodName()})"); Assert(mi.IsPublic, $"Method tagged with HttpEndpointAttribute must be public! ({GetFancyMethodName()})"); var methodParams = mi.GetParameters(); Assert(methodParams.Length >= expectedEndpointParameterTypes.Length); for (int i = 0; i < expectedEndpointParameterTypes.Length; i++) { Assert(methodParams[i].ParameterType.IsAssignableFrom(expectedEndpointParameterTypes[i]), $"Parameter at index {i} of {GetFancyMethodName()} is of a type that cannot contain the expected type {expectedEndpointParameterTypes[i].FullName}."); } Assert(mi.ReturnType == typeof(Task), $"Return type of {GetFancyMethodName()} is not {typeof(Task)}!"); var qparams = new List<(string, (Type type, bool isOptional))>(); for (int i = expectedEndpointParameterTypes.Length; i < methodParams.Length; i++) { var par = methodParams[i]; var attr = par.GetCustomAttribute(false); qparams.Add((attr?.Name ?? par.Name ?? throw new ArgumentException($"C# variable name of parameter at index {i} of method {GetFancyMethodName()} is null!"), (par.ParameterType, attr?.IsOptional ?? false))); if (!stringToTypeParameterConverters.ContainsKey(par.ParameterType)) { throw new MissingParameterConverterException($"Parameter converter for type {par.ParameterType} has not been registered (yet)!"); } } foreach (var location in attrib.Locations) { var normLocation = NormalizeUrlPath(location); int idx = normLocation.IndexOf('{'); if (idx >= 0) { // this path contains path parameters throw new NotImplementedException("Path parameters are not yet implemented!"); } var reqMethod = Enum.GetName(attrib.RequestMethod) ?? throw new ArgumentException("Request method was undefined"); simpleEndpointMethodInfos.Add((normLocation, reqMethod), new EndpointInvocationInfo(mi, qparams)); } } } private readonly Dictionary stringToTypeParameterConverters = new(); private string NormalizeUrlPath(string url) => '/' + url.TrimStart('/'); private async Task ProcessRequestAsync(HttpListenerContext ctx) { using RequestContext rc = new RequestContext(ctx); try { var decUri = WebUtility.UrlDecode(ctx.Request.RawUrl)!; // TODO add path escape countermeasures+unittests var splitted = NormalizeUrlPath(decUri).Split('?', 2, StringSplitOptions.None); var path = WebUtility.UrlDecode(splitted.First()); if (simpleEndpointMethodInfos.TryGetValue((path, ctx.Request.HttpMethod.ToUpperInvariant()), out var endpointInvocationInfo)) { var mi = endpointInvocationInfo.methodInfo; var qparams = endpointInvocationInfo.queryParameters; var args = splitted.Length == 2 ? splitted[1] : null; var parsedQParams = new Dictionary(); var convertedQParamValues = new object[qparams.Count + 1]; // TODO add authcheck here if (args != null) { var queryStringArgs = args.Split('&', StringSplitOptions.None); foreach (var queryKV in queryStringArgs) { var queryKVSplitted = queryKV.Split('='); if (queryKVSplitted.Length != 2) { rc.SetStatusCodeAndDispose(HttpStatusCode.BadRequest, "Malformed request URL parameters"); return; } if (!parsedQParams.TryAdd(WebUtility.UrlDecode(queryKVSplitted[0]), WebUtility.UrlDecode(queryKVSplitted[1]))) { rc.SetStatusCodeAndDispose(HttpStatusCode.BadRequest, "Duplicate request URL parameters"); return; } } for (int i = 0; i < qparams.Count;) { var (qparamName, qparamInfo) = qparams[i]; i++; if (parsedQParams.TryGetValue(qparamName, out var qparamValue)) { if (stringToTypeParameterConverters[qparamInfo.type].TryConvertFromString(qparamValue, out object objRes)) { convertedQParamValues[i] = objRes; } else { rc.SetStatusCodeAndDispose(HttpStatusCode.BadRequest); return; } } else { if (qparamInfo.isOptional) { convertedQParamValues[i] = null!; } else { rc.SetStatusCodeAndDispose(HttpStatusCode.BadRequest, $"Missing required query parameter {qparamName}"); return; } } } } convertedQParamValues[0] = rc; await (Task) (mi.Invoke(null, convertedQParamValues) ?? throw new NullReferenceException("Website func returned null unexpectedly")); } else { // invoke 404 await HandleDefaultErrorPageAsync(rc, 404); } } catch (Exception ex) { await HandleDefaultErrorPageAsync(rc, 500); logger.Fatal($"Caught otherwise uncaught exception while ProcessingRequest:\n{ex}"); } } private static async Task HandleDefaultErrorPageAsync(RequestContext ctx, int errorCode) { await ctx.WriteLineToRespAsync($"""

Oh no, an error occurred!

Code: {errorCode}

"""); } }