diff --git a/crud.py b/crud.py index 8335078..588271b 100644 --- a/crud.py +++ b/crud.py @@ -30,25 +30,24 @@ async def create_event(relay_id: str, e: NostrEvent): async def get_events(relay_id: str, filter: NostrFilter) -> List[NostrEvent]: - values: List[Any] = [] - query = "SELECT id, pubkey, created_at, kind, content, sig FROM nostrrelay.events" - if len(filter.e) or len(filter.p): - query += " INNER JOIN nostrrelay.event_tags ON nostrrelay.events.id = nostrrelay.event_tags.event_id WHERE" - if len(filter.e): - values += filter.e - e_s = ",".join(["?"] * len(filter.e)) - query += f" nostrrelay.event_tags.value in ({e_s}) AND nostrrelay.event_tags.name = 'e'" + values: List[Any] = [relay_id] + query = "SELECT id, pubkey, created_at, kind, content, sig FROM nostrrelay.events " + + inner_joins = [] + where = ["nostrrelay.events.relay_id = ?"] + if len(filter.e): + values += filter.e + e_s = ",".join(["?"] * len(filter.e)) + inner_joins.append("INNER JOIN nostrrelay.event_tags e_tags ON nostrrelay.events.id = e_tags.event_id") + where.append(f" (e_tags.value in ({e_s}) AND e_tags.name = 'e')") - if len(filter.p): - values += filter.p - p_s = ",".join(["?"] * len(filter.p)) - and_op = " AND " if len(filter.e) else "" - query += f"{and_op} nostrrelay.event_tags.value in ({p_s}) AND nostrrelay.event_tags.name = 'p'" - query += " AND nostrrelay.events.relay_id = ?" - else: - query += " WHERE nostrrelay.events.relay_id = ?" + if len(filter.p): + values += filter.p + p_s = ",".join(["?"] * len(filter.p)) + inner_joins.append("INNER JOIN nostrrelay.event_tags p_tags ON nostrrelay.events.id = p_tags.event_id") + where.append(f" p_tags.value in ({p_s}) AND p_tags.name = 'p'") + query += " ".join(inner_joins)+ " WHERE " + " AND ".join(where) - values.append(relay_id) if len(filter.ids) != 0: ids = ",".join(["?"] * len(filter.ids)) diff --git a/tests/test_events.py b/tests/test_events.py index e09690b..63fda50 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -57,6 +57,30 @@ async def test_valid_event_crud(valid_events: List[EventFixture]): events = await get_events(relay_id, filter) assert len(events) == 1, f"Expected one filter event '{f.name}'" + author = "a24496bca5dd73300f4e5d5d346c73132b7354c597fcbb6509891747b4689211" + event_id = "3219eec7427e365585d5adf26f5d2dd2709d3f0f2c0e1f79dc9021e951c67d96" + events_by_author = await get_events(relay_id, NostrFilter(authors=[author])) + assert len(events_by_author) == 5, f"Failed to filter by authors" + + # todo: check why constructor does not work for fields with aliases (#e, #p) + filter = NostrFilter() + filter.p.append(author) + events_related_to_author = await get_events(relay_id, filter) + assert len(events_related_to_author) == 5, f"Failed to filter by tag 'p'" + + filter = NostrFilter() + filter.e.append(event_id) + events_related_to_event = await get_events(relay_id, filter) + assert len(events_related_to_event) == 2, f"Failed to filter by tag 'e'" + + reply_event_id = "6b2b6cb9c72caaf3dfbc5baa5e68d75ac62f38ec011b36cc83832218c36e4894" + filter = NostrFilter(authors=[author]) + filter.p.append(author) + filter.e.append(event_id) + events_related_to_event = await get_events(relay_id, filter) + assert len(events_related_to_event) == 1, f"Failed to filter by tags 'e' & 'p'" + assert events_related_to_event[0].id == reply_event_id, f"Failed to find the right event by tags 'e' & 'p'" + def get_fixtures(file):