diff --git a/client_manager.py b/client_manager.py index bab80ae..dccc75e 100644 --- a/client_manager.py +++ b/client_manager.py @@ -79,13 +79,12 @@ class NostrClientManager: return False return True - def _set_client_callbacks(self, client): - setattr(client, "broadcast_event", self.broadcast_event) - + def _set_client_callbacks(self, client: "NostrClientConnection"): def get_client_config() -> RelaySpec: return self.get_relay_config(client.relay_id) setattr(client, "get_client_config", get_client_config) + client.init_callbacks(self.broadcast_event, get_client_config) class NostrClientConnection: @@ -129,6 +128,12 @@ class NostrClientConnection: except: pass + def init_callbacks(self, broadcast_event: Callable, get_client_config: Callable): + setattr(self, "broadcast_event", broadcast_event) + setattr(self, "get_client_config", get_client_config) + setattr(self.event_validator, "get_client_config", get_client_config) + + async def notify_event(self, event: NostrEvent) -> bool: if self._is_direct_message_for_other(event): return False @@ -285,7 +290,6 @@ class NostrClientConnection: and len(self.filters) >= self.client_config.max_client_filters ) - def _auth_challenge_expired(self): if self._auth_challenge_created_at == 0: return True @@ -307,11 +311,12 @@ class EventValidator: def __init__(self, relay_id: str): self.relay_id = relay_id - self.client_config: RelaySpec self._last_event_timestamp = 0 # in hours self._event_count_per_timestamp = 0 + self.get_client_config: Optional[Callable[[], RelaySpec]] = None + async def validate_write(self, e: NostrEvent, publisher_pubkey: str) -> Tuple[bool, str]: valid, message = self._validate_event(e) if not valid: @@ -344,6 +349,12 @@ class EventValidator: return True, "" + @property + def client_config(self) -> RelaySpec: + if not self.get_client_config: + raise Exception("EventValidator not ready!") + return self.get_client_config() + def _validate_event(self, e: NostrEvent) -> Tuple[bool, str]: if self._exceeded_max_events_per_hour(): return False, f"Exceeded max events per hour limit'!" diff --git a/tests/test_clients.py b/tests/test_clients.py index ec493d4..eff5902 100644 --- a/tests/test_clients.py +++ b/tests/test_clients.py @@ -96,7 +96,7 @@ async def alice_wires_meta_and_post01(ws_alice: MockWebSocket): assert ( len(ws_alice.sent_messages) == 4 - ), "Alice: Expected 3 confirmations to be sent" + ), "Alice: Expected 4 confirmations to be sent" assert ws_alice.sent_messages[0] == dumps( alice["meta_response"] ), "Alice: Wrong confirmation for meta"