package main import ( "bufio" "bytes" "encoding/json" "fmt" "io" "log" "net/http" "os" "os/exec" "regexp" "strconv" "strings" "sync" "time" ) // SSHTunnel manages reverse SSH tunnel to server type SSHTunnel struct { cfg *Config client *APIClient deviceID string cmd *exec.Cmd stopChan chan struct{} allocatedPort int mu sync.Mutex } // NewSSHTunnel creates a new SSH tunnel manager func NewSSHTunnel(cfg *Config, client *APIClient, deviceID string) *SSHTunnel { return &SSHTunnel{ cfg: cfg, client: client, deviceID: deviceID, stopChan: make(chan struct{}), } } // Start initiates the SSH tunnel func (t *SSHTunnel) Start() error { if !t.cfg.SSHTunnel.Enabled { log.Println("[tunnel] SSH tunnel disabled") return nil } keyPath := "/etc/beacon/ssh_tunnel_ed25519" // Verify key exists if _, err := os.Stat(keyPath); os.IsNotExist(err) { return fmt.Errorf("SSH key not found: %s", keyPath) } args := []string{ "-N", // No command execution "-v", // Verbose (to parse allocated port from stderr) "-R", "0:localhost:22", // Reverse tunnel with auto-allocated port "-o", fmt.Sprintf("ServerAliveInterval=%d", t.cfg.SSHTunnel.KeepaliveInterval), "-o", "ServerAliveCountMax=3", "-o", "ExitOnForwardFailure=yes", "-o", "StrictHostKeyChecking=accept-new", "-i", keyPath, "-p", fmt.Sprintf("%d", t.cfg.SSHTunnel.Port), fmt.Sprintf("%s@%s", t.cfg.SSHTunnel.User, t.cfg.SSHTunnel.Server), } t.cmd = exec.Command("ssh", args...) // Capture stderr to parse allocated port stderr, err := t.cmd.StderrPipe() if err != nil { return err } if err := t.cmd.Start(); err != nil { return err } log.Printf("[tunnel] SSH tunnel started (server=%s:%d, user=%s)", t.cfg.SSHTunnel.Server, t.cfg.SSHTunnel.Port, t.cfg.SSHTunnel.User) // Parse stderr for allocated port go t.parseStderr(stderr) // Monitor process go t.monitor() return nil } // parseStderr reads SSH debug output and extracts allocated port func (t *SSHTunnel) parseStderr(r io.Reader) { scanner := bufio.NewScanner(r) for scanner.Scan() { line := scanner.Text() // Log all SSH debug output if strings.Contains(line, "debug") { log.Printf("[tunnel] %s", line) } // Parse allocated port // Example: "Allocated port 12345 for remote forward to localhost:22" if strings.Contains(line, "Allocated port") { re := regexp.MustCompile(`Allocated port (\d+)`) if matches := re.FindStringSubmatch(line); len(matches) > 1 { port, _ := strconv.Atoi(matches[1]) t.mu.Lock() t.allocatedPort = port t.mu.Unlock() log.Printf("[tunnel] Allocated port: %d", port) // Report to server go t.reportPort(port) } } } } // reportPort sends allocated port to server func (t *SSHTunnel) reportPort(port int) { body, err := json.Marshal(map[string]interface{}{ "device_id": t.deviceID, "port": port, "status": "connected", }) if err != nil { log.Printf("[tunnel] Failed to marshal port report: %v", err) return } req, err := http.NewRequest("POST", t.client.baseURL+"/tunnel-port", bytes.NewReader(body)) if err != nil { log.Printf("[tunnel] Failed to create request: %v", err) return } req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+t.client.token) resp, err := t.client.httpClient.Do(req) if err != nil { log.Printf("[tunnel] Failed to report port: %v", err) return } defer resp.Body.Close() if resp.StatusCode != 200 && resp.StatusCode != 201 { respBody, _ := io.ReadAll(resp.Body) log.Printf("[tunnel] Failed to report port (status=%d): %s", resp.StatusCode, string(respBody)) return } log.Printf("[tunnel] Port %d reported to server", port) } // monitor watches SSH process and handles reconnection func (t *SSHTunnel) monitor() { err := t.cmd.Wait() // Clear allocated port t.mu.Lock() t.allocatedPort = 0 t.mu.Unlock() if err != nil { log.Printf("[tunnel] SSH tunnel exited with error: %v", err) } else { log.Println("[tunnel] SSH tunnel exited") } // Report disconnection go t.reportDisconnected() // Auto-reconnect after delay select { case <-t.stopChan: log.Println("[tunnel] Tunnel stopped, not reconnecting") return case <-time.After(time.Duration(t.cfg.SSHTunnel.ReconnectDelay) * time.Second): log.Printf("[tunnel] Reconnecting in %ds...", t.cfg.SSHTunnel.ReconnectDelay) if err := t.Start(); err != nil { log.Printf("[tunnel] Reconnect failed: %v", err) // Will retry after another delay via monitor() } } } // reportDisconnected notifies server that tunnel is down func (t *SSHTunnel) reportDisconnected() { body, _ := json.Marshal(map[string]interface{}{ "device_id": t.deviceID, "status": "disconnected", }) req, _ := http.NewRequest("POST", t.client.baseURL+"/tunnel-port", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+t.client.token) resp, err := t.client.httpClient.Do(req) if err == nil { resp.Body.Close() } } // Stop terminates the SSH tunnel func (t *SSHTunnel) Stop() { log.Println("[tunnel] Stopping SSH tunnel...") close(t.stopChan) if t.cmd != nil && t.cmd.Process != nil { t.cmd.Process.Kill() } } // GetAllocatedPort returns currently allocated port func (t *SSHTunnel) GetAllocatedPort() int { t.mu.Lock() defer t.mu.Unlock() return t.allocatedPort } // Restart restarts the tunnel (useful when config changes) func (t *SSHTunnel) Restart() error { t.Stop() time.Sleep(2 * time.Second) return t.Start() } // UpdateConfig updates the tunnel configuration func (t *SSHTunnel) UpdateConfig(cfg *Config) { t.mu.Lock() defer t.mu.Unlock() t.cfg = cfg }