CSharpHttpServer/SimpleHttpServer/HttpServer.cs
2024-01-09 04:28:45 +01:00

134 lines
5.4 KiB
C#

using SimpleHttpServer.Internal;
using System.Net;
using System.Reflection;
namespace SimpleHttpServer;
public sealed class HttpServer {
private Thread? _listenerThread;
private readonly HttpListener _listener;
private readonly Dictionary<(string path, HttpRequestType rType), HttpEndpointHandler> _plainEndpoints = new();
private readonly Dictionary<(string path, HttpRequestType rType), HttpEndpointHandler> _pparamEndpoints = new();
public string Url { get; private set; }
public Func<HttpListenerContext, HttpResponseBuilder> Default404 { get; private set; }
public static HttpServer Create(int port, string url, params Type[] apiDefinitions) => Create(Console.Error, port, url, false, apiDefinitions);
public static HttpServer Create(TextWriter error, int port, string url, bool throwOnInvalidEndpoint, params Type[] apiDefinitions) {
var epDict = new Dictionary<(string, HttpRequestType), HttpEndpointHandler>();
foreach (var definition in apiDefinitions) {
foreach (var endpoint in definition.GetMethods()) {
var attrib = endpoint.GetCustomAttributes()
.Where(x => x.GetType().IsAssignableTo(typeof(HttpEndpoint<>)))
.Select(x => (HttpEndpoint<IAuthorizer>) x)
.SingleOrDefault();
if (attrib == null) {
continue;
}
// sanity checks
if (!endpoint.IsStatic) {
PrintErrorOrThrow(error, endpoint, throwOnInvalidEndpoint, "HttpEndpointAttribute is only valid on static methods!");
continue;
}
if (!endpoint.IsPublic) {
PrintErrorOrThrow(error, endpoint, throwOnInvalidEndpoint, $"{GetFancyMethodName(endpoint)} needs to be public!");
}
var myParams = endpoint.GetParameters();
if (myParams.Length <= 0 || !myParams[0].GetType().IsAssignableFrom(typeof(HttpListenerContext))) {
PrintErrorOrThrow(error, endpoint, throwOnInvalidEndpoint, $"{GetFancyMethodName(endpoint)} needs to have a HttpListenerContext as its first argument!");
continue;
}
if (!endpoint.ReturnParameter.ParameterType.IsAssignableTo(typeof(HttpResponseBuilder))) {
PrintErrorOrThrow(error, endpoint, throwOnInvalidEndpoint, $"{GetFancyMethodName(endpoint)} needs to have a HttpResponseBuilder as the return type!");
}
var path = attrib.Location;
int idx = path.IndexOf('{');
if (idx >= 0) {
// this path contains path parameters
throw new NotImplementedException("Implement path parameters!");
}
var qparams = new List<(string, Type)>();
}
}
return null!;
}
public void Shutdown() {
Shutdown(-1);
}
public bool Shutdown(int timeout) {
if (_listenerThread == null) {
throw new InvalidOperationException("Cannot shutdown HttpServer that has not been started");
}
_listenerThread.Interrupt();
bool exited = true;
if (timeout < 0) {
_listenerThread.Join();
} else {
exited = _listenerThread.Join(timeout);
}
_listenerThread = null;
_listener.Stop();
return exited;
}
public void Start() {
_listenerThread = new Thread(RunServer);
_listener.Start();
_listenerThread.Start();
}
private void RunServer() {
try {
for (; ; ) {
var ctx = _listener.GetContext();
ThreadPool.QueueUserWorkItem((localCtx) => {
HttpRequestType type;
if (!Enum.TryParse(localCtx.Request.HttpMethod, out type)) {
Default404(localCtx).SendResponse(localCtx.Response);
return;
}
var path = localCtx.Request.Url!.LocalPath.Replace('\\', '/');
HttpEndpointHandler? ep = null;
if (!_plainEndpoints.TryGetValue((path, type), out ep)) {
// not found among plain endpoints
foreach (var epk in _pparamEndpoints.Keys) {
if (epk.rType == type && path.StartsWith(epk.path)) {
ep = _pparamEndpoints[epk];
break;
}
}
if (ep == null) {
Default404(localCtx).SendResponse(localCtx.Response);
return;
}
}
ep.Handle(localCtx);
}, ctx, false);
}
} catch (ThreadInterruptedException) {
// this can only be reached when listener.GetContext is interrupted
// safely exit main loop
}
}
private static void PrintErrorOrThrow(TextWriter error, MethodInfo method, bool forceThrow, string msg) {
if (forceThrow) {
throw new Exception(msg);
} else {
error.WriteLine($"> {msg}\n skipping {GetFancyMethodName(method)} ...");
}
}
private static string GetFancyMethodName(MethodInfo method) => method.DeclaringType!.Name + "#" + method.Name;
}