Browse Source

Add tunnel watchdog and improve SSH key management

Implemented tunnel lifecycle management with watchdog mechanism:
- Added watchdog that checks orphaned ttyd processes every minute
- Watchdog scans running ttyd processes and kills those with closed tunnels
- Survives backend restarts (works independently of in-memory sessions)
- Checks heartbeat and session expiration (60s grace period, 120min timeout)

Centralized SSH key management:
- Created ssh_keys.py utility for authorized_keys synchronization
- Auto-sync SSH keys on device registration
- Auto-remove SSH keys on device deletion
- Fixes authentication failures from stale keys

Added configurable tunnel timeouts:
- timeout_minutes parameter in tunnel config (default: 120)
- Passed to device for client-side timeout enforcement

Tunnel shutdown mechanism (two independent parts):
1. Device side: Reads enabled=false from config, kills SSH tunnel
2. Backend side: Watchdog kills ttyd when tunnel closed or session expired

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
root 4 weeks ago
parent
commit
94dbcf8fb2

BIN
2025-12-28_17-42.png


BIN
2025-12-28_23-01.png


BIN
2025-12-28_23-29.png


+ 3 - 40
backend/app/api/v1/registration.py

@@ -6,7 +6,6 @@ import asyncio
 import copy
 import json
 import secrets
-import subprocess
 from base64 import b64encode
 from datetime import datetime, timezone
 from pathlib import Path
@@ -17,9 +16,10 @@ from pydantic import BaseModel
 from sqlalchemy import select, update
 from sqlalchemy.ext.asyncio import AsyncSession
 
-from app.core.database import async_session_maker, get_db
+from app.core.database import get_db
 from app.models.device import Device
 from app.models.settings import Settings
+from app.utils.ssh_keys import sync_authorized_keys
 
 router = APIRouter()
 
@@ -54,43 +54,6 @@ def _generate_password() -> str:
     return f"{n:08d}"
 
 
