From 2ebc83a286286ca7af6987fc89cbfc66b9a3c212 Mon Sep 17 00:00:00 2001 From: Vlad Stan Date: Fri, 17 Feb 2023 14:16:24 +0200 Subject: [PATCH] refactor: extract `NostrFilter` --- crud.py | 12 +--- models.py | 109 ---------------------------------- relay/client_connection.py | 3 +- relay/client_manager.py | 5 +- relay/event.py | 1 - relay/event_validator.py | 3 +- relay/filter.py | 117 +++++++++++++++++++++++++++++++++++++ tests/test_clients.py | 8 ++- tests/test_events.py | 2 +- 9 files changed, 132 insertions(+), 128 deletions(-) create mode 100644 relay/filter.py diff --git a/crud.py b/crud.py index 8eb7aa0..53ffdc4 100644 --- a/crud.py +++ b/crud.py @@ -1,16 +1,10 @@ import json from typing import List, Optional, Tuple -from .relay.event import NostrEvent - from . import db -from .models import ( - NostrAccount, - NostrFilter, - NostrRelay, - RelayPublicSpec, - RelaySpec, -) +from .models import NostrAccount, NostrRelay, RelayPublicSpec, RelaySpec +from .relay.event import NostrEvent +from .relay.filter import NostrFilter ########################## RELAYS #################### diff --git a/models.py b/models.py index 1b3c953..266e35c 100644 --- a/models.py +++ b/models.py @@ -129,115 +129,6 @@ class NostrRelay(BaseModel): "version": "", } -class NostrFilter(BaseModel): - subscription_id: Optional[str] - - ids: List[str] = [] - authors: List[str] = [] - kinds: List[int] = [] - e: List[str] = Field([], alias="#e") - p: List[str] = Field([], alias="#p") - since: Optional[int] - until: Optional[int] - limit: Optional[int] - - def matches(self, e: NostrEvent) -> bool: - # todo: starts with - if len(self.ids) != 0 and e.id not in self.ids: - return False - if len(self.authors) != 0 and e.pubkey not in self.authors: - return False - if len(self.kinds) != 0 and e.kind not in self.kinds: - return False - - if self.since and e.created_at < self.since: - return False - if self.until and self.until > 0 and e.created_at > self.until: - return False - - found_e_tag = self.tag_in_list(e.tags, "e") - found_p_tag = self.tag_in_list(e.tags, "p") - if not found_e_tag or not found_p_tag: - return False - - return True - - def tag_in_list(self, event_tags, tag_name) -> bool: - filter_tags = dict(self).get(tag_name, []) - if len(filter_tags) == 0: - return True - - event_tag_values = [t[1] for t in event_tags if t[0] == tag_name] - - common_tags = [ - event_tag for event_tag in event_tag_values if event_tag in filter_tags - ] - if len(common_tags) == 0: - return False - return True - - def is_empty(self): - return ( - len(self.ids) == 0 - and len(self.authors) == 0 - and len(self.kinds) == 0 - and len(self.e) == 0 - and len(self.p) == 0 - and (not self.since) - and (not self.until) - ) - - def enforce_limit(self, limit: int): - if not self.limit or self.limit > limit: - self.limit = limit - - def to_sql_components( - self, relay_id: str - ) -> Tuple[List[str], List[str], List[Any]]: - inner_joins: List[str] = [] - where = ["deleted=false", "nostrrelay.events.relay_id = ?"] - values: List[Any] = [relay_id] - - if len(self.e): - values += self.e - e_s = ",".join(["?"] * len(self.e)) - inner_joins.append( - "INNER JOIN nostrrelay.event_tags e_tags ON nostrrelay.events.id = e_tags.event_id" - ) - where.append(f" (e_tags.value in ({e_s}) AND e_tags.name = 'e')") - - if len(self.p): - values += self.p - p_s = ",".join(["?"] * len(self.p)) - inner_joins.append( - "INNER JOIN nostrrelay.event_tags p_tags ON nostrrelay.events.id = p_tags.event_id" - ) - where.append(f" p_tags.value in ({p_s}) AND p_tags.name = 'p'") - - if len(self.ids) != 0: - ids = ",".join(["?"] * len(self.ids)) - where.append(f"id IN ({ids})") - values += self.ids - - if len(self.authors) != 0: - authors = ",".join(["?"] * len(self.authors)) - where.append(f"pubkey IN ({authors})") - values += self.authors - - if len(self.kinds) != 0: - kinds = ",".join(["?"] * len(self.kinds)) - where.append(f"kind IN ({kinds})") - values += self.kinds - - if self.since: - where.append("created_at >= ?") - values += [self.since] - - if self.until: - where.append("created_at < ?") - values += [self.until] - - return inner_joins, where, values class BuyOrder(BaseModel): diff --git a/relay/client_connection.py b/relay/client_connection.py index 238873a..9a82859 100644 --- a/relay/client_connection.py +++ b/relay/client_connection.py @@ -14,9 +14,10 @@ from ..crud import ( get_events, mark_events_deleted, ) +from ..models import RelaySpec from .event import NostrEvent, NostrEventType -from ..models import NostrFilter, RelaySpec from .event_validator import EventValidator +from .filter import NostrFilter class NostrClientConnection: diff --git a/relay/client_manager.py b/relay/client_manager.py index 7f41880..bf8e07e 100644 --- a/relay/client_manager.py +++ b/relay/client_manager.py @@ -1,10 +1,9 @@ from typing import List -from .event import NostrEvent -from .client_connection import NostrClientConnection - from ..crud import get_config_for_all_active_relays from ..models import RelaySpec +from .client_connection import NostrClientConnection +from .event import NostrEvent class NostrClientManager: diff --git a/relay/event.py b/relay/event.py index f8e97af..15b3c02 100644 --- a/relay/event.py +++ b/relay/event.py @@ -8,7 +8,6 @@ from pydantic import BaseModel from secp256k1 import PublicKey - class NostrEventType(str, Enum): EVENT = "EVENT" REQ = "REQ" diff --git a/relay/event_validator.py b/relay/event_validator.py index aaa6cf2..e8e0e9f 100644 --- a/relay/event_validator.py +++ b/relay/event_validator.py @@ -1,11 +1,10 @@ import time from typing import Callable, Optional, Tuple -from .event import NostrEvent - from ..crud import get_account, get_storage_for_public_key, prune_old_events from ..helpers import extract_domain from ..models import NostrAccount, RelaySpec +from .event import NostrEvent class EventValidator: diff --git a/relay/filter.py b/relay/filter.py new file mode 100644 index 0000000..68f740d --- /dev/null +++ b/relay/filter.py @@ -0,0 +1,117 @@ + +from typing import Any, List, Optional, Tuple + +from pydantic import BaseModel, Field + +from .event import NostrEvent + + +class NostrFilter(BaseModel): + subscription_id: Optional[str] + + ids: List[str] = [] + authors: List[str] = [] + kinds: List[int] = [] + e: List[str] = Field([], alias="#e") + p: List[str] = Field([], alias="#p") + since: Optional[int] + until: Optional[int] + limit: Optional[int] + + def matches(self, e: NostrEvent) -> bool: + # todo: starts with + if len(self.ids) != 0 and e.id not in self.ids: + return False + if len(self.authors) != 0 and e.pubkey not in self.authors: + return False + if len(self.kinds) != 0 and e.kind not in self.kinds: + return False + + if self.since and e.created_at < self.since: + return False + if self.until and self.until > 0 and e.created_at > self.until: + return False + + found_e_tag = self.tag_in_list(e.tags, "e") + found_p_tag = self.tag_in_list(e.tags, "p") + if not found_e_tag or not found_p_tag: + return False + + return True + + def tag_in_list(self, event_tags, tag_name) -> bool: + filter_tags = dict(self).get(tag_name, []) + if len(filter_tags) == 0: + return True + + event_tag_values = [t[1] for t in event_tags if t[0] == tag_name] + + common_tags = [ + event_tag for event_tag in event_tag_values if event_tag in filter_tags + ] + if len(common_tags) == 0: + return False + return True + + def is_empty(self): + return ( + len(self.ids) == 0 + and len(self.authors) == 0 + and len(self.kinds) == 0 + and len(self.e) == 0 + and len(self.p) == 0 + and (not self.since) + and (not self.until) + ) + + def enforce_limit(self, limit: int): + if not self.limit or self.limit > limit: + self.limit = limit + + def to_sql_components( + self, relay_id: str + ) -> Tuple[List[str], List[str], List[Any]]: + inner_joins: List[str] = [] + where = ["deleted=false", "nostrrelay.events.relay_id = ?"] + values: List[Any] = [relay_id] + + if len(self.e): + values += self.e + e_s = ",".join(["?"] * len(self.e)) + inner_joins.append( + "INNER JOIN nostrrelay.event_tags e_tags ON nostrrelay.events.id = e_tags.event_id" + ) + where.append(f" (e_tags.value in ({e_s}) AND e_tags.name = 'e')") + + if len(self.p): + values += self.p + p_s = ",".join(["?"] * len(self.p)) + inner_joins.append( + "INNER JOIN nostrrelay.event_tags p_tags ON nostrrelay.events.id = p_tags.event_id" + ) + where.append(f" p_tags.value in ({p_s}) AND p_tags.name = 'p'") + + if len(self.ids) != 0: + ids = ",".join(["?"] * len(self.ids)) + where.append(f"id IN ({ids})") + values += self.ids + + if len(self.authors) != 0: + authors = ",".join(["?"] * len(self.authors)) + where.append(f"pubkey IN ({authors})") + values += self.authors + + if len(self.kinds) != 0: + kinds = ",".join(["?"] * len(self.kinds)) + where.append(f"kind IN ({kinds})") + values += self.kinds + + if self.since: + where.append("created_at >= ?") + values += [self.since] + + if self.until: + where.append("created_at < ?") + values += [self.until] + + return inner_joins, where, values diff --git a/tests/test_clients.py b/tests/test_clients.py index 56d93df..61fe83d 100644 --- a/tests/test_clients.py +++ b/tests/test_clients.py @@ -7,8 +7,12 @@ from fastapi import WebSocket from loguru import logger from lnbits.extensions.nostrrelay.models import RelaySpec # type: ignore -from lnbits.extensions.nostrrelay.relay.client_manager import NostrClientManager # type: ignore -from lnbits.extensions.nostrrelay.relay.client_connection import NostrClientConnection # type: ignore +from lnbits.extensions.nostrrelay.relay.client_connection import ( + NostrClientConnection, # type: ignore +) +from lnbits.extensions.nostrrelay.relay.client_manager import ( + NostrClientManager, # type: ignore +) from .helpers import get_fixtures diff --git a/tests/test_events.py b/tests/test_events.py index 0a42ab2..8e0249e 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -10,8 +10,8 @@ from lnbits.extensions.nostrrelay.crud import ( # type: ignore get_event, get_events, ) -from lnbits.extensions.nostrrelay.models import NostrFilter # type: ignore from lnbits.extensions.nostrrelay.relay.event import NostrEvent # type: ignore +from lnbits.extensions.nostrrelay.relay.filter import NostrFilter # type: ignore from .helpers import get_fixtures