tunnel_service.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. """
  2. Tunnel session management service for SSH and Dashboard tunnels.
  3. """
  4. import asyncio
  5. import os
  6. import signal
  7. import subprocess
  8. import uuid as uuid_module
  9. from datetime import datetime, timedelta
  10. from typing import Dict, Optional
  11. from pydantic import BaseModel
  12. class TunnelSession(BaseModel):
  13. """Tunnel session model"""
  14. uuid: str
  15. device_id: str
  16. admin_user: str
  17. tunnel_type: str # "ssh" | "dashboard"
  18. created_at: datetime
  19. expires_at: datetime
  20. last_heartbeat: Optional[datetime] = None
  21. ttyd_port: Optional[int] = None
  22. ttyd_pid: Optional[int] = None
  23. device_tunnel_port: Optional[int] = None
  24. status: str = "waiting" # "waiting" | "ready" | "failed"
  25. class TunnelStatus(BaseModel):
  26. """Device tunnel status"""
  27. device_id: str
  28. tunnel_type: str # "ssh" | "dashboard"
  29. allocated_port: Optional[int] = None
  30. status: str # "connected" | "disconnected"
  31. connected_at: Optional[datetime] = None
  32. last_heartbeat: Optional[datetime] = None
  33. class TunnelService:
  34. """
  35. Tunnel management service
  36. In-memory storage (можно заменить на Redis для multi-server)
  37. """
  38. def __init__(self):
  39. self.sessions: Dict[str, TunnelSession] = {}
  40. self.tunnel_status: Dict[str, TunnelStatus] = {}
  41. self.cleanup_task = None
  42. def start_background_cleanup(self):
  43. """Start background task for cleanup inactive sessions"""
  44. if not self.cleanup_task:
  45. self.cleanup_task = asyncio.create_task(self._cleanup_loop())
  46. async def _cleanup_loop(self):
  47. """Background cleanup loop"""
  48. while True:
  49. await asyncio.sleep(300) # Every 5 minutes
  50. await self.cleanup_inactive_sessions()
  51. async def cleanup_inactive_sessions(self):
  52. """
  53. Kill ttyd processes with no heartbeat for 60 minutes
  54. Remove expired sessions
  55. """
  56. now = datetime.now()
  57. inactive_threshold = now - timedelta(minutes=60)
  58. grace_period = now - timedelta(seconds=60)
  59. initial_grace = now - timedelta(minutes=2)
  60. for session_uuid, session in list(self.sessions.items()):
  61. # Check expiration (hard limit: 1 hour)
  62. if now > session.expires_at:
  63. print(f"[tunnel] Session expired: {session_uuid}")
  64. self._kill_ttyd(session.ttyd_pid)
  65. del self.sessions[session_uuid]
  66. continue
  67. # Check if tab was never opened (ttyd spawned but no heartbeat after 2 min)
  68. if (session.ttyd_pid and not session.last_heartbeat and
  69. session.created_at < initial_grace):
  70. print(f"[tunnel] Session never opened (no heartbeat): {session_uuid}")
  71. self._kill_ttyd(session.ttyd_pid)
  72. del self.sessions[session_uuid]
  73. continue
  74. # Check inactivity (60 minutes without heartbeat)
  75. if session.last_heartbeat and session.last_heartbeat < inactive_threshold:
  76. print(f"[tunnel] Session inactive for 60 min: {session_uuid}")
  77. self._kill_ttyd(session.ttyd_pid)
  78. del self.sessions[session_uuid]
  79. continue
  80. # Grace period: if tab closed, wait 60 seconds before killing
  81. if session.last_heartbeat and session.last_heartbeat < grace_period:
  82. if session.ttyd_pid and not self._is_process_alive(session.ttyd_pid):
  83. print(f"[tunnel] ttyd process dead: {session_uuid}")
  84. del self.sessions[session_uuid]
  85. def create_session(
  86. self,
  87. device_id: str,
  88. admin_user: str,
  89. tunnel_type: str
  90. ) -> TunnelSession:
  91. """Create new tunnel session"""
  92. session_uuid = str(uuid_module.uuid4())
  93. now = datetime.now()
  94. session = TunnelSession(
  95. uuid=session_uuid,
  96. device_id=device_id,
  97. admin_user=admin_user,
  98. tunnel_type=tunnel_type,
  99. created_at=now,
  100. expires_at=now + timedelta(hours=1),
  101. status="waiting"
  102. )
  103. self.sessions[session_uuid] = session
  104. # Create tunnel status key
  105. status_key = f"{device_id}:{tunnel_type}"
  106. if status_key not in self.tunnel_status:
  107. self.tunnel_status[status_key] = TunnelStatus(
  108. device_id=device_id,
  109. tunnel_type=tunnel_type,
  110. status="disconnected"
  111. )
  112. return session
  113. def get_session(self, session_uuid: str) -> Optional[TunnelSession]:
  114. """Get session by UUID"""
  115. return self.sessions.get(session_uuid)
  116. def update_heartbeat(self, session_uuid: str) -> bool:
  117. """Update session heartbeat"""
  118. session = self.sessions.get(session_uuid)
  119. if not session:
  120. return False
  121. session.last_heartbeat = datetime.now()
  122. return True
  123. def report_device_port(
  124. self,
  125. device_id: str,
  126. tunnel_type: str,
  127. port: Optional[int],
  128. status: str
  129. ):
  130. """Device reports tunnel port allocation"""
  131. status_key = f"{device_id}:{tunnel_type}"
  132. if status == "connected" and port:
  133. self.tunnel_status[status_key] = TunnelStatus(
  134. device_id=device_id,
  135. tunnel_type=tunnel_type,
  136. allocated_port=port,
  137. status="connected",
  138. connected_at=datetime.now(),
  139. last_heartbeat=datetime.now()
  140. )
  141. # Update all waiting sessions for this device
  142. for session in self.sessions.values():
  143. if (session.device_id == device_id and
  144. session.tunnel_type == tunnel_type and
  145. session.status == "waiting"):
  146. session.device_tunnel_port = port
  147. session.status = "ready"
  148. # Spawn ttyd only for SSH tunnels (dashboard doesn't need ttyd)
  149. if session.tunnel_type == "ssh" and not session.ttyd_port:
  150. try:
  151. # Wait a moment for SSH to be fully ready
  152. import time
  153. time.sleep(2)
  154. ttyd_port = self.spawn_ttyd(
  155. session_uuid=session.uuid,
  156. device_tunnel_port=port
  157. )
  158. print(f"[tunnel] Auto-spawned ttyd for session {session.uuid} on port {ttyd_port}")
  159. except Exception as e:
  160. print(f"[tunnel] Failed to auto-spawn ttyd: {e}")
  161. session.status = "failed"
  162. elif session.tunnel_type == "dashboard":
  163. # Wait for dashboard to be fully ready
  164. import time
  165. time.sleep(3)
  166. print(f"[tunnel] Dashboard tunnel ready for session {session.uuid} on port {port}")
  167. elif status == "disconnected":
  168. if status_key in self.tunnel_status:
  169. self.tunnel_status[status_key].status = "disconnected"
  170. self.tunnel_status[status_key].allocated_port = None
  171. def get_tunnel_status(
  172. self,
  173. device_id: str,
  174. tunnel_type: str
  175. ) -> Optional[TunnelStatus]:
  176. """Get tunnel status for device"""
  177. status_key = f"{device_id}:{tunnel_type}"
  178. return self.tunnel_status.get(status_key)
  179. def spawn_ttyd(
  180. self,
  181. session_uuid: str,
  182. device_tunnel_port: int,
  183. server_host: str = "localhost"
  184. ) -> int:
  185. """
  186. Spawn ttyd process for terminal access
  187. Returns ttyd port
  188. """
  189. session = self.sessions.get(session_uuid)
  190. if not session:
  191. raise ValueError(f"Session not found: {session_uuid}")
  192. # Find free port for ttyd (45000-49999)
  193. ttyd_port = self._find_free_port(45000, 49999)
  194. # Spawn ttyd process
  195. # ttyd connects to device via SSH through the tunnel port
  196. cmd = [
  197. "ttyd",
  198. "--port", str(ttyd_port),
  199. "--writable", # Allow input
  200. "ssh",
  201. "-p", str(device_tunnel_port),
  202. "-o", "StrictHostKeyChecking=no",
  203. "-o", "UserKnownHostsFile=/dev/null",
  204. "-o", "ServerAliveInterval=30",
  205. "-o", "ServerAliveCountMax=3",
  206. f"root@{server_host}"
  207. ]
  208. # Log ttyd output for debugging
  209. log_file = f"/tmp/ttyd_{ttyd_port}.log"
  210. with open(log_file, 'w') as f:
  211. f.write(f"Starting ttyd for session {session_uuid}\n")
  212. f.write(f"Command: {' '.join(cmd)}\n")
  213. process = subprocess.Popen(
  214. cmd,
  215. stdout=open(log_file, 'a'),
  216. stderr=subprocess.STDOUT
  217. )
  218. session.ttyd_port = ttyd_port
  219. session.ttyd_pid = process.pid
  220. print(f"[tunnel] Spawned ttyd on port {ttyd_port} (pid={process.pid})")
  221. return ttyd_port
  222. def _find_free_port(self, start: int, end: int) -> int:
  223. """Find free port in range"""
  224. import socket
  225. for port in range(start, end + 1):
  226. try:
  227. with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
  228. s.bind(('', port))
  229. return port
  230. except OSError:
  231. continue
  232. raise RuntimeError(f"No free ports in range {start}-{end}")
  233. def _kill_ttyd(self, pid: Optional[int]):
  234. """Kill ttyd process gracefully"""
  235. if not pid:
  236. return
  237. try:
  238. os.kill(pid, signal.SIGTERM)
  239. print(f"[tunnel] Killed ttyd process {pid}")
  240. except ProcessLookupError:
  241. pass
  242. def _is_process_alive(self, pid: int) -> bool:
  243. """Check if process is running"""
  244. try:
  245. os.kill(pid, 0) # Signal 0 = check existence
  246. return True
  247. except ProcessLookupError:
  248. return False
  249. # Global tunnel service instance
  250. tunnel_service = TunnelService()