Simplifies database queries and updates auth

Removes the `castle.` prefix from database table names in queries, streamlining data access.

Updates authentication to use `WalletTypeInfo` dependency injection for retrieving wallet information. This improves security and aligns with LNBits' authentication patterns. Also modifies the main router's tag to uppercase.
This commit is contained in:
padreug 2025-10-22 12:52:52 +02:00
parent cdd0cda001
commit 5589d813f0
3 changed files with 39 additions and 35 deletions

View file

@ -1,9 +1,10 @@
from fastapi import APIRouter from fastapi import APIRouter
from .crud import db
from .views import castle_generic_router from .views import castle_generic_router
from .views_api import castle_api_router from .views_api import castle_api_router
castle_ext: APIRouter = APIRouter(prefix="/castle", tags=["castle"]) castle_ext: APIRouter = APIRouter(prefix="/castle", tags=["Castle"])
castle_ext.include_router(castle_generic_router) castle_ext.include_router(castle_generic_router)
castle_ext.include_router(castle_api_router) castle_ext.include_router(castle_api_router)
@ -14,4 +15,4 @@ castle_static_files = [
} }
] ]
__all__ = ["castle_ext", "castle_static_files"] __all__ = ["castle_ext", "castle_static_files", "db"]

30
crud.py
View file

