feat: differentiate between publisher and author

This commit is contained in:
Vlad Stan 2023-02-17 09:38:49 +02:00
parent b5f7aa0c78
commit a1d7c474b0
3 changed files with 18 additions and 8 deletions

View file

@ -142,6 +142,10 @@ class NostrClientConnection:
return False return False
def _is_direct_message_for_other(self, event: NostrEvent) -> bool: def _is_direct_message_for_other(self, event: NostrEvent) -> bool:
"""
Direct messages are not inteded to be boradcast (even if encrypted).
If the server requires AUTH for kind '4' then direct message will be sent only to the intended client.
"""
if not event.is_direct_message: if not event.is_direct_message:
return False return False
if not self.client_config.event_requires_auth(event.kind): if not self.client_config.event_requires_auth(event.kind):
@ -208,7 +212,7 @@ class NostrClientConnection:
await delete_events( await delete_events(
self.relay_id, NostrFilter(kinds=[e.kind], authors=[e.pubkey]) self.relay_id, NostrFilter(kinds=[e.kind], authors=[e.pubkey])
) )
await create_event(self.relay_id, e) await create_event(self.relay_id, e, self.pubkey)
await self._broadcast_event(e) await self._broadcast_event(e)
if e.is_delete_event: if e.is_delete_event:
@ -257,6 +261,7 @@ class NostrClientConnection:
filter.enforce_limit(self.client_config.limit_per_filter) filter.enforce_limit(self.client_config.limit_per_filter)
self.filters.append(filter) self.filters.append(filter)
events = await get_events(self.relay_id, filter) events = await get_events(self.relay_id, filter)
events = [e for e in events if not self._is_direct_message_for_other(e)]
serialized_events = [ serialized_events = [
event.serialize_response(subscription_id) for event in events event.serialize_response(subscription_id) for event in events
] ]
@ -290,7 +295,7 @@ class NostrClientConnection:
return False, "error: NIP42 tags are missing for auth event" return False, "error: NIP42 tags are missing for auth event"
if self.client_config.domain != extract_domain(relay_tag[0]): if self.client_config.domain != extract_domain(relay_tag[0]):
return False, "error: wrong relay domain for auth event" return False, "error: wrong relay domain for auth event"
if self._auth_challenge != challenge_tag[0]: if self._auth_challenge != challenge_tag[0]:
return False, "error: wrong chanlange value for auth event" return False, "error: wrong chanlange value for auth event"
@ -302,7 +307,8 @@ class NostrClientConnection:
if not valid: if not valid:
return (valid, message) return (valid, message)
valid, message = await self._validate_storage(e.pubkey, e.size_bytes) publisher_pubkey = self.pubkey if self.pubkey else e.pubkey
valid, message = await self._validate_storage(publisher_pubkey, e.size_bytes)
if not valid: if not valid:
return (valid, message) return (valid, message)

13
crud.py
View file

@ -132,11 +132,13 @@ async def delete_relay(user_id: str, relay_id: str):
########################## EVENTS #################### ########################## EVENTS ####################
async def create_event(relay_id: str, e: NostrEvent): async def create_event(relay_id: str, e: NostrEvent, publisher: Optional[str]):
publisher = publisher if publisher else e.pubkey
await db.execute( await db.execute(
""" """
INSERT INTO nostrrelay.events ( INSERT INTO nostrrelay.events (
relay_id, relay_id,
publisher,
id, id,
pubkey, pubkey,
created_at, created_at,
@ -145,10 +147,11 @@ async def create_event(relay_id: str, e: NostrEvent):
sig, sig,
size size
) )
VALUES (?, ?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
""", """,
( (
relay_id, relay_id,
publisher,
e.id, e.id,
e.pubkey, e.pubkey,
e.created_at, e.created_at,
@ -199,14 +202,14 @@ async def get_event(relay_id: str, id: str) -> Optional[NostrEvent]:
return event return event
async def get_storage_for_public_key(relay_id: str, pubkey: str) -> int: async def get_storage_for_public_key(relay_id: str, publisher_pubkey: str) -> int:
"""Returns the storage space in bytes for all the events of a public key. Deleted events are also counted""" """Returns the storage space in bytes for all the events of a public key. Deleted events are also counted"""
row = await db.fetchone( row = await db.fetchone(
"SELECT SUM(size) as sum FROM nostrrelay.events WHERE relay_id = ? AND pubkey = ? GROUP BY pubkey", "SELECT SUM(size) as sum FROM nostrrelay.events WHERE relay_id = ? AND publisher = ? GROUP BY publisher",
( (
relay_id, relay_id,
pubkey, publisher_pubkey,
), ),
) )
if not row: if not row:

View file

@ -23,6 +23,7 @@ async def m001_initial(db):
CREATE TABLE nostrrelay.events ( CREATE TABLE nostrrelay.events (
relay_id TEXT NOT NULL, relay_id TEXT NOT NULL,
deleted BOOLEAN DEFAULT false, deleted BOOLEAN DEFAULT false,
publisher TEXT NOT NULL,
id TEXT NOT NULL, id TEXT NOT NULL,
pubkey TEXT NOT NULL, pubkey TEXT NOT NULL,
created_at {db.big_int} NOT NULL, created_at {db.big_int} NOT NULL,