From 1eda457067be6f46fb42b9ec5cf77946e11df815 Mon Sep 17 00:00:00 2001 From: Vlad Stan Date: Fri, 10 Feb 2023 12:16:25 +0200 Subject: [PATCH] chore: force `py` format --- __init__.py | 4 +- client_manager.py | 61 ++++++++++++------- crud.py | 132 +++++++++++++++++++++++++++++++++--------- models.py | 37 ++++++++---- tests/helpers.py | 1 + tests/test_clients.py | 5 +- tests/test_events.py | 64 +++++++++++++------- views_api.py | 43 +++++++++----- 8 files changed, 248 insertions(+), 99 deletions(-) diff --git a/__init__.py b/__init__.py index 849456f..031464e 100644 --- a/__init__.py +++ b/__init__.py @@ -27,7 +27,7 @@ from .views import * # noqa from .views_api import * # noqa settings.lnbits_relay_information = { - "name": "LNbits Nostr Relay", + "name": "LNbits Nostr Relay", "description": "Multiple relays are supported", - **NostrRelay.info() + **NostrRelay.info(), } diff --git a/client_manager.py b/client_manager.py index 70b3be5..1c4cc53 100644 --- a/client_manager.py +++ b/client_manager.py @@ -37,7 +37,6 @@ class NostrClientManager: return True - def remove_client(self, c: "NostrClientConnection"): self.clients(c.relay_id).remove(c) @@ -60,7 +59,7 @@ class NostrClientManager: def get_relay_config(self, relay_id: str) -> RelayConfig: return self._active_relays[relay_id] - + def clients(self, relay_id: str) -> List["NostrClientConnection"]: if relay_id not in self._clients: self._clients[relay_id] = [] @@ -75,25 +74,29 @@ class NostrClientManager: if c.relay_id not in self._active_relays: await c.stop(reason=f"Relay '{c.relay_id}' is not active") return False - #todo: NIP-42: AUTH + # todo: NIP-42: AUTH return True def _set_client_callbacks(self, client): setattr(client, "broadcast_event", self.broadcast_event) + def get_client_config() -> ClientConfig: return self.get_relay_config(client.relay_id) - setattr(client, "get_client_config", get_client_config) - -class NostrClientConnection: + setattr(client, "get_client_config", get_client_config) + + +class NostrClientConnection: def __init__(self, relay_id: str, websocket: WebSocket): self.websocket = websocket self.relay_id = relay_id self.filters: List[NostrFilter] = [] - self.broadcast_event: Optional[Callable[[NostrClientConnection, NostrEvent], Awaitable[None]]] = None + self.broadcast_event: Optional[ + Callable[[NostrClientConnection, NostrEvent], Awaitable[None]] + ] = None self.get_client_config: Optional[Callable[[], ClientConfig]] = None - self._last_event_timestamp = 0 # in seconds + self._last_event_timestamp = 0 # in seconds self._event_count_per_timestamp = 0 async def start(self): @@ -184,12 +187,11 @@ class NostrClientConnection: resp_nip20 += [event != None, message] await self._send_msg(resp_nip20) - @property def client_config(self) -> ClientConfig: if not self.get_client_config: - raise Exception("Client not ready!") + raise Exception("Client not ready!") return self.get_client_config() async def _send_msg(self, data: List): @@ -207,7 +209,12 @@ class NostrClientConnection: filter.subscription_id = subscription_id self._remove_filter(subscription_id) if self._can_add_filter(): - return [["NOTICE", f"Maximum number of filters ({self.client_config.max_client_filters}) exceeded."]] + return [ + [ + "NOTICE", + f"Maximum number of filters ({self.client_config.max_client_filters}) exceeded.", + ] + ] filter.enforce_limit(self.client_config.limit_per_filter) self.filters.append(filter) @@ -226,14 +233,20 @@ class NostrClientConnection: self._remove_filter(subscription_id) def _can_add_filter(self) -> bool: - return self.client_config.max_client_filters != 0 and len(self.filters) >= self.client_config.max_client_filters + return ( + self.client_config.max_client_filters != 0 + and len(self.filters) >= self.client_config.max_client_filters + ) - def _validate_event(self, e: NostrEvent)-> Tuple[bool, str]: + def _validate_event(self, e: NostrEvent) -> Tuple[bool, str]: if self._exceeded_max_events_per_second(): return False, f"Exceeded max events per second limit'!" if not self.client_config.is_author_allowed(e.pubkey): - return False, f"Public key '{e.pubkey}' is not allowed in relay '{self.relay_id}'!" + return ( + False, + f"Public key '{e.pubkey}' is not allowed in relay '{self.relay_id}'!", + ) try: e.check_signature() @@ -252,19 +265,21 @@ class NostrClientConnection: return False, "Cannot write event, relay is read-only" # todo: handeld paid paid plan return True, "Temp OK" - stored_bytes = await get_storage_for_public_key(self.relay_id, e.pubkey) if self.client_config.is_paid_relay: # todo: handeld paid paid plan return True, "Temp OK" - + if (stored_bytes + e.size_bytes) <= self.client_config.free_storage_bytes_value: - return True, "" - + return True, "" + if self.client_config.full_storage_action == "block": - return False, f"Cannot write event, no more storage available for public key: '{e.pubkey}'" - + return ( + False, + f"Cannot write event, no more storage available for public key: '{e.pubkey}'", + ) + await prune_old_events(self.relay_id, e.pubkey, e.size_bytes) return True, "" @@ -280,7 +295,9 @@ class NostrClientConnection: self._last_event_timestamp = current_time self._event_count_per_timestamp = 0 - return self._event_count_per_timestamp > self.client_config.max_events_per_second + return ( + self._event_count_per_timestamp > self.client_config.max_events_per_second + ) def _created_at_in_range(self, created_at: int) -> Tuple[bool, str]: current_time = round(time.time()) @@ -290,4 +307,4 @@ class NostrClientConnection: if self.client_config.created_at_in_future != 0: if created_at > (current_time + self.client_config.created_at_in_future): return False, "created_at is too much into the future" - return True, "" \ No newline at end of file + return True, "" diff --git a/crud.py b/crud.py index a7d2145..70d5dfa 100644 --- a/crud.py +++ b/crud.py @@ -6,18 +6,28 @@ from .models import NostrEvent, NostrFilter, NostrRelay, RelayConfig ########################## RELAYS #################### + async def create_relay(user_id: str, r: NostrRelay) -> NostrRelay: await db.execute( """ INSERT INTO nostrrelay.relays (user_id, id, name, description, pubkey, contact, meta) VALUES (?, ?, ?, ?, ?, ?, ?) """, - (user_id, r.id, r.name, r.description, r.pubkey, r.contact, json.dumps(dict(r.config))), + ( + user_id, + r.id, + r.name, + r.description, + r.pubkey, + r.contact, + json.dumps(dict(r.config)), + ), ) relay = await get_relay(user_id, r.id) assert relay, "Created relay cannot be retrieved" return relay + async def update_relay(user_id: str, r: NostrRelay) -> NostrRelay: await db.execute( """ @@ -25,31 +35,59 @@ async def update_relay(user_id: str, r: NostrRelay) -> NostrRelay: SET (name, description, pubkey, contact, active, meta) = (?, ?, ?, ?, ?, ?) WHERE user_id = ? AND id = ? """, - (r.name, r.description, r.pubkey, r.contact, r.active, json.dumps(dict(r.config)), user_id, r.id), + ( + r.name, + r.description, + r.pubkey, + r.contact, + r.active, + json.dumps(dict(r.config)), + user_id, + r.id, + ), ) - + return r + async def get_relay(user_id: str, relay_id: str) -> Optional[NostrRelay]: - row = await db.fetchone("""SELECT * FROM nostrrelay.relays WHERE user_id = ? AND id = ?""", (user_id, relay_id,)) + row = await db.fetchone( + """SELECT * FROM nostrrelay.relays WHERE user_id = ? AND id = ?""", + ( + user_id, + relay_id, + ), + ) return NostrRelay.from_row(row) if row else None + async def get_relays(user_id: str) -> List[NostrRelay]: - rows = await db.fetchall("""SELECT * FROM nostrrelay.relays WHERE user_id = ? ORDER BY id ASC""", (user_id,)) + rows = await db.fetchall( + """SELECT * FROM nostrrelay.relays WHERE user_id = ? ORDER BY id ASC""", + (user_id,), + ) return [NostrRelay.from_row(row) for row in rows] + async def get_config_for_all_active_relays() -> dict: - rows = await db.fetchall("SELECT id, meta FROM nostrrelay.relays WHERE active = true",) + rows = await db.fetchall( + "SELECT id, meta FROM nostrrelay.relays WHERE active = true", + ) active_relay_configs = {} for r in rows: - active_relay_configs[r["id"]] = RelayConfig(**json.loads(r["meta"])) #todo: from_json + active_relay_configs[r["id"]] = RelayConfig( + **json.loads(r["meta"]) + ) # todo: from_json return active_relay_configs + async def get_public_relay(relay_id: str) -> Optional[dict]: - row = await db.fetchone("""SELECT * FROM nostrrelay.relays WHERE id = ?""", (relay_id,)) + row = await db.fetchone( + """SELECT * FROM nostrrelay.relays WHERE id = ?""", (relay_id,) + ) if not row: return None @@ -59,14 +97,20 @@ async def get_public_relay(relay_id: str) -> Optional[dict]: **NostrRelay.info(), "id": relay.id, "name": relay.name, - "description":relay.description, - "pubkey":relay.pubkey, - "contact":relay.contact + "description": relay.description, + "pubkey": relay.pubkey, + "contact": relay.contact, } async def delete_relay(user_id: str, relay_id: str): - await db.execute("""DELETE FROM nostrrelay.relays WHERE user_id = ? AND id = ?""", (user_id, relay_id,)) + await db.execute( + """DELETE FROM nostrrelay.relays WHERE user_id = ? AND id = ?""", + ( + user_id, + relay_id, + ), + ) ########################## EVENTS #################### @@ -85,7 +129,16 @@ async def create_event(relay_id: str, e: NostrEvent): ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) """, - (relay_id, e.id, e.pubkey, e.created_at, e.kind, e.content, e.sig, e.size_bytes), + ( + relay_id, + e.id, + e.pubkey, + e.created_at, + e.kind, + e.content, + e.sig, + e.size_bytes, + ), ) # todo: optimize with bulk insert @@ -94,7 +147,10 @@ async def create_event(relay_id: str, e: NostrEvent): extra = json.dumps(rest) if rest else None await create_event_tags(relay_id, e.id, name, value, extra) -async def get_events(relay_id: str, filter: NostrFilter, include_tags = True) -> List[NostrEvent]: + +async def get_events( + relay_id: str, filter: NostrFilter, include_tags=True +) -> List[NostrEvent]: query, values = build_select_events_query(relay_id, filter) rows = await db.fetchall(query, tuple(values)) @@ -108,8 +164,15 @@ async def get_events(relay_id: str, filter: NostrFilter, include_tags = True) -> return events + async def get_event(relay_id: str, id: str) -> Optional[NostrEvent]: - row = await db.fetchone("SELECT * FROM nostrrelay.events WHERE relay_id = ? AND id = ?", (relay_id, id,)) + row = await db.fetchone( + "SELECT * FROM nostrrelay.events WHERE relay_id = ? AND id = ?", + ( + relay_id, + id, + ), + ) if not row: return None @@ -117,17 +180,25 @@ async def get_event(relay_id: str, id: str) -> Optional[NostrEvent]: event.tags = await get_event_tags(relay_id, id) return event + async def get_storage_for_public_key(relay_id: str, pubkey: str) -> int: """Returns the storage space in bytes for all the events of a public key. Deleted events are also counted""" - row = await db.fetchone("SELECT SUM(size) as sum FROM nostrrelay.events WHERE relay_id = ? AND pubkey = ? GROUP BY pubkey", (relay_id, pubkey,)) + row = await db.fetchone( + "SELECT SUM(size) as sum FROM nostrrelay.events WHERE relay_id = ? AND pubkey = ? GROUP BY pubkey", + ( + relay_id, + pubkey, + ), + ) if not row: return 0 return round(row["sum"]) + async def get_prunable_events(relay_id: str, pubkey: str) -> List[Tuple[str, int]]: - """ Return the oldest 10 000 events. Only the `id` and the size are returned, so the data size should be small""" + """Return the oldest 10 000 events. Only the `id` and the size are returned, so the data size should be small""" query = """ SELECT id, size FROM nostrrelay.events WHERE relay_id = ? AND pubkey = ? @@ -139,21 +210,26 @@ async def get_prunable_events(relay_id: str, pubkey: str) -> List[Tuple[str, int return [(r["id"], r["size"]) for r in rows] -async def mark_events_deleted(relay_id: str, filter: NostrFilter): +async def mark_events_deleted(relay_id: str, filter: NostrFilter): if filter.is_empty(): return None _, where, values = filter.to_sql_components(relay_id) - await db.execute(f"""UPDATE nostrrelay.events SET deleted=true WHERE {" AND ".join(where)}""", tuple(values)) + await db.execute( + f"""UPDATE nostrrelay.events SET deleted=true WHERE {" AND ".join(where)}""", + tuple(values), + ) -async def delete_events(relay_id: str, filter: NostrFilter): + +async def delete_events(relay_id: str, filter: NostrFilter): if filter.is_empty(): return None _, where, values = filter.to_sql_components(relay_id) query = f"""DELETE from nostrrelay.events WHERE {" AND ".join(where)}""" await db.execute(query, tuple(values)) - #todo: delete tags + # todo: delete tags + async def prune_old_events(relay_id: str, pubkey: str, space_to_regain: int): prunable_events = await get_prunable_events(relay_id, pubkey) @@ -175,8 +251,13 @@ async def delete_all_events(relay_id: str): await db.execute(query, (relay_id,)) # todo: delete tags + async def create_event_tags( - relay_id: str, event_id: str, tag_name: str, tag_value: str, extra_values: Optional[str] + relay_id: str, + event_id: str, + tag_name: str, + tag_value: str, + extra_values: Optional[str], ): await db.execute( """ @@ -192,9 +273,8 @@ async def create_event_tags( (relay_id, event_id, tag_name, tag_value, extra_values), ) -async def get_event_tags( - relay_id: str, event_id: str -) -> List[List[str]]: + +async def get_event_tags(relay_id: str, event_id: str) -> List[List[str]]: rows = await db.fetchall( "SELECT * FROM nostrrelay.event_tags WHERE relay_id = ? and event_id = ?", (relay_id, event_id), @@ -211,7 +291,7 @@ async def get_event_tags( return tags -def build_select_events_query(relay_id:str, filter:NostrFilter): +def build_select_events_query(relay_id: str, filter: NostrFilter): inner_joins, where, values = filter.to_sql_components(relay_id) query = f""" diff --git a/models.py b/models.py index 8ebd3dc..0eb4404 100644 --- a/models.py +++ b/models.py @@ -23,15 +23,13 @@ class ClientConfig(BaseModel): created_at_minutes_future = Field(0, alias="createdAtMinutesFuture") created_at_seconds_future = Field(0, alias="createdAtSecondsFuture") - is_paid_relay = Field(False, alias="isPaidRelay") free_storage_value = Field(1, alias="freeStorageValue") free_storage_unit = Field("MB", alias="freeStorageUnit") full_storage_action = Field("prune", alias="fullStorageAction") - + allowed_public_keys = Field([], alias="allowedPublicKeys") blocked_public_keys = Field([], alias="blockedPublicKeys") - def is_author_allowed(self, p: str) -> bool: if p in self.blocked_public_keys: @@ -43,11 +41,21 @@ class ClientConfig(BaseModel): @property def created_at_in_past(self) -> int: - return self.created_at_days_past * 86400 + self.created_at_hours_past * 3600 + self.created_at_minutes_past * 60 + self.created_at_seconds_past + return ( + self.created_at_days_past * 86400 + + self.created_at_hours_past * 3600 + + self.created_at_minutes_past * 60 + + self.created_at_seconds_past + ) @property def created_at_in_future(self) -> int: - return self.created_at_days_future * 86400 + self.created_at_hours_future * 3600 + self.created_at_minutes_future * 60 + self.created_at_seconds_future + return ( + self.created_at_days_future * 86400 + + self.created_at_hours_future * 3600 + + self.created_at_minutes_future * 60 + + self.created_at_seconds_future + ) @property def free_storage_bytes_value(self): @@ -59,6 +67,7 @@ class ClientConfig(BaseModel): class Config: allow_population_by_field_name = True + class RelayConfig(ClientConfig): wallet = Field("") cost_to_join = Field(0, alias="costToJoin") @@ -78,7 +87,6 @@ class NostrRelay(BaseModel): config: "RelayConfig" = RelayConfig() - @classmethod def from_row(cls, row: Row) -> "NostrRelay": relay = cls(**dict(row)) @@ -86,7 +94,9 @@ class NostrRelay(BaseModel): return relay @classmethod - def info(cls,) -> dict: + def info( + cls, + ) -> dict: return { "contact": "https://t.me/lnbits", "supported_nips": [1, 9, 11, 15, 20, 22], @@ -223,7 +233,9 @@ class NostrFilter(BaseModel): if not self.limit or self.limit > limit: self.limit = limit - def to_sql_components(self, relay_id: str) -> Tuple[List[str], List[str], List[Any]]: + def to_sql_components( + self, relay_id: str + ) -> Tuple[List[str], List[str], List[Any]]: inner_joins: List[str] = [] where = ["deleted=false", "nostrrelay.events.relay_id = ?"] values: List[Any] = [relay_id] @@ -231,13 +243,17 @@ class NostrFilter(BaseModel): if len(self.e): values += self.e e_s = ",".join(["?"] * len(self.e)) - inner_joins.append("INNER JOIN nostrrelay.event_tags e_tags ON nostrrelay.events.id = e_tags.event_id") + inner_joins.append( + "INNER JOIN nostrrelay.event_tags e_tags ON nostrrelay.events.id = e_tags.event_id" + ) where.append(f" (e_tags.value in ({e_s}) AND e_tags.name = 'e')") if len(self.p): values += self.p p_s = ",".join(["?"] * len(self.p)) - inner_joins.append("INNER JOIN nostrrelay.event_tags p_tags ON nostrrelay.events.id = p_tags.event_id") + inner_joins.append( + "INNER JOIN nostrrelay.event_tags p_tags ON nostrrelay.events.id = p_tags.event_id" + ) where.append(f" p_tags.value in ({p_s}) AND p_tags.name = 'p'") if len(self.ids) != 0: @@ -262,6 +278,5 @@ class NostrFilter(BaseModel): if self.until: where.append("created_at < ?") values += [self.until] - return inner_joins, where, values diff --git a/tests/helpers.py b/tests/helpers.py index 7d004bc..f818bc9 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -2,6 +2,7 @@ import json FIXTURES_PATH = "tests/extensions/nostrrelay/fixture" + def get_fixtures(file): """ Read the content of the JSON file. diff --git a/tests/test_clients.py b/tests/test_clients.py index 73469eb..1ef7212 100644 --- a/tests/test_clients.py +++ b/tests/test_clients.py @@ -40,9 +40,7 @@ class MockWebSocket(WebSocket): async def wire_mock_data(self, data: dict): await self.fake_wire.put(dumps(data)) - async def close( - self, code: int = 1000, reason: Optional[str] = None - ) -> None: + async def close(self, code: int = 1000, reason: Optional[str] = None) -> None: logger.info(reason) @@ -152,7 +150,6 @@ async def bob_wires_contact_list(ws_alice: MockWebSocket, ws_bob: MockWebSocket) await ws_alice.wire_mock_data(alice["subscribe_to_bob_contact_list"]) await asyncio.sleep(0.1) - print("### ws_alice.sent_message", ws_alice.sent_messages) print("### ws_bob.sent_message", ws_bob.sent_messages) diff --git a/tests/test_events.py b/tests/test_events.py index 0793204..0e6c296 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -12,6 +12,7 @@ from .helpers import get_fixtures RELAY_ID = "r1" + class EventFixture(BaseModel): name: str exception: Optional[str] @@ -23,6 +24,7 @@ def valid_events() -> List[EventFixture]: data = get_fixtures("events") return [EventFixture.parse_obj(e) for e in data["valid"]] + @pytest.fixture def invalid_events() -> List[EventFixture]: data = get_fixtures("events") @@ -37,6 +39,7 @@ def test_valid_event_id_and_signature(valid_events: List[EventFixture]): logger.error(f"Invalid 'id' ot 'signature' for fixture: '{f.name}'") raise e + def test_invalid_event_id_and_signature(invalid_events: List[EventFixture]): for f in invalid_events: with pytest.raises(ValueError, match=f.exception): @@ -44,7 +47,7 @@ def test_invalid_event_id_and_signature(invalid_events: List[EventFixture]): @pytest.mark.asyncio -async def test_valid_event_crud(valid_events: List[EventFixture]): +async def test_valid_event_crud(valid_events: List[EventFixture]): author = "a24496bca5dd73300f4e5d5d346c73132b7354c597fcbb6509891747b4689211" event_id = "3219eec7427e365585d5adf26f5d2dd2709d3f0f2c0e1f79dc9021e951c67d96" reply_event_id = "6b2b6cb9c72caaf3dfbc5baa5e68d75ac62f38ec011b36cc83832218c36e4894" @@ -54,12 +57,10 @@ async def test_valid_event_crud(valid_events: List[EventFixture]): for e in all_events: await create_event(RELAY_ID, e) - - for f in valid_events: + for f in valid_events: await get_by_id(f.data, f.name) await filter_by_id(all_events, f.data, f.name) - await filter_by_author(all_events, author) await filter_by_tag_p(all_events, author) @@ -70,22 +71,30 @@ async def test_valid_event_crud(valid_events: List[EventFixture]): await filter_by_tag_e_p_and_author(all_events, author, event_id, reply_event_id) + async def get_by_id(data: NostrEvent, test_name: str): event = await get_event(RELAY_ID, data.id) assert event, f"Failed to restore event (id='{data.id}')" - assert event.json() != json.dumps(data.json()), f"Restored event is different for fixture '{test_name}'" + assert event.json() != json.dumps( + data.json() + ), f"Restored event is different for fixture '{test_name}'" + async def filter_by_id(all_events: List[NostrEvent], data: NostrEvent, test_name: str): filter = NostrFilter(ids=[data.id]) events = await get_events(RELAY_ID, filter) assert len(events) == 1, f"Expected one queried event '{test_name}'" - assert events[0].json() != json.dumps(data.json()), f"Queried event is different for fixture '{test_name}'" + assert events[0].json() != json.dumps( + data.json() + ), f"Queried event is different for fixture '{test_name}'" filtered_events = [e for e in all_events if filter.matches(e)] assert len(filtered_events) == 1, f"Expected one filter event '{test_name}'" - assert filtered_events[0].json() != json.dumps(data.json()), f"Filtered event is different for fixture '{test_name}'" - + assert filtered_events[0].json() != json.dumps( + data.json() + ), f"Filtered event is different for fixture '{test_name}'" + async def filter_by_author(all_events: List[NostrEvent], author): filter = NostrFilter(authors=[author]) @@ -95,9 +104,10 @@ async def filter_by_author(all_events: List[NostrEvent], author): filtered_events = [e for e in all_events if filter.matches(e)] assert len(filtered_events) == 5, f"Failed to filter by authors" + async def filter_by_tag_p(all_events: List[NostrEvent], author): # todo: check why constructor does not work for fields with aliases (#e, #p) - filter = NostrFilter() + filter = NostrFilter() filter.p.append(author) events_related_to_author = await get_events(RELAY_ID, filter) @@ -107,7 +117,7 @@ async def filter_by_tag_p(all_events: List[NostrEvent], author): assert len(filtered_events) == 5, f"Failed to filter by tag 'p'" -async def filter_by_tag_e(all_events: List[NostrEvent], event_id): +async def filter_by_tag_e(all_events: List[NostrEvent], event_id): filter = NostrFilter() filter.e.append(event_id) @@ -117,29 +127,43 @@ async def filter_by_tag_e(all_events: List[NostrEvent], event_id): filtered_events = [e for e in all_events if filter.matches(e)] assert len(filtered_events) == 2, f"Failed to filter by tag 'e'" -async def filter_by_tag_e_and_p(all_events: List[NostrEvent], author, event_id, reply_event_id): + +async def filter_by_tag_e_and_p( + all_events: List[NostrEvent], author, event_id, reply_event_id +): filter = NostrFilter() filter.p.append(author) filter.e.append(event_id) - + events_related_to_event = await get_events(RELAY_ID, filter) assert len(events_related_to_event) == 1, f"Failed to quert by tags 'e' & 'p'" - assert events_related_to_event[0].id == reply_event_id, f"Failed to query the right event by tags 'e' & 'p'" + assert ( + events_related_to_event[0].id == reply_event_id + ), f"Failed to query the right event by tags 'e' & 'p'" filtered_events = [e for e in all_events if filter.matches(e)] assert len(filtered_events) == 1, f"Failed to filter by tags 'e' & 'p'" - assert filtered_events[0].id == reply_event_id, f"Failed to find the right event by tags 'e' & 'p'" + assert ( + filtered_events[0].id == reply_event_id + ), f"Failed to find the right event by tags 'e' & 'p'" -async def filter_by_tag_e_p_and_author(all_events: List[NostrEvent], author, event_id, reply_event_id): + +async def filter_by_tag_e_p_and_author( + all_events: List[NostrEvent], author, event_id, reply_event_id +): filter = NostrFilter(authors=[author]) filter.p.append(author) filter.e.append(event_id) events_related_to_event = await get_events(RELAY_ID, filter) - assert len(events_related_to_event) == 1, f"Failed to query by 'author' and tags 'e' & 'p'" - assert events_related_to_event[0].id == reply_event_id, f"Failed to query the right event by 'author' and tags 'e' & 'p'" + assert ( + len(events_related_to_event) == 1 + ), f"Failed to query by 'author' and tags 'e' & 'p'" + assert ( + events_related_to_event[0].id == reply_event_id + ), f"Failed to query the right event by 'author' and tags 'e' & 'p'" filtered_events = [e for e in all_events if filter.matches(e)] assert len(filtered_events) == 1, f"Failed to filter by 'author' and tags 'e' & 'p'" - assert filtered_events[0].id == reply_event_id, f"Failed to filter the right event by 'author' and tags 'e' & 'p'" - - + assert ( + filtered_events[0].id == reply_event_id + ), f"Failed to filter the right event by 'author' and tags 'e' & 'p'" diff --git a/views_api.py b/views_api.py index 7ff8587..a41fe85 100644 --- a/views_api.py +++ b/views_api.py @@ -30,6 +30,7 @@ from .models import NostrRelay client_manager = NostrClientManager() + @nostrrelay_ext.websocket("/{relay_id}") async def websocket_endpoint(relay_id: str, websocket: WebSocket): client = NostrClientConnection(relay_id=relay_id, websocket=websocket) @@ -44,7 +45,6 @@ async def websocket_endpoint(relay_id: str, websocket: WebSocket): client_manager.remove_client(client) - @nostrrelay_ext.get("/{relay_id}", status_code=HTTPStatus.OK) async def api_nostrrelay_info(relay_id: str): relay = await get_public_relay(relay_id) @@ -54,16 +54,20 @@ async def api_nostrrelay_info(relay_id: str): detail="Relay not found", ) - return JSONResponse(content=relay, headers={ - "Access-Control-Allow-Origin": "*", - "Access-Control-Allow-Headers": "*", - "Access-Control-Allow-Methods": "GET" - }) - + return JSONResponse( + content=relay, + headers={ + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Headers": "*", + "Access-Control-Allow-Methods": "GET", + }, + ) @nostrrelay_ext.post("/api/v1/relay") -async def api_create_relay(data: NostrRelay, wallet: WalletTypeInfo = Depends(require_admin_key)) -> NostrRelay: +async def api_create_relay( + data: NostrRelay, wallet: WalletTypeInfo = Depends(require_admin_key) +) -> NostrRelay: if len(data.id): await check_admin(UUID4(wallet.wallet.user)) else: @@ -80,8 +84,11 @@ async def api_create_relay(data: NostrRelay, wallet: WalletTypeInfo = Depends(re detail="Cannot create relay", ) + @nostrrelay_ext.put("/api/v1/relay/{relay_id}") -async def api_update_relay(relay_id: str, data: NostrRelay, wallet: WalletTypeInfo = Depends(require_admin_key)) -> NostrRelay: +async def api_update_relay( + relay_id: str, data: NostrRelay, wallet: WalletTypeInfo = Depends(require_admin_key) +) -> NostrRelay: if relay_id != data.id: raise HTTPException( status_code=HTTPStatus.BAD_REQUEST, @@ -97,8 +104,8 @@ async def api_update_relay(relay_id: str, data: NostrRelay, wallet: WalletTypeIn ) updated_relay = NostrRelay.parse_obj({**dict(relay), **dict(data)}) updated_relay = await update_relay(wallet.wallet.user, updated_relay) - - if updated_relay.active: + + if updated_relay.active: await client_manager.enable_relay(relay_id, updated_relay.config) else: await client_manager.disable_relay(relay_id) @@ -116,7 +123,9 @@ async def api_update_relay(relay_id: str, data: NostrRelay, wallet: WalletTypeIn @nostrrelay_ext.get("/api/v1/relay") -async def api_get_relays(wallet: WalletTypeInfo = Depends(require_invoice_key)) -> List[NostrRelay]: +async def api_get_relays( + wallet: WalletTypeInfo = Depends(require_invoice_key), +) -> List[NostrRelay]: try: return await get_relays(wallet.wallet.user) except Exception as ex: @@ -126,8 +135,11 @@ async def api_get_relays(wallet: WalletTypeInfo = Depends(require_invoice_key)) detail="Cannot fetch relays", ) + @nostrrelay_ext.get("/api/v1/relay/{relay_id}") -async def api_get_relay(relay_id: str, wallet: WalletTypeInfo = Depends(require_invoice_key)) -> Optional[NostrRelay]: +async def api_get_relay( + relay_id: str, wallet: WalletTypeInfo = Depends(require_invoice_key) +) -> Optional[NostrRelay]: try: relay = await get_relay(wallet.wallet.user, relay_id) except Exception as ex: @@ -143,8 +155,11 @@ async def api_get_relay(relay_id: str, wallet: WalletTypeInfo = Depends(require_ ) return relay + @nostrrelay_ext.delete("/api/v1/relay/{relay_id}") -async def api_delete_relay(relay_id: str, wallet: WalletTypeInfo = Depends(require_admin_key)): +async def api_delete_relay( + relay_id: str, wallet: WalletTypeInfo = Depends(require_admin_key) +): try: await client_manager.disable_relay(relay_id) await delete_relay(wallet.wallet.user, relay_id)