@ -31,13 +31,13 @@ async def create_account(data: CreateAccount) -> Account:
user_id=data.user_id, user_id=data.user_id,
created_at=datetime.now(), created_at=datetime.now(),
) )
await db.insert("castle.accounts", account) await db.insert("accounts", account)
return account return account
async def get_account(account_id: str) -> Optional[Account]: async def get_account(account_id: str) -> Optional[Account]:
return await db.fetchone( return await db.fetchone(
"SELECT * FROM castle.accounts WHERE id = :id", "SELECT * FROM accounts WHERE id = :id",
{"id": account_id}, {"id": account_id},
Account, Account,
) )
@ -45,7 +45,7 @@ async def get_account(account_id: str) -> Optional[Account]:
async def get_account_by_name(name: str) -> Optional[Account]: async def get_account_by_name(name: str) -> Optional[Account]:
return await db.fetchone( return await db.fetchone(
"SELECT * FROM castle.accounts WHERE name = :name", "SELECT * FROM accounts WHERE name = :name",
{"name": name}, {"name": name},
Account, Account,
) )
@ -53,14 +53,14 @@ async def get_account_by_name(name: str) -> Optional[Account]:
async def get_all_accounts() -> list[Account]: async def get_all_accounts() -> list[Account]:
return await db.fetchall( return await db.fetchall(
"SELECT * FROM castle.accounts ORDER BY account_type, name", "SELECT * FROM accounts ORDER BY account_type, name",
model=Account, model=Account,
) )
async def get_accounts_by_type(account_type: AccountType) -> list[Account]: async def get_accounts_by_type(account_type: AccountType) -> list[Account]:
return await db.fetchall( return await db.fetchall(
"SELECT * FROM castle.accounts WHERE account_type = :type ORDER BY name", "SELECT * FROM accounts WHERE account_type = :type ORDER BY name",
{"type": account_type.value}, {"type": account_type.value},
Account, Account,
) )
@ -74,7 +74,7 @@ async def get_or_create_user_account(
account = await db.fetchone( account = await db.fetchone(
""" """
SELECT * FROM castle.accounts SELECT * FROM accounts
WHERE user_id = :user_id AND account_type = :type AND name = :name WHERE user_id = :user_id AND account_type = :type AND name = :name
""", """,
{"user_id": user_id, "type": account_type.value, "name": account_name}, {"user_id": user_id, "type": account_type.value, "name": account_name},
@ -123,7 +123,7 @@ async def create_journal_entry(
lines=[], lines=[],
) )
await db.insert("castle.journal_entries", journal_entry) await db.insert("journal_entries", journal_entry)
# Create entry lines # Create entry lines
lines = [] lines = []
@ -137,7 +137,7 @@ async def create_journal_entry(
credit=line_data.credit, credit=line_data.credit,
description=line_data.description, description=line_data.description,
) )
await db.insert("castle.entry_lines", line) await db.insert("entry_lines", line)
lines.append(line) lines.append(line)
journal_entry.lines = lines journal_entry.lines = lines
@ -146,7 +146,7 @@ async def create_journal_entry(
async def get_journal_entry(entry_id: str) -> Optional[JournalEntry]: async def get_journal_entry(entry_id: str) -> Optional[JournalEntry]:
entry = await db.fetchone( entry = await db.fetchone(
"SELECT * FROM castle.journal_entries WHERE id = :id", "SELECT * FROM journal_entries WHERE id = :id",
{"id": entry_id}, {"id": entry_id},
JournalEntry, JournalEntry,
) )
@ -159,7 +159,7 @@ async def get_journal_entry(entry_id: str) -> Optional[JournalEntry]:
async def get_entry_lines(journal_entry_id: str) -> list[EntryLine]: async def get_entry_lines(journal_entry_id: str) -> list[EntryLine]:
return await db.fetchall( return await db.fetchall(
"SELECT * FROM castle.entry_lines WHERE journal_entry_id = :id", "SELECT * FROM entry_lines WHERE journal_entry_id = :id",
{"id": journal_entry_id}, {"id": journal_entry_id},
EntryLine, EntryLine,
) )
@ -168,7 +168,7 @@ async def get_entry_lines(journal_entry_id: str) -> list[EntryLine]:
async def get_all_journal_entries(limit: int = 100) -> list[JournalEntry]: async def get_all_journal_entries(limit: int = 100) -> list[JournalEntry]:
entries = await db.fetchall( entries = await db.fetchall(
""" """
SELECT * FROM castle.journal_entries SELECT * FROM journal_entries
ORDER BY entry_date DESC, created_at DESC ORDER BY entry_date DESC, created_at DESC
LIMIT :limit LIMIT :limit
""", """,
@ -187,7 +187,7 @@ async def get_journal_entries_by_user(
) -> list[JournalEntry]: ) -> list[JournalEntry]:
entries = await db.fetchall( entries = await db.fetchall(
""" """
SELECT * FROM castle.journal_entries SELECT * FROM journal_entries
WHERE created_by = :user_id WHERE created_by = :user_id
ORDER BY entry_date DESC, created_at DESC ORDER BY entry_date DESC, created_at DESC
LIMIT :limit LIMIT :limit
@ -212,7 +212,7 @@ async def get_account_balance(account_id: str) -> int:
SELECT SELECT
COALESCE(SUM(debit), 0) as total_debit, COALESCE(SUM(debit), 0) as total_debit,
COALESCE(SUM(credit), 0) as total_credit COALESCE(SUM(credit), 0) as total_credit
FROM castle.entry_lines FROM entry_lines
WHERE account_id = :id WHERE account_id = :id
""", """,
{"id": account_id}, {"id": account_id},
@ -241,7 +241,7 @@ async def get_user_balance(user_id: str) -> UserBalance:
"""Get user's balance with the Castle (positive = castle owes user, negative = user owes castle)""" """Get user's balance with the Castle (positive = castle owes user, negative = user owes castle)"""
# Get all user-specific accounts # Get all user-specific accounts
user_accounts = await db.fetchall( user_accounts = await db.fetchall(
"SELECT * FROM castle.accounts WHERE user_id = :user_id", "SELECT * FROM accounts WHERE user_id = :user_id",
{"user_id": user_id}, {"user_id": user_id},
Account, Account,
) )
@ -272,7 +272,7 @@ async def get_account_transactions(
"""Get all transactions affecting a specific account""" """Get all transactions affecting a specific account"""
lines = await db.fetchall( lines = await db.fetchall(
""" """
SELECT * FROM castle.entry_lines SELECT * FROM entry_lines
WHERE account_id = :id WHERE account_id = :id
ORDER BY id DESC ORDER BY id DESC
LIMIT :limit LIMIT :limit

View file

@ -1,7 +1,7 @@
from http import HTTPStatus from http import HTTPStatus
from fastapi import APIRouter, HTTPException from fastapi import APIRouter, Depends, HTTPException
from lnbits.core.crud import get_user from lnbits.core.models import WalletTypeInfo
from lnbits.decorators import require_admin_key, require_invoice_key from lnbits.decorators import require_admin_key, require_invoice_key
from .crud import ( from .crud import (
@ -46,7 +46,7 @@ async def api_get_accounts() -> list[Account]:
@castle_api_router.post("/accounts", status_code=HTTPStatus.CREATED) @castle_api_router.post("/accounts", status_code=HTTPStatus.CREATED)
async def api_create_account( async def api_create_account(
data: CreateAccount, data: CreateAccount,
wallet_id: str = require_admin_key, # type: ignore wallet: WalletTypeInfo = Depends(require_admin_key),
) -> Account: ) -> Account:
"""Create a new account (admin only)""" """Create a new account (admin only)"""
return await create_account(data) return await create_account(data)
@ -94,10 +94,11 @@ async def api_get_journal_entries(limit: int = 100) -> list[JournalEntry]:
@castle_api_router.get("/entries/user") @castle_api_router.get("/entries/user")
async def api_get_user_entries( async def api_get_user_entries(
wallet_id: str = require_invoice_key, limit: int = 100 # type: ignore wallet: WalletTypeInfo = Depends(require_invoice_key),
limit: int = 100,
) -> list[JournalEntry]: ) -> list[JournalEntry]:
"""Get journal entries created by the current user""" """Get journal entries created by the current user"""
return await get_journal_entries_by_user(wallet_id, limit) return await get_journal_entries_by_user(wallet.wallet.id, limit)
@castle_api_router.get("/entries/{entry_id}") @castle_api_router.get("/entries/{entry_id}")
@ -114,11 +115,11 @@ async def api_get_journal_entry(entry_id: str) -> JournalEntry:
@castle_api_router.post("/entries", status_code=HTTPStatus.CREATED) @castle_api_router.post("/entries", status_code=HTTPStatus.CREATED)
async def api_create_journal_entry( async def api_create_journal_entry(
data: CreateJournalEntry, data: CreateJournalEntry,
wallet_id: str = require_invoice_key, # type: ignore wallet: WalletTypeInfo = Depends(require_invoice_key),
) -> JournalEntry: ) -> JournalEntry:
"""Create a new journal entry""" """Create a new journal entry"""
try: try:
return await create_journal_entry(data, wallet_id) return await create_journal_entry(data, wallet.wallet.id)
except ValueError as e: except ValueError as e:
raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail=str(e)) raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail=str(e))
@ -129,7 +130,7 @@ async def api_create_journal_entry(
@castle_api_router.post("/entries/expense", status_code=HTTPStatus.CREATED) @castle_api_router.post("/entries/expense", status_code=HTTPStatus.CREATED)
async def api_create_expense_entry( async def api_create_expense_entry(
data: ExpenseEntry, data: ExpenseEntry,
wallet_id: str = require_invoice_key, # type: ignore wallet: WalletTypeInfo = Depends(require_invoice_key),
) -> JournalEntry: ) -> JournalEntry:
""" """
Create an expense entry for a user. Create an expense entry for a user.
@ -180,13 +181,13 @@ async def api_create_expense_entry(
], ],
) )
return await create_journal_entry(entry_data, wallet_id) return await create_journal_entry(entry_data, wallet.wallet.id)
@castle_api_router.post("/entries/receivable", status_code=HTTPStatus.CREATED) @castle_api_router.post("/entries/receivable", status_code=HTTPStatus.CREATED)
async def api_create_receivable_entry( async def api_create_receivable_entry(
data: ReceivableEntry, data: ReceivableEntry,
wallet_id: str = require_admin_key, # type: ignore wallet: WalletTypeInfo = Depends(require_admin_key),
) -> JournalEntry: ) -> JournalEntry:
""" """
Create an accounts receivable entry (user owes castle). Create an accounts receivable entry (user owes castle).
@ -228,13 +229,13 @@ async def api_create_receivable_entry(
], ],
) )
return await create_journal_entry(entry_data, wallet_id) return await create_journal_entry(entry_data, wallet.wallet.id)
@castle_api_router.post("/entries/revenue", status_code=HTTPStatus.CREATED) @castle_api_router.post("/entries/revenue", status_code=HTTPStatus.CREATED)
async def api_create_revenue_entry( async def api_create_revenue_entry(
data: RevenueEntry, data: RevenueEntry,
wallet_id: str = require_admin_key, # type: ignore wallet: WalletTypeInfo = Depends(require_admin_key),
) -> JournalEntry: ) -> JournalEntry:
""" """
Create a revenue entry (castle receives payment). Create a revenue entry (castle receives payment).
@ -281,7 +282,7 @@ async def api_create_revenue_entry(
], ],
) )
return await create_journal_entry(entry_data, wallet_id) return await create_journal_entry(entry_data, wallet.wallet.id)
# ===== USER BALANCE ENDPOINTS ===== # ===== USER BALANCE ENDPOINTS =====
@ -289,10 +290,10 @@ async def api_create_revenue_entry(
@castle_api_router.get("/balance") @castle_api_router.get("/balance")
async def api_get_my_balance( async def api_get_my_balance(
wallet_id: str = require_invoice_key, # type: ignore wallet: WalletTypeInfo = Depends(require_invoice_key),
) -> UserBalance: ) -> UserBalance:
"""Get current user's balance with the Castle""" """Get current user's balance with the Castle"""
return await get_user_balance(wallet_id) return await get_user_balance(wallet.wallet.id)
@castle_api_router.get("/balance/{user_id}") @castle_api_router.get("/balance/{user_id}")
@ -307,12 +308,14 @@ async def api_get_user_balance(user_id: str) -> UserBalance:
@castle_api_router.post("/pay-balance") @castle_api_router.post("/pay-balance")
async def api_pay_balance( async def api_pay_balance(
amount: int, amount: int,
wallet_id: str = require_invoice_key, # type: ignore wallet: WalletTypeInfo = Depends(require_invoice_key),
) -> dict: ) -> dict:
""" """
Record a payment from user to castle (reduces what user owes or what castle owes user). Record a payment from user to castle (reduces what user owes or what castle owes user).
This should be called after an invoice is paid. This should be called after an invoice is paid.
""" """
wallet_id = wallet.wallet.id
# Get user's receivable account (what user owes) # Get user's receivable account (what user owes)
user_receivable = await get_or_create_user_account( user_receivable = await get_or_create_user_account(
wallet_id, AccountType.ASSET, "Accounts Receivable" wallet_id, AccountType.ASSET, "Accounts Receivable"
@ -361,7 +364,7 @@ async def api_pay_balance(
async def api_pay_user( async def api_pay_user(
user_id: str, user_id: str,
amount: int, amount: int,
wallet_id: str = require_admin_key, # type: ignore wallet: WalletTypeInfo = Depends(require_admin_key),
) -> dict: ) -> dict:
""" """
Record a payment from castle to user (reduces what castle owes user). Record a payment from castle to user (reduces what castle owes user).
@ -399,7 +402,7 @@ async def api_pay_user(
], ],
) )
entry = await create_journal_entry(entry_data, wallet_id) entry = await create_journal_entry(entry_data, wallet.wallet.id)
# Get updated balance # Get updated balance
balance = await get_user_balance(user_id) balance = await get_user_balance(user_id)