Implements a background task that listens for paid invoices and automatically records them in the accounting system. This ensures payments are captured even if the user closes their browser before the client-side polling detects the payment. Introduces a new `get_journal_entry_by_reference` function to improve idempotency when recording payments.
951 lines
29 KiB
Python
951 lines
29 KiB
Python
import json
|
|
from datetime import datetime
|
|
from typing import Optional
|
|
|
|
from lnbits.db import Database
|
|
from lnbits.helpers import urlsafe_short_hash
|
|
|
|
from .models import (
|
|
Account,
|
|
AccountType,
|
|
AssertionStatus,
|
|
BalanceAssertion,
|
|
CastleSettings,
|
|
CreateAccount,
|
|
CreateBalanceAssertion,
|
|
CreateEntryLine,
|
|
CreateJournalEntry,
|
|
EntryLine,
|
|
JournalEntry,
|
|
StoredUserWalletSettings,
|
|
UserBalance,
|
|
UserCastleSettings,
|
|
UserWalletSettings,
|
|
)
|
|
|
|
# Import core accounting logic
|
|
from .core.balance import BalanceCalculator, AccountType as CoreAccountType
|
|
from .core.inventory import CastleInventory, CastlePosition
|
|
from .core.validation import (
|
|
ValidationError,
|
|
validate_journal_entry,
|
|
validate_balance,
|
|
validate_receivable_entry,
|
|
validate_expense_entry,
|
|
validate_payment_entry,
|
|
)
|
|
|
|
db = Database("ext_castle")
|
|
|
|
|
|
# ===== ACCOUNT OPERATIONS =====
|
|
|
|
|
|
async def create_account(data: CreateAccount) -> Account:
|
|
account_id = urlsafe_short_hash()
|
|
account = Account(
|
|
id=account_id,
|
|
name=data.name,
|
|
account_type=data.account_type,
|
|
description=data.description,
|
|
user_id=data.user_id,
|
|
created_at=datetime.now(),
|
|
)
|
|
await db.insert("accounts", account)
|
|
return account
|
|
|
|
|
|
async def get_account(account_id: str) -> Optional[Account]:
|
|
return await db.fetchone(
|
|
"SELECT * FROM accounts WHERE id = :id",
|
|
{"id": account_id},
|
|
Account,
|
|
)
|
|
|
|
|
|
async def get_account_by_name(name: str) -> Optional[Account]:
|
|
"""Get account by name (hierarchical format)"""
|
|
return await db.fetchone(
|
|
"SELECT * FROM accounts WHERE name = :name",
|
|
{"name": name},
|
|
Account,
|
|
)
|
|
|
|
|
|
async def get_all_accounts() -> list[Account]:
|
|
return await db.fetchall(
|
|
"SELECT * FROM accounts ORDER BY account_type, name",
|
|
model=Account,
|
|
)
|
|
|
|
|
|
async def get_accounts_by_type(account_type: AccountType) -> list[Account]:
|
|
return await db.fetchall(
|
|
"SELECT * FROM accounts WHERE account_type = :type ORDER BY name",
|
|
{"type": account_type.value},
|
|
Account,
|
|
)
|
|
|
|
|
|
async def get_or_create_user_account(
|
|
user_id: str, account_type: AccountType, base_name: str
|
|
) -> Account:
|
|
"""
|
|
Get or create a user-specific account with hierarchical naming.
|
|
|
|
Examples:
|
|
get_or_create_user_account("af983632", AccountType.ASSET, "Accounts Receivable")
|
|
→ "Assets:Receivable:User-af983632"
|
|
|
|
get_or_create_user_account("af983632", AccountType.LIABILITY, "Accounts Payable")
|
|
→ "Liabilities:Payable:User-af983632"
|
|
"""
|
|
from .account_utils import format_hierarchical_account_name
|
|
|
|
# Generate hierarchical account name
|
|
account_name = format_hierarchical_account_name(account_type, base_name, user_id)
|
|
|
|
# Try to find existing account with this hierarchical name
|
|
account = await db.fetchone(
|
|
"""
|
|
SELECT * FROM accounts
|
|
WHERE user_id = :user_id AND account_type = :type AND name = :name
|
|
""",
|
|
{"user_id": user_id, "type": account_type.value, "name": account_name},
|
|
Account,
|
|
)
|
|
|
|
if not account:
|
|
# Create new account with hierarchical name
|
|
account = await create_account(
|
|
CreateAccount(
|
|
name=account_name,
|
|
account_type=account_type,
|
|
description=f"User-specific {account_type.value} account",
|
|
user_id=user_id,
|
|
)
|
|
)
|
|
|
|
return account
|
|
|
|
|
|
# ===== JOURNAL ENTRY OPERATIONS =====
|
|
|
|
|
|
async def create_journal_entry(
|
|
data: CreateJournalEntry, created_by: str
|
|
) -> JournalEntry:
|
|
entry_id = urlsafe_short_hash()
|
|
|
|
# Validate that debits equal credits
|
|
total_debits = sum(line.debit for line in data.lines)
|
|
total_credits = sum(line.credit for line in data.lines)
|
|
|
|
if total_debits != total_credits:
|
|
raise ValueError(
|
|
f"Journal entry must balance: debits={total_debits}, credits={total_credits}"
|
|
)
|
|
|
|
entry_date = data.entry_date or datetime.now()
|
|
|
|
journal_entry = JournalEntry(
|
|
id=entry_id,
|
|
description=data.description,
|
|
entry_date=entry_date,
|
|
created_by=created_by,
|
|
created_at=datetime.now(),
|
|
reference=data.reference,
|
|
lines=[],
|
|
flag=data.flag,
|
|
meta=data.meta,
|
|
)
|
|
|
|
# Insert journal entry without the lines field (lines are stored in entry_lines table)
|
|
await db.execute(
|
|
"""
|
|
INSERT INTO journal_entries (id, description, entry_date, created_by, created_at, reference, flag, meta)
|
|
VALUES (:id, :description, :entry_date, :created_by, :created_at, :reference, :flag, :meta)
|
|
""",
|
|
{
|
|
"id": journal_entry.id,
|
|
"description": journal_entry.description,
|
|
"entry_date": journal_entry.entry_date,
|
|
"created_by": journal_entry.created_by,
|
|
"created_at": journal_entry.created_at,
|
|
"reference": journal_entry.reference,
|
|
"flag": journal_entry.flag.value,
|
|
"meta": json.dumps(journal_entry.meta),
|
|
},
|
|
)
|
|
|
|
# Create entry lines
|
|
lines = []
|
|
for line_data in data.lines:
|
|
line_id = urlsafe_short_hash()
|
|
line = EntryLine(
|
|
id=line_id,
|
|
journal_entry_id=entry_id,
|
|
account_id=line_data.account_id,
|
|
debit=line_data.debit,
|
|
credit=line_data.credit,
|
|
description=line_data.description,
|
|
metadata=line_data.metadata,
|
|
)
|
|
# Insert with metadata as JSON string
|
|
await db.execute(
|
|
"""
|
|
INSERT INTO entry_lines (id, journal_entry_id, account_id, debit, credit, description, metadata)
|
|
VALUES (:id, :journal_entry_id, :account_id, :debit, :credit, :description, :metadata)
|
|
""",
|
|
{
|
|
"id": line.id,
|
|
"journal_entry_id": line.journal_entry_id,
|
|
"account_id": line.account_id,
|
|
"debit": line.debit,
|
|
"credit": line.credit,
|
|
"description": line.description,
|
|
"metadata": json.dumps(line.metadata),
|
|
},
|
|
)
|
|
lines.append(line)
|
|
|
|
journal_entry.lines = lines
|
|
return journal_entry
|
|
|
|
|
|
async def get_journal_entry(entry_id: str) -> Optional[JournalEntry]:
|
|
entry = await db.fetchone(
|
|
"SELECT * FROM journal_entries WHERE id = :id",
|
|
{"id": entry_id},
|
|
JournalEntry,
|
|
)
|
|
|
|
if entry:
|
|
entry.lines = await get_entry_lines(entry_id)
|
|
|
|
return entry
|
|
|
|
|
|
async def get_journal_entry_by_reference(reference: str) -> Optional[JournalEntry]:
|
|
"""Get a journal entry by its reference field (e.g., payment_hash)"""
|
|
entry = await db.fetchone(
|
|
"SELECT * FROM journal_entries WHERE reference = :reference",
|
|
{"reference": reference},
|
|
JournalEntry,
|
|
)
|
|
|
|
if entry:
|
|
entry.lines = await get_entry_lines(entry.id)
|
|
|
|
return entry
|
|
|
|
|
|
async def get_entry_lines(journal_entry_id: str) -> list[EntryLine]:
|
|
rows = await db.fetchall(
|
|
"SELECT * FROM entry_lines WHERE journal_entry_id = :id",
|
|
{"id": journal_entry_id},
|
|
)
|
|
|
|
lines = []
|
|
for row in rows:
|
|
# Parse metadata from JSON string
|
|
metadata = json.loads(row.metadata) if row.metadata else {}
|
|
line = EntryLine(
|
|
id=row.id,
|
|
journal_entry_id=row.journal_entry_id,
|
|
account_id=row.account_id,
|
|
debit=row.debit,
|
|
credit=row.credit,
|
|
description=row.description,
|
|
metadata=metadata,
|
|
)
|
|
lines.append(line)
|
|
|
|
return lines
|
|
|
|
|
|
async def get_all_journal_entries(limit: int = 100) -> list[JournalEntry]:
|
|
entries_data = await db.fetchall(
|
|
"""
|
|
SELECT * FROM journal_entries
|
|
ORDER BY entry_date DESC, created_at DESC
|
|
LIMIT :limit
|
|
""",
|
|
{"limit": limit},
|
|
)
|
|
|
|
entries = []
|
|
for entry_data in entries_data:
|
|
# Parse flag and meta from database
|
|
from .models import JournalEntryFlag
|
|
flag = JournalEntryFlag(entry_data.get("flag", "*"))
|
|
meta = json.loads(entry_data.get("meta", "{}")) if entry_data.get("meta") else {}
|
|
|
|
entry = JournalEntry(
|
|
id=entry_data["id"],
|
|
description=entry_data["description"],
|
|
entry_date=entry_data["entry_date"],
|
|
created_by=entry_data["created_by"],
|
|
created_at=entry_data["created_at"],
|
|
reference=entry_data["reference"],
|
|
flag=flag,
|
|
meta=meta,
|
|
lines=[],
|
|
)
|
|
entry.lines = await get_entry_lines(entry.id)
|
|
entries.append(entry)
|
|
|
|
return entries
|
|
|
|
|
|
async def get_journal_entries_by_user(
|
|
user_id: str, limit: int = 100
|
|
) -> list[JournalEntry]:
|
|
"""Get journal entries that affect the user's accounts"""
|
|
# Get all user-specific accounts
|
|
user_accounts = await db.fetchall(
|
|
"SELECT id FROM accounts WHERE user_id = :user_id",
|
|
{"user_id": user_id},
|
|
)
|
|
|
|
if not user_accounts:
|
|
return []
|
|
|
|
account_ids = [acc["id"] for acc in user_accounts]
|
|
|
|
# Get all journal entries that have lines affecting these accounts
|
|
# Build the IN clause with named parameters
|
|
placeholders = ','.join([f":account_{i}" for i in range(len(account_ids))])
|
|
params = {f"account_{i}": acc_id for i, acc_id in enumerate(account_ids)}
|
|
params["limit"] = limit
|
|
|
|
entries_data = await db.fetchall(
|
|
f"""
|
|
SELECT DISTINCT je.*
|
|
FROM journal_entries je
|
|
JOIN entry_lines el ON je.id = el.journal_entry_id
|
|
WHERE el.account_id IN ({placeholders})
|
|
ORDER BY je.entry_date DESC, je.created_at DESC
|
|
LIMIT :limit
|
|
""",
|
|
params,
|
|
)
|
|
|
|
entries = []
|
|
for entry_data in entries_data:
|
|
# Parse flag and meta from database
|
|
from .models import JournalEntryFlag
|
|
flag = JournalEntryFlag(entry_data.get("flag", "*"))
|
|
meta = json.loads(entry_data.get("meta", "{}")) if entry_data.get("meta") else {}
|
|
|
|
entry = JournalEntry(
|
|
id=entry_data["id"],
|
|
description=entry_data["description"],
|
|
entry_date=entry_data["entry_date"],
|
|
created_by=entry_data["created_by"],
|
|
created_at=entry_data["created_at"],
|
|
reference=entry_data["reference"],
|
|
flag=flag,
|
|
meta=meta,
|
|
lines=[],
|
|
)
|
|
entry.lines = await get_entry_lines(entry.id)
|
|
entries.append(entry)
|
|
|
|
return entries
|
|
|
|
|
|
# ===== BALANCE AND REPORTING =====
|
|
|
|
|
|
async def get_account_balance(account_id: str) -> int:
|
|
"""Calculate account balance (debits - credits for assets/expenses, credits - debits for liabilities/equity/revenue)
|
|
Only includes entries that are cleared (flag='*'), excludes pending/flagged/voided entries."""
|
|
result = await db.fetchone(
|
|
"""
|
|
SELECT
|
|
COALESCE(SUM(el.debit), 0) as total_debit,
|
|
COALESCE(SUM(el.credit), 0) as total_credit
|
|
FROM entry_lines el
|
|
JOIN journal_entries je ON el.journal_entry_id = je.id
|
|
WHERE el.account_id = :id
|
|
AND je.flag = '*'
|
|
""",
|
|
{"id": account_id},
|
|
)
|
|
|
|
if not result:
|
|
return 0
|
|
|
|
account = await get_account(account_id)
|
|
if not account:
|
|
return 0
|
|
|
|
total_debit = result["total_debit"]
|
|
total_credit = result["total_credit"]
|
|
|
|
# Use core BalanceCalculator for consistent logic
|
|
core_account_type = CoreAccountType(account.account_type.value)
|
|
return BalanceCalculator.calculate_account_balance(
|
|
total_debit, total_credit, core_account_type
|
|
)
|
|
|
|
|
|
async def get_user_balance(user_id: str) -> UserBalance:
|
|
"""Get user's balance with the Castle (positive = castle owes user, negative = user owes castle)"""
|
|
# Get all user-specific accounts
|
|
user_accounts = await db.fetchall(
|
|
"SELECT * FROM accounts WHERE user_id = :user_id",
|
|
{"user_id": user_id},
|
|
Account,
|
|
)
|
|
|
|
# Calculate balances for each account
|
|
account_balances = {}
|
|
account_inventories = {}
|
|
|
|
for account in user_accounts:
|
|
# Get satoshi balance
|
|
balance = await get_account_balance(account.id)
|
|
account_balances[account.id] = balance
|
|
|
|
# Get all entry lines for this account to build inventory
|
|
# Only include cleared entries (exclude pending/flagged/voided)
|
|
entry_lines = await db.fetchall(
|
|
"""
|
|
SELECT el.*
|
|
FROM entry_lines el
|
|
JOIN journal_entries je ON el.journal_entry_id = je.id
|
|
WHERE el.account_id = :account_id
|
|
AND je.flag = '*'
|
|
""",
|
|
{"account_id": account.id},
|
|
)
|
|
|
|
# Use BalanceCalculator to build inventory from entry lines
|
|
core_account_type = CoreAccountType(account.account_type.value)
|
|
inventory = BalanceCalculator.build_inventory_from_entry_lines(
|
|
[dict(line) for line in entry_lines],
|
|
core_account_type
|
|
)
|
|
account_inventories[account.id] = inventory
|
|
|
|
# Use BalanceCalculator to calculate total user balance
|
|
accounts_list = [
|
|
{"id": acc.id, "account_type": acc.account_type.value}
|
|
for acc in user_accounts
|
|
]
|
|
balance_result = BalanceCalculator.calculate_user_balance(
|
|
accounts_list,
|
|
account_balances,
|
|
account_inventories
|
|
)
|
|
|
|
return UserBalance(
|
|
user_id=user_id,
|
|
balance=balance_result["balance"],
|
|
accounts=user_accounts,
|
|
fiat_balances=balance_result["fiat_balances"],
|
|
)
|
|
|
|
|
|
async def get_all_user_balances() -> list[UserBalance]:
|
|
"""Get balances for all users (used by castle to see who they owe)"""
|
|
# Get all user-specific accounts
|
|
all_accounts = await db.fetchall(
|
|
"SELECT * FROM accounts WHERE user_id IS NOT NULL",
|
|
{},
|
|
Account,
|
|
)
|
|
|
|
# Get unique user IDs
|
|
user_ids = set(account.user_id for account in all_accounts if account.user_id)
|
|
|
|
# Calculate balance for each user using the refactored function
|
|
user_balances = []
|
|
for user_id in user_ids:
|
|
balance = await get_user_balance(user_id)
|
|
|
|
# Include users with non-zero balance or fiat balances
|
|
if balance.balance != 0 or balance.fiat_balances:
|
|
user_balances.append(balance)
|
|
|
|
return user_balances
|
|
|
|
|
|
async def get_account_transactions(
|
|
account_id: str, limit: int = 100
|
|
) -> list[tuple[JournalEntry, EntryLine]]:
|
|
"""Get all transactions affecting a specific account"""
|
|
rows = await db.fetchall(
|
|
"""
|
|
SELECT * FROM entry_lines
|
|
WHERE account_id = :id
|
|
ORDER BY id DESC
|
|
LIMIT :limit
|
|
""",
|
|
{"id": account_id, "limit": limit},
|
|
)
|
|
|
|
transactions = []
|
|
for row in rows:
|
|
# Parse metadata from JSON string
|
|
metadata = json.loads(row.metadata) if row.metadata else {}
|
|
line = EntryLine(
|
|
id=row.id,
|
|
journal_entry_id=row.journal_entry_id,
|
|
account_id=row.account_id,
|
|
debit=row.debit,
|
|
credit=row.credit,
|
|
description=row.description,
|
|
metadata=metadata,
|
|
)
|
|
entry = await get_journal_entry(line.journal_entry_id)
|
|
if entry:
|
|
transactions.append((entry, line))
|
|
|
|
return transactions
|
|
|
|
|
|
# ===== SETTINGS =====
|
|
|
|
|
|
async def create_castle_settings(
|
|
user_id: str, data: CastleSettings
|
|
) -> CastleSettings:
|
|
settings = UserCastleSettings(**data.dict(), id=user_id)
|
|
await db.insert("extension_settings", settings)
|
|
return settings
|
|
|
|
|
|
async def get_castle_settings(user_id: str) -> Optional[CastleSettings]:
|
|
return await db.fetchone(
|
|
"""
|
|
SELECT * FROM extension_settings
|
|
WHERE id = :user_id
|
|
""",
|
|
{"user_id": user_id},
|
|
CastleSettings,
|
|
)
|
|
|
|
|
|
async def update_castle_settings(
|
|
user_id: str, data: CastleSettings
|
|
) -> CastleSettings:
|
|
settings = UserCastleSettings(**data.dict(), id=user_id)
|
|
await db.update("extension_settings", settings)
|
|
return settings
|
|
|
|
|
|
# ===== USER WALLET SETTINGS =====
|
|
|
|
|
|
async def create_user_wallet_settings(
|
|
user_id: str, data: UserWalletSettings
|
|
) -> UserWalletSettings:
|
|
settings = StoredUserWalletSettings(**data.dict(), id=user_id)
|
|
await db.insert("user_wallet_settings", settings)
|
|
return settings
|
|
|
|
|
|
async def get_user_wallet_settings(user_id: str) -> Optional[UserWalletSettings]:
|
|
return await db.fetchone(
|
|
"""
|
|
SELECT * FROM user_wallet_settings
|
|
WHERE id = :user_id
|
|
""",
|
|
{"user_id": user_id},
|
|
UserWalletSettings,
|
|
)
|
|
|
|
|
|
async def update_user_wallet_settings(
|
|
user_id: str, data: UserWalletSettings
|
|
) -> UserWalletSettings:
|
|
settings = StoredUserWalletSettings(**data.dict(), id=user_id)
|
|
await db.update("user_wallet_settings", settings)
|
|
return settings
|
|
|
|
|
|
async def get_all_user_wallet_settings() -> list[StoredUserWalletSettings]:
|
|
"""Get all user wallet settings"""
|
|
return await db.fetchall(
|
|
"SELECT * FROM user_wallet_settings ORDER BY id",
|
|
{},
|
|
StoredUserWalletSettings,
|
|
)
|
|
|
|
|
|
# ===== MANUAL PAYMENT REQUESTS =====
|
|
|
|
|
|
async def create_manual_payment_request(
|
|
user_id: str, amount: int, description: str
|
|
) -> "ManualPaymentRequest":
|
|
"""Create a new manual payment request"""
|
|
from .models import ManualPaymentRequest
|
|
|
|
request_id = urlsafe_short_hash()
|
|
request = ManualPaymentRequest(
|
|
id=request_id,
|
|
user_id=user_id,
|
|
amount=amount,
|
|
description=description,
|
|
status="pending",
|
|
created_at=datetime.now(),
|
|
)
|
|
|
|
await db.execute(
|
|
"""
|
|
INSERT INTO manual_payment_requests (id, user_id, amount, description, status, created_at)
|
|
VALUES (:id, :user_id, :amount, :description, :status, :created_at)
|
|
""",
|
|
{
|
|
"id": request.id,
|
|
"user_id": request.user_id,
|
|
"amount": request.amount,
|
|
"description": request.description,
|
|
"status": request.status,
|
|
"created_at": request.created_at,
|
|
},
|
|
)
|
|
|
|
return request
|
|
|
|
|
|
async def get_manual_payment_request(request_id: str) -> Optional["ManualPaymentRequest"]:
|
|
"""Get a manual payment request by ID"""
|
|
from .models import ManualPaymentRequest
|
|
|
|
return await db.fetchone(
|
|
"SELECT * FROM manual_payment_requests WHERE id = :id",
|
|
{"id": request_id},
|
|
ManualPaymentRequest,
|
|
)
|
|
|
|
|
|
async def get_user_manual_payment_requests(
|
|
user_id: str, limit: int = 100
|
|
) -> list["ManualPaymentRequest"]:
|
|
"""Get all manual payment requests for a specific user"""
|
|
from .models import ManualPaymentRequest
|
|
|
|
return await db.fetchall(
|
|
"""
|
|
SELECT * FROM manual_payment_requests
|
|
WHERE user_id = :user_id
|
|
ORDER BY created_at DESC
|
|
LIMIT :limit
|
|
""",
|
|
{"user_id": user_id, "limit": limit},
|
|
ManualPaymentRequest,
|
|
)
|
|
|
|
|
|
async def get_all_manual_payment_requests(
|
|
status: Optional[str] = None, limit: int = 100
|
|
) -> list["ManualPaymentRequest"]:
|
|
"""Get all manual payment requests, optionally filtered by status"""
|
|
from .models import ManualPaymentRequest
|
|
|
|
if status:
|
|
return await db.fetchall(
|
|
"""
|
|
SELECT * FROM manual_payment_requests
|
|
WHERE status = :status
|
|
ORDER BY created_at DESC
|
|
LIMIT :limit
|
|
""",
|
|
{"status": status, "limit": limit},
|
|
ManualPaymentRequest,
|
|
)
|
|
else:
|
|
return await db.fetchall(
|
|
"""
|
|
SELECT * FROM manual_payment_requests
|
|
ORDER BY created_at DESC
|
|
LIMIT :limit
|
|
""",
|
|
{"limit": limit},
|
|
ManualPaymentRequest,
|
|
)
|
|
|
|
|
|
async def approve_manual_payment_request(
|
|
request_id: str, reviewed_by: str, journal_entry_id: str
|
|
) -> Optional["ManualPaymentRequest"]:
|
|
"""Approve a manual payment request"""
|
|
from .models import ManualPaymentRequest
|
|
|
|
await db.execute(
|
|
"""
|
|
UPDATE manual_payment_requests
|
|
SET status = 'approved', reviewed_at = :reviewed_at, reviewed_by = :reviewed_by, journal_entry_id = :journal_entry_id
|
|
WHERE id = :id
|
|
""",
|
|
{
|
|
"id": request_id,
|
|
"reviewed_at": datetime.now(),
|
|
"reviewed_by": reviewed_by,
|
|
"journal_entry_id": journal_entry_id,
|
|
},
|
|
)
|
|
|
|
return await get_manual_payment_request(request_id)
|
|
|
|
|
|
async def reject_manual_payment_request(
|
|
request_id: str, reviewed_by: str
|
|
) -> Optional["ManualPaymentRequest"]:
|
|
"""Reject a manual payment request"""
|
|
from .models import ManualPaymentRequest
|
|
|
|
await db.execute(
|
|
"""
|
|
UPDATE manual_payment_requests
|
|
SET status = 'rejected', reviewed_at = :reviewed_at, reviewed_by = :reviewed_by
|
|
WHERE id = :id
|
|
""",
|
|
{
|
|
"id": request_id,
|
|
"reviewed_at": datetime.now(),
|
|
"reviewed_by": reviewed_by,
|
|
},
|
|
)
|
|
|
|
return await get_manual_payment_request(request_id)
|
|
|
|
|
|
# ===== BALANCE ASSERTION OPERATIONS =====
|
|
|
|
|
|
async def create_balance_assertion(
|
|
data: CreateBalanceAssertion, created_by: str
|
|
) -> BalanceAssertion:
|
|
"""Create a new balance assertion"""
|
|
from decimal import Decimal
|
|
|
|
assertion_id = urlsafe_short_hash()
|
|
assertion_date = data.date if data.date else datetime.now()
|
|
|
|
assertion = BalanceAssertion(
|
|
id=assertion_id,
|
|
date=assertion_date,
|
|
account_id=data.account_id,
|
|
expected_balance_sats=data.expected_balance_sats,
|
|
expected_balance_fiat=data.expected_balance_fiat,
|
|
fiat_currency=data.fiat_currency,
|
|
tolerance_sats=data.tolerance_sats,
|
|
tolerance_fiat=data.tolerance_fiat,
|
|
status=AssertionStatus.PENDING,
|
|
created_by=created_by,
|
|
created_at=datetime.now(),
|
|
)
|
|
|
|
# Manually insert with Decimal fields converted to strings
|
|
await db.execute(
|
|
"""
|
|
INSERT INTO balance_assertions (
|
|
id, date, account_id, expected_balance_sats, expected_balance_fiat,
|
|
fiat_currency, tolerance_sats, tolerance_fiat, status, created_by, created_at
|
|
) VALUES (
|
|
:id, :date, :account_id, :expected_balance_sats, :expected_balance_fiat,
|
|
:fiat_currency, :tolerance_sats, :tolerance_fiat, :status, :created_by, :created_at
|
|
)
|
|
""",
|
|
{
|
|
"id": assertion.id,
|
|
"date": assertion.date,
|
|
"account_id": assertion.account_id,
|
|
"expected_balance_sats": assertion.expected_balance_sats,
|
|
"expected_balance_fiat": str(assertion.expected_balance_fiat) if assertion.expected_balance_fiat else None,
|
|
"fiat_currency": assertion.fiat_currency,
|
|
"tolerance_sats": assertion.tolerance_sats,
|
|
"tolerance_fiat": str(assertion.tolerance_fiat),
|
|
"status": assertion.status.value,
|
|
"created_by": assertion.created_by,
|
|
"created_at": assertion.created_at,
|
|
},
|
|
)
|
|
return assertion
|
|
|
|
|
|
async def get_balance_assertion(assertion_id: str) -> Optional[BalanceAssertion]:
|
|
"""Get a balance assertion by ID"""
|
|
from decimal import Decimal
|
|
|
|
row = await db.fetchone(
|
|
"SELECT * FROM balance_assertions WHERE id = :id",
|
|
{"id": assertion_id},
|
|
)
|
|
|
|
if not row:
|
|
return None
|
|
|
|
# Parse Decimal fields from TEXT storage
|
|
return BalanceAssertion(
|
|
id=row["id"],
|
|
date=row["date"],
|
|
account_id=row["account_id"],
|
|
expected_balance_sats=row["expected_balance_sats"],
|
|
expected_balance_fiat=Decimal(row["expected_balance_fiat"]) if row["expected_balance_fiat"] else None,
|
|
fiat_currency=row["fiat_currency"],
|
|
tolerance_sats=row["tolerance_sats"],
|
|
tolerance_fiat=Decimal(row["tolerance_fiat"]) if row["tolerance_fiat"] else Decimal("0"),
|
|
checked_balance_sats=row["checked_balance_sats"],
|
|
checked_balance_fiat=Decimal(row["checked_balance_fiat"]) if row["checked_balance_fiat"] else None,
|
|
difference_sats=row["difference_sats"],
|
|
difference_fiat=Decimal(row["difference_fiat"]) if row["difference_fiat"] else None,
|
|
status=AssertionStatus(row["status"]),
|
|
created_by=row["created_by"],
|
|
created_at=row["created_at"],
|
|
checked_at=row["checked_at"],
|
|
)
|
|
|
|
|
|
async def get_balance_assertions(
|
|
account_id: Optional[str] = None,
|
|
status: Optional[AssertionStatus] = None,
|
|
limit: int = 100,
|
|
) -> list[BalanceAssertion]:
|
|
"""Get balance assertions with optional filters"""
|
|
from decimal import Decimal
|
|
|
|
if account_id and status:
|
|
rows = await db.fetchall(
|
|
"""
|
|
SELECT * FROM balance_assertions
|
|
WHERE account_id = :account_id AND status = :status
|
|
ORDER BY date DESC
|
|
LIMIT :limit
|
|
""",
|
|
{"account_id": account_id, "status": status.value, "limit": limit},
|
|
)
|
|
elif account_id:
|
|
rows = await db.fetchall(
|
|
"""
|
|
SELECT * FROM balance_assertions
|
|
WHERE account_id = :account_id
|
|
ORDER BY date DESC
|
|
LIMIT :limit
|
|
""",
|
|
{"account_id": account_id, "limit": limit},
|
|
)
|
|
elif status:
|
|
rows = await db.fetchall(
|
|
"""
|
|
SELECT * FROM balance_assertions
|
|
WHERE status = :status
|
|
ORDER BY date DESC
|
|
LIMIT :limit
|
|
""",
|
|
{"status": status.value, "limit": limit},
|
|
)
|
|
else:
|
|
rows = await db.fetchall(
|
|
"""
|
|
SELECT * FROM balance_assertions
|
|
ORDER BY date DESC
|
|
LIMIT :limit
|
|
""",
|
|
{"limit": limit},
|
|
)
|
|
|
|
assertions = []
|
|
for row in rows:
|
|
assertions.append(
|
|
BalanceAssertion(
|
|
id=row["id"],
|
|
date=row["date"],
|
|
account_id=row["account_id"],
|
|
expected_balance_sats=row["expected_balance_sats"],
|
|
expected_balance_fiat=Decimal(row["expected_balance_fiat"]) if row["expected_balance_fiat"] else None,
|
|
fiat_currency=row["fiat_currency"],
|
|
tolerance_sats=row["tolerance_sats"],
|
|
tolerance_fiat=Decimal(row["tolerance_fiat"]) if row["tolerance_fiat"] else Decimal("0"),
|
|
checked_balance_sats=row["checked_balance_sats"],
|
|
checked_balance_fiat=Decimal(row["checked_balance_fiat"]) if row["checked_balance_fiat"] else None,
|
|
difference_sats=row["difference_sats"],
|
|
difference_fiat=Decimal(row["difference_fiat"]) if row["difference_fiat"] else None,
|
|
status=AssertionStatus(row["status"]),
|
|
created_by=row["created_by"],
|
|
created_at=row["created_at"],
|
|
checked_at=row["checked_at"],
|
|
)
|
|
)
|
|
|
|
return assertions
|
|
|
|
|
|
async def check_balance_assertion(assertion_id: str) -> BalanceAssertion:
|
|
"""
|
|
Check a balance assertion by comparing expected vs actual balance.
|
|
Updates the assertion with the check results.
|
|
"""
|
|
from decimal import Decimal
|
|
|
|
assertion = await get_balance_assertion(assertion_id)
|
|
if not assertion:
|
|
raise ValueError(f"Balance assertion {assertion_id} not found")
|
|
|
|
# Get actual account balance
|
|
account = await get_account(assertion.account_id)
|
|
if not account:
|
|
raise ValueError(f"Account {assertion.account_id} not found")
|
|
|
|
# Calculate balance at the assertion date
|
|
actual_balance = await get_account_balance(assertion.account_id)
|
|
|
|
# Get fiat balance if needed
|
|
actual_fiat_balance = None
|
|
if assertion.fiat_currency and account.user_id:
|
|
user_balance = await get_user_balance(account.user_id)
|
|
actual_fiat_balance = user_balance.fiat_balances.get(assertion.fiat_currency, Decimal("0"))
|
|
|
|
# Check sats balance
|
|
difference_sats = actual_balance - assertion.expected_balance_sats
|
|
sats_match = abs(difference_sats) <= assertion.tolerance_sats
|
|
|
|
# Check fiat balance if applicable
|
|
fiat_match = True
|
|
difference_fiat = None
|
|
if assertion.expected_balance_fiat is not None and actual_fiat_balance is not None:
|
|
difference_fiat = actual_fiat_balance - assertion.expected_balance_fiat
|
|
fiat_match = abs(difference_fiat) <= assertion.tolerance_fiat
|
|
|
|
# Determine overall status
|
|
status = AssertionStatus.PASSED if (sats_match and fiat_match) else AssertionStatus.FAILED
|
|
|
|
# Update assertion with check results
|
|
await db.execute(
|
|
"""
|
|
UPDATE balance_assertions
|
|
SET checked_balance_sats = :checked_sats,
|
|
checked_balance_fiat = :checked_fiat,
|
|
difference_sats = :diff_sats,
|
|
difference_fiat = :diff_fiat,
|
|
status = :status,
|
|
checked_at = :checked_at
|
|
WHERE id = :id
|
|
""",
|
|
{
|
|
"id": assertion_id,
|
|
"checked_sats": actual_balance,
|
|
"checked_fiat": str(actual_fiat_balance) if actual_fiat_balance is not None else None,
|
|
"diff_sats": difference_sats,
|
|
"diff_fiat": str(difference_fiat) if difference_fiat is not None else None,
|
|
"status": status.value,
|
|
"checked_at": datetime.now(),
|
|
},
|
|
)
|
|
|
|
# Return updated assertion
|
|
return await get_balance_assertion(assertion_id)
|
|
|
|
|
|
async def delete_balance_assertion(assertion_id: str) -> None:
|
|
"""Delete a balance assertion"""
|
|
await db.execute(
|
|
"DELETE FROM balance_assertions WHERE id = :id",
|
|
{"id": assertion_id},
|
|
)
|