Compare commits
3 Commits
29eecc7887
...
2e4570a560
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2e4570a560 | ||
|
|
30daf382ba | ||
|
|
2cf6cd4a7d |
|
|
@ -86,18 +86,24 @@ public sealed class HttpServer {
|
|||
|
||||
private readonly Dictionary<(string path, string rType), EndpointInvocationInfo> simpleEndpointMethodInfos = new();
|
||||
private static readonly Type[] expectedEndpointParameterTypes = new[] { typeof(RequestContext) };
|
||||
public void RegisterEndpointsFromType<T>() {
|
||||
|
||||
public void RegisterEndpointsFromType<T>(Func<T>? instanceFactory = null) where T : class { // T cannot be static, as generic args must be nonstatic
|
||||
if (stringToTypeParameterConverters.Count == 0)
|
||||
RegisterDefaultConverters();
|
||||
|
||||
var t = typeof(T);
|
||||
foreach (var (mi, attrib) in t.GetMethods()
|
||||
var mis = t.GetMethods()
|
||||
.ToDictionary(x => x, x => x.GetCustomAttributes<HttpEndpointAttribute>())
|
||||
.Where(x => x.Value.Any()).ToDictionary(x => x.Key, x => x.Value.Single())) {
|
||||
.Where(x => x.Value.Any()).ToDictionary(x => x.Key, x => x.Value.Single());
|
||||
|
||||
var isStatic = mis.All(x => x.Key.IsStatic); // if all are static then there is no point in having a constructor as no instance data is accessible, but we allow passing a factory anyway
|
||||
Assert(isStatic || (instanceFactory != null), $"You must provide an instance factory if any methods of the given type ({typeof(T).FullName}) are non-static");
|
||||
T? classInstance = instanceFactory?.Invoke();
|
||||
foreach (var (mi, attrib) in mis) {
|
||||
|
||||
string GetFancyMethodName() => mi.DeclaringType!.FullName + "#" + mi.Name;
|
||||
|
||||
Assert(mi.IsStatic, $"Method tagged with HttpEndpointAttribute must be static! ({GetFancyMethodName()})");
|
||||
//Assert(mi.IsStatic, $"Method tagged with HttpEndpointAttribute must be static! ({GetFancyMethodName()})");
|
||||
Assert(mi.IsPublic, $"Method tagged with HttpEndpointAttribute must be public! ({GetFancyMethodName()})");
|
||||
|
||||
var methodParams = mi.GetParameters();
|
||||
|
|
@ -126,8 +132,12 @@ public sealed class HttpServer {
|
|||
}
|
||||
|
||||
// stores the check attributes that are defined on the method and on the containing class
|
||||
var requiredChecks = mi.GetCustomAttributes<BaseEndpointCheckAttribute>(true).Concat(mi.DeclaringType?.GetCustomAttributes<BaseEndpointCheckAttribute>(true) ?? Enumerable.Empty<Attribute>())
|
||||
.Where(a => a.GetType().IsAssignableTo(typeof(BaseEndpointCheckAttribute))).Cast<BaseEndpointCheckAttribute>().ToArray();
|
||||
InternalEndpointCheckAttribute[] requiredChecks = mi.GetCustomAttributes<InternalEndpointCheckAttribute>(true)
|
||||
.Concat(mi.DeclaringType?.GetCustomAttributes<InternalEndpointCheckAttribute>(true) ?? Enumerable.Empty<Attribute>())
|
||||
.Where(a => a.GetType().IsAssignableTo(typeof(InternalEndpointCheckAttribute)))
|
||||
.Cast<InternalEndpointCheckAttribute>().ToArray();
|
||||
|
||||
InternalEndpointCheckAttribute.Initialize(classInstance, requiredChecks);
|
||||
|
||||
foreach (var location in attrib.Locations) {
|
||||
var normLocation = NormalizeUrlPath(location);
|
||||
|
|
@ -139,7 +149,7 @@ public sealed class HttpServer {
|
|||
|
||||
var reqMethod = Enum.GetName(attrib.RequestMethod) ?? throw new ArgumentException("Request method was undefined");
|
||||
mainLogger.Information($"Registered endpoint: '{reqMethod} {normLocation}'");
|
||||
simpleEndpointMethodInfos.Add((normLocation, reqMethod), new EndpointInvocationInfo(mi, qparams, requiredChecks));
|
||||
simpleEndpointMethodInfos.Add((normLocation, reqMethod), new EndpointInvocationInfo(mi, qparams, requiredChecks, classInstance));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -265,7 +275,7 @@ public sealed class HttpServer {
|
|||
convertedQParamValues[0] = rc;
|
||||
rc.ParsedParameters = parsedQParams.AsReadOnly();
|
||||
|
||||
await (Task) (mi.Invoke(null, convertedQParamValues) ?? throw new NullReferenceException("Website func returned null unexpectedly"));
|
||||
await (Task) (mi.Invoke(endpointInvocationInfo.typeInstanceReference, convertedQParamValues) ?? throw new NullReferenceException("Website func returned null unexpectedly"));
|
||||
} else {
|
||||
if (requestMethod == "GET")
|
||||
foreach (var (k, v) in staticServePaths) {
|
||||
|
|
|
|||
|
|
@ -1,15 +1,118 @@
|
|||
using System.Net;
|
||||
using System.Reflection;
|
||||
|
||||
namespace SimpleHttpServer.Types;
|
||||
|
||||
[AttributeUsage(AttributeTargets.Method | AttributeTargets.Class, Inherited = true, AllowMultiple = true)]
|
||||
public abstract class BaseEndpointCheckAttribute : Attribute {
|
||||
public abstract class InternalEndpointCheckAttribute : Attribute {
|
||||
public InternalEndpointCheckAttribute() {
|
||||
CheckSharedVariables();
|
||||
}
|
||||
|
||||
public BaseEndpointCheckAttribute() { }
|
||||
private void CheckSharedVariables() {
|
||||
foreach (var f in GetType().GetRuntimeFields()) {
|
||||
if (f.FieldType.IsAssignableTo(typeof(SharedVariable))) {
|
||||
if (!f.IsInitOnly) {
|
||||
throw new Exception($"Found non-readonly global field {f}!");
|
||||
}
|
||||
if (f.GetValue(this) == null) {
|
||||
throw new Exception("Global fields must be assigned in the CCTOR!");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void Initialize(object? instance, Dictionary<FieldInfo, List<(InternalEndpointCheckAttribute, SharedVariable)>> globals) {
|
||||
SetInstance(instance);
|
||||
foreach (var f in GetType().GetRuntimeFields()) {
|
||||
if (f.FieldType.IsAssignableTo(typeof(SharedVariable))) {
|
||||
SharedVariable origVal = (SharedVariable) f.GetValue(this)!;
|
||||
if (globals.TryGetValue(f, out var options)) {
|
||||
bool foundMatch = false;
|
||||
foreach ((var checker, var gv) in options) {
|
||||
if (Match(checker)) {
|
||||
foundMatch = true;
|
||||
// we need to unify their global variables
|
||||
f.SetValue(this, gv);
|
||||
}
|
||||
}
|
||||
if (!foundMatch) {
|
||||
options.Add((this, origVal));
|
||||
}
|
||||
} else {
|
||||
globals.Add(f, new List<(InternalEndpointCheckAttribute, SharedVariable)>() { (this, origVal) });
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public static void Initialize(object? instance, IEnumerable<InternalEndpointCheckAttribute> endPointChecks) {
|
||||
Dictionary<FieldInfo, List<(InternalEndpointCheckAttribute, SharedVariable)>> globals = new();
|
||||
foreach (var check in endPointChecks) {
|
||||
check.Initialize(instance, globals);
|
||||
}
|
||||
}
|
||||
|
||||
private interface SharedVariable {
|
||||
// Tagging interface
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Represents a Mutable Shared Variable. Fields of this type need to be initialized in the CCtor.
|
||||
/// </summary>
|
||||
protected sealed class MSV<V> : SharedVariable {
|
||||
private readonly V __default;
|
||||
|
||||
public V Val { get; set; } = default!;
|
||||
|
||||
public MSV() : this(default!) { }
|
||||
|
||||
public MSV(V _default) {
|
||||
__default = _default;
|
||||
}
|
||||
|
||||
public static implicit operator V(MSV<V> v) => v.Val;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Represents an Immutable Shared Variable. Fields of this type need to be initialized in the CCtor.
|
||||
/// </summary>
|
||||
protected sealed class ISV<V> : SharedVariable {
|
||||
private readonly V __default;
|
||||
|
||||
public V Val { get; } = default!;
|
||||
|
||||
public ISV() : this(default!) { }
|
||||
|
||||
public ISV(V _default) {
|
||||
__default = _default;
|
||||
}
|
||||
|
||||
public static implicit operator V(ISV<V> v) => v.Val;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Executed when the endpoint is invoked. The endpoint invocation is skipped if any of the checks fail.
|
||||
/// </summary>
|
||||
/// <returns>True to allow invocation, false to prevent.</returns>
|
||||
public abstract bool Check(HttpListenerRequest req);
|
||||
|
||||
protected virtual bool Match(InternalEndpointCheckAttribute other) => true;
|
||||
|
||||
internal abstract void SetInstance(object? instance);
|
||||
}
|
||||
|
||||
[AttributeUsage(AttributeTargets.Method | AttributeTargets.Class, Inherited = true, AllowMultiple = true)]
|
||||
public abstract class BaseEndpointCheckAttribute<T> : InternalEndpointCheckAttribute {
|
||||
/// <summary>
|
||||
/// A reference to the instance of the class that this attribute is attached to.
|
||||
/// Will be null iff an class factory was passed in <see cref="HttpServer.RegisterEndpointsFromType{T}(Func{T}?)"/>.
|
||||
/// </summary>
|
||||
protected internal T? EndpointClassInstance { get; internal set; } = default;
|
||||
|
||||
public BaseEndpointCheckAttribute() : base() { }
|
||||
|
||||
internal override void SetInstance(object? instance) {
|
||||
if (instance != null)
|
||||
EndpointClassInstance = (T?) instance;
|
||||
}
|
||||
}
|
||||
|
|
@ -7,12 +7,17 @@ internal readonly struct EndpointInvocationInfo {
|
|||
|
||||
internal readonly MethodInfo methodInfo;
|
||||
internal readonly List<QueryParameterInfo> queryParameters;
|
||||
internal readonly BaseEndpointCheckAttribute[] requiredChecks;
|
||||
internal readonly InternalEndpointCheckAttribute[] requiredChecks;
|
||||
/// <summary>
|
||||
/// a reference to the object in which this method is defined (or null if the class is static)
|
||||
/// </summary>
|
||||
internal readonly object? typeInstanceReference;
|
||||
|
||||
public EndpointInvocationInfo(MethodInfo methodInfo, List<QueryParameterInfo> queryParameters, BaseEndpointCheckAttribute[] requiredChecks) {
|
||||
public EndpointInvocationInfo(MethodInfo methodInfo, List<QueryParameterInfo> queryParameters, InternalEndpointCheckAttribute[] requiredChecks, object? typeInstanceReference) {
|
||||
this.methodInfo = methodInfo ?? throw new ArgumentNullException(nameof(methodInfo));
|
||||
this.queryParameters = queryParameters ?? throw new ArgumentNullException(nameof(queryParameters));
|
||||
this.requiredChecks = requiredChecks;
|
||||
this.typeInstanceReference = typeInstanceReference;
|
||||
}
|
||||
|
||||
public readonly bool CheckAll(HttpListenerRequest req) => requiredChecks.All(x => x.Check(req));
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user