ssh_tunnel.go 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. package main
  2. import (
  3. "bufio"
  4. "bytes"
  5. "encoding/json"
  6. "fmt"
  7. "io"
  8. "log"
  9. "net/http"
  10. "os"
  11. "os/exec"
  12. "regexp"
  13. "strconv"
  14. "strings"
  15. "sync"
  16. "time"
  17. )
  18. // SSHTunnel manages reverse SSH tunnel to server
  19. type SSHTunnel struct {
  20. cfg *Config
  21. client *APIClient
  22. deviceID string
  23. cmd *exec.Cmd
  24. stopChan chan struct{}
  25. allocatedPort int
  26. mu sync.Mutex
  27. }
  28. // NewSSHTunnel creates a new SSH tunnel manager
  29. func NewSSHTunnel(cfg *Config, client *APIClient, deviceID string) *SSHTunnel {
  30. return &SSHTunnel{
  31. cfg: cfg,
  32. client: client,
  33. deviceID: deviceID,
  34. stopChan: make(chan struct{}),
  35. }
  36. }
  37. // Start initiates the SSH tunnel
  38. func (t *SSHTunnel) Start() error {
  39. if !t.cfg.SSHTunnel.Enabled {
  40. log.Println("[tunnel] SSH tunnel disabled")
  41. return nil
  42. }
  43. keyPath := "/etc/beacon/ssh_tunnel_ed25519"
  44. // Verify key exists
  45. if _, err := os.Stat(keyPath); os.IsNotExist(err) {
  46. return fmt.Errorf("SSH key not found: %s", keyPath)
  47. }
  48. args := []string{
  49. "-N", // No command execution
  50. "-v", // Verbose (to parse allocated port from stderr)
  51. "-R", "0:localhost:22", // Reverse tunnel with auto-allocated port
  52. "-o", fmt.Sprintf("ServerAliveInterval=%d", t.cfg.SSHTunnel.KeepaliveInterval),
  53. "-o", "ServerAliveCountMax=3",
  54. "-o", "ExitOnForwardFailure=yes",
  55. "-o", "StrictHostKeyChecking=accept-new",
  56. "-i", keyPath,
  57. "-p", fmt.Sprintf("%d", t.cfg.SSHTunnel.Port),
  58. fmt.Sprintf("%s@%s", t.cfg.SSHTunnel.User, t.cfg.SSHTunnel.Server),
  59. }
  60. t.cmd = exec.Command("ssh", args...)
  61. // Capture stderr to parse allocated port
  62. stderr, err := t.cmd.StderrPipe()
  63. if err != nil {
  64. return err
  65. }
  66. if err := t.cmd.Start(); err != nil {
  67. return err
  68. }
  69. log.Printf("[tunnel] SSH tunnel started (server=%s:%d, user=%s)",
  70. t.cfg.SSHTunnel.Server, t.cfg.SSHTunnel.Port, t.cfg.SSHTunnel.User)
  71. // Parse stderr for allocated port
  72. go t.parseStderr(stderr)
  73. // Monitor process
  74. go t.monitor()
  75. return nil
  76. }
  77. // parseStderr reads SSH debug output and extracts allocated port
  78. func (t *SSHTunnel) parseStderr(r io.Reader) {
  79. scanner := bufio.NewScanner(r)
  80. for scanner.Scan() {
  81. line := scanner.Text()
  82. // Log all SSH debug output
  83. if strings.Contains(line, "debug") {
  84. log.Printf("[tunnel] %s", line)
  85. }
  86. // Parse allocated port
  87. // Example: "Allocated port 12345 for remote forward to localhost:22"
  88. if strings.Contains(line, "Allocated port") {
  89. re := regexp.MustCompile(`Allocated port (\d+)`)
  90. if matches := re.FindStringSubmatch(line); len(matches) > 1 {
  91. port, _ := strconv.Atoi(matches[1])
  92. t.mu.Lock()
  93. t.allocatedPort = port
  94. t.mu.Unlock()
  95. log.Printf("[tunnel] Allocated port: %d", port)
  96. // Report to server
  97. go t.reportPort(port)
  98. }
  99. }
  100. }
  101. }
  102. // reportPort sends allocated port to server
  103. func (t *SSHTunnel) reportPort(port int) {
  104. body, err := json.Marshal(map[string]interface{}{
  105. "device_id": t.deviceID,
  106. "port": port,
  107. "status": "connected",
  108. })
  109. if err != nil {
  110. log.Printf("[tunnel] Failed to marshal port report: %v", err)
  111. return
  112. }
  113. req, err := http.NewRequest("POST", t.client.baseURL+"/tunnel-port", bytes.NewReader(body))
  114. if err != nil {
  115. log.Printf("[tunnel] Failed to create request: %v", err)
  116. return
  117. }
  118. req.Header.Set("Content-Type", "application/json")
  119. req.Header.Set("Authorization", "Bearer "+t.client.token)
  120. resp, err := t.client.httpClient.Do(req)
  121. if err != nil {
  122. log.Printf("[tunnel] Failed to report port: %v", err)
  123. return
  124. }
  125. defer resp.Body.Close()
  126. if resp.StatusCode != 200 && resp.StatusCode != 201 {
  127. respBody, _ := io.ReadAll(resp.Body)
  128. log.Printf("[tunnel] Failed to report port (status=%d): %s", resp.StatusCode, string(respBody))
  129. return
  130. }
  131. log.Printf("[tunnel] Port %d reported to server", port)
  132. }
  133. // monitor watches SSH process and handles reconnection
  134. func (t *SSHTunnel) monitor() {
  135. err := t.cmd.Wait()
  136. // Clear allocated port
  137. t.mu.Lock()
  138. t.allocatedPort = 0
  139. t.mu.Unlock()
  140. if err != nil {
  141. log.Printf("[tunnel] SSH tunnel exited with error: %v", err)
  142. } else {
  143. log.Println("[tunnel] SSH tunnel exited")
  144. }
  145. // Report disconnection
  146. go t.reportDisconnected()
  147. // Auto-reconnect after delay
  148. select {
  149. case <-t.stopChan:
  150. log.Println("[tunnel] Tunnel stopped, not reconnecting")
  151. return
  152. case <-time.After(time.Duration(t.cfg.SSHTunnel.ReconnectDelay) * time.Second):
  153. log.Printf("[tunnel] Reconnecting in %ds...", t.cfg.SSHTunnel.ReconnectDelay)
  154. if err := t.Start(); err != nil {
  155. log.Printf("[tunnel] Reconnect failed: %v", err)
  156. // Will retry after another delay via monitor()
  157. }
  158. }
  159. }
  160. // reportDisconnected notifies server that tunnel is down
  161. func (t *SSHTunnel) reportDisconnected() {
  162. body, _ := json.Marshal(map[string]interface{}{
  163. "device_id": t.deviceID,
  164. "status": "disconnected",
  165. })
  166. req, _ := http.NewRequest("POST", t.client.baseURL+"/tunnel-port", bytes.NewReader(body))
  167. req.Header.Set("Content-Type", "application/json")
  168. req.Header.Set("Authorization", "Bearer "+t.client.token)
  169. resp, err := t.client.httpClient.Do(req)
  170. if err == nil {
  171. resp.Body.Close()
  172. }
  173. }
  174. // Stop terminates the SSH tunnel
  175. func (t *SSHTunnel) Stop() {
  176. log.Println("[tunnel] Stopping SSH tunnel...")
  177. close(t.stopChan)
  178. if t.cmd != nil && t.cmd.Process != nil {
  179. t.cmd.Process.Kill()
  180. }
  181. }
  182. // GetAllocatedPort returns currently allocated port
  183. func (t *SSHTunnel) GetAllocatedPort() int {
  184. t.mu.Lock()
  185. defer t.mu.Unlock()
  186. return t.allocatedPort
  187. }
  188. // Restart restarts the tunnel (useful when config changes)
  189. func (t *SSHTunnel) Restart() error {
  190. t.Stop()
  191. time.Sleep(2 * time.Second)
  192. return t.Start()
  193. }
  194. // UpdateConfig updates the tunnel configuration
  195. func (t *SSHTunnel) UpdateConfig(cfg *Config) {
  196. t.mu.Lock()
  197. defer t.mu.Unlock()
  198. t.cfg = cfg
  199. }