Compare commits

..

No commits in common. "c75d29a1ba41c6ea14ed87479fb70c2f0c00a1e1" and "cdab5151be10b787d1a819bbf1a9032e8784a109" have entirely different histories.

5 changed files with 27 additions and 53 deletions

View File

@ -1,15 +1,23 @@
using SimpleHttpServer.Types; using SimpleHttpServer.Internal;
using SimpleHttpServer.Types;
namespace SimpleHttpServer; namespace SimpleHttpServer;
[AttributeUsage(AttributeTargets.Method, AllowMultiple = false)] [AttributeUsage(AttributeTargets.Method, AllowMultiple = false)]
public class HttpEndpointAttribute : Attribute { public class HttpEndpointAttribute<T> : Attribute where T : IAuthorizer {
public HttpRequestType RequestMethod { get; private set; } public HttpRequestType RequestMethod { get; private set; }
public string[] Locations { get; private set; } public string[] Locations { get; private set; }
public Type Authorizer { get; private set; }
public HttpEndpointAttribute(HttpRequestType requestMethod, params string[] locations) { public HttpEndpointAttribute(HttpRequestType requestMethod, params string[] locations) {
RequestMethod = requestMethod; RequestMethod = requestMethod;
Locations = locations; Locations = locations;
Authorizer = typeof(T);
} }
} }
[AttributeUsage(AttributeTargets.Method)]
public class HttpEndpointAttribute : HttpEndpointAttribute<DefaultAuthorizer> {
public HttpEndpointAttribute(HttpRequestType type, params string[] locations) : base(type, locations) { }
}

View File

