| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441 |
- """
- Tunnel session management service for SSH and Dashboard tunnels.
- """
- import asyncio
- import os
- import re
- import signal
- import subprocess
- import uuid as uuid_module
- from datetime import datetime, timedelta
- from typing import Dict, Optional, List, Tuple
- from pydantic import BaseModel
- class TunnelSession(BaseModel):
- """Tunnel session model"""
- uuid: str
- device_id: str
- admin_user: str
- tunnel_type: str # "ssh" | "dashboard"
- created_at: datetime
- expires_at: datetime
- last_heartbeat: Optional[datetime] = None
- ttyd_port: Optional[int] = None
- ttyd_pid: Optional[int] = None
- device_tunnel_port: Optional[int] = None
- status: str = "waiting" # "waiting" | "ready" | "failed"
- class TunnelStatus(BaseModel):
- """Device tunnel status"""
- device_id: str
- tunnel_type: str # "ssh" | "dashboard"
- allocated_port: Optional[int] = None
- status: str # "connected" | "disconnected"
- connected_at: Optional[datetime] = None
- last_heartbeat: Optional[datetime] = None
- class TunnelService:
- """
- Tunnel management service
- In-memory storage (можно заменить на Redis для multi-server)
- """
- def __init__(self):
- self.sessions: Dict[str, TunnelSession] = {}
- self.tunnel_status: Dict[str, TunnelStatus] = {}
- async def cleanup_inactive_sessions(self):
- """
- Kill ttyd processes with no heartbeat for 60 minutes
- Remove expired sessions
- Returns list of (device_id, tunnel_type) tuples for tunnels that should be disabled
- """
- now = datetime.now()
- inactive_threshold = now - timedelta(minutes=60)
- grace_period = now - timedelta(seconds=60)
- initial_grace = now - timedelta(minutes=2)
- tunnels_to_disable = [] # List of (device_id, tunnel_type) to disable
- for session_uuid, session in list(self.sessions.items()):
- should_cleanup = False
- reason = ""
- # Check expiration (hard limit: 120 minutes)
- if now > session.expires_at:
- should_cleanup = True
- reason = "Session expired"
- # Check if tab was never opened (ttyd spawned but no heartbeat after 2 min)
- elif (session.ttyd_pid and not session.last_heartbeat and
- session.created_at < initial_grace):
- should_cleanup = True
- reason = "Session never opened (no heartbeat)"
- # Check inactivity (60 minutes without heartbeat)
- elif session.last_heartbeat and session.last_heartbeat < inactive_threshold:
- should_cleanup = True
- reason = "Session inactive for 60 min"
- # Grace period: if tab closed, wait 60 seconds before killing
- elif session.last_heartbeat and session.last_heartbeat < grace_period:
- if session.ttyd_pid and not self._is_process_alive(session.ttyd_pid):
- should_cleanup = True
- reason = "ttyd process dead"
- if should_cleanup:
- print(f"[tunnel] {reason}: {session_uuid}")
- self._kill_ttyd(session.ttyd_pid)
- del self.sessions[session_uuid]
- # Mark tunnel for disabling on device
- tunnels_to_disable.append((session.device_id, session.tunnel_type))
- return tunnels_to_disable
- def create_session(
- self,
- device_id: str,
- admin_user: str,
- tunnel_type: str
- ) -> TunnelSession:
- """Create new tunnel session"""
- session_uuid = str(uuid_module.uuid4())
- now = datetime.now()
- session = TunnelSession(
- uuid=session_uuid,
- device_id=device_id,
- admin_user=admin_user,
- tunnel_type=tunnel_type,
- created_at=now,
- expires_at=now + timedelta(minutes=120), # 2 hours
- status="waiting"
- )
- self.sessions[session_uuid] = session
- # Create tunnel status key
- status_key = f"{device_id}:{tunnel_type}"
- if status_key not in self.tunnel_status:
- self.tunnel_status[status_key] = TunnelStatus(
- device_id=device_id,
- tunnel_type=tunnel_type,
- status="disconnected"
- )
- return session
- def get_session(self, session_uuid: str) -> Optional[TunnelSession]:
- """Get session by UUID"""
- return self.sessions.get(session_uuid)
- def update_heartbeat(self, session_uuid: str) -> bool:
- """Update session heartbeat"""
- session = self.sessions.get(session_uuid)
- if not session:
- return False
- session.last_heartbeat = datetime.now()
- return True
- def report_device_port(
- self,
- device_id: str,
- tunnel_type: str,
- port: Optional[int],
- status: str
- ):
- """Device reports tunnel port allocation"""
- status_key = f"{device_id}:{tunnel_type}"
- if status == "connected" and port:
- self.tunnel_status[status_key] = TunnelStatus(
- device_id=device_id,
- tunnel_type=tunnel_type,
- allocated_port=port,
- status="connected",
- connected_at=datetime.now(),
- last_heartbeat=datetime.now()
- )
- # Update all waiting sessions for this device
- for session in self.sessions.values():
- if (session.device_id == device_id and
- session.tunnel_type == tunnel_type and
- session.status == "waiting"):
- session.device_tunnel_port = port
- session.status = "ready"
- # Spawn ttyd only for SSH tunnels (dashboard doesn't need ttyd)
- if session.tunnel_type == "ssh" and not session.ttyd_port:
- try:
- # Wait a moment for SSH to be fully ready
- import time
- time.sleep(2)
- ttyd_port = self.spawn_ttyd(
- session_uuid=session.uuid,
- device_tunnel_port=port
- )
- print(f"[tunnel] Auto-spawned ttyd for session {session.uuid} on port {ttyd_port}")
- except Exception as e:
- print(f"[tunnel] Failed to auto-spawn ttyd: {e}")
- session.status = "failed"
- elif session.tunnel_type == "dashboard":
- # Wait for dashboard to be fully ready
- import time
- time.sleep(3)
- print(f"[tunnel] Dashboard tunnel ready for session {session.uuid} on port {port}")
- elif status == "disconnected":
- if status_key in self.tunnel_status:
- self.tunnel_status[status_key].status = "disconnected"
- self.tunnel_status[status_key].allocated_port = None
- def get_tunnel_status(
- self,
- device_id: str,
- tunnel_type: str
- ) -> Optional[TunnelStatus]:
- """Get tunnel status for device"""
- status_key = f"{device_id}:{tunnel_type}"
- return self.tunnel_status.get(status_key)
- def spawn_ttyd(
- self,
- session_uuid: str,
- device_tunnel_port: int,
- server_host: str = "localhost"
- ) -> int:
- """
- Spawn ttyd process for terminal access
- Returns ttyd port
- """
- session = self.sessions.get(session_uuid)
- if not session:
- raise ValueError(f"Session not found: {session_uuid}")
- # Find free port for ttyd (45000-49999)
- ttyd_port = self._find_free_port(45000, 49999)
- # Spawn ttyd process
- # ttyd connects to device via SSH through the tunnel port
- cmd = [
- "ttyd",
- "--port", str(ttyd_port),
- "--writable", # Allow input
- "ssh",
- "-p", str(device_tunnel_port),
- "-o", "StrictHostKeyChecking=no",
- "-o", "UserKnownHostsFile=/dev/null",
- "-o", "ServerAliveInterval=30",
- "-o", "ServerAliveCountMax=3",
- f"root@{server_host}"
- ]
- # Log ttyd output for debugging
- log_file = f"/tmp/ttyd_{ttyd_port}.log"
- with open(log_file, 'w') as f:
- f.write(f"Starting ttyd for session {session_uuid}\n")
- f.write(f"Command: {' '.join(cmd)}\n")
- process = subprocess.Popen(
- cmd,
- stdout=open(log_file, 'a'),
- stderr=subprocess.STDOUT
- )
- session.ttyd_port = ttyd_port
- session.ttyd_pid = process.pid
- print(f"[tunnel] Spawned ttyd on port {ttyd_port} (pid={process.pid})")
- return ttyd_port
- def _find_free_port(self, start: int, end: int) -> int:
- """Find free port in range"""
- import socket
- for port in range(start, end + 1):
- try:
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
- s.bind(('', port))
- return port
- except OSError:
- continue
- raise RuntimeError(f"No free ports in range {start}-{end}")
- def _kill_ttyd(self, pid: Optional[int]):
- """Kill ttyd process gracefully"""
- if not pid:
- return
- try:
- os.kill(pid, signal.SIGTERM)
- print(f"[tunnel] Killed ttyd process {pid}")
- except ProcessLookupError:
- pass
- def _is_process_alive(self, pid: int) -> bool:
- """Check if process is running"""
- try:
- os.kill(pid, 0) # Signal 0 = check existence
- return True
- except ProcessLookupError:
- return False
- def _get_all_listening_ports(self) -> set:
- """
- Get all listening TCP ports ONCE to avoid multiple subprocess calls.
- Returns set of port numbers.
- """
- try:
- result = subprocess.run(
- ["ss", "-tln"],
- capture_output=True,
- text=True,
- timeout=5,
- check=False
- )
- # Parse port numbers from ss output
- ports = set()
- for line in result.stdout.split('\n'):
- # Look for :PORT in LISTEN state
- # Example: LISTEN 0 128 0.0.0.0:50001
- if 'LISTEN' in line or '*:' in line or '0.0.0.0:' in line or ':::' in line:
- # Extract port number
- import re
- port_matches = re.findall(r':(\d+)', line)
- for match in port_matches:
- ports.add(int(match))
- return ports
- except Exception as e:
- print(f"[watchdog] Error getting listening ports: {e}")
- return set()
- def _is_port_listening(self, port: int) -> bool:
- """Check if port is listening (tunnel is open) - DEPRECATED, use _get_all_listening_ports"""
- try:
- result = subprocess.run(
- ["ss", "-tln"],
- capture_output=True,
- text=True,
- timeout=5,
- check=False
- )
- # Look for port in LISTEN state
- return f":{port}" in result.stdout
- except Exception as e:
- print(f"[watchdog] Error checking port {port}: {e}")
- return False
- def _get_running_ttyd_processes(self) -> List[Tuple[int, int, int]]:
- """
- Get all running ttyd processes for tunnels.
- Returns: List of (pid, ttyd_port, ssh_tunnel_port)
- """
- try:
- result = subprocess.run(
- ["ps", "aux"],
- capture_output=True,
- text=True,
- timeout=5,
- check=False
- )
- processes = []
- for line in result.stdout.split('\n'):
- # Look for: ttyd --port 45XXX --writable ssh -p 50XXX
- if 'ttyd' in line and '--port 45' in line:
- # Extract PID
- parts = line.split()
- pid = int(parts[1])
- # Extract ttyd port (--port 45XXX)
- port_match = re.search(r'--port (\d+)', line)
- # Extract SSH tunnel port (-p 50XXX or -p 60XXX)
- ssh_port_match = re.search(r'ssh -p (\d+)', line)
- if port_match and ssh_port_match:
- ttyd_port = int(port_match.group(1))
- ssh_port = int(ssh_port_match.group(1))
- processes.append((pid, ttyd_port, ssh_port))
- return processes
- except Exception as e:
- print(f"[watchdog] Error getting ttyd processes: {e}")
- return []
- async def watchdog_cleanup(self) -> List[Tuple[str, str]]:
- """
- Watchdog: Kill orphaned ttyd processes and cleanup stale sessions.
- This runs independently of in-memory sessions and survives restarts.
- Returns: List of (device_id, tunnel_type) to disable in config
- """
- tunnels_to_disable = []
- now = datetime.now()
- print("[watchdog] Running tunnel watchdog...")
- # 1. Check all running ttyd processes
- ttyd_processes = self._get_running_ttyd_processes()
- print(f"[watchdog] Found {len(ttyd_processes)} ttyd processes")
- # 2. Get all listening ports ONCE (optimization to avoid multiple subprocess calls)
- listening_ports = self._get_all_listening_ports()
- for pid, ttyd_port, tunnel_port in ttyd_processes:
- # Check if tunnel port is still open
- if tunnel_port not in listening_ports:
- print(f"[watchdog] Tunnel port {tunnel_port} closed, killing ttyd {pid} (port {ttyd_port})")
- self._kill_ttyd(pid)
- # Find session and mark for config update
- for session in self.sessions.values():
- if session.ttyd_pid == pid:
- tunnels_to_disable.append((session.device_id, session.tunnel_type))
- break
- # 2. Check in-memory sessions
- grace_period = now - timedelta(seconds=60)
- inactive_threshold = now - timedelta(minutes=60)
- initial_grace = now - timedelta(minutes=2)
- for session_uuid, session in list(self.sessions.items()):
- should_cleanup = False
- reason = ""
- # Check expiration (hard limit: 120 minutes)
- if now > session.expires_at:
- should_cleanup = True
- reason = "Session expired (120 min)"
- # Check if tab was never opened
- elif (session.ttyd_pid and not session.last_heartbeat and
- session.created_at < initial_grace):
- should_cleanup = True
- reason = "Session never opened (no heartbeat)"
- # Check inactivity (60 minutes without heartbeat)
- elif session.last_heartbeat and session.last_heartbeat < inactive_threshold:
- should_cleanup = True
- reason = "Session inactive for 60 min"
- # Grace period: if tab closed, wait 60 seconds
- elif session.last_heartbeat and session.last_heartbeat < grace_period:
- should_cleanup = True
- reason = "Tab closed (60s grace period)"
- if should_cleanup:
- print(f"[watchdog] {reason}: {session_uuid}")
- self._kill_ttyd(session.ttyd_pid)
- del self.sessions[session_uuid]
- tunnels_to_disable.append((session.device_id, session.tunnel_type))
- return tunnels_to_disable
- # Global tunnel service instance
- tunnel_service = TunnelService()
|