CSharpHttpServer/SimpleHttpServer/Login/LoginProvider.cs
2024-01-07 23:14:56 +01:00

247 lines
8.6 KiB
C#

using System;
using System.Collections.Generic;
using System.Linq;
using System.Security.Cryptography;
using System.Text;
using Konscious.Security.Cryptography;
using Newtonsoft.Json;
namespace SimpleHttpServer.Login;
internal struct SerialLoginData {
public string salt;
public string pwd;
public string additionalData;
public LoginData toPlainData() {
return new LoginData {
salt = Convert.FromBase64String(salt),
password = Convert.FromBase64String(pwd)
};
}
}
internal struct LoginData {
public byte[] salt;
public byte[] password;
public byte[] encryptedData;
public SerialLoginData toSerial() {
return new SerialLoginData {
salt = Convert.ToBase64String(salt),
pwd = Convert.ToBase64String(password),
additionalData = Convert.ToBase64String(encryptedData)
};
}
}
internal struct LoginDataProviderConfig {
public int SALT_SIZE = 32;
public int KEY_LENGTH = 256 / 8;
public int A2_ITERATIONS = 5;
public int A2_MEMORY_SIZE = 500_000;
public int A2_PARALLELISM = 8;
public int A2_HASH_LENGTH = 256 / 8;
public int A2_MAX_CONCURRENT = 4;
public int PBKDF2_ITERATIONS = 600_000;
public LoginDataProviderConfig() { }
}
public class LoginProvider<T> {
private static readonly Func<T, byte[]> JsonSerialize = t => Encoding.UTF8.GetBytes(JsonConvert.SerializeObject(t));
private static readonly Func<byte[], T> JsonDeserialize = b => JsonConvert.DeserializeObject<T>(Encoding.UTF8.GetString(b))!;
private readonly LoginDataProviderConfig config;
private readonly ReaderWriterLockSlim ldLock = new ReaderWriterLockSlim(LockRecursionPolicy.SupportsRecursion);
private readonly string ldPath;
private readonly Dictionary<string, LoginData> loginData;
private readonly SemaphoreSlim argon2Limit;
private Func<T, byte[]> DataSerializer = JsonSerialize;
private Func<byte[], T> DataDeserializer = JsonDeserialize;
public LoginProvider(string ldPath, string confPath) {
this.ldPath = ldPath;
loginData = LoadLoginData(ldPath);
config = LoadArgon2Config(confPath);
argon2Limit = new SemaphoreSlim(config.A2_MAX_CONCURRENT);
}
private static Dictionary<string, LoginData> LoadLoginData(string path) {
Dictionary<string, SerialLoginData> tempData;
if (!File.Exists(path)) {
File.WriteAllText(path, "{}", Encoding.UTF8);
tempData = new();
} else {
tempData = JsonConvert.DeserializeObject<Dictionary<string, SerialLoginData>>(File.ReadAllText(path))!;
if (tempData == null) {
throw new InvalidDataException($"could not read login data from file {path}");
}
}
var ld = new Dictionary<string, LoginData>();
foreach (var pair in tempData!) {
ld.Add(pair.Key, pair.Value.toPlainData());
}
return ld;
}
private static LoginDataProviderConfig LoadArgon2Config(string path) {
if (!File.Exists(path)) {
var conf = new LoginDataProviderConfig();
File.WriteAllText(path, JsonConvert.SerializeObject(conf));
return conf;
}
return JsonConvert.DeserializeObject<LoginDataProviderConfig>(File.ReadAllText(path));
}
public void SetDataSerialization(Func<T, byte[]> serializer, Func<byte[], T> deserializer) {
DataSerializer = serializer ?? JsonSerialize;
DataDeserializer = deserializer ?? JsonDeserialize;
}
private void StoreLoginData() {
var serial = new Dictionary<string, SerialLoginData>();
ldLock.EnterWriteLock();
try {
foreach (var pair in loginData!) {
serial.Add(pair.Key, pair.Value.toSerial());
}
} finally {
ldLock.ExitWriteLock();
}
File.WriteAllText(ldPath, JsonConvert.SerializeObject(serial));
}
public bool AddUser(string username, string password, T additional) {
ldLock.EnterWriteLock();
try {
if (loginData.ContainsKey(username)) {
return false;
}
var salt = RandomNumberGenerator.GetBytes(config.SALT_SIZE);
var pwdHash = HashPwd(password, salt);
LoginData ld = new LoginData() {
salt = salt,
password = pwdHash,
encryptedData = EncryptAdditionalData(password, salt, additional)
};
loginData.Add(username, ld);
StoreLoginData();
} finally {
ldLock.ExitWriteLock();
}
return true;
}
public bool RemoveUser(string username) {
ldLock.EnterWriteLock();
try {
var removed = loginData.Remove(username);
if (removed) {
StoreLoginData();
}
return removed;
} finally {
ldLock.ExitWriteLock();
}
}
public bool ModifyUser(string username, string newPassword, T newAdditional) {
ldLock.EnterWriteLock();
try {
if (!loginData.ContainsKey(username)) {
return false;
}
loginData.Remove(username, out var data);
data.password = HashPwd(newPassword, data.salt);
data.encryptedData = EncryptAdditionalData(newPassword, data.salt, newAdditional);
loginData.Add(username, data);
StoreLoginData();
} finally {
ldLock.ExitWriteLock();
}
return true;
}
public (bool, T) Authenticate(string username, string password) {
LoginData data;
ldLock.EnterReadLock();
try {
if (!loginData.TryGetValue(username, out data)) {
return (false, default(T)!);
}
} finally {
ldLock.ExitReadLock();
}
var hash = HashPwd(password, data.salt);
if (!hash.SequenceEqual(data.password)) {
return (false, default(T)!);
}
return (true, DecryptAdditionalData(password, data.salt, data.encryptedData));
}
private byte[] HashPwd(string pwd, byte[] salt) {
byte[] hash;
argon2Limit.Wait();
try {
using (var argon2 = new Argon2id(Encoding.UTF8.GetBytes(pwd))) {
argon2.Iterations = config.A2_ITERATIONS;
argon2.MemorySize = config.A2_MEMORY_SIZE;
argon2.DegreeOfParallelism = config.A2_PARALLELISM;
argon2.Salt = salt;
hash = argon2.GetBytes(config.A2_HASH_LENGTH);
}
// force collection to reduce sustained memory usage if many hashes are done in close time proximity to each other
GC.Collect();
} finally {
argon2Limit.Release();
}
return hash;
}
private byte[] EncryptAdditionalData(string pwd, byte[] salt, T data) {
var pbkdf2 = new Rfc2898DeriveBytes(Encoding.UTF8.GetBytes(pwd), salt, config.PBKDF2_ITERATIONS, HashAlgorithmName.SHA256);
var key = pbkdf2.GetBytes(config.KEY_LENGTH / 8);
var plainBytes = DataSerializer(data);
using var aes = Aes.Create();
aes.KeySize = config.KEY_LENGTH;
aes.Key = key;
aes.Mode = CipherMode.CBC;
aes.Padding = PaddingMode.PKCS7;
ICryptoTransform encryptor = aes.CreateEncryptor(aes.Key, aes.IV);
byte[] cipherBytes = encryptor.TransformFinalBlock(plainBytes, 0, plainBytes.Length);
var encryptedBytes = new byte[aes.IV.Length + cipherBytes.Length];
Array.Copy(aes.IV, 0, encryptedBytes, 0, aes.IV.Length);
Array.Copy(cipherBytes, 0, encryptedBytes, aes.IV.Length, cipherBytes.Length);
return encryptedBytes;
}
private T DecryptAdditionalData(string pwd, byte[] salt, byte[] encryptedData) {
var pbkdf2 = new Rfc2898DeriveBytes(Encoding.UTF8.GetBytes(pwd), salt, config.PBKDF2_ITERATIONS, HashAlgorithmName.SHA256);
var key = pbkdf2.GetBytes(config.KEY_LENGTH / 8);
using var aes = Aes.Create();
aes.KeySize = config.KEY_LENGTH;
aes.Key = key;
aes.Mode = CipherMode.CBC;
aes.Padding = PaddingMode.PKCS7;
var iv = new byte[aes.BlockSize / 8];
var cipherBytes = new byte[encryptedData.Length - iv.Length];
Array.Copy(encryptedData, 0, iv, 0, iv.Length);
Array.Copy(encryptedData, iv.Length, cipherBytes, 0, cipherBytes.Length);
aes.IV = iv;
ICryptoTransform decryptor = aes.CreateDecryptor(aes.Key, aes.IV);
byte[] plainBytes = decryptor.TransformFinalBlock(cipherBytes, 0, cipherBytes.Length);
return DataDeserializer(plainBytes);
}
}