""" Tunnel session management service for SSH and Dashboard tunnels. """ import asyncio import os import re import signal import subprocess import uuid as uuid_module from datetime import datetime, timedelta from typing import Dict, Optional, List, Tuple from pydantic import BaseModel class TunnelSession(BaseModel): """Tunnel session model""" uuid: str device_id: str admin_user: str tunnel_type: str # "ssh" | "dashboard" created_at: datetime expires_at: datetime last_heartbeat: Optional[datetime] = None ttyd_port: Optional[int] = None ttyd_pid: Optional[int] = None device_tunnel_port: Optional[int] = None status: str = "waiting" # "waiting" | "ready" | "failed" class TunnelStatus(BaseModel): """Device tunnel status""" device_id: str tunnel_type: str # "ssh" | "dashboard" allocated_port: Optional[int] = None status: str # "connected" | "disconnected" connected_at: Optional[datetime] = None last_heartbeat: Optional[datetime] = None class TunnelService: """ Tunnel management service In-memory storage (можно заменить на Redis для multi-server) """ def __init__(self): self.sessions: Dict[str, TunnelSession] = {} self.tunnel_status: Dict[str, TunnelStatus] = {} async def cleanup_inactive_sessions(self): """ Kill ttyd processes with no heartbeat for 60 minutes Remove expired sessions Returns list of (device_id, tunnel_type) tuples for tunnels that should be disabled """ now = datetime.now() inactive_threshold = now - timedelta(minutes=60) grace_period = now - timedelta(seconds=60) initial_grace = now - timedelta(minutes=2) tunnels_to_disable = [] # List of (device_id, tunnel_type) to disable for session_uuid, session in list(self.sessions.items()): should_cleanup = False reason = "" # Check expiration (hard limit: 120 minutes) if now > session.expires_at: should_cleanup = True reason = "Session expired" # Check if tab was never opened (ttyd spawned but no heartbeat after 2 min) elif (session.ttyd_pid and not session.last_heartbeat and session.created_at < initial_grace): should_cleanup = True reason = "Session never opened (no heartbeat)" # Check inactivity (60 minutes without heartbeat) elif session.last_heartbeat and session.last_heartbeat < inactive_threshold: should_cleanup = True reason = "Session inactive for 60 min" # Grace period: if tab closed, wait 60 seconds before killing elif session.last_heartbeat and session.last_heartbeat < grace_period: if session.ttyd_pid and not self._is_process_alive(session.ttyd_pid): should_cleanup = True reason = "ttyd process dead" if should_cleanup: print(f"[tunnel] {reason}: {session_uuid}") self._kill_ttyd(session.ttyd_pid) del self.sessions[session_uuid] # Mark tunnel for disabling on device tunnels_to_disable.append((session.device_id, session.tunnel_type)) return tunnels_to_disable def create_session( self, device_id: str, admin_user: str, tunnel_type: str ) -> TunnelSession: """Create new tunnel session""" session_uuid = str(uuid_module.uuid4()) now = datetime.now() session = TunnelSession( uuid=session_uuid, device_id=device_id, admin_user=admin_user, tunnel_type=tunnel_type, created_at=now, expires_at=now + timedelta(minutes=120), # 2 hours status="waiting" ) self.sessions[session_uuid] = session # Create tunnel status key status_key = f"{device_id}:{tunnel_type}" if status_key not in self.tunnel_status: self.tunnel_status[status_key] = TunnelStatus( device_id=device_id, tunnel_type=tunnel_type, status="disconnected" ) return session def get_session(self, session_uuid: str) -> Optional[TunnelSession]: """Get session by UUID""" return self.sessions.get(session_uuid) def update_heartbeat(self, session_uuid: str) -> bool: """Update session heartbeat""" session = self.sessions.get(session_uuid) if not session: return False session.last_heartbeat = datetime.now() return True def report_device_port( self, device_id: str, tunnel_type: str, port: Optional[int], status: str ): """Device reports tunnel port allocation""" status_key = f"{device_id}:{tunnel_type}" if status == "connected" and port: self.tunnel_status[status_key] = TunnelStatus( device_id=device_id, tunnel_type=tunnel_type, allocated_port=port, status="connected", connected_at=datetime.now(), last_heartbeat=datetime.now() ) # Update all waiting sessions for this device for session in self.sessions.values(): if (session.device_id == device_id and session.tunnel_type == tunnel_type and session.status == "waiting"): session.device_tunnel_port = port session.status = "ready" # Spawn ttyd only for SSH tunnels (dashboard doesn't need ttyd) if session.tunnel_type == "ssh" and not session.ttyd_port: try: # Wait a moment for SSH to be fully ready import time time.sleep(2) ttyd_port = self.spawn_ttyd( session_uuid=session.uuid, device_tunnel_port=port ) print(f"[tunnel] Auto-spawned ttyd for session {session.uuid} on port {ttyd_port}") except Exception as e: print(f"[tunnel] Failed to auto-spawn ttyd: {e}") session.status = "failed" elif session.tunnel_type == "dashboard": # Wait for dashboard to be fully ready import time time.sleep(3) print(f"[tunnel] Dashboard tunnel ready for session {session.uuid} on port {port}") elif status == "disconnected": if status_key in self.tunnel_status: self.tunnel_status[status_key].status = "disconnected" self.tunnel_status[status_key].allocated_port = None def get_tunnel_status( self, device_id: str, tunnel_type: str ) -> Optional[TunnelStatus]: """Get tunnel status for device""" status_key = f"{device_id}:{tunnel_type}" return self.tunnel_status.get(status_key) def spawn_ttyd( self, session_uuid: str, device_tunnel_port: int, server_host: str = "localhost" ) -> int: """ Spawn ttyd process for terminal access Returns ttyd port """ session = self.sessions.get(session_uuid) if not session: raise ValueError(f"Session not found: {session_uuid}") # Find free port for ttyd (45000-49999) ttyd_port = self._find_free_port(45000, 49999) # Spawn ttyd process # ttyd connects to device via SSH through the tunnel port # Uses dedicated key for device access import os ssh_key = os.path.expanduser("~/.ssh/device_access") cmd = [ "ttyd", "--port", str(ttyd_port), "--writable", # Allow input "ssh", "-i", ssh_key, "-p", str(device_tunnel_port), "-o", "StrictHostKeyChecking=no", "-o", "UserKnownHostsFile=/dev/null", "-o", "ServerAliveInterval=30", "-o", "ServerAliveCountMax=3", f"root@{server_host}" ] # Log ttyd output for debugging log_file = f"/tmp/ttyd_{ttyd_port}.log" with open(log_file, 'w') as f: f.write(f"Starting ttyd for session {session_uuid}\n") f.write(f"Command: {' '.join(cmd)}\n") process = subprocess.Popen( cmd, stdout=open(log_file, 'a'), stderr=subprocess.STDOUT ) session.ttyd_port = ttyd_port session.ttyd_pid = process.pid print(f"[tunnel] Spawned ttyd on port {ttyd_port} (pid={process.pid})") return ttyd_port def _find_free_port(self, start: int, end: int) -> int: """Find free port in range""" import socket for port in range(start, end + 1): try: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(('', port)) return port except OSError: continue raise RuntimeError(f"No free ports in range {start}-{end}") def _kill_ttyd(self, pid: Optional[int]): """Kill ttyd process gracefully""" if not pid: return try: os.kill(pid, signal.SIGTERM) print(f"[tunnel] Killed ttyd process {pid}") except ProcessLookupError: pass def _is_process_alive(self, pid: int) -> bool: """Check if process is running""" try: os.kill(pid, 0) # Signal 0 = check existence return True except ProcessLookupError: return False def _get_all_listening_ports(self) -> set: """ Get all listening TCP ports ONCE to avoid multiple subprocess calls. Returns set of port numbers. """ try: result = subprocess.run( ["ss", "-tln"], capture_output=True, text=True, timeout=5, check=False ) # Parse port numbers from ss output ports = set() for line in result.stdout.split('\n'): # Look for :PORT in LISTEN state # Example: LISTEN 0 128 0.0.0.0:50001 if 'LISTEN' in line or '*:' in line or '0.0.0.0:' in line or ':::' in line: # Extract port number import re port_matches = re.findall(r':(\d+)', line) for match in port_matches: ports.add(int(match)) return ports except Exception as e: print(f"[watchdog] Error getting listening ports: {e}") return set() def _is_port_listening(self, port: int) -> bool: """Check if port is listening (tunnel is open) - DEPRECATED, use _get_all_listening_ports""" try: result = subprocess.run( ["ss", "-tln"], capture_output=True, text=True, timeout=5, check=False ) # Look for port in LISTEN state return f":{port}" in result.stdout except Exception as e: print(f"[watchdog] Error checking port {port}: {e}") return False def _get_running_ttyd_processes(self) -> List[Tuple[int, int, int]]: """ Get all running ttyd processes for tunnels. Returns: List of (pid, ttyd_port, ssh_tunnel_port) """ try: result = subprocess.run( ["ps", "aux"], capture_output=True, text=True, timeout=5, check=False ) processes = [] for line in result.stdout.split('\n'): # Look for: ttyd --port 45XXX --writable ssh -p 50XXX if 'ttyd' in line and '--port 45' in line: # Extract PID parts = line.split() pid = int(parts[1]) # Extract ttyd port (--port 45XXX) port_match = re.search(r'--port (\d+)', line) # Extract SSH tunnel port (-p 50XXX or -p 60XXX) ssh_port_match = re.search(r'ssh -p (\d+)', line) if port_match and ssh_port_match: ttyd_port = int(port_match.group(1)) ssh_port = int(ssh_port_match.group(1)) processes.append((pid, ttyd_port, ssh_port)) return processes except Exception as e: print(f"[watchdog] Error getting ttyd processes: {e}") return [] async def watchdog_cleanup(self) -> List[Tuple[str, str]]: """ Watchdog: Kill orphaned ttyd processes and cleanup stale sessions. This runs independently of in-memory sessions and survives restarts. Returns: List of (device_id, tunnel_type) to disable in config """ tunnels_to_disable = [] now = datetime.now() print("[watchdog] Running tunnel watchdog...") # 1. Check all running ttyd processes ttyd_processes = self._get_running_ttyd_processes() print(f"[watchdog] Found {len(ttyd_processes)} ttyd processes") # 2. Get all listening ports ONCE (optimization to avoid multiple subprocess calls) listening_ports = self._get_all_listening_ports() for pid, ttyd_port, tunnel_port in ttyd_processes: # Check if tunnel port is still open if tunnel_port not in listening_ports: print(f"[watchdog] Tunnel port {tunnel_port} closed, killing ttyd {pid} (port {ttyd_port})") self._kill_ttyd(pid) # Find session and mark for config update for session in self.sessions.values(): if session.ttyd_pid == pid: tunnels_to_disable.append((session.device_id, session.tunnel_type)) break # 2. Check in-memory sessions grace_period = now - timedelta(seconds=60) inactive_threshold = now - timedelta(minutes=60) initial_grace = now - timedelta(minutes=2) for session_uuid, session in list(self.sessions.items()): should_cleanup = False reason = "" # Check expiration (hard limit: 120 minutes) if now > session.expires_at: should_cleanup = True reason = "Session expired (120 min)" # Check if tab was never opened elif (session.ttyd_pid and not session.last_heartbeat and session.created_at < initial_grace): should_cleanup = True reason = "Session never opened (no heartbeat)" # Check inactivity (60 minutes without heartbeat) elif session.last_heartbeat and session.last_heartbeat < inactive_threshold: should_cleanup = True reason = "Session inactive for 60 min" # Grace period: if tab closed, wait 60 seconds elif session.last_heartbeat and session.last_heartbeat < grace_period: should_cleanup = True reason = "Tab closed (60s grace period)" if should_cleanup: print(f"[watchdog] {reason}: {session_uuid}") self._kill_ttyd(session.ttyd_pid) del self.sessions[session_uuid] tunnels_to_disable.append((session.device_id, session.tunnel_type)) return tunnels_to_disable # Global tunnel service instance tunnel_service = TunnelService()