castle/crud.py
padreug 3af93c3479 Add receivable/payable filtering with database-level query optimization
Add account type filtering to Recent Transactions table and fix pagination issue where filters were applied after fetching results, causing incomplete data display.

Database layer (crud.py):
  - Add get_journal_entries_by_user_and_account_type() to filter entries by
    both user_id and account_type at SQL query level
  - Add count_journal_entries_by_user_and_account_type() for accurate counts
  - Filters apply before pagination, ensuring all matching records are fetched

API layer (views_api.py):
  - Add filter_account_type parameter ('asset' for receivable, 'liability' for payable)
  - Refactor filtering logic to use new database-level filter functions
  - Support filter combinations: user only, account_type only, user+account_type, or all
  - Enrich entries with account_type metadata for UI display

Frontend (index.js):
  - Add account_type to transactionFilter state
  - Add accountTypeOptions computed property with receivable/payable choices
  - Reorder table columns to show User before Date
  - Update loadTransactions to send account_type filter parameter
  - Update clearTransactionFilter to clear both user and account_type filters

UI (index.html):
  - Add second filter dropdown for account type (Receivable/Payable)
  - Show clear button when either filter is active
  - Update button label from "Clear Filter" to "Clear Filters"

This fixes the critical bug where filtering for receivables would only show a subset of results (e.g., 2 out of 20 entries fetched) instead of all matching receivables. Now filters are applied at the database level before pagination, ensuring users see all relevant transactions.
2025-11-09 00:28:54 +01:00

1441 lines
45 KiB
Python

