chore: force py format

This commit is contained in:
Vlad Stan 2023-02-10 12:16:25 +02:00
parent 55f9142f3d
commit 1eda457067
8 changed files with 248 additions and 99 deletions

View file

@ -29,5 +29,5 @@ from .views_api import * # noqa
settings.lnbits_relay_information = { settings.lnbits_relay_information = {
"name": "LNbits Nostr Relay", "name": "LNbits Nostr Relay",
"description": "Multiple relays are supported", "description": "Multiple relays are supported",
**NostrRelay.info() **NostrRelay.info(),
} }

View file

@ -37,7 +37,6 @@ class NostrClientManager:
return True return True
def remove_client(self, c: "NostrClientConnection"): def remove_client(self, c: "NostrClientConnection"):
self.clients(c.relay_id).remove(c) self.clients(c.relay_id).remove(c)
@ -80,17 +79,21 @@ class NostrClientManager:
def _set_client_callbacks(self, client): def _set_client_callbacks(self, client):
setattr(client, "broadcast_event", self.broadcast_event) setattr(client, "broadcast_event", self.broadcast_event)
def get_client_config() -> ClientConfig: def get_client_config() -> ClientConfig:
return self.get_relay_config(client.relay_id) return self.get_relay_config(client.relay_id)
setattr(client, "get_client_config", get_client_config) setattr(client, "get_client_config", get_client_config)
class NostrClientConnection:
class NostrClientConnection:
def __init__(self, relay_id: str, websocket: WebSocket): def __init__(self, relay_id: str, websocket: WebSocket):
self.websocket = websocket self.websocket = websocket
self.relay_id = relay_id self.relay_id = relay_id
self.filters: List[NostrFilter] = [] self.filters: List[NostrFilter] = []
self.broadcast_event: Optional[Callable[[NostrClientConnection, NostrEvent], Awaitable[None]]] = None self.broadcast_event: Optional[
Callable[[NostrClientConnection, NostrEvent], Awaitable[None]]
] = None
self.get_client_config: Optional[Callable[[], ClientConfig]] = None self.get_client_config: Optional[Callable[[], ClientConfig]] = None
self._last_event_timestamp = 0 # in seconds self._last_event_timestamp = 0 # in seconds
@ -185,7 +188,6 @@ class NostrClientConnection:
await self._send_msg(resp_nip20) await self._send_msg(resp_nip20)
@property @property
def client_config(self) -> ClientConfig: def client_config(self) -> ClientConfig:
if not self.get_client_config: if not self.get_client_config:
@ -207,7 +209,12 @@ class NostrClientConnection:
filter.subscription_id = subscription_id filter.subscription_id = subscription_id
self._remove_filter(subscription_id) self._remove_filter(subscription_id)
if self._can_add_filter(): if self._can_add_filter():
return [["NOTICE", f"Maximum number of filters ({self.client_config.max_client_filters}) exceeded."]] return [
[
"NOTICE",
f"Maximum number of filters ({self.client_config.max_client_filters}) exceeded.",
]
]
filter.enforce_limit(self.client_config.limit_per_filter) filter.enforce_limit(self.client_config.limit_per_filter)
self.filters.append(filter) self.filters.append(filter)
@ -226,14 +233,20 @@ class NostrClientConnection:
self._remove_filter(subscription_id) self._remove_filter(subscription_id)
def _can_add_filter(self) -> bool: def _can_add_filter(self) -> bool:
return self.client_config.max_client_filters != 0 and len(self.filters) >= self.client_config.max_client_filters return (
self.client_config.max_client_filters != 0
and len(self.filters) >= self.client_config.max_client_filters
)
def _validate_event(self, e: NostrEvent) -> Tuple[bool, str]: def _validate_event(self, e: NostrEvent) -> Tuple[bool, str]:
if self._exceeded_max_events_per_second(): if self._exceeded_max_events_per_second():
return False, f"Exceeded max events per second limit'!" return False, f"Exceeded max events per second limit'!"
if not self.client_config.is_author_allowed(e.pubkey): if not self.client_config.is_author_allowed(e.pubkey):
return False, f"Public key '{e.pubkey}' is not allowed in relay '{self.relay_id}'!" return (
False,
f"Public key '{e.pubkey}' is not allowed in relay '{self.relay_id}'!",
)
try: try:
e.check_signature() e.check_signature()
@ -253,7 +266,6 @@ class NostrClientConnection:
# todo: handeld paid paid plan # todo: handeld paid paid plan
return True, "Temp OK" return True, "Temp OK"
stored_bytes = await get_storage_for_public_key(self.relay_id, e.pubkey) stored_bytes = await get_storage_for_public_key(self.relay_id, e.pubkey)
if self.client_config.is_paid_relay: if self.client_config.is_paid_relay:
# todo: handeld paid paid plan # todo: handeld paid paid plan
@ -263,7 +275,10 @@ class NostrClientConnection:
return True, "" return True, ""
if self.client_config.full_storage_action == "block": if self.client_config.full_storage_action == "block":
return False, f"Cannot write event, no more storage available for public key: '{e.pubkey}'" return (
False,
f"Cannot write event, no more storage available for public key: '{e.pubkey}'",
)
await prune_old_events(self.relay_id, e.pubkey, e.size_bytes) await prune_old_events(self.relay_id, e.pubkey, e.size_bytes)
@ -280,7 +295,9 @@ class NostrClientConnection:
self._last_event_timestamp = current_time self._last_event_timestamp = current_time
self._event_count_per_timestamp = 0 self._event_count_per_timestamp = 0
return self._event_count_per_timestamp > self.client_config.max_events_per_second return (
self._event_count_per_timestamp > self.client_config.max_events_per_second
)
def _created_at_in_range(self, created_at: int) -> Tuple[bool, str]: def _created_at_in_range(self, created_at: int) -> Tuple[bool, str]:
current_time = round(time.time()) current_time = round(time.time())

