import json from datetime import datetime from typing import Optional from lnbits.db import Database from lnbits.helpers import urlsafe_short_hash from .models import ( Account, AccountType, CastleSettings, CreateAccount, CreateEntryLine, CreateJournalEntry, EntryLine, JournalEntry, StoredUserWalletSettings, UserBalance, UserCastleSettings, UserWalletSettings, ) db = Database("ext_castle") # ===== ACCOUNT OPERATIONS ===== async def create_account(data: CreateAccount) -> Account: account_id = urlsafe_short_hash() account = Account( id=account_id, name=data.name, account_type=data.account_type, description=data.description, user_id=data.user_id, created_at=datetime.now(), ) await db.insert("accounts", account) return account async def get_account(account_id: str) -> Optional[Account]: return await db.fetchone( "SELECT * FROM accounts WHERE id = :id", {"id": account_id}, Account, ) async def get_account_by_name(name: str) -> Optional[Account]: return await db.fetchone( "SELECT * FROM accounts WHERE name = :name", {"name": name}, Account, ) async def get_all_accounts() -> list[Account]: return await db.fetchall( "SELECT * FROM accounts ORDER BY account_type, name", model=Account, ) async def get_accounts_by_type(account_type: AccountType) -> list[Account]: return await db.fetchall( "SELECT * FROM accounts WHERE account_type = :type ORDER BY name", {"type": account_type.value}, Account, ) async def get_or_create_user_account( user_id: str, account_type: AccountType, base_name: str ) -> Account: """Get or create a user-specific account (e.g., 'Accounts Payable - User123')""" account_name = f"{base_name} - {user_id[:8]}" account = await db.fetchone( """ SELECT * FROM accounts WHERE user_id = :user_id AND account_type = :type AND name = :name """, {"user_id": user_id, "type": account_type.value, "name": account_name}, Account, ) if not account: account = await create_account( CreateAccount( name=account_name, account_type=account_type, description=f"User-specific {account_type.value} account", user_id=user_id, ) ) return account # ===== JOURNAL ENTRY OPERATIONS ===== async def create_journal_entry( data: CreateJournalEntry, created_by: str ) -> JournalEntry: entry_id = urlsafe_short_hash() # Validate that debits equal credits total_debits = sum(line.debit for line in data.lines) total_credits = sum(line.credit for line in data.lines) if total_debits != total_credits: raise ValueError( f"Journal entry must balance: debits={total_debits}, credits={total_credits}" ) entry_date = data.entry_date or datetime.now() journal_entry = JournalEntry( id=entry_id, description=data.description, entry_date=entry_date, created_by=created_by, created_at=datetime.now(), reference=data.reference, lines=[], ) # Insert journal entry without the lines field (lines are stored in entry_lines table) await db.execute( """ INSERT INTO journal_entries (id, description, entry_date, created_by, created_at, reference) VALUES (:id, :description, :entry_date, :created_by, :created_at, :reference) """, { "id": journal_entry.id, "description": journal_entry.description, "entry_date": journal_entry.entry_date, "created_by": journal_entry.created_by, "created_at": journal_entry.created_at, "reference": journal_entry.reference, }, ) # Create entry lines lines = [] for line_data in data.lines: line_id = urlsafe_short_hash() line = EntryLine( id=line_id, journal_entry_id=entry_id, account_id=line_data.account_id, debit=line_data.debit, credit=line_data.credit, description=line_data.description, metadata=line_data.metadata, ) # Insert with metadata as JSON string await db.execute( """ INSERT INTO entry_lines (id, journal_entry_id, account_id, debit, credit, description, metadata) VALUES (:id, :journal_entry_id, :account_id, :debit, :credit, :description, :metadata) """, { "id": line.id, "journal_entry_id": line.journal_entry_id, "account_id": line.account_id, "debit": line.debit, "credit": line.credit, "description": line.description, "metadata": json.dumps(line.metadata), }, ) lines.append(line) journal_entry.lines = lines return journal_entry async def get_journal_entry(entry_id: str) -> Optional[JournalEntry]: entry = await db.fetchone( "SELECT * FROM journal_entries WHERE id = :id", {"id": entry_id}, JournalEntry, ) if entry: entry.lines = await get_entry_lines(entry_id) return entry async def get_entry_lines(journal_entry_id: str) -> list[EntryLine]: rows = await db.fetchall( "SELECT * FROM entry_lines WHERE journal_entry_id = :id", {"id": journal_entry_id}, ) lines = [] for row in rows: # Parse metadata from JSON string metadata = json.loads(row.metadata) if row.metadata else {} line = EntryLine( id=row.id, journal_entry_id=row.journal_entry_id, account_id=row.account_id, debit=row.debit, credit=row.credit, description=row.description, metadata=metadata, ) lines.append(line) return lines async def get_all_journal_entries(limit: int = 100) -> list[JournalEntry]: entries = await db.fetchall( """ SELECT * FROM journal_entries ORDER BY entry_date DESC, created_at DESC LIMIT :limit """, {"limit": limit}, JournalEntry, ) for entry in entries: entry.lines = await get_entry_lines(entry.id) return entries async def get_journal_entries_by_user( user_id: str, limit: int = 100 ) -> list[JournalEntry]: """Get journal entries that affect the user's accounts""" # Get all user-specific accounts user_accounts = await db.fetchall( "SELECT id FROM accounts WHERE user_id = :user_id", {"user_id": user_id}, ) if not user_accounts: return [] account_ids = [acc["id"] for acc in user_accounts] # Get all journal entries that have lines affecting these accounts # Build the IN clause with named parameters placeholders = ','.join([f":account_{i}" for i in range(len(account_ids))]) params = {f"account_{i}": acc_id for i, acc_id in enumerate(account_ids)} params["limit"] = limit entries_data = await db.fetchall( f""" SELECT DISTINCT je.* FROM journal_entries je JOIN entry_lines el ON je.id = el.journal_entry_id WHERE el.account_id IN ({placeholders}) ORDER BY je.entry_date DESC, je.created_at DESC LIMIT :limit """, params, ) entries = [] for entry_data in entries_data: entry = JournalEntry( id=entry_data["id"], description=entry_data["description"], entry_date=entry_data["entry_date"], created_by=entry_data["created_by"], created_at=entry_data["created_at"], reference=entry_data["reference"], lines=[], ) entry.lines = await get_entry_lines(entry.id) entries.append(entry) return entries # ===== BALANCE AND REPORTING ===== async def get_account_balance(account_id: str) -> int: """Calculate account balance (debits - credits for assets/expenses, credits - debits for liabilities/equity/revenue)""" result = await db.fetchone( """ SELECT COALESCE(SUM(debit), 0) as total_debit, COALESCE(SUM(credit), 0) as total_credit FROM entry_lines WHERE account_id = :id """, {"id": account_id}, ) if not result: return 0 account = await get_account(account_id) if not account: return 0 total_debit = result["total_debit"] total_credit = result["total_credit"] # Normal balance for each account type: # Assets and Expenses: Debit balance (debit - credit) # Liabilities, Equity, and Revenue: Credit balance (credit - debit) if account.account_type in [AccountType.ASSET, AccountType.EXPENSE]: return total_debit - total_credit else: return total_credit - total_debit async def get_user_balance(user_id: str) -> UserBalance: """Get user's balance with the Castle (positive = castle owes user, negative = user owes castle)""" # Get all user-specific accounts user_accounts = await db.fetchall( "SELECT * FROM accounts WHERE user_id = :user_id", {"user_id": user_id}, Account, ) total_balance = 0 fiat_balances = {} # Track fiat balances by currency for account in user_accounts: balance = await get_account_balance(account.id) # Get all entry lines for this account to calculate fiat balances entry_lines = await db.fetchall( "SELECT * FROM entry_lines WHERE account_id = :account_id", {"account_id": account.id}, ) for line in entry_lines: # Parse metadata to get fiat amounts metadata = json.loads(line["metadata"]) if line.get("metadata") else {} fiat_currency = metadata.get("fiat_currency") fiat_amount = metadata.get("fiat_amount") if fiat_currency and fiat_amount: # Initialize currency if not exists if fiat_currency not in fiat_balances: fiat_balances[fiat_currency] = 0.0 # Calculate fiat balance based on account type if account.account_type == AccountType.LIABILITY: # Liability: credit increases (castle owes more), debit decreases if line["credit"] > 0: fiat_balances[fiat_currency] += fiat_amount elif line["debit"] > 0: fiat_balances[fiat_currency] -= fiat_amount elif account.account_type == AccountType.ASSET: # Asset (receivable): debit increases (user owes more), credit decreases if line["debit"] > 0: fiat_balances[fiat_currency] -= fiat_amount elif line["credit"] > 0: fiat_balances[fiat_currency] += fiat_amount # Calculate satoshi balance # If it's a liability account (castle owes user), it's positive # If it's an asset account (user owes castle), it's negative if account.account_type == AccountType.LIABILITY: total_balance += balance elif account.account_type == AccountType.ASSET: total_balance -= balance # Equity contributions are tracked but don't affect what castle owes return UserBalance( user_id=user_id, balance=total_balance, accounts=user_accounts, fiat_balances=fiat_balances, ) async def get_all_user_balances() -> list[UserBalance]: """Get balances for all users (used by castle to see who they owe)""" # Get all user-specific accounts all_accounts = await db.fetchall( "SELECT * FROM accounts WHERE user_id IS NOT NULL", {}, Account, ) # Group by user_id users_dict = {} for account in all_accounts: if account.user_id not in users_dict: users_dict[account.user_id] = [] users_dict[account.user_id].append(account) # Calculate balance for each user user_balances = [] for user_id, accounts in users_dict.items(): total_balance = 0 fiat_balances = {} for account in accounts: balance = await get_account_balance(account.id) # Get all entry lines for this account to calculate fiat balances entry_lines = await db.fetchall( "SELECT * FROM entry_lines WHERE account_id = :account_id", {"account_id": account.id}, ) for line in entry_lines: # Parse metadata to get fiat amounts metadata = json.loads(line["metadata"]) if line.get("metadata") else {} fiat_currency = metadata.get("fiat_currency") fiat_amount = metadata.get("fiat_amount") if fiat_currency and fiat_amount: # Initialize currency if not exists if fiat_currency not in fiat_balances: fiat_balances[fiat_currency] = 0.0 # Calculate fiat balance based on account type if account.account_type == AccountType.LIABILITY: # Liability: credit increases (castle owes more), debit decreases if line["credit"] > 0: fiat_balances[fiat_currency] += fiat_amount elif line["debit"] > 0: fiat_balances[fiat_currency] -= fiat_amount elif account.account_type == AccountType.ASSET: # Asset (receivable): debit increases (user owes more), credit decreases if line["debit"] > 0: fiat_balances[fiat_currency] -= fiat_amount elif line["credit"] > 0: fiat_balances[fiat_currency] += fiat_amount # Calculate satoshi balance if account.account_type == AccountType.LIABILITY: total_balance += balance elif account.account_type == AccountType.ASSET: total_balance -= balance if total_balance != 0 or fiat_balances: # Include users with non-zero balance or fiat balances user_balances.append( UserBalance( user_id=user_id, balance=total_balance, accounts=accounts, fiat_balances=fiat_balances, ) ) return user_balances async def get_account_transactions( account_id: str, limit: int = 100 ) -> list[tuple[JournalEntry, EntryLine]]: """Get all transactions affecting a specific account""" rows = await db.fetchall( """ SELECT * FROM entry_lines WHERE account_id = :id ORDER BY id DESC LIMIT :limit """, {"id": account_id, "limit": limit}, ) transactions = [] for row in rows: # Parse metadata from JSON string metadata = json.loads(row.metadata) if row.metadata else {} line = EntryLine( id=row.id, journal_entry_id=row.journal_entry_id, account_id=row.account_id, debit=row.debit, credit=row.credit, description=row.description, metadata=metadata, ) entry = await get_journal_entry(line.journal_entry_id) if entry: transactions.append((entry, line)) return transactions # ===== SETTINGS ===== async def create_castle_settings( user_id: str, data: CastleSettings ) -> CastleSettings: settings = UserCastleSettings(**data.dict(), id=user_id) await db.insert("extension_settings", settings) return settings async def get_castle_settings(user_id: str) -> Optional[CastleSettings]: return await db.fetchone( """ SELECT * FROM extension_settings WHERE id = :user_id """, {"user_id": user_id}, CastleSettings, ) async def update_castle_settings( user_id: str, data: CastleSettings ) -> CastleSettings: settings = UserCastleSettings(**data.dict(), id=user_id) await db.update("extension_settings", settings) return settings # ===== USER WALLET SETTINGS ===== async def create_user_wallet_settings( user_id: str, data: UserWalletSettings ) -> UserWalletSettings: settings = StoredUserWalletSettings(**data.dict(), id=user_id) await db.insert("user_wallet_settings", settings) return settings async def get_user_wallet_settings(user_id: str) -> Optional[UserWalletSettings]: return await db.fetchone( """ SELECT * FROM user_wallet_settings WHERE id = :user_id """, {"user_id": user_id}, UserWalletSettings, ) async def update_user_wallet_settings( user_id: str, data: UserWalletSettings ) -> UserWalletSettings: settings = StoredUserWalletSettings(**data.dict(), id=user_id) await db.update("user_wallet_settings", settings) return settings async def get_all_user_wallet_settings() -> list[StoredUserWalletSettings]: """Get all user wallet settings""" return await db.fetchall( "SELECT * FROM user_wallet_settings ORDER BY id", {}, StoredUserWalletSettings, )