Compare commits

..

No commits in common. "2cf6cd4a7d447b289dff76c7041224a211f0795a" and "a4ae359df0581b224d7cb7b1ef200a7334dade74" have entirely different histories.

3 changed files with 16 additions and 41 deletions

View File

@ -86,24 +86,18 @@ public sealed class HttpServer {
private readonly Dictionary<(string path, string rType), EndpointInvocationInfo> simpleEndpointMethodInfos = new(); private readonly Dictionary<(string path, string rType), EndpointInvocationInfo> simpleEndpointMethodInfos = new();
private static readonly Type[] expectedEndpointParameterTypes = new[] { typeof(RequestContext) }; private static readonly Type[] expectedEndpointParameterTypes = new[] { typeof(RequestContext) };
public void RegisterEndpointsFromType<T>() {
public void RegisterEndpointsFromType<T>(Func<T>? instanceFactory = null) where T : class { // T cannot be static, as generic args must be nonstatic if (simpleEndpointMethodInfos.Count == 0)
if (stringToTypeParameterConverters.Count == 0)
RegisterDefaultConverters(); RegisterDefaultConverters();
var t = typeof(T); var t = typeof(T);
var mis = t.GetMethods() foreach (var (mi, attrib) in t.GetMethods()
.ToDictionary(x => x, x => x.GetCustomAttributes<HttpEndpointAttribute>()) .ToDictionary(x => x, x => x.GetCustomAttributes<HttpEndpointAttribute>())
.Where(x => x.Value.Any()).ToDictionary(x => x.Key, x => x.Value.Single()); .Where(x => x.Value.Any()).ToDictionary(x => x.Key, x => x.Value.Single())) {
var isStatic = mis.All(x => x.Key.IsStatic); // if all are static then there is no point in having a constructor as no instance data is accessible, but we allow passing a factory anyway
Assert(isStatic || (instanceFactory != null), $"You must provide an instance factory if any methods of the given type ({typeof(T).FullName}) are non-static");
T? classInstance = instanceFactory?.Invoke();
foreach (var (mi, attrib) in mis) {
string GetFancyMethodName() => mi.DeclaringType!.FullName + "#" + mi.Name; string GetFancyMethodName() => mi.DeclaringType!.FullName + "#" + mi.Name;
//Assert(mi.IsStatic, $"Method tagged with HttpEndpointAttribute must be static! ({GetFancyMethodName()})"); Assert(mi.IsStatic, $"Method tagged with HttpEndpointAttribute must be static! ({GetFancyMethodName()})");
Assert(mi.IsPublic, $"Method tagged with HttpEndpointAttribute must be public! ({GetFancyMethodName()})"); Assert(mi.IsPublic, $"Method tagged with HttpEndpointAttribute must be public! ({GetFancyMethodName()})");
var methodParams = mi.GetParameters(); var methodParams = mi.GetParameters();
@ -132,11 +126,8 @@ public sealed class HttpServer {
} }
// stores the check attributes that are defined on the method and on the containing class // stores the check attributes that are defined on the method and on the containing class
var requiredChecks = mi.GetCustomAttributes<InternalEndpointCheckAttribute>(true).Concat(mi.DeclaringType?.GetCustomAttributes<InternalEndpointCheckAttribute>(true) ?? Enumerable.Empty<Attribute>()) var requiredChecks = mi.GetCustomAttributes<BaseEndpointCheckAttribute>(true).Concat(mi.DeclaringType?.GetCustomAttributes<BaseEndpointCheckAttribute>(true) ?? Enumerable.Empty<Attribute>())
.Where(a => a.GetType().IsAssignableTo(typeof(InternalEndpointCheckAttribute))).Cast<InternalEndpointCheckAttribute>().ToArray(); .Where(a => a.GetType().IsAssignableTo(typeof(BaseEndpointCheckAttribute))).Cast<BaseEndpointCheckAttribute>().ToArray();
foreach (var requiredCheck in requiredChecks)
requiredCheck.SetInstance(classInstance);
foreach (var location in attrib.Locations) { foreach (var location in attrib.Locations) {
var normLocation = NormalizeUrlPath(location); var normLocation = NormalizeUrlPath(location);
@ -148,7 +139,7 @@ public sealed class HttpServer {
var reqMethod = Enum.GetName(attrib.RequestMethod) ?? throw new ArgumentException("Request method was undefined"); var reqMethod = Enum.GetName(attrib.RequestMethod) ?? throw new ArgumentException("Request method was undefined");
mainLogger.Information($"Registered endpoint: '{reqMethod} {normLocation}'"); mainLogger.Information($"Registered endpoint: '{reqMethod} {normLocation}'");
simpleEndpointMethodInfos.Add((normLocation, reqMethod), new EndpointInvocationInfo(mi, qparams, requiredChecks, classInstance)); simpleEndpointMethodInfos.Add((normLocation, reqMethod), new EndpointInvocationInfo(mi, qparams, requiredChecks));
} }
} }
} }
@ -274,7 +265,7 @@ public sealed class HttpServer {
convertedQParamValues[0] = rc; convertedQParamValues[0] = rc;
rc.ParsedParameters = parsedQParams.AsReadOnly(); rc.ParsedParameters = parsedQParams.AsReadOnly();
await (Task) (mi.Invoke(endpointInvocationInfo.typeInstanceReference, convertedQParamValues) ?? throw new NullReferenceException("Website func returned null unexpectedly")); await (Task) (mi.Invoke(null, convertedQParamValues) ?? throw new NullReferenceException("Website func returned null unexpectedly"));
} else { } else {
if (requestMethod == "GET") if (requestMethod == "GET")
foreach (var (k, v) in staticServePaths) { foreach (var (k, v) in staticServePaths) {

View File

@ -2,25 +2,14 @@
namespace SimpleHttpServer.Types; namespace SimpleHttpServer.Types;
public abstract class InternalEndpointCheckAttribute : Attribute { [AttributeUsage(AttributeTargets.Method | AttributeTargets.Class, Inherited = true, AllowMultiple = true)]
public abstract class BaseEndpointCheckAttribute : Attribute {
public BaseEndpointCheckAttribute() { }
/// <summary> /// <summary>
/// Executed when the endpoint is invoked. The endpoint invocation is skipped if any of the checks fail. /// Executed when the endpoint is invoked. The endpoint invocation is skipped if any of the checks fail.
/// </summary> /// </summary>
/// <returns>True to allow invocation, false to prevent.</returns> /// <returns>True to allow invocation, false to prevent.</returns>
public abstract bool Check(HttpListenerRequest req); public abstract bool Check(HttpListenerRequest req);
internal abstract void SetInstance(object? instance);
}
[AttributeUsage(AttributeTargets.Method | AttributeTargets.Class, Inherited = true, AllowMultiple = true)]
public abstract class BaseEndpointCheckAttribute<T> : InternalEndpointCheckAttribute {
/// <summary>
/// A reference to the instance of the class that this attribute is attached to.
/// Will be null iff an class factory was passed in <see cref="HttpServer.RegisterEndpointsFromType{T}(Func{T}?)"/>.
/// </summary>
protected internal T? EndpointClassInstance { get; internal set; } = default;
public BaseEndpointCheckAttribute() { }
internal override void SetInstance(object? instance) {
if (instance != null)
EndpointClassInstance = (T?) instance;
}
} }

View File

@ -7,17 +7,12 @@ internal readonly struct EndpointInvocationInfo {
internal readonly MethodInfo methodInfo; internal readonly MethodInfo methodInfo;
internal readonly List<QueryParameterInfo> queryParameters; internal readonly List<QueryParameterInfo> queryParameters;
internal readonly InternalEndpointCheckAttribute[] requiredChecks; internal readonly BaseEndpointCheckAttribute[] requiredChecks;
/// <summary>
/// a reference to the object in which this method is defined (or null if the class is static)
/// </summary>
internal readonly object? typeInstanceReference;
public EndpointInvocationInfo(MethodInfo methodInfo, List<QueryParameterInfo> queryParameters, InternalEndpointCheckAttribute[] requiredChecks, object? typeInstanceReference) { public EndpointInvocationInfo(MethodInfo methodInfo, List<QueryParameterInfo> queryParameters, BaseEndpointCheckAttribute[] requiredChecks) {
this.methodInfo = methodInfo ?? throw new ArgumentNullException(nameof(methodInfo)); this.methodInfo = methodInfo ?? throw new ArgumentNullException(nameof(methodInfo));
this.queryParameters = queryParameters ?? throw new ArgumentNullException(nameof(queryParameters)); this.queryParameters = queryParameters ?? throw new ArgumentNullException(nameof(queryParameters));
this.requiredChecks = requiredChecks; this.requiredChecks = requiredChecks;
this.typeInstanceReference = typeInstanceReference;
} }
public readonly bool CheckAll(HttpListenerRequest req) => requiredChecks.All(x => x.Check(req)); public readonly bool CheckAll(HttpListenerRequest req) => requiredChecks.All(x => x.Check(req));