From dcb8b9b9e124b14651fb20ca6599620ecea8a24c Mon Sep 17 00:00:00 2001 From: 00asdf Date: Fri, 14 Jul 2023 23:47:07 +0200 Subject: [PATCH] better automatic rate limiting --- .../general/utils/discord/DiscordHook.java | 120 ++++++++-------- .../discord/DiscordWebhookException.java | 17 +++ .../internal/DiscordRateLimitBucket.java | 135 ++++++++++++++++++ .../InternalDiscordDataContainer.java | 36 +++++ .../asdf00/general/utils/extras/Triple.java | 13 ++ .../utils/discord/TestDiscordHook.java | 15 ++ 6 files changed, 275 insertions(+), 61 deletions(-) create mode 100644 src/dev/asdf00/general/utils/discord/DiscordWebhookException.java create mode 100644 src/dev/asdf00/general/utils/discord/internal/DiscordRateLimitBucket.java create mode 100644 src/dev/asdf00/general/utils/discord/internal/InternalDiscordDataContainer.java create mode 100644 src/dev/asdf00/general/utils/extras/Triple.java create mode 100644 test/dev/asdf00/general/utils/discord/TestDiscordHook.java diff --git a/src/dev/asdf00/general/utils/discord/DiscordHook.java b/src/dev/asdf00/general/utils/discord/DiscordHook.java index b8261f8..c54cdf9 100644 --- a/src/dev/asdf00/general/utils/discord/DiscordHook.java +++ b/src/dev/asdf00/general/utils/discord/DiscordHook.java @@ -1,37 +1,34 @@ package dev.asdf00.general.utils.discord; +import dev.asdf00.general.utils.discord.internal.DiscordRateLimitBucket; +import dev.asdf00.general.utils.discord.internal.InternalDiscordDataContainer; + import java.io.IOException; import java.net.URI; -import java.net.http.HttpClient; import java.net.http.HttpRequest; import java.net.http.HttpResponse; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.Map; -import java.util.Queue; -import java.util.concurrent.ConcurrentLinkedQueue; -import java.util.concurrent.Executor; -import java.util.concurrent.Executors; +import java.util.*; +import java.util.function.Consumer; public final class DiscordHook { - private static final HttpClient httpClient = HttpClient.newBuilder() - .version(HttpClient.Version.HTTP_2) - .build(); private static final Map currentHooks = new HashMap<>(); - private final String webhook; + private final DiscordRateLimitBucket bucket; - private final Queue msgQueue = new ConcurrentLinkedQueue<>(); + private InternalDiscordDataContainer container; - private final Executor thread = Executors.newSingleThreadExecutor(); + private DiscordHook(String webhook, DiscordRateLimitBucket bucket) { + this.bucket = bucket; + container = new InternalDiscordDataContainer(webhook); + } - private final Object lockObject = new Object(); - - private int timeoutInMillis = 5000; - - private DiscordHook(String webhook) { - this.webhook = webhook; + /** + * Shortcut for {@link DiscordHook#sendMsg(boolean, String, Object...)} with + * splitMessage = false. + */ + public void sendMsg(String msg, Object... args) { + sendMsg(false, msg, args); } /** @@ -44,56 +41,40 @@ public final class DiscordHook { */ public void sendMsg(boolean splitMessage, String msg, Object... args) { var pmsg = String.format(msg, args); - int cnt = 1; + var msgQueue = new ArrayList(); if (!splitMessage && pmsg.length() > 1994) { // large message with pruning - msgQueue.add(pmsg.substring(0, 1995) + " [...]"); + bucket.sendMsg(container, pmsg.substring(0, 1995) + " [...]"); } else if (pmsg.length() > 2000) { // large message with splitting - var ms = new ArrayList(pmsg.length() / 2000 + 1); - for (int i = 0; i < ms.size(); i++) { - ms.add(pmsg.substring(i * 2000, Math.min(pmsg.length(), (i + 1) * 2000 + 1))); + for (int i = 0; i < (pmsg.length() / 2000) + 1; i++) { + bucket.sendMsg(container, pmsg.substring(i * 2000, Math.min(pmsg.length(), (i + 1) * 2000 + 1))); } - msgQueue.addAll(ms); - cnt = ms.size(); } else { // small message - msgQueue.add(pmsg); - } - - // schedule all messages inserted into the queue - for (; cnt > 0; cnt--) { - thread.execute(this::scheduleMsg); - } - } - - private static String wrapIntoJson(String msg) { - return String.format("{\"content\": \"%s\"}", msg); - } - - private void scheduleMsg() { - try { - var json = msgQueue.remove(); - synchronized (lockObject) { - var postRequest = HttpRequest.newBuilder() - .POST(HttpRequest.BodyPublishers.ofString(json)) - .uri(URI.create(webhook)) - .setHeader("User-Agent", "JavaCrawler") - .header("Content-Type", "application/json") - .build(); - httpClient.send(postRequest, HttpResponse.BodyHandlers.ofString()); - Thread.sleep(timeoutInMillis); - } - } catch (InterruptedException | IOException e) { - throw new RuntimeException(e); + bucket.sendMsg(container, pmsg); } } /** - * Sets the timeout for rate limiting. Default is 5 seconds. + * Waits until all messages scheduled via this hook have been sent. + * + * @throws InterruptedException */ - public void setTimeout(int timeout) { - timeoutInMillis = timeout; + public void waitForRemainingMessages() throws InterruptedException { + container.waitForLessThan(1); + } + + /** + * Sets the error handler for this Discord webhook. + * + * @param handler error handler + * @return true if no handler was set previously + */ + public boolean setErrorHandler(Consumer handler) { + var ret = container.handler == null; + container.handler = handler; + return ret; } /** @@ -103,9 +84,26 @@ public final class DiscordHook { * @return instance associated with the given webhook */ public static synchronized DiscordHook getInstance(String webhook) { - if (!currentHooks.containsKey(webhook)) { - currentHooks.put(webhook, new DiscordHook(webhook)); + synchronized (currentHooks) { + if (currentHooks.containsKey(webhook)) { + return currentHooks.get(webhook); + } + var postRequest = HttpRequest.newBuilder() + .POST(HttpRequest.BodyPublishers.ofString("{}")) + .uri(URI.create(webhook)) + .setHeader("User-Agent", "JavaApplication") + .header("Content-Type", "application/json") + .build(); + try { + // try to send invalid message to get the associated rate limit bucket + var response = DiscordRateLimitBucket.httpClient.send(postRequest, HttpResponse.BodyHandlers.ofString()); + var limit = DiscordRateLimitBucket.RateLimit.fromHeaders(response.headers()); + var hook = new DiscordHook(webhook, DiscordRateLimitBucket.getBucket(limit)); + currentHooks.put(webhook, hook); + return hook; + } catch (IOException | InterruptedException e) { + throw new DiscordWebhookException(e); + } } - return currentHooks.get(webhook); } } diff --git a/src/dev/asdf00/general/utils/discord/DiscordWebhookException.java b/src/dev/asdf00/general/utils/discord/DiscordWebhookException.java new file mode 100644 index 0000000..1659bfa --- /dev/null +++ b/src/dev/asdf00/general/utils/discord/DiscordWebhookException.java @@ -0,0 +1,17 @@ +package dev.asdf00.general.utils.discord; + +import java.net.http.HttpResponse; + +public class DiscordWebhookException extends RuntimeException { + public final HttpResponse erroneousResponse; + + public DiscordWebhookException(Throwable e) { + super(e); + this.erroneousResponse = null; + } + + public DiscordWebhookException(HttpResponse erroneousResponse, String msg) { + super(msg); + this.erroneousResponse = erroneousResponse; + } +} diff --git a/src/dev/asdf00/general/utils/discord/internal/DiscordRateLimitBucket.java b/src/dev/asdf00/general/utils/discord/internal/DiscordRateLimitBucket.java new file mode 100644 index 0000000..85905fb --- /dev/null +++ b/src/dev/asdf00/general/utils/discord/internal/DiscordRateLimitBucket.java @@ -0,0 +1,135 @@ +package dev.asdf00.general.utils.discord.internal; + +import dev.asdf00.general.utils.discord.DiscordHook; +import dev.asdf00.general.utils.discord.DiscordWebhookException; +import dev.asdf00.general.utils.extras.Tuple; + +import java.io.IOException; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpHeaders; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.util.HashMap; +import java.util.Map; +import java.util.Queue; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; + +public class DiscordRateLimitBucket { + public static final HttpClient httpClient = HttpClient.newBuilder() + .version(HttpClient.Version.HTTP_2) + .build(); + private static final Map knownBuckets = new HashMap<>(); + + private final Executor thread = Executors.newSingleThreadExecutor(); + + private final Queue> messages = new ConcurrentLinkedQueue<>(); + + private final Object waiter = new Object(); + + private DiscordRateLimitBucket() { + } + + public void sendMsg(InternalDiscordDataContainer container, String msg) { + container.addRemaining(); + messages.add(new Tuple<>(container, wrapIntoMsgJson(msg))); + thread.execute(this::sendingTask); + } + + private void sendingTask() { + synchronized (waiter) { + try { + var msg = messages.remove(); + var response = sendDiscordMessage(msg.a.webhook, msg.b); + var l = RateLimit.fromHeaders(response.headers()); + while (!Thread.interrupted() && response.statusCode() == 429) { + System.out.printf("we hit the rate limit\n"); + Thread.sleep((long) ((l.resetAfter + 0.5) * 1000)); + response = sendDiscordMessage(msg.a.webhook, msg.b); + l = RateLimit.fromHeaders(response.headers()); + } + if (Thread.interrupted()) { + throw new InterruptedException(); + } + if ((response.statusCode() | 4) != 204) { + msg.a.handler.accept(response); + } + if (l.remaining <= 1) { + System.out.println("chillax"); + Thread.sleep((long) ((l.resetAfter + 0.5) * 1000)); + } + msg.a.decrementRemaining(); + } catch (IOException | InterruptedException e) { + throw new DiscordWebhookException(e); + } + } + } + + private HttpResponse sendDiscordMessage(String webhook, String msg) throws IOException, InterruptedException { + var postRequest = HttpRequest.newBuilder() + .POST(HttpRequest.BodyPublishers.ofString(msg)) + .uri(URI.create(webhook)) + .setHeader("User-Agent", "JavaApplication") + .header("Content-Type", "application/json") + .build(); + return httpClient.send(postRequest, HttpResponse.BodyHandlers.ofString()); + } + + private static String wrapIntoMsgJson(String msg) { + return String.format("{\"content\": \"%s\"}", msg); + } + + public static DiscordRateLimitBucket getBucket(RateLimit limit) { + synchronized (knownBuckets) { + var bucket = knownBuckets.get(limit.bucket); + if (bucket == null) { + bucket = new DiscordRateLimitBucket(); + knownBuckets.put(limit.bucket, bucket); + } + return bucket; + } + } + + public static class RateLimit { + public final int limit; + public final int remaining; + public final int reset; + public final double resetAfter; + public final String bucket; + + private RateLimit(int limit, int remaining, int reset, double resetAfter, String bucket) { + this.limit = limit; + this.remaining = remaining; + this.reset = reset; + this.resetAfter = resetAfter; + this.bucket = bucket; + } + + public static RateLimit fromHeaders(HttpHeaders headers) { + int l = -1; + try { + l = (int) headers.firstValueAsLong("X-RateLimit-Limit").orElse(-1); + } catch (NumberFormatException ignore) { + } + int rem = -1; + try { + rem = (int) headers.firstValueAsLong("X-RateLimit-Remaining").orElse(-1); + } catch (NumberFormatException ignore) { + } + int res = -1; + try { + res = (int) headers.firstValueAsLong("X-RateLimit-Reset").orElse(-1); + } catch (NumberFormatException ignore) { + } + double resa = -1; + try { + resa = Double.parseDouble(headers.firstValue("X-RateLimit-Reset-After").orElse("-1")); + } catch (NumberFormatException ignore) { + } + String bucket = headers.firstValue("X_RateLimit-Bucket").orElse(""); + return new RateLimit(l, rem, res, resa, bucket); + } + } +} diff --git a/src/dev/asdf00/general/utils/discord/internal/InternalDiscordDataContainer.java b/src/dev/asdf00/general/utils/discord/internal/InternalDiscordDataContainer.java new file mode 100644 index 0000000..911d225 --- /dev/null +++ b/src/dev/asdf00/general/utils/discord/internal/InternalDiscordDataContainer.java @@ -0,0 +1,36 @@ +package dev.asdf00.general.utils.discord.internal; + +import java.net.http.HttpResponse; +import java.util.function.Consumer; + +public class InternalDiscordDataContainer { + public final String webhook; + public Consumer handler; + private final Object notifier = new Object(); + private long remaining; + + public InternalDiscordDataContainer(String webhook) { + this.webhook = webhook; + } + + public void addRemaining() { + synchronized (notifier) { + remaining++; + } + } + + public void decrementRemaining() { + synchronized (notifier) { + remaining--; + notifier.notifyAll(); + } + } + + public void waitForLessThan(int val) throws InterruptedException { + synchronized (notifier) { + while (remaining >= val) { + notifier.wait(); + } + } + } +} diff --git a/src/dev/asdf00/general/utils/extras/Triple.java b/src/dev/asdf00/general/utils/extras/Triple.java new file mode 100644 index 0000000..7f7fe68 --- /dev/null +++ b/src/dev/asdf00/general/utils/extras/Triple.java @@ -0,0 +1,13 @@ +package dev.asdf00.general.utils.extras; + +public class Triple { + public A a; + public B b; + public C c; + public Triple() { } + public Triple(A a, B b, C c) { + this.a = a; + this.b = b; + this.c = c; + } +} diff --git a/test/dev/asdf00/general/utils/discord/TestDiscordHook.java b/test/dev/asdf00/general/utils/discord/TestDiscordHook.java new file mode 100644 index 0000000..79b7538 --- /dev/null +++ b/test/dev/asdf00/general/utils/discord/TestDiscordHook.java @@ -0,0 +1,15 @@ +package dev.asdf00.general.utils.discord; + +import org.junit.Test; + +public class TestDiscordHook { + @Test + public void sendTestMsg() throws InterruptedException { + var testUri = "https://discord.com/api/webhooks/1129400851128123402/kDYC4SeT9lWVDsO_S0FF1ugW5k-VqyTuBsGIbEcCSFkxrC9fvQpGlT5DNcDlS785nohw"; + var hook = DiscordHook.getInstance(testUri); + for (int i = 0; i < 20; i++) { + hook.sendMsg("spam %s", i); + } + hook.waitForRemainingMessages(); + } +}