From 5589d813f0b8ed9c96b08b2a52c9c4973cf71c7d Mon Sep 17 00:00:00 2001 From: padreug Date: Wed, 22 Oct 2025 12:52:52 +0200 Subject: [PATCH] Simplifies database queries and updates auth Removes the `castle.` prefix from database table names in queries, streamlining data access. Updates authentication to use `WalletTypeInfo` dependency injection for retrieving wallet information. This improves security and aligns with LNBits' authentication patterns. Also modifies the main router's tag to uppercase. --- __init__.py | 5 +++-- crud.py | 30 +++++++++++++++--------------- views_api.py | 39 +++++++++++++++++++++------------------ 3 files changed, 39 insertions(+), 35 deletions(-) diff --git a/__init__.py b/__init__.py index b17296e..56cd641 100644 --- a/__init__.py +++ b/__init__.py @@ -1,9 +1,10 @@ from fastapi import APIRouter +from .crud import db from .views import castle_generic_router from .views_api import castle_api_router -castle_ext: APIRouter = APIRouter(prefix="/castle", tags=["castle"]) +castle_ext: APIRouter = APIRouter(prefix="/castle", tags=["Castle"]) castle_ext.include_router(castle_generic_router) castle_ext.include_router(castle_api_router) @@ -14,4 +15,4 @@ castle_static_files = [ } ] -__all__ = ["castle_ext", "castle_static_files"] +__all__ = ["castle_ext", "castle_static_files", "db"] diff --git a/crud.py b/crud.py index d52b007..226dba2 100644 --- a/crud.py +++ b/crud.py @@ -31,13 +31,13 @@ async def create_account(data: CreateAccount) -> Account: user_id=data.user_id, created_at=datetime.now(), ) - await db.insert("castle.accounts", account) + await db.insert("accounts", account) return account async def get_account(account_id: str) -> Optional[Account]: return await db.fetchone( - "SELECT * FROM castle.accounts WHERE id = :id", + "SELECT * FROM accounts WHERE id = :id", {"id": account_id}, Account, ) @@ -45,7 +45,7 @@ async def get_account(account_id: str) -> Optional[Account]: async def get_account_by_name(name: str) -> Optional[Account]: return await db.fetchone( - "SELECT * FROM castle.accounts WHERE name = :name", + "SELECT * FROM accounts WHERE name = :name", {"name": name}, Account, ) @@ -53,14 +53,14 @@ async def get_account_by_name(name: str) -> Optional[Account]: async def get_all_accounts() -> list[Account]: return await db.fetchall( - "SELECT * FROM castle.accounts ORDER BY account_type, name", + "SELECT * FROM accounts ORDER BY account_type, name", model=Account, ) async def get_accounts_by_type(account_type: AccountType) -> list[Account]: return await db.fetchall( - "SELECT * FROM castle.accounts WHERE account_type = :type ORDER BY name", + "SELECT * FROM accounts WHERE account_type = :type ORDER BY name", {"type": account_type.value}, Account, ) @@ -74,7 +74,7 @@ async def get_or_create_user_account( account = await db.fetchone( """ - SELECT * FROM castle.accounts + SELECT * FROM accounts WHERE user_id = :user_id AND account_type = :type AND name = :name """, {"user_id": user_id, "type": account_type.value, "name": account_name}, @@ -123,7 +123,7 @@ async def create_journal_entry( lines=[], ) - await db.insert("castle.journal_entries", journal_entry) + await db.insert("journal_entries", journal_entry) # Create entry lines lines = [] @@ -137,7 +137,7 @@ async def create_journal_entry( credit=line_data.credit, description=line_data.description, ) - await db.insert("castle.entry_lines", line) + await db.insert("entry_lines", line) lines.append(line) journal_entry.lines = lines @@ -146,7 +146,7 @@ async def create_journal_entry( async def get_journal_entry(entry_id: str) -> Optional[JournalEntry]: entry = await db.fetchone( - "SELECT * FROM castle.journal_entries WHERE id = :id", + "SELECT * FROM journal_entries WHERE id = :id", {"id": entry_id}, JournalEntry, ) @@ -159,7 +159,7 @@ async def get_journal_entry(entry_id: str) -> Optional[JournalEntry]: async def get_entry_lines(journal_entry_id: str) -> list[EntryLine]: return await db.fetchall( - "SELECT * FROM castle.entry_lines WHERE journal_entry_id = :id", + "SELECT * FROM entry_lines WHERE journal_entry_id = :id", {"id": journal_entry_id}, EntryLine, ) @@ -168,7 +168,7 @@ async def get_entry_lines(journal_entry_id: str) -> list[EntryLine]: async def get_all_journal_entries(limit: int = 100) -> list[JournalEntry]: entries = await db.fetchall( """ - SELECT * FROM castle.journal_entries + SELECT * FROM journal_entries ORDER BY entry_date DESC, created_at DESC LIMIT :limit """, @@ -187,7 +187,7 @@ async def get_journal_entries_by_user( ) -> list[JournalEntry]: entries = await db.fetchall( """ - SELECT * FROM castle.journal_entries + SELECT * FROM journal_entries WHERE created_by = :user_id ORDER BY entry_date DESC, created_at DESC LIMIT :limit @@ -212,7 +212,7 @@ async def get_account_balance(account_id: str) -> int: SELECT COALESCE(SUM(debit), 0) as total_debit, COALESCE(SUM(credit), 0) as total_credit - FROM castle.entry_lines + FROM entry_lines WHERE account_id = :id """, {"id": account_id}, @@ -241,7 +241,7 @@ async def get_user_balance(user_id: str) -> UserBalance: """Get user's balance with the Castle (positive = castle owes user, negative = user owes castle)""" # Get all user-specific accounts user_accounts = await db.fetchall( - "SELECT * FROM castle.accounts WHERE user_id = :user_id", + "SELECT * FROM accounts WHERE user_id = :user_id", {"user_id": user_id}, Account, ) @@ -272,7 +272,7 @@ async def get_account_transactions( """Get all transactions affecting a specific account""" lines = await db.fetchall( """ - SELECT * FROM castle.entry_lines + SELECT * FROM entry_lines WHERE account_id = :id ORDER BY id DESC LIMIT :limit diff --git a/views_api.py b/views_api.py index 439c6ee..9a7fc6a 100644 --- a/views_api.py +++ b/views_api.py @@ -1,7 +1,7 @@ from http import HTTPStatus -from fastapi import APIRouter, HTTPException -from lnbits.core.crud import get_user +from fastapi import APIRouter, Depends, HTTPException +from lnbits.core.models import WalletTypeInfo from lnbits.decorators import require_admin_key, require_invoice_key from .crud import ( @@ -46,7 +46,7 @@ async def api_get_accounts() -> list[Account]: @castle_api_router.post("/accounts", status_code=HTTPStatus.CREATED) async def api_create_account( data: CreateAccount, - wallet_id: str = require_admin_key, # type: ignore + wallet: WalletTypeInfo = Depends(require_admin_key), ) -> Account: """Create a new account (admin only)""" return await create_account(data) @@ -94,10 +94,11 @@ async def api_get_journal_entries(limit: int = 100) -> list[JournalEntry]: @castle_api_router.get("/entries/user") async def api_get_user_entries( - wallet_id: str = require_invoice_key, limit: int = 100 # type: ignore + wallet: WalletTypeInfo = Depends(require_invoice_key), + limit: int = 100, ) -> list[JournalEntry]: """Get journal entries created by the current user""" - return await get_journal_entries_by_user(wallet_id, limit) + return await get_journal_entries_by_user(wallet.wallet.id, limit) @castle_api_router.get("/entries/{entry_id}") @@ -114,11 +115,11 @@ async def api_get_journal_entry(entry_id: str) -> JournalEntry: @castle_api_router.post("/entries", status_code=HTTPStatus.CREATED) async def api_create_journal_entry( data: CreateJournalEntry, - wallet_id: str = require_invoice_key, # type: ignore + wallet: WalletTypeInfo = Depends(require_invoice_key), ) -> JournalEntry: """Create a new journal entry""" try: - return await create_journal_entry(data, wallet_id) + return await create_journal_entry(data, wallet.wallet.id) except ValueError as e: raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail=str(e)) @@ -129,7 +130,7 @@ async def api_create_journal_entry( @castle_api_router.post("/entries/expense", status_code=HTTPStatus.CREATED) async def api_create_expense_entry( data: ExpenseEntry, - wallet_id: str = require_invoice_key, # type: ignore + wallet: WalletTypeInfo = Depends(require_invoice_key), ) -> JournalEntry: """ Create an expense entry for a user. @@ -180,13 +181,13 @@ async def api_create_expense_entry( ], ) - return await create_journal_entry(entry_data, wallet_id) + return await create_journal_entry(entry_data, wallet.wallet.id) @castle_api_router.post("/entries/receivable", status_code=HTTPStatus.CREATED) async def api_create_receivable_entry( data: ReceivableEntry, - wallet_id: str = require_admin_key, # type: ignore + wallet: WalletTypeInfo = Depends(require_admin_key), ) -> JournalEntry: """ Create an accounts receivable entry (user owes castle). @@ -228,13 +229,13 @@ async def api_create_receivable_entry( ], ) - return await create_journal_entry(entry_data, wallet_id) + return await create_journal_entry(entry_data, wallet.wallet.id) @castle_api_router.post("/entries/revenue", status_code=HTTPStatus.CREATED) async def api_create_revenue_entry( data: RevenueEntry, - wallet_id: str = require_admin_key, # type: ignore + wallet: WalletTypeInfo = Depends(require_admin_key), ) -> JournalEntry: """ Create a revenue entry (castle receives payment). @@ -281,7 +282,7 @@ async def api_create_revenue_entry( ], ) - return await create_journal_entry(entry_data, wallet_id) + return await create_journal_entry(entry_data, wallet.wallet.id) # ===== USER BALANCE ENDPOINTS ===== @@ -289,10 +290,10 @@ async def api_create_revenue_entry( @castle_api_router.get("/balance") async def api_get_my_balance( - wallet_id: str = require_invoice_key, # type: ignore + wallet: WalletTypeInfo = Depends(require_invoice_key), ) -> UserBalance: """Get current user's balance with the Castle""" - return await get_user_balance(wallet_id) + return await get_user_balance(wallet.wallet.id) @castle_api_router.get("/balance/{user_id}") @@ -307,12 +308,14 @@ async def api_get_user_balance(user_id: str) -> UserBalance: @castle_api_router.post("/pay-balance") async def api_pay_balance( amount: int, - wallet_id: str = require_invoice_key, # type: ignore + wallet: WalletTypeInfo = Depends(require_invoice_key), ) -> dict: """ Record a payment from user to castle (reduces what user owes or what castle owes user). This should be called after an invoice is paid. """ + wallet_id = wallet.wallet.id + # Get user's receivable account (what user owes) user_receivable = await get_or_create_user_account( wallet_id, AccountType.ASSET, "Accounts Receivable" @@ -361,7 +364,7 @@ async def api_pay_balance( async def api_pay_user( user_id: str, amount: int, - wallet_id: str = require_admin_key, # type: ignore + wallet: WalletTypeInfo = Depends(require_admin_key), ) -> dict: """ Record a payment from castle to user (reduces what castle owes user). @@ -399,7 +402,7 @@ async def api_pay_user( ], ) - entry = await create_journal_entry(entry_data, wallet_id) + entry = await create_journal_entry(entry_data, wallet.wallet.id) # Get updated balance balance = await get_user_balance(user_id)