|
@@ -0,0 +1,237 @@
|
|
|
|
|
+package main
|
|
|
|
|
+
|
|
|
|
|
+import (
|
|
|
|
|
+ "bufio"
|
|
|
|
|
+ "bytes"
|
|
|
|
|
+ "encoding/json"
|
|
|
|
|
+ "fmt"
|
|
|
|
|
+ "io"
|
|
|
|
|
+ "log"
|
|
|
|
|
+ "net/http"
|
|
|
|
|
+ "os"
|
|
|
|
|
+ "os/exec"
|
|
|
|
|
+ "regexp"
|
|
|
|
|
+ "strconv"
|
|
|
|
|
+ "strings"
|
|
|
|
|
+ "sync"
|
|
|
|
|
+ "time"
|
|
|
|
|
+)
|
|
|
|
|
+
|
|
|
|
|
+// SSHTunnel manages reverse SSH tunnel to server
|
|
|
|
|
+type SSHTunnel struct {
|
|
|
|
|
+ cfg *Config
|
|
|
|
|
+ client *APIClient
|
|
|
|
|
+ deviceID string
|
|
|
|
|
+ cmd *exec.Cmd
|
|
|
|
|
+ stopChan chan struct{}
|
|
|
|
|
+ allocatedPort int
|
|
|
|
|
+ mu sync.Mutex
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// NewSSHTunnel creates a new SSH tunnel manager
|
|
|
|
|
+func NewSSHTunnel(cfg *Config, client *APIClient, deviceID string) *SSHTunnel {
|
|
|
|
|
+ return &SSHTunnel{
|
|
|
|
|
+ cfg: cfg,
|
|
|
|
|
+ client: client,
|
|
|
|
|
+ deviceID: deviceID,
|
|
|
|
|
+ stopChan: make(chan struct{}),
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// Start initiates the SSH tunnel
|
|
|
|
|
+func (t *SSHTunnel) Start() error {
|
|
|
|
|
+ if !t.cfg.SSHTunnel.Enabled {
|
|
|
|
|
+ log.Println("[tunnel] SSH tunnel disabled")
|
|
|
|
|
+ return nil
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ keyPath := "/etc/beacon/ssh_tunnel_ed25519"
|
|
|
|
|
+
|
|
|
|
|
+ // Verify key exists
|
|
|
|
|
+ if _, err := os.Stat(keyPath); os.IsNotExist(err) {
|
|
|
|
|
+ return fmt.Errorf("SSH key not found: %s", keyPath)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ args := []string{
|
|
|
|
|
+ "-N", // No command execution
|
|
|
|
|
+ "-v", // Verbose (to parse allocated port from stderr)
|
|
|
|
|
+ "-R", "0:localhost:22", // Reverse tunnel with auto-allocated port
|
|
|
|
|
+ "-o", fmt.Sprintf("ServerAliveInterval=%d", t.cfg.SSHTunnel.KeepaliveInterval),
|
|
|
|
|
+ "-o", "ServerAliveCountMax=3",
|
|
|
|
|
+ "-o", "ExitOnForwardFailure=yes",
|
|
|
|
|
+ "-o", "StrictHostKeyChecking=accept-new",
|
|
|
|
|
+ "-i", keyPath,
|
|
|
|
|
+ "-p", fmt.Sprintf("%d", t.cfg.SSHTunnel.Port),
|
|
|
|
|
+ fmt.Sprintf("%s@%s", t.cfg.SSHTunnel.User, t.cfg.SSHTunnel.Server),
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ t.cmd = exec.Command("ssh", args...)
|
|
|
|
|
+
|
|
|
|
|
+ // Capture stderr to parse allocated port
|
|
|
|
|
+ stderr, err := t.cmd.StderrPipe()
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return err
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ if err := t.cmd.Start(); err != nil {
|
|
|
|
|
+ return err
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ log.Printf("[tunnel] SSH tunnel started (server=%s:%d, user=%s)",
|
|
|
|
|
+ t.cfg.SSHTunnel.Server, t.cfg.SSHTunnel.Port, t.cfg.SSHTunnel.User)
|
|
|
|
|
+
|
|
|
|
|
+ // Parse stderr for allocated port
|
|
|
|
|
+ go t.parseStderr(stderr)
|
|
|
|
|
+
|
|
|
|
|
+ // Monitor process
|
|
|
|
|
+ go t.monitor()
|
|
|
|
|
+
|
|
|
|
|
+ return nil
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// parseStderr reads SSH debug output and extracts allocated port
|
|
|
|
|
+func (t *SSHTunnel) parseStderr(r io.Reader) {
|
|
|
|
|
+ scanner := bufio.NewScanner(r)
|
|
|
|
|
+ for scanner.Scan() {
|
|
|
|
|
+ line := scanner.Text()
|
|
|
|
|
+
|
|
|
|
|
+ // Log all SSH debug output
|
|
|
|
|
+ if strings.Contains(line, "debug") {
|
|
|
|
|
+ log.Printf("[tunnel] %s", line)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Parse allocated port
|
|
|
|
|
+ // Example: "Allocated port 12345 for remote forward to localhost:22"
|
|
|
|
|
+ if strings.Contains(line, "Allocated port") {
|
|
|
|
|
+ re := regexp.MustCompile(`Allocated port (\d+)`)
|
|
|
|
|
+ if matches := re.FindStringSubmatch(line); len(matches) > 1 {
|
|
|
|
|
+ port, _ := strconv.Atoi(matches[1])
|
|
|
|
|
+ t.mu.Lock()
|
|
|
|
|
+ t.allocatedPort = port
|
|
|
|
|
+ t.mu.Unlock()
|
|
|
|
|
+
|
|
|
|
|
+ log.Printf("[tunnel] Allocated port: %d", port)
|
|
|
|
|
+
|
|
|
|
|
+ // Report to server
|
|
|
|
|
+ go t.reportPort(port)
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// reportPort sends allocated port to server
|
|
|
|
|
+func (t *SSHTunnel) reportPort(port int) {
|
|
|
|
|
+ body, err := json.Marshal(map[string]interface{}{
|
|
|
|
|
+ "device_id": t.deviceID,
|
|
|
|
|
+ "port": port,
|
|
|
|
|
+ "status": "connected",
|
|
|
|
|
+ })
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ log.Printf("[tunnel] Failed to marshal port report: %v", err)
|
|
|
|
|
+ return
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ req, err := http.NewRequest("POST", t.client.baseURL+"/tunnel-port", bytes.NewReader(body))
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ log.Printf("[tunnel] Failed to create request: %v", err)
|
|
|
|
|
+ return
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ req.Header.Set("Content-Type", "application/json")
|
|
|
|
|
+ req.Header.Set("Authorization", "Bearer "+t.client.token)
|
|
|
|
|
+
|
|
|
|
|
+ resp, err := t.client.httpClient.Do(req)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ log.Printf("[tunnel] Failed to report port: %v", err)
|
|
|
|
|
+ return
|
|
|
|
|
+ }
|
|
|
|
|
+ defer resp.Body.Close()
|
|
|
|
|
+
|
|
|
|
|
+ if resp.StatusCode != 200 && resp.StatusCode != 201 {
|
|
|
|
|
+ respBody, _ := io.ReadAll(resp.Body)
|
|
|
|
|
+ log.Printf("[tunnel] Failed to report port (status=%d): %s", resp.StatusCode, string(respBody))
|
|
|
|
|
+ return
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ log.Printf("[tunnel] Port %d reported to server", port)
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// monitor watches SSH process and handles reconnection
|
|
|
|
|
+func (t *SSHTunnel) monitor() {
|
|
|
|
|
+ err := t.cmd.Wait()
|
|
|
|
|
+
|
|
|
|
|
+ // Clear allocated port
|
|
|
|
|
+ t.mu.Lock()
|
|
|
|
|
+ t.allocatedPort = 0
|
|
|
|
|
+ t.mu.Unlock()
|
|
|
|
|
+
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ log.Printf("[tunnel] SSH tunnel exited with error: %v", err)
|
|
|
|
|
+ } else {
|
|
|
|
|
+ log.Println("[tunnel] SSH tunnel exited")
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Report disconnection
|
|
|
|
|
+ go t.reportDisconnected()
|
|
|
|
|
+
|
|
|
|
|
+ // Auto-reconnect after delay
|
|
|
|
|
+ select {
|
|
|
|
|
+ case <-t.stopChan:
|
|
|
|
|
+ log.Println("[tunnel] Tunnel stopped, not reconnecting")
|
|
|
|
|
+ return
|
|
|
|
|
+ case <-time.After(time.Duration(t.cfg.SSHTunnel.ReconnectDelay) * time.Second):
|
|
|
|
|
+ log.Printf("[tunnel] Reconnecting in %ds...", t.cfg.SSHTunnel.ReconnectDelay)
|
|
|
|
|
+ if err := t.Start(); err != nil {
|
|
|
|
|
+ log.Printf("[tunnel] Reconnect failed: %v", err)
|
|
|
|
|
+ // Will retry after another delay via monitor()
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// reportDisconnected notifies server that tunnel is down
|
|
|
|
|
+func (t *SSHTunnel) reportDisconnected() {
|
|
|
|
|
+ body, _ := json.Marshal(map[string]interface{}{
|
|
|
|
|
+ "device_id": t.deviceID,
|
|
|
|
|
+ "status": "disconnected",
|
|
|
|
|
+ })
|
|
|
|
|
+
|
|
|
|
|
+ req, _ := http.NewRequest("POST", t.client.baseURL+"/tunnel-port", bytes.NewReader(body))
|
|
|
|
|
+ req.Header.Set("Content-Type", "application/json")
|
|
|
|
|
+ req.Header.Set("Authorization", "Bearer "+t.client.token)
|
|
|
|
|
+
|
|
|
|
|
+ resp, err := t.client.httpClient.Do(req)
|
|
|
|
|
+ if err == nil {
|
|
|
|
|
+ resp.Body.Close()
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// Stop terminates the SSH tunnel
|
|
|
|
|
+func (t *SSHTunnel) Stop() {
|
|
|
|
|
+ log.Println("[tunnel] Stopping SSH tunnel...")
|
|
|
|
|
+ close(t.stopChan)
|
|
|
|
|
+
|
|
|
|
|
+ if t.cmd != nil && t.cmd.Process != nil {
|
|
|
|
|
+ t.cmd.Process.Kill()
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// GetAllocatedPort returns currently allocated port
|
|
|
|
|
+func (t *SSHTunnel) GetAllocatedPort() int {
|
|
|
|
|
+ t.mu.Lock()
|
|
|
|
|
+ defer t.mu.Unlock()
|
|
|
|
|
+ return t.allocatedPort
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// Restart restarts the tunnel (useful when config changes)
|
|
|
|
|
+func (t *SSHTunnel) Restart() error {
|
|
|
|
|
+ t.Stop()
|
|
|
|
|
+ time.Sleep(2 * time.Second)
|
|
|
|
|
+ return t.Start()
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+// UpdateConfig updates the tunnel configuration
|
|
|
|
|
+func (t *SSHTunnel) UpdateConfig(cfg *Config) {
|
|
|
|
|
+ t.mu.Lock()
|
|
|
|
|
+ defer t.mu.Unlock()
|
|
|
|
|
+ t.cfg = cfg
|
|
|
|
|
+}
|