From c75d29a1ba41c6ea14ed87479fb70c2f0c00a1e1 Mon Sep 17 00:00:00 2001 From: GHXX Date: Fri, 19 Jul 2024 03:30:31 +0200 Subject: [PATCH] Switch over to a check-based system with multi attribute support --- SimpleHttpServer/HttpEndpointAttribute.cs | 12 ++----- SimpleHttpServer/HttpServer.cs | 36 ++++++++++++------- .../Types/BaseEndpointCheckAttribute.cs | 15 ++++++++ .../Types/EndpointInvocationInfo.cs | 17 ++++++--- 4 files changed, 53 insertions(+), 27 deletions(-) create mode 100644 SimpleHttpServer/Types/BaseEndpointCheckAttribute.cs diff --git a/SimpleHttpServer/HttpEndpointAttribute.cs b/SimpleHttpServer/HttpEndpointAttribute.cs index 8e58e17..147cf6c 100644 --- a/SimpleHttpServer/HttpEndpointAttribute.cs +++ b/SimpleHttpServer/HttpEndpointAttribute.cs @@ -1,23 +1,15 @@ -using SimpleHttpServer.Internal; -using SimpleHttpServer.Types; +using SimpleHttpServer.Types; namespace SimpleHttpServer; [AttributeUsage(AttributeTargets.Method, AllowMultiple = false)] -public class HttpEndpointAttribute : Attribute where T : IAuthorizer { +public class HttpEndpointAttribute : Attribute { public HttpRequestType RequestMethod { get; private set; } public string[] Locations { get; private set; } - public Type Authorizer { get; private set; } public HttpEndpointAttribute(HttpRequestType requestMethod, params string[] locations) { RequestMethod = requestMethod; Locations = locations; - Authorizer = typeof(T); } } - -[AttributeUsage(AttributeTargets.Method)] -public class HttpEndpointAttribute : HttpEndpointAttribute { - public HttpEndpointAttribute(HttpRequestType type, params string[] locations) : base(type, locations) { } -} diff --git a/SimpleHttpServer/HttpServer.cs b/SimpleHttpServer/HttpServer.cs index a084aa7..4b67aa3 100644 --- a/SimpleHttpServer/HttpServer.cs +++ b/SimpleHttpServer/HttpServer.cs @@ -5,6 +5,7 @@ using System.Net; using System.Numerics; using System.Reflection; using System.Text; +using static SimpleHttpServer.Types.EndpointInvocationInfo; namespace SimpleHttpServer; @@ -91,8 +92,8 @@ public sealed class HttpServer { var t = typeof(T); foreach (var (mi, attrib) in t.GetMethods() - .ToDictionary(x => x, x => x.GetCustomAttributes(typeof(HttpEndpointAttribute<>))) - .Where(x => x.Value.Any()).ToDictionary(x => x.Key, x => (HttpEndpointAttribute) x.Value.Single())) { + .ToDictionary(x => x, x => x.GetCustomAttributes()) + .Where(x => x.Value.Any()).ToDictionary(x => x.Key, x => x.Value.Single())) { string GetFancyMethodName() => mi.DeclaringType!.FullName + "#" + mi.Name; @@ -109,18 +110,25 @@ public sealed class HttpServer { Assert(mi.ReturnType == typeof(Task), $"Return type of {GetFancyMethodName()} is not {typeof(Task)}!"); - var qparams = new List<(string, (Type type, bool isOptional))>(); + var qparams = new List(); for (int i = expectedEndpointParameterTypes.Length; i < methodParams.Length; i++) { var par = methodParams[i]; var attr = par.GetCustomAttribute(false); - qparams.Add((attr?.Name ?? par.Name ?? throw new ArgumentException($"C# variable name of parameter at index {i} of method {GetFancyMethodName()} is null!"), - (par.ParameterType, attr?.IsOptional ?? false))); + qparams.Add(new( + attr?.Name ?? par.Name ?? throw new ArgumentException($"C# variable name of parameter at index {i} of method {GetFancyMethodName()} is null!"), + par.ParameterType, + attr?.IsOptional ?? false) + ); if (!stringToTypeParameterConverters.ContainsKey(par.ParameterType)) { throw new MissingParameterConverterException($"Parameter converter for type {par.ParameterType} has not been registered (yet)!"); } } + // stores the check attributes that are defined on the method and on the containing class + var requiredChecks = mi.GetCustomAttributes(true).Concat(mi.DeclaringType?.GetCustomAttributes(true) ?? Enumerable.Empty()) + .Where(a => a.GetType().IsAssignableTo(typeof(BaseEndpointCheckAttribute))).Cast().ToArray(); + foreach (var location in attrib.Locations) { var normLocation = NormalizeUrlPath(location); int idx = normLocation.IndexOf('{'); @@ -131,7 +139,7 @@ public sealed class HttpServer { var reqMethod = Enum.GetName(attrib.RequestMethod) ?? throw new ArgumentException("Request method was undefined"); mainLogger.Information($"Registered endpoint: '{reqMethod} {normLocation}'"); - simpleEndpointMethodInfos.Add((normLocation, reqMethod), new EndpointInvocationInfo(mi, qparams)); + simpleEndpointMethodInfos.Add((normLocation, reqMethod), new EndpointInvocationInfo(mi, qparams, requiredChecks)); } } } @@ -207,7 +215,11 @@ public sealed class HttpServer { var parsedQParams = new Dictionary(); var convertedQParamValues = new object[qparams.Count + 1]; - // TODO add authcheck here + // run the checks to see if the client is allowed to make this request + if (!endpointInvocationInfo.CheckAll(rc.ListenerContext.Request)) { // if any check failed return Forbidden + await HandleDefaultErrorPageAsync(rc, HttpStatusCode.Forbidden, "Client is not allowed to access this resource"); + return; + } if (args != null) { var queryStringArgs = args.Split('&', StringSplitOptions.None); @@ -224,21 +236,21 @@ public sealed class HttpServer { } for (int i = 0; i < qparams.Count;) { - var (qparamName, qparamInfo) = qparams[i]; + var qparam = qparams[i]; i++; - if (parsedQParams.TryGetValue(qparamName, out var qparamValue)) { - if (stringToTypeParameterConverters[qparamInfo.type].TryConvertFromString(qparamValue, out object objRes)) { + if (parsedQParams.TryGetValue(qparam.Name, out var qparamValue)) { + if (stringToTypeParameterConverters[qparam.Type].TryConvertFromString(qparamValue, out object objRes)) { convertedQParamValues[i] = objRes; } else { await HandleDefaultErrorPageAsync(rc, HttpStatusCode.BadRequest); return; } } else { - if (qparamInfo.isOptional) { + if (qparam.IsOptional) { convertedQParamValues[i] = null!; } else { - await HandleDefaultErrorPageAsync(rc, HttpStatusCode.BadRequest, $"Missing required query parameter {qparamName}"); + await HandleDefaultErrorPageAsync(rc, HttpStatusCode.BadRequest, $"Missing required query parameter {qparam.Name}"); return; } } diff --git a/SimpleHttpServer/Types/BaseEndpointCheckAttribute.cs b/SimpleHttpServer/Types/BaseEndpointCheckAttribute.cs new file mode 100644 index 0000000..84fe930 --- /dev/null +++ b/SimpleHttpServer/Types/BaseEndpointCheckAttribute.cs @@ -0,0 +1,15 @@ +using System.Net; + +namespace SimpleHttpServer.Types; + +[AttributeUsage(AttributeTargets.Method | AttributeTargets.Class, Inherited = true, AllowMultiple = true)] +public abstract class BaseEndpointCheckAttribute : Attribute { + + public BaseEndpointCheckAttribute() { } + + /// + /// Executed when the endpoint is invoked. The endpoint invocation is skipped if any of the checks fail. + /// + /// True to allow invocation, false to prevent. + public abstract bool Check(HttpListenerRequest req); +} \ No newline at end of file diff --git a/SimpleHttpServer/Types/EndpointInvocationInfo.cs b/SimpleHttpServer/Types/EndpointInvocationInfo.cs index ed0e80c..803b75c 100644 --- a/SimpleHttpServer/Types/EndpointInvocationInfo.cs +++ b/SimpleHttpServer/Types/EndpointInvocationInfo.cs @@ -1,12 +1,19 @@ -using System.Reflection; +using System.Net; +using System.Reflection; namespace SimpleHttpServer.Types; -internal struct EndpointInvocationInfo { - internal readonly MethodInfo methodInfo; - internal readonly List<(string, (Type type, bool isOptional))> queryParameters; +internal readonly struct EndpointInvocationInfo { + internal record struct QueryParameterInfo(string Name, Type Type, bool IsOptional); - public EndpointInvocationInfo(MethodInfo methodInfo, List<(string, (Type type, bool isOptional))> queryParameters) { + internal readonly MethodInfo methodInfo; + internal readonly List queryParameters; + internal readonly BaseEndpointCheckAttribute[] requiredChecks; + + public EndpointInvocationInfo(MethodInfo methodInfo, List queryParameters, BaseEndpointCheckAttribute[] requiredChecks) { this.methodInfo = methodInfo ?? throw new ArgumentNullException(nameof(methodInfo)); this.queryParameters = queryParameters ?? throw new ArgumentNullException(nameof(queryParameters)); + this.requiredChecks = requiredChecks; } + + public readonly bool CheckAll(HttpListenerRequest req) => requiredChecks.All(x => x.Check(req)); }