| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237 |
- package main
- import (
- "bufio"
- "bytes"
- "encoding/json"
- "fmt"
- "io"
- "log"
- "net/http"
- "os"
- "os/exec"
- "regexp"
- "strconv"
- "strings"
- "sync"
- "time"
- )
- // SSHTunnel manages reverse SSH tunnel to server
- type SSHTunnel struct {
- cfg *Config
- client *APIClient
- deviceID string
- cmd *exec.Cmd
- stopChan chan struct{}
- allocatedPort int
- mu sync.Mutex
- }
- // NewSSHTunnel creates a new SSH tunnel manager
- func NewSSHTunnel(cfg *Config, client *APIClient, deviceID string) *SSHTunnel {
- return &SSHTunnel{
- cfg: cfg,
- client: client,
- deviceID: deviceID,
- stopChan: make(chan struct{}),
- }
- }
- // Start initiates the SSH tunnel
- func (t *SSHTunnel) Start() error {
- if !t.cfg.SSHTunnel.Enabled {
- log.Println("[tunnel] SSH tunnel disabled")
- return nil
- }
- keyPath := "/etc/beacon/ssh_tunnel_ed25519"
- // Verify key exists
- if _, err := os.Stat(keyPath); os.IsNotExist(err) {
- return fmt.Errorf("SSH key not found: %s", keyPath)
- }
- args := []string{
- "-N", // No command execution
- "-v", // Verbose (to parse allocated port from stderr)
- "-R", "0:localhost:22", // Reverse tunnel with auto-allocated port
- "-o", fmt.Sprintf("ServerAliveInterval=%d", t.cfg.SSHTunnel.KeepaliveInterval),
- "-o", "ServerAliveCountMax=3",
- "-o", "ExitOnForwardFailure=yes",
- "-o", "StrictHostKeyChecking=accept-new",
- "-i", keyPath,
- "-p", fmt.Sprintf("%d", t.cfg.SSHTunnel.Port),
- fmt.Sprintf("%s@%s", t.cfg.SSHTunnel.User, t.cfg.SSHTunnel.Server),
- }
- t.cmd = exec.Command("ssh", args...)
- // Capture stderr to parse allocated port
- stderr, err := t.cmd.StderrPipe()
- if err != nil {
- return err
- }
- if err := t.cmd.Start(); err != nil {
- return err
- }
- log.Printf("[tunnel] SSH tunnel started (server=%s:%d, user=%s)",
- t.cfg.SSHTunnel.Server, t.cfg.SSHTunnel.Port, t.cfg.SSHTunnel.User)
- // Parse stderr for allocated port
- go t.parseStderr(stderr)
- // Monitor process
- go t.monitor()
- return nil
- }
- // parseStderr reads SSH debug output and extracts allocated port
- func (t *SSHTunnel) parseStderr(r io.Reader) {
- scanner := bufio.NewScanner(r)
- for scanner.Scan() {
- line := scanner.Text()
- // Log all SSH debug output
- if strings.Contains(line, "debug") {
- log.Printf("[tunnel] %s", line)
- }
- // Parse allocated port
- // Example: "Allocated port 12345 for remote forward to localhost:22"
- if strings.Contains(line, "Allocated port") {
- re := regexp.MustCompile(`Allocated port (\d+)`)
- if matches := re.FindStringSubmatch(line); len(matches) > 1 {
- port, _ := strconv.Atoi(matches[1])
- t.mu.Lock()
- t.allocatedPort = port
- t.mu.Unlock()
- log.Printf("[tunnel] Allocated port: %d", port)
- // Report to server
- go t.reportPort(port)
- }
- }
- }
- }
- // reportPort sends allocated port to server
- func (t *SSHTunnel) reportPort(port int) {
- body, err := json.Marshal(map[string]interface{}{
- "device_id": t.deviceID,
- "port": port,
- "status": "connected",
- })
- if err != nil {
- log.Printf("[tunnel] Failed to marshal port report: %v", err)
- return
- }
- req, err := http.NewRequest("POST", t.client.baseURL+"/tunnel-port", bytes.NewReader(body))
- if err != nil {
- log.Printf("[tunnel] Failed to create request: %v", err)
- return
- }
- req.Header.Set("Content-Type", "application/json")
- req.Header.Set("Authorization", "Bearer "+t.client.token)
- resp, err := t.client.httpClient.Do(req)
- if err != nil {
- log.Printf("[tunnel] Failed to report port: %v", err)
- return
- }
- defer resp.Body.Close()
- if resp.StatusCode != 200 && resp.StatusCode != 201 {
- respBody, _ := io.ReadAll(resp.Body)
- log.Printf("[tunnel] Failed to report port (status=%d): %s", resp.StatusCode, string(respBody))
- return
- }
- log.Printf("[tunnel] Port %d reported to server", port)
- }
- // monitor watches SSH process and handles reconnection
- func (t *SSHTunnel) monitor() {
- err := t.cmd.Wait()
- // Clear allocated port
- t.mu.Lock()
- t.allocatedPort = 0
- t.mu.Unlock()
- if err != nil {
- log.Printf("[tunnel] SSH tunnel exited with error: %v", err)
- } else {
- log.Println("[tunnel] SSH tunnel exited")
- }
- // Report disconnection
- go t.reportDisconnected()
- // Auto-reconnect after delay
- select {
- case <-t.stopChan:
- log.Println("[tunnel] Tunnel stopped, not reconnecting")
- return
- case <-time.After(time.Duration(t.cfg.SSHTunnel.ReconnectDelay) * time.Second):
- log.Printf("[tunnel] Reconnecting in %ds...", t.cfg.SSHTunnel.ReconnectDelay)
- if err := t.Start(); err != nil {
- log.Printf("[tunnel] Reconnect failed: %v", err)
- // Will retry after another delay via monitor()
- }
- }
- }
- // reportDisconnected notifies server that tunnel is down
- func (t *SSHTunnel) reportDisconnected() {
- body, _ := json.Marshal(map[string]interface{}{
- "device_id": t.deviceID,
- "status": "disconnected",
- })
- req, _ := http.NewRequest("POST", t.client.baseURL+"/tunnel-port", bytes.NewReader(body))
- req.Header.Set("Content-Type", "application/json")
- req.Header.Set("Authorization", "Bearer "+t.client.token)
- resp, err := t.client.httpClient.Do(req)
- if err == nil {
- resp.Body.Close()
- }
- }
- // Stop terminates the SSH tunnel
- func (t *SSHTunnel) Stop() {
- log.Println("[tunnel] Stopping SSH tunnel...")
- close(t.stopChan)
- if t.cmd != nil && t.cmd.Process != nil {
- t.cmd.Process.Kill()
- }
- }
- // GetAllocatedPort returns currently allocated port
- func (t *SSHTunnel) GetAllocatedPort() int {
- t.mu.Lock()
- defer t.mu.Unlock()
- return t.allocatedPort
- }
- // Restart restarts the tunnel (useful when config changes)
- func (t *SSHTunnel) Restart() error {
- t.Stop()
- time.Sleep(2 * time.Second)
- return t.Start()
- }
- // UpdateConfig updates the tunnel configuration
- func (t *SSHTunnel) UpdateConfig(cfg *Config) {
- t.mu.Lock()
- defer t.mu.Unlock()
- t.cfg = cfg
- }
|