-async def _sync_authorized_keys():
-    """Sync all device SSH keys to /home/tunnel/.ssh/authorized_keys"""
-    try:
-        async with async_session_maker() as session:
-            result = await session.execute(select(Device))
-            devices = result.scalars().all()
-
-            keys = []
-            for device in devices:
-                if device.config and 'ssh_public_key' in device.config:
-                    ssh_key = device.config['ssh_public_key'].strip()
-                    if ssh_key:
-                        keys.append(f"{ssh_key} # {device.mac_address}")
-
-            authorized_keys_content = "\n".join(keys) + "\n" if keys else ""
-
-            # Write using sudo
-            subprocess.run(
-                ["sudo", "tee", "/home/tunnel/.ssh/authorized_keys"],
-                input=authorized_keys_content.encode(),
-                stdout=subprocess.DEVNULL,
-                check=True
-            )
-            subprocess.run(
-                ["sudo", "chmod", "600", "/home/tunnel/.ssh/authorized_keys"],
-                check=True
-            )
-            subprocess.run(
-                ["sudo", "chown", "tunnel:tunnel", "/home/tunnel/.ssh/authorized_keys"],
-                check=True
-            )
-
-            print(f"[SSH] Synced {len(keys)} keys to authorized_keys")
-    except Exception as e:
-        print(f"[SSH] Failed to sync authorized_keys: {e}")
-
-
 @router.post("/registration", response_model=RegistrationResponse, status_code=201)
 async def register_device(
     data: RegistrationRequest,
@@ -167,7 +130,7 @@ async def register_device(
 
     # Sync SSH keys to authorized_keys (background task)
     if data.ssh_public_key:
-        asyncio.create_task(_sync_authorized_keys())
+        asyncio.create_task(sync_authorized_keys())
 
     return RegistrationResponse(
         device_token=device.device_token,

+ 28 - 0
backend/app/api/v1/superadmin/tunnels.py

@@ -81,6 +81,33 @@ class TunnelStatusResponse(BaseModel):
     tunnel_url: Optional[str] = None
 
 
+@router.get("/sessions")
+async def list_all_sessions(
+    current_user: Annotated[User, Depends(get_current_superadmin)]
+):
+    """
+    Debug endpoint: List all active tunnel sessions in memory.
+    """
+    sessions = []
+    for uuid, session in tunnel_service.sessions.items():
+        sessions.append({
+            "uuid": uuid,
+            "device_id": session.device_id,
+            "tunnel_type": session.tunnel_type,
+            "status": session.status,
+            "created_at": session.created_at.isoformat(),
+            "expires_at": session.expires_at.isoformat(),
+            "last_heartbeat": session.last_heartbeat.isoformat() if session.last_heartbeat else None,
+            "ttyd_port": session.ttyd_port,
+            "ttyd_pid": session.ttyd_pid,
+            "device_tunnel_port": session.device_tunnel_port
+        })
+    return {
+        "total": len(sessions),
+        "sessions": sessions
+    }
+
+
 @router.post("/devices/{device_id}/{tunnel_type}")
 async def enable_tunnel(
     device_id: int,
@@ -145,6 +172,7 @@ async def enable_tunnel(
 
     device.config[tunnel_key]["enabled"] = True
     device.config[tunnel_key]["remote_port"] = allocated_port
+    device.config[tunnel_key]["timeout_minutes"] = 120  # Tunnel will auto-stop after 120 minutes
 
     # Copy other tunnel settings from default
     if "server" not in device.config[tunnel_key]:

+ 4 - 2
backend/app/default_config.json

@@ -21,7 +21,8 @@
     "port": 22,
     "user": "tunnel",
     "remote_port": 0,
-    "keepalive_interval": 30
+    "keepalive_interval": 30,
+    "timeout_minutes": 120
   },
   "dashboard_tunnel": {
     "enabled": false,
@@ -29,7 +30,8 @@
     "port": 22,
     "user": "tunnel",
     "remote_port": 0,
-    "keepalive_interval": 30
+    "keepalive_interval": 30,
+    "timeout_minutes": 120
   },
   "dashboard": {
     "enabled": true

+ 45 - 5
backend/app/main.py

@@ -52,17 +52,57 @@ from app.api.v1 import router as api_v1_router
 app.include_router(api_v1_router, prefix=settings.API_V1_PREFIX)
 
 
+# Background task for tunnel watchdog with DB updates
+async def tunnel_watchdog_with_db():
+    """
+    Tunnel watchdog: cleanup orphaned ttyd and stale sessions
+    Runs independently of in-memory state and survives restarts
+    """
+    import asyncio
+    from app.services.tunnel_service import tunnel_service
+    from app.core.database import async_session_maker
+    from app.models.device import Device
+    from sqlalchemy import select
+    from sqlalchemy.orm import attributes
+
+    while True:
+        await asyncio.sleep(60)  # Every minute
+
+        # Run watchdog and get list of tunnels to disable
+        tunnels_to_disable = await tunnel_service.watchdog_cleanup()
+
+        if tunnels_to_disable:
+            print(f"[watchdog] Disabling {len(tunnels_to_disable)} tunnels in device config")
+
+            async with async_session_maker() as db:
+                for device_mac, tunnel_type in tunnels_to_disable:
+                    # Find device by MAC address
+                    result = await db.execute(
+                        select(Device).where(Device.mac_address == device_mac)
+                    )
+                    device = result.scalar_one_or_none()
+
+                    if device and device.config:
+                        tunnel_key = f"{tunnel_type}_tunnel"
+                        if tunnel_key in device.config:
+                            # Disable tunnel
+                            device.config[tunnel_key]["enabled"] = False
+                            attributes.flag_modified(device, "config")
+                            print(f"[watchdog] Disabled {tunnel_type} tunnel for {device_mac}")
+
+                await db.commit()
+
+
 # Startup event
 @app.on_event("startup")
 async def startup_event():
     """Initialize services on startup"""
-    # Start tunnel cleanup background task
-    from app.services.tunnel_service import tunnel_service
-    tunnel_service.start_background_cleanup()
-    print("[startup] Tunnel cleanup task started")
+    # Start tunnel watchdog background task
+    import asyncio
+    asyncio.create_task(tunnel_watchdog_with_db())
+    print("[startup] Tunnel watchdog started (checks every 60s)")
 
     # Start host monitoring background task
-    import asyncio
     from app.services.host_monitor import host_monitor
     asyncio.create_task(host_monitor.run_monitoring_loop())
     print("[startup] Host monitoring task started")

+ 5 - 0
backend/app/services/device_service.py

@@ -2,6 +2,7 @@
 Device management service.
 """
 
+import asyncio
 from datetime import datetime, timezone
 
 from sqlalchemy import String, func, or_, select
@@ -11,6 +12,7 @@ from sqlalchemy.orm import joinedload
 from app.models.device import Device
 from app.models.organization import Organization
 from app.schemas.device import DeviceCreate, DeviceUpdate
+from app.utils.ssh_keys import sync_authorized_keys
 
 
 async def create_device(
@@ -218,6 +220,9 @@ async def delete_device(
     await db.delete(device)
     await db.commit()
 
+    # Sync SSH keys to remove deleted device's key from authorized_keys
+    asyncio.create_task(sync_authorized_keys())
+
     return True
 
 

+ 145 - 35
backend/app/services/tunnel_service.py

@@ -4,11 +4,12 @@ Tunnel session management service for SSH and Dashboard tunnels.
 
 import asyncio
 import os
+import re
 import signal
 import subprocess
 import uuid as uuid_module
 from datetime import datetime, timedelta
-from typing import Dict, Optional
+from typing import Dict, Optional, List, Tuple
 
 from pydantic import BaseModel
 
@@ -47,57 +48,52 @@ class TunnelService:
     def __init__(self):
         self.sessions: Dict[str, TunnelSession] = {}
         self.tunnel_status: Dict[str, TunnelStatus] = {}
-        self.cleanup_task = None
-
-    def start_background_cleanup(self):
-        """Start background task for cleanup inactive sessions"""
-        if not self.cleanup_task:
-            self.cleanup_task = asyncio.create_task(self._cleanup_loop())
-
-    async def _cleanup_loop(self):
-        """Background cleanup loop"""
-        while True:
-            await asyncio.sleep(300)  # Every 5 minutes
-            await self.cleanup_inactive_sessions()
 
     async def cleanup_inactive_sessions(self):
         """
         Kill ttyd processes with no heartbeat for 60 minutes
         Remove expired sessions
+        Returns list of (device_id, tunnel_type) tuples for tunnels that should be disabled
         """
         now = datetime.now()
         inactive_threshold = now - timedelta(minutes=60)
         grace_period = now - timedelta(seconds=60)
         initial_grace = now - timedelta(minutes=2)
 
+        tunnels_to_disable = []  # List of (device_id, tunnel_type) to disable
+
         for session_uuid, session in list(self.sessions.items()):
-            # Check expiration (hard limit: 1 hour)
-            if now > session.expires_at:
-                print(f"[tunnel] Session expired: {session_uuid}")
-                self._kill_ttyd(session.ttyd_pid)
-                del self.sessions[session_uuid]
-                continue
+            should_cleanup = False
+            reason = ""
 
+            # Check expiration (hard limit: 120 minutes)
+            if now > session.expires_at:
+                should_cleanup = True
+                reason = "Session expired"
             # Check if tab was never opened (ttyd spawned but no heartbeat after 2 min)
-            if (session.ttyd_pid and not session.last_heartbeat and
-                session.created_at < initial_grace):
-                print(f"[tunnel] Session never opened (no heartbeat): {session_uuid}")
-                self._kill_ttyd(session.ttyd_pid)
-                del self.sessions[session_uuid]
-                continue
-
+            elif (session.ttyd_pid and not session.last_heartbeat and
+                  session.created_at < initial_grace):
+                should_cleanup = True
+                reason = "Session never opened (no heartbeat)"
             # Check inactivity (60 minutes without heartbeat)
-            if session.last_heartbeat and session.last_heartbeat < inactive_threshold:
-                print(f"[tunnel] Session inactive for 60 min: {session_uuid}")
+            elif session.last_heartbeat and session.last_heartbeat < inactive_threshold:
+                should_cleanup = True
+                reason = "Session inactive for 60 min"
+            # Grace period: if tab closed, wait 60 seconds before killing
+            elif session.last_heartbeat and session.last_heartbeat < grace_period:
+                if session.ttyd_pid and not self._is_process_alive(session.ttyd_pid):
+                    should_cleanup = True
+                    reason = "ttyd process dead"
+
+            if should_cleanup:
+                print(f"[tunnel] {reason}: {session_uuid}")
                 self._kill_ttyd(session.ttyd_pid)
                 del self.sessions[session_uuid]
-                continue
 
-            # Grace period: if tab closed, wait 60 seconds before killing
-            if session.last_heartbeat and session.last_heartbeat < grace_period:
-                if session.ttyd_pid and not self._is_process_alive(session.ttyd_pid):
-                    print(f"[tunnel] ttyd process dead: {session_uuid}")
-                    del self.sessions[session_uuid]
+                # Mark tunnel for disabling on device
+                tunnels_to_disable.append((session.device_id, session.tunnel_type))
+
+        return tunnels_to_disable
 
     def create_session(
         self,
@@ -115,7 +111,7 @@ class TunnelService:
             admin_user=admin_user,
             tunnel_type=tunnel_type,
             created_at=now,
-            expires_at=now + timedelta(hours=1),
+            expires_at=now + timedelta(minutes=120),  # 2 hours
             status="waiting"
         )
 
@@ -290,6 +286,120 @@ class TunnelService:
         except ProcessLookupError:
             return False
 
+    def _is_port_listening(self, port: int) -> bool:
+        """Check if port is listening (tunnel is open)"""
+        try:
+            result = subprocess.run(
+                ["ss", "-tln"],
+                capture_output=True,
+                text=True,
+                timeout=5
+            )
+            # Look for port in LISTEN state
+            return f":{port}" in result.stdout
+        except Exception as e:
+            print(f"[watchdog] Error checking port {port}: {e}")
+            return False
+
+    def _get_running_ttyd_processes(self) -> List[Tuple[int, int, int]]:
+        """
+        Get all running ttyd processes for tunnels.
+        Returns: List of (pid, ttyd_port, ssh_tunnel_port)
+        """
+        try:
+            result = subprocess.run(
+                ["ps", "aux"],
+                capture_output=True,
+                text=True,
+                timeout=5
+            )
+
+            processes = []
+            for line in result.stdout.split('\n'):
+                # Look for: ttyd --port 45XXX --writable ssh -p 50XXX
+                if 'ttyd' in line and '--port 45' in line:
+                    # Extract PID
+                    parts = line.split()
+                    pid = int(parts[1])
+
+                    # Extract ttyd port (--port 45XXX)
+                    port_match = re.search(r'--port (\d+)', line)
+                    # Extract SSH tunnel port (-p 50XXX or -p 60XXX)
+                    ssh_port_match = re.search(r'ssh -p (\d+)', line)
+
+                    if port_match and ssh_port_match:
+                        ttyd_port = int(port_match.group(1))
+                        ssh_port = int(ssh_port_match.group(1))
+                        processes.append((pid, ttyd_port, ssh_port))
+
+            return processes
+        except Exception as e:
+            print(f"[watchdog] Error getting ttyd processes: {e}")
+            return []
+
+    async def watchdog_cleanup(self) -> List[Tuple[str, str]]:
+        """
+        Watchdog: Kill orphaned ttyd processes and cleanup stale sessions.
+        This runs independently of in-memory sessions and survives restarts.
+
+        Returns: List of (device_id, tunnel_type) to disable in config
+        """
+        tunnels_to_disable = []
+        now = datetime.now()
+
+        print("[watchdog] Running tunnel watchdog...")
+
+        # 1. Check all running ttyd processes
+        ttyd_processes = self._get_running_ttyd_processes()
+        print(f"[watchdog] Found {len(ttyd_processes)} ttyd processes")
+
+        for pid, ttyd_port, tunnel_port in ttyd_processes:
+            # Check if tunnel port is still open
+            if not self._is_port_listening(tunnel_port):
+                print(f"[watchdog] Tunnel port {tunnel_port} closed, killing ttyd {pid} (port {ttyd_port})")
+                self._kill_ttyd(pid)
+
+                # Find session and mark for config update
+                for session in self.sessions.values():
+                    if session.ttyd_pid == pid:
+                        tunnels_to_disable.append((session.device_id, session.tunnel_type))
+                        break
+
+        # 2. Check in-memory sessions
+        grace_period = now - timedelta(seconds=60)
+        inactive_threshold = now - timedelta(minutes=60)
+        initial_grace = now - timedelta(minutes=2)
+
+        for session_uuid, session in list(self.sessions.items()):
+            should_cleanup = False
+            reason = ""
+
+            # Check expiration (hard limit: 120 minutes)
+            if now > session.expires_at:
+                should_cleanup = True
+                reason = "Session expired (120 min)"
+            # Check if tab was never opened
+            elif (session.ttyd_pid and not session.last_heartbeat and
+                  session.created_at < initial_grace):
+                should_cleanup = True
+                reason = "Session never opened (no heartbeat)"
+            # Check inactivity (60 minutes without heartbeat)
+            elif session.last_heartbeat and session.last_heartbeat < inactive_threshold:
+                should_cleanup = True
+                reason = "Session inactive for 60 min"
+            # Grace period: if tab closed, wait 60 seconds
+            elif session.last_heartbeat and session.last_heartbeat < grace_period:
+                should_cleanup = True
+                reason = "Tab closed (60s grace period)"
+
+            if should_cleanup:
+                print(f"[watchdog] {reason}: {session_uuid}")
+                self._kill_ttyd(session.ttyd_pid)
+                del self.sessions[session_uuid]
+                tunnels_to_disable.append((session.device_id, session.tunnel_type))
+
+        return tunnels_to_disable
+
 
 # Global tunnel service instance
 tunnel_service = TunnelService()

+ 51 - 0
backend/app/utils/ssh_keys.py

@@ -0,0 +1,51 @@
+"""
+SSH keys management utilities.
+"""
+
+import subprocess
+from sqlalchemy import select
+from app.core.database import async_session_maker
+from app.models.device import Device
+
+
+async def sync_authorized_keys():
+    """
+    Sync all device SSH keys to /home/tunnel/.ssh/authorized_keys
+
+    This should be called:
+    - After device registration (to add new key)
+    - After device deletion (to remove key)
+    """
+    try:
+        async with async_session_maker() as session:
+            result = await session.execute(select(Device))
+            devices = result.scalars().all()
+
+            keys = []
+            for device in devices:
+                if device.config and 'ssh_public_key' in device.config:
+                    ssh_key = device.config['ssh_public_key'].strip()
+                    if ssh_key:
+                        keys.append(f"{ssh_key} # {device.mac_address}")
+
+            authorized_keys_content = "\n".join(keys) + "\n" if keys else ""
+
+            # Write using sudo
+            subprocess.run(
+                ["sudo", "tee", "/home/tunnel/.ssh/authorized_keys"],
+                input=authorized_keys_content.encode(),
+                stdout=subprocess.DEVNULL,
+                check=True
+            )
+            subprocess.run(
+                ["sudo", "chmod", "600", "/home/tunnel/.ssh/authorized_keys"],
+                check=True
+            )
+            subprocess.run(
+                ["sudo", "chown", "tunnel:tunnel", "/home/tunnel/.ssh/authorized_keys"],
+                check=True
+            )
+
+            print(f"[SSH] Synced {len(keys)} keys to authorized_keys")
+    except Exception as e:
+        print(f"[SSH] Failed to sync authorized_keys: {e}")