feat: limit max events per second

This commit is contained in:
Vlad Stan 2023-02-09 10:28:56 +02:00
parent bddab70677
commit 868e02d3c2
3 changed files with 64 additions and 14 deletions

View file

@ -1,4 +1,5 @@
import json
import time
from typing import Any, Awaitable, Callable, List, Optional
from fastapi import WebSocket
@ -89,6 +90,9 @@ class NostrClientConnection:
self.broadcast_event: Optional[Callable[[NostrClientConnection, NostrEvent], Awaitable[None]]] = None
self.get_client_config: Optional[Callable[[], ClientConfig]] = None
self._last_event_timestamp = 0 # in seconds
self._event_count_per_timestamp = 0
async def start(self):
await self.websocket.accept()
while True:
@ -146,6 +150,11 @@ class NostrClientConnection:
async def _handle_event(self, e: NostrEvent):
resp_nip20: List[Any] = ["OK", e.id]
if self._exceeded_max_events_per_second():
resp_nip20 += [False, f"Exceeded max events per second limit'!"]
await self._send_msg(resp_nip20)
return None
if not self.client_config.is_author_allowed(e.pubkey):
resp_nip20 += [False, f"Public key '{e.pubkey}' is not allowed in relay '{self.relay_id}'!"]
await self._send_msg(resp_nip20)
@ -197,7 +206,7 @@ class NostrClientConnection:
async def _handle_request(self, subscription_id: str, filter: NostrFilter) -> List:
filter.subscription_id = subscription_id
self.remove_filter(subscription_id)
self._remove_filter(subscription_id)
if self._can_add_filter():
return [["NOTICE", f"Maximum number of filters ({self.client_config.max_client_filters}) exceeded."]]
@ -211,11 +220,24 @@ class NostrClientConnection:
serialized_events.append(resp_nip15)
return serialized_events
def remove_filter(self, subscription_id: str):
def _remove_filter(self, subscription_id: str):
self.filters = [f for f in self.filters if f.subscription_id != subscription_id]
def _handle_close(self, subscription_id: str):
self.remove_filter(subscription_id)
self._remove_filter(subscription_id)
def _can_add_filter(self) -> bool:
return self.client_config.max_client_filters != 0 and len(self.filters) >= self.client_config.max_client_filters
return self.client_config.max_client_filters != 0 and len(self.filters) >= self.client_config.max_client_filters
def _exceeded_max_events_per_second(self) -> bool:
if self.client_config.max_events_per_second == 0:
return False
current_time = round(time.time())
if self._last_event_timestamp == current_time:
self._event_count_per_timestamp += 1
else:
self._last_event_timestamp = current_time
self._event_count_per_timestamp = 0
return self._event_count_per_timestamp > self.client_config.max_events_per_second