tunnel_service.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445
  1. """
  2. Tunnel session management service for SSH and Dashboard tunnels.
  3. """
  4. import asyncio
  5. import os
  6. import re
  7. import signal
  8. import subprocess
  9. import uuid as uuid_module
  10. from datetime import datetime, timedelta
  11. from typing import Dict, Optional, List, Tuple
  12. from pydantic import BaseModel
  13. class TunnelSession(BaseModel):
  14. """Tunnel session model"""
  15. uuid: str
  16. device_id: str
  17. admin_user: str
  18. tunnel_type: str # "ssh" | "dashboard"
  19. created_at: datetime
  20. expires_at: datetime
  21. last_heartbeat: Optional[datetime] = None
  22. ttyd_port: Optional[int] = None
  23. ttyd_pid: Optional[int] = None
  24. device_tunnel_port: Optional[int] = None
  25. status: str = "waiting" # "waiting" | "ready" | "failed"
  26. class TunnelStatus(BaseModel):
  27. """Device tunnel status"""
  28. device_id: str
  29. tunnel_type: str # "ssh" | "dashboard"
  30. allocated_port: Optional[int] = None
  31. status: str # "connected" | "disconnected"
  32. connected_at: Optional[datetime] = None
  33. last_heartbeat: Optional[datetime] = None
  34. class TunnelService:
  35. """
  36. Tunnel management service
  37. In-memory storage (можно заменить на Redis для multi-server)
  38. """
  39. def __init__(self):
  40. self.sessions: Dict[str, TunnelSession] = {}
  41. self.tunnel_status: Dict[str, TunnelStatus] = {}
  42. async def cleanup_inactive_sessions(self):
  43. """
  44. Kill ttyd processes with no heartbeat for 60 minutes
  45. Remove expired sessions
  46. Returns list of (device_id, tunnel_type) tuples for tunnels that should be disabled
  47. """
  48. now = datetime.now()
  49. inactive_threshold = now - timedelta(minutes=60)
  50. grace_period = now - timedelta(seconds=60)
  51. initial_grace = now - timedelta(minutes=2)
  52. tunnels_to_disable = [] # List of (device_id, tunnel_type) to disable
  53. for session_uuid, session in list(self.sessions.items()):
  54. should_cleanup = False
  55. reason = ""
  56. # Check expiration (hard limit: 120 minutes)
  57. if now > session.expires_at:
  58. should_cleanup = True
  59. reason = "Session expired"
  60. # Check if tab was never opened (ttyd spawned but no heartbeat after 2 min)
  61. elif (session.ttyd_pid and not session.last_heartbeat and
  62. session.created_at < initial_grace):
  63. should_cleanup = True
  64. reason = "Session never opened (no heartbeat)"
  65. # Check inactivity (60 minutes without heartbeat)
  66. elif session.last_heartbeat and session.last_heartbeat < inactive_threshold:
  67. should_cleanup = True
  68. reason = "Session inactive for 60 min"
  69. # Grace period: if tab closed, wait 60 seconds before killing
  70. elif session.last_heartbeat and session.last_heartbeat < grace_period:
  71. if session.ttyd_pid and not self._is_process_alive(session.ttyd_pid):
  72. should_cleanup = True
  73. reason = "ttyd process dead"
  74. if should_cleanup:
  75. print(f"[tunnel] {reason}: {session_uuid}")
  76. self._kill_ttyd(session.ttyd_pid)
  77. del self.sessions[session_uuid]
  78. # Mark tunnel for disabling on device
  79. tunnels_to_disable.append((session.device_id, session.tunnel_type))
  80. return tunnels_to_disable
  81. def create_session(
  82. self,
  83. device_id: str,
  84. admin_user: str,
  85. tunnel_type: str
  86. ) -> TunnelSession:
  87. """Create new tunnel session"""
  88. session_uuid = str(uuid_module.uuid4())
  89. now = datetime.now()
  90. session = TunnelSession(
  91. uuid=session_uuid,
  92. device_id=device_id,
  93. admin_user=admin_user,
  94. tunnel_type=tunnel_type,
  95. created_at=now,
  96. expires_at=now + timedelta(minutes=120), # 2 hours
  97. status="waiting"
  98. )
  99. self.sessions[session_uuid] = session
  100. # Create tunnel status key
  101. status_key = f"{device_id}:{tunnel_type}"
  102. if status_key not in self.tunnel_status:
  103. self.tunnel_status[status_key] = TunnelStatus(
  104. device_id=device_id,
  105. tunnel_type=tunnel_type,
  106. status="disconnected"
  107. )
  108. return session
  109. def get_session(self, session_uuid: str) -> Optional[TunnelSession]:
  110. """Get session by UUID"""
  111. return self.sessions.get(session_uuid)
  112. def update_heartbeat(self, session_uuid: str) -> bool:
  113. """Update session heartbeat"""
  114. session = self.sessions.get(session_uuid)
  115. if not session:
  116. return False
  117. session.last_heartbeat = datetime.now()
  118. return True
  119. def report_device_port(
  120. self,
  121. device_id: str,
  122. tunnel_type: str,
  123. port: Optional[int],
  124. status: str
  125. ):
  126. """Device reports tunnel port allocation"""
  127. status_key = f"{device_id}:{tunnel_type}"
  128. if status == "connected" and port:
  129. self.tunnel_status[status_key] = TunnelStatus(
  130. device_id=device_id,
  131. tunnel_type=tunnel_type,
  132. allocated_port=port,
  133. status="connected",
  134. connected_at=datetime.now(),
  135. last_heartbeat=datetime.now()
  136. )
  137. # Update all waiting sessions for this device
  138. for session in self.sessions.values():
  139. if (session.device_id == device_id and
  140. session.tunnel_type == tunnel_type and
  141. session.status == "waiting"):
  142. session.device_tunnel_port = port
  143. session.status = "ready"
  144. # Spawn ttyd only for SSH tunnels (dashboard doesn't need ttyd)
  145. if session.tunnel_type == "ssh" and not session.ttyd_port:
  146. try:
  147. # Wait a moment for SSH to be fully ready
  148. import time
  149. time.sleep(2)
  150. ttyd_port = self.spawn_ttyd(
  151. session_uuid=session.uuid,
  152. device_tunnel_port=port
  153. )
  154. print(f"[tunnel] Auto-spawned ttyd for session {session.uuid} on port {ttyd_port}")
  155. except Exception as e:
  156. print(f"[tunnel] Failed to auto-spawn ttyd: {e}")
  157. session.status = "failed"
  158. elif session.tunnel_type == "dashboard":
  159. # Wait for dashboard to be fully ready
  160. import time
  161. time.sleep(3)
  162. print(f"[tunnel] Dashboard tunnel ready for session {session.uuid} on port {port}")
  163. elif status == "disconnected":
  164. if status_key in self.tunnel_status:
  165. self.tunnel_status[status_key].status = "disconnected"
  166. self.tunnel_status[status_key].allocated_port = None
  167. def get_tunnel_status(
  168. self,
  169. device_id: str,
  170. tunnel_type: str
  171. ) -> Optional[TunnelStatus]:
  172. """Get tunnel status for device"""
  173. status_key = f"{device_id}:{tunnel_type}"
  174. return self.tunnel_status.get(status_key)
  175. def spawn_ttyd(
  176. self,
  177. session_uuid: str,
  178. device_tunnel_port: int,
  179. server_host: str = "localhost"
  180. ) -> int:
  181. """
  182. Spawn ttyd process for terminal access
  183. Returns ttyd port
  184. """
  185. session = self.sessions.get(session_uuid)
  186. if not session:
  187. raise ValueError(f"Session not found: {session_uuid}")
  188. # Find free port for ttyd (45000-49999)
  189. ttyd_port = self._find_free_port(45000, 49999)
  190. # Spawn ttyd process
  191. # ttyd connects to device via SSH through the tunnel port
  192. # Uses dedicated key for device access
  193. import os
  194. ssh_key = os.path.expanduser("~/.ssh/device_access")
  195. cmd = [
  196. "ttyd",
  197. "--port", str(ttyd_port),
  198. "--writable", # Allow input
  199. "ssh",
  200. "-i", ssh_key,
  201. "-p", str(device_tunnel_port),
  202. "-o", "StrictHostKeyChecking=no",
  203. "-o", "UserKnownHostsFile=/dev/null",
  204. "-o", "ServerAliveInterval=30",
  205. "-o", "ServerAliveCountMax=3",
  206. f"root@{server_host}"
  207. ]
  208. # Log ttyd output for debugging
  209. log_file = f"/tmp/ttyd_{ttyd_port}.log"
  210. with open(log_file, 'w') as f:
  211. f.write(f"Starting ttyd for session {session_uuid}\n")
  212. f.write(f"Command: {' '.join(cmd)}\n")
  213. process = subprocess.Popen(
  214. cmd,
  215. stdout=open(log_file, 'a'),
  216. stderr=subprocess.STDOUT
  217. )
  218. session.ttyd_port = ttyd_port
  219. session.ttyd_pid = process.pid
  220. print(f"[tunnel] Spawned ttyd on port {ttyd_port} (pid={process.pid})")
  221. return ttyd_port
  222. def _find_free_port(self, start: int, end: int) -> int:
  223. """Find free port in range"""
  224. import socket
  225. for port in range(start, end + 1):
  226. try:
  227. with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
  228. s.bind(('', port))
  229. return port
  230. except OSError:
  231. continue
  232. raise RuntimeError(f"No free ports in range {start}-{end}")
  233. def _kill_ttyd(self, pid: Optional[int]):
  234. """Kill ttyd process gracefully"""
  235. if not pid:
  236. return
  237. try:
  238. os.kill(pid, signal.SIGTERM)
  239. print(f"[tunnel] Killed ttyd process {pid}")
  240. except ProcessLookupError:
  241. pass
  242. def _is_process_alive(self, pid: int) -> bool:
  243. """Check if process is running"""
  244. try:
  245. os.kill(pid, 0) # Signal 0 = check existence
  246. return True
  247. except ProcessLookupError:
  248. return False
  249. def _get_all_listening_ports(self) -> set:
  250. """
  251. Get all listening TCP ports ONCE to avoid multiple subprocess calls.
  252. Returns set of port numbers.
  253. """
  254. try:
  255. result = subprocess.run(
  256. ["ss", "-tln"],
  257. capture_output=True,
  258. text=True,
  259. timeout=5,
  260. check=False
  261. )
  262. # Parse port numbers from ss output
  263. ports = set()
  264. for line in result.stdout.split('\n'):
  265. # Look for :PORT in LISTEN state
  266. # Example: LISTEN 0 128 0.0.0.0:50001
  267. if 'LISTEN' in line or '*:' in line or '0.0.0.0:' in line or ':::' in line:
  268. # Extract port number
  269. import re
  270. port_matches = re.findall(r':(\d+)', line)
  271. for match in port_matches:
  272. ports.add(int(match))
  273. return ports
  274. except Exception as e:
  275. print(f"[watchdog] Error getting listening ports: {e}")
  276. return set()
  277. def _is_port_listening(self, port: int) -> bool:
  278. """Check if port is listening (tunnel is open) - DEPRECATED, use _get_all_listening_ports"""
  279. try:
  280. result = subprocess.run(
  281. ["ss", "-tln"],
  282. capture_output=True,
  283. text=True,
  284. timeout=5,
  285. check=False
  286. )
  287. # Look for port in LISTEN state
  288. return f":{port}" in result.stdout
  289. except Exception as e:
  290. print(f"[watchdog] Error checking port {port}: {e}")
  291. return False
  292. def _get_running_ttyd_processes(self) -> List[Tuple[int, int, int]]:
  293. """
  294. Get all running ttyd processes for tunnels.
  295. Returns: List of (pid, ttyd_port, ssh_tunnel_port)
  296. """
  297. try:
  298. result = subprocess.run(
  299. ["ps", "aux"],
  300. capture_output=True,
  301. text=True,
  302. timeout=5,
  303. check=False
  304. )
  305. processes = []
  306. for line in result.stdout.split('\n'):
  307. # Look for: ttyd --port 45XXX --writable ssh -p 50XXX
  308. if 'ttyd' in line and '--port 45' in line:
  309. # Extract PID
  310. parts = line.split()
  311. pid = int(parts[1])
  312. # Extract ttyd port (--port 45XXX)
  313. port_match = re.search(r'--port (\d+)', line)
  314. # Extract SSH tunnel port (-p 50XXX or -p 60XXX)
  315. ssh_port_match = re.search(r'ssh -p (\d+)', line)
  316. if port_match and ssh_port_match:
  317. ttyd_port = int(port_match.group(1))
  318. ssh_port = int(ssh_port_match.group(1))
  319. processes.append((pid, ttyd_port, ssh_port))
  320. return processes
  321. except Exception as e:
  322. print(f"[watchdog] Error getting ttyd processes: {e}")
  323. return []
  324. async def watchdog_cleanup(self) -> List[Tuple[str, str]]:
  325. """
  326. Watchdog: Kill orphaned ttyd processes and cleanup stale sessions.
  327. This runs independently of in-memory sessions and survives restarts.
  328. Returns: List of (device_id, tunnel_type) to disable in config
  329. """
  330. tunnels_to_disable = []
  331. now = datetime.now()
  332. print("[watchdog] Running tunnel watchdog...")
  333. # 1. Check all running ttyd processes
  334. ttyd_processes = self._get_running_ttyd_processes()
  335. print(f"[watchdog] Found {len(ttyd_processes)} ttyd processes")
  336. # 2. Get all listening ports ONCE (optimization to avoid multiple subprocess calls)
  337. listening_ports = self._get_all_listening_ports()
  338. for pid, ttyd_port, tunnel_port in ttyd_processes:
  339. # Check if tunnel port is still open
  340. if tunnel_port not in listening_ports:
  341. print(f"[watchdog] Tunnel port {tunnel_port} closed, killing ttyd {pid} (port {ttyd_port})")
  342. self._kill_ttyd(pid)
  343. # Find session and mark for config update
  344. for session in self.sessions.values():
  345. if session.ttyd_pid == pid:
  346. tunnels_to_disable.append((session.device_id, session.tunnel_type))
  347. break
  348. # 2. Check in-memory sessions
  349. grace_period = now - timedelta(seconds=60)
  350. inactive_threshold = now - timedelta(minutes=60)
  351. initial_grace = now - timedelta(minutes=2)
  352. for session_uuid, session in list(self.sessions.items()):
  353. should_cleanup = False
  354. reason = ""
  355. # Check expiration (hard limit: 120 minutes)
  356. if now > session.expires_at:
  357. should_cleanup = True
  358. reason = "Session expired (120 min)"
  359. # Check if tab was never opened
  360. elif (session.ttyd_pid and not session.last_heartbeat and
  361. session.created_at < initial_grace):
  362. should_cleanup = True
  363. reason = "Session never opened (no heartbeat)"
  364. # Check inactivity (60 minutes without heartbeat)
  365. elif session.last_heartbeat and session.last_heartbeat < inactive_threshold:
  366. should_cleanup = True
  367. reason = "Session inactive for 60 min"
  368. # Grace period: if tab closed, wait 60 seconds
  369. elif session.last_heartbeat and session.last_heartbeat < grace_period:
  370. should_cleanup = True
  371. reason = "Tab closed (60s grace period)"
  372. if should_cleanup:
  373. print(f"[watchdog] {reason}: {session_uuid}")
  374. self._kill_ttyd(session.ttyd_pid)
  375. del self.sessions[session_uuid]
  376. tunnels_to_disable.append((session.device_id, session.tunnel_type))
  377. return tunnels_to_disable
  378. # Global tunnel service instance
  379. tunnel_service = TunnelService()