castle/crud.py
padreug f2df2f543b Enhance RBAC user management UI and fix permission checks
- Add role management to "By User" tab
  - Show all users with roles and/or direct permissions
  - Add ability to assign/revoke roles from users
  - Display role chips as clickable and removable
  - Add "Assign Role" button for each user

- Fix account_id validation error in permission granting
  - Extract account_id string from Quasar q-select object
  - Apply fix to grantPermission, bulkGrantPermissions, and addRolePermission

- Fix role-based permission checking for expense submission
  - Update get_user_permissions_with_inheritance() to include role permissions
  - Ensures users with role-based permissions can submit expenses

- Improve Vue reactivity for role details dialog
  - Use spread operator to create fresh arrays
  - Add $nextTick() before showing dialog

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-13 10:17:28 +01:00

1674 lines
52 KiB
Python

import json
from datetime import datetime
from typing import Optional
import httpx
from lnbits.db import Database
from lnbits.helpers import urlsafe_short_hash
from lnbits.utils.cache import Cache
from .models import (
Account,
AccountPermission,
AccountType,
AssertionStatus,
AssignUserRole,
BalanceAssertion,
CastleSettings,
CreateAccount,
CreateAccountPermission,
CreateBalanceAssertion,
CreateEntryLine,
CreateJournalEntry,
CreateRole,
CreateRolePermission,
CreateUserEquityStatus,
EntryLine,
JournalEntry,
PermissionType,
Role,
RolePermission,
RoleWithPermissions,
StoredUserWalletSettings,
UpdateRole,
UserBalance,
UserCastleSettings,
UserEquityStatus,
UserRole,
UserWalletSettings,
UserWithRoles,
)
# Import core accounting logic
from .core.validation import (
ValidationError,
validate_journal_entry,
validate_balance,
validate_receivable_entry,
validate_expense_entry,
validate_payment_entry,
)
db = Database("ext_castle")
# ===== CACHING =====
# Cache for account and permission lookups to reduce DB queries
# TTLs: accounts=300s (5min), permissions=60s (1min)
account_cache = Cache() # 5 minutes for accounts (rarely change)
permission_cache = Cache() # 1 minute for permissions (may change frequently)
# Cache TTLs
ACCOUNT_CACHE_TTL = 300 # 5 minutes
PERMISSION_CACHE_TTL = 60 # 1 minute
# ===== ACCOUNT OPERATIONS =====
async def create_account(data: CreateAccount) -> Account:
account_id = urlsafe_short_hash()
account = Account(
id=account_id,
name=data.name,
account_type=data.account_type,
description=data.description,
user_id=data.user_id,
is_virtual=data.is_virtual,
created_at=datetime.now(),
)
await db.insert("accounts", account)
# Invalidate cache for this account (Cache class doesn't have delete method, use pop)
account_cache._values.pop(f"account:id:{account_id}", None)
account_cache._values.pop(f"account:name:{data.name}", None)
return account
async def get_account(account_id: str) -> Optional[Account]:
"""Get account by ID with caching"""
cache_key = f"account:id:{account_id}"
# Try cache first
cached = account_cache.get(cache_key)
if cached is not None:
return cached
# Query DB
account = await db.fetchone(
"SELECT * FROM accounts WHERE id = :id",
{"id": account_id},
Account,
)
# Cache result (even if None)
account_cache.set(cache_key, account, ACCOUNT_CACHE_TTL)
return account
async def get_account_by_name(name: str) -> Optional[Account]:
"""Get account by name (hierarchical format) with caching"""
cache_key = f"account:name:{name}"
# Try cache first
cached = account_cache.get(cache_key)
if cached is not None:
return cached
# Query DB
account = await db.fetchone(
"SELECT * FROM accounts WHERE name = :name",
{"name": name},
Account,
)
# Cache result (even if None)
account_cache.set(cache_key, account, ACCOUNT_CACHE_TTL)
return account
async def get_all_accounts(include_inactive: bool = False) -> list[Account]:
"""
Get all accounts, optionally including inactive ones.
Args:
include_inactive: If True, include inactive accounts. Default False.
Returns:
List of Account objects
"""
if include_inactive:
query = "SELECT * FROM accounts ORDER BY account_type, name"
else:
query = "SELECT * FROM accounts WHERE is_active = TRUE ORDER BY account_type, name"
return await db.fetchall(query, model=Account)
async def get_accounts_by_type(
account_type: AccountType, include_inactive: bool = False
) -> list[Account]:
"""
Get accounts by type, optionally including inactive ones.
Args:
account_type: The account type to filter by
include_inactive: If True, include inactive accounts. Default False.
Returns:
List of Account objects
"""
if include_inactive:
query = "SELECT * FROM accounts WHERE account_type = :type ORDER BY name"
else:
query = "SELECT * FROM accounts WHERE account_type = :type AND is_active = TRUE ORDER BY name"
return await db.fetchall(query, {"type": account_type.value}, Account)
async def update_account_is_active(account_id: str, is_active: bool) -> None:
"""
Update the is_active status of an account (soft delete/reactivate).
Args:
account_id: Account ID to update
is_active: True to activate, False to deactivate
"""
await db.execute(
"""
UPDATE accounts
SET is_active = :is_active
WHERE id = :account_id
""",
{"account_id": account_id, "is_active": is_active},
)
# Invalidate cache
account_cache._values.pop(f"account:id:{account_id}", None)
async def get_or_create_user_account(
user_id: str, account_type: AccountType, base_name: str
) -> Account:
"""
Get or create a user-specific account with hierarchical naming.
This function checks if the account exists in Fava/Beancount and creates it
if it doesn't exist. The account is also registered in Castle's database for
metadata tracking (permissions, descriptions, etc.).
Examples:
get_or_create_user_account("af983632", AccountType.ASSET, "Accounts Receivable")
"Assets:Receivable:User-af983632"
get_or_create_user_account("af983632", AccountType.LIABILITY, "Accounts Payable")
"Liabilities:Payable:User-af983632"
"""
from .account_utils import format_hierarchical_account_name
from .fava_client import get_fava_client
from loguru import logger
# Generate hierarchical account name
account_name = format_hierarchical_account_name(account_type, base_name, user_id)
# Try to find existing account with this hierarchical name in Castle DB
account = await db.fetchone(
"""
SELECT * FROM accounts
WHERE user_id = :user_id AND account_type = :type AND name = :name
""",
{"user_id": user_id, "type": account_type.value, "name": account_name},
Account,
)
logger.info(f"[ACCOUNT CHECK] User {user_id[:8]}, Account: {account_name}, In Castle DB: {account is not None}")
# Always check/create in Fava, even if account exists in Castle DB
# This ensures Beancount has the Open directive
fava_account_exists = False
if True: # Always check Fava
# Check if account exists in Fava/Beancount
fava = get_fava_client()
try:
# Query Fava for this account
query = f"SELECT account WHERE account = '{account_name}'"
async with httpx.AsyncClient(timeout=5.0) as client:
response = await client.get(
f"{fava.base_url}/query",
params={"query_string": query}
)
response.raise_for_status()
result = response.json()
# Check if account exists in Fava
fava_account_exists = len(result.get("data", {}).get("rows", [])) > 0
logger.info(f"[FAVA CHECK] Account {account_name} exists in Fava: {fava_account_exists}")
if not fava_account_exists:
# Create account in Fava/Beancount via Open directive
logger.info(f"[FAVA CREATE] Creating account in Fava: {account_name}")
await fava.add_account(
account_name=account_name,
currencies=["EUR", "SATS", "USD"], # Support common currencies
metadata={
"user_id": user_id,
"description": f"User-specific {account_type.value} account"
}
)
logger.info(f"[FAVA CREATE] Successfully created account in Fava: {account_name}")
except Exception as e:
logger.error(f"[FAVA ERROR] Could not check/create account in Fava: {e}", exc_info=True)
# Continue anyway - account creation in Castle DB is still useful for metadata
# Ensure account exists in Castle DB (sync from Beancount if needed)
# This uses the account sync module for consistency
if not account:
logger.info(f"[CASTLE DB] Syncing account from Beancount to Castle DB: {account_name}")
from .account_sync import sync_single_account_from_beancount
# Sync from Beancount to Castle DB
created = await sync_single_account_from_beancount(account_name)
if created:
logger.info(f"[CASTLE DB] Account synced from Beancount: {account_name}")
else:
logger.warning(f"[CASTLE DB] Failed to sync account from Beancount: {account_name}")
# Fetch the account from Castle DB
account = await db.fetchone(
"""
SELECT * FROM accounts
WHERE user_id = :user_id AND account_type = :type AND name = :name
""",
{"user_id": user_id, "type": account_type.value, "name": account_name},
Account,
)
if not account:
logger.error(f"[CASTLE DB] Account still not found after sync: {account_name}")
# Fallback: create directly in Castle DB if sync failed
logger.info(f"[CASTLE DB] Creating account directly in Castle DB: {account_name}")
try:
account = await create_account(
CreateAccount(
name=account_name,
account_type=account_type,
description=f"User-specific {account_type.value} account",
user_id=user_id,
)
)
except Exception as e:
# Handle UNIQUE constraint error - account already exists
if "UNIQUE constraint failed" in str(e) and "accounts.name" in str(e):
logger.warning(f"[CASTLE DB] Account already exists (UNIQUE constraint), fetching by name: {account_name}")
# Fetch existing account by name only (ignore user_id in query)
account = await db.fetchone(
"""
SELECT * FROM accounts
WHERE name = :name
""",
{"name": account_name},
Account,
)
if account:
logger.info(f"[CASTLE DB] Found existing account: {account_name} (user_id: {account.user_id})")
# Update user_id if it's NULL or different
if account.user_id != user_id:
logger.info(f"[CASTLE DB] Updating account user_id from {account.user_id} to {user_id}")
await db.execute(
"""
UPDATE accounts
SET user_id = :user_id
WHERE name = :name
""",
{"user_id": user_id, "name": account_name}
)
# Refresh account from DB
account = await db.fetchone(
"""
SELECT * FROM accounts
WHERE name = :name
""",
{"name": account_name},
Account,
)
else:
# Re-raise if it's a different error
raise
else:
logger.info(f"[CASTLE DB] Account already exists in Castle DB: {account_name}")
return account
# ===== JOURNAL ENTRY OPERATIONS =====
# ===== JOURNAL ENTRY OPERATIONS (REMOVED) =====
#
# All journal entry operations have been moved to Fava/Beancount.
# Castle no longer maintains its own journal_entries and entry_lines tables.
#
# For journal entry operations, see:
# - views_api.py: api_create_journal_entry() - writes to Fava via FavaClient
# - views_api.py: API endpoints query Fava via FavaClient for reading entries
#
# Migration: m016_drop_obsolete_journal_tables
# Removed functions:
# - create_journal_entry()
# - get_journal_entry()
# - get_journal_entry_by_reference()
# - get_entry_lines()
# - get_all_journal_entries()
# - get_journal_entries_by_user()
# - count_all_journal_entries()
# - count_journal_entries_by_user()
# - get_journal_entries_by_user_and_account_type()
# - count_journal_entries_by_user_and_account_type()
# - get_account_transactions()
# ===== SETTINGS =====
async def create_castle_settings(
user_id: str, data: CastleSettings
) -> CastleSettings:
settings = UserCastleSettings(**data.dict(), id=user_id)
await db.insert("extension_settings", settings)
return settings
async def get_castle_settings(user_id: str) -> Optional[CastleSettings]:
return await db.fetchone(
"""
SELECT * FROM extension_settings
WHERE id = :user_id
""",
{"user_id": user_id},
CastleSettings,
)
async def update_castle_settings(
user_id: str, data: CastleSettings
) -> CastleSettings:
settings = UserCastleSettings(**data.dict(), id=user_id)
await db.update("extension_settings", settings)
return settings
# ===== USER WALLET SETTINGS =====
async def create_user_wallet_settings(
user_id: str, data: UserWalletSettings
) -> UserWalletSettings:
settings = StoredUserWalletSettings(**data.dict(), id=user_id)
await db.insert("user_wallet_settings", settings)
return settings
async def get_user_wallet_settings(user_id: str) -> Optional[UserWalletSettings]:
return await db.fetchone(
"""
SELECT * FROM user_wallet_settings
WHERE id = :user_id
""",
{"user_id": user_id},
UserWalletSettings,
)
async def update_user_wallet_settings(
user_id: str, data: UserWalletSettings
) -> UserWalletSettings:
settings = StoredUserWalletSettings(**data.dict(), id=user_id)
await db.update("user_wallet_settings", settings)
return settings
async def get_all_user_wallet_settings() -> list[StoredUserWalletSettings]:
"""Get all user wallet settings"""
return await db.fetchall(
"SELECT * FROM user_wallet_settings ORDER BY id",
{},
StoredUserWalletSettings,
)
# ===== MANUAL PAYMENT REQUESTS =====
async def create_manual_payment_request(
user_id: str, amount: int, description: str
) -> "ManualPaymentRequest":
"""Create a new manual payment request"""
from .models import ManualPaymentRequest
request_id = urlsafe_short_hash()
request = ManualPaymentRequest(
id=request_id,
user_id=user_id,
amount=amount,
description=description,
status="pending",
created_at=datetime.now(),
)
await db.execute(
"""
INSERT INTO manual_payment_requests (id, user_id, amount, description, status, created_at)
VALUES (:id, :user_id, :amount, :description, :status, :created_at)
""",
{
"id": request.id,
"user_id": request.user_id,
"amount": request.amount,
"description": request.description,
"status": request.status,
"created_at": request.created_at,
},
)
return request
async def get_manual_payment_request(request_id: str) -> Optional["ManualPaymentRequest"]:
"""Get a manual payment request by ID"""
from .models import ManualPaymentRequest
return await db.fetchone(
"SELECT * FROM manual_payment_requests WHERE id = :id",
{"id": request_id},
ManualPaymentRequest,
)
async def get_user_manual_payment_requests(
user_id: str, limit: int = 100
) -> list["ManualPaymentRequest"]:
"""Get all manual payment requests for a specific user"""
from .models import ManualPaymentRequest
return await db.fetchall(
"""
SELECT * FROM manual_payment_requests
WHERE user_id = :user_id
ORDER BY created_at DESC
LIMIT :limit
""",
{"user_id": user_id, "limit": limit},
ManualPaymentRequest,
)
async def get_all_manual_payment_requests(
status: Optional[str] = None, limit: int = 100
) -> list["ManualPaymentRequest"]:
"""Get all manual payment requests, optionally filtered by status"""
from .models import ManualPaymentRequest
if status:
return await db.fetchall(
"""
SELECT * FROM manual_payment_requests
WHERE status = :status
ORDER BY created_at DESC
LIMIT :limit
""",
{"status": status, "limit": limit},
ManualPaymentRequest,
)
else:
return await db.fetchall(
"""
SELECT * FROM manual_payment_requests
ORDER BY created_at DESC
LIMIT :limit
""",
{"limit": limit},
ManualPaymentRequest,
)
async def approve_manual_payment_request(
request_id: str, reviewed_by: str, journal_entry_id: str
) -> Optional["ManualPaymentRequest"]:
"""Approve a manual payment request"""
from .models import ManualPaymentRequest
await db.execute(
"""
UPDATE manual_payment_requests
SET status = 'approved', reviewed_at = :reviewed_at, reviewed_by = :reviewed_by, journal_entry_id = :journal_entry_id
WHERE id = :id
""",
{
"id": request_id,
"reviewed_at": datetime.now(),
"reviewed_by": reviewed_by,
"journal_entry_id": journal_entry_id,
},
)
return await get_manual_payment_request(request_id)
async def reject_manual_payment_request(
request_id: str, reviewed_by: str
) -> Optional["ManualPaymentRequest"]:
"""Reject a manual payment request"""
from .models import ManualPaymentRequest
await db.execute(
"""
UPDATE manual_payment_requests
SET status = 'rejected', reviewed_at = :reviewed_at, reviewed_by = :reviewed_by
WHERE id = :id
""",
{
"id": request_id,
"reviewed_at": datetime.now(),
"reviewed_by": reviewed_by,
},
)
return await get_manual_payment_request(request_id)
# ===== BALANCE ASSERTION OPERATIONS =====
async def create_balance_assertion(
data: CreateBalanceAssertion, created_by: str
) -> BalanceAssertion:
"""Create a new balance assertion"""
from decimal import Decimal
assertion_id = urlsafe_short_hash()
assertion_date = data.date if data.date else datetime.now()
assertion = BalanceAssertion(
id=assertion_id,
date=assertion_date,
account_id=data.account_id,
expected_balance_sats=data.expected_balance_sats,
expected_balance_fiat=data.expected_balance_fiat,
fiat_currency=data.fiat_currency,
tolerance_sats=data.tolerance_sats,
tolerance_fiat=data.tolerance_fiat,
status=AssertionStatus.PENDING,
created_by=created_by,
created_at=datetime.now(),
)
# Manually insert with Decimal fields converted to strings
await db.execute(
"""
INSERT INTO balance_assertions (
id, date, account_id, expected_balance_sats, expected_balance_fiat,
fiat_currency, tolerance_sats, tolerance_fiat, status, created_by, created_at
) VALUES (
:id, :date, :account_id, :expected_balance_sats, :expected_balance_fiat,
:fiat_currency, :tolerance_sats, :tolerance_fiat, :status, :created_by, :created_at
)
""",
{
"id": assertion.id,
"date": assertion.date,
"account_id": assertion.account_id,
"expected_balance_sats": assertion.expected_balance_sats,
"expected_balance_fiat": str(assertion.expected_balance_fiat) if assertion.expected_balance_fiat else None,
"fiat_currency": assertion.fiat_currency,
"tolerance_sats": assertion.tolerance_sats,
"tolerance_fiat": str(assertion.tolerance_fiat),
"status": assertion.status.value,
"created_by": assertion.created_by,
"created_at": assertion.created_at,
},
)
return assertion
async def get_balance_assertion(assertion_id: str) -> Optional[BalanceAssertion]:
"""Get a balance assertion by ID"""
from decimal import Decimal
row = await db.fetchone(
"SELECT * FROM balance_assertions WHERE id = :id",
{"id": assertion_id},
)
if not row:
return None
# Parse Decimal fields from TEXT storage
return BalanceAssertion(
id=row["id"],
date=row["date"],
account_id=row["account_id"],
expected_balance_sats=row["expected_balance_sats"],
expected_balance_fiat=Decimal(row["expected_balance_fiat"]) if row["expected_balance_fiat"] else None,
fiat_currency=row["fiat_currency"],
tolerance_sats=row["tolerance_sats"],
tolerance_fiat=Decimal(row["tolerance_fiat"]) if row["tolerance_fiat"] else Decimal("0"),
checked_balance_sats=row["checked_balance_sats"],
checked_balance_fiat=Decimal(row["checked_balance_fiat"]) if row["checked_balance_fiat"] else None,
difference_sats=row["difference_sats"],
difference_fiat=Decimal(row["difference_fiat"]) if row["difference_fiat"] else None,
status=AssertionStatus(row["status"]),
created_by=row["created_by"],
created_at=row["created_at"],
checked_at=row["checked_at"],
)
async def get_balance_assertions(
account_id: Optional[str] = None,
status: Optional[AssertionStatus] = None,
limit: int = 100,
) -> list[BalanceAssertion]:
"""Get balance assertions with optional filters"""
from decimal import Decimal
if account_id and status:
rows = await db.fetchall(
"""
SELECT * FROM balance_assertions
WHERE account_id = :account_id AND status = :status
ORDER BY date DESC
LIMIT :limit
""",
{"account_id": account_id, "status": status.value, "limit": limit},
)
elif account_id:
rows = await db.fetchall(
"""
SELECT * FROM balance_assertions
WHERE account_id = :account_id
ORDER BY date DESC
LIMIT :limit
""",
{"account_id": account_id, "limit": limit},
)
elif status:
rows = await db.fetchall(
"""
SELECT * FROM balance_assertions
WHERE status = :status
ORDER BY date DESC
LIMIT :limit
""",
{"status": status.value, "limit": limit},
)
else:
rows = await db.fetchall(
"""
SELECT * FROM balance_assertions
ORDER BY date DESC
LIMIT :limit
""",
{"limit": limit},
)
assertions = []
for row in rows:
assertions.append(
BalanceAssertion(
id=row["id"],
date=row["date"],
account_id=row["account_id"],
expected_balance_sats=row["expected_balance_sats"],
expected_balance_fiat=Decimal(row["expected_balance_fiat"]) if row["expected_balance_fiat"] else None,
fiat_currency=row["fiat_currency"],
tolerance_sats=row["tolerance_sats"],
tolerance_fiat=Decimal(row["tolerance_fiat"]) if row["tolerance_fiat"] else Decimal("0"),
checked_balance_sats=row["checked_balance_sats"],
checked_balance_fiat=Decimal(row["checked_balance_fiat"]) if row["checked_balance_fiat"] else None,
difference_sats=row["difference_sats"],
difference_fiat=Decimal(row["difference_fiat"]) if row["difference_fiat"] else None,
status=AssertionStatus(row["status"]),
created_by=row["created_by"],
created_at=row["created_at"],
checked_at=row["checked_at"],
)
)
return assertions
async def check_balance_assertion(assertion_id: str) -> BalanceAssertion:
"""
Check a balance assertion by comparing expected vs actual balance.
Updates the assertion with the check results.
Uses Fava/Beancount for balance queries.
"""
from decimal import Decimal
from .fava_client import get_fava_client
assertion = await get_balance_assertion(assertion_id)
if not assertion:
raise ValueError(f"Balance assertion {assertion_id} not found")
# Get actual account balance from Fava
account = await get_account(assertion.account_id)
if not account:
raise ValueError(f"Account {assertion.account_id} not found")
fava = get_fava_client()
# Get balance from Fava
balance_data = await fava.get_account_balance(account.name)
actual_balance = balance_data["sats"]
# Get fiat balance if needed
actual_fiat_balance = None
if assertion.fiat_currency and account.user_id:
user_balance_data = await fava.get_user_balance(account.user_id)
actual_fiat_balance = user_balance_data["fiat_balances"].get(assertion.fiat_currency, Decimal("0"))
# Check sats balance
difference_sats = actual_balance - assertion.expected_balance_sats
sats_match = abs(difference_sats) <= assertion.tolerance_sats
# Check fiat balance if applicable
fiat_match = True
difference_fiat = None
if assertion.expected_balance_fiat is not None and actual_fiat_balance is not None:
difference_fiat = actual_fiat_balance - assertion.expected_balance_fiat
fiat_match = abs(difference_fiat) <= assertion.tolerance_fiat
# Determine overall status
status = AssertionStatus.PASSED if (sats_match and fiat_match) else AssertionStatus.FAILED
# Update assertion with check results
await db.execute(
"""
UPDATE balance_assertions
SET checked_balance_sats = :checked_sats,
checked_balance_fiat = :checked_fiat,
difference_sats = :diff_sats,
difference_fiat = :diff_fiat,
status = :status,
checked_at = :checked_at
WHERE id = :id
""",
{
"id": assertion_id,
"checked_sats": actual_balance,
"checked_fiat": str(actual_fiat_balance) if actual_fiat_balance is not None else None,
"diff_sats": difference_sats,
"diff_fiat": str(difference_fiat) if difference_fiat is not None else None,
"status": status.value,
"checked_at": datetime.now(),
},
)
# Return updated assertion
return await get_balance_assertion(assertion_id)
async def delete_balance_assertion(assertion_id: str) -> None:
"""Delete a balance assertion"""
await db.execute(
"DELETE FROM balance_assertions WHERE id = :id",
{"id": assertion_id},
)
# User Equity Status CRUD operations
async def get_user_equity_status(user_id: str) -> Optional["UserEquityStatus"]:
"""Get user's equity eligibility status"""
from .models import UserEquityStatus
row = await db.fetchone(
"""
SELECT * FROM user_equity_status
WHERE user_id = :user_id
""",
{"user_id": user_id},
)
return UserEquityStatus(**row) if row else None
async def create_or_update_user_equity_status(
data: "CreateUserEquityStatus", granted_by: str
) -> "UserEquityStatus":
"""Create or update user equity eligibility status"""
from datetime import datetime
from .models import UserEquityStatus, AccountType
import uuid
# Auto-create user-specific equity account if granting eligibility
if data.is_equity_eligible:
# Generate equity account name: Equity:User-{user_id}
equity_account_name = f"Equity:User-{data.user_id[:8]}"
# Check if the equity account already exists
equity_account = await get_account_by_name(equity_account_name)
if not equity_account:
# Create the user-specific equity account
await db.execute(
"""
INSERT INTO accounts (id, name, account_type, description, user_id, created_at)
VALUES (:id, :name, :type, :description, :user_id, :created_at)
""",
{
"id": str(uuid.uuid4()),
"name": equity_account_name,
"type": AccountType.EQUITY.value,
"description": f"Equity contributions for user {data.user_id[:8]}",
"user_id": data.user_id,
"created_at": datetime.now(),
},
)
# Auto-populate equity_account_name in the data
data.equity_account_name = equity_account_name
# Check if user already has equity status
existing = await get_user_equity_status(data.user_id)
if existing:
# Update existing record
await db.execute(
"""
UPDATE user_equity_status
SET is_equity_eligible = :is_equity_eligible,
equity_account_name = :equity_account_name,
notes = :notes,
granted_by = :granted_by,
granted_at = :granted_at,
revoked_at = :revoked_at
WHERE user_id = :user_id
""",
{
"user_id": data.user_id,
"is_equity_eligible": data.is_equity_eligible,
"equity_account_name": data.equity_account_name,
"notes": data.notes,
"granted_by": granted_by,
"granted_at": datetime.now(),
"revoked_at": None if data.is_equity_eligible else datetime.now(),
},
)
else:
# Create new record
await db.execute(
"""
INSERT INTO user_equity_status (
user_id, is_equity_eligible, equity_account_name,
notes, granted_by, granted_at
)
VALUES (
:user_id, :is_equity_eligible, :equity_account_name,
:notes, :granted_by, :granted_at
)
""",
{
"user_id": data.user_id,
"is_equity_eligible": data.is_equity_eligible,
"equity_account_name": data.equity_account_name,
"notes": data.notes,
"granted_by": granted_by,
"granted_at": datetime.now(),
},
)
# Return the created/updated record
result = await get_user_equity_status(data.user_id)
if not result:
raise ValueError(f"Failed to create/update equity status for user {data.user_id}")
return result
async def revoke_user_equity_eligibility(user_id: str) -> Optional["UserEquityStatus"]:
"""Revoke user's equity contribution eligibility"""
from datetime import datetime
await db.execute(
"""
UPDATE user_equity_status
SET is_equity_eligible = FALSE,
revoked_at = :revoked_at
WHERE user_id = :user_id
""",
{"user_id": user_id, "revoked_at": datetime.now()},
)
return await get_user_equity_status(user_id)
async def get_all_equity_eligible_users() -> list["UserEquityStatus"]:
"""Get all equity-eligible users"""
from .models import UserEquityStatus
rows = await db.fetchall(
"""
SELECT * FROM user_equity_status
WHERE is_equity_eligible = TRUE
ORDER BY granted_at DESC
"""
)
return [UserEquityStatus(**row) for row in rows]
# ===== ACCOUNT PERMISSION OPERATIONS =====
async def create_account_permission(
data: "CreateAccountPermission", granted_by: str
) -> "AccountPermission":
"""
Create a new account permission.
Raises:
ValueError: If account is inactive or doesn't exist
"""
from .models import AccountPermission
# Validate account exists and is active
account = await get_account(data.account_id)
if not account:
raise ValueError(f"Account {data.account_id} not found")
if not account.is_active:
raise ValueError(
f"Cannot grant permission on inactive account: {account.name}"
)
permission_id = urlsafe_short_hash()
permission = AccountPermission(
id=permission_id,
user_id=data.user_id,
account_id=data.account_id,
permission_type=data.permission_type,
granted_by=granted_by,
granted_at=datetime.now(),
expires_at=data.expires_at,
notes=data.notes,
)
await db.execute(
"""
INSERT INTO account_permissions (
id, user_id, account_id, permission_type, granted_by,
granted_at, expires_at, notes
)
VALUES (
:id, :user_id, :account_id, :permission_type, :granted_by,
:granted_at, :expires_at, :notes
)
""",
{
"id": permission.id,
"user_id": permission.user_id,
"account_id": permission.account_id,
"permission_type": permission.permission_type.value,
"granted_by": permission.granted_by,
"granted_at": permission.granted_at,
"expires_at": permission.expires_at,
"notes": permission.notes,
},
)
# Invalidate permission cache for this user (Cache class doesn't have delete method, use pop)
permission_cache._values.pop(f"permissions:user:{permission.user_id}", None)
permission_cache._values.pop(f"permissions:user:{permission.user_id}:{permission.permission_type.value}", None)
return permission
async def get_account_permission(permission_id: str) -> Optional["AccountPermission"]:
"""Get account permission by ID"""
from .models import AccountPermission, PermissionType
row = await db.fetchone(
"SELECT * FROM account_permissions WHERE id = :id",
{"id": permission_id},
)
if not row:
return None
return AccountPermission(
id=row["id"],
user_id=row["user_id"],
account_id=row["account_id"],
permission_type=PermissionType(row["permission_type"]),
granted_by=row["granted_by"],
granted_at=row["granted_at"],
expires_at=row["expires_at"],
notes=row["notes"],
)
async def get_user_permissions(
user_id: str, permission_type: Optional["PermissionType"] = None
) -> list["AccountPermission"]:
"""Get all permissions for a specific user with caching"""
from .models import AccountPermission, PermissionType
# Build cache key
cache_key = f"permissions:user:{user_id}"
if permission_type:
cache_key += f":{permission_type.value}"
# Try cache first
cached = permission_cache.get(cache_key)
if cached is not None:
return cached
# Query DB
if permission_type:
rows = await db.fetchall(
"""
SELECT * FROM account_permissions
WHERE user_id = :user_id
AND permission_type = :permission_type
AND (expires_at IS NULL OR expires_at > :now)
ORDER BY granted_at DESC
""",
{
"user_id": user_id,
"permission_type": permission_type.value,
"now": datetime.now(),
},
)
else:
rows = await db.fetchall(
"""
SELECT * FROM account_permissions
WHERE user_id = :user_id
AND (expires_at IS NULL OR expires_at > :now)
ORDER BY granted_at DESC
""",
{"user_id": user_id, "now": datetime.now()},
)
permissions = [
AccountPermission(
id=row["id"],
user_id=row["user_id"],
account_id=row["account_id"],
permission_type=PermissionType(row["permission_type"]),
granted_by=row["granted_by"],
granted_at=row["granted_at"],
expires_at=row["expires_at"],
notes=row["notes"],
)
for row in rows
]
# Cache result
permission_cache.set(cache_key, permissions, PERMISSION_CACHE_TTL)
return permissions
async def get_account_permissions(account_id: str) -> list["AccountPermission"]:
"""Get all permissions for a specific account"""
from .models import AccountPermission, PermissionType
rows = await db.fetchall(
"""
SELECT * FROM account_permissions
WHERE account_id = :account_id
AND (expires_at IS NULL OR expires_at > :now)
ORDER BY granted_at DESC
""",
{"account_id": account_id, "now": datetime.now()},
)
return [
AccountPermission(
id=row["id"],
user_id=row["user_id"],
account_id=row["account_id"],
permission_type=PermissionType(row["permission_type"]),
granted_by=row["granted_by"],
granted_at=row["granted_at"],
expires_at=row["expires_at"],
notes=row["notes"],
)
for row in rows
]
async def delete_account_permission(permission_id: str) -> None:
"""Delete (revoke) an account permission"""
# Get permission first to invalidate cache
permission = await get_account_permission(permission_id)
await db.execute(
"DELETE FROM account_permissions WHERE id = :id",
{"id": permission_id},
)
# Invalidate permission cache for this user (Cache class doesn't have delete method, use pop)
if permission:
permission_cache._values.pop(f"permissions:user:{permission.user_id}", None)
permission_cache._values.pop(f"permissions:user:{permission.user_id}:{permission.permission_type.value}", None)
async def check_user_has_permission(
user_id: str, account_id: str, permission_type: "PermissionType"
) -> bool:
"""Check if user has a specific permission on an account (direct permission only, no inheritance)"""
row = await db.fetchone(
"""
SELECT id FROM account_permissions
WHERE user_id = :user_id
AND account_id = :account_id
AND permission_type = :permission_type
AND (expires_at IS NULL OR expires_at > :now)
""",
{
"user_id": user_id,
"account_id": account_id,
"permission_type": permission_type.value,
"now": datetime.now(),
},
)
return row is not None
async def get_user_permissions_with_inheritance(
user_id: str, account_name: str, permission_type: "PermissionType"
) -> list[tuple["AccountPermission", Optional[str]]]:
"""
Get all permissions for a user on an account, including inherited permissions from parent accounts.
Includes both direct permissions AND role-based permissions.
Returns list of tuples: (permission, parent_account_name or None)
Example:
If user has permission on "Expenses:Food", they also have permission on "Expenses:Food:Groceries"
Returns: [(permission_on_food, "Expenses:Food")]
"""
from .models import AccountPermission, PermissionType
# Get direct user permissions of this type
direct_permissions = await get_user_permissions(user_id, permission_type)
# Get role-based permissions of this type
role_permissions_list = await get_user_permissions_from_roles(user_id)
role_perms = []
for role, perms in role_permissions_list:
# Filter for the specific permission type
role_perms.extend([p for p in perms if p.permission_type == permission_type])
# Combine direct and role-based permissions
all_permissions = list(direct_permissions) + role_perms
# Find which permissions apply to this account (direct or inherited)
applicable_permissions = []
for perm in all_permissions:
# Get the account for this permission
account = await get_account(perm.account_id)
if not account:
continue
# Check if this account is a parent of the target account
# Parent accounts are indicated by hierarchical names (colon-separated)
# e.g., "Expenses:Food" is parent of "Expenses:Food:Groceries"
if account_name == account.name:
# Direct permission
applicable_permissions.append((perm, None))
elif account_name.startswith(account.name + ":"):
# Inherited permission from parent account
applicable_permissions.append((perm, account.name))
return applicable_permissions
# ===== ROLE-BASED ACCESS CONTROL (RBAC) OPERATIONS =====
async def create_role(data: CreateRole, created_by: str) -> Role:
"""Create a new role"""
role_id = urlsafe_short_hash()
role = Role(
id=role_id,
name=data.name,
description=data.description,
is_default=data.is_default,
created_by=created_by,
created_at=datetime.now(),
)
await db.execute(
"""
INSERT INTO roles (id, name, description, is_default, created_by, created_at)
VALUES (:id, :name, :description, :is_default, :created_by, :created_at)
""",
{
"id": role.id,
"name": role.name,
"description": role.description,
"is_default": role.is_default,
"created_by": role.created_by,
"created_at": role.created_at,
},
)
return role
async def get_role(role_id: str) -> Optional[Role]:
"""Get role by ID"""
row = await db.fetchone(
"SELECT * FROM roles WHERE id = :id",
{"id": role_id},
)
if not row:
return None
return Role(
id=row["id"],
name=row["name"],
description=row["description"],
is_default=row["is_default"],
created_by=row["created_by"],
created_at=row["created_at"],
)
async def get_role_by_name(name: str) -> Optional[Role]:
"""Get role by name"""
row = await db.fetchone(
"SELECT * FROM roles WHERE name = :name",
{"name": name},
)
if not row:
return None
return Role(
id=row["id"],
name=row["name"],
description=row["description"],
is_default=row["is_default"],
created_by=row["created_by"],
created_at=row["created_at"],
)
async def get_all_roles() -> list[Role]:
"""Get all roles"""
rows = await db.fetchall(
"SELECT * FROM roles ORDER BY name",
)
return [
Role(
id=row["id"],
name=row["name"],
description=row["description"],
is_default=row["is_default"],
created_by=row["created_by"],
created_at=row["created_at"],
)
for row in rows
]
async def get_default_role() -> Optional[Role]:
"""Get the default role that is auto-assigned to new users"""
row = await db.fetchone(
"SELECT * FROM roles WHERE is_default = TRUE LIMIT 1",
)
if not row:
return None
return Role(
id=row["id"],
name=row["name"],
description=row["description"],
is_default=row["is_default"],
created_by=row["created_by"],
created_at=row["created_at"],
)
async def update_role(role_id: str, data: UpdateRole) -> Optional[Role]:
"""Update a role"""
# If setting this role as default, unset any other default roles
if data.is_default is True:
await db.execute(
"UPDATE roles SET is_default = FALSE WHERE id != :id",
{"id": role_id},
)
# Build update statement dynamically based on provided fields
updates = []
params = {"id": role_id}
if data.name is not None:
updates.append("name = :name")
params["name"] = data.name
if data.description is not None:
updates.append("description = :description")
params["description"] = data.description
if data.is_default is not None:
updates.append("is_default = :is_default")
params["is_default"] = data.is_default
if not updates:
return await get_role(role_id)
await db.execute(
f"UPDATE roles SET {', '.join(updates)} WHERE id = :id",
params,
)
return await get_role(role_id)
async def delete_role(role_id: str) -> None:
"""Delete a role (cascade deletes role_permissions and user_roles)"""
await db.execute(
"DELETE FROM roles WHERE id = :id",
{"id": role_id},
)
# ===== ROLE PERMISSION OPERATIONS =====
async def create_role_permission(data: CreateRolePermission) -> RolePermission:
"""Create a permission for a role"""
permission_id = urlsafe_short_hash()
permission = RolePermission(
id=permission_id,
role_id=data.role_id,
account_id=data.account_id,
permission_type=data.permission_type,
notes=data.notes,
created_at=datetime.now(),
)
await db.execute(
"""
INSERT INTO role_permissions (id, role_id, account_id, permission_type, notes, created_at)
VALUES (:id, :role_id, :account_id, :permission_type, :notes, :created_at)
""",
{
"id": permission.id,
"role_id": permission.role_id,
"account_id": permission.account_id,
"permission_type": permission.permission_type.value,
"notes": permission.notes,
"created_at": permission.created_at,
},
)
return permission
async def get_role_permissions(role_id: str) -> list[RolePermission]:
"""Get all permissions for a specific role"""
rows = await db.fetchall(
"""
SELECT * FROM role_permissions
WHERE role_id = :role_id
ORDER BY created_at DESC
""",
{"role_id": role_id},
)
return [
RolePermission(
id=row["id"],
role_id=row["role_id"],
account_id=row["account_id"],
permission_type=PermissionType(row["permission_type"]),
notes=row["notes"],
created_at=row["created_at"],
)
for row in rows
]
async def delete_role_permission(permission_id: str) -> None:
"""Delete a role permission"""
await db.execute(
"DELETE FROM role_permissions WHERE id = :id",
{"id": permission_id},
)
# ===== USER ROLE OPERATIONS =====
async def assign_user_role(data: AssignUserRole, granted_by: str) -> UserRole:
"""Assign a user to a role"""
user_role_id = urlsafe_short_hash()
user_role = UserRole(
id=user_role_id,
user_id=data.user_id,
role_id=data.role_id,
granted_by=granted_by,
granted_at=datetime.now(),
expires_at=data.expires_at,
notes=data.notes,
)
await db.execute(
"""
INSERT INTO user_roles (id, user_id, role_id, granted_by, granted_at, expires_at, notes)
VALUES (:id, :user_id, :role_id, :granted_by, :granted_at, :expires_at, :notes)
""",
{
"id": user_role.id,
"user_id": user_role.user_id,
"role_id": user_role.role_id,
"granted_by": user_role.granted_by,
"granted_at": user_role.granted_at,
"expires_at": user_role.expires_at,
"notes": user_role.notes,
},
)
return user_role
async def get_user_roles(user_id: str) -> list[UserRole]:
"""Get all active roles for a user"""
rows = await db.fetchall(
"""
SELECT * FROM user_roles
WHERE user_id = :user_id
AND (expires_at IS NULL OR expires_at > :now)
ORDER BY granted_at DESC
""",
{"user_id": user_id, "now": datetime.now()},
)
return [
UserRole(
id=row["id"],
user_id=row["user_id"],
role_id=row["role_id"],
granted_by=row["granted_by"],
granted_at=row["granted_at"],
expires_at=row["expires_at"],
notes=row["notes"],
)
for row in rows
]
async def get_all_user_roles() -> list[UserRole]:
"""Get all active user role assignments"""
rows = await db.fetchall(
"""
SELECT * FROM user_roles
WHERE (expires_at IS NULL OR expires_at > :now)
ORDER BY user_id, granted_at DESC
""",
{"now": datetime.now()},
)
return [
UserRole(
id=row["id"],
user_id=row["user_id"],
role_id=row["role_id"],
granted_by=row["granted_by"],
granted_at=row["granted_at"],
expires_at=row["expires_at"],
notes=row["notes"],
)
for row in rows
]
async def get_role_users(role_id: str) -> list[UserRole]:
"""Get all users assigned to a role"""
rows = await db.fetchall(
"""
SELECT * FROM user_roles
WHERE role_id = :role_id
AND (expires_at IS NULL OR expires_at > :now)
ORDER BY granted_at DESC
""",
{"role_id": role_id, "now": datetime.now()},
)
return [
UserRole(
id=row["id"],
user_id=row["user_id"],
role_id=row["role_id"],
granted_by=row["granted_by"],
granted_at=row["granted_at"],
expires_at=row["expires_at"],
notes=row["notes"],
)
for row in rows
]
async def revoke_user_role(user_role_id: str) -> None:
"""Revoke a user's role assignment"""
await db.execute(
"DELETE FROM user_roles WHERE id = :id",
{"id": user_role_id},
)
async def get_role_count_for_user(user_id: str) -> int:
"""Get count of active roles for a user"""
row = await db.fetchone(
"""
SELECT COUNT(*) as count FROM user_roles
WHERE user_id = :user_id
AND (expires_at IS NULL OR expires_at > :now)
""",
{"user_id": user_id, "now": datetime.now()},
)
return row["count"] if row else 0
async def auto_assign_default_role(user_id: str, assigned_by: str) -> UserRole | None:
"""
Auto-assign the default role to a user if they don't have any roles yet.
Returns the created UserRole if assigned, None if user already has roles or no default role exists.
"""
from loguru import logger
logger.info(f"[AUTO-ASSIGN] Checking auto-assignment for user {user_id}")
# Check if user already has any roles
user_role_count = await get_role_count_for_user(user_id)
logger.info(f"[AUTO-ASSIGN] User {user_id} has {user_role_count} roles")
if user_role_count > 0:
logger.info(f"[AUTO-ASSIGN] User {user_id} already has roles, skipping auto-assignment")
return None
# Find the default role
default_role = await get_default_role()
if not default_role:
logger.warning(f"[AUTO-ASSIGN] No default role found, cannot auto-assign for user {user_id}")
return None
logger.info(f"[AUTO-ASSIGN] Found default role: {default_role.name} (id: {default_role.id})")
# Assign the default role
data = AssignUserRole(
user_id=user_id,
role_id=default_role.id,
notes="Auto-assigned default role on first access",
)
result = await assign_user_role(data, assigned_by)
logger.info(f"[AUTO-ASSIGN] Successfully assigned role {default_role.name} to user {user_id}")
return result
async def get_user_count_for_role(role_id: str) -> int:
"""Get count of users assigned to a role"""
row = await db.fetchone(
"""
SELECT COUNT(*) as count FROM user_roles
WHERE role_id = :role_id
AND (expires_at IS NULL OR expires_at > :now)
""",
{"role_id": role_id, "now": datetime.now()},
)
return row["count"] if row else 0
# ===== RBAC HELPER FUNCTIONS =====
async def get_user_permissions_from_roles(
user_id: str,
) -> list[tuple[Role, list[RolePermission]]]:
"""
Get all permissions a user has through their role assignments.
Returns list of tuples: (role, list of permissions from that role)
"""
# Get user's active roles
user_roles = await get_user_roles(user_id)
result = []
for user_role in user_roles:
role = await get_role(user_role.role_id)
if role:
permissions = await get_role_permissions(role.id)
result.append((role, permissions))
return result
async def check_user_has_role_permission(
user_id: str, account_id: str, permission_type: PermissionType
) -> bool:
"""Check if user has a specific permission through any of their roles"""
# Get all permissions from user's roles
role_permissions = await get_user_permissions_from_roles(user_id)
# Check if any role grants the required permission on this account
for role, permissions in role_permissions:
for perm in permissions:
if perm.account_id == account_id and perm.permission_type == permission_type:
return True
return False