418 lines
21 KiB
C#
418 lines
21 KiB
C#
using SimpleHttpServer.Types;
|
|
using SimpleHttpServer.Types.Exceptions;
|
|
using SimpleHttpServer.Types.ParameterConverters;
|
|
using System.Net;
|
|
using System.Numerics;
|
|
using System.Reflection;
|
|
using System.Text;
|
|
using static SimpleHttpServer.Types.EndpointInvocationInfo;
|
|
|
|
namespace SimpleHttpServer;
|
|
|
|
public sealed class HttpServer {
|
|
|
|
public int Port { get; }
|
|
|
|
private readonly HttpListener listener;
|
|
private Task? listenerTask;
|
|
private readonly Logger mainLogger;
|
|
private readonly Logger requestLogger;
|
|
private readonly SimpleHttpServerConfiguration conf;
|
|
private bool shutdown = false;
|
|
|
|
public HttpServer(int port, SimpleHttpServerConfiguration configuration) {
|
|
Port = port;
|
|
conf = configuration;
|
|
listener = new HttpListener();
|
|
listener.Prefixes.Add($"http://localhost:{port}/");
|
|
mainLogger = new(LogOutputTopic.Main, conf);
|
|
requestLogger = new(LogOutputTopic.Request, conf);
|
|
}
|
|
|
|
public void Start() {
|
|
mainLogger.Information($"Starting on port {Port}...");
|
|
Assert(listenerTask == null, "Server was already started!");
|
|
listener.Start();
|
|
listenerTask = Task.Run(GetContextLoopAsync);
|
|
mainLogger.Information($"Ready to handle requests!");
|
|
}
|
|
|
|
public async Task StopAsync(CancellationToken ctok) {
|
|
mainLogger.Information("Stopping server...");
|
|
Assert(listenerTask != null, "Server was not started!");
|
|
shutdown = true;
|
|
listener.Stop();
|
|
await listenerTask.WaitAsync(ctok);
|
|
}
|
|
|
|
public async Task GetContextLoopAsync() {
|
|
while (!shutdown) {
|
|
try {
|
|
var ctx = await listener.GetContextAsync();
|
|
_ = ProcessRequestAsync(ctx);
|
|
} catch (HttpListenerException ex) when (ex.ErrorCode == 995) { //The I/O operation has been aborted because of either a thread exit or an application request
|
|
} catch (Exception ex) {
|
|
mainLogger.Fatal($"Caught otherwise uncaught exception in GetContextLoop:\n{ex}");
|
|
}
|
|
}
|
|
}
|
|
|
|
private void RegisterDefaultConverters() {
|
|
void RegisterConverter<T>() where T : IParsable<T> {
|
|
stringToTypeParameterConverters.Add(typeof(T), new ParsableParameterConverter<T>());
|
|
}
|
|
stringToTypeParameterConverters.Add(typeof(string), new StringParameterConverter());
|
|
|
|
stringToTypeParameterConverters.Add(typeof(bool), new BoolParsableParameterConverter());
|
|
RegisterConverter<char>();
|
|
RegisterConverter<byte>();
|
|
RegisterConverter<short>();
|
|
RegisterConverter<int>();
|
|
RegisterConverter<long>();
|
|
RegisterConverter<Int128>();
|
|
RegisterConverter<UInt128>();
|
|
RegisterConverter<BigInteger>();
|
|
|
|
RegisterConverter<sbyte>();
|
|
RegisterConverter<ushort>();
|
|
RegisterConverter<uint>();
|
|
RegisterConverter<ulong>();
|
|
|
|
RegisterConverter<Half>();
|
|
RegisterConverter<float>();
|
|
RegisterConverter<double>();
|
|
RegisterConverter<decimal>();
|
|
}
|
|
|
|
private readonly MultiKeyDictionary<string, string, EndpointInvocationInfo> simpleEndpointMethodInfos = new(); // requestmethod, path
|
|
private readonly MultiKeyDictionary<string, string, EndpointInvocationInfo> pathEndpointMethodInfos = new(); // requestmethod, path
|
|
private readonly Dictionary<string, PathTree<EndpointInvocationInfo>> pathEndpointMethodInfosTrees = new(); // reqmethod : pathtree
|
|
private static readonly Type[] expectedEndpointParameterTypes = new[] { typeof(RequestContext) };
|
|
internal static readonly int expectedEndpointParameterPrefixCount = expectedEndpointParameterTypes.Length;
|
|
|
|
public void RegisterEndpointsFromType<T>(Func<T>? instanceFactory = null) where T : class { // T cannot be static, as generic args must be nonstatic
|
|
if (stringToTypeParameterConverters.Count == 0)
|
|
RegisterDefaultConverters();
|
|
|
|
var t = typeof(T);
|
|
var mis = t.GetMethods()
|
|
.ToDictionary(x => x, x => x.GetCustomAttributes<HttpEndpointAttribute>())
|
|
.Where(x => x.Value.Any()).ToDictionary(x => x.Key, x => x.Value.Single());
|
|
|
|
var isStatic = mis.All(x => x.Key.IsStatic); // if all are static then there is no point in having a constructor as no instance data is accessible, but we allow passing a factory anyway
|
|
Assert(isStatic || (instanceFactory != null), $"You must provide an instance factory if any methods of the given type ({typeof(T).FullName}) are non-static");
|
|
T? classInstance = instanceFactory?.Invoke();
|
|
foreach (var (mi, attrib) in mis) {
|
|
|
|
string GetFancyMethodName() => mi.DeclaringType!.FullName + "#" + mi.Name;
|
|
|
|
//Assert(mi.IsStatic, $"Method tagged with HttpEndpointAttribute must be static! ({GetFancyMethodName()})");
|
|
Assert(mi.IsPublic, $"Method tagged with HttpEndpointAttribute must be public! ({GetFancyMethodName()})");
|
|
|
|
var methodParams = mi.GetParameters();
|
|
// check the mandatory prefix parameters
|
|
Assert(methodParams.Length >= expectedEndpointParameterTypes.Length);
|
|
for (int i = 0; i < expectedEndpointParameterTypes.Length; i++) {
|
|
Assert(methodParams[i].ParameterType.IsAssignableFrom(expectedEndpointParameterTypes[i]),
|
|
$"Parameter at index {i} of {GetFancyMethodName()} is of a type that cannot contain the expected type {expectedEndpointParameterTypes[i].FullName}.");
|
|
}
|
|
|
|
// check return type
|
|
Assert(mi.ReturnType == typeof(Task), $"Return type of {GetFancyMethodName()} is not {typeof(Task)}!");
|
|
|
|
// check the rest of the method parameters
|
|
var qparams = new List<QueryParameterInfo>();
|
|
var pparams = new List<PathParameterInfo>();
|
|
int mParamIndex = expectedEndpointParameterTypes.Length;
|
|
for (int i = expectedEndpointParameterTypes.Length; i < methodParams.Length; i++) {
|
|
var par = methodParams[i];
|
|
var attr = par.GetCustomAttribute<ParameterAttribute>(false);
|
|
var pathAttr = par.GetCustomAttribute<PathParameterAttribute>(false);
|
|
|
|
if (attr != null && pathAttr != null) {
|
|
throw new ArgumentException($"A method argument cannot be tagged with both {nameof(ParameterAttribute)} and {nameof(PathParameterAttribute)}");
|
|
}
|
|
|
|
if (!stringToTypeParameterConverters.ContainsKey(par.ParameterType)) {
|
|
throw new MissingParameterConverterException($"Parameter converter for type {par.ParameterType} for parameter at index {i} of method {GetFancyMethodName()} has not been registered (yet)!");
|
|
}
|
|
|
|
if (pathAttr != null) { // parameter is a path param
|
|
|
|
pparams.Add(new(
|
|
pathAttr?.Name ?? throw new ArgumentException($"C# variable name of path parameter at index {i} of method {GetFancyMethodName()} is null!"),
|
|
par.ParameterType,
|
|
mParamIndex++
|
|
)
|
|
);
|
|
} else { // parameter is a normal query param
|
|
qparams.Add(new(
|
|
attr?.Name ?? par.Name ?? throw new ArgumentException($"C# variable name of query parameter at index {i} of method {GetFancyMethodName()} is null!"),
|
|
par.ParameterType,
|
|
mParamIndex++,
|
|
attr?.IsOptional ?? false)
|
|
);
|
|
}
|
|
}
|
|
|
|
// stores the check attributes that are defined on the method and on the containing class
|
|
InternalEndpointCheckAttribute[] requiredChecks = mi.GetCustomAttributes<InternalEndpointCheckAttribute>(true)
|
|
.Concat(mi.DeclaringType?.GetCustomAttributes<InternalEndpointCheckAttribute>(true) ?? Enumerable.Empty<Attribute>())
|
|
.Where(a => a.GetType().IsAssignableTo(typeof(InternalEndpointCheckAttribute)))
|
|
.Cast<InternalEndpointCheckAttribute>().ToArray();
|
|
|
|
InternalEndpointCheckAttribute.Initialize(classInstance, requiredChecks);
|
|
|
|
foreach (var location in attrib.Locations) {
|
|
var normLocation = NormalizeUrlPath(location);
|
|
|
|
var reqMethod = Enum.GetName(attrib.RequestMethod) ?? throw new ArgumentException("Request method was undefined");
|
|
|
|
var pparamsCopy = new List<PathParameterInfo>(pparams);
|
|
var splittedLocation = location[1..].Split('/');
|
|
for (int i = 0; i < pparamsCopy.Count; i++) {
|
|
var pp = pparamsCopy[i];
|
|
var idx = Array.IndexOf(splittedLocation, pp.Name);
|
|
Assert(idx != -1, "Path parameter name was incorrect?");
|
|
pp.SegmentStartPos = idx;
|
|
pparamsCopy[i] = pp;
|
|
}
|
|
|
|
var epInvocInfo = new EndpointInvocationInfo(mi, pparamsCopy, qparams, requiredChecks, classInstance);
|
|
if (pparams.Any()) {
|
|
mainLogger.Information($"Registered path endpoint: '{reqMethod} {normLocation}'");
|
|
Assert(normLocation[0] == '/');
|
|
pathEndpointMethodInfos.Add(reqMethod, normLocation[1..], epInvocInfo);
|
|
} else {
|
|
mainLogger.Information($"Registered simple endpoint: '{reqMethod} {normLocation}'");
|
|
simpleEndpointMethodInfos.Add(reqMethod, normLocation, epInvocInfo);
|
|
}
|
|
}
|
|
}
|
|
|
|
// rebuild path trees
|
|
pathEndpointMethodInfosTrees.Clear();
|
|
foreach (var (reqMethod, d2) in pathEndpointMethodInfos.backingDict)
|
|
pathEndpointMethodInfosTrees.Add(reqMethod, new(d2));
|
|
}
|
|
|
|
/// <summary>
|
|
/// Serves all files located in <paramref name="filesystemDirectory"/> on a website path that is relative to <paramref name="requestPath"/>,
|
|
/// while restricting requests to inside the local filesystem directory. Static serving has a lower priority than registering an endpoint.
|
|
/// </summary>
|
|
/// <param name="requestPath"></param>
|
|
/// <param name="filesystemDirectory"></param>
|
|
public void RegisterStaticServePath(string requestPath, string filesystemDirectory) {
|
|
var absPath = Path.GetFullPath(filesystemDirectory);
|
|
string npath = NormalizeUrlPath(requestPath);
|
|
mainLogger.Information($"Registered static serve path: '{npath}' --> '{absPath}'");
|
|
staticServePaths.Add(npath, absPath);
|
|
}
|
|
|
|
private readonly Dictionary<string, string> staticServePaths = new();
|
|
|
|
private readonly Dictionary<Type, IParameterConverter> stringToTypeParameterConverters = new();
|
|
|
|
private string NormalizeUrlPath(string url) {
|
|
var fwdSlashUrl = url.Replace('\\', '/');
|
|
|
|
var segments = fwdSlashUrl.Trim('/').Split('/', StringSplitOptions.RemoveEmptyEntries).ToList();
|
|
List<string> simplifiedSegmentsReversed = new List<string>();
|
|
int doubleDotsEncountered = 0;
|
|
for (int i = segments.Count - 1; i >= 0; i--) {
|
|
var segment = segments[i];
|
|
if (segment == ".") {
|
|
continue; // remove single dot segments
|
|
}
|
|
if (segment == "..") {
|
|
doubleDotsEncountered++; // if we encounter a doubledot, keep track of that and dont add it to the output yet
|
|
continue;
|
|
}
|
|
// otherwise only keep the segment if doubleDotsEncountered > 0
|
|
if (doubleDotsEncountered > 0) {
|
|
doubleDotsEncountered--;
|
|
continue;
|
|
}
|
|
simplifiedSegmentsReversed.Add(segment);
|
|
}
|
|
|
|
var rv = new StringBuilder();
|
|
for (int i = 0; i < doubleDotsEncountered; i++) {
|
|
rv.Append("../");
|
|
}
|
|
rv.AppendJoin('/', simplifiedSegmentsReversed.Reverse<string>());
|
|
|
|
var suffix = (rv.ToString().TrimEnd('/') + (fwdSlashUrl.EndsWith('/') ? "/" : "")).TrimStart('/');
|
|
if (conf.TrimTrailingSlash) {
|
|
suffix = suffix.TrimEnd('/');
|
|
}
|
|
|
|
return '/' + suffix;
|
|
}
|
|
|
|
private async Task ProcessRequestAsync(HttpListenerContext ctx) {
|
|
using RequestContext rc = new RequestContext(ctx);
|
|
|
|
// TODO add path escape countermeasure-unittests
|
|
var splitted = (ctx.Request.RawUrl ?? "").Split('?', 2, StringSplitOptions.None);
|
|
var reqPath = NormalizeUrlPath(WebUtility.UrlDecode(splitted.First()));
|
|
string requestMethod = ctx.Request.HttpMethod.ToUpperInvariant();
|
|
bool wasStaticlyServed = false;
|
|
|
|
void LogRequest() {
|
|
requestLogger.Information($"{rc.ListenerContext.Response.StatusCode} {(wasStaticlyServed ? "static" : "endpnt")} {requestMethod} {ctx.Request.Url}");
|
|
}
|
|
try {
|
|
|
|
/* Finding the endpoint that should process the request:
|
|
* 1. Try to see if there is a simple endpoint where request method and path match
|
|
* 2. Otherwise, try to see if a path-parameter-endpoint matches (duplicates throw an error on startup)
|
|
* 3. Otherwise, check if it is inside a static serve path
|
|
* 4. Otherwise, show 404 page */
|
|
|
|
EndpointInvocationInfo? pathEndpointInvocationInfo = null;
|
|
if (simpleEndpointMethodInfos.TryGetValue(requestMethod, reqPath, out var simpleEndpointInvocationInfo) ||
|
|
pathEndpointMethodInfosTrees.TryGetValue(requestMethod, out var pt) && pt.TryGetPath(reqPath, out pathEndpointInvocationInfo)) { // try to find simple or pathparam-endpoint
|
|
var endpointInvocationInfo = simpleEndpointInvocationInfo ?? pathEndpointInvocationInfo ?? throw new Exception("retrieved endpoint is somehow null");
|
|
var mi = endpointInvocationInfo.methodInfo;
|
|
var qparams = endpointInvocationInfo.queryParameters;
|
|
var pparams = endpointInvocationInfo.pathParameters;
|
|
var args = splitted.Length == 2 ? splitted[1] : null;
|
|
|
|
var parsedQParams = new Dictionary<string, string>();
|
|
var convertedMParamValues = new object[expectedEndpointParameterTypes.Length + pparams.Count + qparams.Count];
|
|
|
|
// run the checks to see if the client is allowed to make this request
|
|
if (!endpointInvocationInfo.CheckAll(rc.ListenerContext.Request)) { // if any check failed return Forbidden
|
|
await HandleDefaultErrorPageAsync(rc, HttpStatusCode.Forbidden, "Client is not allowed to access this resource");
|
|
return;
|
|
}
|
|
|
|
if (args != null) {
|
|
var queryStringArgs = args.Split('&', StringSplitOptions.None);
|
|
foreach (var queryKV in queryStringArgs) {
|
|
var queryKVSplitted = queryKV.Split('=');
|
|
if (queryKVSplitted.Length != 2) {
|
|
await HandleDefaultErrorPageAsync(rc, HttpStatusCode.BadRequest, "Malformed request URL parameters");
|
|
return;
|
|
}
|
|
if (!parsedQParams.TryAdd(WebUtility.UrlDecode(queryKVSplitted[0]), WebUtility.UrlDecode(queryKVSplitted[1]))) {
|
|
await HandleDefaultErrorPageAsync(rc, HttpStatusCode.BadRequest, "Duplicate request URL parameters");
|
|
return;
|
|
}
|
|
}
|
|
|
|
for (int i = 0; i < qparams.Count;) {
|
|
var qparam = qparams[i];
|
|
i++;
|
|
|
|
if (parsedQParams.TryGetValue(qparam.Name, out var qparamValue)) {
|
|
if (stringToTypeParameterConverters[qparam.Type].TryConvertFromString(qparamValue, out object objRes)) {
|
|
convertedMParamValues[qparam.ArgPos] = objRes;
|
|
} else {
|
|
await HandleDefaultErrorPageAsync(rc, HttpStatusCode.BadRequest);
|
|
return;
|
|
}
|
|
} else {
|
|
if (qparam.IsOptional) {
|
|
convertedMParamValues[qparam.ArgPos] = null!;
|
|
} else {
|
|
await HandleDefaultErrorPageAsync(rc, HttpStatusCode.BadRequest, $"Missing required query parameter {qparam.Name}");
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
} else { // check for missing query parameters
|
|
var requiredParams = qparams.Where(x => !x.IsOptional).Select(x => $"'{x.Name}'").ToList();
|
|
if (requiredParams.Any()) {
|
|
await HandleDefaultErrorPageAsync(rc, HttpStatusCode.BadRequest, $"Missing required query parameter(s): {string.Join(",", requiredParams)}");
|
|
return;
|
|
}
|
|
}
|
|
if (pparams.Count != 0) {
|
|
var splittedReqPath = reqPath[1..].Split('/');
|
|
for (int i = 0; i < pparams.Count; i++) {
|
|
var pparam = pparams[i];
|
|
string paramValue;
|
|
if (pparam.IsCatchAll)
|
|
paramValue = string.Join('/', splittedReqPath[pparam.SegmentStartPos..]);
|
|
else
|
|
paramValue = splittedReqPath[pparam.SegmentStartPos];
|
|
|
|
if (stringToTypeParameterConverters[pparam.Type].TryConvertFromString(paramValue, out var res))
|
|
convertedMParamValues[pparam.ArgPos] = res;
|
|
else {
|
|
await HandleDefaultErrorPageAsync(rc, HttpStatusCode.BadRequest);
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
|
|
convertedMParamValues[0] = rc;
|
|
rc.ParsedParameters = parsedQParams.AsReadOnly();
|
|
|
|
// todo read and convert pathparams
|
|
|
|
await (Task) (mi.Invoke(endpointInvocationInfo.typeInstanceReference, convertedMParamValues) ?? throw new NullReferenceException("Website func returned null unexpectedly"));
|
|
} else { // try to find suitable static serve path
|
|
if (requestMethod == "GET")
|
|
foreach (var (k, v) in staticServePaths) {
|
|
if (reqPath.StartsWith(k)) { // do a static serve
|
|
wasStaticlyServed = true;
|
|
var relativeStaticReqPath = reqPath[k.Length..];
|
|
var staticResponsePath = Path.GetFullPath(Path.Join(v, relativeStaticReqPath.TrimStart('/')));
|
|
|
|
if (Path.GetRelativePath(v, staticResponsePath).Contains("..")) {
|
|
requestLogger.Warning($"Blocked GET request to {reqPath} as somehow the target file does not lie inside the static serve folder? Are you using symlinks?");
|
|
await HandleDefaultErrorPageAsync(rc, HttpStatusCode.NotFound);
|
|
return;
|
|
}
|
|
|
|
if (File.Exists(staticResponsePath)) {
|
|
rc.SetStatusCode(HttpStatusCode.OK);
|
|
if (staticResponsePath.EndsWith(".svg")) {
|
|
rc.ListenerContext.Response.AddHeader("Content-Type", "image/svg+xml");
|
|
}
|
|
using var f = File.OpenRead(staticResponsePath);
|
|
await f.CopyToAsync(rc.ListenerContext.Response.OutputStream);
|
|
} else {
|
|
await HandleDefaultErrorPageAsync(rc, HttpStatusCode.NotFound);
|
|
}
|
|
return;
|
|
}
|
|
}
|
|
|
|
// invoke 404
|
|
await HandleDefaultErrorPageAsync(rc, 404);
|
|
}
|
|
|
|
} catch (Exception ex) {
|
|
await HandleDefaultErrorPageAsync(rc, 500);
|
|
mainLogger.Fatal($"Caught otherwise uncaught exception while ProcessingRequest:\n{ex}");
|
|
} finally {
|
|
try { await rc.RespWriter.FlushAsync(); } catch (ObjectDisposedException) { }
|
|
rc.ListenerContext.Response.Close();
|
|
LogRequest();
|
|
}
|
|
}
|
|
|
|
private static async Task HandleDefaultErrorPageAsync(RequestContext ctx, HttpStatusCode errorCode, string? statusDescription = null) => await HandleDefaultErrorPageAsync(ctx, (int) errorCode, statusDescription);
|
|
|
|
private static async Task HandleDefaultErrorPageAsync(RequestContext ctx, int errorCode, string? statusDescription = null) {
|
|
ctx.SetStatusCode(errorCode);
|
|
string desc = statusDescription != null ? $"\r\n{statusDescription}" : "";
|
|
await ctx.WriteLineToRespAsync($"""
|
|
<body>
|
|
<h1>Oh no, an error occurred!</h1>
|
|
<p>Code: {errorCode}</p>{desc}
|
|
</body>
|
|
""");
|
|
try {
|
|
if (statusDescription == null) {
|
|
await ctx.SetStatusCodeAndDisposeAsync(errorCode);
|
|
} else {
|
|
await ctx.SetStatusCodeAndDisposeAsync(errorCode, statusDescription);
|
|
}
|
|
} catch (ObjectDisposedException) { }
|
|
}
|
|
} |