Compare commits

...

2 Commits

Author SHA1 Message Date
GHXX c75d29a1ba Switch over to a check-based system with multi attribute support 2024-07-19 03:31:04 +02:00
GHXX fa79134d02 Move file 2024-07-19 03:28:56 +02:00
5 changed files with 54 additions and 28 deletions

View File

@ -1,23 +1,15 @@
using SimpleHttpServer.Internal;
using SimpleHttpServer.Types;
using SimpleHttpServer.Types;
namespace SimpleHttpServer;
[AttributeUsage(AttributeTargets.Method, AllowMultiple = false)]
public class HttpEndpointAttribute<T> : Attribute where T : IAuthorizer {
public class HttpEndpointAttribute : Attribute {
public HttpRequestType RequestMethod { get; private set; }
public string[] Locations { get; private set; }
public Type Authorizer { get; private set; }
public HttpEndpointAttribute(HttpRequestType requestMethod, params string[] locations) {
RequestMethod = requestMethod;
Locations = locations;
Authorizer = typeof(T);
}
}
[AttributeUsage(AttributeTargets.Method)]
public class HttpEndpointAttribute : HttpEndpointAttribute<DefaultAuthorizer> {
public HttpEndpointAttribute(HttpRequestType type, params string[] locations) : base(type, locations) { }
}

View File

@ -5,6 +5,7 @@ using System.Net;
using System.Numerics;
using System.Reflection;
using System.Text;
using static SimpleHttpServer.Types.EndpointInvocationInfo;
namespace SimpleHttpServer;
@ -91,8 +92,8 @@ public sealed class HttpServer {
var t = typeof(T);
foreach (var (mi, attrib) in t.GetMethods()
.ToDictionary(x => x, x => x.GetCustomAttributes(typeof(HttpEndpointAttribute<>)))
.Where(x => x.Value.Any()).ToDictionary(x => x.Key, x => (HttpEndpointAttribute) x.Value.Single())) {
.ToDictionary(x => x, x => x.GetCustomAttributes<HttpEndpointAttribute>())
.Where(x => x.Value.Any()).ToDictionary(x => x.Key, x => x.Value.Single())) {
string GetFancyMethodName() => mi.DeclaringType!.FullName + "#" + mi.Name;
@ -109,18 +110,25 @@ public sealed class HttpServer {
Assert(mi.ReturnType == typeof(Task), $"Return type of {GetFancyMethodName()} is not {typeof(Task)}!");
var qparams = new List<(string, (Type type, bool isOptional))>();
var qparams = new List<QueryParameterInfo>();
for (int i = expectedEndpointParameterTypes.Length; i < methodParams.Length; i++) {
var par = methodParams[i];
var attr = par.GetCustomAttribute<ParameterAttribute>(false);
qparams.Add((attr?.Name ?? par.Name ?? throw new ArgumentException($"C# variable name of parameter at index {i} of method {GetFancyMethodName()} is null!"),
(par.ParameterType, attr?.IsOptional ?? false)));
qparams.Add(new(
attr?.Name ?? par.Name ?? throw new ArgumentException($"C# variable name of parameter at index {i} of method {GetFancyMethodName()} is null!"),
par.ParameterType,
attr?.IsOptional ?? false)
);
if (!stringToTypeParameterConverters.ContainsKey(par.ParameterType)) {
throw new MissingParameterConverterException($"Parameter converter for type {par.ParameterType} has not been registered (yet)!");
}
}
// stores the check attributes that are defined on the method and on the containing class
var requiredChecks = mi.GetCustomAttributes<BaseEndpointCheckAttribute>(true).Concat(mi.DeclaringType?.GetCustomAttributes<BaseEndpointCheckAttribute>(true) ?? Enumerable.Empty<Attribute>())
.Where(a => a.GetType().IsAssignableTo(typeof(BaseEndpointCheckAttribute))).Cast<BaseEndpointCheckAttribute>().ToArray();
foreach (var location in attrib.Locations) {
var normLocation = NormalizeUrlPath(location);
int idx = normLocation.IndexOf('{');
@ -131,7 +139,7 @@ public sealed class HttpServer {
var reqMethod = Enum.GetName(attrib.RequestMethod) ?? throw new ArgumentException("Request method was undefined");
mainLogger.Information($"Registered endpoint: '{reqMethod} {normLocation}'");
simpleEndpointMethodInfos.Add((normLocation, reqMethod), new EndpointInvocationInfo(mi, qparams));
simpleEndpointMethodInfos.Add((normLocation, reqMethod), new EndpointInvocationInfo(mi, qparams, requiredChecks));
}
}
}
@ -207,7 +215,11 @@ public sealed class HttpServer {
var parsedQParams = new Dictionary<string, string>();
var convertedQParamValues = new object[qparams.Count + 1];
// TODO add authcheck here
// run the checks to see if the client is allowed to make this request
if (!endpointInvocationInfo.CheckAll(rc.ListenerContext.Request)) { // if any check failed return Forbidden
await HandleDefaultErrorPageAsync(rc, HttpStatusCode.Forbidden, "Client is not allowed to access this resource");
return;
}
if (args != null) {
var queryStringArgs = args.Split('&', StringSplitOptions.None);
@ -224,21 +236,21 @@ public sealed class HttpServer {
}
for (int i = 0; i < qparams.Count;) {
var (qparamName, qparamInfo) = qparams[i];
var qparam = qparams[i];
i++;
if (parsedQParams.TryGetValue(qparamName, out var qparamValue)) {
if (stringToTypeParameterConverters[qparamInfo.type].TryConvertFromString(qparamValue, out object objRes)) {
if (parsedQParams.TryGetValue(qparam.Name, out var qparamValue)) {
if (stringToTypeParameterConverters[qparam.Type].TryConvertFromString(qparamValue, out object objRes)) {
convertedQParamValues[i] = objRes;
} else {
await HandleDefaultErrorPageAsync(rc, HttpStatusCode.BadRequest);
return;
}
} else {
if (qparamInfo.isOptional) {
if (qparam.IsOptional) {
convertedQParamValues[i] = null!;
} else {
await HandleDefaultErrorPageAsync(rc, HttpStatusCode.BadRequest, $"Missing required query parameter {qparamName}");
await HandleDefaultErrorPageAsync(rc, HttpStatusCode.BadRequest, $"Missing required query parameter {qparam.Name}");
return;
}
}

View File

@ -0,0 +1,15 @@
using System.Net;
namespace SimpleHttpServer.Types;
[AttributeUsage(AttributeTargets.Method | AttributeTargets.Class, Inherited = true, AllowMultiple = true)]
public abstract class BaseEndpointCheckAttribute : Attribute {
public BaseEndpointCheckAttribute() { }
/// <summary>
/// Executed when the endpoint is invoked. The endpoint invocation is skipped if any of the checks fail.
/// </summary>
/// <returns>True to allow invocation, false to prevent.</returns>
public abstract bool Check(HttpListenerRequest req);
}

View File

@ -1,12 +1,19 @@
using System.Reflection;
using System.Net;
using System.Reflection;
namespace SimpleHttpServer.Types;
internal struct EndpointInvocationInfo {
internal readonly MethodInfo methodInfo;
internal readonly List<(string, (Type type, bool isOptional))> queryParameters;
internal readonly struct EndpointInvocationInfo {
internal record struct QueryParameterInfo(string Name, Type Type, bool IsOptional);
public EndpointInvocationInfo(MethodInfo methodInfo, List<(string, (Type type, bool isOptional))> queryParameters) {
internal readonly MethodInfo methodInfo;
internal readonly List<QueryParameterInfo> queryParameters;
internal readonly BaseEndpointCheckAttribute[] requiredChecks;
public EndpointInvocationInfo(MethodInfo methodInfo, List<QueryParameterInfo> queryParameters, BaseEndpointCheckAttribute[] requiredChecks) {
this.methodInfo = methodInfo ?? throw new ArgumentNullException(nameof(methodInfo));
this.queryParameters = queryParameters ?? throw new ArgumentNullException(nameof(queryParameters));
this.requiredChecks = requiredChecks;
}
public readonly bool CheckAll(HttpListenerRequest req) => requiredChecks.All(x => x.Check(req));
}

View File

@ -1,7 +1,7 @@
using System.Collections.ObjectModel;
using System.Net;
namespace SimpleHttpServer;
namespace SimpleHttpServer.Types;
public class RequestContext : IDisposable {
public HttpListenerContext ListenerContext { get; }