castle/crud.py
2025-10-22 12:33:45 +02:00

290 lines
7.7 KiB
Python

from datetime import datetime
from typing import Optional
from lnbits.db import Database
from lnbits.helpers import urlsafe_short_hash
from .models import (
Account,
AccountType,
CreateAccount,
CreateEntryLine,
CreateJournalEntry,
EntryLine,
JournalEntry,
UserBalance,
)
db = Database("ext_castle")
# ===== ACCOUNT OPERATIONS =====
async def create_account(data: CreateAccount) -> Account:
account_id = urlsafe_short_hash()
account = Account(
id=account_id,
name=data.name,
account_type=data.account_type,
description=data.description,
user_id=data.user_id,
created_at=datetime.now(),
)
await db.insert("castle.accounts", account)
return account
async def get_account(account_id: str) -> Optional[Account]:
return await db.fetchone(
"SELECT * FROM castle.accounts WHERE id = :id",
{"id": account_id},
Account,
)
async def get_account_by_name(name: str) -> Optional[Account]:
return await db.fetchone(
"SELECT * FROM castle.accounts WHERE name = :name",
{"name": name},
Account,
)
async def get_all_accounts() -> list[Account]:
return await db.fetchall(
"SELECT * FROM castle.accounts ORDER BY account_type, name",
model=Account,
)
async def get_accounts_by_type(account_type: AccountType) -> list[Account]:
return await db.fetchall(
"SELECT * FROM castle.accounts WHERE account_type = :type ORDER BY name",
{"type": account_type.value},
Account,
)
async def get_or_create_user_account(
user_id: str, account_type: AccountType, base_name: str
) -> Account:
"""Get or create a user-specific account (e.g., 'Accounts Payable - User123')"""
account_name = f"{base_name} - {user_id[:8]}"
account = await db.fetchone(
"""
SELECT * FROM castle.accounts
WHERE user_id = :user_id AND account_type = :type AND name = :name
""",
{"user_id": user_id, "type": account_type.value, "name": account_name},
Account,
)
if not account:
account = await create_account(
CreateAccount(
name=account_name,
account_type=account_type,
description=f"User-specific {account_type.value} account",
user_id=user_id,
)
)
return account
# ===== JOURNAL ENTRY OPERATIONS =====
async def create_journal_entry(
data: CreateJournalEntry, created_by: str
) -> JournalEntry:
entry_id = urlsafe_short_hash()
# Validate that debits equal credits
total_debits = sum(line.debit for line in data.lines)
total_credits = sum(line.credit for line in data.lines)
if total_debits != total_credits:
raise ValueError(
f"Journal entry must balance: debits={total_debits}, credits={total_credits}"
)
entry_date = data.entry_date or datetime.now()
journal_entry = JournalEntry(
id=entry_id,
description=data.description,
entry_date=entry_date,
created_by=created_by,
created_at=datetime.now(),
reference=data.reference,
lines=[],
)
await db.insert("castle.journal_entries", journal_entry)
# Create entry lines
lines = []
for line_data in data.lines:
line_id = urlsafe_short_hash()
line = EntryLine(
id=line_id,
journal_entry_id=entry_id,
account_id=line_data.account_id,
debit=line_data.debit,
credit=line_data.credit,
description=line_data.description,
)
await db.insert("castle.entry_lines", line)
lines.append(line)
journal_entry.lines = lines
return journal_entry
async def get_journal_entry(entry_id: str) -> Optional[JournalEntry]:
entry = await db.fetchone(
"SELECT * FROM castle.journal_entries WHERE id = :id",
{"id": entry_id},
JournalEntry,
)
if entry:
entry.lines = await get_entry_lines(entry_id)
return entry
async def get_entry_lines(journal_entry_id: str) -> list[EntryLine]:
return await db.fetchall(
"SELECT * FROM castle.entry_lines WHERE journal_entry_id = :id",
{"id": journal_entry_id},
EntryLine,
)
async def get_all_journal_entries(limit: int = 100) -> list[JournalEntry]:
entries = await db.fetchall(
"""
SELECT * FROM castle.journal_entries
ORDER BY entry_date DESC, created_at DESC
LIMIT :limit
""",
{"limit": limit},
JournalEntry,
)
for entry in entries:
entry.lines = await get_entry_lines(entry.id)
return entries
async def get_journal_entries_by_user(
user_id: str, limit: int = 100
) -> list[JournalEntry]:
entries = await db.fetchall(
"""
SELECT * FROM castle.journal_entries
WHERE created_by = :user_id
ORDER BY entry_date DESC, created_at DESC
LIMIT :limit
""",
{"user_id": user_id, "limit": limit},
JournalEntry,
)
for entry in entries:
entry.lines = await get_entry_lines(entry.id)
return entries
# ===== BALANCE AND REPORTING =====
async def get_account_balance(account_id: str) -> int:
"""Calculate account balance (debits - credits for assets/expenses, credits - debits for liabilities/equity/revenue)"""
result = await db.fetchone(
"""
SELECT
COALESCE(SUM(debit), 0) as total_debit,
COALESCE(SUM(credit), 0) as total_credit
FROM castle.entry_lines
WHERE account_id = :id
""",
{"id": account_id},
)
if not result:
return 0
account = await get_account(account_id)
if not account:
return 0
total_debit = result["total_debit"]
total_credit = result["total_credit"]
# Normal balance for each account type:
# Assets and Expenses: Debit balance (debit - credit)
# Liabilities, Equity, and Revenue: Credit balance (credit - debit)
if account.account_type in [AccountType.ASSET, AccountType.EXPENSE]:
return total_debit - total_credit
else:
return total_credit - total_debit
async def get_user_balance(user_id: str) -> UserBalance:
"""Get user's balance with the Castle (positive = castle owes user, negative = user owes castle)"""
# Get all user-specific accounts
user_accounts = await db.fetchall(
"SELECT * FROM castle.accounts WHERE user_id = :user_id",
{"user_id": user_id},
Account,
)
total_balance = 0
for account in user_accounts:
balance = await get_account_balance(account.id)
# If it's a liability account (castle owes user), it's positive
# If it's an asset account (user owes castle), it's negative
if account.account_type == AccountType.LIABILITY:
total_balance += balance
elif account.account_type == AccountType.ASSET:
total_balance -= balance
# Equity contributions are tracked but don't affect what castle owes
return UserBalance(
user_id=user_id,
balance=total_balance,
accounts=user_accounts,
)
async def get_account_transactions(
account_id: str, limit: int = 100
) -> list[tuple[JournalEntry, EntryLine]]:
"""Get all transactions affecting a specific account"""
lines = await db.fetchall(
"""
SELECT * FROM castle.entry_lines
WHERE account_id = :id
ORDER BY id DESC
LIMIT :limit
""",
{"id": account_id, "limit": limit},
EntryLine,
)
transactions = []
for line in lines:
entry = await get_journal_entry(line.journal_entry_id)
if entry:
transactions.append((entry, line))
return transactions