""" Device management service. """ import asyncio from datetime import datetime, timezone from sqlalchemy import String, func, or_, select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload from app.models.device import Device from app.models.organization import Organization from app.schemas.device import DeviceCreate, DeviceUpdate from app.utils.ssh_keys import sync_authorized_keys async def create_device( db: AsyncSession, data: DeviceCreate, ) -> Device: """ Create a new device. Args: db: Database session data: Device creation data Returns: Created device """ # Check if MAC address already exists result = await db.execute( select(Device).where(Device.mac_address == data.mac_address) ) existing_device = result.scalar_one_or_none() if existing_device: raise ValueError(f"Device with MAC {data.mac_address} already exists") device = Device( mac_address=data.mac_address, organization_id=data.organization_id, status="offline", config=data.config or {}, # simple_id will be auto-generated by PostgreSQL sequence ) db.add(device) await db.commit() await db.refresh(device) return device async def get_device(db: AsyncSession, device_id: int) -> Device | None: """ Get device by ID. Args: db: Database session device_id: Device ID Returns: Device or None """ result = await db.execute(select(Device).where(Device.id == device_id)) return result.scalar_one_or_none() async def get_device_by_mac(db: AsyncSession, mac_address: str) -> Device | None: """ Get device by MAC address. Args: db: Database session mac_address: Device MAC address Returns: Device or None """ result = await db.execute( select(Device).where(Device.mac_address == mac_address) ) return result.scalar_one_or_none() async def list_devices( db: AsyncSession, skip: int = 0, limit: int = 100, organization_id: int | None = None, status: str | None = None, search: str | None = None, ) -> tuple[list[Device], int]: """ List devices with pagination and filters. Args: db: Database session skip: Number of records to skip limit: Maximum number of records to return organization_id: Filter by organization (optional) status: Filter by status (optional) search: Universal search across all fields (optional) Returns: Tuple of (devices list, total count) """ # Build query query = select(Device) # Base filters filters = [] if organization_id is not None: filters.append(Device.organization_id == organization_id) if status: filters.append(Device.status == status) # Universal search filter - requires join with Organization if search and len(search) >= 2: # Join with Organization for searching by org name/email query = query.outerjoin(Organization, Device.organization_id == Organization.id) # Search across multiple fields search_pattern = f"%{search}%" search_filters = [ Device.mac_address.ilike(search_pattern), func.cast(Device.simple_id, String).ilike(search_pattern), Organization.name.ilike(search_pattern), Organization.contact_email.ilike(search_pattern), ] filters.append(or_(*search_filters)) # Always load organization relationship query = query.options(joinedload(Device.organization)) # Apply all filters if filters: query = query.where(*filters) # Get total count count_query = select(func.count()).select_from(Device) if search and len(search) >= 2: count_query = count_query.join( Organization, Device.organization_id == Organization.id, isouter=True ) if filters: count_query = count_query.where(*filters) total_result = await db.execute(count_query) total = total_result.scalar_one() # Get paginated results query = query.offset(skip).limit(limit).order_by(Device.simple_id.desc()) result = await db.execute(query) devices = list(result.scalars().all()) return devices, total async def update_device( db: AsyncSession, device_id: int, data: DeviceUpdate, ) -> Device | None: """ Update device. Args: db: Database session device_id: Device ID data: Update data Returns: Updated device or None if not found """ result = await db.execute(select(Device).where(Device.id == device_id)) device = result.scalar_one_or_none() if not device: return None # Update fields update_data = data.model_dump(exclude_unset=True) for field, value in update_data.items(): setattr(device, field, value) await db.commit() await db.refresh(device) return device async def delete_device( db: AsyncSession, device_id: int, ) -> bool: """ Delete device. Args: db: Database session device_id: Device ID Returns: True if deleted, False if not found """ result = await db.execute(select(Device).where(Device.id == device_id)) device = result.scalar_one_or_none() if not device: return False await db.delete(device) await db.commit() # Sync SSH keys to remove deleted device's key from authorized_keys asyncio.create_task(sync_authorized_keys()) return True async def update_device_heartbeat( db: AsyncSession, mac_address: str, ) -> Device | None: """ Update device last_seen_at timestamp (heartbeat). Args: db: Database session mac_address: Device MAC address Returns: Updated device or None if not found """ result = await db.execute( select(Device).where(Device.mac_address == mac_address) ) device = result.scalar_one_or_none() if not device: return None device.last_seen_at = datetime.now(timezone.utc) device.status = "online" await db.commit() await db.refresh(device) return device