116
crud.py
View file

@ -6,18 +6,28 @@ from .models import NostrEvent, NostrFilter, NostrRelay, RelayConfig
########################## RELAYS #################### ########################## RELAYS ####################
async def create_relay(user_id: str, r: NostrRelay) -> NostrRelay: async def create_relay(user_id: str, r: NostrRelay) -> NostrRelay:
await db.execute( await db.execute(
""" """
INSERT INTO nostrrelay.relays (user_id, id, name, description, pubkey, contact, meta) INSERT INTO nostrrelay.relays (user_id, id, name, description, pubkey, contact, meta)
VALUES (?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?)
""", """,
(user_id, r.id, r.name, r.description, r.pubkey, r.contact, json.dumps(dict(r.config))), (
user_id,
r.id,
r.name,
r.description,
r.pubkey,
r.contact,
json.dumps(dict(r.config)),
),
) )
relay = await get_relay(user_id, r.id) relay = await get_relay(user_id, r.id)
assert relay, "Created relay cannot be retrieved" assert relay, "Created relay cannot be retrieved"
return relay return relay
async def update_relay(user_id: str, r: NostrRelay) -> NostrRelay: async def update_relay(user_id: str, r: NostrRelay) -> NostrRelay:
await db.execute( await db.execute(
""" """
@ -25,31 +35,59 @@ async def update_relay(user_id: str, r: NostrRelay) -> NostrRelay:
SET (name, description, pubkey, contact, active, meta) = (?, ?, ?, ?, ?, ?) SET (name, description, pubkey, contact, active, meta) = (?, ?, ?, ?, ?, ?)
WHERE user_id = ? AND id = ? WHERE user_id = ? AND id = ?
""", """,
(r.name, r.description, r.pubkey, r.contact, r.active, json.dumps(dict(r.config)), user_id, r.id), (
r.name,
r.description,
r.pubkey,
r.contact,
r.active,
json.dumps(dict(r.config)),
user_id,
r.id,
),
) )
return r return r
async def get_relay(user_id: str, relay_id: str) -> Optional[NostrRelay]: async def get_relay(user_id: str, relay_id: str) -> Optional[NostrRelay]:
row = await db.fetchone("""SELECT * FROM nostrrelay.relays WHERE user_id = ? AND id = ?""", (user_id, relay_id,)) row = await db.fetchone(
"""SELECT * FROM nostrrelay.relays WHERE user_id = ? AND id = ?""",
(
user_id,
relay_id,
),
)
return NostrRelay.from_row(row) if row else None return NostrRelay.from_row(row) if row else None
async def get_relays(user_id: str) -> List[NostrRelay]: async def get_relays(user_id: str) -> List[NostrRelay]:
rows = await db.fetchall("""SELECT * FROM nostrrelay.relays WHERE user_id = ? ORDER BY id ASC""", (user_id,)) rows = await db.fetchall(
"""SELECT * FROM nostrrelay.relays WHERE user_id = ? ORDER BY id ASC""",
(user_id,),
)
return [NostrRelay.from_row(row) for row in rows] return [NostrRelay.from_row(row) for row in rows]
async def get_config_for_all_active_relays() -> dict: async def get_config_for_all_active_relays() -> dict:
rows = await db.fetchall("SELECT id, meta FROM nostrrelay.relays WHERE active = true",) rows = await db.fetchall(
"SELECT id, meta FROM nostrrelay.relays WHERE active = true",
)
active_relay_configs = {} active_relay_configs = {}
for r in rows: for r in rows:
active_relay_configs[r["id"]] = RelayConfig(**json.loads(r["meta"])) #todo: from_json active_relay_configs[r["id"]] = RelayConfig(
**json.loads(r["meta"])
) # todo: from_json
return active_relay_configs return active_relay_configs
async def get_public_relay(relay_id: str) -> Optional[dict]: async def get_public_relay(relay_id: str) -> Optional[dict]:
row = await db.fetchone("""SELECT * FROM nostrrelay.relays WHERE id = ?""", (relay_id,)) row = await db.fetchone(
"""SELECT * FROM nostrrelay.relays WHERE id = ?""", (relay_id,)
)
if not row: if not row:
return None return None
@ -61,12 +99,18 @@ async def get_public_relay(relay_id: str) -> Optional[dict]:
"name": relay.name, "name": relay.name,
"description": relay.description, "description": relay.description,
"pubkey": relay.pubkey, "pubkey": relay.pubkey,
"contact":relay.contact "contact": relay.contact,
} }
async def delete_relay(user_id: str, relay_id: str): async def delete_relay(user_id: str, relay_id: str):
await db.execute("""DELETE FROM nostrrelay.relays WHERE user_id = ? AND id = ?""", (user_id, relay_id,)) await db.execute(
"""DELETE FROM nostrrelay.relays WHERE user_id = ? AND id = ?""",
(
user_id,
relay_id,
),
)
########################## EVENTS #################### ########################## EVENTS ####################
@ -85,7 +129,16 @@ async def create_event(relay_id: str, e: NostrEvent):
) )
VALUES (?, ?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""", """,
(relay_id, e.id, e.pubkey, e.created_at, e.kind, e.content, e.sig, e.size_bytes), (
relay_id,
e.id,
e.pubkey,
e.created_at,
e.kind,
e.content,
e.sig,
e.size_bytes,
),
) )
# todo: optimize with bulk insert # todo: optimize with bulk insert
@ -94,7 +147,10 @@ async def create_event(relay_id: str, e: NostrEvent):
extra = json.dumps(rest) if rest else None extra = json.dumps(rest) if rest else None
await create_event_tags(relay_id, e.id, name, value, extra) await create_event_tags(relay_id, e.id, name, value, extra)
async def get_events(relay_id: str, filter: NostrFilter, include_tags = True) -> List[NostrEvent]:
async def get_events(
relay_id: str, filter: NostrFilter, include_tags=True
) -> List[NostrEvent]:
query, values = build_select_events_query(relay_id, filter) query, values = build_select_events_query(relay_id, filter)
rows = await db.fetchall(query, tuple(values)) rows = await db.fetchall(query, tuple(values))
@ -108,8 +164,15 @@ async def get_events(relay_id: str, filter: NostrFilter, include_tags = True) ->
return events return events
async def get_event(relay_id: str, id: str) -> Optional[NostrEvent]: async def get_event(relay_id: str, id: str) -> Optional[NostrEvent]:
row = await db.fetchone("SELECT * FROM nostrrelay.events WHERE relay_id = ? AND id = ?", (relay_id, id,)) row = await db.fetchone(
"SELECT * FROM nostrrelay.events WHERE relay_id = ? AND id = ?",
(
relay_id,
id,
),
)
if not row: if not row:
return None return None
@ -117,15 +180,23 @@ async def get_event(relay_id: str, id: str) -> Optional[NostrEvent]:
event.tags = await get_event_tags(relay_id, id) event.tags = await get_event_tags(relay_id, id)
return event return event
async def get_storage_for_public_key(relay_id: str, pubkey: str) -> int: async def get_storage_for_public_key(relay_id: str, pubkey: str) -> int:
"""Returns the storage space in bytes for all the events of a public key. Deleted events are also counted""" """Returns the storage space in bytes for all the events of a public key. Deleted events are also counted"""
row = await db.fetchone("SELECT SUM(size) as sum FROM nostrrelay.events WHERE relay_id = ? AND pubkey = ? GROUP BY pubkey", (relay_id, pubkey,)) row = await db.fetchone(
"SELECT SUM(size) as sum FROM nostrrelay.events WHERE relay_id = ? AND pubkey = ? GROUP BY pubkey",
(
relay_id,
pubkey,
),
)
if not row: if not row:
return 0 return 0
return round(row["sum"]) return round(row["sum"])
async def get_prunable_events(relay_id: str, pubkey: str) -> List[Tuple[str, int]]: async def get_prunable_events(relay_id: str, pubkey: str) -> List[Tuple[str, int]]:
"""Return the oldest 10 000 events. Only the `id` and the size are returned, so the data size should be small""" """Return the oldest 10 000 events. Only the `id` and the size are returned, so the data size should be small"""
query = """ query = """
@ -144,7 +215,11 @@ async def mark_events_deleted(relay_id: str, filter: NostrFilter):
return None return None
_, where, values = filter.to_sql_components(relay_id) _, where, values = filter.to_sql_components(relay_id)
await db.execute(f"""UPDATE nostrrelay.events SET deleted=true WHERE {" AND ".join(where)}""", tuple(values)) await db.execute(
f"""UPDATE nostrrelay.events SET deleted=true WHERE {" AND ".join(where)}""",
tuple(values),
)
async def delete_events(relay_id: str, filter: NostrFilter): async def delete_events(relay_id: str, filter: NostrFilter):
if filter.is_empty(): if filter.is_empty():
@ -155,6 +230,7 @@ async def delete_events(relay_id: str, filter: NostrFilter):
await db.execute(query, tuple(values)) await db.execute(query, tuple(values))
# todo: delete tags # todo: delete tags
async def prune_old_events(relay_id: str, pubkey: str, space_to_regain: int): async def prune_old_events(relay_id: str, pubkey: str, space_to_regain: int):
prunable_events = await get_prunable_events(relay_id, pubkey) prunable_events = await get_prunable_events(relay_id, pubkey)
prunable_event_ids = [] prunable_event_ids = []
@ -175,8 +251,13 @@ async def delete_all_events(relay_id: str):
await db.execute(query, (relay_id,)) await db.execute(query, (relay_id,))
# todo: delete tags # todo: delete tags
async def create_event_tags( async def create_event_tags(
relay_id: str, event_id: str, tag_name: str, tag_value: str, extra_values: Optional[str] relay_id: str,
event_id: str,
tag_name: str,
tag_value: str,
extra_values: Optional[str],
): ):
await db.execute( await db.execute(
""" """
@ -192,9 +273,8 @@ async def create_event_tags(
(relay_id, event_id, tag_name, tag_value, extra_values), (relay_id, event_id, tag_name, tag_value, extra_values),
) )
async def get_event_tags(
relay_id: str, event_id: str async def get_event_tags(relay_id: str, event_id: str) -> List[List[str]]:
) -> List[List[str]]:
rows = await db.fetchall( rows = await db.fetchall(
"SELECT * FROM nostrrelay.event_tags WHERE relay_id = ? and event_id = ?", "SELECT * FROM nostrrelay.event_tags WHERE relay_id = ? and event_id = ?",
(relay_id, event_id), (relay_id, event_id),

View file

@ -23,7 +23,6 @@ class ClientConfig(BaseModel):
created_at_minutes_future = Field(0, alias="createdAtMinutesFuture") created_at_minutes_future = Field(0, alias="createdAtMinutesFuture")
created_at_seconds_future = Field(0, alias="createdAtSecondsFuture") created_at_seconds_future = Field(0, alias="createdAtSecondsFuture")
is_paid_relay = Field(False, alias="isPaidRelay") is_paid_relay = Field(False, alias="isPaidRelay")
free_storage_value = Field(1, alias="freeStorageValue") free_storage_value = Field(1, alias="freeStorageValue")
free_storage_unit = Field("MB", alias="freeStorageUnit") free_storage_unit = Field("MB", alias="freeStorageUnit")
@ -32,7 +31,6 @@ class ClientConfig(BaseModel):
allowed_public_keys = Field([], alias="allowedPublicKeys") allowed_public_keys = Field([], alias="allowedPublicKeys")
blocked_public_keys = Field([], alias="blockedPublicKeys") blocked_public_keys = Field([], alias="blockedPublicKeys")
def is_author_allowed(self, p: str) -> bool: def is_author_allowed(self, p: str) -> bool:
if p in self.blocked_public_keys: if p in self.blocked_public_keys:
return False return False
@ -43,11 +41,21 @@ class ClientConfig(BaseModel):
@property @property
def created_at_in_past(self) -> int: def created_at_in_past(self) -> int:
return self.created_at_days_past * 86400 + self.created_at_hours_past * 3600 + self.created_at_minutes_past * 60 + self.created_at_seconds_past return (
self.created_at_days_past * 86400
+ self.created_at_hours_past * 3600
+ self.created_at_minutes_past * 60
+ self.created_at_seconds_past
)
@property @property
def created_at_in_future(self) -> int: def created_at_in_future(self) -> int:
return self.created_at_days_future * 86400 + self.created_at_hours_future * 3600 + self.created_at_minutes_future * 60 + self.created_at_seconds_future return (
self.created_at_days_future * 86400
+ self.created_at_hours_future * 3600
+ self.created_at_minutes_future * 60
+ self.created_at_seconds_future
)
@property @property
def free_storage_bytes_value(self): def free_storage_bytes_value(self):
@ -59,6 +67,7 @@ class ClientConfig(BaseModel):
class Config: class Config:
allow_population_by_field_name = True allow_population_by_field_name = True
class RelayConfig(ClientConfig): class RelayConfig(ClientConfig):
wallet = Field("") wallet = Field("")
cost_to_join = Field(0, alias="costToJoin") cost_to_join = Field(0, alias="costToJoin")
@ -78,7 +87,6 @@ class NostrRelay(BaseModel):
config: "RelayConfig" = RelayConfig() config: "RelayConfig" = RelayConfig()
@classmethod @classmethod
def from_row(cls, row: Row) -> "NostrRelay": def from_row(cls, row: Row) -> "NostrRelay":
relay = cls(**dict(row)) relay = cls(**dict(row))
@ -86,7 +94,9 @@ class NostrRelay(BaseModel):
return relay return relay
@classmethod @classmethod
def info(cls,) -> dict: def info(
cls,
) -> dict:
return { return {
"contact": "https://t.me/lnbits", "contact": "https://t.me/lnbits",
"supported_nips": [1, 9, 11, 15, 20, 22], "supported_nips": [1, 9, 11, 15, 20, 22],
@ -223,7 +233,9 @@ class NostrFilter(BaseModel):
if not self.limit or self.limit > limit: if not self.limit or self.limit > limit:
self.limit = limit self.limit = limit
def to_sql_components(self, relay_id: str) -> Tuple[List[str], List[str], List[Any]]: def to_sql_components(
self, relay_id: str
) -> Tuple[List[str], List[str], List[Any]]:
inner_joins: List[str] = [] inner_joins: List[str] = []
where = ["deleted=false", "nostrrelay.events.relay_id = ?"] where = ["deleted=false", "nostrrelay.events.relay_id = ?"]
values: List[Any] = [relay_id] values: List[Any] = [relay_id]
@ -231,13 +243,17 @@ class NostrFilter(BaseModel):
if len(self.e): if len(self.e):
values += self.e values += self.e
e_s = ",".join(["?"] * len(self.e)) e_s = ",".join(["?"] * len(self.e))
inner_joins.append("INNER JOIN nostrrelay.event_tags e_tags ON nostrrelay.events.id = e_tags.event_id") inner_joins.append(
"INNER JOIN nostrrelay.event_tags e_tags ON nostrrelay.events.id = e_tags.event_id"
)
where.append(f" (e_tags.value in ({e_s}) AND e_tags.name = 'e')") where.append(f" (e_tags.value in ({e_s}) AND e_tags.name = 'e')")
if len(self.p): if len(self.p):
values += self.p values += self.p
p_s = ",".join(["?"] * len(self.p)) p_s = ",".join(["?"] * len(self.p))
inner_joins.append("INNER JOIN nostrrelay.event_tags p_tags ON nostrrelay.events.id = p_tags.event_id") inner_joins.append(
"INNER JOIN nostrrelay.event_tags p_tags ON nostrrelay.events.id = p_tags.event_id"
)
where.append(f" p_tags.value in ({p_s}) AND p_tags.name = 'p'") where.append(f" p_tags.value in ({p_s}) AND p_tags.name = 'p'")
if len(self.ids) != 0: if len(self.ids) != 0:
@ -263,5 +279,4 @@ class NostrFilter(BaseModel):
where.append("created_at < ?") where.append("created_at < ?")
values += [self.until] values += [self.until]
return inner_joins, where, values return inner_joins, where, values

View file

@ -2,6 +2,7 @@ import json
FIXTURES_PATH = "tests/extensions/nostrrelay/fixture" FIXTURES_PATH = "tests/extensions/nostrrelay/fixture"
def get_fixtures(file): def get_fixtures(file):
""" """
Read the content of the JSON file. Read the content of the JSON file.

View file

@ -40,9 +40,7 @@ class MockWebSocket(WebSocket):
async def wire_mock_data(self, data: dict): async def wire_mock_data(self, data: dict):
await self.fake_wire.put(dumps(data)) await self.fake_wire.put(dumps(data))
async def close( async def close(self, code: int = 1000, reason: Optional[str] = None) -> None:
self, code: int = 1000, reason: Optional[str] = None
) -> None:
logger.info(reason) logger.info(reason)
@ -152,7 +150,6 @@ async def bob_wires_contact_list(ws_alice: MockWebSocket, ws_bob: MockWebSocket)
await ws_alice.wire_mock_data(alice["subscribe_to_bob_contact_list"]) await ws_alice.wire_mock_data(alice["subscribe_to_bob_contact_list"])
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
print("### ws_alice.sent_message", ws_alice.sent_messages) print("### ws_alice.sent_message", ws_alice.sent_messages)
print("### ws_bob.sent_message", ws_bob.sent_messages) print("### ws_bob.sent_message", ws_bob.sent_messages)

View file

@ -12,6 +12,7 @@ from .helpers import get_fixtures
RELAY_ID = "r1" RELAY_ID = "r1"
class EventFixture(BaseModel): class EventFixture(BaseModel):
name: str name: str
exception: Optional[str] exception: Optional[str]
@ -23,6 +24,7 @@ def valid_events() -> List[EventFixture]:
data = get_fixtures("events") data = get_fixtures("events")
return [EventFixture.parse_obj(e) for e in data["valid"]] return [EventFixture.parse_obj(e) for e in data["valid"]]
@pytest.fixture @pytest.fixture
def invalid_events() -> List[EventFixture]: def invalid_events() -> List[EventFixture]:
data = get_fixtures("events") data = get_fixtures("events")
@ -37,6 +39,7 @@ def test_valid_event_id_and_signature(valid_events: List[EventFixture]):
logger.error(f"Invalid 'id' ot 'signature' for fixture: '{f.name}'") logger.error(f"Invalid 'id' ot 'signature' for fixture: '{f.name}'")
raise e raise e
def test_invalid_event_id_and_signature(invalid_events: List[EventFixture]): def test_invalid_event_id_and_signature(invalid_events: List[EventFixture]):
for f in invalid_events: for f in invalid_events:
with pytest.raises(ValueError, match=f.exception): with pytest.raises(ValueError, match=f.exception):
@ -54,12 +57,10 @@ async def test_valid_event_crud(valid_events: List[EventFixture]):
for e in all_events: for e in all_events:
await create_event(RELAY_ID, e) await create_event(RELAY_ID, e)
for f in valid_events: for f in valid_events:
await get_by_id(f.data, f.name) await get_by_id(f.data, f.name)
await filter_by_id(all_events, f.data, f.name) await filter_by_id(all_events, f.data, f.name)
await filter_by_author(all_events, author) await filter_by_author(all_events, author)
await filter_by_tag_p(all_events, author) await filter_by_tag_p(all_events, author)
@ -70,21 +71,29 @@ async def test_valid_event_crud(valid_events: List[EventFixture]):
await filter_by_tag_e_p_and_author(all_events, author, event_id, reply_event_id) await filter_by_tag_e_p_and_author(all_events, author, event_id, reply_event_id)
async def get_by_id(data: NostrEvent, test_name: str): async def get_by_id(data: NostrEvent, test_name: str):
event = await get_event(RELAY_ID, data.id) event = await get_event(RELAY_ID, data.id)
assert event, f"Failed to restore event (id='{data.id}')" assert event, f"Failed to restore event (id='{data.id}')"
assert event.json() != json.dumps(data.json()), f"Restored event is different for fixture '{test_name}'" assert event.json() != json.dumps(
data.json()
), f"Restored event is different for fixture '{test_name}'"
async def filter_by_id(all_events: List[NostrEvent], data: NostrEvent, test_name: str): async def filter_by_id(all_events: List[NostrEvent], data: NostrEvent, test_name: str):
filter = NostrFilter(ids=[data.id]) filter = NostrFilter(ids=[data.id])
events = await get_events(RELAY_ID, filter) events = await get_events(RELAY_ID, filter)
assert len(events) == 1, f"Expected one queried event '{test_name}'" assert len(events) == 1, f"Expected one queried event '{test_name}'"
assert events[0].json() != json.dumps(data.json()), f"Queried event is different for fixture '{test_name}'" assert events[0].json() != json.dumps(
data.json()
), f"Queried event is different for fixture '{test_name}'"
filtered_events = [e for e in all_events if filter.matches(e)] filtered_events = [e for e in all_events if filter.matches(e)]
assert len(filtered_events) == 1, f"Expected one filter event '{test_name}'" assert len(filtered_events) == 1, f"Expected one filter event '{test_name}'"
assert filtered_events[0].json() != json.dumps(data.json()), f"Filtered event is different for fixture '{test_name}'" assert filtered_events[0].json() != json.dumps(
data.json()
), f"Filtered event is different for fixture '{test_name}'"
async def filter_by_author(all_events: List[NostrEvent], author): async def filter_by_author(all_events: List[NostrEvent], author):
@ -95,6 +104,7 @@ async def filter_by_author(all_events: List[NostrEvent], author):
filtered_events = [e for e in all_events if filter.matches(e)] filtered_events = [e for e in all_events if filter.matches(e)]
assert len(filtered_events) == 5, f"Failed to filter by authors" assert len(filtered_events) == 5, f"Failed to filter by authors"
async def filter_by_tag_p(all_events: List[NostrEvent], author): async def filter_by_tag_p(all_events: List[NostrEvent], author):
# todo: check why constructor does not work for fields with aliases (#e, #p) # todo: check why constructor does not work for fields with aliases (#e, #p)
filter = NostrFilter() filter = NostrFilter()
@ -117,29 +127,43 @@ async def filter_by_tag_e(all_events: List[NostrEvent], event_id):
filtered_events = [e for e in all_events if filter.matches(e)] filtered_events = [e for e in all_events if filter.matches(e)]
assert len(filtered_events) == 2, f"Failed to filter by tag 'e'" assert len(filtered_events) == 2, f"Failed to filter by tag 'e'"
async def filter_by_tag_e_and_p(all_events: List[NostrEvent], author, event_id, reply_event_id):
async def filter_by_tag_e_and_p(
all_events: List[NostrEvent], author, event_id, reply_event_id
):
filter = NostrFilter() filter = NostrFilter()
filter.p.append(author) filter.p.append(author)
filter.e.append(event_id) filter.e.append(event_id)
events_related_to_event = await get_events(RELAY_ID, filter) events_related_to_event = await get_events(RELAY_ID, filter)
assert len(events_related_to_event) == 1, f"Failed to quert by tags 'e' & 'p'" assert len(events_related_to_event) == 1, f"Failed to quert by tags 'e' & 'p'"
assert events_related_to_event[0].id == reply_event_id, f"Failed to query the right event by tags 'e' & 'p'" assert (
events_related_to_event[0].id == reply_event_id
), f"Failed to query the right event by tags 'e' & 'p'"
filtered_events = [e for e in all_events if filter.matches(e)] filtered_events = [e for e in all_events if filter.matches(e)]
assert len(filtered_events) == 1, f"Failed to filter by tags 'e' & 'p'" assert len(filtered_events) == 1, f"Failed to filter by tags 'e' & 'p'"
assert filtered_events[0].id == reply_event_id, f"Failed to find the right event by tags 'e' & 'p'" assert (
filtered_events[0].id == reply_event_id
), f"Failed to find the right event by tags 'e' & 'p'"
async def filter_by_tag_e_p_and_author(all_events: List[NostrEvent], author, event_id, reply_event_id):
async def filter_by_tag_e_p_and_author(
all_events: List[NostrEvent], author, event_id, reply_event_id
):
filter = NostrFilter(authors=[author]) filter = NostrFilter(authors=[author])
filter.p.append(author) filter.p.append(author)
filter.e.append(event_id) filter.e.append(event_id)
events_related_to_event = await get_events(RELAY_ID, filter) events_related_to_event = await get_events(RELAY_ID, filter)
assert len(events_related_to_event) == 1, f"Failed to query by 'author' and tags 'e' & 'p'" assert (
assert events_related_to_event[0].id == reply_event_id, f"Failed to query the right event by 'author' and tags 'e' & 'p'" len(events_related_to_event) == 1
), f"Failed to query by 'author' and tags 'e' & 'p'"
assert (
events_related_to_event[0].id == reply_event_id
), f"Failed to query the right event by 'author' and tags 'e' & 'p'"
filtered_events = [e for e in all_events if filter.matches(e)] filtered_events = [e for e in all_events if filter.matches(e)]
assert len(filtered_events) == 1, f"Failed to filter by 'author' and tags 'e' & 'p'" assert len(filtered_events) == 1, f"Failed to filter by 'author' and tags 'e' & 'p'"
assert filtered_events[0].id == reply_event_id, f"Failed to filter the right event by 'author' and tags 'e' & 'p'" assert (
filtered_events[0].id == reply_event_id
), f"Failed to filter the right event by 'author' and tags 'e' & 'p'"

View file

@ -30,6 +30,7 @@ from .models import NostrRelay
client_manager = NostrClientManager() client_manager = NostrClientManager()
@nostrrelay_ext.websocket("/{relay_id}") @nostrrelay_ext.websocket("/{relay_id}")
async def websocket_endpoint(relay_id: str, websocket: WebSocket): async def websocket_endpoint(relay_id: str, websocket: WebSocket):
client = NostrClientConnection(relay_id=relay_id, websocket=websocket) client = NostrClientConnection(relay_id=relay_id, websocket=websocket)
@ -44,7 +45,6 @@ async def websocket_endpoint(relay_id: str, websocket: WebSocket):
client_manager.remove_client(client) client_manager.remove_client(client)
@nostrrelay_ext.get("/{relay_id}", status_code=HTTPStatus.OK) @nostrrelay_ext.get("/{relay_id}", status_code=HTTPStatus.OK)
async def api_nostrrelay_info(relay_id: str): async def api_nostrrelay_info(relay_id: str):
relay = await get_public_relay(relay_id) relay = await get_public_relay(relay_id)
@ -54,16 +54,20 @@ async def api_nostrrelay_info(relay_id: str):
detail="Relay not found", detail="Relay not found",
) )
return JSONResponse(content=relay, headers={ return JSONResponse(
content=relay,
headers={
"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Headers": "*", "Access-Control-Allow-Headers": "*",
"Access-Control-Allow-Methods": "GET" "Access-Control-Allow-Methods": "GET",
}) },
)
@nostrrelay_ext.post("/api/v1/relay") @nostrrelay_ext.post("/api/v1/relay")
async def api_create_relay(data: NostrRelay, wallet: WalletTypeInfo = Depends(require_admin_key)) -> NostrRelay: async def api_create_relay(
data: NostrRelay, wallet: WalletTypeInfo = Depends(require_admin_key)
) -> NostrRelay:
if len(data.id): if len(data.id):
await check_admin(UUID4(wallet.wallet.user)) await check_admin(UUID4(wallet.wallet.user))
else: else:
@ -80,8 +84,11 @@ async def api_create_relay(data: NostrRelay, wallet: WalletTypeInfo = Depends(re
detail="Cannot create relay", detail="Cannot create relay",
) )
@nostrrelay_ext.put("/api/v1/relay/{relay_id}") @nostrrelay_ext.put("/api/v1/relay/{relay_id}")
async def api_update_relay(relay_id: str, data: NostrRelay, wallet: WalletTypeInfo = Depends(require_admin_key)) -> NostrRelay: async def api_update_relay(
relay_id: str, data: NostrRelay, wallet: WalletTypeInfo = Depends(require_admin_key)
) -> NostrRelay:
if relay_id != data.id: if relay_id != data.id:
raise HTTPException( raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST, status_code=HTTPStatus.BAD_REQUEST,
@ -116,7 +123,9 @@ async def api_update_relay(relay_id: str, data: NostrRelay, wallet: WalletTypeIn
@nostrrelay_ext.get("/api/v1/relay") @nostrrelay_ext.get("/api/v1/relay")
async def api_get_relays(wallet: WalletTypeInfo = Depends(require_invoice_key)) -> List[NostrRelay]: async def api_get_relays(
wallet: WalletTypeInfo = Depends(require_invoice_key),
) -> List[NostrRelay]:
try: try:
return await get_relays(wallet.wallet.user) return await get_relays(wallet.wallet.user)
except Exception as ex: except Exception as ex:
@ -126,8 +135,11 @@ async def api_get_relays(wallet: WalletTypeInfo = Depends(require_invoice_key))
detail="Cannot fetch relays", detail="Cannot fetch relays",
) )
@nostrrelay_ext.get("/api/v1/relay/{relay_id}") @nostrrelay_ext.get("/api/v1/relay/{relay_id}")
async def api_get_relay(relay_id: str, wallet: WalletTypeInfo = Depends(require_invoice_key)) -> Optional[NostrRelay]: async def api_get_relay(
relay_id: str, wallet: WalletTypeInfo = Depends(require_invoice_key)
) -> Optional[NostrRelay]:
try: try:
relay = await get_relay(wallet.wallet.user, relay_id) relay = await get_relay(wallet.wallet.user, relay_id)
except Exception as ex: except Exception as ex:
@ -143,8 +155,11 @@ async def api_get_relay(relay_id: str, wallet: WalletTypeInfo = Depends(require_
) )
return relay return relay
@nostrrelay_ext.delete("/api/v1/relay/{relay_id}") @nostrrelay_ext.delete("/api/v1/relay/{relay_id}")
async def api_delete_relay(relay_id: str, wallet: WalletTypeInfo = Depends(require_admin_key)): async def api_delete_relay(
relay_id: str, wallet: WalletTypeInfo = Depends(require_admin_key)
):
try: try:
await client_manager.disable_relay(relay_id) await client_manager.disable_relay(relay_id)
await delete_relay(wallet.wallet.user, relay_id) await delete_relay(wallet.wallet.user, relay_id)