Simplifies database queries and updates auth
Removes the `castle.` prefix from database table names in queries, streamlining data access. Updates authentication to use `WalletTypeInfo` dependency injection for retrieving wallet information. This improves security and aligns with LNBits' authentication patterns. Also modifies the main router's tag to uppercase.
This commit is contained in:
parent
cdd0cda001
commit
5589d813f0
3 changed files with 39 additions and 35 deletions
|
|
@ -1,9 +1,10 @@
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
from .crud import db
|
||||||
from .views import castle_generic_router
|
from .views import castle_generic_router
|
||||||
from .views_api import castle_api_router
|
from .views_api import castle_api_router
|
||||||
|
|
||||||
castle_ext: APIRouter = APIRouter(prefix="/castle", tags=["castle"])
|
castle_ext: APIRouter = APIRouter(prefix="/castle", tags=["Castle"])
|
||||||
castle_ext.include_router(castle_generic_router)
|
castle_ext.include_router(castle_generic_router)
|
||||||
castle_ext.include_router(castle_api_router)
|
castle_ext.include_router(castle_api_router)
|
||||||
|
|
||||||
|
|
@ -14,4 +15,4 @@ castle_static_files = [
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
__all__ = ["castle_ext", "castle_static_files"]
|
__all__ = ["castle_ext", "castle_static_files", "db"]
|
||||||
|
|
|
||||||
30
crud.py
30
crud.py
|
|
@ -31,13 +31,13 @@ async def create_account(data: CreateAccount) -> Account:
|
||||||
user_id=data.user_id,
|
user_id=data.user_id,
|
||||||
created_at=datetime.now(),
|
created_at=datetime.now(),
|
||||||
)
|
)
|
||||||
await db.insert("castle.accounts", account)
|
await db.insert("accounts", account)
|
||||||
return account
|
return account
|
||||||
|
|
||||||
|
|
||||||
async def get_account(account_id: str) -> Optional[Account]:
|
async def get_account(account_id: str) -> Optional[Account]:
|
||||||
return await db.fetchone(
|
return await db.fetchone(
|
||||||
"SELECT * FROM castle.accounts WHERE id = :id",
|
"SELECT * FROM accounts WHERE id = :id",
|
||||||
{"id": account_id},
|
{"id": account_id},
|
||||||
Account,
|
Account,
|
||||||
)
|
)
|
||||||
|
|
@ -45,7 +45,7 @@ async def get_account(account_id: str) -> Optional[Account]:
|
||||||
|
|
||||||
async def get_account_by_name(name: str) -> Optional[Account]:
|
async def get_account_by_name(name: str) -> Optional[Account]:
|
||||||
return await db.fetchone(
|
return await db.fetchone(
|
||||||
"SELECT * FROM castle.accounts WHERE name = :name",
|
"SELECT * FROM accounts WHERE name = :name",
|
||||||
{"name": name},
|
{"name": name},
|
||||||
Account,
|
Account,
|
||||||
)
|
)
|
||||||
|
|
@ -53,14 +53,14 @@ async def get_account_by_name(name: str) -> Optional[Account]:
|
||||||
|
|
||||||
async def get_all_accounts() -> list[Account]:
|
async def get_all_accounts() -> list[Account]:
|
||||||
return await db.fetchall(
|
return await db.fetchall(
|
||||||
"SELECT * FROM castle.accounts ORDER BY account_type, name",
|
"SELECT * FROM accounts ORDER BY account_type, name",
|
||||||
model=Account,
|
model=Account,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def get_accounts_by_type(account_type: AccountType) -> list[Account]:
|
async def get_accounts_by_type(account_type: AccountType) -> list[Account]:
|
||||||
return await db.fetchall(
|
return await db.fetchall(
|
||||||
"SELECT * FROM castle.accounts WHERE account_type = :type ORDER BY name",
|
"SELECT * FROM accounts WHERE account_type = :type ORDER BY name",
|
||||||
{"type": account_type.value},
|
{"type": account_type.value},
|
||||||
Account,
|
Account,
|
||||||
)
|
)
|
||||||
|
|
@ -74,7 +74,7 @@ async def get_or_create_user_account(
|
||||||
|
|
||||||
account = await db.fetchone(
|
account = await db.fetchone(
|
||||||
"""
|
"""
|
||||||
SELECT * FROM castle.accounts
|
SELECT * FROM accounts
|
||||||
WHERE user_id = :user_id AND account_type = :type AND name = :name
|
WHERE user_id = :user_id AND account_type = :type AND name = :name
|
||||||
""",
|
""",
|
||||||
{"user_id": user_id, "type": account_type.value, "name": account_name},
|
{"user_id": user_id, "type": account_type.value, "name": account_name},
|
||||||
|
|
@ -123,7 +123,7 @@ async def create_journal_entry(
|
||||||
lines=[],
|
lines=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
await db.insert("castle.journal_entries", journal_entry)
|
await db.insert("journal_entries", journal_entry)
|
||||||
|
|
||||||
# Create entry lines
|
# Create entry lines
|
||||||
lines = []
|
lines = []
|
||||||
|
|
@ -137,7 +137,7 @@ async def create_journal_entry(
|
||||||
credit=line_data.credit,
|
credit=line_data.credit,
|
||||||
description=line_data.description,
|
description=line_data.description,
|
||||||
)
|
)
|
||||||
await db.insert("castle.entry_lines", line)
|
await db.insert("entry_lines", line)
|
||||||
lines.append(line)
|
lines.append(line)
|
||||||
|
|
||||||
journal_entry.lines = lines
|
journal_entry.lines = lines
|
||||||
|
|
@ -146,7 +146,7 @@ async def create_journal_entry(
|
||||||
|
|
||||||
async def get_journal_entry(entry_id: str) -> Optional[JournalEntry]:
|
async def get_journal_entry(entry_id: str) -> Optional[JournalEntry]:
|
||||||
entry = await db.fetchone(
|
entry = await db.fetchone(
|
||||||
"SELECT * FROM castle.journal_entries WHERE id = :id",
|
"SELECT * FROM journal_entries WHERE id = :id",
|
||||||
{"id": entry_id},
|
{"id": entry_id},
|
||||||
JournalEntry,
|
JournalEntry,
|
||||||
)
|
)
|
||||||
|
|
@ -159,7 +159,7 @@ async def get_journal_entry(entry_id: str) -> Optional[JournalEntry]:
|
||||||
|
|
||||||
async def get_entry_lines(journal_entry_id: str) -> list[EntryLine]:
|
async def get_entry_lines(journal_entry_id: str) -> list[EntryLine]:
|
||||||
return await db.fetchall(
|
return await db.fetchall(
|
||||||
"SELECT * FROM castle.entry_lines WHERE journal_entry_id = :id",
|
"SELECT * FROM entry_lines WHERE journal_entry_id = :id",
|
||||||
{"id": journal_entry_id},
|
{"id": journal_entry_id},
|
||||||
EntryLine,
|
EntryLine,
|
||||||
)
|
)
|
||||||
|
|
@ -168,7 +168,7 @@ async def get_entry_lines(journal_entry_id: str) -> list[EntryLine]:
|
||||||
async def get_all_journal_entries(limit: int = 100) -> list[JournalEntry]:
|
async def get_all_journal_entries(limit: int = 100) -> list[JournalEntry]:
|
||||||
entries = await db.fetchall(
|
entries = await db.fetchall(
|
||||||
"""
|
"""
|
||||||
SELECT * FROM castle.journal_entries
|
SELECT * FROM journal_entries
|
||||||
ORDER BY entry_date DESC, created_at DESC
|
ORDER BY entry_date DESC, created_at DESC
|
||||||
LIMIT :limit
|
LIMIT :limit
|
||||||
""",
|
""",
|
||||||
|
|
@ -187,7 +187,7 @@ async def get_journal_entries_by_user(
|
||||||
) -> list[JournalEntry]:
|
) -> list[JournalEntry]:
|
||||||
entries = await db.fetchall(
|
entries = await db.fetchall(
|
||||||
"""
|
"""
|
||||||
SELECT * FROM castle.journal_entries
|
SELECT * FROM journal_entries
|
||||||
WHERE created_by = :user_id
|
WHERE created_by = :user_id
|
||||||
ORDER BY entry_date DESC, created_at DESC
|
ORDER BY entry_date DESC, created_at DESC
|
||||||
LIMIT :limit
|
LIMIT :limit
|
||||||
|
|
@ -212,7 +212,7 @@ async def get_account_balance(account_id: str) -> int:
|
||||||
SELECT
|
SELECT
|
||||||
COALESCE(SUM(debit), 0) as total_debit,
|
COALESCE(SUM(debit), 0) as total_debit,
|
||||||
COALESCE(SUM(credit), 0) as total_credit
|
COALESCE(SUM(credit), 0) as total_credit
|
||||||
FROM castle.entry_lines
|
FROM entry_lines
|
||||||
WHERE account_id = :id
|
WHERE account_id = :id
|
||||||
""",
|
""",
|
||||||
{"id": account_id},
|
{"id": account_id},
|
||||||
|
|
@ -241,7 +241,7 @@ async def get_user_balance(user_id: str) -> UserBalance:
|
||||||
"""Get user's balance with the Castle (positive = castle owes user, negative = user owes castle)"""
|
"""Get user's balance with the Castle (positive = castle owes user, negative = user owes castle)"""
|
||||||
# Get all user-specific accounts
|
# Get all user-specific accounts
|
||||||
user_accounts = await db.fetchall(
|
user_accounts = await db.fetchall(
|
||||||
"SELECT * FROM castle.accounts WHERE user_id = :user_id",
|
"SELECT * FROM accounts WHERE user_id = :user_id",
|
||||||
{"user_id": user_id},
|
{"user_id": user_id},
|
||||||
Account,
|
Account,
|
||||||
)
|
)
|
||||||
|
|
@ -272,7 +272,7 @@ async def get_account_transactions(
|
||||||
"""Get all transactions affecting a specific account"""
|
"""Get all transactions affecting a specific account"""
|
||||||
lines = await db.fetchall(
|
lines = await db.fetchall(
|
||||||
"""
|
"""
|
||||||
SELECT * FROM castle.entry_lines
|
SELECT * FROM entry_lines
|
||||||
WHERE account_id = :id
|
WHERE account_id = :id
|
||||||
ORDER BY id DESC
|
ORDER BY id DESC
|
||||||
LIMIT :limit
|
LIMIT :limit
|
||||||
|
|
|
||||||
39
views_api.py
39
views_api.py
|
|
@ -1,7 +1,7 @@
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from lnbits.core.crud import get_user
|
from lnbits.core.models import WalletTypeInfo
|
||||||
from lnbits.decorators import require_admin_key, require_invoice_key
|
from lnbits.decorators import require_admin_key, require_invoice_key
|
||||||
|
|
||||||
from .crud import (
|
from .crud import (
|
||||||
|
|
@ -46,7 +46,7 @@ async def api_get_accounts() -> list[Account]:
|
||||||
@castle_api_router.post("/accounts", status_code=HTTPStatus.CREATED)
|
@castle_api_router.post("/accounts", status_code=HTTPStatus.CREATED)
|
||||||
async def api_create_account(
|
async def api_create_account(
|
||||||
data: CreateAccount,
|
data: CreateAccount,
|
||||||
wallet_id: str = require_admin_key, # type: ignore
|
wallet: WalletTypeInfo = Depends(require_admin_key),
|
||||||
) -> Account:
|
) -> Account:
|
||||||
"""Create a new account (admin only)"""
|
"""Create a new account (admin only)"""
|
||||||
return await create_account(data)
|
return await create_account(data)
|
||||||
|
|
@ -94,10 +94,11 @@ async def api_get_journal_entries(limit: int = 100) -> list[JournalEntry]:
|
||||||
|
|
||||||
@castle_api_router.get("/entries/user")
|
@castle_api_router.get("/entries/user")
|
||||||
async def api_get_user_entries(
|
async def api_get_user_entries(
|
||||||
wallet_id: str = require_invoice_key, limit: int = 100 # type: ignore
|
wallet: WalletTypeInfo = Depends(require_invoice_key),
|
||||||
|
limit: int = 100,
|
||||||
) -> list[JournalEntry]:
|
) -> list[JournalEntry]:
|
||||||
"""Get journal entries created by the current user"""
|
"""Get journal entries created by the current user"""
|
||||||
return await get_journal_entries_by_user(wallet_id, limit)
|
return await get_journal_entries_by_user(wallet.wallet.id, limit)
|
||||||
|
|
||||||
|
|
||||||
@castle_api_router.get("/entries/{entry_id}")
|
@castle_api_router.get("/entries/{entry_id}")
|
||||||
|
|
@ -114,11 +115,11 @@ async def api_get_journal_entry(entry_id: str) -> JournalEntry:
|
||||||
@castle_api_router.post("/entries", status_code=HTTPStatus.CREATED)
|
@castle_api_router.post("/entries", status_code=HTTPStatus.CREATED)
|
||||||
async def api_create_journal_entry(
|
async def api_create_journal_entry(
|
||||||
data: CreateJournalEntry,
|
data: CreateJournalEntry,
|
||||||
wallet_id: str = require_invoice_key, # type: ignore
|
wallet: WalletTypeInfo = Depends(require_invoice_key),
|
||||||
) -> JournalEntry:
|
) -> JournalEntry:
|
||||||
"""Create a new journal entry"""
|
"""Create a new journal entry"""
|
||||||
try:
|
try:
|
||||||
return await create_journal_entry(data, wallet_id)
|
return await create_journal_entry(data, wallet.wallet.id)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail=str(e))
|
raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail=str(e))
|
||||||
|
|
||||||
|
|
@ -129,7 +130,7 @@ async def api_create_journal_entry(
|
||||||
@castle_api_router.post("/entries/expense", status_code=HTTPStatus.CREATED)
|
@castle_api_router.post("/entries/expense", status_code=HTTPStatus.CREATED)
|
||||||
async def api_create_expense_entry(
|
async def api_create_expense_entry(
|
||||||
data: ExpenseEntry,
|
data: ExpenseEntry,
|
||||||
wallet_id: str = require_invoice_key, # type: ignore
|
wallet: WalletTypeInfo = Depends(require_invoice_key),
|
||||||
) -> JournalEntry:
|
) -> JournalEntry:
|
||||||
"""
|
"""
|
||||||
Create an expense entry for a user.
|
Create an expense entry for a user.
|
||||||
|
|
@ -180,13 +181,13 @@ async def api_create_expense_entry(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
return await create_journal_entry(entry_data, wallet_id)
|
return await create_journal_entry(entry_data, wallet.wallet.id)
|
||||||
|
|
||||||
|
|
||||||
@castle_api_router.post("/entries/receivable", status_code=HTTPStatus.CREATED)
|
@castle_api_router.post("/entries/receivable", status_code=HTTPStatus.CREATED)
|
||||||
async def api_create_receivable_entry(
|
async def api_create_receivable_entry(
|
||||||
data: ReceivableEntry,
|
data: ReceivableEntry,
|
||||||
wallet_id: str = require_admin_key, # type: ignore
|
wallet: WalletTypeInfo = Depends(require_admin_key),
|
||||||
) -> JournalEntry:
|
) -> JournalEntry:
|
||||||
"""
|
"""
|
||||||
Create an accounts receivable entry (user owes castle).
|
Create an accounts receivable entry (user owes castle).
|
||||||
|
|
@ -228,13 +229,13 @@ async def api_create_receivable_entry(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
return await create_journal_entry(entry_data, wallet_id)
|
return await create_journal_entry(entry_data, wallet.wallet.id)
|
||||||
|
|
||||||
|
|
||||||
@castle_api_router.post("/entries/revenue", status_code=HTTPStatus.CREATED)
|
@castle_api_router.post("/entries/revenue", status_code=HTTPStatus.CREATED)
|
||||||
async def api_create_revenue_entry(
|
async def api_create_revenue_entry(
|
||||||
data: RevenueEntry,
|
data: RevenueEntry,
|
||||||
wallet_id: str = require_admin_key, # type: ignore
|
wallet: WalletTypeInfo = Depends(require_admin_key),
|
||||||
) -> JournalEntry:
|
) -> JournalEntry:
|
||||||
"""
|
"""
|
||||||
Create a revenue entry (castle receives payment).
|
Create a revenue entry (castle receives payment).
|
||||||
|
|
@ -281,7 +282,7 @@ async def api_create_revenue_entry(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
return await create_journal_entry(entry_data, wallet_id)
|
return await create_journal_entry(entry_data, wallet.wallet.id)
|
||||||
|
|
||||||
|
|
||||||
# ===== USER BALANCE ENDPOINTS =====
|
# ===== USER BALANCE ENDPOINTS =====
|
||||||
|
|
@ -289,10 +290,10 @@ async def api_create_revenue_entry(
|
||||||
|
|
||||||
@castle_api_router.get("/balance")
|
@castle_api_router.get("/balance")
|
||||||
async def api_get_my_balance(
|
async def api_get_my_balance(
|
||||||
wallet_id: str = require_invoice_key, # type: ignore
|
wallet: WalletTypeInfo = Depends(require_invoice_key),
|
||||||
) -> UserBalance:
|
) -> UserBalance:
|
||||||
"""Get current user's balance with the Castle"""
|
"""Get current user's balance with the Castle"""
|
||||||
return await get_user_balance(wallet_id)
|
return await get_user_balance(wallet.wallet.id)
|
||||||
|
|
||||||
|
|
||||||
@castle_api_router.get("/balance/{user_id}")
|
@castle_api_router.get("/balance/{user_id}")
|
||||||
|
|
@ -307,12 +308,14 @@ async def api_get_user_balance(user_id: str) -> UserBalance:
|
||||||
@castle_api_router.post("/pay-balance")
|
@castle_api_router.post("/pay-balance")
|
||||||
async def api_pay_balance(
|
async def api_pay_balance(
|
||||||
amount: int,
|
amount: int,
|
||||||
wallet_id: str = require_invoice_key, # type: ignore
|
wallet: WalletTypeInfo = Depends(require_invoice_key),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
Record a payment from user to castle (reduces what user owes or what castle owes user).
|
Record a payment from user to castle (reduces what user owes or what castle owes user).
|
||||||
This should be called after an invoice is paid.
|
This should be called after an invoice is paid.
|
||||||
"""
|
"""
|
||||||
|
wallet_id = wallet.wallet.id
|
||||||
|
|
||||||
# Get user's receivable account (what user owes)
|
# Get user's receivable account (what user owes)
|
||||||
user_receivable = await get_or_create_user_account(
|
user_receivable = await get_or_create_user_account(
|
||||||
wallet_id, AccountType.ASSET, "Accounts Receivable"
|
wallet_id, AccountType.ASSET, "Accounts Receivable"
|
||||||
|
|
@ -361,7 +364,7 @@ async def api_pay_balance(
|
||||||
async def api_pay_user(
|
async def api_pay_user(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
amount: int,
|
amount: int,
|
||||||
wallet_id: str = require_admin_key, # type: ignore
|
wallet: WalletTypeInfo = Depends(require_admin_key),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
Record a payment from castle to user (reduces what castle owes user).
|
Record a payment from castle to user (reduces what castle owes user).
|
||||||
|
|
@ -399,7 +402,7 @@ async def api_pay_user(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
entry = await create_journal_entry(entry_data, wallet_id)
|
entry = await create_journal_entry(entry_data, wallet.wallet.id)
|
||||||
|
|
||||||
# Get updated balance
|
# Get updated balance
|
||||||
balance = await get_user_balance(user_id)
|
balance = await get_user_balance(user_id)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue