chore: force py format

This commit is contained in:
Vlad Stan 2023-02-10 12:16:25 +02:00
parent 55f9142f3d
commit 1eda457067
8 changed files with 248 additions and 99 deletions

View file

@ -27,7 +27,7 @@ from .views import * # noqa
from .views_api import * # noqa
settings.lnbits_relay_information = {
"name": "LNbits Nostr Relay",
"name": "LNbits Nostr Relay",
"description": "Multiple relays are supported",
**NostrRelay.info()
**NostrRelay.info(),
}

View file

@ -37,7 +37,6 @@ class NostrClientManager:
return True
def remove_client(self, c: "NostrClientConnection"):
self.clients(c.relay_id).remove(c)
@ -60,7 +59,7 @@ class NostrClientManager:
def get_relay_config(self, relay_id: str) -> RelayConfig:
return self._active_relays[relay_id]
def clients(self, relay_id: str) -> List["NostrClientConnection"]:
if relay_id not in self._clients:
self._clients[relay_id] = []
@ -75,25 +74,29 @@ class NostrClientManager:
if c.relay_id not in self._active_relays:
await c.stop(reason=f"Relay '{c.relay_id}' is not active")
return False
#todo: NIP-42: AUTH
# todo: NIP-42: AUTH
return True
def _set_client_callbacks(self, client):
setattr(client, "broadcast_event", self.broadcast_event)
def get_client_config() -> ClientConfig:
return self.get_relay_config(client.relay_id)
setattr(client, "get_client_config", get_client_config)
class NostrClientConnection:
setattr(client, "get_client_config", get_client_config)
class NostrClientConnection:
def __init__(self, relay_id: str, websocket: WebSocket):
self.websocket = websocket
self.relay_id = relay_id
self.filters: List[NostrFilter] = []
self.broadcast_event: Optional[Callable[[NostrClientConnection, NostrEvent], Awaitable[None]]] = None
self.broadcast_event: Optional[
Callable[[NostrClientConnection, NostrEvent], Awaitable[None]]
] = None
self.get_client_config: Optional[Callable[[], ClientConfig]] = None
self._last_event_timestamp = 0 # in seconds
self._last_event_timestamp = 0 # in seconds
self._event_count_per_timestamp = 0
async def start(self):
@ -184,12 +187,11 @@ class NostrClientConnection:
resp_nip20 += [event != None, message]
await self._send_msg(resp_nip20)
@property
def client_config(self) -> ClientConfig:
if not self.get_client_config:
raise Exception("Client not ready!")
raise Exception("Client not ready!")
return self.get_client_config()
async def _send_msg(self, data: List):
@ -207,7 +209,12 @@ class NostrClientConnection:
filter.subscription_id = subscription_id
self._remove_filter(subscription_id)
if self._can_add_filter():
return [["NOTICE", f"Maximum number of filters ({self.client_config.max_client_filters}) exceeded."]]
return [
[
"NOTICE",
f"Maximum number of filters ({self.client_config.max_client_filters}) exceeded.",
]
]
filter.enforce_limit(self.client_config.limit_per_filter)
self.filters.append(filter)
@ -226,14 +233,20 @@ class NostrClientConnection:
self._remove_filter(subscription_id)
def _can_add_filter(self) -> bool:
return self.client_config.max_client_filters != 0 and len(self.filters) >= self.client_config.max_client_filters
return (
self.client_config.max_client_filters != 0
and len(self.filters) >= self.client_config.max_client_filters
)
def _validate_event(self, e: NostrEvent)-> Tuple[bool, str]:
def _validate_event(self, e: NostrEvent) -> Tuple[bool, str]:
if self._exceeded_max_events_per_second():
return False, f"Exceeded max events per second limit'!"
if not self.client_config.is_author_allowed(e.pubkey):
return False, f"Public key '{e.pubkey}' is not allowed in relay '{self.relay_id}'!"
return (
False,
f"Public key '{e.pubkey}' is not allowed in relay '{self.relay_id}'!",
)
try:
e.check_signature()
@ -252,19 +265,21 @@ class NostrClientConnection:
return False, "Cannot write event, relay is read-only"
# todo: handeld paid paid plan
return True, "Temp OK"
stored_bytes = await get_storage_for_public_key(self.relay_id, e.pubkey)
if self.client_config.is_paid_relay:
# todo: handeld paid paid plan
return True, "Temp OK"
if (stored_bytes + e.size_bytes) <= self.client_config.free_storage_bytes_value:
return True, ""
return True, ""
if self.client_config.full_storage_action == "block":
return False, f"Cannot write event, no more storage available for public key: '{e.pubkey}'"
return (
False,
f"Cannot write event, no more storage available for public key: '{e.pubkey}'",
)
await prune_old_events(self.relay_id, e.pubkey, e.size_bytes)
return True, ""
@ -280,7 +295,9 @@ class NostrClientConnection:
self._last_event_timestamp = current_time
self._event_count_per_timestamp = 0
return self._event_count_per_timestamp > self.client_config.max_events_per_second
return (
self._event_count_per_timestamp > self.client_config.max_events_per_second
)
def _created_at_in_range(self, created_at: int) -> Tuple[bool, str]:
current_time = round(time.time())
@ -290,4 +307,4 @@ class NostrClientConnection:
if self.client_config.created_at_in_future != 0:
if created_at > (current_time + self.client_config.created_at_in_future):
return False, "created_at is too much into the future"
return True, ""
return True, ""

