Compare commits
2 Commits
cdab5151be
...
c75d29a1ba
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c75d29a1ba | ||
|
|
fa79134d02 |
|
|
@ -1,23 +1,15 @@
|
|||
using SimpleHttpServer.Internal;
|
||||
using SimpleHttpServer.Types;
|
||||
using SimpleHttpServer.Types;
|
||||
|
||||
namespace SimpleHttpServer;
|
||||
|
||||
[AttributeUsage(AttributeTargets.Method, AllowMultiple = false)]
|
||||
public class HttpEndpointAttribute<T> : Attribute where T : IAuthorizer {
|
||||
public class HttpEndpointAttribute : Attribute {
|
||||
|
||||
public HttpRequestType RequestMethod { get; private set; }
|
||||
public string[] Locations { get; private set; }
|
||||
public Type Authorizer { get; private set; }
|
||||
|
||||
public HttpEndpointAttribute(HttpRequestType requestMethod, params string[] locations) {
|
||||
RequestMethod = requestMethod;
|
||||
Locations = locations;
|
||||
Authorizer = typeof(T);
|
||||
}
|
||||
}
|
||||
|
||||
[AttributeUsage(AttributeTargets.Method)]
|
||||
public class HttpEndpointAttribute : HttpEndpointAttribute<DefaultAuthorizer> {
|
||||
public HttpEndpointAttribute(HttpRequestType type, params string[] locations) : base(type, locations) { }
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ using System.Net;
|
|||
using System.Numerics;
|
||||
using System.Reflection;
|
||||
using System.Text;
|
||||
using static SimpleHttpServer.Types.EndpointInvocationInfo;
|
||||
|
||||
namespace SimpleHttpServer;
|
||||
|
||||
|
|
@ -91,8 +92,8 @@ public sealed class HttpServer {
|
|||
|
||||
var t = typeof(T);
|
||||
foreach (var (mi, attrib) in t.GetMethods()
|
||||
.ToDictionary(x => x, x => x.GetCustomAttributes(typeof(HttpEndpointAttribute<>)))
|
||||
.Where(x => x.Value.Any()).ToDictionary(x => x.Key, x => (HttpEndpointAttribute) x.Value.Single())) {
|
||||
.ToDictionary(x => x, x => x.GetCustomAttributes<HttpEndpointAttribute>())
|
||||
.Where(x => x.Value.Any()).ToDictionary(x => x.Key, x => x.Value.Single())) {
|
||||
|
||||
string GetFancyMethodName() => mi.DeclaringType!.FullName + "#" + mi.Name;
|
||||
|
||||
|
|
@ -109,18 +110,25 @@ public sealed class HttpServer {
|
|||
Assert(mi.ReturnType == typeof(Task), $"Return type of {GetFancyMethodName()} is not {typeof(Task)}!");
|
||||
|
||||
|
||||
var qparams = new List<(string, (Type type, bool isOptional))>();
|
||||
var qparams = new List<QueryParameterInfo>();
|
||||
for (int i = expectedEndpointParameterTypes.Length; i < methodParams.Length; i++) {
|
||||
var par = methodParams[i];
|
||||
var attr = par.GetCustomAttribute<ParameterAttribute>(false);
|
||||
qparams.Add((attr?.Name ?? par.Name ?? throw new ArgumentException($"C# variable name of parameter at index {i} of method {GetFancyMethodName()} is null!"),
|
||||
(par.ParameterType, attr?.IsOptional ?? false)));
|
||||
qparams.Add(new(
|
||||
attr?.Name ?? par.Name ?? throw new ArgumentException($"C# variable name of parameter at index {i} of method {GetFancyMethodName()} is null!"),
|
||||
par.ParameterType,
|
||||
attr?.IsOptional ?? false)
|
||||
);
|
||||
|
||||
if (!stringToTypeParameterConverters.ContainsKey(par.ParameterType)) {
|
||||
throw new MissingParameterConverterException($"Parameter converter for type {par.ParameterType} has not been registered (yet)!");
|
||||
}
|
||||
}
|
||||
|
||||
// stores the check attributes that are defined on the method and on the containing class
|
||||
var requiredChecks = mi.GetCustomAttributes<BaseEndpointCheckAttribute>(true).Concat(mi.DeclaringType?.GetCustomAttributes<BaseEndpointCheckAttribute>(true) ?? Enumerable.Empty<Attribute>())
|
||||
.Where(a => a.GetType().IsAssignableTo(typeof(BaseEndpointCheckAttribute))).Cast<BaseEndpointCheckAttribute>().ToArray();
|
||||
|
||||
foreach (var location in attrib.Locations) {
|
||||
var normLocation = NormalizeUrlPath(location);
|
||||
int idx = normLocation.IndexOf('{');
|
||||
|
|
@ -131,7 +139,7 @@ public sealed class HttpServer {
|
|||
|
||||
var reqMethod = Enum.GetName(attrib.RequestMethod) ?? throw new ArgumentException("Request method was undefined");
|
||||
mainLogger.Information($"Registered endpoint: '{reqMethod} {normLocation}'");
|
||||
simpleEndpointMethodInfos.Add((normLocation, reqMethod), new EndpointInvocationInfo(mi, qparams));
|
||||
simpleEndpointMethodInfos.Add((normLocation, reqMethod), new EndpointInvocationInfo(mi, qparams, requiredChecks));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -207,7 +215,11 @@ public sealed class HttpServer {
|
|||
var parsedQParams = new Dictionary<string, string>();
|
||||
var convertedQParamValues = new object[qparams.Count + 1];
|
||||
|
||||
// TODO add authcheck here
|
||||
// run the checks to see if the client is allowed to make this request
|
||||
if (!endpointInvocationInfo.CheckAll(rc.ListenerContext.Request)) { // if any check failed return Forbidden
|
||||
await HandleDefaultErrorPageAsync(rc, HttpStatusCode.Forbidden, "Client is not allowed to access this resource");
|
||||
return;
|
||||
}
|
||||
|
||||
if (args != null) {
|
||||
var queryStringArgs = args.Split('&', StringSplitOptions.None);
|
||||
|
|
@ -224,21 +236,21 @@ public sealed class HttpServer {
|
|||
}
|
||||
|
||||
for (int i = 0; i < qparams.Count;) {
|
||||
var (qparamName, qparamInfo) = qparams[i];
|
||||
var qparam = qparams[i];
|
||||
i++;
|
||||
|
||||
if (parsedQParams.TryGetValue(qparamName, out var qparamValue)) {
|
||||
if (stringToTypeParameterConverters[qparamInfo.type].TryConvertFromString(qparamValue, out object objRes)) {
|
||||
if (parsedQParams.TryGetValue(qparam.Name, out var qparamValue)) {
|
||||
if (stringToTypeParameterConverters[qparam.Type].TryConvertFromString(qparamValue, out object objRes)) {
|
||||
convertedQParamValues[i] = objRes;
|
||||
} else {
|
||||
await HandleDefaultErrorPageAsync(rc, HttpStatusCode.BadRequest);
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
if (qparamInfo.isOptional) {
|
||||
if (qparam.IsOptional) {
|
||||
convertedQParamValues[i] = null!;
|
||||
} else {
|
||||
await HandleDefaultErrorPageAsync(rc, HttpStatusCode.BadRequest, $"Missing required query parameter {qparamName}");
|
||||
await HandleDefaultErrorPageAsync(rc, HttpStatusCode.BadRequest, $"Missing required query parameter {qparam.Name}");
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
15
SimpleHttpServer/Types/BaseEndpointCheckAttribute.cs
Normal file
15
SimpleHttpServer/Types/BaseEndpointCheckAttribute.cs
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
using System.Net;
|
||||
|
||||
namespace SimpleHttpServer.Types;
|
||||
|
||||
[AttributeUsage(AttributeTargets.Method | AttributeTargets.Class, Inherited = true, AllowMultiple = true)]
|
||||
public abstract class BaseEndpointCheckAttribute : Attribute {
|
||||
|
||||
public BaseEndpointCheckAttribute() { }
|
||||
|
||||
/// <summary>
|
||||
/// Executed when the endpoint is invoked. The endpoint invocation is skipped if any of the checks fail.
|
||||
/// </summary>
|
||||
/// <returns>True to allow invocation, false to prevent.</returns>
|
||||
public abstract bool Check(HttpListenerRequest req);
|
||||
}
|
||||
|
|
@ -1,12 +1,19 @@
|
|||
using System.Reflection;
|
||||
using System.Net;
|
||||
using System.Reflection;
|
||||
|
||||
namespace SimpleHttpServer.Types;
|
||||
internal struct EndpointInvocationInfo {
|
||||
internal readonly MethodInfo methodInfo;
|
||||
internal readonly List<(string, (Type type, bool isOptional))> queryParameters;
|
||||
internal readonly struct EndpointInvocationInfo {
|
||||
internal record struct QueryParameterInfo(string Name, Type Type, bool IsOptional);
|
||||
|
||||
public EndpointInvocationInfo(MethodInfo methodInfo, List<(string, (Type type, bool isOptional))> queryParameters) {
|
||||
internal readonly MethodInfo methodInfo;
|
||||
internal readonly List<QueryParameterInfo> queryParameters;
|
||||
internal readonly BaseEndpointCheckAttribute[] requiredChecks;
|
||||
|
||||
public EndpointInvocationInfo(MethodInfo methodInfo, List<QueryParameterInfo> queryParameters, BaseEndpointCheckAttribute[] requiredChecks) {
|
||||
this.methodInfo = methodInfo ?? throw new ArgumentNullException(nameof(methodInfo));
|
||||
this.queryParameters = queryParameters ?? throw new ArgumentNullException(nameof(queryParameters));
|
||||
this.requiredChecks = requiredChecks;
|
||||
}
|
||||
|
||||
public readonly bool CheckAll(HttpListenerRequest req) => requiredChecks.All(x => x.Check(req));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
using System.Collections.ObjectModel;
|
||||
using System.Net;
|
||||
|
||||
namespace SimpleHttpServer;
|
||||
namespace SimpleHttpServer.Types;
|
||||
public class RequestContext : IDisposable {
|
||||
|
||||
public HttpListenerContext ListenerContext { get; }
|
||||
Loading…
Reference in New Issue
Block a user