castle/crud.py
padreug cb62cbb0a2 Update account sync to mark orphaned accounts as inactive
- Added update_account_is_active() function in crud.py
- Updated sync_accounts_from_beancount() to:
  * Mark accounts in Castle DB but not in Beancount as inactive
  * Reactivate accounts that return to Beancount
  * Track deactivated and reactivated counts in sync stats
- Improved sync efficiency with lookup maps
- Enhanced logging for deactivation/reactivation events

This completes the soft delete implementation for orphaned accounts.
When accounts are removed from the Beancount ledger, they are now
automatically marked as inactive in Castle DB during the hourly sync.
2025-11-11 01:54:04 +01:00

1134 lines
36 KiB
Python

import json
from datetime import datetime
from typing import Optional
import httpx
from lnbits.db import Database
from lnbits.helpers import urlsafe_short_hash
from lnbits.utils.cache import Cache
from .models import (
Account,
AccountPermission,
AccountType,
AssertionStatus,
BalanceAssertion,
CastleSettings,
CreateAccount,
CreateAccountPermission,
CreateBalanceAssertion,
CreateEntryLine,
CreateJournalEntry,
CreateUserEquityStatus,
EntryLine,
JournalEntry,
PermissionType,
StoredUserWalletSettings,
UserBalance,
UserCastleSettings,
UserEquityStatus,
UserWalletSettings,
)
# Import core accounting logic
from .core.validation import (
ValidationError,
validate_journal_entry,
validate_balance,
validate_receivable_entry,
validate_expense_entry,
validate_payment_entry,
)
db = Database("ext_castle")
# ===== CACHING =====
# Cache for account and permission lookups to reduce DB queries
# TTLs: accounts=300s (5min), permissions=60s (1min)
account_cache = Cache() # 5 minutes for accounts (rarely change)
permission_cache = Cache() # 1 minute for permissions (may change frequently)
# Cache TTLs
ACCOUNT_CACHE_TTL = 300 # 5 minutes
PERMISSION_CACHE_TTL = 60 # 1 minute
# ===== ACCOUNT OPERATIONS =====
async def create_account(data: CreateAccount) -> Account:
account_id = urlsafe_short_hash()
account = Account(
id=account_id,
name=data.name,
account_type=data.account_type,
description=data.description,
user_id=data.user_id,
created_at=datetime.now(),
)
await db.insert("accounts", account)
# Invalidate cache for this account (Cache class doesn't have delete method, use pop)
account_cache._values.pop(f"account:id:{account_id}", None)
account_cache._values.pop(f"account:name:{data.name}", None)
return account
async def get_account(account_id: str) -> Optional[Account]:
"""Get account by ID with caching"""
cache_key = f"account:id:{account_id}"
# Try cache first
cached = account_cache.get(cache_key)
if cached is not None:
return cached
# Query DB
account = await db.fetchone(
"SELECT * FROM accounts WHERE id = :id",
{"id": account_id},
Account,
)
# Cache result (even if None)
account_cache.set(cache_key, account, ACCOUNT_CACHE_TTL)
return account
async def get_account_by_name(name: str) -> Optional[Account]:
"""Get account by name (hierarchical format) with caching"""
cache_key = f"account:name:{name}"
# Try cache first
cached = account_cache.get(cache_key)
if cached is not None:
return cached
# Query DB
account = await db.fetchone(
"SELECT * FROM accounts WHERE name = :name",
{"name": name},
Account,
)
# Cache result (even if None)
account_cache.set(cache_key, account, ACCOUNT_CACHE_TTL)
return account
async def get_all_accounts() -> list[Account]:
return await db.fetchall(
"SELECT * FROM accounts ORDER BY account_type, name",
model=Account,
)
async def get_accounts_by_type(account_type: AccountType) -> list[Account]:
return await db.fetchall(
"SELECT * FROM accounts WHERE account_type = :type ORDER BY name",
{"type": account_type.value},
Account,
)
async def update_account_is_active(account_id: str, is_active: bool) -> None:
"""
Update the is_active status of an account (soft delete/reactivate).
Args:
account_id: Account ID to update
is_active: True to activate, False to deactivate
"""
await db.execute(
"""
UPDATE accounts
SET is_active = :is_active
WHERE id = :account_id
""",
{"account_id": account_id, "is_active": is_active},
)
# Invalidate cache
account_cache._values.pop(f"account:id:{account_id}", None)
async def get_or_create_user_account(
user_id: str, account_type: AccountType, base_name: str
) -> Account:
"""
Get or create a user-specific account with hierarchical naming.
This function checks if the account exists in Fava/Beancount and creates it
if it doesn't exist. The account is also registered in Castle's database for
metadata tracking (permissions, descriptions, etc.).
Examples:
get_or_create_user_account("af983632", AccountType.ASSET, "Accounts Receivable")
"Assets:Receivable:User-af983632"
get_or_create_user_account("af983632", AccountType.LIABILITY, "Accounts Payable")
"Liabilities:Payable:User-af983632"
"""
from .account_utils import format_hierarchical_account_name
from .fava_client import get_fava_client
from loguru import logger
# Generate hierarchical account name
account_name = format_hierarchical_account_name(account_type, base_name, user_id)
# Try to find existing account with this hierarchical name in Castle DB
account = await db.fetchone(
"""
SELECT * FROM accounts
WHERE user_id = :user_id AND account_type = :type AND name = :name
""",
{"user_id": user_id, "type": account_type.value, "name": account_name},
Account,
)
logger.info(f"[ACCOUNT CHECK] User {user_id[:8]}, Account: {account_name}, In Castle DB: {account is not None}")
# Always check/create in Fava, even if account exists in Castle DB
# This ensures Beancount has the Open directive
fava_account_exists = False
if True: # Always check Fava
# Check if account exists in Fava/Beancount
fava = get_fava_client()
try:
# Query Fava for this account
query = f"SELECT account WHERE account = '{account_name}'"
async with httpx.AsyncClient(timeout=5.0) as client:
response = await client.get(
f"{fava.base_url}/query",
params={"query_string": query}
)
response.raise_for_status()
result = response.json()
# Check if account exists in Fava
fava_account_exists = len(result.get("data", {}).get("rows", [])) > 0
logger.info(f"[FAVA CHECK] Account {account_name} exists in Fava: {fava_account_exists}")
if not fava_account_exists:
# Create account in Fava/Beancount via Open directive
logger.info(f"[FAVA CREATE] Creating account in Fava: {account_name}")
await fava.add_account(
account_name=account_name,
currencies=["EUR", "SATS", "USD"], # Support common currencies
metadata={
"user_id": user_id,
"description": f"User-specific {account_type.value} account"
}
)
logger.info(f"[FAVA CREATE] Successfully created account in Fava: {account_name}")
except Exception as e:
logger.error(f"[FAVA ERROR] Could not check/create account in Fava: {e}", exc_info=True)
# Continue anyway - account creation in Castle DB is still useful for metadata
# Ensure account exists in Castle DB (sync from Beancount if needed)
# This uses the account sync module for consistency
if not account:
logger.info(f"[CASTLE DB] Syncing account from Beancount to Castle DB: {account_name}")
from .account_sync import sync_single_account_from_beancount
# Sync from Beancount to Castle DB
created = await sync_single_account_from_beancount(account_name)
if created:
logger.info(f"[CASTLE DB] Account synced from Beancount: {account_name}")
else:
logger.warning(f"[CASTLE DB] Failed to sync account from Beancount: {account_name}")
# Fetch the account from Castle DB
account = await db.fetchone(
"""
SELECT * FROM accounts
WHERE user_id = :user_id AND account_type = :type AND name = :name
""",
{"user_id": user_id, "type": account_type.value, "name": account_name},
Account,
)
if not account:
logger.error(f"[CASTLE DB] Account still not found after sync: {account_name}")
# Fallback: create directly in Castle DB if sync failed
logger.info(f"[CASTLE DB] Creating account directly in Castle DB: {account_name}")
account = await create_account(
CreateAccount(
name=account_name,
account_type=account_type,
description=f"User-specific {account_type.value} account",
user_id=user_id,
)
)
else:
logger.info(f"[CASTLE DB] Account already exists in Castle DB: {account_name}")
return account
# ===== JOURNAL ENTRY OPERATIONS =====
# ===== JOURNAL ENTRY OPERATIONS (REMOVED) =====
#
# All journal entry operations have been moved to Fava/Beancount.
# Castle no longer maintains its own journal_entries and entry_lines tables.
#
# For journal entry operations, see:
# - views_api.py: api_create_journal_entry() - writes to Fava via FavaClient
# - views_api.py: API endpoints query Fava via FavaClient for reading entries
#
# Migration: m016_drop_obsolete_journal_tables
# Removed functions:
# - create_journal_entry()
# - get_journal_entry()
# - get_journal_entry_by_reference()
# - get_entry_lines()
# - get_all_journal_entries()
# - get_journal_entries_by_user()
# - count_all_journal_entries()
# - count_journal_entries_by_user()
# - get_journal_entries_by_user_and_account_type()
# - count_journal_entries_by_user_and_account_type()
# - get_account_transactions()
# ===== SETTINGS =====
async def create_castle_settings(
user_id: str, data: CastleSettings
) -> CastleSettings:
settings = UserCastleSettings(**data.dict(), id=user_id)
await db.insert("extension_settings", settings)
return settings
async def get_castle_settings(user_id: str) -> Optional[CastleSettings]:
return await db.fetchone(
"""
SELECT * FROM extension_settings
WHERE id = :user_id
""",
{"user_id": user_id},
CastleSettings,
)
async def update_castle_settings(
user_id: str, data: CastleSettings
) -> CastleSettings:
settings = UserCastleSettings(**data.dict(), id=user_id)
await db.update("extension_settings", settings)
return settings
# ===== USER WALLET SETTINGS =====
async def create_user_wallet_settings(
user_id: str, data: UserWalletSettings
) -> UserWalletSettings:
settings = StoredUserWalletSettings(**data.dict(), id=user_id)
await db.insert("user_wallet_settings", settings)
return settings
async def get_user_wallet_settings(user_id: str) -> Optional[UserWalletSettings]:
return await db.fetchone(
"""
SELECT * FROM user_wallet_settings
WHERE id = :user_id
""",
{"user_id": user_id},
UserWalletSettings,
)
async def update_user_wallet_settings(
user_id: str, data: UserWalletSettings
) -> UserWalletSettings:
settings = StoredUserWalletSettings(**data.dict(), id=user_id)
await db.update("user_wallet_settings", settings)
return settings
async def get_all_user_wallet_settings() -> list[StoredUserWalletSettings]:
"""Get all user wallet settings"""
return await db.fetchall(
"SELECT * FROM user_wallet_settings ORDER BY id",
{},
StoredUserWalletSettings,
)
# ===== MANUAL PAYMENT REQUESTS =====
async def create_manual_payment_request(
user_id: str, amount: int, description: str
) -> "ManualPaymentRequest":
"""Create a new manual payment request"""
from .models import ManualPaymentRequest
request_id = urlsafe_short_hash()
request = ManualPaymentRequest(
id=request_id,
user_id=user_id,
amount=amount,
description=description,
status="pending",
created_at=datetime.now(),
)
await db.execute(
"""
INSERT INTO manual_payment_requests (id, user_id, amount, description, status, created_at)
VALUES (:id, :user_id, :amount, :description, :status, :created_at)
""",
{
"id": request.id,
"user_id": request.user_id,
"amount": request.amount,
"description": request.description,
"status": request.status,
"created_at": request.created_at,
},
)
return request
async def get_manual_payment_request(request_id: str) -> Optional["ManualPaymentRequest"]:
"""Get a manual payment request by ID"""
from .models import ManualPaymentRequest
return await db.fetchone(
"SELECT * FROM manual_payment_requests WHERE id = :id",
{"id": request_id},
ManualPaymentRequest,
)
async def get_user_manual_payment_requests(
user_id: str, limit: int = 100
) -> list["ManualPaymentRequest"]:
"""Get all manual payment requests for a specific user"""
from .models import ManualPaymentRequest
return await db.fetchall(
"""
SELECT * FROM manual_payment_requests
WHERE user_id = :user_id
ORDER BY created_at DESC
LIMIT :limit
""",
{"user_id": user_id, "limit": limit},
ManualPaymentRequest,
)
async def get_all_manual_payment_requests(
status: Optional[str] = None, limit: int = 100
) -> list["ManualPaymentRequest"]:
"""Get all manual payment requests, optionally filtered by status"""
from .models import ManualPaymentRequest
if status:
return await db.fetchall(
"""
SELECT * FROM manual_payment_requests
WHERE status = :status
ORDER BY created_at DESC
LIMIT :limit
""",
{"status": status, "limit": limit},
ManualPaymentRequest,
)
else:
return await db.fetchall(
"""
SELECT * FROM manual_payment_requests
ORDER BY created_at DESC
LIMIT :limit
""",
{"limit": limit},
ManualPaymentRequest,
)
async def approve_manual_payment_request(
request_id: str, reviewed_by: str, journal_entry_id: str
) -> Optional["ManualPaymentRequest"]:
"""Approve a manual payment request"""
from .models import ManualPaymentRequest
await db.execute(
"""
UPDATE manual_payment_requests
SET status = 'approved', reviewed_at = :reviewed_at, reviewed_by = :reviewed_by, journal_entry_id = :journal_entry_id
WHERE id = :id
""",
{
"id": request_id,
"reviewed_at": datetime.now(),
"reviewed_by": reviewed_by,
"journal_entry_id": journal_entry_id,
},
)
return await get_manual_payment_request(request_id)
async def reject_manual_payment_request(
request_id: str, reviewed_by: str
) -> Optional["ManualPaymentRequest"]:
"""Reject a manual payment request"""
from .models import ManualPaymentRequest
await db.execute(
"""
UPDATE manual_payment_requests
SET status = 'rejected', reviewed_at = :reviewed_at, reviewed_by = :reviewed_by
WHERE id = :id
""",
{
"id": request_id,
"reviewed_at": datetime.now(),
"reviewed_by": reviewed_by,
},
)
return await get_manual_payment_request(request_id)
# ===== BALANCE ASSERTION OPERATIONS =====
async def create_balance_assertion(
data: CreateBalanceAssertion, created_by: str
) -> BalanceAssertion:
"""Create a new balance assertion"""
from decimal import Decimal
assertion_id = urlsafe_short_hash()
assertion_date = data.date if data.date else datetime.now()
assertion = BalanceAssertion(
id=assertion_id,
date=assertion_date,
account_id=data.account_id,
expected_balance_sats=data.expected_balance_sats,
expected_balance_fiat=data.expected_balance_fiat,
fiat_currency=data.fiat_currency,
tolerance_sats=data.tolerance_sats,
tolerance_fiat=data.tolerance_fiat,
status=AssertionStatus.PENDING,
created_by=created_by,
created_at=datetime.now(),
)
# Manually insert with Decimal fields converted to strings
await db.execute(
"""
INSERT INTO balance_assertions (
id, date, account_id, expected_balance_sats, expected_balance_fiat,
fiat_currency, tolerance_sats, tolerance_fiat, status, created_by, created_at
) VALUES (
:id, :date, :account_id, :expected_balance_sats, :expected_balance_fiat,
:fiat_currency, :tolerance_sats, :tolerance_fiat, :status, :created_by, :created_at
)
""",
{
"id": assertion.id,
"date": assertion.date,
"account_id": assertion.account_id,
"expected_balance_sats": assertion.expected_balance_sats,
"expected_balance_fiat": str(assertion.expected_balance_fiat) if assertion.expected_balance_fiat else None,
"fiat_currency": assertion.fiat_currency,
"tolerance_sats": assertion.tolerance_sats,
"tolerance_fiat": str(assertion.tolerance_fiat),
"status": assertion.status.value,
"created_by": assertion.created_by,
"created_at": assertion.created_at,
},
)
return assertion
async def get_balance_assertion(assertion_id: str) -> Optional[BalanceAssertion]:
"""Get a balance assertion by ID"""
from decimal import Decimal
row = await db.fetchone(
"SELECT * FROM balance_assertions WHERE id = :id",
{"id": assertion_id},
)
if not row:
return None
# Parse Decimal fields from TEXT storage
return BalanceAssertion(
id=row["id"],
date=row["date"],
account_id=row["account_id"],
expected_balance_sats=row["expected_balance_sats"],
expected_balance_fiat=Decimal(row["expected_balance_fiat"]) if row["expected_balance_fiat"] else None,
fiat_currency=row["fiat_currency"],
tolerance_sats=row["tolerance_sats"],
tolerance_fiat=Decimal(row["tolerance_fiat"]) if row["tolerance_fiat"] else Decimal("0"),
checked_balance_sats=row["checked_balance_sats"],
checked_balance_fiat=Decimal(row["checked_balance_fiat"]) if row["checked_balance_fiat"] else None,
difference_sats=row["difference_sats"],
difference_fiat=Decimal(row["difference_fiat"]) if row["difference_fiat"] else None,
status=AssertionStatus(row["status"]),
created_by=row["created_by"],
created_at=row["created_at"],
checked_at=row["checked_at"],
)
async def get_balance_assertions(
account_id: Optional[str] = None,
status: Optional[AssertionStatus] = None,
limit: int = 100,
) -> list[BalanceAssertion]:
"""Get balance assertions with optional filters"""
from decimal import Decimal
if account_id and status:
rows = await db.fetchall(
"""
SELECT * FROM balance_assertions
WHERE account_id = :account_id AND status = :status
ORDER BY date DESC
LIMIT :limit
""",
{"account_id": account_id, "status": status.value, "limit": limit},
)
elif account_id:
rows = await db.fetchall(
"""
SELECT * FROM balance_assertions
WHERE account_id = :account_id
ORDER BY date DESC
LIMIT :limit
""",
{"account_id": account_id, "limit": limit},
)
elif status:
rows = await db.fetchall(
"""
SELECT * FROM balance_assertions
WHERE status = :status
ORDER BY date DESC
LIMIT :limit
""",
{"status": status.value, "limit": limit},
)
else:
rows = await db.fetchall(
"""
SELECT * FROM balance_assertions
ORDER BY date DESC
LIMIT :limit
""",
{"limit": limit},
)
assertions = []
for row in rows:
assertions.append(
BalanceAssertion(
id=row["id"],
date=row["date"],
account_id=row["account_id"],
expected_balance_sats=row["expected_balance_sats"],
expected_balance_fiat=Decimal(row["expected_balance_fiat"]) if row["expected_balance_fiat"] else None,
fiat_currency=row["fiat_currency"],
tolerance_sats=row["tolerance_sats"],
tolerance_fiat=Decimal(row["tolerance_fiat"]) if row["tolerance_fiat"] else Decimal("0"),
checked_balance_sats=row["checked_balance_sats"],
checked_balance_fiat=Decimal(row["checked_balance_fiat"]) if row["checked_balance_fiat"] else None,
difference_sats=row["difference_sats"],
difference_fiat=Decimal(row["difference_fiat"]) if row["difference_fiat"] else None,
status=AssertionStatus(row["status"]),
created_by=row["created_by"],
created_at=row["created_at"],
checked_at=row["checked_at"],
)
)
return assertions
async def check_balance_assertion(assertion_id: str) -> BalanceAssertion:
"""
Check a balance assertion by comparing expected vs actual balance.
Updates the assertion with the check results.
Uses Fava/Beancount for balance queries.
"""
from decimal import Decimal
from .fava_client import get_fava_client
assertion = await get_balance_assertion(assertion_id)
if not assertion:
raise ValueError(f"Balance assertion {assertion_id} not found")
# Get actual account balance from Fava
account = await get_account(assertion.account_id)
if not account:
raise ValueError(f"Account {assertion.account_id} not found")
fava = get_fava_client()
# Get balance from Fava
balance_data = await fava.get_account_balance(account.name)
actual_balance = balance_data["sats"]
# Get fiat balance if needed
actual_fiat_balance = None
if assertion.fiat_currency and account.user_id:
user_balance_data = await fava.get_user_balance(account.user_id)
actual_fiat_balance = user_balance_data["fiat_balances"].get(assertion.fiat_currency, Decimal("0"))
# Check sats balance
difference_sats = actual_balance - assertion.expected_balance_sats
sats_match = abs(difference_sats) <= assertion.tolerance_sats
# Check fiat balance if applicable
fiat_match = True
difference_fiat = None
if assertion.expected_balance_fiat is not None and actual_fiat_balance is not None:
difference_fiat = actual_fiat_balance - assertion.expected_balance_fiat
fiat_match = abs(difference_fiat) <= assertion.tolerance_fiat
# Determine overall status
status = AssertionStatus.PASSED if (sats_match and fiat_match) else AssertionStatus.FAILED
# Update assertion with check results
await db.execute(
"""
UPDATE balance_assertions
SET checked_balance_sats = :checked_sats,
checked_balance_fiat = :checked_fiat,
difference_sats = :diff_sats,
difference_fiat = :diff_fiat,
status = :status,
checked_at = :checked_at
WHERE id = :id
""",
{
"id": assertion_id,
"checked_sats": actual_balance,
"checked_fiat": str(actual_fiat_balance) if actual_fiat_balance is not None else None,
"diff_sats": difference_sats,
"diff_fiat": str(difference_fiat) if difference_fiat is not None else None,
"status": status.value,
"checked_at": datetime.now(),
},
)
# Return updated assertion
return await get_balance_assertion(assertion_id)
async def delete_balance_assertion(assertion_id: str) -> None:
"""Delete a balance assertion"""
await db.execute(
"DELETE FROM balance_assertions WHERE id = :id",
{"id": assertion_id},
)
# User Equity Status CRUD operations
async def get_user_equity_status(user_id: str) -> Optional["UserEquityStatus"]:
"""Get user's equity eligibility status"""
from .models import UserEquityStatus
row = await db.fetchone(
"""
SELECT * FROM user_equity_status
WHERE user_id = :user_id
""",
{"user_id": user_id},
)
return UserEquityStatus(**row) if row else None
async def create_or_update_user_equity_status(
data: "CreateUserEquityStatus", granted_by: str
) -> "UserEquityStatus":
"""Create or update user equity eligibility status"""
from datetime import datetime
from .models import UserEquityStatus, AccountType
import uuid
# Auto-create user-specific equity account if granting eligibility
if data.is_equity_eligible:
# Generate equity account name: Equity:User-{user_id}
equity_account_name = f"Equity:User-{data.user_id[:8]}"
# Check if the equity account already exists
equity_account = await get_account_by_name(equity_account_name)
if not equity_account:
# Create the user-specific equity account
await db.execute(
"""
INSERT INTO accounts (id, name, account_type, description, user_id, created_at)
VALUES (:id, :name, :type, :description, :user_id, :created_at)
""",
{
"id": str(uuid.uuid4()),
"name": equity_account_name,
"type": AccountType.EQUITY.value,
"description": f"Equity contributions for user {data.user_id[:8]}",
"user_id": data.user_id,
"created_at": datetime.now(),
},
)
# Auto-populate equity_account_name in the data
data.equity_account_name = equity_account_name
# Check if user already has equity status
existing = await get_user_equity_status(data.user_id)
if existing:
# Update existing record
await db.execute(
"""
UPDATE user_equity_status
SET is_equity_eligible = :is_equity_eligible,
equity_account_name = :equity_account_name,
notes = :notes,
granted_by = :granted_by,
granted_at = :granted_at,
revoked_at = :revoked_at
WHERE user_id = :user_id
""",
{
"user_id": data.user_id,
"is_equity_eligible": data.is_equity_eligible,
"equity_account_name": data.equity_account_name,
"notes": data.notes,
"granted_by": granted_by,
"granted_at": datetime.now(),
"revoked_at": None if data.is_equity_eligible else datetime.now(),
},
)
else:
# Create new record
await db.execute(
"""
INSERT INTO user_equity_status (
user_id, is_equity_eligible, equity_account_name,
notes, granted_by, granted_at
)
VALUES (
:user_id, :is_equity_eligible, :equity_account_name,
:notes, :granted_by, :granted_at
)
""",
{
"user_id": data.user_id,
"is_equity_eligible": data.is_equity_eligible,
"equity_account_name": data.equity_account_name,
"notes": data.notes,
"granted_by": granted_by,
"granted_at": datetime.now(),
},
)
# Return the created/updated record
result = await get_user_equity_status(data.user_id)
if not result:
raise ValueError(f"Failed to create/update equity status for user {data.user_id}")
return result
async def revoke_user_equity_eligibility(user_id: str) -> Optional["UserEquityStatus"]:
"""Revoke user's equity contribution eligibility"""
from datetime import datetime
await db.execute(
"""
UPDATE user_equity_status
SET is_equity_eligible = FALSE,
revoked_at = :revoked_at
WHERE user_id = :user_id
""",
{"user_id": user_id, "revoked_at": datetime.now()},
)
return await get_user_equity_status(user_id)
async def get_all_equity_eligible_users() -> list["UserEquityStatus"]:
"""Get all equity-eligible users"""
from .models import UserEquityStatus
rows = await db.fetchall(
"""
SELECT * FROM user_equity_status
WHERE is_equity_eligible = TRUE
ORDER BY granted_at DESC
"""
)
return [UserEquityStatus(**row) for row in rows]
# ===== ACCOUNT PERMISSION OPERATIONS =====
async def create_account_permission(
data: "CreateAccountPermission", granted_by: str
) -> "AccountPermission":
"""Create a new account permission"""
from .models import AccountPermission
permission_id = urlsafe_short_hash()
permission = AccountPermission(
id=permission_id,
user_id=data.user_id,
account_id=data.account_id,
permission_type=data.permission_type,
granted_by=granted_by,
granted_at=datetime.now(),
expires_at=data.expires_at,
notes=data.notes,
)
await db.execute(
"""
INSERT INTO account_permissions (
id, user_id, account_id, permission_type, granted_by,
granted_at, expires_at, notes
)
VALUES (
:id, :user_id, :account_id, :permission_type, :granted_by,
:granted_at, :expires_at, :notes
)
""",
{
"id": permission.id,
"user_id": permission.user_id,
"account_id": permission.account_id,
"permission_type": permission.permission_type.value,
"granted_by": permission.granted_by,
"granted_at": permission.granted_at,
"expires_at": permission.expires_at,
"notes": permission.notes,
},
)
# Invalidate permission cache for this user (Cache class doesn't have delete method, use pop)
permission_cache._values.pop(f"permissions:user:{permission.user_id}", None)
permission_cache._values.pop(f"permissions:user:{permission.user_id}:{permission.permission_type.value}", None)
return permission
async def get_account_permission(permission_id: str) -> Optional["AccountPermission"]:
"""Get account permission by ID"""
from .models import AccountPermission, PermissionType
row = await db.fetchone(
"SELECT * FROM account_permissions WHERE id = :id",
{"id": permission_id},
)
if not row:
return None
return AccountPermission(
id=row["id"],
user_id=row["user_id"],
account_id=row["account_id"],
permission_type=PermissionType(row["permission_type"]),
granted_by=row["granted_by"],
granted_at=row["granted_at"],
expires_at=row["expires_at"],
notes=row["notes"],
)
async def get_user_permissions(
user_id: str, permission_type: Optional["PermissionType"] = None
) -> list["AccountPermission"]:
"""Get all permissions for a specific user with caching"""
from .models import AccountPermission, PermissionType
# Build cache key
cache_key = f"permissions:user:{user_id}"
if permission_type:
cache_key += f":{permission_type.value}"
# Try cache first
cached = permission_cache.get(cache_key)
if cached is not None:
return cached
# Query DB
if permission_type:
rows = await db.fetchall(
"""
SELECT * FROM account_permissions
WHERE user_id = :user_id
AND permission_type = :permission_type
AND (expires_at IS NULL OR expires_at > :now)
ORDER BY granted_at DESC
""",
{
"user_id": user_id,
"permission_type": permission_type.value,
"now": datetime.now(),
},
)
else:
rows = await db.fetchall(
"""
SELECT * FROM account_permissions
WHERE user_id = :user_id
AND (expires_at IS NULL OR expires_at > :now)
ORDER BY granted_at DESC
""",
{"user_id": user_id, "now": datetime.now()},
)
permissions = [
AccountPermission(
id=row["id"],
user_id=row["user_id"],
account_id=row["account_id"],
permission_type=PermissionType(row["permission_type"]),
granted_by=row["granted_by"],
granted_at=row["granted_at"],
expires_at=row["expires_at"],
notes=row["notes"],
)
for row in rows
]
# Cache result
permission_cache.set(cache_key, permissions, PERMISSION_CACHE_TTL)
return permissions
async def get_account_permissions(account_id: str) -> list["AccountPermission"]:
"""Get all permissions for a specific account"""
from .models import AccountPermission, PermissionType
rows = await db.fetchall(
"""
SELECT * FROM account_permissions
WHERE account_id = :account_id
AND (expires_at IS NULL OR expires_at > :now)
ORDER BY granted_at DESC
""",
{"account_id": account_id, "now": datetime.now()},
)
return [
AccountPermission(
id=row["id"],
user_id=row["user_id"],
account_id=row["account_id"],
permission_type=PermissionType(row["permission_type"]),
granted_by=row["granted_by"],
granted_at=row["granted_at"],
expires_at=row["expires_at"],
notes=row["notes"],
)
for row in rows
]
async def delete_account_permission(permission_id: str) -> None:
"""Delete (revoke) an account permission"""
# Get permission first to invalidate cache
permission = await get_account_permission(permission_id)
await db.execute(
"DELETE FROM account_permissions WHERE id = :id",
{"id": permission_id},
)
# Invalidate permission cache for this user (Cache class doesn't have delete method, use pop)
if permission:
permission_cache._values.pop(f"permissions:user:{permission.user_id}", None)
permission_cache._values.pop(f"permissions:user:{permission.user_id}:{permission.permission_type.value}", None)
async def check_user_has_permission(
user_id: str, account_id: str, permission_type: "PermissionType"
) -> bool:
"""Check if user has a specific permission on an account (direct permission only, no inheritance)"""
row = await db.fetchone(
"""
SELECT id FROM account_permissions
WHERE user_id = :user_id
AND account_id = :account_id
AND permission_type = :permission_type
AND (expires_at IS NULL OR expires_at > :now)
""",
{
"user_id": user_id,
"account_id": account_id,
"permission_type": permission_type.value,
"now": datetime.now(),
},
)
return row is not None
async def get_user_permissions_with_inheritance(
user_id: str, account_name: str, permission_type: "PermissionType"
) -> list[tuple["AccountPermission", Optional[str]]]:
"""
Get all permissions for a user on an account, including inherited permissions from parent accounts.
Returns list of tuples: (permission, parent_account_name or None)
Example:
If user has permission on "Expenses:Food", they also have permission on "Expenses:Food:Groceries"
Returns: [(permission_on_food, "Expenses:Food")]
"""
from .models import AccountPermission, PermissionType
# Get all user's permissions of this type
user_permissions = await get_user_permissions(user_id, permission_type)
# Find which permissions apply to this account (direct or inherited)
applicable_permissions = []
for perm in user_permissions:
# Get the account for this permission
account = await get_account(perm.account_id)
if not account:
continue
# Check if this account is a parent of the target account
# Parent accounts are indicated by hierarchical names (colon-separated)
# e.g., "Expenses:Food" is parent of "Expenses:Food:Groceries"
if account_name == account.name:
# Direct permission
applicable_permissions.append((perm, None))
elif account_name.startswith(account.name + ":"):
# Inherited permission from parent account
applicable_permissions.append((perm, account.name))
return applicable_permissions