""" Inventory system for position tracking. Similar to Beancount's Inventory class, this module provides position tracking across multiple currencies with cost basis information. """ from dataclasses import dataclass, field from datetime import datetime from decimal import Decimal from typing import Any, Dict, Optional, Tuple @dataclass(frozen=True) class CastlePosition: """ A position in the Castle inventory. Represents an amount in a specific currency, optionally with cost basis information for tracking currency conversions. Examples: # Simple sats position CastlePosition(currency="SATS", amount=Decimal("100000")) # Sats with EUR cost basis CastlePosition( currency="SATS", amount=Decimal("100000"), cost_currency="EUR", cost_amount=Decimal("50.00") ) """ currency: str # "SATS", "EUR", "USD", etc. amount: Decimal # Cost basis (for tracking conversions) cost_currency: Optional[str] = None # Original currency if converted cost_amount: Optional[Decimal] = None # Original amount # Metadata date: Optional[datetime] = None metadata: Dict[str, Any] = field(default_factory=dict) def __post_init__(self): """Validate position data""" if not isinstance(self.amount, Decimal): object.__setattr__(self, "amount", Decimal(str(self.amount))) if self.cost_amount is not None and not isinstance(self.cost_amount, Decimal): object.__setattr__( self, "cost_amount", Decimal(str(self.cost_amount)) ) def __add__(self, other: "CastlePosition") -> "CastlePosition": """Add two positions (must be same currency and cost_currency)""" if self.currency != other.currency: raise ValueError(f"Cannot add positions with different currencies: {self.currency} != {other.currency}") if self.cost_currency != other.cost_currency: raise ValueError(f"Cannot add positions with different cost currencies: {self.cost_currency} != {other.cost_currency}") return CastlePosition( currency=self.currency, amount=self.amount + other.amount, cost_currency=self.cost_currency, cost_amount=( (self.cost_amount or Decimal(0)) + (other.cost_amount or Decimal(0)) if self.cost_amount is not None or other.cost_amount is not None else None ), date=other.date, # Use most recent date metadata={**self.metadata, **other.metadata}, ) def negate(self) -> "CastlePosition": """Return a position with negated amount""" return CastlePosition( currency=self.currency, amount=-self.amount, cost_currency=self.cost_currency, cost_amount=-self.cost_amount if self.cost_amount else None, date=self.date, metadata=self.metadata, ) class CastleInventory: """ Track balances across multiple currencies with conversion tracking. Similar to Beancount's Inventory but optimized for Castle's use case. Positions are keyed by (currency, cost_currency) to track different cost bases separately. Examples: inv = CastleInventory() inv.add_position(CastlePosition("SATS", Decimal("100000"))) inv.add_position(CastlePosition("SATS", Decimal("50000"), "EUR", Decimal("25"))) inv.get_balance_sats() # Returns: Decimal("150000") inv.get_balance_fiat("EUR") # Returns: Decimal("25") """ def __init__(self): self.positions: Dict[Tuple[str, Optional[str]], CastlePosition] = {} def add_position(self, position: CastlePosition): """ Add or merge a position into the inventory. Positions with the same (currency, cost_currency) key are merged. """ key = (position.currency, position.cost_currency) if key in self.positions: self.positions[key] = self.positions[key] + position else: self.positions[key] = position def get_balance_sats(self) -> Decimal: """Get total balance in satoshis""" return sum( pos.amount for (curr, _), pos in self.positions.items() if curr == "SATS" ) def get_balance_fiat(self, currency: str) -> Decimal: """ Get balance in specific fiat currency from cost metadata. This sums up all cost_amount values for positions that have the specified cost_currency. """ return sum( pos.cost_amount or Decimal(0) for (_, cost_curr), pos in self.positions.items() if cost_curr == currency ) def get_all_fiat_balances(self) -> Dict[str, Decimal]: """Get balances for all fiat currencies present in the inventory""" fiat_currencies = set( cost_curr for _, cost_curr in self.positions.keys() if cost_curr ) return { curr: self.get_balance_fiat(curr) for curr in fiat_currencies } def is_empty(self) -> bool: """Check if inventory has no positions""" return len(self.positions) == 0 def is_zero(self) -> bool: """ Check if all positions sum to zero. Returns True if the inventory has positions but they all sum to zero. """ return all( pos.amount == Decimal(0) for pos in self.positions.values() ) def to_dict(self) -> dict: """ Export inventory to dictionary format. Returns: { "sats": 100000, "fiat": { "EUR": 50.00, "USD": 60.00 } } """ fiat_balances = self.get_all_fiat_balances() return { "sats": int(self.get_balance_sats()), "fiat": { curr: float(amount) for curr, amount in fiat_balances.items() }, } def __repr__(self) -> str: """String representation for debugging""" if self.is_empty(): return "CastleInventory(empty)" positions_str = ", ".join( f"{curr}: {pos.amount}" for (curr, _), pos in self.positions.items() ) return f"CastleInventory({positions_str})"