tunnel_service.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. """
  2. Tunnel session management service for SSH and Dashboard tunnels.
  3. """
  4. import asyncio
  5. import os
  6. import signal
  7. import subprocess
  8. import uuid as uuid_module
  9. from datetime import datetime, timedelta
  10. from typing import Dict, Optional
  11. from pydantic import BaseModel
  12. class TunnelSession(BaseModel):
  13. """Tunnel session model"""
  14. uuid: str
  15. device_id: str
  16. admin_user: str
  17. tunnel_type: str # "ssh" | "dashboard"
  18. created_at: datetime
  19. expires_at: datetime
  20. last_heartbeat: Optional[datetime] = None
  21. ttyd_port: Optional[int] = None
  22. ttyd_pid: Optional[int] = None
  23. device_tunnel_port: Optional[int] = None
  24. status: str = "waiting" # "waiting" | "ready" | "failed"
  25. class TunnelStatus(BaseModel):
  26. """Device tunnel status"""
  27. device_id: str
  28. tunnel_type: str # "ssh" | "dashboard"
  29. allocated_port: Optional[int] = None
  30. status: str # "connected" | "disconnected"
  31. connected_at: Optional[datetime] = None
  32. last_heartbeat: Optional[datetime] = None
  33. class TunnelService:
  34. """
  35. Tunnel management service
  36. In-memory storage (можно заменить на Redis для multi-server)
  37. """
  38. def __init__(self):
  39. self.sessions: Dict[str, TunnelSession] = {}
  40. self.tunnel_status: Dict[str, TunnelStatus] = {}
  41. self.cleanup_task = None
  42. def start_background_cleanup(self):
  43. """Start background task for cleanup inactive sessions"""
  44. if not self.cleanup_task:
  45. self.cleanup_task = asyncio.create_task(self._cleanup_loop())
  46. async def _cleanup_loop(self):
  47. """Background cleanup loop"""
  48. while True:
  49. await asyncio.sleep(300) # Every 5 minutes
  50. await self.cleanup_inactive_sessions()
  51. async def cleanup_inactive_sessions(self):
  52. """
  53. Kill ttyd processes with no heartbeat for 60 minutes
  54. Remove expired sessions
  55. """
  56. now = datetime.now()
  57. inactive_threshold = now - timedelta(minutes=60)
  58. grace_period = now - timedelta(seconds=60)
  59. for session_uuid, session in list(self.sessions.items()):
  60. # Check expiration (hard limit: 1 hour)
  61. if now > session.expires_at:
  62. print(f"[tunnel] Session expired: {session_uuid}")
  63. self._kill_ttyd(session.ttyd_pid)
  64. del self.sessions[session_uuid]
  65. continue
  66. # Check inactivity (60 minutes without heartbeat)
  67. if session.last_heartbeat and session.last_heartbeat < inactive_threshold:
  68. print(f"[tunnel] Session inactive for 60 min: {session_uuid}")
  69. self._kill_ttyd(session.ttyd_pid)
  70. del self.sessions[session_uuid]
  71. continue
  72. # Grace period: if tab closed, wait 60 seconds before killing
  73. if session.last_heartbeat and session.last_heartbeat < grace_period:
  74. if session.ttyd_pid and not self._is_process_alive(session.ttyd_pid):
  75. print(f"[tunnel] ttyd process dead: {session_uuid}")
  76. del self.sessions[session_uuid]
  77. def create_session(
  78. self,
  79. device_id: str,
  80. admin_user: str,
  81. tunnel_type: str
  82. ) -> TunnelSession:
  83. """Create new tunnel session"""
  84. session_uuid = str(uuid_module.uuid4())
  85. now = datetime.now()
  86. session = TunnelSession(
  87. uuid=session_uuid,
  88. device_id=device_id,
  89. admin_user=admin_user,
  90. tunnel_type=tunnel_type,
  91. created_at=now,
  92. expires_at=now + timedelta(hours=1),
  93. status="waiting"
  94. )
  95. self.sessions[session_uuid] = session
  96. # Create tunnel status key
  97. status_key = f"{device_id}:{tunnel_type}"
  98. if status_key not in self.tunnel_status:
  99. self.tunnel_status[status_key] = TunnelStatus(
  100. device_id=device_id,
  101. tunnel_type=tunnel_type,
  102. status="disconnected"
  103. )
  104. return session
  105. def get_session(self, session_uuid: str) -> Optional[TunnelSession]:
  106. """Get session by UUID"""
  107. return self.sessions.get(session_uuid)
  108. def update_heartbeat(self, session_uuid: str) -> bool:
  109. """Update session heartbeat"""
  110. session = self.sessions.get(session_uuid)
  111. if not session:
  112. return False
  113. session.last_heartbeat = datetime.now()
  114. return True
  115. def report_device_port(
  116. self,
  117. device_id: str,
  118. tunnel_type: str,
  119. port: Optional[int],
  120. status: str
  121. ):
  122. """Device reports tunnel port allocation"""
  123. status_key = f"{device_id}:{tunnel_type}"
  124. if status == "connected" and port:
  125. self.tunnel_status[status_key] = TunnelStatus(
  126. device_id=device_id,
  127. tunnel_type=tunnel_type,
  128. allocated_port=port,
  129. status="connected",
  130. connected_at=datetime.now(),
  131. last_heartbeat=datetime.now()
  132. )
  133. # Update all waiting sessions for this device
  134. for session in self.sessions.values():
  135. if (session.device_id == device_id and
  136. session.tunnel_type == tunnel_type and
  137. session.status == "waiting"):
  138. session.device_tunnel_port = port
  139. session.status = "ready"
  140. elif status == "disconnected":
  141. if status_key in self.tunnel_status:
  142. self.tunnel_status[status_key].status = "disconnected"
  143. self.tunnel_status[status_key].allocated_port = None
  144. def get_tunnel_status(
  145. self,
  146. device_id: str,
  147. tunnel_type: str
  148. ) -> Optional[TunnelStatus]:
  149. """Get tunnel status for device"""
  150. status_key = f"{device_id}:{tunnel_type}"
  151. return self.tunnel_status.get(status_key)
  152. def spawn_ttyd(
  153. self,
  154. session_uuid: str,
  155. device_tunnel_port: int,
  156. server_host: str = "localhost"
  157. ) -> int:
  158. """
  159. Spawn ttyd process for terminal access
  160. Returns ttyd port
  161. """
  162. session = self.sessions.get(session_uuid)
  163. if not session:
  164. raise ValueError(f"Session not found: {session_uuid}")
  165. # Find free port for ttyd (45000-49999)
  166. ttyd_port = self._find_free_port(45000, 49999)
  167. # Spawn ttyd process
  168. # ttyd connects to device via SSH through the tunnel port
  169. cmd = [
  170. "ttyd",
  171. "--port", str(ttyd_port),
  172. "--once", # Single session
  173. "--writable", # Allow input
  174. "ssh",
  175. "-p", str(device_tunnel_port),
  176. "-o", "StrictHostKeyChecking=no",
  177. "-o", "UserKnownHostsFile=/dev/null",
  178. f"root@{server_host}"
  179. ]
  180. process = subprocess.Popen(
  181. cmd,
  182. stdout=subprocess.DEVNULL,
  183. stderr=subprocess.DEVNULL
  184. )
  185. session.ttyd_port = ttyd_port
  186. session.ttyd_pid = process.pid
  187. print(f"[tunnel] Spawned ttyd on port {ttyd_port} (pid={process.pid})")
  188. return ttyd_port
  189. def _find_free_port(self, start: int, end: int) -> int:
  190. """Find free port in range"""
  191. import socket
  192. for port in range(start, end + 1):
  193. try:
  194. with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
  195. s.bind(('', port))
  196. return port
  197. except OSError:
  198. continue
  199. raise RuntimeError(f"No free ports in range {start}-{end}")
  200. def _kill_ttyd(self, pid: Optional[int]):
  201. """Kill ttyd process gracefully"""
  202. if not pid:
  203. return
  204. try:
  205. os.kill(pid, signal.SIGTERM)
  206. print(f"[tunnel] Killed ttyd process {pid}")
  207. except ProcessLookupError:
  208. pass
  209. def _is_process_alive(self, pid: int) -> bool:
  210. """Check if process is running"""
  211. try:
  212. os.kill(pid, 0) # Signal 0 = check existence
  213. return True
  214. except ProcessLookupError:
  215. return False
  216. # Global tunnel service instance
  217. tunnel_service = TunnelService()