Removes core balance calculation logic

Migrates balance calculation and inventory tracking to
Fava/Beancount, leveraging Fava's query API for all
accounting calculations. This simplifies the core module
and centralizes accounting logic in Fava.
This commit is contained in:
padreug 2025-11-09 23:13:26 +01:00
parent efc09aa5ce
commit 88ff3821ce
4 changed files with 13 additions and 585 deletions

View file

@ -4,8 +4,6 @@ Castle Core Module - Pure accounting logic separated from database operations.
This module contains the core business logic for double-entry accounting, This module contains the core business logic for double-entry accounting,
following Beancount patterns for clean architecture: following Beancount patterns for clean architecture:
- inventory.py: Position tracking across currencies
- balance.py: Balance calculation logic
- validation.py: Comprehensive validation rules - validation.py: Comprehensive validation rules
Benefits: Benefits:
@ -13,16 +11,14 @@ Benefits:
- Reusable across different storage backends - Reusable across different storage backends
- Clear separation of concerns - Clear separation of concerns
- Easier to audit and verify - Easier to audit and verify
Note: Balance calculation and inventory tracking have been migrated to Fava/Beancount.
All accounting calculations are now performed via Fava's query API.
""" """
from .inventory import CastleInventory, CastlePosition
from .balance import BalanceCalculator
from .validation import ValidationError, validate_journal_entry, validate_balance from .validation import ValidationError, validate_journal_entry, validate_balance
__all__ = [ __all__ = [
"CastleInventory",
"CastlePosition",
"BalanceCalculator",
"ValidationError", "ValidationError",
"validate_journal_entry", "validate_journal_entry",
"validate_balance", "validate_balance",

View file

@ -1,246 +0,0 @@
"""
Balance calculation logic for Castle accounting.
Pure functions for calculating account and user balances from journal entries,
following double-entry accounting principles.
"""
from decimal import Decimal
from typing import Any, Dict, List, Optional
from enum import Enum
from .inventory import CastleInventory, CastlePosition
class AccountType(str, Enum):
"""Account types in double-entry accounting"""
ASSET = "asset"
LIABILITY = "liability"
EQUITY = "equity"
REVENUE = "revenue"
EXPENSE = "expense"
class BalanceCalculator:
"""
Pure logic for calculating balances from journal entries.
This class contains no database access - it operates on data structures
passed to it, making it easy to test and reuse.
"""
@staticmethod
def calculate_account_balance(
total_debit: int,
total_credit: int,
account_type: AccountType
) -> int:
"""
Calculate account balance based on account type.
Normal balances:
- Assets and Expenses: Debit balance (debit - credit)
- Liabilities, Equity, and Revenue: Credit balance (credit - debit)
Args:
total_debit: Sum of all debits in satoshis
total_credit: Sum of all credits in satoshis
account_type: Type of account
Returns:
Balance in satoshis
"""
if account_type in [AccountType.ASSET, AccountType.EXPENSE]:
return total_debit - total_credit
else:
return total_credit - total_debit
@staticmethod
def calculate_account_balance_from_amount(
total_amount: int,
account_type: AccountType
) -> int:
"""
Calculate account balance from total amount (Beancount-style single amount field).
This method uses Beancount's elegant single amount field approach:
- Positive amounts represent debits (increase assets/expenses)
- Negative amounts represent credits (increase liabilities/equity/revenue)
Args:
total_amount: Sum of all amounts for this account (positive/negative)
account_type: Type of account
Returns:
Balance in satoshis
Examples:
# Asset account with +100 (debit):
calculate_account_balance_from_amount(100, AccountType.ASSET) 100
# Liability account with -100 (credit = liability increase):
calculate_account_balance_from_amount(-100, AccountType.LIABILITY) 100
"""
if account_type in [AccountType.ASSET, AccountType.EXPENSE]:
# For assets and expenses, positive amounts increase balance
return total_amount
else:
# For liabilities, equity, and revenue, negative amounts increase balance
# So we invert the sign for display
return -total_amount
@staticmethod
def build_inventory_from_entry_lines(
entry_lines: List[Dict[str, Any]],
account_type: AccountType
) -> CastleInventory:
"""
Build a CastleInventory from journal entry lines (Beancount-style with single amount field).
Args:
entry_lines: List of entry line dictionaries with keys:
- amount: int (satoshis; positive = debit, negative = credit)
- metadata: str (JSON string with optional fiat_currency, fiat_amount)
account_type: Type of account (affects sign of amounts)
Returns:
CastleInventory with positions for sats and fiat currencies
"""
import json
inventory = CastleInventory()
for line in entry_lines:
# Parse metadata
metadata = json.loads(line.get("metadata", "{}")) if line.get("metadata") else {}
fiat_currency = metadata.get("fiat_currency")
fiat_amount_raw = metadata.get("fiat_amount")
# Convert fiat amount to Decimal
fiat_amount = Decimal(str(fiat_amount_raw)) if fiat_amount_raw else None
# Get amount (Beancount-style: positive = debit, negative = credit)
amount = line.get("amount", 0)
if amount != 0:
sats_amount = Decimal(amount)
# Apply account-specific sign adjustment
# For liability/equity/revenue: negative amounts increase balance
# For assets/expenses: positive amounts increase balance
if account_type in [AccountType.LIABILITY, AccountType.EQUITY, AccountType.REVENUE]:
# Invert sign for liability-type accounts
sats_amount = -sats_amount
fiat_amount = -fiat_amount if fiat_amount else None
inventory.add_position(
CastlePosition(
currency="SATS",
amount=sats_amount,
cost_currency=fiat_currency,
cost_amount=fiat_amount,
metadata=metadata,
)
)
return inventory
@staticmethod
def calculate_user_balance(
accounts: List[Dict[str, Any]],
account_balances: Dict[str, int],
account_inventories: Dict[str, CastleInventory]
) -> Dict[str, Any]:
"""
Calculate user's total balance across all their accounts.
User balance represents what the Castle owes the user:
- Positive: Castle owes user
- Negative: User owes Castle
Args:
accounts: List of account dictionaries with keys:
- id: str
- account_type: str (asset/liability/equity)
account_balances: Dict mapping account_id to balance in sats
account_inventories: Dict mapping account_id to CastleInventory
Returns:
Dictionary with:
- balance: int (total sats, positive = castle owes user)
- fiat_balances: Dict[str, Decimal] (fiat balances by currency)
"""
total_balance = 0
combined_inventory = CastleInventory()
for account in accounts:
account_id = account["id"]
account_type = AccountType(account["account_type"])
balance = account_balances.get(account_id, 0)
inventory = account_inventories.get(account_id, CastleInventory())
# Add sats balance based on account type
if account_type == AccountType.LIABILITY:
# Liability: positive balance means castle owes user
total_balance += balance
elif account_type == AccountType.ASSET:
# Asset (receivable): positive balance means user owes castle (negative for user)
total_balance -= balance
# Equity contributions don't affect what castle owes
# Merge inventories for fiat tracking (exclude equity)
if account_type != AccountType.EQUITY:
for position in inventory.positions.values():
# Adjust sign based on account type
if account_type == AccountType.ASSET:
# For receivables, negate the position
combined_inventory.add_position(position.negate())
else:
combined_inventory.add_position(position)
fiat_balances = combined_inventory.get_all_fiat_balances()
return {
"balance": total_balance,
"fiat_balances": fiat_balances,
}
@staticmethod
def check_balance_matches(
actual_balance_sats: int,
expected_balance_sats: int,
tolerance_sats: int = 0
) -> bool:
"""
Check if actual balance matches expected within tolerance.
Args:
actual_balance_sats: Actual calculated balance
expected_balance_sats: Expected balance from assertion
tolerance_sats: Allowed difference (±)
Returns:
True if balances match within tolerance
"""
difference = abs(actual_balance_sats - expected_balance_sats)
return difference <= tolerance_sats
@staticmethod
def check_fiat_balance_matches(
actual_balance_fiat: Decimal,
expected_balance_fiat: Decimal,
tolerance_fiat: Decimal = Decimal(0)
) -> bool:
"""
Check if actual fiat balance matches expected within tolerance.
Args:
actual_balance_fiat: Actual calculated fiat balance
expected_balance_fiat: Expected fiat balance from assertion
tolerance_fiat: Allowed difference (±)
Returns:
True if balances match within tolerance
"""
difference = abs(actual_balance_fiat - expected_balance_fiat)
return difference <= tolerance_fiat

View file

@ -1,203 +0,0 @@
"""
Inventory system for position tracking.
Similar to Beancount's Inventory class, this module provides position tracking
across multiple currencies with cost basis information.
"""
from dataclasses import dataclass, field
from datetime import datetime
from decimal import Decimal
from typing import Any, Dict, Optional, Tuple
@dataclass(frozen=True)
class CastlePosition:
"""
A position in the Castle inventory.
Represents an amount in a specific currency, optionally with cost basis
information for tracking currency conversions.
Examples:
# Simple sats position
CastlePosition(currency="SATS", amount=Decimal("100000"))
# Sats with EUR cost basis
CastlePosition(
currency="SATS",
amount=Decimal("100000"),
cost_currency="EUR",
cost_amount=Decimal("50.00")
)
"""
currency: str # "SATS", "EUR", "USD", etc.
amount: Decimal
# Cost basis (for tracking conversions)
cost_currency: Optional[str] = None # Original currency if converted
cost_amount: Optional[Decimal] = None # Original amount
# Metadata
date: Optional[datetime] = None
metadata: Dict[str, Any] = field(default_factory=dict)
def __post_init__(self):
"""Validate position data"""
if not isinstance(self.amount, Decimal):
object.__setattr__(self, "amount", Decimal(str(self.amount)))
if self.cost_amount is not None and not isinstance(self.cost_amount, Decimal):
object.__setattr__(
self, "cost_amount", Decimal(str(self.cost_amount))
)
def __add__(self, other: "CastlePosition") -> "CastlePosition":
"""Add two positions (must be same currency and cost_currency)"""
if self.currency != other.currency:
raise ValueError(f"Cannot add positions with different currencies: {self.currency} != {other.currency}")
if self.cost_currency != other.cost_currency:
raise ValueError(f"Cannot add positions with different cost currencies: {self.cost_currency} != {other.cost_currency}")
return CastlePosition(
currency=self.currency,
amount=self.amount + other.amount,
cost_currency=self.cost_currency,
cost_amount=(
(self.cost_amount or Decimal(0)) + (other.cost_amount or Decimal(0))
if self.cost_amount is not None or other.cost_amount is not None
else None
),
date=other.date, # Use most recent date
metadata={**self.metadata, **other.metadata},
)
def negate(self) -> "CastlePosition":
"""Return a position with negated amount"""
return CastlePosition(
currency=self.currency,
amount=-self.amount,
cost_currency=self.cost_currency,
cost_amount=-self.cost_amount if self.cost_amount else None,
date=self.date,
metadata=self.metadata,
)
class CastleInventory:
"""
Track balances across multiple currencies with conversion tracking.
Similar to Beancount's Inventory but optimized for Castle's use case.
Positions are keyed by (currency, cost_currency) to track different
cost bases separately.
Examples:
inv = CastleInventory()
inv.add_position(CastlePosition("SATS", Decimal("100000")))
inv.add_position(CastlePosition("SATS", Decimal("50000"), "EUR", Decimal("25")))
inv.get_balance_sats() # Returns: Decimal("150000")
inv.get_balance_fiat("EUR") # Returns: Decimal("25")
"""
def __init__(self):
self.positions: Dict[Tuple[str, Optional[str]], CastlePosition] = {}
def add_position(self, position: CastlePosition):
"""
Add or merge a position into the inventory.
Positions with the same (currency, cost_currency) key are merged.
"""
key = (position.currency, position.cost_currency)
if key in self.positions:
self.positions[key] = self.positions[key] + position
else:
self.positions[key] = position
def get_balance_sats(self) -> Decimal:
"""Get total balance in satoshis"""
return sum(
pos.amount
for (curr, _), pos in self.positions.items()
if curr == "SATS"
)
def get_balance_fiat(self, currency: str) -> Decimal:
"""
Get balance in specific fiat currency from cost metadata.
This sums up all cost_amount values for positions that have
the specified cost_currency.
"""
return sum(
pos.cost_amount or Decimal(0)
for (_, cost_curr), pos in self.positions.items()
if cost_curr == currency
)
def get_all_fiat_balances(self) -> Dict[str, Decimal]:
"""Get balances for all fiat currencies present in the inventory"""
fiat_currencies = set(
cost_curr
for _, cost_curr in self.positions.keys()
if cost_curr
)
return {
curr: self.get_balance_fiat(curr)
for curr in fiat_currencies
}
def is_empty(self) -> bool:
"""Check if inventory has no positions"""
return len(self.positions) == 0
def is_zero(self) -> bool:
"""
Check if all positions sum to zero.
Returns True if the inventory has positions but they all sum to zero.
"""
return all(
pos.amount == Decimal(0)
for pos in self.positions.values()
)
def to_dict(self) -> dict:
"""
Export inventory to dictionary format.
Returns:
{
"sats": 100000,
"fiat": {
"EUR": 50.00,
"USD": 60.00
}
}
"""
fiat_balances = self.get_all_fiat_balances()
return {
"sats": int(self.get_balance_sats()),
"fiat": {
curr: float(amount)
for curr, amount in fiat_balances.items()
},
}
def __repr__(self) -> str:
"""String representation for debugging"""
if self.is_empty():
return "CastleInventory(empty)"
positions_str = ", ".join(
f"{curr}: {pos.amount}"
for (curr, _), pos in self.positions.items()
)
return f"CastleInventory({positions_str})"

139
crud.py
View file

@ -29,8 +29,6 @@ from .models import (
) )
# Import core accounting logic # Import core accounting logic
from .core.balance import BalanceCalculator, AccountType as CoreAccountType
from .core.inventory import CastleInventory, CastlePosition
from .core.validation import ( from .core.validation import (
ValidationError, ValidationError,
validate_journal_entry, validate_journal_entry,
@ -484,128 +482,6 @@ async def count_journal_entries_by_user_and_account_type(user_id: str, account_t
# ===== BALANCE AND REPORTING ===== # ===== BALANCE AND REPORTING =====
async def get_account_balance(account_id: str) -> int:
"""
Calculate account balance using single amount field (Beancount-style).
Only includes entries that are cleared (flag='*'), excludes pending/flagged/voided entries.
For each account type:
- Assets/Expenses: balance = sum of amounts (positive amounts increase, negative decrease)
- Liabilities/Equity/Revenue: balance = -sum of amounts (negative amounts increase, positive decrease)
This works because we store amounts consistently:
- Debit (asset/expense increase) = positive amount
- Credit (liability/equity/revenue increase) = negative amount
"""
result = await db.fetchone(
"""
SELECT COALESCE(SUM(el.amount), 0) as total_amount
FROM entry_lines el
JOIN journal_entries je ON el.journal_entry_id = je.id
WHERE el.account_id = :id
AND je.flag = '*'
""",
{"id": account_id},
)
if not result:
return 0
account = await get_account(account_id)
if not account:
return 0
total_amount = result["total_amount"]
# Use core BalanceCalculator for consistent logic
core_account_type = CoreAccountType(account.account_type.value)
return BalanceCalculator.calculate_account_balance_from_amount(
total_amount, core_account_type
)
async def get_user_balance(user_id: str) -> UserBalance:
"""Get user's balance with the Castle (positive = castle owes user, negative = user owes castle)"""
# Get all user-specific accounts
user_accounts = await db.fetchall(
"SELECT * FROM accounts WHERE user_id = :user_id",
{"user_id": user_id},
Account,
)
# Calculate balances for each account
account_balances = {}
account_inventories = {}
for account in user_accounts:
# Get satoshi balance
balance = await get_account_balance(account.id)
account_balances[account.id] = balance
# Get all entry lines for this account to build inventory
# Only include cleared entries (exclude pending/flagged/voided)
entry_lines = await db.fetchall(
"""
SELECT el.*
FROM entry_lines el
JOIN journal_entries je ON el.journal_entry_id = je.id
WHERE el.account_id = :account_id
AND je.flag = '*'
""",
{"account_id": account.id},
)
# Use BalanceCalculator to build inventory from entry lines
core_account_type = CoreAccountType(account.account_type.value)
inventory = BalanceCalculator.build_inventory_from_entry_lines(
[dict(line) for line in entry_lines],
core_account_type
)
account_inventories[account.id] = inventory
# Use BalanceCalculator to calculate total user balance
accounts_list = [
{"id": acc.id, "account_type": acc.account_type.value}
for acc in user_accounts
]
balance_result = BalanceCalculator.calculate_user_balance(
accounts_list,
account_balances,
account_inventories
)
return UserBalance(
user_id=user_id,
balance=balance_result["balance"],
accounts=user_accounts,
fiat_balances=balance_result["fiat_balances"],
)
async def get_all_user_balances() -> list[UserBalance]:
"""Get balances for all users (used by castle to see who they owe)"""
# Get all user-specific accounts
all_accounts = await db.fetchall(
"SELECT * FROM accounts WHERE user_id IS NOT NULL",
{},
Account,
)
# Get unique user IDs
user_ids = set(account.user_id for account in all_accounts if account.user_id)
# Calculate balance for each user using the refactored function
user_balances = []
for user_id in user_ids:
balance = await get_user_balance(user_id)
# Include users with non-zero balance or fiat balances
if balance.balance != 0 or balance.fiat_balances:
user_balances.append(balance)
return user_balances
async def get_account_transactions( async def get_account_transactions(
account_id: str, limit: int = 100 account_id: str, limit: int = 100
) -> list[tuple[JournalEntry, EntryLine]]: ) -> list[tuple[JournalEntry, EntryLine]]:
@ -1013,26 +889,31 @@ async def check_balance_assertion(assertion_id: str) -> BalanceAssertion:
""" """
Check a balance assertion by comparing expected vs actual balance. Check a balance assertion by comparing expected vs actual balance.
Updates the assertion with the check results. Updates the assertion with the check results.
Uses Fava/Beancount for balance queries.
""" """
from decimal import Decimal from decimal import Decimal
from .fava_client import get_fava_client
assertion = await get_balance_assertion(assertion_id) assertion = await get_balance_assertion(assertion_id)
if not assertion: if not assertion:
raise ValueError(f"Balance assertion {assertion_id} not found") raise ValueError(f"Balance assertion {assertion_id} not found")
# Get actual account balance # Get actual account balance from Fava
account = await get_account(assertion.account_id) account = await get_account(assertion.account_id)
if not account: if not account:
raise ValueError(f"Account {assertion.account_id} not found") raise ValueError(f"Account {assertion.account_id} not found")
# Calculate balance at the assertion date fava = get_fava_client()
actual_balance = await get_account_balance(assertion.account_id)
# Get balance from Fava
balance_data = await fava.get_account_balance(account.name)
actual_balance = balance_data["sats"]
# Get fiat balance if needed # Get fiat balance if needed
actual_fiat_balance = None actual_fiat_balance = None
if assertion.fiat_currency and account.user_id: if assertion.fiat_currency and account.user_id:
user_balance = await get_user_balance(account.user_id) user_balance_data = await fava.get_user_balance(account.user_id)
actual_fiat_balance = user_balance.fiat_balances.get(assertion.fiat_currency, Decimal("0")) actual_fiat_balance = user_balance_data["fiat_balances"].get(assertion.fiat_currency, Decimal("0"))
# Check sats balance # Check sats balance
difference_sats = actual_balance - assertion.expected_balance_sats difference_sats = actual_balance - assertion.expected_balance_sats