[feat] fetch all payments for user (#3132)

This commit is contained in:
Vlad Stan 2025-04-29 13:52:07 +03:00 committed by GitHub
parent 2dee26b728
commit e339bb6181
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 87 additions and 5 deletions

View file

@ -1,7 +1,7 @@
from time import time
from typing import Any, Optional, Tuple
from lnbits.core.crud.wallets import get_total_balance, get_wallet
from lnbits.core.crud.wallets import get_total_balance, get_wallet, get_wallets_ids
from lnbits.core.db import db
from lnbits.core.models import PaymentState
from lnbits.db import Connection, DateTrunc, Filters, Page
@ -95,6 +95,7 @@ async def get_latest_payments_by_extension(
async def get_payments_paginated(
*,
wallet_id: Optional[str] = None,
user_id: Optional[str] = None,
complete: bool = False,
pending: bool = False,
failed: bool = False,
@ -121,6 +122,13 @@ async def get_payments_paginated(
if wallet_id:
values["wallet_id"] = wallet_id
clause.append("wallet_id = :wallet_id")
elif user_id:
wallet_ids = await get_wallets_ids(user_id=user_id, conn=conn) or [
"no-wallets-for-user"
]
# wallet ids are safe to use in sql queries
wallet_ids_str = [f"'{w}'" for w in wallet_ids]
clause.append(f""" wallet_id IN ({", ".join(wallet_ids_str)}) """)
if complete and pending:
clause.append(

View file

@ -135,6 +135,20 @@ async def get_wallets(
)
async def get_wallets_ids(
user_id: str, deleted: Optional[bool] = None, conn: Optional[Connection] = None
) -> list[str]:
where = "AND deleted = :deleted" if deleted is not None else ""
result: list[dict] = await (conn or db).fetchall(
f"""
SELECT id FROM wallets
WHERE "user" = :user {where}
""",
{"user": user_id, "deleted": deleted},
)
return [row["id"] for row in result]
async def get_wallets_count():
result = await db.execute("SELECT COUNT(*) as count FROM wallets")
row = result.mappings().first()

View file

@ -264,13 +264,21 @@ async def _api_payments_create_invoice(data: CreateInvoice, wallet: Wallet):
response_description="list of payments",
response_model=Page[Payment],
openapi_extra=generate_filter_params_openapi(PaymentFilters),
dependencies=[Depends(check_admin)],
)
async def api_all_payments_paginated(
filters: Filters = Depends(parse_filters(PaymentFilters)),
user: User = Depends(check_user_exists),
):
if user.admin:
# admin user can see payments from all wallets
for_user_id = None
else:
# regular user can only see payments from their wallets
for_user_id = user.id
return await get_payments_paginated(
filters=filters,
user_id=for_user_id,
)

View file

@ -8,10 +8,10 @@ from bolt11 import encode as bolt11_encode
from bolt11.types import MilliSatoshi
from pytest_mock.plugin import MockerFixture
from lnbits.core.crud import get_standalone_payment, get_wallet
from lnbits.core.crud.payments import get_payment
from lnbits.core.crud import create_wallet, get_standalone_payment, get_wallet
from lnbits.core.crud.payments import get_payment, get_payments_paginated
from lnbits.core.models import Payment, PaymentState, Wallet
from lnbits.core.services import create_invoice, pay_invoice
from lnbits.core.services import create_invoice, create_user_account, pay_invoice
from lnbits.exceptions import InvoiceError, PaymentError
from lnbits.settings import Settings
from lnbits.tasks import (
@ -596,3 +596,55 @@ async def test_service_fee(
assert service_fee_payment.amount == 422_400
assert service_fee_payment.bolt11 == external_invoice.payment_request
assert service_fee_payment.preimage is None
@pytest.mark.anyio
async def test_get_payments_for_user(to_wallet: Wallet):
all_payments = await get_payments_paginated()
total_before = all_payments.total
user = await create_user_account()
wallet_one = await create_wallet(user_id=user.id, wallet_name="first wallet")
wallet_two = await create_wallet(user_id=user.id, wallet_name="second wallet")
user_payments = await get_payments_paginated(user_id=user.id)
assert user_payments.total == 0
payment = await create_invoice(wallet_id=wallet_one.id, amount=100, memo="one")
user_payments = await get_payments_paginated(user_id=user.id)
assert user_payments.total == 1
# this will create a payment in the to_wallet that we need to count for at the end
await pay_invoice(
wallet_id=to_wallet.id,
payment_request=payment.bolt11,
)
user_payments = await get_payments_paginated(user_id=user.id)
assert user_payments.total == 1
payment = await create_invoice(wallet_id=wallet_one.id, amount=3, memo="two")
user_payments = await get_payments_paginated(user_id=user.id)
assert user_payments.total == 2
payment = await create_invoice(wallet_id=wallet_two.id, amount=3, memo="three")
user_payments = await get_payments_paginated(user_id=user.id)
assert user_payments.total == 3
await pay_invoice(
wallet_id=wallet_one.id,
payment_request=payment.bolt11,
)
user_payments = await get_payments_paginated(user_id=user.id)
assert user_payments.total == 4
all_payments = await get_payments_paginated()
total_after = all_payments.total
assert total_after == total_before + 5, "Total payments should be updated."
@pytest.mark.anyio
async def test_get_payments_for_non_user():
user_payments = await get_payments_paginated(user_id="nonexistent")
assert (
user_payments.total == 0
), "No payments should be found for non-existent user."