castle/core/balance.py
padreug 5cc2630777 REFACTOR Migrates to single 'amount' field for transactions
Refactors the data model to use a single 'amount' field for journal entry lines, aligning with the Beancount approach.
This simplifies the model, enhances compatibility, and eliminates invalid states.

Includes a database migration to convert existing debit/credit columns to the new 'amount' field.

Updates balance calculation logic to utilize the new amount field for improved accuracy and efficiency.
2025-11-08 10:33:17 +01:00

245 lines
8.6 KiB
Python

"""
Balance calculation logic for Castle accounting.
Pure functions for calculating account and user balances from journal entries,
following double-entry accounting principles.
"""
from decimal import Decimal
from typing import Any, Dict, List, Optional
from enum import Enum
from .inventory import CastleInventory, CastlePosition
class AccountType(str, Enum):
"""Account types in double-entry accounting"""
ASSET = "asset"
LIABILITY = "liability"
EQUITY = "equity"
REVENUE = "revenue"
EXPENSE = "expense"
class BalanceCalculator:
"""
Pure logic for calculating balances from journal entries.
This class contains no database access - it operates on data structures
passed to it, making it easy to test and reuse.
"""
@staticmethod
def calculate_account_balance(
total_debit: int,
total_credit: int,
account_type: AccountType
) -> int:
"""
Calculate account balance based on account type.
Normal balances:
- Assets and Expenses: Debit balance (debit - credit)
- Liabilities, Equity, and Revenue: Credit balance (credit - debit)
Args:
total_debit: Sum of all debits in satoshis
total_credit: Sum of all credits in satoshis
account_type: Type of account
Returns:
Balance in satoshis
"""
if account_type in [AccountType.ASSET, AccountType.EXPENSE]:
return total_debit - total_credit
else:
return total_credit - total_debit
@staticmethod
def calculate_account_balance_from_amount(
total_amount: int,
account_type: AccountType
) -> int:
"""
Calculate account balance from total amount (Beancount-style single amount field).
This method uses Beancount's elegant single amount field approach:
- Positive amounts represent debits (increase assets/expenses)
- Negative amounts represent credits (increase liabilities/equity/revenue)
Args:
total_amount: Sum of all amounts for this account (positive/negative)
account_type: Type of account
Returns:
Balance in satoshis
Examples:
# Asset account with +100 (debit):
calculate_account_balance_from_amount(100, AccountType.ASSET) → 100
# Liability account with -100 (credit = liability increase):
calculate_account_balance_from_amount(-100, AccountType.LIABILITY) → 100
"""
if account_type in [AccountType.ASSET, AccountType.EXPENSE]:
# For assets and expenses, positive amounts increase balance
return total_amount
else:
# For liabilities, equity, and revenue, negative amounts increase balance
# So we invert the sign for display
return -total_amount
@staticmethod
def build_inventory_from_entry_lines(
entry_lines: List[Dict[str, Any]],
account_type: AccountType
) -> CastleInventory:
"""
Build a CastleInventory from journal entry lines (Beancount-style with single amount field).
Args:
entry_lines: List of entry line dictionaries with keys:
- amount: int (satoshis; positive = debit, negative = credit)
- metadata: str (JSON string with optional fiat_currency, fiat_amount)
account_type: Type of account (affects sign of amounts)
Returns:
CastleInventory with positions for sats and fiat currencies
"""
import json
inventory = CastleInventory()
for line in entry_lines:
# Parse metadata
metadata = json.loads(line.get("metadata", "{}")) if line.get("metadata") else {}
fiat_currency = metadata.get("fiat_currency")
fiat_amount_raw = metadata.get("fiat_amount")
# Convert fiat amount to Decimal
fiat_amount = Decimal(str(fiat_amount_raw)) if fiat_amount_raw else None
# Get amount (Beancount-style: positive = debit, negative = credit)
amount = line.get("amount", 0)
if amount != 0:
sats_amount = Decimal(amount)
# Apply account-specific sign adjustment
# For liability/equity/revenue: negative amounts increase balance
# For assets/expenses: positive amounts increase balance
if account_type in [AccountType.LIABILITY, AccountType.EQUITY, AccountType.REVENUE]:
# Invert sign for liability-type accounts
sats_amount = -sats_amount
fiat_amount = -fiat_amount if fiat_amount else None
inventory.add_position(
CastlePosition(
currency="SATS",
amount=sats_amount,
cost_currency=fiat_currency,
cost_amount=fiat_amount,
metadata=metadata,
)
)
return inventory
@staticmethod
def calculate_user_balance(
accounts: List[Dict[str, Any]],
account_balances: Dict[str, int],
account_inventories: Dict[str, CastleInventory]
) -> Dict[str, Any]:
"""
Calculate user's total balance across all their accounts.
User balance represents what the Castle owes the user:
- Positive: Castle owes user
- Negative: User owes Castle
Args:
accounts: List of account dictionaries with keys:
- id: str
- account_type: str (asset/liability/equity)
account_balances: Dict mapping account_id to balance in sats
account_inventories: Dict mapping account_id to CastleInventory
Returns:
Dictionary with:
- balance: int (total sats, positive = castle owes user)
- fiat_balances: Dict[str, Decimal] (fiat balances by currency)
"""
total_balance = 0
combined_inventory = CastleInventory()
for account in accounts:
account_id = account["id"]
account_type = AccountType(account["account_type"])
balance = account_balances.get(account_id, 0)
inventory = account_inventories.get(account_id, CastleInventory())
# Add sats balance based on account type
if account_type == AccountType.LIABILITY:
# Liability: positive balance means castle owes user
total_balance += balance
elif account_type == AccountType.ASSET:
# Asset (receivable): positive balance means user owes castle (negative for user)
total_balance -= balance
# Equity contributions don't affect what castle owes
# Merge inventories for fiat tracking
for position in inventory.positions.values():
# Adjust sign based on account type
if account_type == AccountType.ASSET:
# For receivables, negate the position
combined_inventory.add_position(position.negate())
else:
combined_inventory.add_position(position)
fiat_balances = combined_inventory.get_all_fiat_balances()
return {
"balance": total_balance,
"fiat_balances": fiat_balances,
}
@staticmethod
def check_balance_matches(
actual_balance_sats: int,
expected_balance_sats: int,
tolerance_sats: int = 0
) -> bool:
"""
Check if actual balance matches expected within tolerance.
Args:
actual_balance_sats: Actual calculated balance
expected_balance_sats: Expected balance from assertion
tolerance_sats: Allowed difference (±)
Returns:
True if balances match within tolerance
"""
difference = abs(actual_balance_sats - expected_balance_sats)
return difference <= tolerance_sats
@staticmethod
def check_fiat_balance_matches(
actual_balance_fiat: Decimal,
expected_balance_fiat: Decimal,
tolerance_fiat: Decimal = Decimal(0)
) -> bool:
"""
Check if actual fiat balance matches expected within tolerance.
Args:
actual_balance_fiat: Actual calculated fiat balance
expected_balance_fiat: Expected fiat balance from assertion
tolerance_fiat: Allowed difference (±)
Returns:
True if balances match within tolerance
"""
difference = abs(actual_balance_fiat - expected_balance_fiat)
return difference <= tolerance_fiat