Removes the `castle.` prefix from database table names in queries, streamlining data access. Updates authentication to use `WalletTypeInfo` dependency injection for retrieving wallet information. This improves security and aligns with LNBits' authentication patterns. Also modifies the main router's tag to uppercase.
290 lines
7.6 KiB
Python
290 lines
7.6 KiB
Python
from datetime import datetime
|
|
from typing import Optional
|
|
|
|
from lnbits.db import Database
|
|
from lnbits.helpers import urlsafe_short_hash
|
|
|
|
from .models import (
|
|
Account,
|
|
AccountType,
|
|
CreateAccount,
|
|
CreateEntryLine,
|
|
CreateJournalEntry,
|
|
EntryLine,
|
|
JournalEntry,
|
|
UserBalance,
|
|
)
|
|
|
|
db = Database("ext_castle")
|
|
|
|
|
|
# ===== ACCOUNT OPERATIONS =====
|
|
|
|
|
|
async def create_account(data: CreateAccount) -> Account:
|
|
account_id = urlsafe_short_hash()
|
|
account = Account(
|
|
id=account_id,
|
|
name=data.name,
|
|
account_type=data.account_type,
|
|
description=data.description,
|
|
user_id=data.user_id,
|
|
created_at=datetime.now(),
|
|
)
|
|
await db.insert("accounts", account)
|
|
return account
|
|
|
|
|
|
async def get_account(account_id: str) -> Optional[Account]:
|
|
return await db.fetchone(
|
|
"SELECT * FROM accounts WHERE id = :id",
|
|
{"id": account_id},
|
|
Account,
|
|
)
|
|
|
|
|
|
async def get_account_by_name(name: str) -> Optional[Account]:
|
|
return await db.fetchone(
|
|
"SELECT * FROM accounts WHERE name = :name",
|
|
{"name": name},
|
|
Account,
|
|
)
|
|
|
|
|
|
async def get_all_accounts() -> list[Account]:
|
|
return await db.fetchall(
|
|
"SELECT * FROM accounts ORDER BY account_type, name",
|
|
model=Account,
|
|
)
|
|
|
|
|
|
async def get_accounts_by_type(account_type: AccountType) -> list[Account]:
|
|
return await db.fetchall(
|
|
"SELECT * FROM accounts WHERE account_type = :type ORDER BY name",
|
|
{"type": account_type.value},
|
|
Account,
|
|
)
|
|
|
|
|
|
async def get_or_create_user_account(
|
|
user_id: str, account_type: AccountType, base_name: str
|
|
) -> Account:
|
|
"""Get or create a user-specific account (e.g., 'Accounts Payable - User123')"""
|
|
account_name = f"{base_name} - {user_id[:8]}"
|
|
|
|
account = await db.fetchone(
|
|
"""
|
|
SELECT * FROM accounts
|
|
WHERE user_id = :user_id AND account_type = :type AND name = :name
|
|
""",
|
|
{"user_id": user_id, "type": account_type.value, "name": account_name},
|
|
Account,
|
|
)
|
|
|
|
if not account:
|
|
account = await create_account(
|
|
CreateAccount(
|
|
name=account_name,
|
|
account_type=account_type,
|
|
description=f"User-specific {account_type.value} account",
|
|
user_id=user_id,
|
|
)
|
|
)
|
|
|
|
return account
|
|
|
|
|
|
# ===== JOURNAL ENTRY OPERATIONS =====
|
|
|
|
|
|
async def create_journal_entry(
|
|
data: CreateJournalEntry, created_by: str
|
|
) -> JournalEntry:
|
|
entry_id = urlsafe_short_hash()
|
|
|
|
# Validate that debits equal credits
|
|
total_debits = sum(line.debit for line in data.lines)
|
|
total_credits = sum(line.credit for line in data.lines)
|
|
|
|
if total_debits != total_credits:
|
|
raise ValueError(
|
|
f"Journal entry must balance: debits={total_debits}, credits={total_credits}"
|
|
)
|
|
|
|
entry_date = data.entry_date or datetime.now()
|
|
|
|
journal_entry = JournalEntry(
|
|
id=entry_id,
|
|
description=data.description,
|
|
entry_date=entry_date,
|
|
created_by=created_by,
|
|
created_at=datetime.now(),
|
|
reference=data.reference,
|
|
lines=[],
|
|
)
|
|
|
|
await db.insert("journal_entries", journal_entry)
|
|
|
|
# Create entry lines
|
|
lines = []
|
|
for line_data in data.lines:
|
|
line_id = urlsafe_short_hash()
|
|
line = EntryLine(
|
|
id=line_id,
|
|
journal_entry_id=entry_id,
|
|
account_id=line_data.account_id,
|
|
debit=line_data.debit,
|
|
credit=line_data.credit,
|
|
description=line_data.description,
|
|
)
|
|
await db.insert("entry_lines", line)
|
|
lines.append(line)
|
|
|
|
journal_entry.lines = lines
|
|
return journal_entry
|
|
|
|
|
|
async def get_journal_entry(entry_id: str) -> Optional[JournalEntry]:
|
|
entry = await db.fetchone(
|
|
"SELECT * FROM journal_entries WHERE id = :id",
|
|
{"id": entry_id},
|
|
JournalEntry,
|
|
)
|
|
|
|
if entry:
|
|
entry.lines = await get_entry_lines(entry_id)
|
|
|
|
return entry
|
|
|
|
|
|
async def get_entry_lines(journal_entry_id: str) -> list[EntryLine]:
|
|
return await db.fetchall(
|
|
"SELECT * FROM entry_lines WHERE journal_entry_id = :id",
|
|
{"id": journal_entry_id},
|
|
EntryLine,
|
|
)
|
|
|
|
|
|
async def get_all_journal_entries(limit: int = 100) -> list[JournalEntry]:
|
|
entries = await db.fetchall(
|
|
"""
|
|
SELECT * FROM journal_entries
|
|
ORDER BY entry_date DESC, created_at DESC
|
|
LIMIT :limit
|
|
""",
|
|
{"limit": limit},
|
|
JournalEntry,
|
|
)
|
|
|
|
for entry in entries:
|
|
entry.lines = await get_entry_lines(entry.id)
|
|
|
|
return entries
|
|
|
|
|
|
async def get_journal_entries_by_user(
|
|
user_id: str, limit: int = 100
|
|
) -> list[JournalEntry]:
|
|
entries = await db.fetchall(
|
|
"""
|
|
SELECT * FROM journal_entries
|
|
WHERE created_by = :user_id
|
|
ORDER BY entry_date DESC, created_at DESC
|
|
LIMIT :limit
|
|
""",
|
|
{"user_id": user_id, "limit": limit},
|
|
JournalEntry,
|
|
)
|
|
|
|
for entry in entries:
|
|
entry.lines = await get_entry_lines(entry.id)
|
|
|
|
return entries
|
|
|
|
|
|
# ===== BALANCE AND REPORTING =====
|
|
|
|
|
|
async def get_account_balance(account_id: str) -> int:
|
|
"""Calculate account balance (debits - credits for assets/expenses, credits - debits for liabilities/equity/revenue)"""
|
|
result = await db.fetchone(
|
|
"""
|
|
SELECT
|
|
COALESCE(SUM(debit), 0) as total_debit,
|
|
COALESCE(SUM(credit), 0) as total_credit
|
|
FROM entry_lines
|
|
WHERE account_id = :id
|
|
""",
|
|
{"id": account_id},
|
|
)
|
|
|
|
if not result:
|
|
return 0
|
|
|
|
account = await get_account(account_id)
|
|
if not account:
|
|
return 0
|
|
|
|
total_debit = result["total_debit"]
|
|
total_credit = result["total_credit"]
|
|
|
|
# Normal balance for each account type:
|
|
# Assets and Expenses: Debit balance (debit - credit)
|
|
# Liabilities, Equity, and Revenue: Credit balance (credit - debit)
|
|
if account.account_type in [AccountType.ASSET, AccountType.EXPENSE]:
|
|
return total_debit - total_credit
|
|
else:
|
|
return total_credit - total_debit
|
|
|
|
|
|
async def get_user_balance(user_id: str) -> UserBalance:
|
|
"""Get user's balance with the Castle (positive = castle owes user, negative = user owes castle)"""
|
|
# Get all user-specific accounts
|
|
user_accounts = await db.fetchall(
|
|
"SELECT * FROM accounts WHERE user_id = :user_id",
|
|
{"user_id": user_id},
|
|
Account,
|
|
)
|
|
|
|
total_balance = 0
|
|
|
|
for account in user_accounts:
|
|
balance = await get_account_balance(account.id)
|
|
|
|
# If it's a liability account (castle owes user), it's positive
|
|
# If it's an asset account (user owes castle), it's negative
|
|
if account.account_type == AccountType.LIABILITY:
|
|
total_balance += balance
|
|
elif account.account_type == AccountType.ASSET:
|
|
total_balance -= balance
|
|
# Equity contributions are tracked but don't affect what castle owes
|
|
|
|
return UserBalance(
|
|
user_id=user_id,
|
|
balance=total_balance,
|
|
accounts=user_accounts,
|
|
)
|
|
|
|
|
|
async def get_account_transactions(
|
|
account_id: str, limit: int = 100
|
|
) -> list[tuple[JournalEntry, EntryLine]]:
|
|
"""Get all transactions affecting a specific account"""
|
|
lines = await db.fetchall(
|
|
"""
|
|
SELECT * FROM entry_lines
|
|
WHERE account_id = :id
|
|
ORDER BY id DESC
|
|
LIMIT :limit
|
|
""",
|
|
{"id": account_id, "limit": limit},
|
|
EntryLine,
|
|
)
|
|
|
|
transactions = []
|
|
for line in lines:
|
|
entry = await get_journal_entry(line.journal_entry_id)
|
|
if entry:
|
|
transactions.append((entry, line))
|
|
|
|
return transactions
|