import json from datetime import datetime from typing import Optional from lnbits.db import Database from lnbits.helpers import urlsafe_short_hash from .models import ( Account, AccountPermission, AccountType, AssertionStatus, BalanceAssertion, CastleSettings, CreateAccount, CreateAccountPermission, CreateBalanceAssertion, CreateEntryLine, CreateJournalEntry, CreateUserEquityStatus, EntryLine, JournalEntry, PermissionType, StoredUserWalletSettings, UserBalance, UserCastleSettings, UserEquityStatus, UserWalletSettings, ) # Import core accounting logic from .core.balance import BalanceCalculator, AccountType as CoreAccountType from .core.inventory import CastleInventory, CastlePosition from .core.validation import ( ValidationError, validate_journal_entry, validate_balance, validate_receivable_entry, validate_expense_entry, validate_payment_entry, ) db = Database("ext_castle") # ===== ACCOUNT OPERATIONS ===== async def create_account(data: CreateAccount) -> Account: account_id = urlsafe_short_hash() account = Account( id=account_id, name=data.name, account_type=data.account_type, description=data.description, user_id=data.user_id, created_at=datetime.now(), ) await db.insert("accounts", account) return account async def get_account(account_id: str) -> Optional[Account]: return await db.fetchone( "SELECT * FROM accounts WHERE id = :id", {"id": account_id}, Account, ) async def get_account_by_name(name: str) -> Optional[Account]: """Get account by name (hierarchical format)""" return await db.fetchone( "SELECT * FROM accounts WHERE name = :name", {"name": name}, Account, ) async def get_all_accounts() -> list[Account]: return await db.fetchall( "SELECT * FROM accounts ORDER BY account_type, name", model=Account, ) async def get_accounts_by_type(account_type: AccountType) -> list[Account]: return await db.fetchall( "SELECT * FROM accounts WHERE account_type = :type ORDER BY name", {"type": account_type.value}, Account, ) async def get_or_create_user_account( user_id: str, account_type: AccountType, base_name: str ) -> Account: """ Get or create a user-specific account with hierarchical naming. Examples: get_or_create_user_account("af983632", AccountType.ASSET, "Accounts Receivable") → "Assets:Receivable:User-af983632" get_or_create_user_account("af983632", AccountType.LIABILITY, "Accounts Payable") → "Liabilities:Payable:User-af983632" """ from .account_utils import format_hierarchical_account_name # Generate hierarchical account name account_name = format_hierarchical_account_name(account_type, base_name, user_id) # Try to find existing account with this hierarchical name account = await db.fetchone( """ SELECT * FROM accounts WHERE user_id = :user_id AND account_type = :type AND name = :name """, {"user_id": user_id, "type": account_type.value, "name": account_name}, Account, ) if not account: # Create new account with hierarchical name account = await create_account( CreateAccount( name=account_name, account_type=account_type, description=f"User-specific {account_type.value} account", user_id=user_id, ) ) return account # ===== JOURNAL ENTRY OPERATIONS ===== async def create_journal_entry( data: CreateJournalEntry, created_by: str ) -> JournalEntry: entry_id = urlsafe_short_hash() # Validate that entry balances (sum of all amounts = 0) # Beancount-style: positive amounts cancel out negative amounts total_amount = sum(line.amount for line in data.lines) if total_amount != 0: raise ValueError( f"Journal entry must balance (sum of amounts = 0): sum={total_amount}" ) entry_date = data.entry_date or datetime.now() journal_entry = JournalEntry( id=entry_id, description=data.description, entry_date=entry_date, created_by=created_by, created_at=datetime.now(), reference=data.reference, lines=[], flag=data.flag, meta=data.meta, ) # Insert journal entry without the lines field (lines are stored in entry_lines table) await db.execute( """ INSERT INTO journal_entries (id, description, entry_date, created_by, created_at, reference, flag, meta) VALUES (:id, :description, :entry_date, :created_by, :created_at, :reference, :flag, :meta) """, { "id": journal_entry.id, "description": journal_entry.description, "entry_date": journal_entry.entry_date, "created_by": journal_entry.created_by, "created_at": journal_entry.created_at, "reference": journal_entry.reference, "flag": journal_entry.flag.value, "meta": json.dumps(journal_entry.meta), }, ) # Create entry lines lines = [] for line_data in data.lines: line_id = urlsafe_short_hash() line = EntryLine( id=line_id, journal_entry_id=entry_id, account_id=line_data.account_id, amount=line_data.amount, description=line_data.description, metadata=line_data.metadata, ) # Insert with metadata as JSON string await db.execute( """ INSERT INTO entry_lines (id, journal_entry_id, account_id, amount, description, metadata) VALUES (:id, :journal_entry_id, :account_id, :amount, :description, :metadata) """, { "id": line.id, "journal_entry_id": line.journal_entry_id, "account_id": line.account_id, "amount": line.amount, "description": line.description, "metadata": json.dumps(line.metadata), }, ) lines.append(line) journal_entry.lines = lines return journal_entry async def get_journal_entry(entry_id: str) -> Optional[JournalEntry]: entry = await db.fetchone( "SELECT * FROM journal_entries WHERE id = :id", {"id": entry_id}, JournalEntry, ) if entry: entry.lines = await get_entry_lines(entry_id) return entry async def get_journal_entry_by_reference(reference: str) -> Optional[JournalEntry]: """Get a journal entry by its reference field (e.g., payment_hash)""" entry = await db.fetchone( "SELECT * FROM journal_entries WHERE reference = :reference", {"reference": reference}, JournalEntry, ) if entry: entry.lines = await get_entry_lines(entry.id) return entry async def get_entry_lines(journal_entry_id: str) -> list[EntryLine]: rows = await db.fetchall( "SELECT * FROM entry_lines WHERE journal_entry_id = :id", {"id": journal_entry_id}, ) lines = [] for row in rows: # Parse metadata from JSON string metadata = json.loads(row.metadata) if row.metadata else {} line = EntryLine( id=row.id, journal_entry_id=row.journal_entry_id, account_id=row.account_id, amount=row.amount, description=row.description, metadata=metadata, ) lines.append(line) return lines async def get_all_journal_entries(limit: int = 100, offset: int = 0) -> list[JournalEntry]: entries_data = await db.fetchall( """ SELECT * FROM journal_entries ORDER BY entry_date DESC, created_at DESC LIMIT :limit OFFSET :offset """, {"limit": limit, "offset": offset}, ) entries = [] for entry_data in entries_data: # Parse flag and meta from database from .models import JournalEntryFlag flag = JournalEntryFlag(entry_data.get("flag", "*")) meta = json.loads(entry_data.get("meta", "{}")) if entry_data.get("meta") else {} entry = JournalEntry( id=entry_data["id"], description=entry_data["description"], entry_date=entry_data["entry_date"], created_by=entry_data["created_by"], created_at=entry_data["created_at"], reference=entry_data["reference"], flag=flag, meta=meta, lines=[], ) entry.lines = await get_entry_lines(entry.id) entries.append(entry) return entries async def get_journal_entries_by_user( user_id: str, limit: int = 100, offset: int = 0 ) -> list[JournalEntry]: """Get journal entries that affect the user's accounts""" # Get all user-specific accounts user_accounts = await db.fetchall( "SELECT id FROM accounts WHERE user_id = :user_id", {"user_id": user_id}, ) if not user_accounts: return [] account_ids = [acc["id"] for acc in user_accounts] # Get all journal entries that have lines affecting these accounts # Build the IN clause with named parameters placeholders = ','.join([f":account_{i}" for i in range(len(account_ids))]) params = {f"account_{i}": acc_id for i, acc_id in enumerate(account_ids)} params["limit"] = limit params["offset"] = offset entries_data = await db.fetchall( f""" SELECT DISTINCT je.* FROM journal_entries je JOIN entry_lines el ON je.id = el.journal_entry_id WHERE el.account_id IN ({placeholders}) ORDER BY je.entry_date DESC, je.created_at DESC LIMIT :limit OFFSET :offset """, params, ) entries = [] for entry_data in entries_data: # Parse flag and meta from database from .models import JournalEntryFlag flag = JournalEntryFlag(entry_data.get("flag", "*")) meta = json.loads(entry_data.get("meta", "{}")) if entry_data.get("meta") else {} entry = JournalEntry( id=entry_data["id"], description=entry_data["description"], entry_date=entry_data["entry_date"], created_by=entry_data["created_by"], created_at=entry_data["created_at"], reference=entry_data["reference"], flag=flag, meta=meta, lines=[], ) entry.lines = await get_entry_lines(entry.id) entries.append(entry) return entries async def count_all_journal_entries() -> int: """Count total number of journal entries""" result = await db.fetchone( "SELECT COUNT(*) as total FROM journal_entries" ) return result["total"] if result else 0 async def count_journal_entries_by_user(user_id: str) -> int: """Count journal entries that affect the user's accounts""" # Get all user-specific accounts user_accounts = await db.fetchall( "SELECT id FROM accounts WHERE user_id = :user_id", {"user_id": user_id}, ) if not user_accounts: return 0 account_ids = [acc["id"] for acc in user_accounts] # Count journal entries that have lines affecting these accounts placeholders = ','.join([f":account_{i}" for i in range(len(account_ids))]) params = {f"account_{i}": acc_id for i, acc_id in enumerate(account_ids)} result = await db.fetchone( f""" SELECT COUNT(DISTINCT je.id) as total FROM journal_entries je JOIN entry_lines el ON je.id = el.journal_entry_id WHERE el.account_id IN ({placeholders}) """, params, ) return result["total"] if result else 0 # ===== BALANCE AND REPORTING ===== async def get_account_balance(account_id: str) -> int: """ Calculate account balance using single amount field (Beancount-style). Only includes entries that are cleared (flag='*'), excludes pending/flagged/voided entries. For each account type: - Assets/Expenses: balance = sum of amounts (positive amounts increase, negative decrease) - Liabilities/Equity/Revenue: balance = -sum of amounts (negative amounts increase, positive decrease) This works because we store amounts consistently: - Debit (asset/expense increase) = positive amount - Credit (liability/equity/revenue increase) = negative amount """ result = await db.fetchone( """ SELECT COALESCE(SUM(el.amount), 0) as total_amount FROM entry_lines el JOIN journal_entries je ON el.journal_entry_id = je.id WHERE el.account_id = :id AND je.flag = '*' """, {"id": account_id}, ) if not result: return 0 account = await get_account(account_id) if not account: return 0 total_amount = result["total_amount"] # Use core BalanceCalculator for consistent logic core_account_type = CoreAccountType(account.account_type.value) return BalanceCalculator.calculate_account_balance_from_amount( total_amount, core_account_type ) async def get_user_balance(user_id: str) -> UserBalance: """Get user's balance with the Castle (positive = castle owes user, negative = user owes castle)""" # Get all user-specific accounts user_accounts = await db.fetchall( "SELECT * FROM accounts WHERE user_id = :user_id", {"user_id": user_id}, Account, ) # Calculate balances for each account account_balances = {} account_inventories = {} for account in user_accounts: # Get satoshi balance balance = await get_account_balance(account.id) account_balances[account.id] = balance # Get all entry lines for this account to build inventory # Only include cleared entries (exclude pending/flagged/voided) entry_lines = await db.fetchall( """ SELECT el.* FROM entry_lines el JOIN journal_entries je ON el.journal_entry_id = je.id WHERE el.account_id = :account_id AND je.flag = '*' """, {"account_id": account.id}, ) # Use BalanceCalculator to build inventory from entry lines core_account_type = CoreAccountType(account.account_type.value) inventory = BalanceCalculator.build_inventory_from_entry_lines( [dict(line) for line in entry_lines], core_account_type ) account_inventories[account.id] = inventory # Use BalanceCalculator to calculate total user balance accounts_list = [ {"id": acc.id, "account_type": acc.account_type.value} for acc in user_accounts ] balance_result = BalanceCalculator.calculate_user_balance( accounts_list, account_balances, account_inventories ) return UserBalance( user_id=user_id, balance=balance_result["balance"], accounts=user_accounts, fiat_balances=balance_result["fiat_balances"], ) async def get_all_user_balances() -> list[UserBalance]: """Get balances for all users (used by castle to see who they owe)""" # Get all user-specific accounts all_accounts = await db.fetchall( "SELECT * FROM accounts WHERE user_id IS NOT NULL", {}, Account, ) # Get unique user IDs user_ids = set(account.user_id for account in all_accounts if account.user_id) # Calculate balance for each user using the refactored function user_balances = [] for user_id in user_ids: balance = await get_user_balance(user_id) # Include users with non-zero balance or fiat balances if balance.balance != 0 or balance.fiat_balances: user_balances.append(balance) return user_balances async def get_account_transactions( account_id: str, limit: int = 100 ) -> list[tuple[JournalEntry, EntryLine]]: """Get all transactions affecting a specific account""" rows = await db.fetchall( """ SELECT * FROM entry_lines WHERE account_id = :id ORDER BY id DESC LIMIT :limit """, {"id": account_id, "limit": limit}, ) transactions = [] for row in rows: # Parse metadata from JSON string metadata = json.loads(row.metadata) if row.metadata else {} line = EntryLine( id=row.id, journal_entry_id=row.journal_entry_id, account_id=row.account_id, amount=row.amount, description=row.description, metadata=metadata, ) entry = await get_journal_entry(line.journal_entry_id) if entry: transactions.append((entry, line)) return transactions # ===== SETTINGS ===== async def create_castle_settings( user_id: str, data: CastleSettings ) -> CastleSettings: settings = UserCastleSettings(**data.dict(), id=user_id) await db.insert("extension_settings", settings) return settings async def get_castle_settings(user_id: str) -> Optional[CastleSettings]: return await db.fetchone( """ SELECT * FROM extension_settings WHERE id = :user_id """, {"user_id": user_id}, CastleSettings, ) async def update_castle_settings( user_id: str, data: CastleSettings ) -> CastleSettings: settings = UserCastleSettings(**data.dict(), id=user_id) await db.update("extension_settings", settings) return settings # ===== USER WALLET SETTINGS ===== async def create_user_wallet_settings( user_id: str, data: UserWalletSettings ) -> UserWalletSettings: settings = StoredUserWalletSettings(**data.dict(), id=user_id) await db.insert("user_wallet_settings", settings) return settings async def get_user_wallet_settings(user_id: str) -> Optional[UserWalletSettings]: return await db.fetchone( """ SELECT * FROM user_wallet_settings WHERE id = :user_id """, {"user_id": user_id}, UserWalletSettings, ) async def update_user_wallet_settings( user_id: str, data: UserWalletSettings ) -> UserWalletSettings: settings = StoredUserWalletSettings(**data.dict(), id=user_id) await db.update("user_wallet_settings", settings) return settings async def get_all_user_wallet_settings() -> list[StoredUserWalletSettings]: """Get all user wallet settings""" return await db.fetchall( "SELECT * FROM user_wallet_settings ORDER BY id", {}, StoredUserWalletSettings, ) # ===== MANUAL PAYMENT REQUESTS ===== async def create_manual_payment_request( user_id: str, amount: int, description: str ) -> "ManualPaymentRequest": """Create a new manual payment request""" from .models import ManualPaymentRequest request_id = urlsafe_short_hash() request = ManualPaymentRequest( id=request_id, user_id=user_id, amount=amount, description=description, status="pending", created_at=datetime.now(), ) await db.execute( """ INSERT INTO manual_payment_requests (id, user_id, amount, description, status, created_at) VALUES (:id, :user_id, :amount, :description, :status, :created_at) """, { "id": request.id, "user_id": request.user_id, "amount": request.amount, "description": request.description, "status": request.status, "created_at": request.created_at, }, ) return request async def get_manual_payment_request(request_id: str) -> Optional["ManualPaymentRequest"]: """Get a manual payment request by ID""" from .models import ManualPaymentRequest return await db.fetchone( "SELECT * FROM manual_payment_requests WHERE id = :id", {"id": request_id}, ManualPaymentRequest, ) async def get_user_manual_payment_requests( user_id: str, limit: int = 100 ) -> list["ManualPaymentRequest"]: """Get all manual payment requests for a specific user""" from .models import ManualPaymentRequest return await db.fetchall( """ SELECT * FROM manual_payment_requests WHERE user_id = :user_id ORDER BY created_at DESC LIMIT :limit """, {"user_id": user_id, "limit": limit}, ManualPaymentRequest, ) async def get_all_manual_payment_requests( status: Optional[str] = None, limit: int = 100 ) -> list["ManualPaymentRequest"]: """Get all manual payment requests, optionally filtered by status""" from .models import ManualPaymentRequest if status: return await db.fetchall( """ SELECT * FROM manual_payment_requests WHERE status = :status ORDER BY created_at DESC LIMIT :limit """, {"status": status, "limit": limit}, ManualPaymentRequest, ) else: return await db.fetchall( """ SELECT * FROM manual_payment_requests ORDER BY created_at DESC LIMIT :limit """, {"limit": limit}, ManualPaymentRequest, ) async def approve_manual_payment_request( request_id: str, reviewed_by: str, journal_entry_id: str ) -> Optional["ManualPaymentRequest"]: """Approve a manual payment request""" from .models import ManualPaymentRequest await db.execute( """ UPDATE manual_payment_requests SET status = 'approved', reviewed_at = :reviewed_at, reviewed_by = :reviewed_by, journal_entry_id = :journal_entry_id WHERE id = :id """, { "id": request_id, "reviewed_at": datetime.now(), "reviewed_by": reviewed_by, "journal_entry_id": journal_entry_id, }, ) return await get_manual_payment_request(request_id) async def reject_manual_payment_request( request_id: str, reviewed_by: str ) -> Optional["ManualPaymentRequest"]: """Reject a manual payment request""" from .models import ManualPaymentRequest await db.execute( """ UPDATE manual_payment_requests SET status = 'rejected', reviewed_at = :reviewed_at, reviewed_by = :reviewed_by WHERE id = :id """, { "id": request_id, "reviewed_at": datetime.now(), "reviewed_by": reviewed_by, }, ) return await get_manual_payment_request(request_id) # ===== BALANCE ASSERTION OPERATIONS ===== async def create_balance_assertion( data: CreateBalanceAssertion, created_by: str ) -> BalanceAssertion: """Create a new balance assertion""" from decimal import Decimal assertion_id = urlsafe_short_hash() assertion_date = data.date if data.date else datetime.now() assertion = BalanceAssertion( id=assertion_id, date=assertion_date, account_id=data.account_id, expected_balance_sats=data.expected_balance_sats, expected_balance_fiat=data.expected_balance_fiat, fiat_currency=data.fiat_currency, tolerance_sats=data.tolerance_sats, tolerance_fiat=data.tolerance_fiat, status=AssertionStatus.PENDING, created_by=created_by, created_at=datetime.now(), ) # Manually insert with Decimal fields converted to strings await db.execute( """ INSERT INTO balance_assertions ( id, date, account_id, expected_balance_sats, expected_balance_fiat, fiat_currency, tolerance_sats, tolerance_fiat, status, created_by, created_at ) VALUES ( :id, :date, :account_id, :expected_balance_sats, :expected_balance_fiat, :fiat_currency, :tolerance_sats, :tolerance_fiat, :status, :created_by, :created_at ) """, { "id": assertion.id, "date": assertion.date, "account_id": assertion.account_id, "expected_balance_sats": assertion.expected_balance_sats, "expected_balance_fiat": str(assertion.expected_balance_fiat) if assertion.expected_balance_fiat else None, "fiat_currency": assertion.fiat_currency, "tolerance_sats": assertion.tolerance_sats, "tolerance_fiat": str(assertion.tolerance_fiat), "status": assertion.status.value, "created_by": assertion.created_by, "created_at": assertion.created_at, }, ) return assertion async def get_balance_assertion(assertion_id: str) -> Optional[BalanceAssertion]: """Get a balance assertion by ID""" from decimal import Decimal row = await db.fetchone( "SELECT * FROM balance_assertions WHERE id = :id", {"id": assertion_id}, ) if not row: return None # Parse Decimal fields from TEXT storage return BalanceAssertion( id=row["id"], date=row["date"], account_id=row["account_id"], expected_balance_sats=row["expected_balance_sats"], expected_balance_fiat=Decimal(row["expected_balance_fiat"]) if row["expected_balance_fiat"] else None, fiat_currency=row["fiat_currency"], tolerance_sats=row["tolerance_sats"], tolerance_fiat=Decimal(row["tolerance_fiat"]) if row["tolerance_fiat"] else Decimal("0"), checked_balance_sats=row["checked_balance_sats"], checked_balance_fiat=Decimal(row["checked_balance_fiat"]) if row["checked_balance_fiat"] else None, difference_sats=row["difference_sats"], difference_fiat=Decimal(row["difference_fiat"]) if row["difference_fiat"] else None, status=AssertionStatus(row["status"]), created_by=row["created_by"], created_at=row["created_at"], checked_at=row["checked_at"], ) async def get_balance_assertions( account_id: Optional[str] = None, status: Optional[AssertionStatus] = None, limit: int = 100, ) -> list[BalanceAssertion]: """Get balance assertions with optional filters""" from decimal import Decimal if account_id and status: rows = await db.fetchall( """ SELECT * FROM balance_assertions WHERE account_id = :account_id AND status = :status ORDER BY date DESC LIMIT :limit """, {"account_id": account_id, "status": status.value, "limit": limit}, ) elif account_id: rows = await db.fetchall( """ SELECT * FROM balance_assertions WHERE account_id = :account_id ORDER BY date DESC LIMIT :limit """, {"account_id": account_id, "limit": limit}, ) elif status: rows = await db.fetchall( """ SELECT * FROM balance_assertions WHERE status = :status ORDER BY date DESC LIMIT :limit """, {"status": status.value, "limit": limit}, ) else: rows = await db.fetchall( """ SELECT * FROM balance_assertions ORDER BY date DESC LIMIT :limit """, {"limit": limit}, ) assertions = [] for row in rows: assertions.append( BalanceAssertion( id=row["id"], date=row["date"], account_id=row["account_id"], expected_balance_sats=row["expected_balance_sats"], expected_balance_fiat=Decimal(row["expected_balance_fiat"]) if row["expected_balance_fiat"] else None, fiat_currency=row["fiat_currency"], tolerance_sats=row["tolerance_sats"], tolerance_fiat=Decimal(row["tolerance_fiat"]) if row["tolerance_fiat"] else Decimal("0"), checked_balance_sats=row["checked_balance_sats"], checked_balance_fiat=Decimal(row["checked_balance_fiat"]) if row["checked_balance_fiat"] else None, difference_sats=row["difference_sats"], difference_fiat=Decimal(row["difference_fiat"]) if row["difference_fiat"] else None, status=AssertionStatus(row["status"]), created_by=row["created_by"], created_at=row["created_at"], checked_at=row["checked_at"], ) ) return assertions async def check_balance_assertion(assertion_id: str) -> BalanceAssertion: """ Check a balance assertion by comparing expected vs actual balance. Updates the assertion with the check results. """ from decimal import Decimal assertion = await get_balance_assertion(assertion_id) if not assertion: raise ValueError(f"Balance assertion {assertion_id} not found") # Get actual account balance account = await get_account(assertion.account_id) if not account: raise ValueError(f"Account {assertion.account_id} not found") # Calculate balance at the assertion date actual_balance = await get_account_balance(assertion.account_id) # Get fiat balance if needed actual_fiat_balance = None if assertion.fiat_currency and account.user_id: user_balance = await get_user_balance(account.user_id) actual_fiat_balance = user_balance.fiat_balances.get(assertion.fiat_currency, Decimal("0")) # Check sats balance difference_sats = actual_balance - assertion.expected_balance_sats sats_match = abs(difference_sats) <= assertion.tolerance_sats # Check fiat balance if applicable fiat_match = True difference_fiat = None if assertion.expected_balance_fiat is not None and actual_fiat_balance is not None: difference_fiat = actual_fiat_balance - assertion.expected_balance_fiat fiat_match = abs(difference_fiat) <= assertion.tolerance_fiat # Determine overall status status = AssertionStatus.PASSED if (sats_match and fiat_match) else AssertionStatus.FAILED # Update assertion with check results await db.execute( """ UPDATE balance_assertions SET checked_balance_sats = :checked_sats, checked_balance_fiat = :checked_fiat, difference_sats = :diff_sats, difference_fiat = :diff_fiat, status = :status, checked_at = :checked_at WHERE id = :id """, { "id": assertion_id, "checked_sats": actual_balance, "checked_fiat": str(actual_fiat_balance) if actual_fiat_balance is not None else None, "diff_sats": difference_sats, "diff_fiat": str(difference_fiat) if difference_fiat is not None else None, "status": status.value, "checked_at": datetime.now(), }, ) # Return updated assertion return await get_balance_assertion(assertion_id) async def delete_balance_assertion(assertion_id: str) -> None: """Delete a balance assertion""" await db.execute( "DELETE FROM balance_assertions WHERE id = :id", {"id": assertion_id}, ) # User Equity Status CRUD operations async def get_user_equity_status(user_id: str) -> Optional["UserEquityStatus"]: """Get user's equity eligibility status""" from .models import UserEquityStatus row = await db.fetchone( """ SELECT * FROM user_equity_status WHERE user_id = :user_id """, {"user_id": user_id}, ) return UserEquityStatus(**row) if row else None async def create_or_update_user_equity_status( data: "CreateUserEquityStatus", granted_by: str ) -> "UserEquityStatus": """Create or update user equity eligibility status""" from datetime import datetime from .models import UserEquityStatus, AccountType import uuid # Auto-create user-specific equity account if granting eligibility if data.is_equity_eligible: # Generate equity account name: Equity:User-{user_id} equity_account_name = f"Equity:User-{data.user_id[:8]}" # Check if the equity account already exists equity_account = await get_account_by_name(equity_account_name) if not equity_account: # Create the user-specific equity account await db.execute( """ INSERT INTO accounts (id, name, account_type, description, user_id, created_at) VALUES (:id, :name, :type, :description, :user_id, :created_at) """, { "id": str(uuid.uuid4()), "name": equity_account_name, "type": AccountType.EQUITY.value, "description": f"Equity contributions for user {data.user_id[:8]}", "user_id": data.user_id, "created_at": datetime.now(), }, ) # Auto-populate equity_account_name in the data data.equity_account_name = equity_account_name # Check if user already has equity status existing = await get_user_equity_status(data.user_id) if existing: # Update existing record await db.execute( """ UPDATE user_equity_status SET is_equity_eligible = :is_equity_eligible, equity_account_name = :equity_account_name, notes = :notes, granted_by = :granted_by, granted_at = :granted_at, revoked_at = :revoked_at WHERE user_id = :user_id """, { "user_id": data.user_id, "is_equity_eligible": data.is_equity_eligible, "equity_account_name": data.equity_account_name, "notes": data.notes, "granted_by": granted_by, "granted_at": datetime.now(), "revoked_at": None if data.is_equity_eligible else datetime.now(), }, ) else: # Create new record await db.execute( """ INSERT INTO user_equity_status ( user_id, is_equity_eligible, equity_account_name, notes, granted_by, granted_at ) VALUES ( :user_id, :is_equity_eligible, :equity_account_name, :notes, :granted_by, :granted_at ) """, { "user_id": data.user_id, "is_equity_eligible": data.is_equity_eligible, "equity_account_name": data.equity_account_name, "notes": data.notes, "granted_by": granted_by, "granted_at": datetime.now(), }, ) # Return the created/updated record result = await get_user_equity_status(data.user_id) if not result: raise ValueError(f"Failed to create/update equity status for user {data.user_id}") return result async def revoke_user_equity_eligibility(user_id: str) -> Optional["UserEquityStatus"]: """Revoke user's equity contribution eligibility""" from datetime import datetime await db.execute( """ UPDATE user_equity_status SET is_equity_eligible = FALSE, revoked_at = :revoked_at WHERE user_id = :user_id """, {"user_id": user_id, "revoked_at": datetime.now()}, ) return await get_user_equity_status(user_id) async def get_all_equity_eligible_users() -> list["UserEquityStatus"]: """Get all equity-eligible users""" from .models import UserEquityStatus rows = await db.fetchall( """ SELECT * FROM user_equity_status WHERE is_equity_eligible = TRUE ORDER BY granted_at DESC """ ) return [UserEquityStatus(**row) for row in rows] # ===== ACCOUNT PERMISSION OPERATIONS ===== async def create_account_permission( data: "CreateAccountPermission", granted_by: str ) -> "AccountPermission": """Create a new account permission""" from .models import AccountPermission permission_id = urlsafe_short_hash() permission = AccountPermission( id=permission_id, user_id=data.user_id, account_id=data.account_id, permission_type=data.permission_type, granted_by=granted_by, granted_at=datetime.now(), expires_at=data.expires_at, notes=data.notes, ) await db.execute( """ INSERT INTO account_permissions ( id, user_id, account_id, permission_type, granted_by, granted_at, expires_at, notes ) VALUES ( :id, :user_id, :account_id, :permission_type, :granted_by, :granted_at, :expires_at, :notes ) """, { "id": permission.id, "user_id": permission.user_id, "account_id": permission.account_id, "permission_type": permission.permission_type.value, "granted_by": permission.granted_by, "granted_at": permission.granted_at, "expires_at": permission.expires_at, "notes": permission.notes, }, ) return permission async def get_account_permission(permission_id: str) -> Optional["AccountPermission"]: """Get account permission by ID""" from .models import AccountPermission, PermissionType row = await db.fetchone( "SELECT * FROM account_permissions WHERE id = :id", {"id": permission_id}, ) if not row: return None return AccountPermission( id=row["id"], user_id=row["user_id"], account_id=row["account_id"], permission_type=PermissionType(row["permission_type"]), granted_by=row["granted_by"], granted_at=row["granted_at"], expires_at=row["expires_at"], notes=row["notes"], ) async def get_user_permissions( user_id: str, permission_type: Optional["PermissionType"] = None ) -> list["AccountPermission"]: """Get all permissions for a specific user""" from .models import AccountPermission, PermissionType if permission_type: rows = await db.fetchall( """ SELECT * FROM account_permissions WHERE user_id = :user_id AND permission_type = :permission_type AND (expires_at IS NULL OR expires_at > :now) ORDER BY granted_at DESC """, { "user_id": user_id, "permission_type": permission_type.value, "now": datetime.now(), }, ) else: rows = await db.fetchall( """ SELECT * FROM account_permissions WHERE user_id = :user_id AND (expires_at IS NULL OR expires_at > :now) ORDER BY granted_at DESC """, {"user_id": user_id, "now": datetime.now()}, ) return [ AccountPermission( id=row["id"], user_id=row["user_id"], account_id=row["account_id"], permission_type=PermissionType(row["permission_type"]), granted_by=row["granted_by"], granted_at=row["granted_at"], expires_at=row["expires_at"], notes=row["notes"], ) for row in rows ] async def get_account_permissions(account_id: str) -> list["AccountPermission"]: """Get all permissions for a specific account""" from .models import AccountPermission, PermissionType rows = await db.fetchall( """ SELECT * FROM account_permissions WHERE account_id = :account_id AND (expires_at IS NULL OR expires_at > :now) ORDER BY granted_at DESC """, {"account_id": account_id, "now": datetime.now()}, ) return [ AccountPermission( id=row["id"], user_id=row["user_id"], account_id=row["account_id"], permission_type=PermissionType(row["permission_type"]), granted_by=row["granted_by"], granted_at=row["granted_at"], expires_at=row["expires_at"], notes=row["notes"], ) for row in rows ] async def delete_account_permission(permission_id: str) -> None: """Delete (revoke) an account permission""" await db.execute( "DELETE FROM account_permissions WHERE id = :id", {"id": permission_id}, ) async def check_user_has_permission( user_id: str, account_id: str, permission_type: "PermissionType" ) -> bool: """Check if user has a specific permission on an account (direct permission only, no inheritance)""" row = await db.fetchone( """ SELECT id FROM account_permissions WHERE user_id = :user_id AND account_id = :account_id AND permission_type = :permission_type AND (expires_at IS NULL OR expires_at > :now) """, { "user_id": user_id, "account_id": account_id, "permission_type": permission_type.value, "now": datetime.now(), }, ) return row is not None async def get_user_permissions_with_inheritance( user_id: str, account_name: str, permission_type: "PermissionType" ) -> list[tuple["AccountPermission", Optional[str]]]: """ Get all permissions for a user on an account, including inherited permissions from parent accounts. Returns list of tuples: (permission, parent_account_name or None) Example: If user has permission on "Expenses:Food", they also have permission on "Expenses:Food:Groceries" Returns: [(permission_on_food, "Expenses:Food")] """ from .models import AccountPermission, PermissionType # Get all user's permissions of this type user_permissions = await get_user_permissions(user_id, permission_type) # Find which permissions apply to this account (direct or inherited) applicable_permissions = [] for perm in user_permissions: # Get the account for this permission account = await get_account(perm.account_id) if not account: continue # Check if this account is a parent of the target account # Parent accounts are indicated by hierarchical names (colon-separated) # e.g., "Expenses:Food" is parent of "Expenses:Food:Groceries" if account_name == account.name: # Direct permission applicable_permissions.append((perm, None)) elif account_name.startswith(account.name + ":"): # Inherited permission from parent account applicable_permissions.append((perm, account.name)) return applicable_permissions