CSharpHttpServer/SimpleHttpServer/HttpServer.cs
2024-01-15 19:56:14 +01:00

295 lines
14 KiB
C#

using SimpleHttpServer.Types;
using SimpleHttpServer.Types.Exceptions;
using SimpleHttpServer.Types.ParameterConverters;
using System.Net;
using System.Numerics;
using System.Reflection;
using System.Text;
namespace SimpleHttpServer;
public sealed class HttpServer {
public int Port { get; }
private readonly HttpListener listener;
private Task? listenerTask;
private readonly Logger mainLogger;
private readonly Logger requestLogger;
private readonly SimpleHttpServerConfiguration conf;
private bool shutdown = false;
public HttpServer(int port, SimpleHttpServerConfiguration configuration) {
Port = port;
conf = configuration;
listener = new HttpListener();
listener.Prefixes.Add($"http://localhost:{port}/");
mainLogger = new(LogOutputTopic.Main, conf);
requestLogger = new(LogOutputTopic.Request, conf);
}
public void Start() {
mainLogger.Information($"Starting on port {Port}...");
Assert(listenerTask == null, "Server was already started!");
listener.Start();
listenerTask = Task.Run(GetContextLoopAsync);
mainLogger.Information($"Ready to handle requests!");
}
public async Task StopAsync(CancellationToken ctok) {
mainLogger.Information("Stopping server...");
Assert(listenerTask != null, "Server was not started!");
shutdown = true;
listener.Stop();
await listenerTask.WaitAsync(ctok);
}
public async Task GetContextLoopAsync() {
while (!shutdown) {
try {
var ctx = await listener.GetContextAsync();
_ = ProcessRequestAsync(ctx);
} catch (HttpListenerException ex) when (ex.ErrorCode == 995) { //The I/O operation has been aborted because of either a thread exit or an application request
} catch (Exception ex) {
mainLogger.Fatal($"Caught otherwise uncaught exception in GetContextLoop:\n{ex}");
}
}
}
private void RegisterDefaultConverters() {
void RegisterConverter<T>() where T : IParsable<T> {
stringToTypeParameterConverters.Add(typeof(T), new ParsableParameterConverter<T>());
}
stringToTypeParameterConverters.Add(typeof(string), new StringParameterConverter());
stringToTypeParameterConverters.Add(typeof(bool), new BoolParsableParameterConverter());
RegisterConverter<char>();
RegisterConverter<byte>();
RegisterConverter<short>();
RegisterConverter<int>();
RegisterConverter<long>();
RegisterConverter<Int128>();
RegisterConverter<UInt128>();
RegisterConverter<BigInteger>();
RegisterConverter<sbyte>();
RegisterConverter<ushort>();
RegisterConverter<uint>();
RegisterConverter<ulong>();
RegisterConverter<Half>();
RegisterConverter<float>();
RegisterConverter<double>();
RegisterConverter<decimal>();
}
private readonly Dictionary<(string path, string rType), EndpointInvocationInfo> simpleEndpointMethodInfos = new();
private static readonly Type[] expectedEndpointParameterTypes = new[] { typeof(RequestContext) };
public void RegisterEndpointsFromType<T>() {
if (simpleEndpointMethodInfos.Count == 0)
RegisterDefaultConverters();
var t = typeof(T);
foreach (var (mi, attrib) in t.GetMethods()
.ToDictionary(x => x, x => x.GetCustomAttributes(typeof(HttpEndpointAttribute<>)))
.Where(x => x.Value.Any()).ToDictionary(x => x.Key, x => (HttpEndpointAttribute) x.Value.Single())) {
string GetFancyMethodName() => mi.DeclaringType!.FullName + "#" + mi.Name;
Assert(mi.IsStatic, $"Method tagged with HttpEndpointAttribute must be static! ({GetFancyMethodName()})");
Assert(mi.IsPublic, $"Method tagged with HttpEndpointAttribute must be public! ({GetFancyMethodName()})");
var methodParams = mi.GetParameters();
Assert(methodParams.Length >= expectedEndpointParameterTypes.Length);
for (int i = 0; i < expectedEndpointParameterTypes.Length; i++) {
Assert(methodParams[i].ParameterType.IsAssignableFrom(expectedEndpointParameterTypes[i]),
$"Parameter at index {i} of {GetFancyMethodName()} is of a type that cannot contain the expected type {expectedEndpointParameterTypes[i].FullName}.");
}
Assert(mi.ReturnType == typeof(Task), $"Return type of {GetFancyMethodName()} is not {typeof(Task)}!");
var qparams = new List<(string, (Type type, bool isOptional))>();
for (int i = expectedEndpointParameterTypes.Length; i < methodParams.Length; i++) {
var par = methodParams[i];
var attr = par.GetCustomAttribute<ParameterAttribute>(false);
qparams.Add((attr?.Name ?? par.Name ?? throw new ArgumentException($"C# variable name of parameter at index {i} of method {GetFancyMethodName()} is null!"),
(par.ParameterType, attr?.IsOptional ?? false)));
if (!stringToTypeParameterConverters.ContainsKey(par.ParameterType)) {
throw new MissingParameterConverterException($"Parameter converter for type {par.ParameterType} has not been registered (yet)!");
}
}
foreach (var location in attrib.Locations) {
var normLocation = NormalizeUrlPath(location);
int idx = normLocation.IndexOf('{');
if (idx >= 0) {
// this path contains path parameters
throw new NotImplementedException("Path parameters are not yet implemented!");
}
var reqMethod = Enum.GetName(attrib.RequestMethod) ?? throw new ArgumentException("Request method was undefined");
mainLogger.Information($"Registered endpoint: '{reqMethod} {normLocation}'");
simpleEndpointMethodInfos.Add((normLocation, reqMethod), new EndpointInvocationInfo(mi, qparams));
}
}
}
/// <summary>
/// Serves all files located in <paramref name="filesystemDirectory"/> on a website path that is relative to <paramref name="requestPath"/>,
/// while restricting requests to inside the local filesystem directory. Static serving has a lower priority than registering an endpoint.
/// </summary>
/// <param name="requestPath"></param>
/// <param name="filesystemDirectory"></param>
public void RegisterStaticServePath(string requestPath, string filesystemDirectory) {
var absPath = Path.GetFullPath(filesystemDirectory);
string npath = NormalizeUrlPath(requestPath);
mainLogger.Information($"Registered static serve path: '{npath}' --> '{absPath}'");
staticServePaths.Add(npath, absPath);
}
private readonly Dictionary<string, string> staticServePaths = new Dictionary<string, string>();
private readonly Dictionary<Type, IParameterConverter> stringToTypeParameterConverters = new();
private static string NormalizeUrlPath(string url) {
var fwdSlashUrl = url.Replace('\\', '/');
var segments = fwdSlashUrl.Trim('/').Split('/', StringSplitOptions.RemoveEmptyEntries).ToList();
List<string> simplifiedSegmentsReversed = new List<string>();
int doubleDotsEncountered = 0;
for (int i = segments.Count - 1; i >= 0; i--) {
var segment = segments[i];
if (segment == ".") {
continue; // remove single dot segments
}
if (segment == "..") {
doubleDotsEncountered++; // if we encounter a doubledot, keep track of that and dont add it to the output yet
continue;
}
// otherwise only keep the segment if doubleDotsEncountered > 0
if (doubleDotsEncountered > 0) {
doubleDotsEncountered--;
continue;
}
simplifiedSegmentsReversed.Add(segment);
}
var rv = new StringBuilder();
for (int i = 0; i < doubleDotsEncountered; i++) {
rv.Append("../");
}
rv.AppendJoin('/', simplifiedSegmentsReversed.Reverse<string>());
return '/' + (rv.ToString().TrimEnd('/') + (fwdSlashUrl.EndsWith('/') ? "/" : "")).TrimStart('/');
}
private async Task ProcessRequestAsync(HttpListenerContext ctx) {
using RequestContext rc = new RequestContext(ctx);
var decUri = WebUtility.UrlDecode(ctx.Request.RawUrl)!; // TODO add path escape countermeasure-unittests
var splitted = decUri.Split('?', 2, StringSplitOptions.None);
var reqPath = NormalizeUrlPath(WebUtility.UrlDecode(splitted.First()));
string requestMethod = ctx.Request.HttpMethod.ToUpperInvariant();
bool wasStaticlyServed = false;
void LogRequest() {
requestLogger.Information($"{rc.ListenerContext.Response.StatusCode} {(wasStaticlyServed ? "static" : "endpnt")} {requestMethod} {decUri}");
}
try {
if (simpleEndpointMethodInfos.TryGetValue((reqPath, requestMethod), out var endpointInvocationInfo)) {
var mi = endpointInvocationInfo.methodInfo;
var qparams = endpointInvocationInfo.queryParameters;
var args = splitted.Length == 2 ? splitted[1] : null;
var parsedQParams = new Dictionary<string, string>();
var convertedQParamValues = new object[qparams.Count + 1];
// TODO add authcheck here
if (args != null) {
var queryStringArgs = args.Split('&', StringSplitOptions.None);
foreach (var queryKV in queryStringArgs) {
var queryKVSplitted = queryKV.Split('=');
if (queryKVSplitted.Length != 2) {
await rc.SetStatusCodeAndDisposeAsync(HttpStatusCode.BadRequest, "Malformed request URL parameters");
return;
}
if (!parsedQParams.TryAdd(WebUtility.UrlDecode(queryKVSplitted[0]), WebUtility.UrlDecode(queryKVSplitted[1]))) {
await rc.SetStatusCodeAndDisposeAsync(HttpStatusCode.BadRequest, "Duplicate request URL parameters");
return;
}
}
for (int i = 0; i < qparams.Count;) {
var (qparamName, qparamInfo) = qparams[i];
i++;
if (parsedQParams.TryGetValue(qparamName, out var qparamValue)) {
if (stringToTypeParameterConverters[qparamInfo.type].TryConvertFromString(qparamValue, out object objRes)) {
convertedQParamValues[i] = objRes;
} else {
await rc.SetStatusCodeAndDisposeAsync(HttpStatusCode.BadRequest);
return;
}
} else {
if (qparamInfo.isOptional) {
convertedQParamValues[i] = null!;
} else {
await rc.SetStatusCodeAndDisposeAsync(HttpStatusCode.BadRequest, $"Missing required query parameter {qparamName}");
return;
}
}
}
}
convertedQParamValues[0] = rc;
await (Task) (mi.Invoke(null, convertedQParamValues) ?? throw new NullReferenceException("Website func returned null unexpectedly"));
} else {
if (requestMethod == "GET")
foreach (var (k, v) in staticServePaths) {
if (reqPath.StartsWith(k)) { // do a static serve
wasStaticlyServed = true;
var relativeStaticReqPath = reqPath[k.Length..];
var staticResponsePath = Path.Combine(v, relativeStaticReqPath);
if (Path.GetRelativePath(v, staticResponsePath).Contains("..")) {
requestLogger.Warning($"Blocked GET request to {reqPath} as somehow the target file does not lie inside the static serve folder? Are you using symlinks?");
await rc.SetStatusCodeAndDisposeAsync(HttpStatusCode.NotFound);
return;
}
if (File.Exists(staticResponsePath)) {
rc.SetStatusCode(HttpStatusCode.OK);
using var f = File.OpenRead(v);
await f.CopyToAsync(rc.ListenerContext.Response.OutputStream);
} else {
await rc.SetStatusCodeAndDisposeAsync(HttpStatusCode.NotFound);
}
return;
}
}
// invoke 404
await HandleDefaultErrorPageAsync(rc, 404);
}
} catch (Exception ex) {
await HandleDefaultErrorPageAsync(rc, 500);
mainLogger.Fatal($"Caught otherwise uncaught exception while ProcessingRequest:\n{ex}");
}
LogRequest();
}
private static async Task HandleDefaultErrorPageAsync(RequestContext ctx, int errorCode) {
await ctx.WriteLineToRespAsync($"""
<body>
<h1>Oh no, an error occurred!</h1>
<p>Code: {errorCode}</p>
</body>
""");
}
}