132
crud.py
View file

@ -6,18 +6,28 @@ from .models import NostrEvent, NostrFilter, NostrRelay, RelayConfig
########################## RELAYS ####################
async def create_relay(user_id: str, r: NostrRelay) -> NostrRelay:
await db.execute(
"""
INSERT INTO nostrrelay.relays (user_id, id, name, description, pubkey, contact, meta)
VALUES (?, ?, ?, ?, ?, ?, ?)
""",
(user_id, r.id, r.name, r.description, r.pubkey, r.contact, json.dumps(dict(r.config))),
(
user_id,
r.id,
r.name,
r.description,
r.pubkey,
r.contact,
json.dumps(dict(r.config)),
),
)
relay = await get_relay(user_id, r.id)
assert relay, "Created relay cannot be retrieved"
return relay
async def update_relay(user_id: str, r: NostrRelay) -> NostrRelay:
await db.execute(
"""
@ -25,31 +35,59 @@ async def update_relay(user_id: str, r: NostrRelay) -> NostrRelay:
SET (name, description, pubkey, contact, active, meta) = (?, ?, ?, ?, ?, ?)
WHERE user_id = ? AND id = ?
""",
(r.name, r.description, r.pubkey, r.contact, r.active, json.dumps(dict(r.config)), user_id, r.id),
(
r.name,
r.description,
r.pubkey,
r.contact,
r.active,
json.dumps(dict(r.config)),
user_id,
r.id,
),
)
return r
async def get_relay(user_id: str, relay_id: str) -> Optional[NostrRelay]:
row = await db.fetchone("""SELECT * FROM nostrrelay.relays WHERE user_id = ? AND id = ?""", (user_id, relay_id,))
row = await db.fetchone(
"""SELECT * FROM nostrrelay.relays WHERE user_id = ? AND id = ?""",
(
user_id,
relay_id,
),
)
return NostrRelay.from_row(row) if row else None
async def get_relays(user_id: str) -> List[NostrRelay]:
rows = await db.fetchall("""SELECT * FROM nostrrelay.relays WHERE user_id = ? ORDER BY id ASC""", (user_id,))
rows = await db.fetchall(
"""SELECT * FROM nostrrelay.relays WHERE user_id = ? ORDER BY id ASC""",
(user_id,),
)
return [NostrRelay.from_row(row) for row in rows]
async def get_config_for_all_active_relays() -> dict:
rows = await db.fetchall("SELECT id, meta FROM nostrrelay.relays WHERE active = true",)
rows = await db.fetchall(
"SELECT id, meta FROM nostrrelay.relays WHERE active = true",
)
active_relay_configs = {}
for r in rows:
active_relay_configs[r["id"]] = RelayConfig(**json.loads(r["meta"])) #todo: from_json
active_relay_configs[r["id"]] = RelayConfig(
**json.loads(r["meta"])
) # todo: from_json
return active_relay_configs
async def get_public_relay(relay_id: str) -> Optional[dict]:
row = await db.fetchone("""SELECT * FROM nostrrelay.relays WHERE id = ?""", (relay_id,))
row = await db.fetchone(
"""SELECT * FROM nostrrelay.relays WHERE id = ?""", (relay_id,)
)
if not row:
return None
@ -59,14 +97,20 @@ async def get_public_relay(relay_id: str) -> Optional[dict]:
**NostrRelay.info(),
"id": relay.id,
"name": relay.name,
"description":relay.description,
"pubkey":relay.pubkey,
"contact":relay.contact
"description": relay.description,
"pubkey": relay.pubkey,
"contact": relay.contact,
}
async def delete_relay(user_id: str, relay_id: str):
await db.execute("""DELETE FROM nostrrelay.relays WHERE user_id = ? AND id = ?""", (user_id, relay_id,))
await db.execute(
"""DELETE FROM nostrrelay.relays WHERE user_id = ? AND id = ?""",
(
user_id,
relay_id,
),
)
########################## EVENTS ####################
@ -85,7 +129,16 @@ async def create_event(relay_id: str, e: NostrEvent):
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""",
(relay_id, e.id, e.pubkey, e.created_at, e.kind, e.content, e.sig, e.size_bytes),
(
relay_id,
e.id,
e.pubkey,
e.created_at,
e.kind,
e.content,
e.sig,
e.size_bytes,
),
)
# todo: optimize with bulk insert
@ -94,7 +147,10 @@ async def create_event(relay_id: str, e: NostrEvent):
extra = json.dumps(rest) if rest else None
await create_event_tags(relay_id, e.id, name, value, extra)
async def get_events(relay_id: str, filter: NostrFilter, include_tags = True) -> List[NostrEvent]:
async def get_events(
relay_id: str, filter: NostrFilter, include_tags=True
) -> List[NostrEvent]:
query, values = build_select_events_query(relay_id, filter)
rows = await db.fetchall(query, tuple(values))
@ -108,8 +164,15 @@ async def get_events(relay_id: str, filter: NostrFilter, include_tags = True) ->
return events
async def get_event(relay_id: str, id: str) -> Optional[NostrEvent]:
row = await db.fetchone("SELECT * FROM nostrrelay.events WHERE relay_id = ? AND id = ?", (relay_id, id,))
row = await db.fetchone(
"SELECT * FROM nostrrelay.events WHERE relay_id = ? AND id = ?",
(
relay_id,
id,
),
)
if not row:
return None
@ -117,17 +180,25 @@ async def get_event(relay_id: str, id: str) -> Optional[NostrEvent]:
event.tags = await get_event_tags(relay_id, id)
return event
async def get_storage_for_public_key(relay_id: str, pubkey: str) -> int:
"""Returns the storage space in bytes for all the events of a public key. Deleted events are also counted"""
row = await db.fetchone("SELECT SUM(size) as sum FROM nostrrelay.events WHERE relay_id = ? AND pubkey = ? GROUP BY pubkey", (relay_id, pubkey,))
row = await db.fetchone(
"SELECT SUM(size) as sum FROM nostrrelay.events WHERE relay_id = ? AND pubkey = ? GROUP BY pubkey",
(
relay_id,
pubkey,
),
)
if not row:
return 0
return round(row["sum"])
async def get_prunable_events(relay_id: str, pubkey: str) -> List[Tuple[str, int]]:
""" Return the oldest 10 000 events. Only the `id` and the size are returned, so the data size should be small"""
"""Return the oldest 10 000 events. Only the `id` and the size are returned, so the data size should be small"""
query = """
SELECT id, size FROM nostrrelay.events
WHERE relay_id = ? AND pubkey = ?
@ -139,21 +210,26 @@ async def get_prunable_events(relay_id: str, pubkey: str) -> List[Tuple[str, int
return [(r["id"], r["size"]) for r in rows]
async def mark_events_deleted(relay_id: str, filter: NostrFilter):
async def mark_events_deleted(relay_id: str, filter: NostrFilter):
if filter.is_empty():
return None
_, where, values = filter.to_sql_components(relay_id)
await db.execute(f"""UPDATE nostrrelay.events SET deleted=true WHERE {" AND ".join(where)}""", tuple(values))
await db.execute(
f"""UPDATE nostrrelay.events SET deleted=true WHERE {" AND ".join(where)}""",
tuple(values),
)
async def delete_events(relay_id: str, filter: NostrFilter):
async def delete_events(relay_id: str, filter: NostrFilter):
if filter.is_empty():
return None
_, where, values = filter.to_sql_components(relay_id)
query = f"""DELETE from nostrrelay.events WHERE {" AND ".join(where)}"""
await db.execute(query, tuple(values))
#todo: delete tags
# todo: delete tags
async def prune_old_events(relay_id: str, pubkey: str, space_to_regain: int):
prunable_events = await get_prunable_events(relay_id, pubkey)
@ -175,8 +251,13 @@ async def delete_all_events(relay_id: str):
await db.execute(query, (relay_id,))
# todo: delete tags
async def create_event_tags(
relay_id: str, event_id: str, tag_name: str, tag_value: str, extra_values: Optional[str]
relay_id: str,
event_id: str,
tag_name: str,
tag_value: str,
extra_values: Optional[str],
):
await db.execute(
"""
@ -192,9 +273,8 @@ async def create_event_tags(
(relay_id, event_id, tag_name, tag_value, extra_values),
)
async def get_event_tags(
relay_id: str, event_id: str
) -> List[List[str]]:
async def get_event_tags(relay_id: str, event_id: str) -> List[List[str]]:
rows = await db.fetchall(
"SELECT * FROM nostrrelay.event_tags WHERE relay_id = ? and event_id = ?",
(relay_id, event_id),
@ -211,7 +291,7 @@ async def get_event_tags(
return tags
def build_select_events_query(relay_id:str, filter:NostrFilter):
def build_select_events_query(relay_id: str, filter: NostrFilter):
inner_joins, where, values = filter.to_sql_components(relay_id)
query = f"""

View file

@ -23,15 +23,13 @@ class ClientConfig(BaseModel):
created_at_minutes_future = Field(0, alias="createdAtMinutesFuture")
created_at_seconds_future = Field(0, alias="createdAtSecondsFuture")
is_paid_relay = Field(False, alias="isPaidRelay")
free_storage_value = Field(1, alias="freeStorageValue")
free_storage_unit = Field("MB", alias="freeStorageUnit")
full_storage_action = Field("prune", alias="fullStorageAction")
allowed_public_keys = Field([], alias="allowedPublicKeys")
blocked_public_keys = Field([], alias="blockedPublicKeys")
def is_author_allowed(self, p: str) -> bool:
if p in self.blocked_public_keys:
@ -43,11 +41,21 @@ class ClientConfig(BaseModel):
@property
def created_at_in_past(self) -> int:
return self.created_at_days_past * 86400 + self.created_at_hours_past * 3600 + self.created_at_minutes_past * 60 + self.created_at_seconds_past
return (
self.created_at_days_past * 86400
+ self.created_at_hours_past * 3600
+ self.created_at_minutes_past * 60
+ self.created_at_seconds_past
)
@property
def created_at_in_future(self) -> int:
return self.created_at_days_future * 86400 + self.created_at_hours_future * 3600 + self.created_at_minutes_future * 60 + self.created_at_seconds_future
return (
self.created_at_days_future * 86400
+ self.created_at_hours_future * 3600
+ self.created_at_minutes_future * 60
+ self.created_at_seconds_future
)
@property
def free_storage_bytes_value(self):
@ -59,6 +67,7 @@ class ClientConfig(BaseModel):
class Config:
allow_population_by_field_name = True
class RelayConfig(ClientConfig):
wallet = Field("")
cost_to_join = Field(0, alias="costToJoin")
@ -78,7 +87,6 @@ class NostrRelay(BaseModel):
config: "RelayConfig" = RelayConfig()
@classmethod
def from_row(cls, row: Row) -> "NostrRelay":
relay = cls(**dict(row))
@ -86,7 +94,9 @@ class NostrRelay(BaseModel):
return relay
@classmethod
def info(cls,) -> dict:
def info(
cls,
) -> dict:
return {
"contact": "https://t.me/lnbits",
"supported_nips": [1, 9, 11, 15, 20, 22],
@ -223,7 +233,9 @@ class NostrFilter(BaseModel):
if not self.limit or self.limit > limit:
self.limit = limit
def to_sql_components(self, relay_id: str) -> Tuple[List[str], List[str], List[Any]]:
def to_sql_components(
self, relay_id: str
) -> Tuple[List[str], List[str], List[Any]]:
inner_joins: List[str] = []
where = ["deleted=false", "nostrrelay.events.relay_id = ?"]
values: List[Any] = [relay_id]
@ -231,13 +243,17 @@ class NostrFilter(BaseModel):
if len(self.e):
values += self.e
e_s = ",".join(["?"] * len(self.e))
inner_joins.append("INNER JOIN nostrrelay.event_tags e_tags ON nostrrelay.events.id = e_tags.event_id")
inner_joins.append(
"INNER JOIN nostrrelay.event_tags e_tags ON nostrrelay.events.id = e_tags.event_id"
)
where.append(f" (e_tags.value in ({e_s}) AND e_tags.name = 'e')")
if len(self.p):
values += self.p
p_s = ",".join(["?"] * len(self.p))
inner_joins.append("INNER JOIN nostrrelay.event_tags p_tags ON nostrrelay.events.id = p_tags.event_id")
inner_joins.append(
"INNER JOIN nostrrelay.event_tags p_tags ON nostrrelay.events.id = p_tags.event_id"
)
where.append(f" p_tags.value in ({p_s}) AND p_tags.name = 'p'")
if len(self.ids) != 0:
@ -262,6 +278,5 @@ class NostrFilter(BaseModel):
if self.until:
where.append("created_at < ?")
values += [self.until]
return inner_joins, where, values

View file

@ -2,6 +2,7 @@ import json
FIXTURES_PATH = "tests/extensions/nostrrelay/fixture"
def get_fixtures(file):
"""
Read the content of the JSON file.

View file

@ -40,9 +40,7 @@ class MockWebSocket(WebSocket):
async def wire_mock_data(self, data: dict):
await self.fake_wire.put(dumps(data))
async def close(
self, code: int = 1000, reason: Optional[str] = None
) -> None:
async def close(self, code: int = 1000, reason: Optional[str] = None) -> None:
logger.info(reason)
@ -152,7 +150,6 @@ async def bob_wires_contact_list(ws_alice: MockWebSocket, ws_bob: MockWebSocket)
await ws_alice.wire_mock_data(alice["subscribe_to_bob_contact_list"])
await asyncio.sleep(0.1)
print("### ws_alice.sent_message", ws_alice.sent_messages)
print("### ws_bob.sent_message", ws_bob.sent_messages)

View file

@ -12,6 +12,7 @@ from .helpers import get_fixtures
RELAY_ID = "r1"
class EventFixture(BaseModel):
name: str
exception: Optional[str]
@ -23,6 +24,7 @@ def valid_events() -> List[EventFixture]:
data = get_fixtures("events")
return [EventFixture.parse_obj(e) for e in data["valid"]]
@pytest.fixture
def invalid_events() -> List[EventFixture]:
data = get_fixtures("events")
@ -37,6 +39,7 @@ def test_valid_event_id_and_signature(valid_events: List[EventFixture]):
logger.error(f"Invalid 'id' ot 'signature' for fixture: '{f.name}'")
raise e
def test_invalid_event_id_and_signature(invalid_events: List[EventFixture]):
for f in invalid_events:
with pytest.raises(ValueError, match=f.exception):
@ -44,7 +47,7 @@ def test_invalid_event_id_and_signature(invalid_events: List[EventFixture]):
@pytest.mark.asyncio
async def test_valid_event_crud(valid_events: List[EventFixture]):
async def test_valid_event_crud(valid_events: List[EventFixture]):
author = "a24496bca5dd73300f4e5d5d346c73132b7354c597fcbb6509891747b4689211"
event_id = "3219eec7427e365585d5adf26f5d2dd2709d3f0f2c0e1f79dc9021e951c67d96"
reply_event_id = "6b2b6cb9c72caaf3dfbc5baa5e68d75ac62f38ec011b36cc83832218c36e4894"
@ -54,12 +57,10 @@ async def test_valid_event_crud(valid_events: List[EventFixture]):
for e in all_events:
await create_event(RELAY_ID, e)
for f in valid_events:
for f in valid_events:
await get_by_id(f.data, f.name)
await filter_by_id(all_events, f.data, f.name)
await filter_by_author(all_events, author)
await filter_by_tag_p(all_events, author)
@ -70,22 +71,30 @@ async def test_valid_event_crud(valid_events: List[EventFixture]):
await filter_by_tag_e_p_and_author(all_events, author, event_id, reply_event_id)
async def get_by_id(data: NostrEvent, test_name: str):
event = await get_event(RELAY_ID, data.id)
assert event, f"Failed to restore event (id='{data.id}')"
assert event.json() != json.dumps(data.json()), f"Restored event is different for fixture '{test_name}'"
assert event.json() != json.dumps(
data.json()
), f"Restored event is different for fixture '{test_name}'"
async def filter_by_id(all_events: List[NostrEvent], data: NostrEvent, test_name: str):
filter = NostrFilter(ids=[data.id])
events = await get_events(RELAY_ID, filter)
assert len(events) == 1, f"Expected one queried event '{test_name}'"
assert events[0].json() != json.dumps(data.json()), f"Queried event is different for fixture '{test_name}'"
assert events[0].json() != json.dumps(
data.json()
), f"Queried event is different for fixture '{test_name}'"
filtered_events = [e for e in all_events if filter.matches(e)]
assert len(filtered_events) == 1, f"Expected one filter event '{test_name}'"
assert filtered_events[0].json() != json.dumps(data.json()), f"Filtered event is different for fixture '{test_name}'"
assert filtered_events[0].json() != json.dumps(
data.json()
), f"Filtered event is different for fixture '{test_name}'"
async def filter_by_author(all_events: List[NostrEvent], author):
filter = NostrFilter(authors=[author])
@ -95,9 +104,10 @@ async def filter_by_author(all_events: List[NostrEvent], author):
filtered_events = [e for e in all_events if filter.matches(e)]
assert len(filtered_events) == 5, f"Failed to filter by authors"
async def filter_by_tag_p(all_events: List[NostrEvent], author):
# todo: check why constructor does not work for fields with aliases (#e, #p)
filter = NostrFilter()
filter = NostrFilter()
filter.p.append(author)
events_related_to_author = await get_events(RELAY_ID, filter)
@ -107,7 +117,7 @@ async def filter_by_tag_p(all_events: List[NostrEvent], author):
assert len(filtered_events) == 5, f"Failed to filter by tag 'p'"
async def filter_by_tag_e(all_events: List[NostrEvent], event_id):
async def filter_by_tag_e(all_events: List[NostrEvent], event_id):
filter = NostrFilter()
filter.e.append(event_id)
@ -117,29 +127,43 @@ async def filter_by_tag_e(all_events: List[NostrEvent], event_id):
filtered_events = [e for e in all_events if filter.matches(e)]
assert len(filtered_events) == 2, f"Failed to filter by tag 'e'"
async def filter_by_tag_e_and_p(all_events: List[NostrEvent], author, event_id, reply_event_id):
async def filter_by_tag_e_and_p(
all_events: List[NostrEvent], author, event_id, reply_event_id
):
filter = NostrFilter()
filter.p.append(author)
filter.e.append(event_id)
events_related_to_event = await get_events(RELAY_ID, filter)
assert len(events_related_to_event) == 1, f"Failed to quert by tags 'e' & 'p'"
assert events_related_to_event[0].id == reply_event_id, f"Failed to query the right event by tags 'e' & 'p'"
assert (
events_related_to_event[0].id == reply_event_id
), f"Failed to query the right event by tags 'e' & 'p'"
filtered_events = [e for e in all_events if filter.matches(e)]
assert len(filtered_events) == 1, f"Failed to filter by tags 'e' & 'p'"
assert filtered_events[0].id == reply_event_id, f"Failed to find the right event by tags 'e' & 'p'"
assert (
filtered_events[0].id == reply_event_id
), f"Failed to find the right event by tags 'e' & 'p'"
async def filter_by_tag_e_p_and_author(all_events: List[NostrEvent], author, event_id, reply_event_id):
async def filter_by_tag_e_p_and_author(
all_events: List[NostrEvent], author, event_id, reply_event_id
):
filter = NostrFilter(authors=[author])
filter.p.append(author)
filter.e.append(event_id)
events_related_to_event = await get_events(RELAY_ID, filter)
assert len(events_related_to_event) == 1, f"Failed to query by 'author' and tags 'e' & 'p'"
assert events_related_to_event[0].id == reply_event_id, f"Failed to query the right event by 'author' and tags 'e' & 'p'"
assert (
len(events_related_to_event) == 1
), f"Failed to query by 'author' and tags 'e' & 'p'"
assert (
events_related_to_event[0].id == reply_event_id
), f"Failed to query the right event by 'author' and tags 'e' & 'p'"
filtered_events = [e for e in all_events if filter.matches(e)]
assert len(filtered_events) == 1, f"Failed to filter by 'author' and tags 'e' & 'p'"
assert filtered_events[0].id == reply_event_id, f"Failed to filter the right event by 'author' and tags 'e' & 'p'"
assert (
filtered_events[0].id == reply_event_id
), f"Failed to filter the right event by 'author' and tags 'e' & 'p'"

View file

@ -30,6 +30,7 @@ from .models import NostrRelay
client_manager = NostrClientManager()
@nostrrelay_ext.websocket("/{relay_id}")
async def websocket_endpoint(relay_id: str, websocket: WebSocket):
client = NostrClientConnection(relay_id=relay_id, websocket=websocket)
@ -44,7 +45,6 @@ async def websocket_endpoint(relay_id: str, websocket: WebSocket):
client_manager.remove_client(client)
@nostrrelay_ext.get("/{relay_id}", status_code=HTTPStatus.OK)
async def api_nostrrelay_info(relay_id: str):
relay = await get_public_relay(relay_id)
@ -54,16 +54,20 @@ async def api_nostrrelay_info(relay_id: str):
detail="Relay not found",
)
return JSONResponse(content=relay, headers={
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Headers": "*",
"Access-Control-Allow-Methods": "GET"
})
return JSONResponse(
content=relay,
headers={
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Headers": "*",
"Access-Control-Allow-Methods": "GET",
},
)
@nostrrelay_ext.post("/api/v1/relay")
async def api_create_relay(data: NostrRelay, wallet: WalletTypeInfo = Depends(require_admin_key)) -> NostrRelay:
async def api_create_relay(
data: NostrRelay, wallet: WalletTypeInfo = Depends(require_admin_key)
) -> NostrRelay:
if len(data.id):
await check_admin(UUID4(wallet.wallet.user))
else:
@ -80,8 +84,11 @@ async def api_create_relay(data: NostrRelay, wallet: WalletTypeInfo = Depends(re
detail="Cannot create relay",
)
@nostrrelay_ext.put("/api/v1/relay/{relay_id}")
async def api_update_relay(relay_id: str, data: NostrRelay, wallet: WalletTypeInfo = Depends(require_admin_key)) -> NostrRelay:
async def api_update_relay(
relay_id: str, data: NostrRelay, wallet: WalletTypeInfo = Depends(require_admin_key)
) -> NostrRelay:
if relay_id != data.id:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
@ -97,8 +104,8 @@ async def api_update_relay(relay_id: str, data: NostrRelay, wallet: WalletTypeIn
)
updated_relay = NostrRelay.parse_obj({**dict(relay), **dict(data)})
updated_relay = await update_relay(wallet.wallet.user, updated_relay)
if updated_relay.active:
if updated_relay.active:
await client_manager.enable_relay(relay_id, updated_relay.config)
else:
await client_manager.disable_relay(relay_id)
@ -116,7 +123,9 @@ async def api_update_relay(relay_id: str, data: NostrRelay, wallet: WalletTypeIn
@nostrrelay_ext.get("/api/v1/relay")
async def api_get_relays(wallet: WalletTypeInfo = Depends(require_invoice_key)) -> List[NostrRelay]:
async def api_get_relays(
wallet: WalletTypeInfo = Depends(require_invoice_key),
) -> List[NostrRelay]:
try:
return await get_relays(wallet.wallet.user)
except Exception as ex:
@ -126,8 +135,11 @@ async def api_get_relays(wallet: WalletTypeInfo = Depends(require_invoice_key))
detail="Cannot fetch relays",
)
@nostrrelay_ext.get("/api/v1/relay/{relay_id}")
async def api_get_relay(relay_id: str, wallet: WalletTypeInfo = Depends(require_invoice_key)) -> Optional[NostrRelay]:
async def api_get_relay(
relay_id: str, wallet: WalletTypeInfo = Depends(require_invoice_key)
) -> Optional[NostrRelay]:
try:
relay = await get_relay(wallet.wallet.user, relay_id)
except Exception as ex:
@ -143,8 +155,11 @@ async def api_get_relay(relay_id: str, wallet: WalletTypeInfo = Depends(require_
)
return relay
@nostrrelay_ext.delete("/api/v1/relay/{relay_id}")
async def api_delete_relay(relay_id: str, wallet: WalletTypeInfo = Depends(require_admin_key)):
async def api_delete_relay(
relay_id: str, wallet: WalletTypeInfo = Depends(require_admin_key)
):
try:
await client_manager.disable_relay(relay_id)
await delete_relay(wallet.wallet.user, relay_id)