195 lines
7.2 KiB
C#
195 lines
7.2 KiB
C#
using SimpleHttpServer.Internal;
|
|
using System.Net;
|
|
using System.Reflection;
|
|
|
|
namespace SimpleHttpServer;
|
|
|
|
public sealed class HttpServer {
|
|
|
|
public int Port { get; }
|
|
|
|
private readonly CancellationTokenSource ctokSrc;
|
|
|
|
private readonly HttpListener listener;
|
|
private Task listenerTask;
|
|
private Logger logger;
|
|
|
|
public HttpServer(int port, TextWriter? logRedirect = null) {
|
|
ctokSrc = new();
|
|
Port = port;
|
|
listener = new HttpListener();
|
|
listener.Prefixes.Add($"http://localhost:{port}/");
|
|
logger = new("HttpServer", logRedirect);
|
|
}
|
|
|
|
public async Task StartAsync() {
|
|
logger.Information($"Starting on port {Port}...");
|
|
listener.Start();
|
|
listenerTask = Task.Run(GetContextLoop);
|
|
logger.Information($"Ready to handle requests!");
|
|
|
|
await Task.Yield();
|
|
}
|
|
|
|
public async Task GetContextLoop() {
|
|
while (true) {
|
|
try {
|
|
var ctx = await listener.GetContextAsync();
|
|
_ = ProcessRequestAsync(ctx);
|
|
} catch (Exception ex) {
|
|
|
|
} finally {
|
|
|
|
}
|
|
}
|
|
}
|
|
|
|
private async Task ProcessRequestAsync(HttpListenerContext ctx) {
|
|
|
|
}
|
|
|
|
|
|
private readonly Dictionary<(string path, HttpRequestType rType), Action<RequestContext>> simpleEndpoints = new();
|
|
public void RegisterRoutesFromType<T>() {
|
|
var t = typeof(T);
|
|
foreach (var (mi, attrib) in t.GetMethods()
|
|
.ToDictionary(x => x, x => x.GetCustomAttributes(typeof(HttpRoute<>)))
|
|
.Where(x => x.Value.Any()).ToDictionary(x => x.Key, x => x.Value.Single() as HttpRoute<IAuthorizer> ?? throw new InvalidCastException()))
|
|
{
|
|
simpleEndpoints.Add((attrib.Location, attrib.RequestMethod), mi.CreateDelegate<Action<RequestContext>>());
|
|
}
|
|
}
|
|
|
|
private readonly HttpListener _listener;
|
|
private readonly Dictionary<(string path, HttpRequestType rType), HttpEndpointHandler> _plainEndpoints = new();
|
|
private readonly Dictionary<(string path, HttpRequestType rType), HttpEndpointHandler> _pparamEndpoints = new();
|
|
|
|
public string Url { get; private set; }
|
|
public Func<HttpListenerContext, HttpResponseBuilder> Default404 { get; private set; }
|
|
|
|
public static HttpServer Create(int port, string url, params Type[] apiDefinitions) => Create(Console.Error, port, url, false, apiDefinitions);
|
|
|
|
public static HttpServer Create(TextWriter error, int port, string url, bool throwOnInvalidEndpoint, params Type[] apiDefinitions) {
|
|
var epDict = new Dictionary<(string, HttpRequestType), HttpEndpointHandler>();
|
|
|
|
foreach (var definition in apiDefinitions) {
|
|
foreach (var endpoint in definition.GetMethods()) {
|
|
var attrib = endpoint.GetCustomAttributes()
|
|
.Where(x => x.GetType().IsAssignableTo(typeof(HttpRoute<>)))
|
|
.Select(x => (HttpRoute<IAuthorizer>) x)
|
|
.SingleOrDefault();
|
|
|
|
if (attrib == null) {
|
|
continue;
|
|
}
|
|
|
|
// sanity checks
|
|
if (!endpoint.IsStatic) {
|
|
PrintErrorOrThrow(error, endpoint, throwOnInvalidEndpoint, "HttpEndpointAttribute is only valid on static methods!");
|
|
continue;
|
|
}
|
|
if (!endpoint.IsPublic) {
|
|
PrintErrorOrThrow(error, endpoint, throwOnInvalidEndpoint, $"{GetFancyMethodName(endpoint)} needs to be public!");
|
|
}
|
|
var myParams = endpoint.GetParameters();
|
|
if (myParams.Length <= 0 || !myParams[0].GetType().IsAssignableFrom(typeof(HttpListenerContext))) {
|
|
PrintErrorOrThrow(error, endpoint, throwOnInvalidEndpoint, $"{GetFancyMethodName(endpoint)} needs to have a HttpListenerContext as its first argument!");
|
|
continue;
|
|
}
|
|
if (!endpoint.ReturnParameter.ParameterType.IsAssignableTo(typeof(HttpResponseBuilder))) {
|
|
PrintErrorOrThrow(error, endpoint, throwOnInvalidEndpoint, $"{GetFancyMethodName(endpoint)} needs to have a HttpResponseBuilder as the return type!");
|
|
}
|
|
|
|
var path = attrib.Location;
|
|
int idx = path.IndexOf('{');
|
|
if (idx >= 0) {
|
|
// this path contains path parameters
|
|
throw new NotImplementedException("Implement path parameters!");
|
|
}
|
|
var qparams = new List<(string, Type)>();
|
|
}
|
|
}
|
|
return null!;
|
|
}
|
|
|
|
|
|
public void Shutdown() {
|
|
Shutdown(-1);
|
|
}
|
|
|
|
public bool Shutdown() {
|
|
if (_listenerThread == null)
|
|
throw new InvalidOperationException("Cannot shut down HttpServer that has not been started");
|
|
|
|
|
|
|
|
}
|
|
|
|
public bool Shutdown(int timeout) {
|
|
if (_listenerThread == null) {
|
|
throw new InvalidOperationException("Cannot shutdown HttpServer that has not been started");
|
|
}
|
|
_listenerThread.Interrupt();
|
|
bool exited = true;
|
|
if (timeout < 0) {
|
|
_listenerThread.Join();
|
|
} else {
|
|
exited = _listenerThread.Join(timeout);
|
|
}
|
|
_listenerThread = null;
|
|
_listener.Stop();
|
|
return exited;
|
|
}
|
|
|
|
public void Start() {
|
|
_listenerThread = new Thread(RunServer);
|
|
_listener.Start();
|
|
_listenerThread.Start();
|
|
}
|
|
|
|
|
|
private void RunServer() {
|
|
try {
|
|
while (true) {
|
|
var ctx = _listener.GetContext();
|
|
|
|
ThreadPool.QueueUserWorkItem((localCtx) => {
|
|
HttpRequestType type;
|
|
if (!Enum.TryParse(localCtx.Request.HttpMethod, out type)) {
|
|
Default404(localCtx).SendResponse(localCtx.Response);
|
|
return;
|
|
}
|
|
var path = localCtx.Request.Url!.LocalPath.Replace('\\', '/');
|
|
HttpEndpointHandler? ep = null;
|
|
if (!_plainEndpoints.TryGetValue((path, type), out ep)) {
|
|
// not found among plain endpoints
|
|
foreach (var epk in _pparamEndpoints.Keys) {
|
|
if (epk.rType == type && path.StartsWith(epk.path)) {
|
|
ep = _pparamEndpoints[epk];
|
|
break;
|
|
}
|
|
}
|
|
if (ep == null) {
|
|
Default404(localCtx).SendResponse(localCtx.Response);
|
|
return;
|
|
}
|
|
}
|
|
ep.Handle(localCtx);
|
|
}, ctx, false);
|
|
}
|
|
} catch (ThreadInterruptedException) {
|
|
// this can only be reached when listener.GetContext is interrupted
|
|
// safely exit main loop
|
|
}
|
|
}
|
|
|
|
private static void PrintErrorOrThrow(TextWriter error, MethodInfo method, bool forceThrow, string msg) {
|
|
if (forceThrow) {
|
|
throw new Exception(msg);
|
|
} else {
|
|
error.WriteLine($"> {msg}\n skipping {GetFancyMethodName(method)} ...");
|
|
}
|
|
}
|
|
|
|
private static string GetFancyMethodName(MethodInfo method) => method.DeclaringType!.Name + "#" + method.Name;
|
|
} |