castle/crud.py
padreug 1a9c91d042 FIX Admin Net Balance: Filters entry lines by journal entry flag
Ensures that only cleared entry lines are included when
calculating fiat balances by filtering based on the
journal entry's flag. Excludes pending, flagged, and
voided entries
2025-10-23 01:09:43 +02:00

767 lines
24 KiB
Python

import json
from datetime import datetime
from typing import Optional
from lnbits.db import Database
from lnbits.helpers import urlsafe_short_hash
from .models import (
Account,
AccountType,
CastleSettings,
CreateAccount,
CreateEntryLine,
CreateJournalEntry,
EntryLine,
JournalEntry,
StoredUserWalletSettings,
UserBalance,
UserCastleSettings,
UserWalletSettings,
)
db = Database("ext_castle")
# ===== ACCOUNT OPERATIONS =====
async def create_account(data: CreateAccount) -> Account:
account_id = urlsafe_short_hash()
account = Account(
id=account_id,
name=data.name,
account_type=data.account_type,
description=data.description,
user_id=data.user_id,
created_at=datetime.now(),
)
await db.insert("accounts", account)
return account
async def get_account(account_id: str) -> Optional[Account]:
return await db.fetchone(
"SELECT * FROM accounts WHERE id = :id",
{"id": account_id},
Account,
)
async def get_account_by_name(name: str) -> Optional[Account]:
"""Get account by name (hierarchical format)"""
return await db.fetchone(
"SELECT * FROM accounts WHERE name = :name",
{"name": name},
Account,
)
async def get_all_accounts() -> list[Account]:
return await db.fetchall(
"SELECT * FROM accounts ORDER BY account_type, name",
model=Account,
)
async def get_accounts_by_type(account_type: AccountType) -> list[Account]:
return await db.fetchall(
"SELECT * FROM accounts WHERE account_type = :type ORDER BY name",
{"type": account_type.value},
Account,
)
async def get_or_create_user_account(
user_id: str, account_type: AccountType, base_name: str
) -> Account:
"""
Get or create a user-specific account with hierarchical naming.
Examples:
get_or_create_user_account("af983632", AccountType.ASSET, "Accounts Receivable")
"Assets:Receivable:User-af983632"
get_or_create_user_account("af983632", AccountType.LIABILITY, "Accounts Payable")
"Liabilities:Payable:User-af983632"
"""
from .account_utils import format_hierarchical_account_name
# Generate hierarchical account name
account_name = format_hierarchical_account_name(account_type, base_name, user_id)
# Try to find existing account with this hierarchical name
account = await db.fetchone(
"""
SELECT * FROM accounts
WHERE user_id = :user_id AND account_type = :type AND name = :name
""",
{"user_id": user_id, "type": account_type.value, "name": account_name},
Account,
)
if not account:
# Create new account with hierarchical name
account = await create_account(
CreateAccount(
name=account_name,
account_type=account_type,
description=f"User-specific {account_type.value} account",
user_id=user_id,
)
)
return account
# ===== JOURNAL ENTRY OPERATIONS =====
async def create_journal_entry(
data: CreateJournalEntry, created_by: str
) -> JournalEntry:
entry_id = urlsafe_short_hash()
# Validate that debits equal credits
total_debits = sum(line.debit for line in data.lines)
total_credits = sum(line.credit for line in data.lines)
if total_debits != total_credits:
raise ValueError(
f"Journal entry must balance: debits={total_debits}, credits={total_credits}"
)
entry_date = data.entry_date or datetime.now()
journal_entry = JournalEntry(
id=entry_id,
description=data.description,
entry_date=entry_date,
created_by=created_by,
created_at=datetime.now(),
reference=data.reference,
lines=[],
flag=data.flag,
meta=data.meta,
)
# Insert journal entry without the lines field (lines are stored in entry_lines table)
await db.execute(
"""
INSERT INTO journal_entries (id, description, entry_date, created_by, created_at, reference, flag, meta)
VALUES (:id, :description, :entry_date, :created_by, :created_at, :reference, :flag, :meta)
""",
{
"id": journal_entry.id,
"description": journal_entry.description,
"entry_date": journal_entry.entry_date,
"created_by": journal_entry.created_by,
"created_at": journal_entry.created_at,
"reference": journal_entry.reference,
"flag": journal_entry.flag.value,
"meta": json.dumps(journal_entry.meta),
},
)
# Create entry lines
lines = []
for line_data in data.lines:
line_id = urlsafe_short_hash()
line = EntryLine(
id=line_id,
journal_entry_id=entry_id,
account_id=line_data.account_id,
debit=line_data.debit,
credit=line_data.credit,
description=line_data.description,
metadata=line_data.metadata,
)
# Insert with metadata as JSON string
await db.execute(
"""
INSERT INTO entry_lines (id, journal_entry_id, account_id, debit, credit, description, metadata)
VALUES (:id, :journal_entry_id, :account_id, :debit, :credit, :description, :metadata)
""",
{
"id": line.id,
"journal_entry_id": line.journal_entry_id,
"account_id": line.account_id,
"debit": line.debit,
"credit": line.credit,
"description": line.description,
"metadata": json.dumps(line.metadata),
},
)
lines.append(line)
journal_entry.lines = lines
return journal_entry
async def get_journal_entry(entry_id: str) -> Optional[JournalEntry]:
entry = await db.fetchone(
"SELECT * FROM journal_entries WHERE id = :id",
{"id": entry_id},
JournalEntry,
)
if entry:
entry.lines = await get_entry_lines(entry_id)
return entry
async def get_entry_lines(journal_entry_id: str) -> list[EntryLine]:
rows = await db.fetchall(
"SELECT * FROM entry_lines WHERE journal_entry_id = :id",
{"id": journal_entry_id},
)
lines = []
for row in rows:
# Parse metadata from JSON string
metadata = json.loads(row.metadata) if row.metadata else {}
line = EntryLine(
id=row.id,
journal_entry_id=row.journal_entry_id,
account_id=row.account_id,
debit=row.debit,
credit=row.credit,
description=row.description,
metadata=metadata,
)
lines.append(line)
return lines
async def get_all_journal_entries(limit: int = 100) -> list[JournalEntry]:
entries_data = await db.fetchall(
"""
SELECT * FROM journal_entries
ORDER BY entry_date DESC, created_at DESC
LIMIT :limit
""",
{"limit": limit},
)
entries = []
for entry_data in entries_data:
# Parse flag and meta from database
from .models import JournalEntryFlag
flag = JournalEntryFlag(entry_data.get("flag", "*"))
meta = json.loads(entry_data.get("meta", "{}")) if entry_data.get("meta") else {}
entry = JournalEntry(
id=entry_data["id"],
description=entry_data["description"],
entry_date=entry_data["entry_date"],
created_by=entry_data["created_by"],
created_at=entry_data["created_at"],
reference=entry_data["reference"],
flag=flag,
meta=meta,
lines=[],
)
entry.lines = await get_entry_lines(entry.id)
entries.append(entry)
return entries
async def get_journal_entries_by_user(
user_id: str, limit: int = 100
) -> list[JournalEntry]:
"""Get journal entries that affect the user's accounts"""
# Get all user-specific accounts
user_accounts = await db.fetchall(
"SELECT id FROM accounts WHERE user_id = :user_id",
{"user_id": user_id},
)
if not user_accounts:
return []
account_ids = [acc["id"] for acc in user_accounts]
# Get all journal entries that have lines affecting these accounts
# Build the IN clause with named parameters
placeholders = ','.join([f":account_{i}" for i in range(len(account_ids))])
params = {f"account_{i}": acc_id for i, acc_id in enumerate(account_ids)}
params["limit"] = limit
entries_data = await db.fetchall(
f"""
SELECT DISTINCT je.*
FROM journal_entries je
JOIN entry_lines el ON je.id = el.journal_entry_id
WHERE el.account_id IN ({placeholders})
ORDER BY je.entry_date DESC, je.created_at DESC
LIMIT :limit
""",
params,
)
entries = []
for entry_data in entries_data:
# Parse flag and meta from database
from .models import JournalEntryFlag
flag = JournalEntryFlag(entry_data.get("flag", "*"))
meta = json.loads(entry_data.get("meta", "{}")) if entry_data.get("meta") else {}
entry = JournalEntry(
id=entry_data["id"],
description=entry_data["description"],
entry_date=entry_data["entry_date"],
created_by=entry_data["created_by"],
created_at=entry_data["created_at"],
reference=entry_data["reference"],
flag=flag,
meta=meta,
lines=[],
)
entry.lines = await get_entry_lines(entry.id)
entries.append(entry)
return entries
# ===== BALANCE AND REPORTING =====
async def get_account_balance(account_id: str) -> int:
"""Calculate account balance (debits - credits for assets/expenses, credits - debits for liabilities/equity/revenue)
Only includes entries that are cleared (flag='*'), excludes pending/flagged/voided entries."""
result = await db.fetchone(
"""
SELECT
COALESCE(SUM(el.debit), 0) as total_debit,
COALESCE(SUM(el.credit), 0) as total_credit
FROM entry_lines el
JOIN journal_entries je ON el.journal_entry_id = je.id
WHERE el.account_id = :id
AND je.flag = '*'
""",
{"id": account_id},
)
if not result:
return 0
account = await get_account(account_id)
if not account:
return 0
total_debit = result["total_debit"]
total_credit = result["total_credit"]
# Normal balance for each account type:
# Assets and Expenses: Debit balance (debit - credit)
# Liabilities, Equity, and Revenue: Credit balance (credit - debit)
if account.account_type in [AccountType.ASSET, AccountType.EXPENSE]:
return total_debit - total_credit
else:
return total_credit - total_debit
async def get_user_balance(user_id: str) -> UserBalance:
"""Get user's balance with the Castle (positive = castle owes user, negative = user owes castle)"""
# Get all user-specific accounts
user_accounts = await db.fetchall(
"SELECT * FROM accounts WHERE user_id = :user_id",
{"user_id": user_id},
Account,
)
total_balance = 0
fiat_balances = {} # Track fiat balances by currency
for account in user_accounts:
balance = await get_account_balance(account.id)
# Get all entry lines for this account to calculate fiat balances
# Only include cleared entries (exclude pending/flagged/voided)
entry_lines = await db.fetchall(
"""
SELECT el.*
FROM entry_lines el
JOIN journal_entries je ON el.journal_entry_id = je.id
WHERE el.account_id = :account_id
AND je.flag = '*'
""",
{"account_id": account.id},
)
for line in entry_lines:
# Parse metadata to get fiat amounts
metadata = json.loads(line["metadata"]) if line.get("metadata") else {}
fiat_currency = metadata.get("fiat_currency")
fiat_amount = metadata.get("fiat_amount")
if fiat_currency and fiat_amount:
from decimal import Decimal
# Initialize currency if not exists
if fiat_currency not in fiat_balances:
fiat_balances[fiat_currency] = Decimal("0")
# Convert fiat_amount to Decimal
fiat_decimal = Decimal(str(fiat_amount))
# Calculate fiat balance based on account type
if account.account_type == AccountType.LIABILITY:
# Liability: credit increases (castle owes more), debit decreases
if line["credit"] > 0:
fiat_balances[fiat_currency] += fiat_decimal
elif line["debit"] > 0:
fiat_balances[fiat_currency] -= fiat_decimal
elif account.account_type == AccountType.ASSET:
# Asset (receivable): debit increases (user owes more), credit decreases
if line["debit"] > 0:
fiat_balances[fiat_currency] -= fiat_decimal
elif line["credit"] > 0:
fiat_balances[fiat_currency] += fiat_decimal
# Calculate satoshi balance
# If it's a liability account (castle owes user), it's positive
# If it's an asset account (user owes castle), it's negative
if account.account_type == AccountType.LIABILITY:
total_balance += balance
elif account.account_type == AccountType.ASSET:
total_balance -= balance
# Equity contributions are tracked but don't affect what castle owes
return UserBalance(
user_id=user_id,
balance=total_balance,
accounts=user_accounts,
fiat_balances=fiat_balances,
)
async def get_all_user_balances() -> list[UserBalance]:
"""Get balances for all users (used by castle to see who they owe)"""
# Get all user-specific accounts
all_accounts = await db.fetchall(
"SELECT * FROM accounts WHERE user_id IS NOT NULL",
{},
Account,
)
# Group by user_id
users_dict = {}
for account in all_accounts:
if account.user_id not in users_dict:
users_dict[account.user_id] = []
users_dict[account.user_id].append(account)
# Calculate balance for each user
user_balances = []
for user_id, accounts in users_dict.items():
total_balance = 0
fiat_balances = {}
for account in accounts:
balance = await get_account_balance(account.id)
# Get all entry lines for this account to calculate fiat balances
# Only include cleared entries (exclude pending/flagged/voided)
entry_lines = await db.fetchall(
"""
SELECT el.*
FROM entry_lines el
JOIN journal_entries je ON el.journal_entry_id = je.id
WHERE el.account_id = :account_id
AND je.flag = '*'
""",
{"account_id": account.id},
)
for line in entry_lines:
# Parse metadata to get fiat amounts
metadata = json.loads(line["metadata"]) if line.get("metadata") else {}
fiat_currency = metadata.get("fiat_currency")
fiat_amount = metadata.get("fiat_amount")
if fiat_currency and fiat_amount:
from decimal import Decimal
# Initialize currency if not exists
if fiat_currency not in fiat_balances:
fiat_balances[fiat_currency] = Decimal("0")
# Convert fiat_amount to Decimal
fiat_decimal = Decimal(str(fiat_amount))
# Calculate fiat balance based on account type
if account.account_type == AccountType.LIABILITY:
# Liability: credit increases (castle owes more), debit decreases
if line["credit"] > 0:
fiat_balances[fiat_currency] += fiat_decimal
elif line["debit"] > 0:
fiat_balances[fiat_currency] -= fiat_decimal
elif account.account_type == AccountType.ASSET:
# Asset (receivable): debit increases (user owes more), credit decreases
if line["debit"] > 0:
fiat_balances[fiat_currency] -= fiat_decimal
elif line["credit"] > 0:
fiat_balances[fiat_currency] += fiat_decimal
# Calculate satoshi balance
if account.account_type == AccountType.LIABILITY:
total_balance += balance
elif account.account_type == AccountType.ASSET:
total_balance -= balance
if total_balance != 0 or fiat_balances: # Include users with non-zero balance or fiat balances
user_balances.append(
UserBalance(
user_id=user_id,
balance=total_balance,
accounts=accounts,
fiat_balances=fiat_balances,
)
)
return user_balances
async def get_account_transactions(
account_id: str, limit: int = 100
) -> list[tuple[JournalEntry, EntryLine]]:
"""Get all transactions affecting a specific account"""
rows = await db.fetchall(
"""
SELECT * FROM entry_lines
WHERE account_id = :id
ORDER BY id DESC
LIMIT :limit
""",
{"id": account_id, "limit": limit},
)
transactions = []
for row in rows:
# Parse metadata from JSON string
metadata = json.loads(row.metadata) if row.metadata else {}
line = EntryLine(
id=row.id,
journal_entry_id=row.journal_entry_id,
account_id=row.account_id,
debit=row.debit,
credit=row.credit,
description=row.description,
metadata=metadata,
)
entry = await get_journal_entry(line.journal_entry_id)
if entry:
transactions.append((entry, line))
return transactions
# ===== SETTINGS =====
async def create_castle_settings(
user_id: str, data: CastleSettings
) -> CastleSettings:
settings = UserCastleSettings(**data.dict(), id=user_id)
await db.insert("extension_settings", settings)
return settings
async def get_castle_settings(user_id: str) -> Optional[CastleSettings]:
return await db.fetchone(
"""
SELECT * FROM extension_settings
WHERE id = :user_id
""",
{"user_id": user_id},
CastleSettings,
)
async def update_castle_settings(
user_id: str, data: CastleSettings
) -> CastleSettings:
settings = UserCastleSettings(**data.dict(), id=user_id)
await db.update("extension_settings", settings)
return settings
# ===== USER WALLET SETTINGS =====
async def create_user_wallet_settings(
user_id: str, data: UserWalletSettings
) -> UserWalletSettings:
settings = StoredUserWalletSettings(**data.dict(), id=user_id)
await db.insert("user_wallet_settings", settings)
return settings
async def get_user_wallet_settings(user_id: str) -> Optional[UserWalletSettings]:
return await db.fetchone(
"""
SELECT * FROM user_wallet_settings
WHERE id = :user_id
""",
{"user_id": user_id},
UserWalletSettings,
)
async def update_user_wallet_settings(
user_id: str, data: UserWalletSettings
) -> UserWalletSettings:
settings = StoredUserWalletSettings(**data.dict(), id=user_id)
await db.update("user_wallet_settings", settings)
return settings
async def get_all_user_wallet_settings() -> list[StoredUserWalletSettings]:
"""Get all user wallet settings"""
return await db.fetchall(
"SELECT * FROM user_wallet_settings ORDER BY id",
{},
StoredUserWalletSettings,
)
# ===== MANUAL PAYMENT REQUESTS =====
async def create_manual_payment_request(
user_id: str, amount: int, description: str
) -> "ManualPaymentRequest":
"""Create a new manual payment request"""
from .models import ManualPaymentRequest
request_id = urlsafe_short_hash()
request = ManualPaymentRequest(
id=request_id,
user_id=user_id,
amount=amount,
description=description,
status="pending",
created_at=datetime.now(),
)
await db.execute(
"""
INSERT INTO manual_payment_requests (id, user_id, amount, description, status, created_at)
VALUES (:id, :user_id, :amount, :description, :status, :created_at)
""",
{
"id": request.id,
"user_id": request.user_id,
"amount": request.amount,
"description": request.description,
"status": request.status,
"created_at": request.created_at,
},
)
return request
async def get_manual_payment_request(request_id: str) -> Optional["ManualPaymentRequest"]:
"""Get a manual payment request by ID"""
from .models import ManualPaymentRequest
return await db.fetchone(
"SELECT * FROM manual_payment_requests WHERE id = :id",
{"id": request_id},
ManualPaymentRequest,
)
async def get_user_manual_payment_requests(
user_id: str, limit: int = 100
) -> list["ManualPaymentRequest"]:
"""Get all manual payment requests for a specific user"""
from .models import ManualPaymentRequest
return await db.fetchall(
"""
SELECT * FROM manual_payment_requests
WHERE user_id = :user_id
ORDER BY created_at DESC
LIMIT :limit
""",
{"user_id": user_id, "limit": limit},
ManualPaymentRequest,
)
async def get_all_manual_payment_requests(
status: Optional[str] = None, limit: int = 100
) -> list["ManualPaymentRequest"]:
"""Get all manual payment requests, optionally filtered by status"""
from .models import ManualPaymentRequest
if status:
return await db.fetchall(
"""
SELECT * FROM manual_payment_requests
WHERE status = :status
ORDER BY created_at DESC
LIMIT :limit
""",
{"status": status, "limit": limit},
ManualPaymentRequest,
)
else:
return await db.fetchall(
"""
SELECT * FROM manual_payment_requests
ORDER BY created_at DESC
LIMIT :limit
""",
{"limit": limit},
ManualPaymentRequest,
)
async def approve_manual_payment_request(
request_id: str, reviewed_by: str, journal_entry_id: str
) -> Optional["ManualPaymentRequest"]:
"""Approve a manual payment request"""
from .models import ManualPaymentRequest
await db.execute(
"""
UPDATE manual_payment_requests
SET status = 'approved', reviewed_at = :reviewed_at, reviewed_by = :reviewed_by, journal_entry_id = :journal_entry_id
WHERE id = :id
""",
{
"id": request_id,
"reviewed_at": datetime.now(),
"reviewed_by": reviewed_by,
"journal_entry_id": journal_entry_id,
},
)
return await get_manual_payment_request(request_id)
async def reject_manual_payment_request(
request_id: str, reviewed_by: str
) -> Optional["ManualPaymentRequest"]:
"""Reject a manual payment request"""
from .models import ManualPaymentRequest
await db.execute(
"""
UPDATE manual_payment_requests
SET status = 'rejected', reviewed_at = :reviewed_at, reviewed_by = :reviewed_by
WHERE id = :id
""",
{
"id": request_id,
"reviewed_at": datetime.now(),
"reviewed_by": reviewed_by,
},
)
return await get_manual_payment_request(request_id)