registration.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. """
  2. Device registration endpoint.
  3. """
  4. import asyncio
  5. import json
  6. import secrets
  7. from base64 import b64encode
  8. from datetime import datetime, timezone
  9. from pathlib import Path
  10. from typing import Annotated
  11. from fastapi import APIRouter, Depends, HTTPException, status
  12. from pydantic import BaseModel
  13. from sqlalchemy import select, update
  14. from sqlalchemy.ext.asyncio import AsyncSession
  15. from app.core.database import get_db
  16. from app.models.device import Device
  17. from app.models.settings import Settings
  18. from app.utils.ssh_keys import sync_authorized_keys
  19. router = APIRouter()
  20. # Path to default config file
  21. DEFAULT_CONFIG_PATH = Path(__file__).parent.parent.parent / "default_config.json"
  22. def _load_default_config() -> dict:
  23. """Load default config from file (not cached, always fresh)."""
  24. with open(DEFAULT_CONFIG_PATH, "r") as f:
  25. return json.load(f)
  26. class RegistrationRequest(BaseModel):
  27. """Device registration request."""
  28. device_id: str # MAC address
  29. ssh_public_key: str | None = None
  30. class RegistrationResponse(BaseModel):
  31. """Device registration response."""
  32. device_token: str
  33. device_password: str
  34. def _generate_token() -> str:
  35. """Generate 32-byte base64 token."""
  36. return b64encode(secrets.token_bytes(32)).decode("ascii")
  37. def _generate_password() -> str:
  38. """Generate 8-digit password."""
  39. n = secrets.randbelow(10**8)
  40. return f"{n:08d}"
  41. @router.post("/registration", response_model=RegistrationResponse, status_code=201)
  42. async def register_device(
  43. data: RegistrationRequest,
  44. db: Annotated[AsyncSession, Depends(get_db)],
  45. ):
  46. """
  47. Register new device or return existing credentials.
  48. Requires auto_registration to be enabled in settings.
  49. """
  50. mac_address = data.device_id.lower().strip()
  51. if not mac_address:
  52. raise HTTPException(
  53. status_code=status.HTTP_400_BAD_REQUEST,
  54. detail="Missing device_id",
  55. )
  56. # Check if device already exists
  57. result = await db.execute(select(Device).where(Device.mac_address == mac_address))
  58. device = result.scalar_one_or_none()
  59. if device:
  60. # Update SSH key if provided (device may have regenerated keys)
  61. ssh_key_updated = False
  62. if data.ssh_public_key:
  63. new_key = data.ssh_public_key.strip()
  64. old_key = (device.config or {}).get("ssh_public_key", "").strip()
  65. if new_key != old_key:
  66. # Update config with new key (preserve other settings)
  67. new_config = {**(device.config or {}), "ssh_public_key": new_key}
  68. device.config = new_config
  69. ssh_key_updated = True
  70. print(f"[REGISTRATION] Updated SSH key for device={mac_address}")
  71. # Re-generate credentials if missing
  72. if not device.device_token or not device.device_password:
  73. device.device_token = _generate_token()
  74. device.device_password = _generate_password()
  75. if ssh_key_updated or not device.device_token:
  76. await db.commit()
  77. # Sync SSH keys if updated
  78. if ssh_key_updated:
  79. asyncio.create_task(sync_authorized_keys())
  80. return RegistrationResponse(
  81. device_token=device.device_token,
  82. device_password=device.device_password,
  83. )
  84. # Check auto-registration setting
  85. settings_result = await db.execute(
  86. select(Settings).where(Settings.key == "auto_registration")
  87. )
  88. auto_reg_setting = settings_result.scalar_one_or_none()
  89. if not auto_reg_setting or not auto_reg_setting.value.get("enabled", False):
  90. raise HTTPException(
  91. status_code=status.HTTP_401_UNAUTHORIZED,
  92. detail="Registration disabled. Contact administrator.",
  93. )
  94. # Create new device with default config (read fresh from file)
  95. device_config = _load_default_config()
  96. # Add SSH public key if provided
  97. if data.ssh_public_key:
  98. device_config["ssh_public_key"] = data.ssh_public_key
  99. device = Device(
  100. mac_address=mac_address,
  101. organization_id=None, # Unassigned
  102. status="online",
  103. config=device_config,
  104. device_token=_generate_token(),
  105. device_password=_generate_password(),
  106. )
  107. db.add(device)
  108. await db.flush()
  109. # Update last_device_at
  110. auto_reg_setting.value["last_device_at"] = datetime.now(timezone.utc).isoformat()
  111. await db.commit()
  112. await db.refresh(device)
  113. print(f"[REGISTRATION] device={mac_address} simple_id={device.simple_id}")
  114. # Sync SSH keys to authorized_keys (background task)
  115. if data.ssh_public_key:
  116. asyncio.create_task(sync_authorized_keys())
  117. return RegistrationResponse(
  118. device_token=device.device_token,
  119. device_password=device.device_password,
  120. )