| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295 |
- """
- Tunnel session management service for SSH and Dashboard tunnels.
- """
- import asyncio
- import os
- import signal
- import subprocess
- import uuid as uuid_module
- from datetime import datetime, timedelta
- from typing import Dict, Optional
- from pydantic import BaseModel
- class TunnelSession(BaseModel):
- """Tunnel session model"""
- uuid: str
- device_id: str
- admin_user: str
- tunnel_type: str # "ssh" | "dashboard"
- created_at: datetime
- expires_at: datetime
- last_heartbeat: Optional[datetime] = None
- ttyd_port: Optional[int] = None
- ttyd_pid: Optional[int] = None
- device_tunnel_port: Optional[int] = None
- status: str = "waiting" # "waiting" | "ready" | "failed"
- class TunnelStatus(BaseModel):
- """Device tunnel status"""
- device_id: str
- tunnel_type: str # "ssh" | "dashboard"
- allocated_port: Optional[int] = None
- status: str # "connected" | "disconnected"
- connected_at: Optional[datetime] = None
- last_heartbeat: Optional[datetime] = None
- class TunnelService:
- """
- Tunnel management service
- In-memory storage (можно заменить на Redis для multi-server)
- """
- def __init__(self):
- self.sessions: Dict[str, TunnelSession] = {}
- self.tunnel_status: Dict[str, TunnelStatus] = {}
- self.cleanup_task = None
- def start_background_cleanup(self):
- """Start background task for cleanup inactive sessions"""
- if not self.cleanup_task:
- self.cleanup_task = asyncio.create_task(self._cleanup_loop())
- async def _cleanup_loop(self):
- """Background cleanup loop"""
- while True:
- await asyncio.sleep(300) # Every 5 minutes
- await self.cleanup_inactive_sessions()
- async def cleanup_inactive_sessions(self):
- """
- Kill ttyd processes with no heartbeat for 60 minutes
- Remove expired sessions
- """
- now = datetime.now()
- inactive_threshold = now - timedelta(minutes=60)
- grace_period = now - timedelta(seconds=60)
- initial_grace = now - timedelta(minutes=2)
- for session_uuid, session in list(self.sessions.items()):
- # Check expiration (hard limit: 1 hour)
- if now > session.expires_at:
- print(f"[tunnel] Session expired: {session_uuid}")
- self._kill_ttyd(session.ttyd_pid)
- del self.sessions[session_uuid]
- continue
- # Check if tab was never opened (ttyd spawned but no heartbeat after 2 min)
- if (session.ttyd_pid and not session.last_heartbeat and
- session.created_at < initial_grace):
- print(f"[tunnel] Session never opened (no heartbeat): {session_uuid}")
- self._kill_ttyd(session.ttyd_pid)
- del self.sessions[session_uuid]
- continue
- # Check inactivity (60 minutes without heartbeat)
- if session.last_heartbeat and session.last_heartbeat < inactive_threshold:
- print(f"[tunnel] Session inactive for 60 min: {session_uuid}")
- self._kill_ttyd(session.ttyd_pid)
- del self.sessions[session_uuid]
- continue
- # Grace period: if tab closed, wait 60 seconds before killing
- if session.last_heartbeat and session.last_heartbeat < grace_period:
- if session.ttyd_pid and not self._is_process_alive(session.ttyd_pid):
- print(f"[tunnel] ttyd process dead: {session_uuid}")
- del self.sessions[session_uuid]
- def create_session(
- self,
- device_id: str,
- admin_user: str,
- tunnel_type: str
- ) -> TunnelSession:
- """Create new tunnel session"""
- session_uuid = str(uuid_module.uuid4())
- now = datetime.now()
- session = TunnelSession(
- uuid=session_uuid,
- device_id=device_id,
- admin_user=admin_user,
- tunnel_type=tunnel_type,
- created_at=now,
- expires_at=now + timedelta(hours=1),
- status="waiting"
- )
- self.sessions[session_uuid] = session
- # Create tunnel status key
- status_key = f"{device_id}:{tunnel_type}"
- if status_key not in self.tunnel_status:
- self.tunnel_status[status_key] = TunnelStatus(
- device_id=device_id,
- tunnel_type=tunnel_type,
- status="disconnected"
- )
- return session
- def get_session(self, session_uuid: str) -> Optional[TunnelSession]:
- """Get session by UUID"""
- return self.sessions.get(session_uuid)
- def update_heartbeat(self, session_uuid: str) -> bool:
- """Update session heartbeat"""
- session = self.sessions.get(session_uuid)
- if not session:
- return False
- session.last_heartbeat = datetime.now()
- return True
- def report_device_port(
- self,
- device_id: str,
- tunnel_type: str,
- port: Optional[int],
- status: str
- ):
- """Device reports tunnel port allocation"""
- status_key = f"{device_id}:{tunnel_type}"
- if status == "connected" and port:
- self.tunnel_status[status_key] = TunnelStatus(
- device_id=device_id,
- tunnel_type=tunnel_type,
- allocated_port=port,
- status="connected",
- connected_at=datetime.now(),
- last_heartbeat=datetime.now()
- )
- # Update all waiting sessions for this device
- for session in self.sessions.values():
- if (session.device_id == device_id and
- session.tunnel_type == tunnel_type and
- session.status == "waiting"):
- session.device_tunnel_port = port
- session.status = "ready"
- # Spawn ttyd only for SSH tunnels (dashboard doesn't need ttyd)
- if session.tunnel_type == "ssh" and not session.ttyd_port:
- try:
- # Wait a moment for SSH to be fully ready
- import time
- time.sleep(2)
- ttyd_port = self.spawn_ttyd(
- session_uuid=session.uuid,
- device_tunnel_port=port
- )
- print(f"[tunnel] Auto-spawned ttyd for session {session.uuid} on port {ttyd_port}")
- except Exception as e:
- print(f"[tunnel] Failed to auto-spawn ttyd: {e}")
- session.status = "failed"
- elif session.tunnel_type == "dashboard":
- # Wait for dashboard to be fully ready
- import time
- time.sleep(3)
- print(f"[tunnel] Dashboard tunnel ready for session {session.uuid} on port {port}")
- elif status == "disconnected":
- if status_key in self.tunnel_status:
- self.tunnel_status[status_key].status = "disconnected"
- self.tunnel_status[status_key].allocated_port = None
- def get_tunnel_status(
- self,
- device_id: str,
- tunnel_type: str
- ) -> Optional[TunnelStatus]:
- """Get tunnel status for device"""
- status_key = f"{device_id}:{tunnel_type}"
- return self.tunnel_status.get(status_key)
- def spawn_ttyd(
- self,
- session_uuid: str,
- device_tunnel_port: int,
- server_host: str = "localhost"
- ) -> int:
- """
- Spawn ttyd process for terminal access
- Returns ttyd port
- """
- session = self.sessions.get(session_uuid)
- if not session:
- raise ValueError(f"Session not found: {session_uuid}")
- # Find free port for ttyd (45000-49999)
- ttyd_port = self._find_free_port(45000, 49999)
- # Spawn ttyd process
- # ttyd connects to device via SSH through the tunnel port
- cmd = [
- "ttyd",
- "--port", str(ttyd_port),
- "--writable", # Allow input
- "ssh",
- "-p", str(device_tunnel_port),
- "-o", "StrictHostKeyChecking=no",
- "-o", "UserKnownHostsFile=/dev/null",
- "-o", "ServerAliveInterval=30",
- "-o", "ServerAliveCountMax=3",
- f"root@{server_host}"
- ]
- # Log ttyd output for debugging
- log_file = f"/tmp/ttyd_{ttyd_port}.log"
- with open(log_file, 'w') as f:
- f.write(f"Starting ttyd for session {session_uuid}\n")
- f.write(f"Command: {' '.join(cmd)}\n")
- process = subprocess.Popen(
- cmd,
- stdout=open(log_file, 'a'),
- stderr=subprocess.STDOUT
- )
- session.ttyd_port = ttyd_port
- session.ttyd_pid = process.pid
- print(f"[tunnel] Spawned ttyd on port {ttyd_port} (pid={process.pid})")
- return ttyd_port
- def _find_free_port(self, start: int, end: int) -> int:
- """Find free port in range"""
- import socket
- for port in range(start, end + 1):
- try:
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
- s.bind(('', port))
- return port
- except OSError:
- continue
- raise RuntimeError(f"No free ports in range {start}-{end}")
- def _kill_ttyd(self, pid: Optional[int]):
- """Kill ttyd process gracefully"""
- if not pid:
- return
- try:
- os.kill(pid, signal.SIGTERM)
- print(f"[tunnel] Killed ttyd process {pid}")
- except ProcessLookupError:
- pass
- def _is_process_alive(self, pid: int) -> bool:
- """Check if process is running"""
- try:
- os.kill(pid, 0) # Signal 0 = check existence
- return True
- except ProcessLookupError:
- return False
- # Global tunnel service instance
- tunnel_service = TunnelService()
|