diff --git a/core/__init__.py b/core/__init__.py index 9b4cf2b..662bb20 100644 --- a/core/__init__.py +++ b/core/__init__.py @@ -4,8 +4,6 @@ Castle Core Module - Pure accounting logic separated from database operations. This module contains the core business logic for double-entry accounting, following Beancount patterns for clean architecture: -- inventory.py: Position tracking across currencies -- balance.py: Balance calculation logic - validation.py: Comprehensive validation rules Benefits: @@ -13,16 +11,14 @@ Benefits: - Reusable across different storage backends - Clear separation of concerns - Easier to audit and verify + +Note: Balance calculation and inventory tracking have been migrated to Fava/Beancount. +All accounting calculations are now performed via Fava's query API. """ -from .inventory import CastleInventory, CastlePosition -from .balance import BalanceCalculator from .validation import ValidationError, validate_journal_entry, validate_balance __all__ = [ - "CastleInventory", - "CastlePosition", - "BalanceCalculator", "ValidationError", "validate_journal_entry", "validate_balance", diff --git a/core/balance.py b/core/balance.py deleted file mode 100644 index 37a113c..0000000 --- a/core/balance.py +++ /dev/null @@ -1,246 +0,0 @@ -""" -Balance calculation logic for Castle accounting. - -Pure functions for calculating account and user balances from journal entries, -following double-entry accounting principles. -""" - -from decimal import Decimal -from typing import Any, Dict, List, Optional -from enum import Enum - -from .inventory import CastleInventory, CastlePosition - - -class AccountType(str, Enum): - """Account types in double-entry accounting""" - ASSET = "asset" - LIABILITY = "liability" - EQUITY = "equity" - REVENUE = "revenue" - EXPENSE = "expense" - - -class BalanceCalculator: - """ - Pure logic for calculating balances from journal entries. - - This class contains no database access - it operates on data structures - passed to it, making it easy to test and reuse. - """ - - @staticmethod - def calculate_account_balance( - total_debit: int, - total_credit: int, - account_type: AccountType - ) -> int: - """ - Calculate account balance based on account type. - - Normal balances: - - Assets and Expenses: Debit balance (debit - credit) - - Liabilities, Equity, and Revenue: Credit balance (credit - debit) - - Args: - total_debit: Sum of all debits in satoshis - total_credit: Sum of all credits in satoshis - account_type: Type of account - - Returns: - Balance in satoshis - """ - if account_type in [AccountType.ASSET, AccountType.EXPENSE]: - return total_debit - total_credit - else: - return total_credit - total_debit - - @staticmethod - def calculate_account_balance_from_amount( - total_amount: int, - account_type: AccountType - ) -> int: - """ - Calculate account balance from total amount (Beancount-style single amount field). - - This method uses Beancount's elegant single amount field approach: - - Positive amounts represent debits (increase assets/expenses) - - Negative amounts represent credits (increase liabilities/equity/revenue) - - Args: - total_amount: Sum of all amounts for this account (positive/negative) - account_type: Type of account - - Returns: - Balance in satoshis - - Examples: - # Asset account with +100 (debit): - calculate_account_balance_from_amount(100, AccountType.ASSET) → 100 - - # Liability account with -100 (credit = liability increase): - calculate_account_balance_from_amount(-100, AccountType.LIABILITY) → 100 - """ - if account_type in [AccountType.ASSET, AccountType.EXPENSE]: - # For assets and expenses, positive amounts increase balance - return total_amount - else: - # For liabilities, equity, and revenue, negative amounts increase balance - # So we invert the sign for display - return -total_amount - - @staticmethod - def build_inventory_from_entry_lines( - entry_lines: List[Dict[str, Any]], - account_type: AccountType - ) -> CastleInventory: - """ - Build a CastleInventory from journal entry lines (Beancount-style with single amount field). - - Args: - entry_lines: List of entry line dictionaries with keys: - - amount: int (satoshis; positive = debit, negative = credit) - - metadata: str (JSON string with optional fiat_currency, fiat_amount) - account_type: Type of account (affects sign of amounts) - - Returns: - CastleInventory with positions for sats and fiat currencies - """ - import json - - inventory = CastleInventory() - - for line in entry_lines: - # Parse metadata - metadata = json.loads(line.get("metadata", "{}")) if line.get("metadata") else {} - fiat_currency = metadata.get("fiat_currency") - fiat_amount_raw = metadata.get("fiat_amount") - - # Convert fiat amount to Decimal - fiat_amount = Decimal(str(fiat_amount_raw)) if fiat_amount_raw else None - - # Get amount (Beancount-style: positive = debit, negative = credit) - amount = line.get("amount", 0) - - if amount != 0: - sats_amount = Decimal(amount) - - # Apply account-specific sign adjustment - # For liability/equity/revenue: negative amounts increase balance - # For assets/expenses: positive amounts increase balance - if account_type in [AccountType.LIABILITY, AccountType.EQUITY, AccountType.REVENUE]: - # Invert sign for liability-type accounts - sats_amount = -sats_amount - fiat_amount = -fiat_amount if fiat_amount else None - - inventory.add_position( - CastlePosition( - currency="SATS", - amount=sats_amount, - cost_currency=fiat_currency, - cost_amount=fiat_amount, - metadata=metadata, - ) - ) - - return inventory - - @staticmethod - def calculate_user_balance( - accounts: List[Dict[str, Any]], - account_balances: Dict[str, int], - account_inventories: Dict[str, CastleInventory] - ) -> Dict[str, Any]: - """ - Calculate user's total balance across all their accounts. - - User balance represents what the Castle owes the user: - - Positive: Castle owes user - - Negative: User owes Castle - - Args: - accounts: List of account dictionaries with keys: - - id: str - - account_type: str (asset/liability/equity) - account_balances: Dict mapping account_id to balance in sats - account_inventories: Dict mapping account_id to CastleInventory - - Returns: - Dictionary with: - - balance: int (total sats, positive = castle owes user) - - fiat_balances: Dict[str, Decimal] (fiat balances by currency) - """ - total_balance = 0 - combined_inventory = CastleInventory() - - for account in accounts: - account_id = account["id"] - account_type = AccountType(account["account_type"]) - balance = account_balances.get(account_id, 0) - inventory = account_inventories.get(account_id, CastleInventory()) - - # Add sats balance based on account type - if account_type == AccountType.LIABILITY: - # Liability: positive balance means castle owes user - total_balance += balance - elif account_type == AccountType.ASSET: - # Asset (receivable): positive balance means user owes castle (negative for user) - total_balance -= balance - # Equity contributions don't affect what castle owes - - # Merge inventories for fiat tracking (exclude equity) - if account_type != AccountType.EQUITY: - for position in inventory.positions.values(): - # Adjust sign based on account type - if account_type == AccountType.ASSET: - # For receivables, negate the position - combined_inventory.add_position(position.negate()) - else: - combined_inventory.add_position(position) - - fiat_balances = combined_inventory.get_all_fiat_balances() - - return { - "balance": total_balance, - "fiat_balances": fiat_balances, - } - - @staticmethod - def check_balance_matches( - actual_balance_sats: int, - expected_balance_sats: int, - tolerance_sats: int = 0 - ) -> bool: - """ - Check if actual balance matches expected within tolerance. - - Args: - actual_balance_sats: Actual calculated balance - expected_balance_sats: Expected balance from assertion - tolerance_sats: Allowed difference (±) - - Returns: - True if balances match within tolerance - """ - difference = abs(actual_balance_sats - expected_balance_sats) - return difference <= tolerance_sats - - @staticmethod - def check_fiat_balance_matches( - actual_balance_fiat: Decimal, - expected_balance_fiat: Decimal, - tolerance_fiat: Decimal = Decimal(0) - ) -> bool: - """ - Check if actual fiat balance matches expected within tolerance. - - Args: - actual_balance_fiat: Actual calculated fiat balance - expected_balance_fiat: Expected fiat balance from assertion - tolerance_fiat: Allowed difference (±) - - Returns: - True if balances match within tolerance - """ - difference = abs(actual_balance_fiat - expected_balance_fiat) - return difference <= tolerance_fiat diff --git a/core/inventory.py b/core/inventory.py deleted file mode 100644 index 858ff43..0000000 --- a/core/inventory.py +++ /dev/null @@ -1,203 +0,0 @@ -""" -Inventory system for position tracking. - -Similar to Beancount's Inventory class, this module provides position tracking -across multiple currencies with cost basis information. -""" - -from dataclasses import dataclass, field -from datetime import datetime -from decimal import Decimal -from typing import Any, Dict, Optional, Tuple - - -@dataclass(frozen=True) -class CastlePosition: - """ - A position in the Castle inventory. - - Represents an amount in a specific currency, optionally with cost basis - information for tracking currency conversions. - - Examples: - # Simple sats position - CastlePosition(currency="SATS", amount=Decimal("100000")) - - # Sats with EUR cost basis - CastlePosition( - currency="SATS", - amount=Decimal("100000"), - cost_currency="EUR", - cost_amount=Decimal("50.00") - ) - """ - - currency: str # "SATS", "EUR", "USD", etc. - amount: Decimal - - # Cost basis (for tracking conversions) - cost_currency: Optional[str] = None # Original currency if converted - cost_amount: Optional[Decimal] = None # Original amount - - # Metadata - date: Optional[datetime] = None - metadata: Dict[str, Any] = field(default_factory=dict) - - def __post_init__(self): - """Validate position data""" - if not isinstance(self.amount, Decimal): - object.__setattr__(self, "amount", Decimal(str(self.amount))) - - if self.cost_amount is not None and not isinstance(self.cost_amount, Decimal): - object.__setattr__( - self, "cost_amount", Decimal(str(self.cost_amount)) - ) - - def __add__(self, other: "CastlePosition") -> "CastlePosition": - """Add two positions (must be same currency and cost_currency)""" - if self.currency != other.currency: - raise ValueError(f"Cannot add positions with different currencies: {self.currency} != {other.currency}") - - if self.cost_currency != other.cost_currency: - raise ValueError(f"Cannot add positions with different cost currencies: {self.cost_currency} != {other.cost_currency}") - - return CastlePosition( - currency=self.currency, - amount=self.amount + other.amount, - cost_currency=self.cost_currency, - cost_amount=( - (self.cost_amount or Decimal(0)) + (other.cost_amount or Decimal(0)) - if self.cost_amount is not None or other.cost_amount is not None - else None - ), - date=other.date, # Use most recent date - metadata={**self.metadata, **other.metadata}, - ) - - def negate(self) -> "CastlePosition": - """Return a position with negated amount""" - return CastlePosition( - currency=self.currency, - amount=-self.amount, - cost_currency=self.cost_currency, - cost_amount=-self.cost_amount if self.cost_amount else None, - date=self.date, - metadata=self.metadata, - ) - - -class CastleInventory: - """ - Track balances across multiple currencies with conversion tracking. - - Similar to Beancount's Inventory but optimized for Castle's use case. - Positions are keyed by (currency, cost_currency) to track different - cost bases separately. - - Examples: - inv = CastleInventory() - inv.add_position(CastlePosition("SATS", Decimal("100000"))) - inv.add_position(CastlePosition("SATS", Decimal("50000"), "EUR", Decimal("25"))) - - inv.get_balance_sats() # Returns: Decimal("150000") - inv.get_balance_fiat("EUR") # Returns: Decimal("25") - """ - - def __init__(self): - self.positions: Dict[Tuple[str, Optional[str]], CastlePosition] = {} - - def add_position(self, position: CastlePosition): - """ - Add or merge a position into the inventory. - - Positions with the same (currency, cost_currency) key are merged. - """ - key = (position.currency, position.cost_currency) - - if key in self.positions: - self.positions[key] = self.positions[key] + position - else: - self.positions[key] = position - - def get_balance_sats(self) -> Decimal: - """Get total balance in satoshis""" - return sum( - pos.amount - for (curr, _), pos in self.positions.items() - if curr == "SATS" - ) - - def get_balance_fiat(self, currency: str) -> Decimal: - """ - Get balance in specific fiat currency from cost metadata. - - This sums up all cost_amount values for positions that have - the specified cost_currency. - """ - return sum( - pos.cost_amount or Decimal(0) - for (_, cost_curr), pos in self.positions.items() - if cost_curr == currency - ) - - def get_all_fiat_balances(self) -> Dict[str, Decimal]: - """Get balances for all fiat currencies present in the inventory""" - fiat_currencies = set( - cost_curr - for _, cost_curr in self.positions.keys() - if cost_curr - ) - - return { - curr: self.get_balance_fiat(curr) - for curr in fiat_currencies - } - - def is_empty(self) -> bool: - """Check if inventory has no positions""" - return len(self.positions) == 0 - - def is_zero(self) -> bool: - """ - Check if all positions sum to zero. - - Returns True if the inventory has positions but they all sum to zero. - """ - return all( - pos.amount == Decimal(0) - for pos in self.positions.values() - ) - - def to_dict(self) -> dict: - """ - Export inventory to dictionary format. - - Returns: - { - "sats": 100000, - "fiat": { - "EUR": 50.00, - "USD": 60.00 - } - } - """ - fiat_balances = self.get_all_fiat_balances() - - return { - "sats": int(self.get_balance_sats()), - "fiat": { - curr: float(amount) - for curr, amount in fiat_balances.items() - }, - } - - def __repr__(self) -> str: - """String representation for debugging""" - if self.is_empty(): - return "CastleInventory(empty)" - - positions_str = ", ".join( - f"{curr}: {pos.amount}" - for (curr, _), pos in self.positions.items() - ) - return f"CastleInventory({positions_str})" diff --git a/crud.py b/crud.py index 57eea2c..b70be70 100644 --- a/crud.py +++ b/crud.py @@ -29,8 +29,6 @@ from .models import ( ) # Import core accounting logic -from .core.balance import BalanceCalculator, AccountType as CoreAccountType -from .core.inventory import CastleInventory, CastlePosition from .core.validation import ( ValidationError, validate_journal_entry, @@ -484,128 +482,6 @@ async def count_journal_entries_by_user_and_account_type(user_id: str, account_t # ===== BALANCE AND REPORTING ===== -async def get_account_balance(account_id: str) -> int: - """ - Calculate account balance using single amount field (Beancount-style). - Only includes entries that are cleared (flag='*'), excludes pending/flagged/voided entries. - - For each account type: - - Assets/Expenses: balance = sum of amounts (positive amounts increase, negative decrease) - - Liabilities/Equity/Revenue: balance = -sum of amounts (negative amounts increase, positive decrease) - - This works because we store amounts consistently: - - Debit (asset/expense increase) = positive amount - - Credit (liability/equity/revenue increase) = negative amount - """ - result = await db.fetchone( - """ - SELECT COALESCE(SUM(el.amount), 0) as total_amount - FROM entry_lines el - JOIN journal_entries je ON el.journal_entry_id = je.id - WHERE el.account_id = :id - AND je.flag = '*' - """, - {"id": account_id}, - ) - - if not result: - return 0 - - account = await get_account(account_id) - if not account: - return 0 - - total_amount = result["total_amount"] - - # Use core BalanceCalculator for consistent logic - core_account_type = CoreAccountType(account.account_type.value) - return BalanceCalculator.calculate_account_balance_from_amount( - total_amount, core_account_type - ) - - -async def get_user_balance(user_id: str) -> UserBalance: - """Get user's balance with the Castle (positive = castle owes user, negative = user owes castle)""" - # Get all user-specific accounts - user_accounts = await db.fetchall( - "SELECT * FROM accounts WHERE user_id = :user_id", - {"user_id": user_id}, - Account, - ) - - # Calculate balances for each account - account_balances = {} - account_inventories = {} - - for account in user_accounts: - # Get satoshi balance - balance = await get_account_balance(account.id) - account_balances[account.id] = balance - - # Get all entry lines for this account to build inventory - # Only include cleared entries (exclude pending/flagged/voided) - entry_lines = await db.fetchall( - """ - SELECT el.* - FROM entry_lines el - JOIN journal_entries je ON el.journal_entry_id = je.id - WHERE el.account_id = :account_id - AND je.flag = '*' - """, - {"account_id": account.id}, - ) - - # Use BalanceCalculator to build inventory from entry lines - core_account_type = CoreAccountType(account.account_type.value) - inventory = BalanceCalculator.build_inventory_from_entry_lines( - [dict(line) for line in entry_lines], - core_account_type - ) - account_inventories[account.id] = inventory - - # Use BalanceCalculator to calculate total user balance - accounts_list = [ - {"id": acc.id, "account_type": acc.account_type.value} - for acc in user_accounts - ] - balance_result = BalanceCalculator.calculate_user_balance( - accounts_list, - account_balances, - account_inventories - ) - - return UserBalance( - user_id=user_id, - balance=balance_result["balance"], - accounts=user_accounts, - fiat_balances=balance_result["fiat_balances"], - ) - - -async def get_all_user_balances() -> list[UserBalance]: - """Get balances for all users (used by castle to see who they owe)""" - # Get all user-specific accounts - all_accounts = await db.fetchall( - "SELECT * FROM accounts WHERE user_id IS NOT NULL", - {}, - Account, - ) - - # Get unique user IDs - user_ids = set(account.user_id for account in all_accounts if account.user_id) - - # Calculate balance for each user using the refactored function - user_balances = [] - for user_id in user_ids: - balance = await get_user_balance(user_id) - - # Include users with non-zero balance or fiat balances - if balance.balance != 0 or balance.fiat_balances: - user_balances.append(balance) - - return user_balances - - async def get_account_transactions( account_id: str, limit: int = 100 ) -> list[tuple[JournalEntry, EntryLine]]: @@ -1013,26 +889,31 @@ async def check_balance_assertion(assertion_id: str) -> BalanceAssertion: """ Check a balance assertion by comparing expected vs actual balance. Updates the assertion with the check results. + Uses Fava/Beancount for balance queries. """ from decimal import Decimal + from .fava_client import get_fava_client assertion = await get_balance_assertion(assertion_id) if not assertion: raise ValueError(f"Balance assertion {assertion_id} not found") - # Get actual account balance + # Get actual account balance from Fava account = await get_account(assertion.account_id) if not account: raise ValueError(f"Account {assertion.account_id} not found") - # Calculate balance at the assertion date - actual_balance = await get_account_balance(assertion.account_id) + fava = get_fava_client() + + # Get balance from Fava + balance_data = await fava.get_account_balance(account.name) + actual_balance = balance_data["sats"] # Get fiat balance if needed actual_fiat_balance = None if assertion.fiat_currency and account.user_id: - user_balance = await get_user_balance(account.user_id) - actual_fiat_balance = user_balance.fiat_balances.get(assertion.fiat_currency, Decimal("0")) + user_balance_data = await fava.get_user_balance(account.user_id) + actual_fiat_balance = user_balance_data["fiat_balances"].get(assertion.fiat_currency, Decimal("0")) # Check sats balance difference_sats = actual_balance - assertion.expected_balance_sats