import json from datetime import datetime from typing import Optional import httpx from lnbits.db import Database from lnbits.helpers import urlsafe_short_hash from lnbits.utils.cache import Cache from .models import ( Account, AccountPermission, AccountType, AssertionStatus, AssignUserRole, BalanceAssertion, CastleSettings, CreateAccount, CreateAccountPermission, CreateBalanceAssertion, CreateEntryLine, CreateJournalEntry, CreateRole, CreateRolePermission, CreateUserEquityStatus, EntryLine, JournalEntry, PermissionType, Role, RolePermission, RoleWithPermissions, StoredUserWalletSettings, UpdateRole, UserBalance, UserCastleSettings, UserEquityStatus, UserRole, UserWalletSettings, UserWithRoles, ) # Import core accounting logic from .core.validation import ( ValidationError, validate_journal_entry, validate_balance, validate_receivable_entry, validate_expense_entry, validate_payment_entry, ) db = Database("ext_castle") # ===== CACHING ===== # Cache for account and permission lookups to reduce DB queries # TTLs: accounts=300s (5min), permissions=60s (1min) account_cache = Cache() # 5 minutes for accounts (rarely change) permission_cache = Cache() # 1 minute for permissions (may change frequently) # Cache TTLs ACCOUNT_CACHE_TTL = 300 # 5 minutes PERMISSION_CACHE_TTL = 60 # 1 minute # ===== ACCOUNT OPERATIONS ===== async def create_account(data: CreateAccount) -> Account: account_id = urlsafe_short_hash() account = Account( id=account_id, name=data.name, account_type=data.account_type, description=data.description, user_id=data.user_id, is_virtual=data.is_virtual, created_at=datetime.now(), ) await db.insert("accounts", account) # Invalidate cache for this account (Cache class doesn't have delete method, use pop) account_cache._values.pop(f"account:id:{account_id}", None) account_cache._values.pop(f"account:name:{data.name}", None) return account async def get_account(account_id: str) -> Optional[Account]: """Get account by ID with caching""" cache_key = f"account:id:{account_id}" # Try cache first cached = account_cache.get(cache_key) if cached is not None: return cached # Query DB account = await db.fetchone( "SELECT * FROM accounts WHERE id = :id", {"id": account_id}, Account, ) # Cache result (even if None) account_cache.set(cache_key, account, ACCOUNT_CACHE_TTL) return account async def get_account_by_name(name: str) -> Optional[Account]: """Get account by name (hierarchical format) with caching""" cache_key = f"account:name:{name}" # Try cache first cached = account_cache.get(cache_key) if cached is not None: return cached # Query DB account = await db.fetchone( "SELECT * FROM accounts WHERE name = :name", {"name": name}, Account, ) # Cache result (even if None) account_cache.set(cache_key, account, ACCOUNT_CACHE_TTL) return account async def get_all_accounts(include_inactive: bool = False) -> list[Account]: """ Get all accounts, optionally including inactive ones. Args: include_inactive: If True, include inactive accounts. Default False. Returns: List of Account objects """ if include_inactive: query = "SELECT * FROM accounts ORDER BY account_type, name" else: query = "SELECT * FROM accounts WHERE is_active = TRUE ORDER BY account_type, name" return await db.fetchall(query, model=Account) async def get_accounts_by_type( account_type: AccountType, include_inactive: bool = False ) -> list[Account]: """ Get accounts by type, optionally including inactive ones. Args: account_type: The account type to filter by include_inactive: If True, include inactive accounts. Default False. Returns: List of Account objects """ if include_inactive: query = "SELECT * FROM accounts WHERE account_type = :type ORDER BY name" else: query = "SELECT * FROM accounts WHERE account_type = :type AND is_active = TRUE ORDER BY name" return await db.fetchall(query, {"type": account_type.value}, Account) async def update_account_is_active(account_id: str, is_active: bool) -> None: """ Update the is_active status of an account (soft delete/reactivate). Args: account_id: Account ID to update is_active: True to activate, False to deactivate """ await db.execute( """ UPDATE accounts SET is_active = :is_active WHERE id = :account_id """, {"account_id": account_id, "is_active": is_active}, ) # Invalidate cache account_cache._values.pop(f"account:id:{account_id}", None) async def get_or_create_user_account( user_id: str, account_type: AccountType, base_name: str ) -> Account: """ Get or create a user-specific account with hierarchical naming. This function checks if the account exists in Fava/Beancount and creates it if it doesn't exist. The account is also registered in Castle's database for metadata tracking (permissions, descriptions, etc.). Examples: get_or_create_user_account("af983632", AccountType.ASSET, "Accounts Receivable") → "Assets:Receivable:User-af983632" get_or_create_user_account("af983632", AccountType.LIABILITY, "Accounts Payable") → "Liabilities:Payable:User-af983632" """ from .account_utils import format_hierarchical_account_name from .fava_client import get_fava_client from loguru import logger # Generate hierarchical account name account_name = format_hierarchical_account_name(account_type, base_name, user_id) # Try to find existing account with this hierarchical name in Castle DB account = await db.fetchone( """ SELECT * FROM accounts WHERE user_id = :user_id AND account_type = :type AND name = :name """, {"user_id": user_id, "type": account_type.value, "name": account_name}, Account, ) logger.info(f"[ACCOUNT CHECK] User {user_id[:8]}, Account: {account_name}, In Castle DB: {account is not None}") # Always check/create in Fava, even if account exists in Castle DB # This ensures Beancount has the Open directive fava_account_exists = False if True: # Always check Fava # Check if account exists in Fava/Beancount fava = get_fava_client() try: # Query Fava for this account query = f"SELECT account WHERE account = '{account_name}'" async with httpx.AsyncClient(timeout=5.0) as client: response = await client.get( f"{fava.base_url}/query", params={"query_string": query} ) response.raise_for_status() result = response.json() # Check if account exists in Fava fava_account_exists = len(result.get("data", {}).get("rows", [])) > 0 logger.info(f"[FAVA CHECK] Account {account_name} exists in Fava: {fava_account_exists}") if not fava_account_exists: # Create account in Fava/Beancount via Open directive logger.info(f"[FAVA CREATE] Creating account in Fava: {account_name}") await fava.add_account( account_name=account_name, currencies=["EUR", "SATS", "USD"], # Support common currencies metadata={ "user_id": user_id, "description": f"User-specific {account_type.value} account" } ) logger.info(f"[FAVA CREATE] Successfully created account in Fava: {account_name}") except Exception as e: logger.error(f"[FAVA ERROR] Could not check/create account in Fava: {e}", exc_info=True) # Continue anyway - account creation in Castle DB is still useful for metadata # Ensure account exists in Castle DB (sync from Beancount if needed) # This uses the account sync module for consistency if not account: logger.info(f"[CASTLE DB] Syncing account from Beancount to Castle DB: {account_name}") from .account_sync import sync_single_account_from_beancount # Sync from Beancount to Castle DB created = await sync_single_account_from_beancount(account_name) if created: logger.info(f"[CASTLE DB] Account synced from Beancount: {account_name}") else: logger.warning(f"[CASTLE DB] Failed to sync account from Beancount: {account_name}") # Fetch the account from Castle DB account = await db.fetchone( """ SELECT * FROM accounts WHERE user_id = :user_id AND account_type = :type AND name = :name """, {"user_id": user_id, "type": account_type.value, "name": account_name}, Account, ) if not account: logger.error(f"[CASTLE DB] Account still not found after sync: {account_name}") # Fallback: create directly in Castle DB if sync failed logger.info(f"[CASTLE DB] Creating account directly in Castle DB: {account_name}") try: account = await create_account( CreateAccount( name=account_name, account_type=account_type, description=f"User-specific {account_type.value} account", user_id=user_id, ) ) except Exception as e: # Handle UNIQUE constraint error - account already exists if "UNIQUE constraint failed" in str(e) and "accounts.name" in str(e): logger.warning(f"[CASTLE DB] Account already exists (UNIQUE constraint), fetching by name: {account_name}") # Fetch existing account by name only (ignore user_id in query) account = await db.fetchone( """ SELECT * FROM accounts WHERE name = :name """, {"name": account_name}, Account, ) if account: logger.info(f"[CASTLE DB] Found existing account: {account_name} (user_id: {account.user_id})") # Update user_id if it's NULL or different if account.user_id != user_id: logger.info(f"[CASTLE DB] Updating account user_id from {account.user_id} to {user_id}") await db.execute( """ UPDATE accounts SET user_id = :user_id WHERE name = :name """, {"user_id": user_id, "name": account_name} ) # Refresh account from DB account = await db.fetchone( """ SELECT * FROM accounts WHERE name = :name """, {"name": account_name}, Account, ) else: # Re-raise if it's a different error raise else: logger.info(f"[CASTLE DB] Account already exists in Castle DB: {account_name}") return account # ===== JOURNAL ENTRY OPERATIONS ===== # ===== JOURNAL ENTRY OPERATIONS (REMOVED) ===== # # All journal entry operations have been moved to Fava/Beancount. # Castle no longer maintains its own journal_entries and entry_lines tables. # # For journal entry operations, see: # - views_api.py: api_create_journal_entry() - writes to Fava via FavaClient # - views_api.py: API endpoints query Fava via FavaClient for reading entries # # Migration: m016_drop_obsolete_journal_tables # Removed functions: # - create_journal_entry() # - get_journal_entry() # - get_journal_entry_by_reference() # - get_entry_lines() # - get_all_journal_entries() # - get_journal_entries_by_user() # - count_all_journal_entries() # - count_journal_entries_by_user() # - get_journal_entries_by_user_and_account_type() # - count_journal_entries_by_user_and_account_type() # - get_account_transactions() # ===== SETTINGS ===== async def create_castle_settings( user_id: str, data: CastleSettings ) -> CastleSettings: settings = UserCastleSettings(**data.dict(), id=user_id) await db.insert("extension_settings", settings) return settings async def get_castle_settings(user_id: str) -> Optional[CastleSettings]: return await db.fetchone( """ SELECT * FROM extension_settings WHERE id = :user_id """, {"user_id": user_id}, CastleSettings, ) async def update_castle_settings( user_id: str, data: CastleSettings ) -> CastleSettings: settings = UserCastleSettings(**data.dict(), id=user_id) await db.update("extension_settings", settings) return settings # ===== USER WALLET SETTINGS ===== async def create_user_wallet_settings( user_id: str, data: UserWalletSettings ) -> UserWalletSettings: settings = StoredUserWalletSettings(**data.dict(), id=user_id) await db.insert("user_wallet_settings", settings) return settings async def get_user_wallet_settings(user_id: str) -> Optional[UserWalletSettings]: return await db.fetchone( """ SELECT * FROM user_wallet_settings WHERE id = :user_id """, {"user_id": user_id}, UserWalletSettings, ) async def update_user_wallet_settings( user_id: str, data: UserWalletSettings ) -> UserWalletSettings: settings = StoredUserWalletSettings(**data.dict(), id=user_id) await db.update("user_wallet_settings", settings) return settings async def get_all_user_wallet_settings() -> list[StoredUserWalletSettings]: """Get all user wallet settings""" return await db.fetchall( "SELECT * FROM user_wallet_settings ORDER BY id", {}, StoredUserWalletSettings, ) # ===== MANUAL PAYMENT REQUESTS ===== async def create_manual_payment_request( user_id: str, amount: int, description: str ) -> "ManualPaymentRequest": """Create a new manual payment request""" from .models import ManualPaymentRequest request_id = urlsafe_short_hash() request = ManualPaymentRequest( id=request_id, user_id=user_id, amount=amount, description=description, status="pending", created_at=datetime.now(), ) await db.execute( """ INSERT INTO manual_payment_requests (id, user_id, amount, description, status, created_at) VALUES (:id, :user_id, :amount, :description, :status, :created_at) """, { "id": request.id, "user_id": request.user_id, "amount": request.amount, "description": request.description, "status": request.status, "created_at": request.created_at, }, ) return request async def get_manual_payment_request(request_id: str) -> Optional["ManualPaymentRequest"]: """Get a manual payment request by ID""" from .models import ManualPaymentRequest return await db.fetchone( "SELECT * FROM manual_payment_requests WHERE id = :id", {"id": request_id}, ManualPaymentRequest, ) async def get_user_manual_payment_requests( user_id: str, limit: int = 100 ) -> list["ManualPaymentRequest"]: """Get all manual payment requests for a specific user""" from .models import ManualPaymentRequest return await db.fetchall( """ SELECT * FROM manual_payment_requests WHERE user_id = :user_id ORDER BY created_at DESC LIMIT :limit """, {"user_id": user_id, "limit": limit}, ManualPaymentRequest, ) async def get_all_manual_payment_requests( status: Optional[str] = None, limit: int = 100 ) -> list["ManualPaymentRequest"]: """Get all manual payment requests, optionally filtered by status""" from .models import ManualPaymentRequest if status: return await db.fetchall( """ SELECT * FROM manual_payment_requests WHERE status = :status ORDER BY created_at DESC LIMIT :limit """, {"status": status, "limit": limit}, ManualPaymentRequest, ) else: return await db.fetchall( """ SELECT * FROM manual_payment_requests ORDER BY created_at DESC LIMIT :limit """, {"limit": limit}, ManualPaymentRequest, ) async def approve_manual_payment_request( request_id: str, reviewed_by: str, journal_entry_id: str ) -> Optional["ManualPaymentRequest"]: """Approve a manual payment request""" from .models import ManualPaymentRequest await db.execute( """ UPDATE manual_payment_requests SET status = 'approved', reviewed_at = :reviewed_at, reviewed_by = :reviewed_by, journal_entry_id = :journal_entry_id WHERE id = :id """, { "id": request_id, "reviewed_at": datetime.now(), "reviewed_by": reviewed_by, "journal_entry_id": journal_entry_id, }, ) return await get_manual_payment_request(request_id) async def reject_manual_payment_request( request_id: str, reviewed_by: str ) -> Optional["ManualPaymentRequest"]: """Reject a manual payment request""" from .models import ManualPaymentRequest await db.execute( """ UPDATE manual_payment_requests SET status = 'rejected', reviewed_at = :reviewed_at, reviewed_by = :reviewed_by WHERE id = :id """, { "id": request_id, "reviewed_at": datetime.now(), "reviewed_by": reviewed_by, }, ) return await get_manual_payment_request(request_id) # ===== BALANCE ASSERTION OPERATIONS ===== async def create_balance_assertion( data: CreateBalanceAssertion, created_by: str ) -> BalanceAssertion: """Create a new balance assertion""" from decimal import Decimal assertion_id = urlsafe_short_hash() assertion_date = data.date if data.date else datetime.now() assertion = BalanceAssertion( id=assertion_id, date=assertion_date, account_id=data.account_id, expected_balance_sats=data.expected_balance_sats, expected_balance_fiat=data.expected_balance_fiat, fiat_currency=data.fiat_currency, tolerance_sats=data.tolerance_sats, tolerance_fiat=data.tolerance_fiat, status=AssertionStatus.PENDING, created_by=created_by, created_at=datetime.now(), ) # Manually insert with Decimal fields converted to strings await db.execute( """ INSERT INTO balance_assertions ( id, date, account_id, expected_balance_sats, expected_balance_fiat, fiat_currency, tolerance_sats, tolerance_fiat, status, created_by, created_at ) VALUES ( :id, :date, :account_id, :expected_balance_sats, :expected_balance_fiat, :fiat_currency, :tolerance_sats, :tolerance_fiat, :status, :created_by, :created_at ) """, { "id": assertion.id, "date": assertion.date, "account_id": assertion.account_id, "expected_balance_sats": assertion.expected_balance_sats, "expected_balance_fiat": str(assertion.expected_balance_fiat) if assertion.expected_balance_fiat else None, "fiat_currency": assertion.fiat_currency, "tolerance_sats": assertion.tolerance_sats, "tolerance_fiat": str(assertion.tolerance_fiat), "status": assertion.status.value, "created_by": assertion.created_by, "created_at": assertion.created_at, }, ) return assertion async def get_balance_assertion(assertion_id: str) -> Optional[BalanceAssertion]: """Get a balance assertion by ID""" from decimal import Decimal row = await db.fetchone( "SELECT * FROM balance_assertions WHERE id = :id", {"id": assertion_id}, ) if not row: return None # Parse Decimal fields from TEXT storage return BalanceAssertion( id=row["id"], date=row["date"], account_id=row["account_id"], expected_balance_sats=row["expected_balance_sats"], expected_balance_fiat=Decimal(row["expected_balance_fiat"]) if row["expected_balance_fiat"] else None, fiat_currency=row["fiat_currency"], tolerance_sats=row["tolerance_sats"], tolerance_fiat=Decimal(row["tolerance_fiat"]) if row["tolerance_fiat"] else Decimal("0"), checked_balance_sats=row["checked_balance_sats"], checked_balance_fiat=Decimal(row["checked_balance_fiat"]) if row["checked_balance_fiat"] else None, difference_sats=row["difference_sats"], difference_fiat=Decimal(row["difference_fiat"]) if row["difference_fiat"] else None, status=AssertionStatus(row["status"]), created_by=row["created_by"], created_at=row["created_at"], checked_at=row["checked_at"], ) async def get_balance_assertions( account_id: Optional[str] = None, status: Optional[AssertionStatus] = None, limit: int = 100, ) -> list[BalanceAssertion]: """Get balance assertions with optional filters""" from decimal import Decimal if account_id and status: rows = await db.fetchall( """ SELECT * FROM balance_assertions WHERE account_id = :account_id AND status = :status ORDER BY date DESC LIMIT :limit """, {"account_id": account_id, "status": status.value, "limit": limit}, ) elif account_id: rows = await db.fetchall( """ SELECT * FROM balance_assertions WHERE account_id = :account_id ORDER BY date DESC LIMIT :limit """, {"account_id": account_id, "limit": limit}, ) elif status: rows = await db.fetchall( """ SELECT * FROM balance_assertions WHERE status = :status ORDER BY date DESC LIMIT :limit """, {"status": status.value, "limit": limit}, ) else: rows = await db.fetchall( """ SELECT * FROM balance_assertions ORDER BY date DESC LIMIT :limit """, {"limit": limit}, ) assertions = [] for row in rows: assertions.append( BalanceAssertion( id=row["id"], date=row["date"], account_id=row["account_id"], expected_balance_sats=row["expected_balance_sats"], expected_balance_fiat=Decimal(row["expected_balance_fiat"]) if row["expected_balance_fiat"] else None, fiat_currency=row["fiat_currency"], tolerance_sats=row["tolerance_sats"], tolerance_fiat=Decimal(row["tolerance_fiat"]) if row["tolerance_fiat"] else Decimal("0"), checked_balance_sats=row["checked_balance_sats"], checked_balance_fiat=Decimal(row["checked_balance_fiat"]) if row["checked_balance_fiat"] else None, difference_sats=row["difference_sats"], difference_fiat=Decimal(row["difference_fiat"]) if row["difference_fiat"] else None, status=AssertionStatus(row["status"]), created_by=row["created_by"], created_at=row["created_at"], checked_at=row["checked_at"], ) ) return assertions async def check_balance_assertion(assertion_id: str) -> BalanceAssertion: """ Check a balance assertion by comparing expected vs actual balance. Updates the assertion with the check results. Uses Fava/Beancount for balance queries. """ from decimal import Decimal from .fava_client import get_fava_client assertion = await get_balance_assertion(assertion_id) if not assertion: raise ValueError(f"Balance assertion {assertion_id} not found") # Get actual account balance from Fava account = await get_account(assertion.account_id) if not account: raise ValueError(f"Account {assertion.account_id} not found") fava = get_fava_client() # Get balance from Fava balance_data = await fava.get_account_balance(account.name) actual_balance = balance_data["sats"] # Get fiat balance if needed actual_fiat_balance = None if assertion.fiat_currency and account.user_id: user_balance_data = await fava.get_user_balance(account.user_id) actual_fiat_balance = user_balance_data["fiat_balances"].get(assertion.fiat_currency, Decimal("0")) # Check sats balance difference_sats = actual_balance - assertion.expected_balance_sats sats_match = abs(difference_sats) <= assertion.tolerance_sats # Check fiat balance if applicable fiat_match = True difference_fiat = None if assertion.expected_balance_fiat is not None and actual_fiat_balance is not None: difference_fiat = actual_fiat_balance - assertion.expected_balance_fiat fiat_match = abs(difference_fiat) <= assertion.tolerance_fiat # Determine overall status status = AssertionStatus.PASSED if (sats_match and fiat_match) else AssertionStatus.FAILED # Update assertion with check results await db.execute( """ UPDATE balance_assertions SET checked_balance_sats = :checked_sats, checked_balance_fiat = :checked_fiat, difference_sats = :diff_sats, difference_fiat = :diff_fiat, status = :status, checked_at = :checked_at WHERE id = :id """, { "id": assertion_id, "checked_sats": actual_balance, "checked_fiat": str(actual_fiat_balance) if actual_fiat_balance is not None else None, "diff_sats": difference_sats, "diff_fiat": str(difference_fiat) if difference_fiat is not None else None, "status": status.value, "checked_at": datetime.now(), }, ) # Return updated assertion return await get_balance_assertion(assertion_id) async def delete_balance_assertion(assertion_id: str) -> None: """Delete a balance assertion""" await db.execute( "DELETE FROM balance_assertions WHERE id = :id", {"id": assertion_id}, ) # User Equity Status CRUD operations async def get_user_equity_status(user_id: str) -> Optional["UserEquityStatus"]: """Get user's equity eligibility status""" from .models import UserEquityStatus row = await db.fetchone( """ SELECT * FROM user_equity_status WHERE user_id = :user_id """, {"user_id": user_id}, ) return UserEquityStatus(**row) if row else None async def create_or_update_user_equity_status( data: "CreateUserEquityStatus", granted_by: str ) -> "UserEquityStatus": """Create or update user equity eligibility status""" from datetime import datetime from .models import UserEquityStatus, AccountType import uuid # Auto-create user-specific equity account if granting eligibility if data.is_equity_eligible: # Generate equity account name: Equity:User-{user_id} equity_account_name = f"Equity:User-{data.user_id[:8]}" # Check if the equity account already exists equity_account = await get_account_by_name(equity_account_name) if not equity_account: # Create the user-specific equity account await db.execute( """ INSERT INTO accounts (id, name, account_type, description, user_id, created_at) VALUES (:id, :name, :type, :description, :user_id, :created_at) """, { "id": str(uuid.uuid4()), "name": equity_account_name, "type": AccountType.EQUITY.value, "description": f"Equity contributions for user {data.user_id[:8]}", "user_id": data.user_id, "created_at": datetime.now(), }, ) # Auto-populate equity_account_name in the data data.equity_account_name = equity_account_name # Check if user already has equity status existing = await get_user_equity_status(data.user_id) if existing: # Update existing record await db.execute( """ UPDATE user_equity_status SET is_equity_eligible = :is_equity_eligible, equity_account_name = :equity_account_name, notes = :notes, granted_by = :granted_by, granted_at = :granted_at, revoked_at = :revoked_at WHERE user_id = :user_id """, { "user_id": data.user_id, "is_equity_eligible": data.is_equity_eligible, "equity_account_name": data.equity_account_name, "notes": data.notes, "granted_by": granted_by, "granted_at": datetime.now(), "revoked_at": None if data.is_equity_eligible else datetime.now(), }, ) else: # Create new record await db.execute( """ INSERT INTO user_equity_status ( user_id, is_equity_eligible, equity_account_name, notes, granted_by, granted_at ) VALUES ( :user_id, :is_equity_eligible, :equity_account_name, :notes, :granted_by, :granted_at ) """, { "user_id": data.user_id, "is_equity_eligible": data.is_equity_eligible, "equity_account_name": data.equity_account_name, "notes": data.notes, "granted_by": granted_by, "granted_at": datetime.now(), }, ) # Return the created/updated record result = await get_user_equity_status(data.user_id) if not result: raise ValueError(f"Failed to create/update equity status for user {data.user_id}") return result async def revoke_user_equity_eligibility(user_id: str) -> Optional["UserEquityStatus"]: """Revoke user's equity contribution eligibility""" from datetime import datetime await db.execute( """ UPDATE user_equity_status SET is_equity_eligible = FALSE, revoked_at = :revoked_at WHERE user_id = :user_id """, {"user_id": user_id, "revoked_at": datetime.now()}, ) return await get_user_equity_status(user_id) async def get_all_equity_eligible_users() -> list["UserEquityStatus"]: """Get all equity-eligible users""" from .models import UserEquityStatus rows = await db.fetchall( """ SELECT * FROM user_equity_status WHERE is_equity_eligible = TRUE ORDER BY granted_at DESC """ ) return [UserEquityStatus(**row) for row in rows] # ===== ACCOUNT PERMISSION OPERATIONS ===== async def create_account_permission( data: "CreateAccountPermission", granted_by: str ) -> "AccountPermission": """ Create a new account permission. Raises: ValueError: If account is inactive or doesn't exist """ from .models import AccountPermission # Validate account exists and is active account = await get_account(data.account_id) if not account: raise ValueError(f"Account {data.account_id} not found") if not account.is_active: raise ValueError( f"Cannot grant permission on inactive account: {account.name}" ) permission_id = urlsafe_short_hash() permission = AccountPermission( id=permission_id, user_id=data.user_id, account_id=data.account_id, permission_type=data.permission_type, granted_by=granted_by, granted_at=datetime.now(), expires_at=data.expires_at, notes=data.notes, ) await db.execute( """ INSERT INTO account_permissions ( id, user_id, account_id, permission_type, granted_by, granted_at, expires_at, notes ) VALUES ( :id, :user_id, :account_id, :permission_type, :granted_by, :granted_at, :expires_at, :notes ) """, { "id": permission.id, "user_id": permission.user_id, "account_id": permission.account_id, "permission_type": permission.permission_type.value, "granted_by": permission.granted_by, "granted_at": permission.granted_at, "expires_at": permission.expires_at, "notes": permission.notes, }, ) # Invalidate permission cache for this user (Cache class doesn't have delete method, use pop) permission_cache._values.pop(f"permissions:user:{permission.user_id}", None) permission_cache._values.pop(f"permissions:user:{permission.user_id}:{permission.permission_type.value}", None) return permission async def get_account_permission(permission_id: str) -> Optional["AccountPermission"]: """Get account permission by ID""" from .models import AccountPermission, PermissionType row = await db.fetchone( "SELECT * FROM account_permissions WHERE id = :id", {"id": permission_id}, ) if not row: return None return AccountPermission( id=row["id"], user_id=row["user_id"], account_id=row["account_id"], permission_type=PermissionType(row["permission_type"]), granted_by=row["granted_by"], granted_at=row["granted_at"], expires_at=row["expires_at"], notes=row["notes"], ) async def get_user_permissions( user_id: str, permission_type: Optional["PermissionType"] = None ) -> list["AccountPermission"]: """Get all permissions for a specific user with caching""" from .models import AccountPermission, PermissionType # Build cache key cache_key = f"permissions:user:{user_id}" if permission_type: cache_key += f":{permission_type.value}" # Try cache first cached = permission_cache.get(cache_key) if cached is not None: return cached # Query DB if permission_type: rows = await db.fetchall( """ SELECT * FROM account_permissions WHERE user_id = :user_id AND permission_type = :permission_type AND (expires_at IS NULL OR expires_at > :now) ORDER BY granted_at DESC """, { "user_id": user_id, "permission_type": permission_type.value, "now": datetime.now(), }, ) else: rows = await db.fetchall( """ SELECT * FROM account_permissions WHERE user_id = :user_id AND (expires_at IS NULL OR expires_at > :now) ORDER BY granted_at DESC """, {"user_id": user_id, "now": datetime.now()}, ) permissions = [ AccountPermission( id=row["id"], user_id=row["user_id"], account_id=row["account_id"], permission_type=PermissionType(row["permission_type"]), granted_by=row["granted_by"], granted_at=row["granted_at"], expires_at=row["expires_at"], notes=row["notes"], ) for row in rows ] # Cache result permission_cache.set(cache_key, permissions, PERMISSION_CACHE_TTL) return permissions async def get_account_permissions(account_id: str) -> list["AccountPermission"]: """Get all permissions for a specific account""" from .models import AccountPermission, PermissionType rows = await db.fetchall( """ SELECT * FROM account_permissions WHERE account_id = :account_id AND (expires_at IS NULL OR expires_at > :now) ORDER BY granted_at DESC """, {"account_id": account_id, "now": datetime.now()}, ) return [ AccountPermission( id=row["id"], user_id=row["user_id"], account_id=row["account_id"], permission_type=PermissionType(row["permission_type"]), granted_by=row["granted_by"], granted_at=row["granted_at"], expires_at=row["expires_at"], notes=row["notes"], ) for row in rows ] async def delete_account_permission(permission_id: str) -> None: """Delete (revoke) an account permission""" # Get permission first to invalidate cache permission = await get_account_permission(permission_id) await db.execute( "DELETE FROM account_permissions WHERE id = :id", {"id": permission_id}, ) # Invalidate permission cache for this user (Cache class doesn't have delete method, use pop) if permission: permission_cache._values.pop(f"permissions:user:{permission.user_id}", None) permission_cache._values.pop(f"permissions:user:{permission.user_id}:{permission.permission_type.value}", None) async def check_user_has_permission( user_id: str, account_id: str, permission_type: "PermissionType" ) -> bool: """Check if user has a specific permission on an account (direct permission only, no inheritance)""" row = await db.fetchone( """ SELECT id FROM account_permissions WHERE user_id = :user_id AND account_id = :account_id AND permission_type = :permission_type AND (expires_at IS NULL OR expires_at > :now) """, { "user_id": user_id, "account_id": account_id, "permission_type": permission_type.value, "now": datetime.now(), }, ) return row is not None async def get_user_permissions_with_inheritance( user_id: str, account_name: str, permission_type: "PermissionType" ) -> list[tuple["AccountPermission", Optional[str]]]: """ Get all permissions for a user on an account, including inherited permissions from parent accounts. Includes both direct permissions AND role-based permissions. Returns list of tuples: (permission, parent_account_name or None) Example: If user has permission on "Expenses:Food", they also have permission on "Expenses:Food:Groceries" Returns: [(permission_on_food, "Expenses:Food")] """ from .models import AccountPermission, PermissionType # Get direct user permissions of this type direct_permissions = await get_user_permissions(user_id, permission_type) # Get role-based permissions of this type role_permissions_list = await get_user_permissions_from_roles(user_id) role_perms = [] for role, perms in role_permissions_list: # Filter for the specific permission type role_perms.extend([p for p in perms if p.permission_type == permission_type]) # Combine direct and role-based permissions all_permissions = list(direct_permissions) + role_perms # Find which permissions apply to this account (direct or inherited) applicable_permissions = [] for perm in all_permissions: # Get the account for this permission account = await get_account(perm.account_id) if not account: continue # Check if this account is a parent of the target account # Parent accounts are indicated by hierarchical names (colon-separated) # e.g., "Expenses:Food" is parent of "Expenses:Food:Groceries" if account_name == account.name: # Direct permission applicable_permissions.append((perm, None)) elif account_name.startswith(account.name + ":"): # Inherited permission from parent account applicable_permissions.append((perm, account.name)) return applicable_permissions # ===== ROLE-BASED ACCESS CONTROL (RBAC) OPERATIONS ===== async def create_role(data: CreateRole, created_by: str) -> Role: """Create a new role""" role_id = urlsafe_short_hash() role = Role( id=role_id, name=data.name, description=data.description, is_default=data.is_default, created_by=created_by, created_at=datetime.now(), ) await db.execute( """ INSERT INTO roles (id, name, description, is_default, created_by, created_at) VALUES (:id, :name, :description, :is_default, :created_by, :created_at) """, { "id": role.id, "name": role.name, "description": role.description, "is_default": role.is_default, "created_by": role.created_by, "created_at": role.created_at, }, ) return role async def get_role(role_id: str) -> Optional[Role]: """Get role by ID""" row = await db.fetchone( "SELECT * FROM roles WHERE id = :id", {"id": role_id}, ) if not row: return None return Role( id=row["id"], name=row["name"], description=row["description"], is_default=row["is_default"], created_by=row["created_by"], created_at=row["created_at"], ) async def get_role_by_name(name: str) -> Optional[Role]: """Get role by name""" row = await db.fetchone( "SELECT * FROM roles WHERE name = :name", {"name": name}, ) if not row: return None return Role( id=row["id"], name=row["name"], description=row["description"], is_default=row["is_default"], created_by=row["created_by"], created_at=row["created_at"], ) async def get_all_roles() -> list[Role]: """Get all roles""" rows = await db.fetchall( "SELECT * FROM roles ORDER BY name", ) return [ Role( id=row["id"], name=row["name"], description=row["description"], is_default=row["is_default"], created_by=row["created_by"], created_at=row["created_at"], ) for row in rows ] async def get_default_role() -> Optional[Role]: """Get the default role that is auto-assigned to new users""" row = await db.fetchone( "SELECT * FROM roles WHERE is_default = TRUE LIMIT 1", ) if not row: return None return Role( id=row["id"], name=row["name"], description=row["description"], is_default=row["is_default"], created_by=row["created_by"], created_at=row["created_at"], ) async def update_role(role_id: str, data: UpdateRole) -> Optional[Role]: """Update a role""" # If setting this role as default, unset any other default roles if data.is_default is True: await db.execute( "UPDATE roles SET is_default = FALSE WHERE id != :id", {"id": role_id}, ) # Build update statement dynamically based on provided fields updates = [] params = {"id": role_id} if data.name is not None: updates.append("name = :name") params["name"] = data.name if data.description is not None: updates.append("description = :description") params["description"] = data.description if data.is_default is not None: updates.append("is_default = :is_default") params["is_default"] = data.is_default if not updates: return await get_role(role_id) await db.execute( f"UPDATE roles SET {', '.join(updates)} WHERE id = :id", params, ) return await get_role(role_id) async def delete_role(role_id: str) -> None: """Delete a role (cascade deletes role_permissions and user_roles)""" await db.execute( "DELETE FROM roles WHERE id = :id", {"id": role_id}, ) # ===== ROLE PERMISSION OPERATIONS ===== async def create_role_permission(data: CreateRolePermission) -> RolePermission: """Create a permission for a role""" permission_id = urlsafe_short_hash() permission = RolePermission( id=permission_id, role_id=data.role_id, account_id=data.account_id, permission_type=data.permission_type, notes=data.notes, created_at=datetime.now(), ) await db.execute( """ INSERT INTO role_permissions (id, role_id, account_id, permission_type, notes, created_at) VALUES (:id, :role_id, :account_id, :permission_type, :notes, :created_at) """, { "id": permission.id, "role_id": permission.role_id, "account_id": permission.account_id, "permission_type": permission.permission_type.value, "notes": permission.notes, "created_at": permission.created_at, }, ) return permission async def get_role_permissions(role_id: str) -> list[RolePermission]: """Get all permissions for a specific role""" rows = await db.fetchall( """ SELECT * FROM role_permissions WHERE role_id = :role_id ORDER BY created_at DESC """, {"role_id": role_id}, ) return [ RolePermission( id=row["id"], role_id=row["role_id"], account_id=row["account_id"], permission_type=PermissionType(row["permission_type"]), notes=row["notes"], created_at=row["created_at"], ) for row in rows ] async def delete_role_permission(permission_id: str) -> None: """Delete a role permission""" await db.execute( "DELETE FROM role_permissions WHERE id = :id", {"id": permission_id}, ) # ===== USER ROLE OPERATIONS ===== async def assign_user_role(data: AssignUserRole, granted_by: str) -> UserRole: """Assign a user to a role""" user_role_id = urlsafe_short_hash() user_role = UserRole( id=user_role_id, user_id=data.user_id, role_id=data.role_id, granted_by=granted_by, granted_at=datetime.now(), expires_at=data.expires_at, notes=data.notes, ) await db.execute( """ INSERT INTO user_roles (id, user_id, role_id, granted_by, granted_at, expires_at, notes) VALUES (:id, :user_id, :role_id, :granted_by, :granted_at, :expires_at, :notes) """, { "id": user_role.id, "user_id": user_role.user_id, "role_id": user_role.role_id, "granted_by": user_role.granted_by, "granted_at": user_role.granted_at, "expires_at": user_role.expires_at, "notes": user_role.notes, }, ) return user_role async def get_user_roles(user_id: str) -> list[UserRole]: """Get all active roles for a user""" rows = await db.fetchall( """ SELECT * FROM user_roles WHERE user_id = :user_id AND (expires_at IS NULL OR expires_at > :now) ORDER BY granted_at DESC """, {"user_id": user_id, "now": datetime.now()}, ) return [ UserRole( id=row["id"], user_id=row["user_id"], role_id=row["role_id"], granted_by=row["granted_by"], granted_at=row["granted_at"], expires_at=row["expires_at"], notes=row["notes"], ) for row in rows ] async def get_all_user_roles() -> list[UserRole]: """Get all active user role assignments""" rows = await db.fetchall( """ SELECT * FROM user_roles WHERE (expires_at IS NULL OR expires_at > :now) ORDER BY user_id, granted_at DESC """, {"now": datetime.now()}, ) return [ UserRole( id=row["id"], user_id=row["user_id"], role_id=row["role_id"], granted_by=row["granted_by"], granted_at=row["granted_at"], expires_at=row["expires_at"], notes=row["notes"], ) for row in rows ] async def get_role_users(role_id: str) -> list[UserRole]: """Get all users assigned to a role""" rows = await db.fetchall( """ SELECT * FROM user_roles WHERE role_id = :role_id AND (expires_at IS NULL OR expires_at > :now) ORDER BY granted_at DESC """, {"role_id": role_id, "now": datetime.now()}, ) return [ UserRole( id=row["id"], user_id=row["user_id"], role_id=row["role_id"], granted_by=row["granted_by"], granted_at=row["granted_at"], expires_at=row["expires_at"], notes=row["notes"], ) for row in rows ] async def revoke_user_role(user_role_id: str) -> None: """Revoke a user's role assignment""" await db.execute( "DELETE FROM user_roles WHERE id = :id", {"id": user_role_id}, ) async def get_role_count_for_user(user_id: str) -> int: """Get count of active roles for a user""" row = await db.fetchone( """ SELECT COUNT(*) as count FROM user_roles WHERE user_id = :user_id AND (expires_at IS NULL OR expires_at > :now) """, {"user_id": user_id, "now": datetime.now()}, ) return row["count"] if row else 0 async def auto_assign_default_role(user_id: str, assigned_by: str) -> UserRole | None: """ Auto-assign the default role to a user if they don't have any roles yet. Returns the created UserRole if assigned, None if user already has roles or no default role exists. """ from loguru import logger logger.info(f"[AUTO-ASSIGN] Checking auto-assignment for user {user_id}") # Check if user already has any roles user_role_count = await get_role_count_for_user(user_id) logger.info(f"[AUTO-ASSIGN] User {user_id} has {user_role_count} roles") if user_role_count > 0: logger.info(f"[AUTO-ASSIGN] User {user_id} already has roles, skipping auto-assignment") return None # Find the default role default_role = await get_default_role() if not default_role: logger.warning(f"[AUTO-ASSIGN] No default role found, cannot auto-assign for user {user_id}") return None logger.info(f"[AUTO-ASSIGN] Found default role: {default_role.name} (id: {default_role.id})") # Assign the default role data = AssignUserRole( user_id=user_id, role_id=default_role.id, notes="Auto-assigned default role on first access", ) result = await assign_user_role(data, assigned_by) logger.info(f"[AUTO-ASSIGN] Successfully assigned role {default_role.name} to user {user_id}") return result async def get_user_count_for_role(role_id: str) -> int: """Get count of users assigned to a role""" row = await db.fetchone( """ SELECT COUNT(*) as count FROM user_roles WHERE role_id = :role_id AND (expires_at IS NULL OR expires_at > :now) """, {"role_id": role_id, "now": datetime.now()}, ) return row["count"] if row else 0 # ===== RBAC HELPER FUNCTIONS ===== async def get_user_permissions_from_roles( user_id: str, ) -> list[tuple[Role, list[RolePermission]]]: """ Get all permissions a user has through their role assignments. Returns list of tuples: (role, list of permissions from that role) """ # Get user's active roles user_roles = await get_user_roles(user_id) result = [] for user_role in user_roles: role = await get_role(user_role.role_id) if role: permissions = await get_role_permissions(role.id) result.append((role, permissions)) return result async def check_user_has_role_permission( user_id: str, account_id: str, permission_type: PermissionType ) -> bool: """Check if user has a specific permission through any of their roles""" # Get all permissions from user's roles role_permissions = await get_user_permissions_from_roles(user_id) # Check if any role grants the required permission on this account for role, permissions in role_permissions: for perm in permissions: if perm.account_id == account_id and perm.permission_type == permission_type: return True return False