import json
from datetime import datetime
from typing import Optional
from lnbits.db import Database
from lnbits.helpers import urlsafe_short_hash
from .models import (
Account,
AccountPermission,
AccountType,
AssertionStatus,
BalanceAssertion,
CastleSettings,
CreateAccount,
CreateAccountPermission,
CreateBalanceAssertion,
CreateEntryLine,
CreateJournalEntry,
CreateUserEquityStatus,
EntryLine,
JournalEntry,
PermissionType,
StoredUserWalletSettings,
UserBalance,
UserCastleSettings,
UserEquityStatus,
UserWalletSettings,
)
# Import core accounting logic
from .core.balance import BalanceCalculator, AccountType as CoreAccountType
from .core.inventory import CastleInventory, CastlePosition
from .core.validation import (
ValidationError,
validate_journal_entry,
validate_balance,
validate_receivable_entry,
validate_expense_entry,
validate_payment_entry,
)
db = Database("ext_castle")
# ===== ACCOUNT OPERATIONS =====
async def create_account(data: CreateAccount) -> Account:
account_id = urlsafe_short_hash()
account = Account(
id=account_id,
name=data.name,
account_type=data.account_type,
description=data.description,
user_id=data.user_id,
created_at=datetime.now(),
)
await db.insert("accounts", account)
return account
async def get_account(account_id: str) -> Optional[Account]:
return await db.fetchone(
"SELECT * FROM accounts WHERE id = :id",
{"id": account_id},
Account,
)
async def get_account_by_name(name: str) -> Optional[Account]:
"""Get account by name (hierarchical format)"""
return await db.fetchone(
"SELECT * FROM accounts WHERE name = :name",
{"name": name},
Account,
)
async def get_all_accounts() -> list[Account]:
return await db.fetchall(
"SELECT * FROM accounts ORDER BY account_type, name",
model=Account,
)
async def get_accounts_by_type(account_type: AccountType) -> list[Account]:
return await db.fetchall(
"SELECT * FROM accounts WHERE account_type = :type ORDER BY name",
{"type": account_type.value},
Account,
)
async def get_or_create_user_account(
user_id: str, account_type: AccountType, base_name: str
) -> Account:
"""
Get or create a user-specific account with hierarchical naming.
Examples:
get_or_create_user_account("af983632", AccountType.ASSET, "Accounts Receivable")
"Assets:Receivable:User-af983632"
get_or_create_user_account("af983632", AccountType.LIABILITY, "Accounts Payable")
"Liabilities:Payable:User-af983632"
"""
from .account_utils import format_hierarchical_account_name
# Generate hierarchical account name
account_name = format_hierarchical_account_name(account_type, base_name, user_id)
# Try to find existing account with this hierarchical name
account = await db.fetchone(
"""
SELECT * FROM accounts
WHERE user_id = :user_id AND account_type = :type AND name = :name
""",
{"user_id": user_id, "type": account_type.value, "name": account_name},
Account,
)
if not account:
# Create new account with hierarchical name
account = await create_account(
CreateAccount(
name=account_name,
account_type=account_type,
description=f"User-specific {account_type.value} account",
user_id=user_id,
)
)
return account
# ===== JOURNAL ENTRY OPERATIONS =====
async def create_journal_entry(
data: CreateJournalEntry, created_by: str
) -> JournalEntry:
entry_id = urlsafe_short_hash()
# Validate that entry balances (sum of all amounts = 0)
# Beancount-style: positive amounts cancel out negative amounts
total_amount = sum(line.amount for line in data.lines)
if total_amount != 0:
raise ValueError(
f"Journal entry must balance (sum of amounts = 0): sum={total_amount}"
)
entry_date = data.entry_date or datetime.now()
journal_entry = JournalEntry(
id=entry_id,
description=data.description,
entry_date=entry_date,
created_by=created_by,
created_at=datetime.now(),
reference=data.reference,
lines=[],
flag=data.flag,
meta=data.meta,
)
# Insert journal entry without the lines field (lines are stored in entry_lines table)
await db.execute(
"""
INSERT INTO journal_entries (id, description, entry_date, created_by, created_at, reference, flag, meta)
VALUES (:id, :description, :entry_date, :created_by, :created_at, :reference, :flag, :meta)
""",
{
"id": journal_entry.id,
"description": journal_entry.description,
"entry_date": journal_entry.entry_date,
"created_by": journal_entry.created_by,
"created_at": journal_entry.created_at,
"reference": journal_entry.reference,
"flag": journal_entry.flag.value,
"meta": json.dumps(journal_entry.meta),
},
)
# Create entry lines
lines = []
for line_data in data.lines:
line_id = urlsafe_short_hash()
line = EntryLine(
id=line_id,
journal_entry_id=entry_id,
account_id=line_data.account_id,
amount=line_data.amount,
description=line_data.description,
metadata=line_data.metadata,
)
# Insert with metadata as JSON string
await db.execute(
"""
INSERT INTO entry_lines (id, journal_entry_id, account_id, amount, description, metadata)
VALUES (:id, :journal_entry_id, :account_id, :amount, :description, :metadata)
""",
{
"id": line.id,
"journal_entry_id": line.journal_entry_id,
"account_id": line.account_id,
"amount": line.amount,
"description": line.description,
"metadata": json.dumps(line.metadata),
},
)
lines.append(line)
journal_entry.lines = lines
return journal_entry
async def get_journal_entry(entry_id: str) -> Optional[JournalEntry]:
entry = await db.fetchone(
"SELECT * FROM journal_entries WHERE id = :id",
{"id": entry_id},
JournalEntry,
)
if entry:
entry.lines = await get_entry_lines(entry_id)
return entry
async def get_journal_entry_by_reference(reference: str) -> Optional[JournalEntry]:
"""Get a journal entry by its reference field (e.g., payment_hash)"""
entry = await db.fetchone(
"SELECT * FROM journal_entries WHERE reference = :reference",
{"reference": reference},
JournalEntry,
)
if entry:
entry.lines = await get_entry_lines(entry.id)
return entry
async def get_entry_lines(journal_entry_id: str) -> list[EntryLine]:
rows = await db.fetchall(
"SELECT * FROM entry_lines WHERE journal_entry_id = :id",
{"id": journal_entry_id},
)
lines = []
for row in rows:
# Parse metadata from JSON string
metadata = json.loads(row.metadata) if row.metadata else {}
line = EntryLine(
id=row.id,
journal_entry_id=row.journal_entry_id,
account_id=row.account_id,
amount=row.amount,
description=row.description,
metadata=metadata,
)
lines.append(line)
return lines
async def get_all_journal_entries(limit: int = 100, offset: int = 0) -> list[JournalEntry]:
entries_data = await db.fetchall(
"""
SELECT * FROM journal_entries
ORDER BY entry_date DESC, created_at DESC
LIMIT :limit OFFSET :offset
""",
{"limit": limit, "offset": offset},
)
entries = []
for entry_data in entries_data:
# Parse flag and meta from database
from .models import JournalEntryFlag
flag = JournalEntryFlag(entry_data.get("flag", "*"))
meta = json.loads(entry_data.get("meta", "{}")) if entry_data.get("meta") else {}
entry = JournalEntry(
id=entry_data["id"],
description=entry_data["description"],
entry_date=entry_data["entry_date"],
created_by=entry_data["created_by"],
created_at=entry_data["created_at"],
reference=entry_data["reference"],
flag=flag,
meta=meta,
lines=[],
)
entry.lines = await get_entry_lines(entry.id)
entries.append(entry)
return entries
async def get_journal_entries_by_user(
user_id: str, limit: int = 100, offset: int = 0
) -> list[JournalEntry]:
"""Get journal entries that affect the user's accounts"""
# Get all user-specific accounts
user_accounts = await db.fetchall(
"SELECT id FROM accounts WHERE user_id = :user_id",
{"user_id": user_id},
)
if not user_accounts:
return []
account_ids = [acc["id"] for acc in user_accounts]
# Get all journal entries that have lines affecting these accounts
# Build the IN clause with named parameters
placeholders = ','.join([f":account_{i}" for i in range(len(account_ids))])
params = {f"account_{i}": acc_id for i, acc_id in enumerate(account_ids)}
params["limit"] = limit
params["offset"] = offset
entries_data = await db.fetchall(
f"""
SELECT DISTINCT je.*
FROM journal_entries je
JOIN entry_lines el ON je.id = el.journal_entry_id
WHERE el.account_id IN ({placeholders})
ORDER BY je.entry_date DESC, je.created_at DESC
LIMIT :limit OFFSET :offset
""",
params,
)
entries = []
for entry_data in entries_data:
# Parse flag and meta from database
from .models import JournalEntryFlag
flag = JournalEntryFlag(entry_data.get("flag", "*"))
meta = json.loads(entry_data.get("meta", "{}")) if entry_data.get("meta") else {}
entry = JournalEntry(
id=entry_data["id"],
description=entry_data["description"],
entry_date=entry_data["entry_date"],
created_by=entry_data["created_by"],
created_at=entry_data["created_at"],
reference=entry_data["reference"],
flag=flag,
meta=meta,
lines=[],
)
entry.lines = await get_entry_lines(entry.id)
entries.append(entry)
return entries
async def count_all_journal_entries() -> int:
"""Count total number of journal entries"""
result = await db.fetchone(
"SELECT COUNT(*) as total FROM journal_entries"
)
return result["total"] if result else 0
async def count_journal_entries_by_user(user_id: str) -> int:
"""Count journal entries that affect the user's accounts"""
# Get all user-specific accounts
user_accounts = await db.fetchall(
"SELECT id FROM accounts WHERE user_id = :user_id",
{"user_id": user_id},
)
if not user_accounts:
return 0
account_ids = [acc["id"] for acc in user_accounts]
# Count journal entries that have lines affecting these accounts
placeholders = ','.join([f":account_{i}" for i in range(len(account_ids))])
params = {f"account_{i}": acc_id for i, acc_id in enumerate(account_ids)}
result = await db.fetchone(
f"""
SELECT COUNT(DISTINCT je.id) as total
FROM journal_entries je
JOIN entry_lines el ON je.id = el.journal_entry_id
WHERE el.account_id IN ({placeholders})
""",
params,
)
return result["total"] if result else 0
async def get_journal_entries_by_user_and_account_type(
user_id: str, account_type: str, limit: int = 100, offset: int = 0
) -> list[JournalEntry]:
"""Get journal entries that affect the user's accounts filtered by account type"""
# Get all user-specific accounts of the specified type
user_accounts = await db.fetchall(
"SELECT id FROM accounts WHERE user_id = :user_id AND account_type = :account_type",
{"user_id": user_id, "account_type": account_type},
)
if not user_accounts:
return []
account_ids = [acc["id"] for acc in user_accounts]
# Get all journal entries that have lines affecting these accounts
placeholders = ','.join([f":account_{i}" for i in range(len(account_ids))])
params = {f"account_{i}": acc_id for i, acc_id in enumerate(account_ids)}
params["limit"] = limit
params["offset"] = offset
entries_data = await db.fetchall(
f"""
SELECT DISTINCT je.*
FROM journal_entries je
JOIN entry_lines el ON je.id = el.journal_entry_id
WHERE el.account_id IN ({placeholders})
ORDER BY je.entry_date DESC, je.created_at DESC
LIMIT :limit OFFSET :offset
""",
params,
)
entries = []
for entry_data in entries_data:
# Parse flag and meta from database
from .models import JournalEntryFlag
flag = JournalEntryFlag(entry_data.get("flag", "*"))
meta = json.loads(entry_data.get("meta", "{}")) if entry_data.get("meta") else {}
entry = JournalEntry(
id=entry_data["id"],
description=entry_data["description"],
entry_date=entry_data["entry_date"],
created_by=entry_data["created_by"],
created_at=entry_data["created_at"],
reference=entry_data["reference"],
flag=flag,
meta=meta,
lines=[],
)
entry.lines = await get_entry_lines(entry.id)
entries.append(entry)
return entries
async def count_journal_entries_by_user_and_account_type(user_id: str, account_type: str) -> int:
"""Count journal entries that affect the user's accounts filtered by account type"""
# Get all user-specific accounts of the specified type
user_accounts = await db.fetchall(
"SELECT id FROM accounts WHERE user_id = :user_id AND account_type = :account_type",
{"user_id": user_id, "account_type": account_type},
)
if not user_accounts:
return 0
account_ids = [acc["id"] for acc in user_accounts]
# Count journal entries that have lines affecting these accounts
placeholders = ','.join([f":account_{i}" for i in range(len(account_ids))])
params = {f"account_{i}": acc_id for i, acc_id in enumerate(account_ids)}
result = await db.fetchone(
f"""
SELECT COUNT(DISTINCT je.id) as total
FROM journal_entries je
JOIN entry_lines el ON je.id = el.journal_entry_id
WHERE el.account_id IN ({placeholders})
""",
params,
)
return result["total"] if result else 0
# ===== BALANCE AND REPORTING =====
async def get_account_balance(account_id: str) -> int:
"""
Calculate account balance using single amount field (Beancount-style).
Only includes entries that are cleared (flag='*'), excludes pending/flagged/voided entries.
For each account type:
- Assets/Expenses: balance = sum of amounts (positive amounts increase, negative decrease)
- Liabilities/Equity/Revenue: balance = -sum of amounts (negative amounts increase, positive decrease)
This works because we store amounts consistently:
- Debit (asset/expense increase) = positive amount
- Credit (liability/equity/revenue increase) = negative amount
"""
result = await db.fetchone(
"""
SELECT COALESCE(SUM(el.amount), 0) as total_amount
FROM entry_lines el
JOIN journal_entries je ON el.journal_entry_id = je.id
WHERE el.account_id = :id
AND je.flag = '*'
""",
{"id": account_id},
)
if not result:
return 0
account = await get_account(account_id)
if not account:
return 0
total_amount = result["total_amount"]
# Use core BalanceCalculator for consistent logic
core_account_type = CoreAccountType(account.account_type.value)
return BalanceCalculator.calculate_account_balance_from_amount(
total_amount, core_account_type
)
async def get_user_balance(user_id: str) -> UserBalance:
"""Get user's balance with the Castle (positive = castle owes user, negative = user owes castle)"""
# Get all user-specific accounts
user_accounts = await db.fetchall(
"SELECT * FROM accounts WHERE user_id = :user_id",
{"user_id": user_id},
Account,
)
# Calculate balances for each account
account_balances = {}
account_inventories = {}
for account in user_accounts:
# Get satoshi balance
balance = await get_account_balance(account.id)
account_balances[account.id] = balance
# Get all entry lines for this account to build inventory
# Only include cleared entries (exclude pending/flagged/voided)
entry_lines = await db.fetchall(
"""
SELECT el.*
FROM entry_lines el
JOIN journal_entries je ON el.journal_entry_id = je.id
WHERE el.account_id = :account_id
AND je.flag = '*'
""",
{"account_id": account.id},
)
# Use BalanceCalculator to build inventory from entry lines
core_account_type = CoreAccountType(account.account_type.value)
inventory = BalanceCalculator.build_inventory_from_entry_lines(
[dict(line) for line in entry_lines],
core_account_type
)
account_inventories[account.id] = inventory
# Use BalanceCalculator to calculate total user balance
accounts_list = [
{"id": acc.id, "account_type": acc.account_type.value}
for acc in user_accounts
]
balance_result = BalanceCalculator.calculate_user_balance(
accounts_list,
account_balances,
account_inventories
)
return UserBalance(
user_id=user_id,
balance=balance_result["balance"],
accounts=user_accounts,
fiat_balances=balance_result["fiat_balances"],
)
async def get_all_user_balances() -> list[UserBalance]:
"""Get balances for all users (used by castle to see who they owe)"""
# Get all user-specific accounts
all_accounts = await db.fetchall(
"SELECT * FROM accounts WHERE user_id IS NOT NULL",
{},
Account,
)
# Get unique user IDs
user_ids = set(account.user_id for account in all_accounts if account.user_id)
# Calculate balance for each user using the refactored function
user_balances = []
for user_id in user_ids:
balance = await get_user_balance(user_id)
# Include users with non-zero balance or fiat balances
if balance.balance != 0 or balance.fiat_balances:
user_balances.append(balance)
return user_balances
async def get_account_transactions(
account_id: str, limit: int = 100
) -> list[tuple[JournalEntry, EntryLine]]:
"""Get all transactions affecting a specific account"""
rows = await db.fetchall(
"""
SELECT * FROM entry_lines
WHERE account_id = :id
ORDER BY id DESC
LIMIT :limit
""",
{"id": account_id, "limit": limit},
)
transactions = []
for row in rows:
# Parse metadata from JSON string
metadata = json.loads(row.metadata) if row.metadata else {}
line = EntryLine(
id=row.id,
journal_entry_id=row.journal_entry_id,
account_id=row.account_id,
amount=row.amount,
description=row.description,
metadata=metadata,
)
entry = await get_journal_entry(line.journal_entry_id)
if entry:
transactions.append((entry, line))
return transactions
# ===== SETTINGS =====
async def create_castle_settings(
user_id: str, data: CastleSettings
) -> CastleSettings:
settings = UserCastleSettings(**data.dict(), id=user_id)
await db.insert("extension_settings", settings)
return settings
async def get_castle_settings(user_id: str) -> Optional[CastleSettings]:
return await db.fetchone(
"""
SELECT * FROM extension_settings
WHERE id = :user_id
""",
{"user_id": user_id},
CastleSettings,
)
async def update_castle_settings(
user_id: str, data: CastleSettings
) -> CastleSettings:
settings = UserCastleSettings(**data.dict(), id=user_id)
await db.update("extension_settings", settings)
return settings
# ===== USER WALLET SETTINGS =====
async def create_user_wallet_settings(
user_id: str, data: UserWalletSettings
) -> UserWalletSettings:
settings = StoredUserWalletSettings(**data.dict(), id=user_id)
await db.insert("user_wallet_settings", settings)
return settings
async def get_user_wallet_settings(user_id: str) -> Optional[UserWalletSettings]:
return await db.fetchone(
"""
SELECT * FROM user_wallet_settings
WHERE id = :user_id
""",
{"user_id": user_id},
UserWalletSettings,
)
async def update_user_wallet_settings(
user_id: str, data: UserWalletSettings
) -> UserWalletSettings:
settings = StoredUserWalletSettings(**data.dict(), id=user_id)
await db.update("user_wallet_settings", settings)
return settings
async def get_all_user_wallet_settings() -> list[StoredUserWalletSettings]:
"""Get all user wallet settings"""
return await db.fetchall(
"SELECT * FROM user_wallet_settings ORDER BY id",
{},
StoredUserWalletSettings,
)
# ===== MANUAL PAYMENT REQUESTS =====
async def create_manual_payment_request(
user_id: str, amount: int, description: str
) -> "ManualPaymentRequest":
"""Create a new manual payment request"""
from .models import ManualPaymentRequest
request_id = urlsafe_short_hash()
request = ManualPaymentRequest(
id=request_id,
user_id=user_id,
amount=amount,
description=description,
status="pending",
created_at=datetime.now(),
)
await db.execute(
"""
INSERT INTO manual_payment_requests (id, user_id, amount, description, status, created_at)
VALUES (:id, :user_id, :amount, :description, :status, :created_at)
""",
{
"id": request.id,
"user_id": request.user_id,
"amount": request.amount,
"description": request.description,
"status": request.status,
"created_at": request.created_at,
},
)
return request
async def get_manual_payment_request(request_id: str) -> Optional["ManualPaymentRequest"]:
"""Get a manual payment request by ID"""
from .models import ManualPaymentRequest
return await db.fetchone(
"SELECT * FROM manual_payment_requests WHERE id = :id",
{"id": request_id},
ManualPaymentRequest,
)
async def get_user_manual_payment_requests(
user_id: str, limit: int = 100
) -> list["ManualPaymentRequest"]:
"""Get all manual payment requests for a specific user"""
from .models import ManualPaymentRequest
return await db.fetchall(
"""
SELECT * FROM manual_payment_requests
WHERE user_id = :user_id
ORDER BY created_at DESC
LIMIT :limit
""",
{"user_id": user_id, "limit": limit},
ManualPaymentRequest,
)
async def get_all_manual_payment_requests(
status: Optional[str] = None, limit: int = 100
) -> list["ManualPaymentRequest"]:
"""Get all manual payment requests, optionally filtered by status"""
from .models import ManualPaymentRequest
if status:
return await db.fetchall(
"""
SELECT * FROM manual_payment_requests
WHERE status = :status
ORDER BY created_at DESC
LIMIT :limit
""",
{"status": status, "limit": limit},
ManualPaymentRequest,
)
else:
return await db.fetchall(
"""
SELECT * FROM manual_payment_requests
ORDER BY created_at DESC
LIMIT :limit
""",
{"limit": limit},
ManualPaymentRequest,
)
async def approve_manual_payment_request(
request_id: str, reviewed_by: str, journal_entry_id: str
) -> Optional["ManualPaymentRequest"]:
"""Approve a manual payment request"""
from .models import ManualPaymentRequest
await db.execute(
"""
UPDATE manual_payment_requests
SET status = 'approved', reviewed_at = :reviewed_at, reviewed_by = :reviewed_by, journal_entry_id = :journal_entry_id
WHERE id = :id
""",
{
"id": request_id,
"reviewed_at": datetime.now(),
"reviewed_by": reviewed_by,
"journal_entry_id": journal_entry_id,
},
)
return await get_manual_payment_request(request_id)
async def reject_manual_payment_request(
request_id: str, reviewed_by: str
) -> Optional["ManualPaymentRequest"]:
"""Reject a manual payment request"""
from .models import ManualPaymentRequest
await db.execute(
"""
UPDATE manual_payment_requests
SET status = 'rejected', reviewed_at = :reviewed_at, reviewed_by = :reviewed_by
WHERE id = :id
""",
{
"id": request_id,
"reviewed_at": datetime.now(),
"reviewed_by": reviewed_by,
},
)
return await get_manual_payment_request(request_id)
# ===== BALANCE ASSERTION OPERATIONS =====
async def create_balance_assertion(
data: CreateBalanceAssertion, created_by: str
) -> BalanceAssertion:
"""Create a new balance assertion"""
from decimal import Decimal
assertion_id = urlsafe_short_hash()
assertion_date = data.date if data.date else datetime.now()
assertion = BalanceAssertion(
id=assertion_id,
date=assertion_date,
account_id=data.account_id,
expected_balance_sats=data.expected_balance_sats,
expected_balance_fiat=data.expected_balance_fiat,
fiat_currency=data.fiat_currency,
tolerance_sats=data.tolerance_sats,
tolerance_fiat=data.tolerance_fiat,
status=AssertionStatus.PENDING,
created_by=created_by,
created_at=datetime.now(),
)
# Manually insert with Decimal fields converted to strings
await db.execute(
"""
INSERT INTO balance_assertions (
id, date, account_id, expected_balance_sats, expected_balance_fiat,
fiat_currency, tolerance_sats, tolerance_fiat, status, created_by, created_at
) VALUES (
:id, :date, :account_id, :expected_balance_sats, :expected_balance_fiat,
:fiat_currency, :tolerance_sats, :tolerance_fiat, :status, :created_by, :created_at
)
""",
{
"id": assertion.id,
"date": assertion.date,
"account_id": assertion.account_id,
"expected_balance_sats": assertion.expected_balance_sats,
"expected_balance_fiat": str(assertion.expected_balance_fiat) if assertion.expected_balance_fiat else None,
"fiat_currency": assertion.fiat_currency,
"tolerance_sats": assertion.tolerance_sats,
"tolerance_fiat": str(assertion.tolerance_fiat),
"status": assertion.status.value,
"created_by": assertion.created_by,
"created_at": assertion.created_at,
},
)
return assertion
async def get_balance_assertion(assertion_id: str) -> Optional[BalanceAssertion]:
"""Get a balance assertion by ID"""
from decimal import Decimal
row = await db.fetchone(
"SELECT * FROM balance_assertions WHERE id = :id",
{"id": assertion_id},
)
if not row:
return None
# Parse Decimal fields from TEXT storage
return BalanceAssertion(
id=row["id"],
date=row["date"],
account_id=row["account_id"],
expected_balance_sats=row["expected_balance_sats"],
expected_balance_fiat=Decimal(row["expected_balance_fiat"]) if row["expected_balance_fiat"] else None,
fiat_currency=row["fiat_currency"],
tolerance_sats=row["tolerance_sats"],
tolerance_fiat=Decimal(row["tolerance_fiat"]) if row["tolerance_fiat"] else Decimal("0"),
checked_balance_sats=row["checked_balance_sats"],
checked_balance_fiat=Decimal(row["checked_balance_fiat"]) if row["checked_balance_fiat"] else None,
difference_sats=row["difference_sats"],
difference_fiat=Decimal(row["difference_fiat"]) if row["difference_fiat"] else None,
status=AssertionStatus(row["status"]),
created_by=row["created_by"],
created_at=row["created_at"],
checked_at=row["checked_at"],
)
async def get_balance_assertions(
account_id: Optional[str] = None,
status: Optional[AssertionStatus] = None,
limit: int = 100,
) -> list[BalanceAssertion]:
"""Get balance assertions with optional filters"""
from decimal import Decimal
if account_id and status:
rows = await db.fetchall(
"""
SELECT * FROM balance_assertions
WHERE account_id = :account_id AND status = :status
ORDER BY date DESC
LIMIT :limit
""",
{"account_id": account_id, "status": status.value, "limit": limit},
)
elif account_id:
rows = await db.fetchall(
"""
SELECT * FROM balance_assertions
WHERE account_id = :account_id
ORDER BY date DESC
LIMIT :limit
""",
{"account_id": account_id, "limit": limit},
)
elif status:
rows = await db.fetchall(
"""
SELECT * FROM balance_assertions
WHERE status = :status
ORDER BY date DESC
LIMIT :limit
""",
{"status": status.value, "limit": limit},
)
else:
rows = await db.fetchall(
"""
SELECT * FROM balance_assertions
ORDER BY date DESC
LIMIT :limit
""",
{"limit": limit},
)
assertions = []
for row in rows:
assertions.append(
BalanceAssertion(
id=row["id"],
date=row["date"],
account_id=row["account_id"],
expected_balance_sats=row["expected_balance_sats"],
expected_balance_fiat=Decimal(row["expected_balance_fiat"]) if row["expected_balance_fiat"] else None,
fiat_currency=row["fiat_currency"],
tolerance_sats=row["tolerance_sats"],
tolerance_fiat=Decimal(row["tolerance_fiat"]) if row["tolerance_fiat"] else Decimal("0"),
checked_balance_sats=row["checked_balance_sats"],
checked_balance_fiat=Decimal(row["checked_balance_fiat"]) if row["checked_balance_fiat"] else None,
difference_sats=row["difference_sats"],
difference_fiat=Decimal(row["difference_fiat"]) if row["difference_fiat"] else None,
status=AssertionStatus(row["status"]),
created_by=row["created_by"],
created_at=row["created_at"],
checked_at=row["checked_at"],
)
)
return assertions
async def check_balance_assertion(assertion_id: str) -> BalanceAssertion:
"""
Check a balance assertion by comparing expected vs actual balance.
Updates the assertion with the check results.
"""
from decimal import Decimal
assertion = await get_balance_assertion(assertion_id)
if not assertion:
raise ValueError(f"Balance assertion {assertion_id} not found")
# Get actual account balance
account = await get_account(assertion.account_id)
if not account:
raise ValueError(f"Account {assertion.account_id} not found")
# Calculate balance at the assertion date
actual_balance = await get_account_balance(assertion.account_id)
# Get fiat balance if needed
actual_fiat_balance = None
if assertion.fiat_currency and account.user_id:
user_balance = await get_user_balance(account.user_id)
actual_fiat_balance = user_balance.fiat_balances.get(assertion.fiat_currency, Decimal("0"))
# Check sats balance
difference_sats = actual_balance - assertion.expected_balance_sats
sats_match = abs(difference_sats) <= assertion.tolerance_sats
# Check fiat balance if applicable
fiat_match = True
difference_fiat = None
if assertion.expected_balance_fiat is not None and actual_fiat_balance is not None:
difference_fiat = actual_fiat_balance - assertion.expected_balance_fiat
fiat_match = abs(difference_fiat) <= assertion.tolerance_fiat
# Determine overall status
status = AssertionStatus.PASSED if (sats_match and fiat_match) else AssertionStatus.FAILED
# Update assertion with check results
await db.execute(
"""
UPDATE balance_assertions
SET checked_balance_sats = :checked_sats,
checked_balance_fiat = :checked_fiat,
difference_sats = :diff_sats,
difference_fiat = :diff_fiat,
status = :status,
checked_at = :checked_at
WHERE id = :id
""",
{
"id": assertion_id,
"checked_sats": actual_balance,
"checked_fiat": str(actual_fiat_balance) if actual_fiat_balance is not None else None,
"diff_sats": difference_sats,
"diff_fiat": str(difference_fiat) if difference_fiat is not None else None,
"status": status.value,
"checked_at": datetime.now(),
},
)
# Return updated assertion
return await get_balance_assertion(assertion_id)
async def delete_balance_assertion(assertion_id: str) -> None:
"""Delete a balance assertion"""
await db.execute(
"DELETE FROM balance_assertions WHERE id = :id",
{"id": assertion_id},
)
# User Equity Status CRUD operations
async def get_user_equity_status(user_id: str) -> Optional["UserEquityStatus"]:
"""Get user's equity eligibility status"""
from .models import UserEquityStatus
row = await db.fetchone(
"""
SELECT * FROM user_equity_status
WHERE user_id = :user_id
""",
{"user_id": user_id},
)
return UserEquityStatus(**row) if row else None
async def create_or_update_user_equity_status(
data: "CreateUserEquityStatus", granted_by: str
) -> "UserEquityStatus":
"""Create or update user equity eligibility status"""
from datetime import datetime
from .models import UserEquityStatus, AccountType
import uuid
# Auto-create user-specific equity account if granting eligibility
if data.is_equity_eligible:
# Generate equity account name: Equity:User-{user_id}
equity_account_name = f"Equity:User-{data.user_id[:8]}"
# Check if the equity account already exists
equity_account = await get_account_by_name(equity_account_name)
if not equity_account:
# Create the user-specific equity account
await db.execute(
"""
INSERT INTO accounts (id, name, account_type, description, user_id, created_at)
VALUES (:id, :name, :type, :description, :user_id, :created_at)
""",
{
"id": str(uuid.uuid4()),
"name": equity_account_name,
"type": AccountType.EQUITY.value,
"description": f"Equity contributions for user {data.user_id[:8]}",
"user_id": data.user_id,
"created_at": datetime.now(),
},
)
# Auto-populate equity_account_name in the data
data.equity_account_name = equity_account_name
# Check if user already has equity status
existing = await get_user_equity_status(data.user_id)
if existing:
# Update existing record
await db.execute(
"""
UPDATE user_equity_status
SET is_equity_eligible = :is_equity_eligible,
equity_account_name = :equity_account_name,
notes = :notes,
granted_by = :granted_by,
granted_at = :granted_at,
revoked_at = :revoked_at
WHERE user_id = :user_id
""",
{
"user_id": data.user_id,
"is_equity_eligible": data.is_equity_eligible,
"equity_account_name": data.equity_account_name,
"notes": data.notes,
"granted_by": granted_by,
"granted_at": datetime.now(),
"revoked_at": None if data.is_equity_eligible else datetime.now(),
},
)
else:
# Create new record
await db.execute(
"""
INSERT INTO user_equity_status (
user_id, is_equity_eligible, equity_account_name,
notes, granted_by, granted_at
)
VALUES (
:user_id, :is_equity_eligible, :equity_account_name,
:notes, :granted_by, :granted_at
)
""",
{
"user_id": data.user_id,
"is_equity_eligible": data.is_equity_eligible,
"equity_account_name": data.equity_account_name,
"notes": data.notes,
"granted_by": granted_by,
"granted_at": datetime.now(),
},
)
# Return the created/updated record
result = await get_user_equity_status(data.user_id)
if not result:
raise ValueError(f"Failed to create/update equity status for user {data.user_id}")
return result
async def revoke_user_equity_eligibility(user_id: str) -> Optional["UserEquityStatus"]:
"""Revoke user's equity contribution eligibility"""
from datetime import datetime
await db.execute(
"""
UPDATE user_equity_status
SET is_equity_eligible = FALSE,
revoked_at = :revoked_at
WHERE user_id = :user_id
""",
{"user_id": user_id, "revoked_at": datetime.now()},
)
return await get_user_equity_status(user_id)
async def get_all_equity_eligible_users() -> list["UserEquityStatus"]:
"""Get all equity-eligible users"""
from .models import UserEquityStatus
rows = await db.fetchall(
"""
SELECT * FROM user_equity_status
WHERE is_equity_eligible = TRUE
ORDER BY granted_at DESC
"""
)
return [UserEquityStatus(**row) for row in rows]
# ===== ACCOUNT PERMISSION OPERATIONS =====
async def create_account_permission(
data: "CreateAccountPermission", granted_by: str
) -> "AccountPermission":
"""Create a new account permission"""
from .models import AccountPermission
permission_id = urlsafe_short_hash()
permission = AccountPermission(
id=permission_id,
user_id=data.user_id,
account_id=data.account_id,
permission_type=data.permission_type,
granted_by=granted_by,
granted_at=datetime.now(),
expires_at=data.expires_at,
notes=data.notes,
)
await db.execute(
"""
INSERT INTO account_permissions (
id, user_id, account_id, permission_type, granted_by,
granted_at, expires_at, notes
)
VALUES (
:id, :user_id, :account_id, :permission_type, :granted_by,
:granted_at, :expires_at, :notes
)
""",
{
"id": permission.id,
"user_id": permission.user_id,
"account_id": permission.account_id,
"permission_type": permission.permission_type.value,
"granted_by": permission.granted_by,
"granted_at": permission.granted_at,
"expires_at": permission.expires_at,
"notes": permission.notes,
},
)
return permission
async def get_account_permission(permission_id: str) -> Optional["AccountPermission"]:
"""Get account permission by ID"""
from .models import AccountPermission, PermissionType
row = await db.fetchone(
"SELECT * FROM account_permissions WHERE id = :id",
{"id": permission_id},
)
if not row:
return None
return AccountPermission(
id=row["id"],
user_id=row["user_id"],
account_id=row["account_id"],
permission_type=PermissionType(row["permission_type"]),
granted_by=row["granted_by"],
granted_at=row["granted_at"],
expires_at=row["expires_at"],
notes=row["notes"],
)
async def get_user_permissions(
user_id: str, permission_type: Optional["PermissionType"] = None
) -> list["AccountPermission"]:
"""Get all permissions for a specific user"""
from .models import AccountPermission, PermissionType
if permission_type:
rows = await db.fetchall(
"""
SELECT * FROM account_permissions
WHERE user_id = :user_id
AND permission_type = :permission_type
AND (expires_at IS NULL OR expires_at > :now)
ORDER BY granted_at DESC
""",
{
"user_id": user_id,
"permission_type": permission_type.value,
"now": datetime.now(),
},
)
else:
rows = await db.fetchall(
"""
SELECT * FROM account_permissions
WHERE user_id = :user_id
AND (expires_at IS NULL OR expires_at > :now)
ORDER BY granted_at DESC
""",
{"user_id": user_id, "now": datetime.now()},
)
return [
AccountPermission(
id=row["id"],
user_id=row["user_id"],
account_id=row["account_id"],
permission_type=PermissionType(row["permission_type"]),
granted_by=row["granted_by"],
granted_at=row["granted_at"],
expires_at=row["expires_at"],
notes=row["notes"],
)
for row in rows
]
async def get_account_permissions(account_id: str) -> list["AccountPermission"]:
"""Get all permissions for a specific account"""
from .models import AccountPermission, PermissionType
rows = await db.fetchall(
"""
SELECT * FROM account_permissions
WHERE account_id = :account_id
AND (expires_at IS NULL OR expires_at > :now)
ORDER BY granted_at DESC
""",
{"account_id": account_id, "now": datetime.now()},
)
return [
AccountPermission(
id=row["id"],
user_id=row["user_id"],
account_id=row["account_id"],
permission_type=PermissionType(row["permission_type"]),
granted_by=row["granted_by"],
granted_at=row["granted_at"],
expires_at=row["expires_at"],
notes=row["notes"],
)
for row in rows
]
async def delete_account_permission(permission_id: str) -> None:
"""Delete (revoke) an account permission"""
await db.execute(
"DELETE FROM account_permissions WHERE id = :id",
{"id": permission_id},
)
async def check_user_has_permission(
user_id: str, account_id: str, permission_type: "PermissionType"
) -> bool:
"""Check if user has a specific permission on an account (direct permission only, no inheritance)"""
row = await db.fetchone(
"""
SELECT id FROM account_permissions
WHERE user_id = :user_id
AND account_id = :account_id
AND permission_type = :permission_type
AND (expires_at IS NULL OR expires_at > :now)
""",
{
"user_id": user_id,
"account_id": account_id,
"permission_type": permission_type.value,
"now": datetime.now(),
},
)
return row is not None
async def get_user_permissions_with_inheritance(
user_id: str, account_name: str, permission_type: "PermissionType"
) -> list[tuple["AccountPermission", Optional[str]]]:
"""
Get all permissions for a user on an account, including inherited permissions from parent accounts.
Returns list of tuples: (permission, parent_account_name or None)
Example:
If user has permission on "Expenses:Food", they also have permission on "Expenses:Food:Groceries"
Returns: [(permission_on_food, "Expenses:Food")]
"""
from .models import AccountPermission, PermissionType
# Get all user's permissions of this type
user_permissions = await get_user_permissions(user_id, permission_type)
# Find which permissions apply to this account (direct or inherited)
applicable_permissions = []
for perm in user_permissions:
# Get the account for this permission
account = await get_account(perm.account_id)
if not account:
continue
# Check if this account is a parent of the target account
# Parent accounts are indicated by hierarchical names (colon-separated)
# e.g., "Expenses:Food" is parent of "Expenses:Food:Groceries"
if account_name == account.name:
# Direct permission
applicable_permissions.append((perm, None))
elif account_name.startswith(account.name + ":"):
# Inherited permission from parent account
applicable_permissions.append((perm, account.name))
return applicable_permissions