- Add role management to "By User" tab - Show all users with roles and/or direct permissions - Add ability to assign/revoke roles from users - Display role chips as clickable and removable - Add "Assign Role" button for each user - Fix account_id validation error in permission granting - Extract account_id string from Quasar q-select object - Apply fix to grantPermission, bulkGrantPermissions, and addRolePermission - Fix role-based permission checking for expense submission - Update get_user_permissions_with_inheritance() to include role permissions - Ensures users with role-based permissions can submit expenses - Improve Vue reactivity for role details dialog - Use spread operator to create fresh arrays - Add $nextTick() before showing dialog 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1674 lines
52 KiB
Python
1674 lines
52 KiB
Python
import json
|
|
from datetime import datetime
|
|
from typing import Optional
|
|
|
|
import httpx
|
|
from lnbits.db import Database
|
|
from lnbits.helpers import urlsafe_short_hash
|
|
from lnbits.utils.cache import Cache
|
|
|
|
from .models import (
|
|
Account,
|
|
AccountPermission,
|
|
AccountType,
|
|
AssertionStatus,
|
|
AssignUserRole,
|
|
BalanceAssertion,
|
|
CastleSettings,
|
|
CreateAccount,
|
|
CreateAccountPermission,
|
|
CreateBalanceAssertion,
|
|
CreateEntryLine,
|
|
CreateJournalEntry,
|
|
CreateRole,
|
|
CreateRolePermission,
|
|
CreateUserEquityStatus,
|
|
EntryLine,
|
|
JournalEntry,
|
|
PermissionType,
|
|
Role,
|
|
RolePermission,
|
|
RoleWithPermissions,
|
|
StoredUserWalletSettings,
|
|
UpdateRole,
|
|
UserBalance,
|
|
UserCastleSettings,
|
|
UserEquityStatus,
|
|
UserRole,
|
|
UserWalletSettings,
|
|
UserWithRoles,
|
|
)
|
|
|
|
# Import core accounting logic
|
|
from .core.validation import (
|
|
ValidationError,
|
|
validate_journal_entry,
|
|
validate_balance,
|
|
validate_receivable_entry,
|
|
validate_expense_entry,
|
|
validate_payment_entry,
|
|
)
|
|
|
|
db = Database("ext_castle")
|
|
|
|
# ===== CACHING =====
|
|
# Cache for account and permission lookups to reduce DB queries
|
|
# TTLs: accounts=300s (5min), permissions=60s (1min)
|
|
|
|
account_cache = Cache() # 5 minutes for accounts (rarely change)
|
|
permission_cache = Cache() # 1 minute for permissions (may change frequently)
|
|
|
|
# Cache TTLs
|
|
ACCOUNT_CACHE_TTL = 300 # 5 minutes
|
|
PERMISSION_CACHE_TTL = 60 # 1 minute
|
|
|
|
|
|
# ===== ACCOUNT OPERATIONS =====
|
|
|
|
|
|
async def create_account(data: CreateAccount) -> Account:
|
|
account_id = urlsafe_short_hash()
|
|
account = Account(
|
|
id=account_id,
|
|
name=data.name,
|
|
account_type=data.account_type,
|
|
description=data.description,
|
|
user_id=data.user_id,
|
|
is_virtual=data.is_virtual,
|
|
created_at=datetime.now(),
|
|
)
|
|
await db.insert("accounts", account)
|
|
|
|
# Invalidate cache for this account (Cache class doesn't have delete method, use pop)
|
|
account_cache._values.pop(f"account:id:{account_id}", None)
|
|
account_cache._values.pop(f"account:name:{data.name}", None)
|
|
|
|
return account
|
|
|
|
|
|
async def get_account(account_id: str) -> Optional[Account]:
|
|
"""Get account by ID with caching"""
|
|
cache_key = f"account:id:{account_id}"
|
|
|
|
# Try cache first
|
|
cached = account_cache.get(cache_key)
|
|
if cached is not None:
|
|
return cached
|
|
|
|
# Query DB
|
|
account = await db.fetchone(
|
|
"SELECT * FROM accounts WHERE id = :id",
|
|
{"id": account_id},
|
|
Account,
|
|
)
|
|
|
|
# Cache result (even if None)
|
|
account_cache.set(cache_key, account, ACCOUNT_CACHE_TTL)
|
|
|
|
return account
|
|
|
|
|
|
async def get_account_by_name(name: str) -> Optional[Account]:
|
|
"""Get account by name (hierarchical format) with caching"""
|
|
cache_key = f"account:name:{name}"
|
|
|
|
# Try cache first
|
|
cached = account_cache.get(cache_key)
|
|
if cached is not None:
|
|
return cached
|
|
|
|
# Query DB
|
|
account = await db.fetchone(
|
|
"SELECT * FROM accounts WHERE name = :name",
|
|
{"name": name},
|
|
Account,
|
|
)
|
|
|
|
# Cache result (even if None)
|
|
account_cache.set(cache_key, account, ACCOUNT_CACHE_TTL)
|
|
|
|
return account
|
|
|
|
|
|
async def get_all_accounts(include_inactive: bool = False) -> list[Account]:
|
|
"""
|
|
Get all accounts, optionally including inactive ones.
|
|
|
|
Args:
|
|
include_inactive: If True, include inactive accounts. Default False.
|
|
|
|
Returns:
|
|
List of Account objects
|
|
"""
|
|
if include_inactive:
|
|
query = "SELECT * FROM accounts ORDER BY account_type, name"
|
|
else:
|
|
query = "SELECT * FROM accounts WHERE is_active = TRUE ORDER BY account_type, name"
|
|
|
|
return await db.fetchall(query, model=Account)
|
|
|
|
|
|
async def get_accounts_by_type(
|
|
account_type: AccountType, include_inactive: bool = False
|
|
) -> list[Account]:
|
|
"""
|
|
Get accounts by type, optionally including inactive ones.
|
|
|
|
Args:
|
|
account_type: The account type to filter by
|
|
include_inactive: If True, include inactive accounts. Default False.
|
|
|
|
Returns:
|
|
List of Account objects
|
|
"""
|
|
if include_inactive:
|
|
query = "SELECT * FROM accounts WHERE account_type = :type ORDER BY name"
|
|
else:
|
|
query = "SELECT * FROM accounts WHERE account_type = :type AND is_active = TRUE ORDER BY name"
|
|
|
|
return await db.fetchall(query, {"type": account_type.value}, Account)
|
|
|
|
|
|
async def update_account_is_active(account_id: str, is_active: bool) -> None:
|
|
"""
|
|
Update the is_active status of an account (soft delete/reactivate).
|
|
|
|
Args:
|
|
account_id: Account ID to update
|
|
is_active: True to activate, False to deactivate
|
|
"""
|
|
await db.execute(
|
|
"""
|
|
UPDATE accounts
|
|
SET is_active = :is_active
|
|
WHERE id = :account_id
|
|
""",
|
|
{"account_id": account_id, "is_active": is_active},
|
|
)
|
|
|
|
# Invalidate cache
|
|
account_cache._values.pop(f"account:id:{account_id}", None)
|
|
|
|
|
|
async def get_or_create_user_account(
|
|
user_id: str, account_type: AccountType, base_name: str
|
|
) -> Account:
|
|
"""
|
|
Get or create a user-specific account with hierarchical naming.
|
|
|
|
This function checks if the account exists in Fava/Beancount and creates it
|
|
if it doesn't exist. The account is also registered in Castle's database for
|
|
metadata tracking (permissions, descriptions, etc.).
|
|
|
|
Examples:
|
|
get_or_create_user_account("af983632", AccountType.ASSET, "Accounts Receivable")
|
|
→ "Assets:Receivable:User-af983632"
|
|
|
|
get_or_create_user_account("af983632", AccountType.LIABILITY, "Accounts Payable")
|
|
→ "Liabilities:Payable:User-af983632"
|
|
"""
|
|
from .account_utils import format_hierarchical_account_name
|
|
from .fava_client import get_fava_client
|
|
from loguru import logger
|
|
|
|
# Generate hierarchical account name
|
|
account_name = format_hierarchical_account_name(account_type, base_name, user_id)
|
|
|
|
# Try to find existing account with this hierarchical name in Castle DB
|
|
account = await db.fetchone(
|
|
"""
|
|
SELECT * FROM accounts
|
|
WHERE user_id = :user_id AND account_type = :type AND name = :name
|
|
""",
|
|
{"user_id": user_id, "type": account_type.value, "name": account_name},
|
|
Account,
|
|
)
|
|
|
|
logger.info(f"[ACCOUNT CHECK] User {user_id[:8]}, Account: {account_name}, In Castle DB: {account is not None}")
|
|
|
|
# Always check/create in Fava, even if account exists in Castle DB
|
|
# This ensures Beancount has the Open directive
|
|
fava_account_exists = False
|
|
if True: # Always check Fava
|
|
# Check if account exists in Fava/Beancount
|
|
fava = get_fava_client()
|
|
try:
|
|
# Query Fava for this account
|
|
query = f"SELECT account WHERE account = '{account_name}'"
|
|
async with httpx.AsyncClient(timeout=5.0) as client:
|
|
response = await client.get(
|
|
f"{fava.base_url}/query",
|
|
params={"query_string": query}
|
|
)
|
|
response.raise_for_status()
|
|
result = response.json()
|
|
|
|
# Check if account exists in Fava
|
|
fava_account_exists = len(result.get("data", {}).get("rows", [])) > 0
|
|
logger.info(f"[FAVA CHECK] Account {account_name} exists in Fava: {fava_account_exists}")
|
|
|
|
if not fava_account_exists:
|
|
# Create account in Fava/Beancount via Open directive
|
|
logger.info(f"[FAVA CREATE] Creating account in Fava: {account_name}")
|
|
await fava.add_account(
|
|
account_name=account_name,
|
|
currencies=["EUR", "SATS", "USD"], # Support common currencies
|
|
metadata={
|
|
"user_id": user_id,
|
|
"description": f"User-specific {account_type.value} account"
|
|
}
|
|
)
|
|
logger.info(f"[FAVA CREATE] Successfully created account in Fava: {account_name}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"[FAVA ERROR] Could not check/create account in Fava: {e}", exc_info=True)
|
|
# Continue anyway - account creation in Castle DB is still useful for metadata
|
|
|
|
# Ensure account exists in Castle DB (sync from Beancount if needed)
|
|
# This uses the account sync module for consistency
|
|
if not account:
|
|
logger.info(f"[CASTLE DB] Syncing account from Beancount to Castle DB: {account_name}")
|
|
from .account_sync import sync_single_account_from_beancount
|
|
|
|
# Sync from Beancount to Castle DB
|
|
created = await sync_single_account_from_beancount(account_name)
|
|
|
|
if created:
|
|
logger.info(f"[CASTLE DB] Account synced from Beancount: {account_name}")
|
|
else:
|
|
logger.warning(f"[CASTLE DB] Failed to sync account from Beancount: {account_name}")
|
|
|
|
# Fetch the account from Castle DB
|
|
account = await db.fetchone(
|
|
"""
|
|
SELECT * FROM accounts
|
|
WHERE user_id = :user_id AND account_type = :type AND name = :name
|
|
""",
|
|
{"user_id": user_id, "type": account_type.value, "name": account_name},
|
|
Account,
|
|
)
|
|
|
|
if not account:
|
|
logger.error(f"[CASTLE DB] Account still not found after sync: {account_name}")
|
|
# Fallback: create directly in Castle DB if sync failed
|
|
logger.info(f"[CASTLE DB] Creating account directly in Castle DB: {account_name}")
|
|
try:
|
|
account = await create_account(
|
|
CreateAccount(
|
|
name=account_name,
|
|
account_type=account_type,
|
|
description=f"User-specific {account_type.value} account",
|
|
user_id=user_id,
|
|
)
|
|
)
|
|
except Exception as e:
|
|
# Handle UNIQUE constraint error - account already exists
|
|
if "UNIQUE constraint failed" in str(e) and "accounts.name" in str(e):
|
|
logger.warning(f"[CASTLE DB] Account already exists (UNIQUE constraint), fetching by name: {account_name}")
|
|
# Fetch existing account by name only (ignore user_id in query)
|
|
account = await db.fetchone(
|
|
"""
|
|
SELECT * FROM accounts
|
|
WHERE name = :name
|
|
""",
|
|
{"name": account_name},
|
|
Account,
|
|
)
|
|
if account:
|
|
logger.info(f"[CASTLE DB] Found existing account: {account_name} (user_id: {account.user_id})")
|
|
# Update user_id if it's NULL or different
|
|
if account.user_id != user_id:
|
|
logger.info(f"[CASTLE DB] Updating account user_id from {account.user_id} to {user_id}")
|
|
await db.execute(
|
|
"""
|
|
UPDATE accounts
|
|
SET user_id = :user_id
|
|
WHERE name = :name
|
|
""",
|
|
{"user_id": user_id, "name": account_name}
|
|
)
|
|
# Refresh account from DB
|
|
account = await db.fetchone(
|
|
"""
|
|
SELECT * FROM accounts
|
|
WHERE name = :name
|
|
""",
|
|
{"name": account_name},
|
|
Account,
|
|
)
|
|
else:
|
|
# Re-raise if it's a different error
|
|
raise
|
|
else:
|
|
logger.info(f"[CASTLE DB] Account already exists in Castle DB: {account_name}")
|
|
|
|
return account
|
|
|
|
|
|
# ===== JOURNAL ENTRY OPERATIONS =====
|
|
|
|
|
|
# ===== JOURNAL ENTRY OPERATIONS (REMOVED) =====
|
|
#
|
|
# All journal entry operations have been moved to Fava/Beancount.
|
|
# Castle no longer maintains its own journal_entries and entry_lines tables.
|
|
#
|
|
# For journal entry operations, see:
|
|
# - views_api.py: api_create_journal_entry() - writes to Fava via FavaClient
|
|
# - views_api.py: API endpoints query Fava via FavaClient for reading entries
|
|
#
|
|
# Migration: m016_drop_obsolete_journal_tables
|
|
# Removed functions:
|
|
# - create_journal_entry()
|
|
# - get_journal_entry()
|
|
# - get_journal_entry_by_reference()
|
|
# - get_entry_lines()
|
|
# - get_all_journal_entries()
|
|
# - get_journal_entries_by_user()
|
|
# - count_all_journal_entries()
|
|
# - count_journal_entries_by_user()
|
|
# - get_journal_entries_by_user_and_account_type()
|
|
# - count_journal_entries_by_user_and_account_type()
|
|
# - get_account_transactions()
|
|
|
|
|
|
# ===== SETTINGS =====
|
|
|
|
|
|
async def create_castle_settings(
|
|
user_id: str, data: CastleSettings
|
|
) -> CastleSettings:
|
|
settings = UserCastleSettings(**data.dict(), id=user_id)
|
|
await db.insert("extension_settings", settings)
|
|
return settings
|
|
|
|
|
|
async def get_castle_settings(user_id: str) -> Optional[CastleSettings]:
|
|
return await db.fetchone(
|
|
"""
|
|
SELECT * FROM extension_settings
|
|
WHERE id = :user_id
|
|
""",
|
|
{"user_id": user_id},
|
|
CastleSettings,
|
|
)
|
|
|
|
|
|
async def update_castle_settings(
|
|
user_id: str, data: CastleSettings
|
|
) -> CastleSettings:
|
|
settings = UserCastleSettings(**data.dict(), id=user_id)
|
|
await db.update("extension_settings", settings)
|
|
return settings
|
|
|
|
|
|
# ===== USER WALLET SETTINGS =====
|
|
|
|
|
|
async def create_user_wallet_settings(
|
|
user_id: str, data: UserWalletSettings
|
|
) -> UserWalletSettings:
|
|
settings = StoredUserWalletSettings(**data.dict(), id=user_id)
|
|
await db.insert("user_wallet_settings", settings)
|
|
return settings
|
|
|
|
|
|
async def get_user_wallet_settings(user_id: str) -> Optional[UserWalletSettings]:
|
|
return await db.fetchone(
|
|
"""
|
|
SELECT * FROM user_wallet_settings
|
|
WHERE id = :user_id
|
|
""",
|
|
{"user_id": user_id},
|
|
UserWalletSettings,
|
|
)
|
|
|
|
|
|
async def update_user_wallet_settings(
|
|
user_id: str, data: UserWalletSettings
|
|
) -> UserWalletSettings:
|
|
settings = StoredUserWalletSettings(**data.dict(), id=user_id)
|
|
await db.update("user_wallet_settings", settings)
|
|
return settings
|
|
|
|
|
|
async def get_all_user_wallet_settings() -> list[StoredUserWalletSettings]:
|
|
"""Get all user wallet settings"""
|
|
return await db.fetchall(
|
|
"SELECT * FROM user_wallet_settings ORDER BY id",
|
|
{},
|
|
StoredUserWalletSettings,
|
|
)
|
|
|
|
|
|
# ===== MANUAL PAYMENT REQUESTS =====
|
|
|
|
|
|
async def create_manual_payment_request(
|
|
user_id: str, amount: int, description: str
|
|
) -> "ManualPaymentRequest":
|
|
"""Create a new manual payment request"""
|
|
from .models import ManualPaymentRequest
|
|
|
|
request_id = urlsafe_short_hash()
|
|
request = ManualPaymentRequest(
|
|
id=request_id,
|
|
user_id=user_id,
|
|
amount=amount,
|
|
description=description,
|
|
status="pending",
|
|
created_at=datetime.now(),
|
|
)
|
|
|
|
await db.execute(
|
|
"""
|
|
INSERT INTO manual_payment_requests (id, user_id, amount, description, status, created_at)
|
|
VALUES (:id, :user_id, :amount, :description, :status, :created_at)
|
|
""",
|
|
{
|
|
"id": request.id,
|
|
"user_id": request.user_id,
|
|
"amount": request.amount,
|
|
"description": request.description,
|
|
"status": request.status,
|
|
"created_at": request.created_at,
|
|
},
|
|
)
|
|
|
|
return request
|
|
|
|
|
|
async def get_manual_payment_request(request_id: str) -> Optional["ManualPaymentRequest"]:
|
|
"""Get a manual payment request by ID"""
|
|
from .models import ManualPaymentRequest
|
|
|
|
return await db.fetchone(
|
|
"SELECT * FROM manual_payment_requests WHERE id = :id",
|
|
{"id": request_id},
|
|
ManualPaymentRequest,
|
|
)
|
|
|
|
|
|
async def get_user_manual_payment_requests(
|
|
user_id: str, limit: int = 100
|
|
) -> list["ManualPaymentRequest"]:
|
|
"""Get all manual payment requests for a specific user"""
|
|
from .models import ManualPaymentRequest
|
|
|
|
return await db.fetchall(
|
|
"""
|
|
SELECT * FROM manual_payment_requests
|
|
WHERE user_id = :user_id
|
|
ORDER BY created_at DESC
|
|
LIMIT :limit
|
|
""",
|
|
{"user_id": user_id, "limit": limit},
|
|
ManualPaymentRequest,
|
|
)
|
|
|
|
|
|
async def get_all_manual_payment_requests(
|
|
status: Optional[str] = None, limit: int = 100
|
|
) -> list["ManualPaymentRequest"]:
|
|
"""Get all manual payment requests, optionally filtered by status"""
|
|
from .models import ManualPaymentRequest
|
|
|
|
if status:
|
|
return await db.fetchall(
|
|
"""
|
|
SELECT * FROM manual_payment_requests
|
|
WHERE status = :status
|
|
ORDER BY created_at DESC
|
|
LIMIT :limit
|
|
""",
|
|
{"status": status, "limit": limit},
|
|
ManualPaymentRequest,
|
|
)
|
|
else:
|
|
return await db.fetchall(
|
|
"""
|
|
SELECT * FROM manual_payment_requests
|
|
ORDER BY created_at DESC
|
|
LIMIT :limit
|
|
""",
|
|
{"limit": limit},
|
|
ManualPaymentRequest,
|
|
)
|
|
|
|
|
|
async def approve_manual_payment_request(
|
|
request_id: str, reviewed_by: str, journal_entry_id: str
|
|
) -> Optional["ManualPaymentRequest"]:
|
|
"""Approve a manual payment request"""
|
|
from .models import ManualPaymentRequest
|
|
|
|
await db.execute(
|
|
"""
|
|
UPDATE manual_payment_requests
|
|
SET status = 'approved', reviewed_at = :reviewed_at, reviewed_by = :reviewed_by, journal_entry_id = :journal_entry_id
|
|
WHERE id = :id
|
|
""",
|
|
{
|
|
"id": request_id,
|
|
"reviewed_at": datetime.now(),
|
|
"reviewed_by": reviewed_by,
|
|
"journal_entry_id": journal_entry_id,
|
|
},
|
|
)
|
|
|
|
return await get_manual_payment_request(request_id)
|
|
|
|
|
|
async def reject_manual_payment_request(
|
|
request_id: str, reviewed_by: str
|
|
) -> Optional["ManualPaymentRequest"]:
|
|
"""Reject a manual payment request"""
|
|
from .models import ManualPaymentRequest
|
|
|
|
await db.execute(
|
|
"""
|
|
UPDATE manual_payment_requests
|
|
SET status = 'rejected', reviewed_at = :reviewed_at, reviewed_by = :reviewed_by
|
|
WHERE id = :id
|
|
""",
|
|
{
|
|
"id": request_id,
|
|
"reviewed_at": datetime.now(),
|
|
"reviewed_by": reviewed_by,
|
|
},
|
|
)
|
|
|
|
return await get_manual_payment_request(request_id)
|
|
|
|
|
|
# ===== BALANCE ASSERTION OPERATIONS =====
|
|
|
|
|
|
async def create_balance_assertion(
|
|
data: CreateBalanceAssertion, created_by: str
|
|
) -> BalanceAssertion:
|
|
"""Create a new balance assertion"""
|
|
from decimal import Decimal
|
|
|
|
assertion_id = urlsafe_short_hash()
|
|
assertion_date = data.date if data.date else datetime.now()
|
|
|
|
assertion = BalanceAssertion(
|
|
id=assertion_id,
|
|
date=assertion_date,
|
|
account_id=data.account_id,
|
|
expected_balance_sats=data.expected_balance_sats,
|
|
expected_balance_fiat=data.expected_balance_fiat,
|
|
fiat_currency=data.fiat_currency,
|
|
tolerance_sats=data.tolerance_sats,
|
|
tolerance_fiat=data.tolerance_fiat,
|
|
status=AssertionStatus.PENDING,
|
|
created_by=created_by,
|
|
created_at=datetime.now(),
|
|
)
|
|
|
|
# Manually insert with Decimal fields converted to strings
|
|
await db.execute(
|
|
"""
|
|
INSERT INTO balance_assertions (
|
|
id, date, account_id, expected_balance_sats, expected_balance_fiat,
|
|
fiat_currency, tolerance_sats, tolerance_fiat, status, created_by, created_at
|
|
) VALUES (
|
|
:id, :date, :account_id, :expected_balance_sats, :expected_balance_fiat,
|
|
:fiat_currency, :tolerance_sats, :tolerance_fiat, :status, :created_by, :created_at
|
|
)
|
|
""",
|
|
{
|
|
"id": assertion.id,
|
|
"date": assertion.date,
|
|
"account_id": assertion.account_id,
|
|
"expected_balance_sats": assertion.expected_balance_sats,
|
|
"expected_balance_fiat": str(assertion.expected_balance_fiat) if assertion.expected_balance_fiat else None,
|
|
"fiat_currency": assertion.fiat_currency,
|
|
"tolerance_sats": assertion.tolerance_sats,
|
|
"tolerance_fiat": str(assertion.tolerance_fiat),
|
|
"status": assertion.status.value,
|
|
"created_by": assertion.created_by,
|
|
"created_at": assertion.created_at,
|
|
},
|
|
)
|
|
return assertion
|
|
|
|
|
|
async def get_balance_assertion(assertion_id: str) -> Optional[BalanceAssertion]:
|
|
"""Get a balance assertion by ID"""
|
|
from decimal import Decimal
|
|
|
|
row = await db.fetchone(
|
|
"SELECT * FROM balance_assertions WHERE id = :id",
|
|
{"id": assertion_id},
|
|
)
|
|
|
|
if not row:
|
|
return None
|
|
|
|
# Parse Decimal fields from TEXT storage
|
|
return BalanceAssertion(
|
|
id=row["id"],
|
|
date=row["date"],
|
|
account_id=row["account_id"],
|
|
expected_balance_sats=row["expected_balance_sats"],
|
|
expected_balance_fiat=Decimal(row["expected_balance_fiat"]) if row["expected_balance_fiat"] else None,
|
|
fiat_currency=row["fiat_currency"],
|
|
tolerance_sats=row["tolerance_sats"],
|
|
tolerance_fiat=Decimal(row["tolerance_fiat"]) if row["tolerance_fiat"] else Decimal("0"),
|
|
checked_balance_sats=row["checked_balance_sats"],
|
|
checked_balance_fiat=Decimal(row["checked_balance_fiat"]) if row["checked_balance_fiat"] else None,
|
|
difference_sats=row["difference_sats"],
|
|
difference_fiat=Decimal(row["difference_fiat"]) if row["difference_fiat"] else None,
|
|
status=AssertionStatus(row["status"]),
|
|
created_by=row["created_by"],
|
|
created_at=row["created_at"],
|
|
checked_at=row["checked_at"],
|
|
)
|
|
|
|
|
|
async def get_balance_assertions(
|
|
account_id: Optional[str] = None,
|
|
status: Optional[AssertionStatus] = None,
|
|
limit: int = 100,
|
|
) -> list[BalanceAssertion]:
|
|
"""Get balance assertions with optional filters"""
|
|
from decimal import Decimal
|
|
|
|
if account_id and status:
|
|
rows = await db.fetchall(
|
|
"""
|
|
SELECT * FROM balance_assertions
|
|
WHERE account_id = :account_id AND status = :status
|
|
ORDER BY date DESC
|
|
LIMIT :limit
|
|
""",
|
|
{"account_id": account_id, "status": status.value, "limit": limit},
|
|
)
|
|
elif account_id:
|
|
rows = await db.fetchall(
|
|
"""
|
|
SELECT * FROM balance_assertions
|
|
WHERE account_id = :account_id
|
|
ORDER BY date DESC
|
|
LIMIT :limit
|
|
""",
|
|
{"account_id": account_id, "limit": limit},
|
|
)
|
|
elif status:
|
|
rows = await db.fetchall(
|
|
"""
|
|
SELECT * FROM balance_assertions
|
|
WHERE status = :status
|
|
ORDER BY date DESC
|
|
LIMIT :limit
|
|
""",
|
|
{"status": status.value, "limit": limit},
|
|
)
|
|
else:
|
|
rows = await db.fetchall(
|
|
"""
|
|
SELECT * FROM balance_assertions
|
|
ORDER BY date DESC
|
|
LIMIT :limit
|
|
""",
|
|
{"limit": limit},
|
|
)
|
|
|
|
assertions = []
|
|
for row in rows:
|
|
assertions.append(
|
|
BalanceAssertion(
|
|
id=row["id"],
|
|
date=row["date"],
|
|
account_id=row["account_id"],
|
|
expected_balance_sats=row["expected_balance_sats"],
|
|
expected_balance_fiat=Decimal(row["expected_balance_fiat"]) if row["expected_balance_fiat"] else None,
|
|
fiat_currency=row["fiat_currency"],
|
|
tolerance_sats=row["tolerance_sats"],
|
|
tolerance_fiat=Decimal(row["tolerance_fiat"]) if row["tolerance_fiat"] else Decimal("0"),
|
|
checked_balance_sats=row["checked_balance_sats"],
|
|
checked_balance_fiat=Decimal(row["checked_balance_fiat"]) if row["checked_balance_fiat"] else None,
|
|
difference_sats=row["difference_sats"],
|
|
difference_fiat=Decimal(row["difference_fiat"]) if row["difference_fiat"] else None,
|
|
status=AssertionStatus(row["status"]),
|
|
created_by=row["created_by"],
|
|
created_at=row["created_at"],
|
|
checked_at=row["checked_at"],
|
|
)
|
|
)
|
|
|
|
return assertions
|
|
|
|
|
|
async def check_balance_assertion(assertion_id: str) -> BalanceAssertion:
|
|
"""
|
|
Check a balance assertion by comparing expected vs actual balance.
|
|
Updates the assertion with the check results.
|
|
Uses Fava/Beancount for balance queries.
|
|
"""
|
|
from decimal import Decimal
|
|
from .fava_client import get_fava_client
|
|
|
|
assertion = await get_balance_assertion(assertion_id)
|
|
if not assertion:
|
|
raise ValueError(f"Balance assertion {assertion_id} not found")
|
|
|
|
# Get actual account balance from Fava
|
|
account = await get_account(assertion.account_id)
|
|
if not account:
|
|
raise ValueError(f"Account {assertion.account_id} not found")
|
|
|
|
fava = get_fava_client()
|
|
|
|
# Get balance from Fava
|
|
balance_data = await fava.get_account_balance(account.name)
|
|
actual_balance = balance_data["sats"]
|
|
|
|
# Get fiat balance if needed
|
|
actual_fiat_balance = None
|
|
if assertion.fiat_currency and account.user_id:
|
|
user_balance_data = await fava.get_user_balance(account.user_id)
|
|
actual_fiat_balance = user_balance_data["fiat_balances"].get(assertion.fiat_currency, Decimal("0"))
|
|
|
|
# Check sats balance
|
|
difference_sats = actual_balance - assertion.expected_balance_sats
|
|
sats_match = abs(difference_sats) <= assertion.tolerance_sats
|
|
|
|
# Check fiat balance if applicable
|
|
fiat_match = True
|
|
difference_fiat = None
|
|
if assertion.expected_balance_fiat is not None and actual_fiat_balance is not None:
|
|
difference_fiat = actual_fiat_balance - assertion.expected_balance_fiat
|
|
fiat_match = abs(difference_fiat) <= assertion.tolerance_fiat
|
|
|
|
# Determine overall status
|
|
status = AssertionStatus.PASSED if (sats_match and fiat_match) else AssertionStatus.FAILED
|
|
|
|
# Update assertion with check results
|
|
await db.execute(
|
|
"""
|
|
UPDATE balance_assertions
|
|
SET checked_balance_sats = :checked_sats,
|
|
checked_balance_fiat = :checked_fiat,
|
|
difference_sats = :diff_sats,
|
|
difference_fiat = :diff_fiat,
|
|
status = :status,
|
|
checked_at = :checked_at
|
|
WHERE id = :id
|
|
""",
|
|
{
|
|
"id": assertion_id,
|
|
"checked_sats": actual_balance,
|
|
"checked_fiat": str(actual_fiat_balance) if actual_fiat_balance is not None else None,
|
|
"diff_sats": difference_sats,
|
|
"diff_fiat": str(difference_fiat) if difference_fiat is not None else None,
|
|
"status": status.value,
|
|
"checked_at": datetime.now(),
|
|
},
|
|
)
|
|
|
|
# Return updated assertion
|
|
return await get_balance_assertion(assertion_id)
|
|
|
|
|
|
async def delete_balance_assertion(assertion_id: str) -> None:
|
|
"""Delete a balance assertion"""
|
|
await db.execute(
|
|
"DELETE FROM balance_assertions WHERE id = :id",
|
|
{"id": assertion_id},
|
|
)
|
|
|
|
|
|
# User Equity Status CRUD operations
|
|
|
|
|
|
async def get_user_equity_status(user_id: str) -> Optional["UserEquityStatus"]:
|
|
"""Get user's equity eligibility status"""
|
|
from .models import UserEquityStatus
|
|
|
|
row = await db.fetchone(
|
|
"""
|
|
SELECT * FROM user_equity_status
|
|
WHERE user_id = :user_id
|
|
""",
|
|
{"user_id": user_id},
|
|
)
|
|
|
|
return UserEquityStatus(**row) if row else None
|
|
|
|
|
|
async def create_or_update_user_equity_status(
|
|
data: "CreateUserEquityStatus", granted_by: str
|
|
) -> "UserEquityStatus":
|
|
"""Create or update user equity eligibility status"""
|
|
from datetime import datetime
|
|
from .models import UserEquityStatus, AccountType
|
|
import uuid
|
|
|
|
# Auto-create user-specific equity account if granting eligibility
|
|
if data.is_equity_eligible:
|
|
# Generate equity account name: Equity:User-{user_id}
|
|
equity_account_name = f"Equity:User-{data.user_id[:8]}"
|
|
|
|
# Check if the equity account already exists
|
|
equity_account = await get_account_by_name(equity_account_name)
|
|
|
|
if not equity_account:
|
|
# Create the user-specific equity account
|
|
await db.execute(
|
|
"""
|
|
INSERT INTO accounts (id, name, account_type, description, user_id, created_at)
|
|
VALUES (:id, :name, :type, :description, :user_id, :created_at)
|
|
""",
|
|
{
|
|
"id": str(uuid.uuid4()),
|
|
"name": equity_account_name,
|
|
"type": AccountType.EQUITY.value,
|
|
"description": f"Equity contributions for user {data.user_id[:8]}",
|
|
"user_id": data.user_id,
|
|
"created_at": datetime.now(),
|
|
},
|
|
)
|
|
|
|
# Auto-populate equity_account_name in the data
|
|
data.equity_account_name = equity_account_name
|
|
|
|
# Check if user already has equity status
|
|
existing = await get_user_equity_status(data.user_id)
|
|
|
|
if existing:
|
|
# Update existing record
|
|
await db.execute(
|
|
"""
|
|
UPDATE user_equity_status
|
|
SET is_equity_eligible = :is_equity_eligible,
|
|
equity_account_name = :equity_account_name,
|
|
notes = :notes,
|
|
granted_by = :granted_by,
|
|
granted_at = :granted_at,
|
|
revoked_at = :revoked_at
|
|
WHERE user_id = :user_id
|
|
""",
|
|
{
|
|
"user_id": data.user_id,
|
|
"is_equity_eligible": data.is_equity_eligible,
|
|
"equity_account_name": data.equity_account_name,
|
|
"notes": data.notes,
|
|
"granted_by": granted_by,
|
|
"granted_at": datetime.now(),
|
|
"revoked_at": None if data.is_equity_eligible else datetime.now(),
|
|
},
|
|
)
|
|
else:
|
|
# Create new record
|
|
await db.execute(
|
|
"""
|
|
INSERT INTO user_equity_status (
|
|
user_id, is_equity_eligible, equity_account_name,
|
|
notes, granted_by, granted_at
|
|
)
|
|
VALUES (
|
|
:user_id, :is_equity_eligible, :equity_account_name,
|
|
:notes, :granted_by, :granted_at
|
|
)
|
|
""",
|
|
{
|
|
"user_id": data.user_id,
|
|
"is_equity_eligible": data.is_equity_eligible,
|
|
"equity_account_name": data.equity_account_name,
|
|
"notes": data.notes,
|
|
"granted_by": granted_by,
|
|
"granted_at": datetime.now(),
|
|
},
|
|
)
|
|
|
|
# Return the created/updated record
|
|
result = await get_user_equity_status(data.user_id)
|
|
if not result:
|
|
raise ValueError(f"Failed to create/update equity status for user {data.user_id}")
|
|
return result
|
|
|
|
|
|
async def revoke_user_equity_eligibility(user_id: str) -> Optional["UserEquityStatus"]:
|
|
"""Revoke user's equity contribution eligibility"""
|
|
from datetime import datetime
|
|
|
|
await db.execute(
|
|
"""
|
|
UPDATE user_equity_status
|
|
SET is_equity_eligible = FALSE,
|
|
revoked_at = :revoked_at
|
|
WHERE user_id = :user_id
|
|
""",
|
|
{"user_id": user_id, "revoked_at": datetime.now()},
|
|
)
|
|
|
|
return await get_user_equity_status(user_id)
|
|
|
|
|
|
async def get_all_equity_eligible_users() -> list["UserEquityStatus"]:
|
|
"""Get all equity-eligible users"""
|
|
from .models import UserEquityStatus
|
|
|
|
rows = await db.fetchall(
|
|
"""
|
|
SELECT * FROM user_equity_status
|
|
WHERE is_equity_eligible = TRUE
|
|
ORDER BY granted_at DESC
|
|
"""
|
|
)
|
|
|
|
return [UserEquityStatus(**row) for row in rows]
|
|
|
|
|
|
# ===== ACCOUNT PERMISSION OPERATIONS =====
|
|
|
|
|
|
async def create_account_permission(
|
|
data: "CreateAccountPermission", granted_by: str
|
|
) -> "AccountPermission":
|
|
"""
|
|
Create a new account permission.
|
|
|
|
Raises:
|
|
ValueError: If account is inactive or doesn't exist
|
|
"""
|
|
from .models import AccountPermission
|
|
|
|
# Validate account exists and is active
|
|
account = await get_account(data.account_id)
|
|
if not account:
|
|
raise ValueError(f"Account {data.account_id} not found")
|
|
if not account.is_active:
|
|
raise ValueError(
|
|
f"Cannot grant permission on inactive account: {account.name}"
|
|
)
|
|
|
|
permission_id = urlsafe_short_hash()
|
|
permission = AccountPermission(
|
|
id=permission_id,
|
|
user_id=data.user_id,
|
|
account_id=data.account_id,
|
|
permission_type=data.permission_type,
|
|
granted_by=granted_by,
|
|
granted_at=datetime.now(),
|
|
expires_at=data.expires_at,
|
|
notes=data.notes,
|
|
)
|
|
|
|
await db.execute(
|
|
"""
|
|
INSERT INTO account_permissions (
|
|
id, user_id, account_id, permission_type, granted_by,
|
|
granted_at, expires_at, notes
|
|
)
|
|
VALUES (
|
|
:id, :user_id, :account_id, :permission_type, :granted_by,
|
|
:granted_at, :expires_at, :notes
|
|
)
|
|
""",
|
|
{
|
|
"id": permission.id,
|
|
"user_id": permission.user_id,
|
|
"account_id": permission.account_id,
|
|
"permission_type": permission.permission_type.value,
|
|
"granted_by": permission.granted_by,
|
|
"granted_at": permission.granted_at,
|
|
"expires_at": permission.expires_at,
|
|
"notes": permission.notes,
|
|
},
|
|
)
|
|
|
|
# Invalidate permission cache for this user (Cache class doesn't have delete method, use pop)
|
|
permission_cache._values.pop(f"permissions:user:{permission.user_id}", None)
|
|
permission_cache._values.pop(f"permissions:user:{permission.user_id}:{permission.permission_type.value}", None)
|
|
|
|
return permission
|
|
|
|
|
|
async def get_account_permission(permission_id: str) -> Optional["AccountPermission"]:
|
|
"""Get account permission by ID"""
|
|
from .models import AccountPermission, PermissionType
|
|
|
|
row = await db.fetchone(
|
|
"SELECT * FROM account_permissions WHERE id = :id",
|
|
{"id": permission_id},
|
|
)
|
|
|
|
if not row:
|
|
return None
|
|
|
|
return AccountPermission(
|
|
id=row["id"],
|
|
user_id=row["user_id"],
|
|
account_id=row["account_id"],
|
|
permission_type=PermissionType(row["permission_type"]),
|
|
granted_by=row["granted_by"],
|
|
granted_at=row["granted_at"],
|
|
expires_at=row["expires_at"],
|
|
notes=row["notes"],
|
|
)
|
|
|
|
|
|
async def get_user_permissions(
|
|
user_id: str, permission_type: Optional["PermissionType"] = None
|
|
) -> list["AccountPermission"]:
|
|
"""Get all permissions for a specific user with caching"""
|
|
from .models import AccountPermission, PermissionType
|
|
|
|
# Build cache key
|
|
cache_key = f"permissions:user:{user_id}"
|
|
if permission_type:
|
|
cache_key += f":{permission_type.value}"
|
|
|
|
# Try cache first
|
|
cached = permission_cache.get(cache_key)
|
|
if cached is not None:
|
|
return cached
|
|
|
|
# Query DB
|
|
if permission_type:
|
|
rows = await db.fetchall(
|
|
"""
|
|
SELECT * FROM account_permissions
|
|
WHERE user_id = :user_id
|
|
AND permission_type = :permission_type
|
|
AND (expires_at IS NULL OR expires_at > :now)
|
|
ORDER BY granted_at DESC
|
|
""",
|
|
{
|
|
"user_id": user_id,
|
|
"permission_type": permission_type.value,
|
|
"now": datetime.now(),
|
|
},
|
|
)
|
|
else:
|
|
rows = await db.fetchall(
|
|
"""
|
|
SELECT * FROM account_permissions
|
|
WHERE user_id = :user_id
|
|
AND (expires_at IS NULL OR expires_at > :now)
|
|
ORDER BY granted_at DESC
|
|
""",
|
|
{"user_id": user_id, "now": datetime.now()},
|
|
)
|
|
|
|
permissions = [
|
|
AccountPermission(
|
|
id=row["id"],
|
|
user_id=row["user_id"],
|
|
account_id=row["account_id"],
|
|
permission_type=PermissionType(row["permission_type"]),
|
|
granted_by=row["granted_by"],
|
|
granted_at=row["granted_at"],
|
|
expires_at=row["expires_at"],
|
|
notes=row["notes"],
|
|
)
|
|
for row in rows
|
|
]
|
|
|
|
# Cache result
|
|
permission_cache.set(cache_key, permissions, PERMISSION_CACHE_TTL)
|
|
|
|
return permissions
|
|
|
|
|
|
async def get_account_permissions(account_id: str) -> list["AccountPermission"]:
|
|
"""Get all permissions for a specific account"""
|
|
from .models import AccountPermission, PermissionType
|
|
|
|
rows = await db.fetchall(
|
|
"""
|
|
SELECT * FROM account_permissions
|
|
WHERE account_id = :account_id
|
|
AND (expires_at IS NULL OR expires_at > :now)
|
|
ORDER BY granted_at DESC
|
|
""",
|
|
{"account_id": account_id, "now": datetime.now()},
|
|
)
|
|
|
|
return [
|
|
AccountPermission(
|
|
id=row["id"],
|
|
user_id=row["user_id"],
|
|
account_id=row["account_id"],
|
|
permission_type=PermissionType(row["permission_type"]),
|
|
granted_by=row["granted_by"],
|
|
granted_at=row["granted_at"],
|
|
expires_at=row["expires_at"],
|
|
notes=row["notes"],
|
|
)
|
|
for row in rows
|
|
]
|
|
|
|
|
|
async def delete_account_permission(permission_id: str) -> None:
|
|
"""Delete (revoke) an account permission"""
|
|
# Get permission first to invalidate cache
|
|
permission = await get_account_permission(permission_id)
|
|
|
|
await db.execute(
|
|
"DELETE FROM account_permissions WHERE id = :id",
|
|
{"id": permission_id},
|
|
)
|
|
|
|
# Invalidate permission cache for this user (Cache class doesn't have delete method, use pop)
|
|
if permission:
|
|
permission_cache._values.pop(f"permissions:user:{permission.user_id}", None)
|
|
permission_cache._values.pop(f"permissions:user:{permission.user_id}:{permission.permission_type.value}", None)
|
|
|
|
|
|
async def check_user_has_permission(
|
|
user_id: str, account_id: str, permission_type: "PermissionType"
|
|
) -> bool:
|
|
"""Check if user has a specific permission on an account (direct permission only, no inheritance)"""
|
|
row = await db.fetchone(
|
|
"""
|
|
SELECT id FROM account_permissions
|
|
WHERE user_id = :user_id
|
|
AND account_id = :account_id
|
|
AND permission_type = :permission_type
|
|
AND (expires_at IS NULL OR expires_at > :now)
|
|
""",
|
|
{
|
|
"user_id": user_id,
|
|
"account_id": account_id,
|
|
"permission_type": permission_type.value,
|
|
"now": datetime.now(),
|
|
},
|
|
)
|
|
|
|
return row is not None
|
|
|
|
|
|
async def get_user_permissions_with_inheritance(
|
|
user_id: str, account_name: str, permission_type: "PermissionType"
|
|
) -> list[tuple["AccountPermission", Optional[str]]]:
|
|
"""
|
|
Get all permissions for a user on an account, including inherited permissions from parent accounts.
|
|
Includes both direct permissions AND role-based permissions.
|
|
Returns list of tuples: (permission, parent_account_name or None)
|
|
|
|
Example:
|
|
If user has permission on "Expenses:Food", they also have permission on "Expenses:Food:Groceries"
|
|
Returns: [(permission_on_food, "Expenses:Food")]
|
|
"""
|
|
from .models import AccountPermission, PermissionType
|
|
|
|
# Get direct user permissions of this type
|
|
direct_permissions = await get_user_permissions(user_id, permission_type)
|
|
|
|
# Get role-based permissions of this type
|
|
role_permissions_list = await get_user_permissions_from_roles(user_id)
|
|
role_perms = []
|
|
for role, perms in role_permissions_list:
|
|
# Filter for the specific permission type
|
|
role_perms.extend([p for p in perms if p.permission_type == permission_type])
|
|
|
|
# Combine direct and role-based permissions
|
|
all_permissions = list(direct_permissions) + role_perms
|
|
|
|
# Find which permissions apply to this account (direct or inherited)
|
|
applicable_permissions = []
|
|
|
|
for perm in all_permissions:
|
|
# Get the account for this permission
|
|
account = await get_account(perm.account_id)
|
|
if not account:
|
|
continue
|
|
|
|
# Check if this account is a parent of the target account
|
|
# Parent accounts are indicated by hierarchical names (colon-separated)
|
|
# e.g., "Expenses:Food" is parent of "Expenses:Food:Groceries"
|
|
if account_name == account.name:
|
|
# Direct permission
|
|
applicable_permissions.append((perm, None))
|
|
elif account_name.startswith(account.name + ":"):
|
|
# Inherited permission from parent account
|
|
applicable_permissions.append((perm, account.name))
|
|
|
|
return applicable_permissions
|
|
|
|
|
|
# ===== ROLE-BASED ACCESS CONTROL (RBAC) OPERATIONS =====
|
|
|
|
|
|
async def create_role(data: CreateRole, created_by: str) -> Role:
|
|
"""Create a new role"""
|
|
role_id = urlsafe_short_hash()
|
|
role = Role(
|
|
id=role_id,
|
|
name=data.name,
|
|
description=data.description,
|
|
is_default=data.is_default,
|
|
created_by=created_by,
|
|
created_at=datetime.now(),
|
|
)
|
|
|
|
await db.execute(
|
|
"""
|
|
INSERT INTO roles (id, name, description, is_default, created_by, created_at)
|
|
VALUES (:id, :name, :description, :is_default, :created_by, :created_at)
|
|
""",
|
|
{
|
|
"id": role.id,
|
|
"name": role.name,
|
|
"description": role.description,
|
|
"is_default": role.is_default,
|
|
"created_by": role.created_by,
|
|
"created_at": role.created_at,
|
|
},
|
|
)
|
|
|
|
return role
|
|
|
|
|
|
async def get_role(role_id: str) -> Optional[Role]:
|
|
"""Get role by ID"""
|
|
row = await db.fetchone(
|
|
"SELECT * FROM roles WHERE id = :id",
|
|
{"id": role_id},
|
|
)
|
|
|
|
if not row:
|
|
return None
|
|
|
|
return Role(
|
|
id=row["id"],
|
|
name=row["name"],
|
|
description=row["description"],
|
|
is_default=row["is_default"],
|
|
created_by=row["created_by"],
|
|
created_at=row["created_at"],
|
|
)
|
|
|
|
|
|
async def get_role_by_name(name: str) -> Optional[Role]:
|
|
"""Get role by name"""
|
|
row = await db.fetchone(
|
|
"SELECT * FROM roles WHERE name = :name",
|
|
{"name": name},
|
|
)
|
|
|
|
if not row:
|
|
return None
|
|
|
|
return Role(
|
|
id=row["id"],
|
|
name=row["name"],
|
|
description=row["description"],
|
|
is_default=row["is_default"],
|
|
created_by=row["created_by"],
|
|
created_at=row["created_at"],
|
|
)
|
|
|
|
|
|
async def get_all_roles() -> list[Role]:
|
|
"""Get all roles"""
|
|
rows = await db.fetchall(
|
|
"SELECT * FROM roles ORDER BY name",
|
|
)
|
|
|
|
return [
|
|
Role(
|
|
id=row["id"],
|
|
name=row["name"],
|
|
description=row["description"],
|
|
is_default=row["is_default"],
|
|
created_by=row["created_by"],
|
|
created_at=row["created_at"],
|
|
)
|
|
for row in rows
|
|
]
|
|
|
|
|
|
async def get_default_role() -> Optional[Role]:
|
|
"""Get the default role that is auto-assigned to new users"""
|
|
row = await db.fetchone(
|
|
"SELECT * FROM roles WHERE is_default = TRUE LIMIT 1",
|
|
)
|
|
|
|
if not row:
|
|
return None
|
|
|
|
return Role(
|
|
id=row["id"],
|
|
name=row["name"],
|
|
description=row["description"],
|
|
is_default=row["is_default"],
|
|
created_by=row["created_by"],
|
|
created_at=row["created_at"],
|
|
)
|
|
|
|
|
|
async def update_role(role_id: str, data: UpdateRole) -> Optional[Role]:
|
|
"""Update a role"""
|
|
# If setting this role as default, unset any other default roles
|
|
if data.is_default is True:
|
|
await db.execute(
|
|
"UPDATE roles SET is_default = FALSE WHERE id != :id",
|
|
{"id": role_id},
|
|
)
|
|
|
|
# Build update statement dynamically based on provided fields
|
|
updates = []
|
|
params = {"id": role_id}
|
|
|
|
if data.name is not None:
|
|
updates.append("name = :name")
|
|
params["name"] = data.name
|
|
|
|
if data.description is not None:
|
|
updates.append("description = :description")
|
|
params["description"] = data.description
|
|
|
|
if data.is_default is not None:
|
|
updates.append("is_default = :is_default")
|
|
params["is_default"] = data.is_default
|
|
|
|
if not updates:
|
|
return await get_role(role_id)
|
|
|
|
await db.execute(
|
|
f"UPDATE roles SET {', '.join(updates)} WHERE id = :id",
|
|
params,
|
|
)
|
|
|
|
return await get_role(role_id)
|
|
|
|
|
|
async def delete_role(role_id: str) -> None:
|
|
"""Delete a role (cascade deletes role_permissions and user_roles)"""
|
|
await db.execute(
|
|
"DELETE FROM roles WHERE id = :id",
|
|
{"id": role_id},
|
|
)
|
|
|
|
|
|
# ===== ROLE PERMISSION OPERATIONS =====
|
|
|
|
|
|
async def create_role_permission(data: CreateRolePermission) -> RolePermission:
|
|
"""Create a permission for a role"""
|
|
permission_id = urlsafe_short_hash()
|
|
permission = RolePermission(
|
|
id=permission_id,
|
|
role_id=data.role_id,
|
|
account_id=data.account_id,
|
|
permission_type=data.permission_type,
|
|
notes=data.notes,
|
|
created_at=datetime.now(),
|
|
)
|
|
|
|
await db.execute(
|
|
"""
|
|
INSERT INTO role_permissions (id, role_id, account_id, permission_type, notes, created_at)
|
|
VALUES (:id, :role_id, :account_id, :permission_type, :notes, :created_at)
|
|
""",
|
|
{
|
|
"id": permission.id,
|
|
"role_id": permission.role_id,
|
|
"account_id": permission.account_id,
|
|
"permission_type": permission.permission_type.value,
|
|
"notes": permission.notes,
|
|
"created_at": permission.created_at,
|
|
},
|
|
)
|
|
|
|
return permission
|
|
|
|
|
|
async def get_role_permissions(role_id: str) -> list[RolePermission]:
|
|
"""Get all permissions for a specific role"""
|
|
rows = await db.fetchall(
|
|
"""
|
|
SELECT * FROM role_permissions
|
|
WHERE role_id = :role_id
|
|
ORDER BY created_at DESC
|
|
""",
|
|
{"role_id": role_id},
|
|
)
|
|
|
|
return [
|
|
RolePermission(
|
|
id=row["id"],
|
|
role_id=row["role_id"],
|
|
account_id=row["account_id"],
|
|
permission_type=PermissionType(row["permission_type"]),
|
|
notes=row["notes"],
|
|
created_at=row["created_at"],
|
|
)
|
|
for row in rows
|
|
]
|
|
|
|
|
|
async def delete_role_permission(permission_id: str) -> None:
|
|
"""Delete a role permission"""
|
|
await db.execute(
|
|
"DELETE FROM role_permissions WHERE id = :id",
|
|
{"id": permission_id},
|
|
)
|
|
|
|
|
|
# ===== USER ROLE OPERATIONS =====
|
|
|
|
|
|
async def assign_user_role(data: AssignUserRole, granted_by: str) -> UserRole:
|
|
"""Assign a user to a role"""
|
|
user_role_id = urlsafe_short_hash()
|
|
user_role = UserRole(
|
|
id=user_role_id,
|
|
user_id=data.user_id,
|
|
role_id=data.role_id,
|
|
granted_by=granted_by,
|
|
granted_at=datetime.now(),
|
|
expires_at=data.expires_at,
|
|
notes=data.notes,
|
|
)
|
|
|
|
await db.execute(
|
|
"""
|
|
INSERT INTO user_roles (id, user_id, role_id, granted_by, granted_at, expires_at, notes)
|
|
VALUES (:id, :user_id, :role_id, :granted_by, :granted_at, :expires_at, :notes)
|
|
""",
|
|
{
|
|
"id": user_role.id,
|
|
"user_id": user_role.user_id,
|
|
"role_id": user_role.role_id,
|
|
"granted_by": user_role.granted_by,
|
|
"granted_at": user_role.granted_at,
|
|
"expires_at": user_role.expires_at,
|
|
"notes": user_role.notes,
|
|
},
|
|
)
|
|
|
|
return user_role
|
|
|
|
|
|
async def get_user_roles(user_id: str) -> list[UserRole]:
|
|
"""Get all active roles for a user"""
|
|
rows = await db.fetchall(
|
|
"""
|
|
SELECT * FROM user_roles
|
|
WHERE user_id = :user_id
|
|
AND (expires_at IS NULL OR expires_at > :now)
|
|
ORDER BY granted_at DESC
|
|
""",
|
|
{"user_id": user_id, "now": datetime.now()},
|
|
)
|
|
|
|
return [
|
|
UserRole(
|
|
id=row["id"],
|
|
user_id=row["user_id"],
|
|
role_id=row["role_id"],
|
|
granted_by=row["granted_by"],
|
|
granted_at=row["granted_at"],
|
|
expires_at=row["expires_at"],
|
|
notes=row["notes"],
|
|
)
|
|
for row in rows
|
|
]
|
|
|
|
|
|
async def get_all_user_roles() -> list[UserRole]:
|
|
"""Get all active user role assignments"""
|
|
rows = await db.fetchall(
|
|
"""
|
|
SELECT * FROM user_roles
|
|
WHERE (expires_at IS NULL OR expires_at > :now)
|
|
ORDER BY user_id, granted_at DESC
|
|
""",
|
|
{"now": datetime.now()},
|
|
)
|
|
|
|
return [
|
|
UserRole(
|
|
id=row["id"],
|
|
user_id=row["user_id"],
|
|
role_id=row["role_id"],
|
|
granted_by=row["granted_by"],
|
|
granted_at=row["granted_at"],
|
|
expires_at=row["expires_at"],
|
|
notes=row["notes"],
|
|
)
|
|
for row in rows
|
|
]
|
|
|
|
|
|
async def get_role_users(role_id: str) -> list[UserRole]:
|
|
"""Get all users assigned to a role"""
|
|
rows = await db.fetchall(
|
|
"""
|
|
SELECT * FROM user_roles
|
|
WHERE role_id = :role_id
|
|
AND (expires_at IS NULL OR expires_at > :now)
|
|
ORDER BY granted_at DESC
|
|
""",
|
|
{"role_id": role_id, "now": datetime.now()},
|
|
)
|
|
|
|
return [
|
|
UserRole(
|
|
id=row["id"],
|
|
user_id=row["user_id"],
|
|
role_id=row["role_id"],
|
|
granted_by=row["granted_by"],
|
|
granted_at=row["granted_at"],
|
|
expires_at=row["expires_at"],
|
|
notes=row["notes"],
|
|
)
|
|
for row in rows
|
|
]
|
|
|
|
|
|
async def revoke_user_role(user_role_id: str) -> None:
|
|
"""Revoke a user's role assignment"""
|
|
await db.execute(
|
|
"DELETE FROM user_roles WHERE id = :id",
|
|
{"id": user_role_id},
|
|
)
|
|
|
|
|
|
async def get_role_count_for_user(user_id: str) -> int:
|
|
"""Get count of active roles for a user"""
|
|
row = await db.fetchone(
|
|
"""
|
|
SELECT COUNT(*) as count FROM user_roles
|
|
WHERE user_id = :user_id
|
|
AND (expires_at IS NULL OR expires_at > :now)
|
|
""",
|
|
{"user_id": user_id, "now": datetime.now()},
|
|
)
|
|
|
|
return row["count"] if row else 0
|
|
|
|
|
|
async def auto_assign_default_role(user_id: str, assigned_by: str) -> UserRole | None:
|
|
"""
|
|
Auto-assign the default role to a user if they don't have any roles yet.
|
|
Returns the created UserRole if assigned, None if user already has roles or no default role exists.
|
|
"""
|
|
from loguru import logger
|
|
|
|
logger.info(f"[AUTO-ASSIGN] Checking auto-assignment for user {user_id}")
|
|
|
|
# Check if user already has any roles
|
|
user_role_count = await get_role_count_for_user(user_id)
|
|
logger.info(f"[AUTO-ASSIGN] User {user_id} has {user_role_count} roles")
|
|
if user_role_count > 0:
|
|
logger.info(f"[AUTO-ASSIGN] User {user_id} already has roles, skipping auto-assignment")
|
|
return None
|
|
|
|
# Find the default role
|
|
default_role = await get_default_role()
|
|
if not default_role:
|
|
logger.warning(f"[AUTO-ASSIGN] No default role found, cannot auto-assign for user {user_id}")
|
|
return None
|
|
|
|
logger.info(f"[AUTO-ASSIGN] Found default role: {default_role.name} (id: {default_role.id})")
|
|
|
|
# Assign the default role
|
|
data = AssignUserRole(
|
|
user_id=user_id,
|
|
role_id=default_role.id,
|
|
notes="Auto-assigned default role on first access",
|
|
)
|
|
result = await assign_user_role(data, assigned_by)
|
|
logger.info(f"[AUTO-ASSIGN] Successfully assigned role {default_role.name} to user {user_id}")
|
|
return result
|
|
|
|
|
|
async def get_user_count_for_role(role_id: str) -> int:
|
|
"""Get count of users assigned to a role"""
|
|
row = await db.fetchone(
|
|
"""
|
|
SELECT COUNT(*) as count FROM user_roles
|
|
WHERE role_id = :role_id
|
|
AND (expires_at IS NULL OR expires_at > :now)
|
|
""",
|
|
{"role_id": role_id, "now": datetime.now()},
|
|
)
|
|
|
|
return row["count"] if row else 0
|
|
|
|
|
|
# ===== RBAC HELPER FUNCTIONS =====
|
|
|
|
|
|
async def get_user_permissions_from_roles(
|
|
user_id: str,
|
|
) -> list[tuple[Role, list[RolePermission]]]:
|
|
"""
|
|
Get all permissions a user has through their role assignments.
|
|
Returns list of tuples: (role, list of permissions from that role)
|
|
"""
|
|
# Get user's active roles
|
|
user_roles = await get_user_roles(user_id)
|
|
|
|
result = []
|
|
for user_role in user_roles:
|
|
role = await get_role(user_role.role_id)
|
|
if role:
|
|
permissions = await get_role_permissions(role.id)
|
|
result.append((role, permissions))
|
|
|
|
return result
|
|
|
|
|
|
async def check_user_has_role_permission(
|
|
user_id: str, account_id: str, permission_type: PermissionType
|
|
) -> bool:
|
|
"""Check if user has a specific permission through any of their roles"""
|
|
# Get all permissions from user's roles
|
|
role_permissions = await get_user_permissions_from_roles(user_id)
|
|
|
|
# Check if any role grants the required permission on this account
|
|
for role, permissions in role_permissions:
|
|
for perm in permissions:
|
|
if perm.account_id == account_id and perm.permission_type == permission_type:
|
|
return True
|
|
|
|
return False
|