Reputation: 11
Is there some ways to do an Anti-Flood middleware in aiogram-3.x (beta), like in aiogram-2 [example]
I tried search some examples about it across the aiogram-3 documentation, but there is no solutions for my problem documentation
Upvotes: 1
Views: 3377
Reputation: 71
Antiflood middleware for aiogram 3.4.1 using redis as storage:
from __future__ import annotations
from typing import *
from aiogram import BaseMiddleware
from aiogram.types import Message
import redis.asyncio.client
import time
def rate_limit(limit: int, key = None):
"""
Decorator for configuring rate limit and key in different functions.
:param limit:
:param key:
:return:
"""
def decorator(func):
setattr(func, 'throttling_rate_limit', limit)
if key:
setattr(func, 'throttling_key', key)
return func
return decorator
class ThrottlingMiddleware(BaseMiddleware):
def __init__(self, redis: redis.asyncio.client.Redis, limit = .5, key_prefix = 'antiflood_'):
self.rate_limit = limit
self.prefix = key_prefix
self.throttle_manager = ThrottleManager(redis = redis)
super(ThrottlingMiddleware, self).__init__()
async def __call__(
self,
handler: Callable[[Message, Dict[str, Any]], Awaitable[Any]],
event: Message,
data: Dict[str, Any]
) -> Any:
try:
await self.on_process_event(event, data)
except CancelHandler:
# Cancel current handler
return
try:
result = await handler(event, data)
except Exception as e:
logger.exception(e)
return result
async def on_process_event(
self,
event: Message,
data: dict,
) -> Any:
limit = getattr(data["handler"].callback, "throttling_rate_limit", self.rate_limit)
key = getattr(data["handler"].callback, "throttling_rate_limit", f"{self.prefix}_message")
# Use ThrottleManager.throttle method.
try:
await self.throttle_manager.throttle(key, rate = limit, user_id = event.from_user.id, chat_id = event.chat.id)
except Throttled as t:
# Execute action
await self.event_throttled(event, t)
# Cancel current handler
raise CancelHandler()
async def event_throttled(self, event: Message, throttled: Throttled):
# Calculate how many time is left till the block ends
delta = throttled.rate - throttled.delta
# Prevent flooding
if throttled.exceeded_count <= 2:
await event.answer(f'Too many events.\nTry again in {delta:.2f} seconds.')
class ThrottleManager:
bucket_keys = [
"RATE_LIMIT", "DELTA",
"LAST_CALL", "EXCEEDED_COUNT"
]
def __init__(self, redis: redis.asyncio.client.Redis):
self.redis = redis
async def throttle(self, key: str, rate: float, user_id: int, chat_id: int):
now = time.time()
bucket_name = f'throttle_{key}_{user_id}_{chat_id}'
data = await self.redis.hmget(bucket_name, self.bucket_keys)
data = {
k: float(v.decode())
if isinstance(v, bytes)
else v
for k, v in zip(self.bucket_keys, data)
if v is not None
}
# Calculate
called = data.get("LAST_CALL", now)
delta = now - called
result = delta >= rate or delta <= 0
# Save result
data["RATE_LIMIT"] = rate
data["LAST_CALL"] = now
data["DELTA"] = delta
if not result:
data["EXCEEDED_COUNT"] += 1
else:
data["EXCEEDED_COUNT"] = 1
await self.redis.hmset(bucket_name, data)
if not result:
raise Throttled(key=key, chat=chat_id, user=user_id, **data)
return result
class Throttled(Exception):
def __init__(self, **kwargs):
self.key = kwargs.pop("key", '<None>')
self.called_at = kwargs.pop("LAST_CALL", time.time())
self.rate = kwargs.pop("RATE_LIMIT", None)
self.exceeded_count = kwargs.pop("EXCEEDED_COUNT", 0)
self.delta = kwargs.pop("DELTA", 0)
self.user = kwargs.pop('user', None)
self.chat = kwargs.pop('chat', None)
def __str__(self):
return f"Rate limit exceeded! (Limit: {self.rate} s, " \
f"exceeded: {self.exceeded_count}, " \
f"time delta: {round(self.delta, 3)} s)"
class CancelHandler(Exception):
pass
Upvotes: 2
Reputation: 3
from aiogram.dispatcher import DEFAULT_RATE_LIMIT
from aiogram.dispatcher.handler import CancelHandler, current_handler
from aiogram.dispatcher.middlewares import BaseMiddleware
from aiogram.utils.exceptions import Throttled
import asyncio
from aiogram import types
from loader import *
def rate_limit(limit: float, key=None):
def decorator(func):
setattr(func, 'throttling_rate_limit', limit)
if key:
setattr(func, 'throttling_key', key)
return func
return decorator
class ThrottlingMiddleware(BaseMiddleware):
def __init__(self, limit=DEFAULT_RATE_LIMIT, key_prefix='antiflood_'):
self.rate_limit = limit
self.prefix = key_prefix
super(ThrottlingMiddleware, self).__init__()
async def on_process_message(self, message: types.Message, data: dict):
handler = current_handler.get()
dispatcher = Dispatcher.get_current()
if handler:
limit = getattr(handler, 'throttling_rate_limit', self.rate_limit)
key = getattr(handler, 'throttling_key', f"{self.prefix}_{handler.__name__}")
else:
limit = self.rate_limit
key = f"{self.prefix}_message"
try:
await dispatcher.throttle(key, rate=limit)
except Throttled as t:
await self.message_throttled(message, t)
raise CancelHandler()
async def message_throttled(self, message: types.Message, throttled: Throttled):
handler = current_handler.get()
dispatcher = Dispatcher.get_current()
if handler:
key = getattr(handler, 'throttling_key', f"{self.prefix}_{handler.__name__}")
else:
key = f"{self.prefix}_message"
delta = throttled.rate - throttled.delta
await asyncio.sleep(delta)
thr = await dispatcher.check_key(key)
Upvotes: 0