device_service.py 6.1 KB


  1. """
  2. Device management service.
  3. """
  4. import asyncio
  5. from datetime import datetime, timezone
  6. from sqlalchemy import String, func, or_, select
  7. from sqlalchemy.ext.asyncio import AsyncSession
  8. from sqlalchemy.orm import joinedload
  9. from app.models.device import Device
  10. from app.models.organization import Organization
  11. from app.schemas.device import DeviceCreate, DeviceUpdate
  12. from app.utils.ssh_keys import sync_authorized_keys
  13. async def create_device(
  14. db: AsyncSession,
  15. data: DeviceCreate,
  16. ) -> Device:
  17. """
  18. Create a new device.
  19. Args:
  20. db: Database session
  21. data: Device creation data
  22. Returns:
  23. Created device
  24. """
  25. # Check if MAC address already exists
  26. result = await db.execute(
  27. select(Device).where(Device.mac_address == data.mac_address)
  28. )
  29. existing_device = result.scalar_one_or_none()
  30. if existing_device:
  31. raise ValueError(f"Device with MAC {data.mac_address} already exists")
  32. device = Device(
  33. mac_address=data.mac_address,
  34. organization_id=data.organization_id,
  35. status="offline",
  36. config=data.config or {},
  37. # simple_id will be auto-generated by PostgreSQL sequence
  38. )
  39. db.add(device)
  40. await db.commit()
  41. await db.refresh(device)
  42. return device
  43. async def get_device(db: AsyncSession, device_id: int) -> Device | None:
  44. """
  45. Get device by ID.
  46. Args:
  47. db: Database session
  48. device_id: Device ID
  49. Returns:
  50. Device or None
  51. """
  52. result = await db.execute(select(Device).where(Device.id == device_id))
  53. return result.scalar_one_or_none()
  54. async def get_device_by_mac(db: AsyncSession, mac_address: str) -> Device | None:
  55. """
  56. Get device by MAC address.
  57. Args:
  58. db: Database session
  59. mac_address: Device MAC address
  60. Returns:
  61. Device or None
  62. """
  63. result = await db.execute(
  64. select(Device).where(Device.mac_address == mac_address)
  65. )
  66. return result.scalar_one_or_none()
  67. async def list_devices(
  68. db: AsyncSession,
  69. skip: int = 0,
  70. limit: int = 100,
  71. organization_id: int | None = None,
  72. status: str | None = None,
  73. search: str | None = None,
  74. ) -> tuple[list[Device], int]:
  75. """
  76. List devices with pagination and filters.
  77. Args:
  78. db: Database session
  79. skip: Number of records to skip
  80. limit: Maximum number of records to return
  81. organization_id: Filter by organization (optional)
  82. status: Filter by status (optional)
  83. search: Universal search across all fields (optional)
  84. Returns:
  85. Tuple of (devices list, total count)
  86. """
  87. # Build query
  88. query = select(Device)
  89. # Base filters
  90. filters = []
  91. if organization_id is not None:
  92. filters.append(Device.organization_id == organization_id)
  93. if status:
  94. filters.append(Device.status == status)
  95. # Universal search filter - requires join with Organization
  96. if search and len(search) >= 2:
  97. # Join with Organization for searching by org name/email
  98. query = query.outerjoin(Organization, Device.organization_id == Organization.id)
  99. # Search across multiple fields
  100. search_pattern = f"%{search}%"
  101. search_filters = [
  102. Device.mac_address.ilike(search_pattern),
  103. func.cast(Device.simple_id, String).ilike(search_pattern),
  104. Organization.name.ilike(search_pattern),
  105. Organization.contact_email.ilike(search_pattern),
  106. ]
  107. filters.append(or_(*search_filters))
  108. # Always load organization relationship
  109. query = query.options(joinedload(Device.organization))
  110. # Apply all filters
  111. if filters:
  112. query = query.where(*filters)
  113. # Get total count
  114. count_query = select(func.count()).select_from(Device)
  115. if search and len(search) >= 2:
  116. count_query = count_query.join(
  117. Organization, Device.organization_id == Organization.id, isouter=True
  118. )
  119. if filters:
  120. count_query = count_query.where(*filters)
  121. total_result = await db.execute(count_query)
  122. total = total_result.scalar_one()
  123. # Get paginated results
  124. query = query.offset(skip).limit(limit).order_by(Device.simple_id.desc())
  125. result = await db.execute(query)
  126. devices = list(result.scalars().all())
  127. return devices, total
  128. async def update_device(
  129. db: AsyncSession,
  130. device_id: int,
  131. data: DeviceUpdate,
  132. ) -> Device | None:
  133. """
  134. Update device.
  135. Args:
  136. db: Database session
  137. device_id: Device ID
  138. data: Update data
  139. Returns:
  140. Updated device or None if not found
  141. """
  142. result = await db.execute(select(Device).where(Device.id == device_id))
  143. device = result.scalar_one_or_none()
  144. if not device:
  145. return None
  146. # Update fields
  147. update_data = data.model_dump(exclude_unset=True)
  148. for field, value in update_data.items():
  149. setattr(device, field, value)
  150. await db.commit()
  151. await db.refresh(device)
  152. return device
  153. async def delete_device(
  154. db: AsyncSession,
  155. device_id: int,
  156. ) -> bool:
  157. """
  158. Delete device.
  159. Args:
  160. db: Database session
  161. device_id: Device ID
  162. Returns:
  163. True if deleted, False if not found
  164. """
  165. result = await db.execute(select(Device).where(Device.id == device_id))
  166. device = result.scalar_one_or_none()
  167. if not device:
  168. return False
  169. await db.delete(device)
  170. await db.commit()
  171. # Sync SSH keys to remove deleted device's key from authorized_keys
  172. asyncio.create_task(sync_authorized_keys())
  173. return True
  174. async def update_device_heartbeat(
  175. db: AsyncSession,
  176. mac_address: str,
  177. ) -> Device | None:
  178. """
  179. Update device last_seen_at timestamp (heartbeat).
  180. Args:
  181. db: Database session
  182. mac_address: Device MAC address
  183. Returns:
  184. Updated device or None if not found
  185. """
  186. result = await db.execute(
  187. select(Device).where(Device.mac_address == mac_address)
  188. )
  189. device = result.scalar_one_or_none()
  190. if not device:
  191. return None
  192. device.last_seen_at = datetime.now(timezone.utc)
  193. device.status = "online"
  194. await db.commit()
  195. await db.refresh(device)
  196. return device