Updates journal entry retrieval to filter entries based on the user's accounts rather than the user ID. This ensures that users only see journal entries that directly affect their accounts. Also displays fiat amount in journal entries if available in the metadata.
559 lines
17 KiB
Python
559 lines
17 KiB
Python
import json
|
|
from datetime import datetime
|
|
from typing import Optional
|
|
|
|
from lnbits.db import Database
|
|
from lnbits.helpers import urlsafe_short_hash
|
|
|
|
from .models import (
|
|
Account,
|
|
AccountType,
|
|
CastleSettings,
|
|
CreateAccount,
|
|
CreateEntryLine,
|
|
CreateJournalEntry,
|
|
EntryLine,
|
|
JournalEntry,
|
|
StoredUserWalletSettings,
|
|
UserBalance,
|
|
UserCastleSettings,
|
|
UserWalletSettings,
|
|
)
|
|
|
|
db = Database("ext_castle")
|
|
|
|
|
|
# ===== ACCOUNT OPERATIONS =====
|
|
|
|
|
|
async def create_account(data: CreateAccount) -> Account:
|
|
account_id = urlsafe_short_hash()
|
|
account = Account(
|
|
id=account_id,
|
|
name=data.name,
|
|
account_type=data.account_type,
|
|
description=data.description,
|
|
user_id=data.user_id,
|
|
created_at=datetime.now(),
|
|
)
|
|
await db.insert("accounts", account)
|
|
return account
|
|
|
|
|
|
async def get_account(account_id: str) -> Optional[Account]:
|
|
return await db.fetchone(
|
|
"SELECT * FROM accounts WHERE id = :id",
|
|
{"id": account_id},
|
|
Account,
|
|
)
|
|
|
|
|
|
async def get_account_by_name(name: str) -> Optional[Account]:
|
|
return await db.fetchone(
|
|
"SELECT * FROM accounts WHERE name = :name",
|
|
{"name": name},
|
|
Account,
|
|
)
|
|
|
|
|
|
async def get_all_accounts() -> list[Account]:
|
|
return await db.fetchall(
|
|
"SELECT * FROM accounts ORDER BY account_type, name",
|
|
model=Account,
|
|
)
|
|
|
|
|
|
async def get_accounts_by_type(account_type: AccountType) -> list[Account]:
|
|
return await db.fetchall(
|
|
"SELECT * FROM accounts WHERE account_type = :type ORDER BY name",
|
|
{"type": account_type.value},
|
|
Account,
|
|
)
|
|
|
|
|
|
async def get_or_create_user_account(
|
|
user_id: str, account_type: AccountType, base_name: str
|
|
) -> Account:
|
|
"""Get or create a user-specific account (e.g., 'Accounts Payable - User123')"""
|
|
account_name = f"{base_name} - {user_id[:8]}"
|
|
|
|
account = await db.fetchone(
|
|
"""
|
|
SELECT * FROM accounts
|
|
WHERE user_id = :user_id AND account_type = :type AND name = :name
|
|
""",
|
|
{"user_id": user_id, "type": account_type.value, "name": account_name},
|
|
Account,
|
|
)
|
|
|
|
if not account:
|
|
account = await create_account(
|
|
CreateAccount(
|
|
name=account_name,
|
|
account_type=account_type,
|
|
description=f"User-specific {account_type.value} account",
|
|
user_id=user_id,
|
|
)
|
|
)
|
|
|
|
return account
|
|
|
|
|
|
# ===== JOURNAL ENTRY OPERATIONS =====
|
|
|
|
|
|
async def create_journal_entry(
|
|
data: CreateJournalEntry, created_by: str
|
|
) -> JournalEntry:
|
|
entry_id = urlsafe_short_hash()
|
|
|
|
# Validate that debits equal credits
|
|
total_debits = sum(line.debit for line in data.lines)
|
|
total_credits = sum(line.credit for line in data.lines)
|
|
|
|
if total_debits != total_credits:
|
|
raise ValueError(
|
|
f"Journal entry must balance: debits={total_debits}, credits={total_credits}"
|
|
)
|
|
|
|
entry_date = data.entry_date or datetime.now()
|
|
|
|
journal_entry = JournalEntry(
|
|
id=entry_id,
|
|
description=data.description,
|
|
entry_date=entry_date,
|
|
created_by=created_by,
|
|
created_at=datetime.now(),
|
|
reference=data.reference,
|
|
lines=[],
|
|
)
|
|
|
|
# Insert journal entry without the lines field (lines are stored in entry_lines table)
|
|
await db.execute(
|
|
"""
|
|
INSERT INTO journal_entries (id, description, entry_date, created_by, created_at, reference)
|
|
VALUES (:id, :description, :entry_date, :created_by, :created_at, :reference)
|
|
""",
|
|
{
|
|
"id": journal_entry.id,
|
|
"description": journal_entry.description,
|
|
"entry_date": journal_entry.entry_date,
|
|
"created_by": journal_entry.created_by,
|
|
"created_at": journal_entry.created_at,
|
|
"reference": journal_entry.reference,
|
|
},
|
|
)
|
|
|
|
# Create entry lines
|
|
lines = []
|
|
for line_data in data.lines:
|
|
line_id = urlsafe_short_hash()
|
|
line = EntryLine(
|
|
id=line_id,
|
|
journal_entry_id=entry_id,
|
|
account_id=line_data.account_id,
|
|
debit=line_data.debit,
|
|
credit=line_data.credit,
|
|
description=line_data.description,
|
|
metadata=line_data.metadata,
|
|
)
|
|
# Insert with metadata as JSON string
|
|
await db.execute(
|
|
"""
|
|
INSERT INTO entry_lines (id, journal_entry_id, account_id, debit, credit, description, metadata)
|
|
VALUES (:id, :journal_entry_id, :account_id, :debit, :credit, :description, :metadata)
|
|
""",
|
|
{
|
|
"id": line.id,
|
|
"journal_entry_id": line.journal_entry_id,
|
|
"account_id": line.account_id,
|
|
"debit": line.debit,
|
|
"credit": line.credit,
|
|
"description": line.description,
|
|
"metadata": json.dumps(line.metadata),
|
|
},
|
|
)
|
|
lines.append(line)
|
|
|
|
journal_entry.lines = lines
|
|
return journal_entry
|
|
|
|
|
|
async def get_journal_entry(entry_id: str) -> Optional[JournalEntry]:
|
|
entry = await db.fetchone(
|
|
"SELECT * FROM journal_entries WHERE id = :id",
|
|
{"id": entry_id},
|
|
JournalEntry,
|
|
)
|
|
|
|
if entry:
|
|
entry.lines = await get_entry_lines(entry_id)
|
|
|
|
return entry
|
|
|
|
|
|
async def get_entry_lines(journal_entry_id: str) -> list[EntryLine]:
|
|
rows = await db.fetchall(
|
|
"SELECT * FROM entry_lines WHERE journal_entry_id = :id",
|
|
{"id": journal_entry_id},
|
|
)
|
|
|
|
lines = []
|
|
for row in rows:
|
|
# Parse metadata from JSON string
|
|
metadata = json.loads(row.metadata) if row.metadata else {}
|
|
line = EntryLine(
|
|
id=row.id,
|
|
journal_entry_id=row.journal_entry_id,
|
|
account_id=row.account_id,
|
|
debit=row.debit,
|
|
credit=row.credit,
|
|
description=row.description,
|
|
metadata=metadata,
|
|
)
|
|
lines.append(line)
|
|
|
|
return lines
|
|
|
|
|
|
async def get_all_journal_entries(limit: int = 100) -> list[JournalEntry]:
|
|
entries = await db.fetchall(
|
|
"""
|
|
SELECT * FROM journal_entries
|
|
ORDER BY entry_date DESC, created_at DESC
|
|
LIMIT :limit
|
|
""",
|
|
{"limit": limit},
|
|
JournalEntry,
|
|
)
|
|
|
|
for entry in entries:
|
|
entry.lines = await get_entry_lines(entry.id)
|
|
|
|
return entries
|
|
|
|
|
|
async def get_journal_entries_by_user(
|
|
user_id: str, limit: int = 100
|
|
) -> list[JournalEntry]:
|
|
"""Get journal entries that affect the user's accounts"""
|
|
# Get all user-specific accounts
|
|
user_accounts = await db.fetchall(
|
|
"SELECT id FROM accounts WHERE user_id = :user_id",
|
|
{"user_id": user_id},
|
|
)
|
|
|
|
if not user_accounts:
|
|
return []
|
|
|
|
account_ids = [acc["id"] for acc in user_accounts]
|
|
|
|
# Get all journal entries that have lines affecting these accounts
|
|
# Build the IN clause with named parameters
|
|
placeholders = ','.join([f":account_{i}" for i in range(len(account_ids))])
|
|
params = {f"account_{i}": acc_id for i, acc_id in enumerate(account_ids)}
|
|
params["limit"] = limit
|
|
|
|
entries_data = await db.fetchall(
|
|
f"""
|
|
SELECT DISTINCT je.*
|
|
FROM journal_entries je
|
|
JOIN entry_lines el ON je.id = el.journal_entry_id
|
|
WHERE el.account_id IN ({placeholders})
|
|
ORDER BY je.entry_date DESC, je.created_at DESC
|
|
LIMIT :limit
|
|
""",
|
|
params,
|
|
)
|
|
|
|
entries = []
|
|
for entry_data in entries_data:
|
|
entry = JournalEntry(
|
|
id=entry_data["id"],
|
|
description=entry_data["description"],
|
|
entry_date=entry_data["entry_date"],
|
|
created_by=entry_data["created_by"],
|
|
created_at=entry_data["created_at"],
|
|
reference=entry_data["reference"],
|
|
lines=[],
|
|
)
|
|
entry.lines = await get_entry_lines(entry.id)
|
|
entries.append(entry)
|
|
|
|
return entries
|
|
|
|
|
|
# ===== BALANCE AND REPORTING =====
|
|
|
|
|
|
async def get_account_balance(account_id: str) -> int:
|
|
"""Calculate account balance (debits - credits for assets/expenses, credits - debits for liabilities/equity/revenue)"""
|
|
result = await db.fetchone(
|
|
"""
|
|
SELECT
|
|
COALESCE(SUM(debit), 0) as total_debit,
|
|
COALESCE(SUM(credit), 0) as total_credit
|
|
FROM entry_lines
|
|
WHERE account_id = :id
|
|
""",
|
|
{"id": account_id},
|
|
)
|
|
|
|
if not result:
|
|
return 0
|
|
|
|
account = await get_account(account_id)
|
|
if not account:
|
|
return 0
|
|
|
|
total_debit = result["total_debit"]
|
|
total_credit = result["total_credit"]
|
|
|
|
# Normal balance for each account type:
|
|
# Assets and Expenses: Debit balance (debit - credit)
|
|
# Liabilities, Equity, and Revenue: Credit balance (credit - debit)
|
|
if account.account_type in [AccountType.ASSET, AccountType.EXPENSE]:
|
|
return total_debit - total_credit
|
|
else:
|
|
return total_credit - total_debit
|
|
|
|
|
|
async def get_user_balance(user_id: str) -> UserBalance:
|
|
"""Get user's balance with the Castle (positive = castle owes user, negative = user owes castle)"""
|
|
# Get all user-specific accounts
|
|
user_accounts = await db.fetchall(
|
|
"SELECT * FROM accounts WHERE user_id = :user_id",
|
|
{"user_id": user_id},
|
|
Account,
|
|
)
|
|
|
|
total_balance = 0
|
|
fiat_balances = {} # Track fiat balances by currency
|
|
|
|
for account in user_accounts:
|
|
balance = await get_account_balance(account.id)
|
|
|
|
# Get all entry lines for this account to calculate fiat balances
|
|
entry_lines = await db.fetchall(
|
|
"SELECT * FROM entry_lines WHERE account_id = :account_id",
|
|
{"account_id": account.id},
|
|
)
|
|
|
|
for line in entry_lines:
|
|
# Parse metadata to get fiat amounts
|
|
metadata = json.loads(line["metadata"]) if line.get("metadata") else {}
|
|
fiat_currency = metadata.get("fiat_currency")
|
|
fiat_amount = metadata.get("fiat_amount")
|
|
|
|
if fiat_currency and fiat_amount:
|
|
# Initialize currency if not exists
|
|
if fiat_currency not in fiat_balances:
|
|
fiat_balances[fiat_currency] = 0.0
|
|
|
|
# Calculate fiat balance based on account type
|
|
if account.account_type == AccountType.LIABILITY:
|
|
# Liability: credit increases (castle owes more), debit decreases
|
|
if line["credit"] > 0:
|
|
fiat_balances[fiat_currency] += fiat_amount
|
|
elif line["debit"] > 0:
|
|
fiat_balances[fiat_currency] -= fiat_amount
|
|
elif account.account_type == AccountType.ASSET:
|
|
# Asset (receivable): debit increases (user owes more), credit decreases
|
|
if line["debit"] > 0:
|
|
fiat_balances[fiat_currency] -= fiat_amount
|
|
elif line["credit"] > 0:
|
|
fiat_balances[fiat_currency] += fiat_amount
|
|
|
|
# Calculate satoshi balance
|
|
# If it's a liability account (castle owes user), it's positive
|
|
# If it's an asset account (user owes castle), it's negative
|
|
if account.account_type == AccountType.LIABILITY:
|
|
total_balance += balance
|
|
elif account.account_type == AccountType.ASSET:
|
|
total_balance -= balance
|
|
# Equity contributions are tracked but don't affect what castle owes
|
|
|
|
return UserBalance(
|
|
user_id=user_id,
|
|
balance=total_balance,
|
|
accounts=user_accounts,
|
|
fiat_balances=fiat_balances,
|
|
)
|
|
|
|
|
|
async def get_all_user_balances() -> list[UserBalance]:
|
|
"""Get balances for all users (used by castle to see who they owe)"""
|
|
# Get all user-specific accounts
|
|
all_accounts = await db.fetchall(
|
|
"SELECT * FROM accounts WHERE user_id IS NOT NULL",
|
|
{},
|
|
Account,
|
|
)
|
|
|
|
# Group by user_id
|
|
users_dict = {}
|
|
for account in all_accounts:
|
|
if account.user_id not in users_dict:
|
|
users_dict[account.user_id] = []
|
|
users_dict[account.user_id].append(account)
|
|
|
|
# Calculate balance for each user
|
|
user_balances = []
|
|
for user_id, accounts in users_dict.items():
|
|
total_balance = 0
|
|
fiat_balances = {}
|
|
|
|
for account in accounts:
|
|
balance = await get_account_balance(account.id)
|
|
|
|
# Get all entry lines for this account to calculate fiat balances
|
|
entry_lines = await db.fetchall(
|
|
"SELECT * FROM entry_lines WHERE account_id = :account_id",
|
|
{"account_id": account.id},
|
|
)
|
|
|
|
for line in entry_lines:
|
|
# Parse metadata to get fiat amounts
|
|
metadata = json.loads(line["metadata"]) if line.get("metadata") else {}
|
|
fiat_currency = metadata.get("fiat_currency")
|
|
fiat_amount = metadata.get("fiat_amount")
|
|
|
|
if fiat_currency and fiat_amount:
|
|
# Initialize currency if not exists
|
|
if fiat_currency not in fiat_balances:
|
|
fiat_balances[fiat_currency] = 0.0
|
|
|
|
# Calculate fiat balance based on account type
|
|
if account.account_type == AccountType.LIABILITY:
|
|
# Liability: credit increases (castle owes more), debit decreases
|
|
if line["credit"] > 0:
|
|
fiat_balances[fiat_currency] += fiat_amount
|
|
elif line["debit"] > 0:
|
|
fiat_balances[fiat_currency] -= fiat_amount
|
|
elif account.account_type == AccountType.ASSET:
|
|
# Asset (receivable): debit increases (user owes more), credit decreases
|
|
if line["debit"] > 0:
|
|
fiat_balances[fiat_currency] -= fiat_amount
|
|
elif line["credit"] > 0:
|
|
fiat_balances[fiat_currency] += fiat_amount
|
|
|
|
# Calculate satoshi balance
|
|
if account.account_type == AccountType.LIABILITY:
|
|
total_balance += balance
|
|
elif account.account_type == AccountType.ASSET:
|
|
total_balance -= balance
|
|
|
|
if total_balance != 0 or fiat_balances: # Include users with non-zero balance or fiat balances
|
|
user_balances.append(
|
|
UserBalance(
|
|
user_id=user_id,
|
|
balance=total_balance,
|
|
accounts=accounts,
|
|
fiat_balances=fiat_balances,
|
|
)
|
|
)
|
|
|
|
return user_balances
|
|
|
|
|
|
async def get_account_transactions(
|
|
account_id: str, limit: int = 100
|
|
) -> list[tuple[JournalEntry, EntryLine]]:
|
|
"""Get all transactions affecting a specific account"""
|
|
rows = await db.fetchall(
|
|
"""
|
|
SELECT * FROM entry_lines
|
|
WHERE account_id = :id
|
|
ORDER BY id DESC
|
|
LIMIT :limit
|
|
""",
|
|
{"id": account_id, "limit": limit},
|
|
)
|
|
|
|
transactions = []
|
|
for row in rows:
|
|
# Parse metadata from JSON string
|
|
metadata = json.loads(row.metadata) if row.metadata else {}
|
|
line = EntryLine(
|
|
id=row.id,
|
|
journal_entry_id=row.journal_entry_id,
|
|
account_id=row.account_id,
|
|
debit=row.debit,
|
|
credit=row.credit,
|
|
description=row.description,
|
|
metadata=metadata,
|
|
)
|
|
entry = await get_journal_entry(line.journal_entry_id)
|
|
if entry:
|
|
transactions.append((entry, line))
|
|
|
|
return transactions
|
|
|
|
|
|
# ===== SETTINGS =====
|
|
|
|
|
|
async def create_castle_settings(
|
|
user_id: str, data: CastleSettings
|
|
) -> CastleSettings:
|
|
settings = UserCastleSettings(**data.dict(), id=user_id)
|
|
await db.insert("extension_settings", settings)
|
|
return settings
|
|
|
|
|
|
async def get_castle_settings(user_id: str) -> Optional[CastleSettings]:
|
|
return await db.fetchone(
|
|
"""
|
|
SELECT * FROM extension_settings
|
|
WHERE id = :user_id
|
|
""",
|
|
{"user_id": user_id},
|
|
CastleSettings,
|
|
)
|
|
|
|
|
|
async def update_castle_settings(
|
|
user_id: str, data: CastleSettings
|
|
) -> CastleSettings:
|
|
settings = UserCastleSettings(**data.dict(), id=user_id)
|
|
await db.update("extension_settings", settings)
|
|
return settings
|
|
|
|
|
|
# ===== USER WALLET SETTINGS =====
|
|
|
|
|
|
async def create_user_wallet_settings(
|
|
user_id: str, data: UserWalletSettings
|
|
) -> UserWalletSettings:
|
|
settings = StoredUserWalletSettings(**data.dict(), id=user_id)
|
|
await db.insert("user_wallet_settings", settings)
|
|
return settings
|
|
|
|
|
|
async def get_user_wallet_settings(user_id: str) -> Optional[UserWalletSettings]:
|
|
return await db.fetchone(
|
|
"""
|
|
SELECT * FROM user_wallet_settings
|
|
WHERE id = :user_id
|
|
""",
|
|
{"user_id": user_id},
|
|
UserWalletSettings,
|
|
)
|
|
|
|
|
|
async def update_user_wallet_settings(
|
|
user_id: str, data: UserWalletSettings
|
|
) -> UserWalletSettings:
|
|
settings = StoredUserWalletSettings(**data.dict(), id=user_id)
|
|
await db.update("user_wallet_settings", settings)
|
|
return settings
|
|
|
|
|
|
async def get_all_user_wallet_settings() -> list[StoredUserWalletSettings]:
|
|
"""Get all user wallet settings"""
|
|
return await db.fetchall(
|
|
"SELECT * FROM user_wallet_settings ORDER BY id",
|
|
{},
|
|
StoredUserWalletSettings,
|
|
)
|