@ -5,7 +5,6 @@ using System.Net;
using System.Numerics; using System.Numerics;
using System.Reflection; using System.Reflection;
using System.Text; using System.Text;
using static SimpleHttpServer.Types.EndpointInvocationInfo;
namespace SimpleHttpServer; namespace SimpleHttpServer;
@ -92,8 +91,8 @@ public sealed class HttpServer {
var t = typeof(T); var t = typeof(T);
foreach (var (mi, attrib) in t.GetMethods() foreach (var (mi, attrib) in t.GetMethods()
.ToDictionary(x => x, x => x.GetCustomAttributes<HttpEndpointAttribute>()) .ToDictionary(x => x, x => x.GetCustomAttributes(typeof(HttpEndpointAttribute<>)))
.Where(x => x.Value.Any()).ToDictionary(x => x.Key, x => x.Value.Single())) { .Where(x => x.Value.Any()).ToDictionary(x => x.Key, x => (HttpEndpointAttribute) x.Value.Single())) {
string GetFancyMethodName() => mi.DeclaringType!.FullName + "#" + mi.Name; string GetFancyMethodName() => mi.DeclaringType!.FullName + "#" + mi.Name;
@ -110,25 +109,18 @@ public sealed class HttpServer {
Assert(mi.ReturnType == typeof(Task), $"Return type of {GetFancyMethodName()} is not {typeof(Task)}!"); Assert(mi.ReturnType == typeof(Task), $"Return type of {GetFancyMethodName()} is not {typeof(Task)}!");
var qparams = new List<QueryParameterInfo>(); var qparams = new List<(string, (Type type, bool isOptional))>();
for (int i = expectedEndpointParameterTypes.Length; i < methodParams.Length; i++) { for (int i = expectedEndpointParameterTypes.Length; i < methodParams.Length; i++) {
var par = methodParams[i]; var par = methodParams[i];
var attr = par.GetCustomAttribute<ParameterAttribute>(false); var attr = par.GetCustomAttribute<ParameterAttribute>(false);
qparams.Add(new( qparams.Add((attr?.Name ?? par.Name ?? throw new ArgumentException($"C# variable name of parameter at index {i} of method {GetFancyMethodName()} is null!"),
attr?.Name ?? par.Name ?? throw new ArgumentException($"C# variable name of parameter at index {i} of method {GetFancyMethodName()} is null!"), (par.ParameterType, attr?.IsOptional ?? false)));
par.ParameterType,
attr?.IsOptional ?? false)
);
if (!stringToTypeParameterConverters.ContainsKey(par.ParameterType)) { if (!stringToTypeParameterConverters.ContainsKey(par.ParameterType)) {
throw new MissingParameterConverterException($"Parameter converter for type {par.ParameterType} has not been registered (yet)!"); throw new MissingParameterConverterException($"Parameter converter for type {par.ParameterType} has not been registered (yet)!");
} }
} }
// stores the check attributes that are defined on the method and on the containing class
var requiredChecks = mi.GetCustomAttributes<BaseEndpointCheckAttribute>(true).Concat(mi.DeclaringType?.GetCustomAttributes<BaseEndpointCheckAttribute>(true) ?? Enumerable.Empty<Attribute>())
.Where(a => a.GetType().IsAssignableTo(typeof(BaseEndpointCheckAttribute))).Cast<BaseEndpointCheckAttribute>().ToArray();
foreach (var location in attrib.Locations) { foreach (var location in attrib.Locations) {
var normLocation = NormalizeUrlPath(location); var normLocation = NormalizeUrlPath(location);
int idx = normLocation.IndexOf('{'); int idx = normLocation.IndexOf('{');
@ -139,7 +131,7 @@ public sealed class HttpServer {
var reqMethod = Enum.GetName(attrib.RequestMethod) ?? throw new ArgumentException("Request method was undefined"); var reqMethod = Enum.GetName(attrib.RequestMethod) ?? throw new ArgumentException("Request method was undefined");
mainLogger.Information($"Registered endpoint: '{reqMethod} {normLocation}'"); mainLogger.Information($"Registered endpoint: '{reqMethod} {normLocation}'");
simpleEndpointMethodInfos.Add((normLocation, reqMethod), new EndpointInvocationInfo(mi, qparams, requiredChecks)); simpleEndpointMethodInfos.Add((normLocation, reqMethod), new EndpointInvocationInfo(mi, qparams));
} }
} }
} }
@ -215,11 +207,7 @@ public sealed class HttpServer {
var parsedQParams = new Dictionary<string, string>(); var parsedQParams = new Dictionary<string, string>();
var convertedQParamValues = new object[qparams.Count + 1]; var convertedQParamValues = new object[qparams.Count + 1];
// run the checks to see if the client is allowed to make this request // TODO add authcheck here
if (!endpointInvocationInfo.CheckAll(rc.ListenerContext.Request)) { // if any check failed return Forbidden
await HandleDefaultErrorPageAsync(rc, HttpStatusCode.Forbidden, "Client is not allowed to access this resource");
return;
}
if (args != null) { if (args != null) {
var queryStringArgs = args.Split('&', StringSplitOptions.None); var queryStringArgs = args.Split('&', StringSplitOptions.None);
@ -236,21 +224,21 @@ public sealed class HttpServer {
} }
for (int i = 0; i < qparams.Count;) { for (int i = 0; i < qparams.Count;) {
var qparam = qparams[i]; var (qparamName, qparamInfo) = qparams[i];
i++; i++;
if (parsedQParams.TryGetValue(qparam.Name, out var qparamValue)) { if (parsedQParams.TryGetValue(qparamName, out var qparamValue)) {
if (stringToTypeParameterConverters[qparam.Type].TryConvertFromString(qparamValue, out object objRes)) { if (stringToTypeParameterConverters[qparamInfo.type].TryConvertFromString(qparamValue, out object objRes)) {
convertedQParamValues[i] = objRes; convertedQParamValues[i] = objRes;
} else { } else {
await HandleDefaultErrorPageAsync(rc, HttpStatusCode.BadRequest); await HandleDefaultErrorPageAsync(rc, HttpStatusCode.BadRequest);
return; return;
} }
} else { } else {
if (qparam.IsOptional) { if (qparamInfo.isOptional) {
convertedQParamValues[i] = null!; convertedQParamValues[i] = null!;
} else { } else {
await HandleDefaultErrorPageAsync(rc, HttpStatusCode.BadRequest, $"Missing required query parameter {qparam.Name}"); await HandleDefaultErrorPageAsync(rc, HttpStatusCode.BadRequest, $"Missing required query parameter {qparamName}");
return; return;
} }
} }

View File

@ -1,7 +1,7 @@
using System.Collections.ObjectModel; using System.Collections.ObjectModel;
using System.Net; using System.Net;
namespace SimpleHttpServer.Types; namespace SimpleHttpServer;
public class RequestContext : IDisposable { public class RequestContext : IDisposable {
public HttpListenerContext ListenerContext { get; } public HttpListenerContext ListenerContext { get; }

View File

@ -1,15 +0,0 @@
using System.Net;
namespace SimpleHttpServer.Types;
[AttributeUsage(AttributeTargets.Method | AttributeTargets.Class, Inherited = true, AllowMultiple = true)]
public abstract class BaseEndpointCheckAttribute : Attribute {
public BaseEndpointCheckAttribute() { }
/// <summary>
/// Executed when the endpoint is invoked. The endpoint invocation is skipped if any of the checks fail.
/// </summary>
/// <returns>True to allow invocation, false to prevent.</returns>
public abstract bool Check(HttpListenerRequest req);
}

View File

@ -1,19 +1,12 @@
using System.Net; using System.Reflection;
using System.Reflection;
namespace SimpleHttpServer.Types; namespace SimpleHttpServer.Types;
internal readonly struct EndpointInvocationInfo { internal struct EndpointInvocationInfo {
internal record struct QueryParameterInfo(string Name, Type Type, bool IsOptional);
internal readonly MethodInfo methodInfo; internal readonly MethodInfo methodInfo;
internal readonly List<QueryParameterInfo> queryParameters; internal readonly List<(string, (Type type, bool isOptional))> queryParameters;
internal readonly BaseEndpointCheckAttribute[] requiredChecks;
public EndpointInvocationInfo(MethodInfo methodInfo, List<QueryParameterInfo> queryParameters, BaseEndpointCheckAttribute[] requiredChecks) { public EndpointInvocationInfo(MethodInfo methodInfo, List<(string, (Type type, bool isOptional))> queryParameters) {
this.methodInfo = methodInfo ?? throw new ArgumentNullException(nameof(methodInfo)); this.methodInfo = methodInfo ?? throw new ArgumentNullException(nameof(methodInfo));
this.queryParameters = queryParameters ?? throw new ArgumentNullException(nameof(queryParameters)); this.queryParameters = queryParameters ?? throw new ArgumentNullException(nameof(queryParameters));
this.requiredChecks = requiredChecks;
} }
public readonly bool CheckAll(HttpListenerRequest req) => requiredChecks.All(x => x.Check(req));
} }