|
|
@@ -1,47 +1,55 @@
|
|
|
package main
|
|
|
|
|
|
import (
|
|
|
- "bufio"
|
|
|
- "bytes"
|
|
|
- "encoding/json"
|
|
|
"fmt"
|
|
|
- "io"
|
|
|
"log"
|
|
|
- "net/http"
|
|
|
"os"
|
|
|
"os/exec"
|
|
|
- "regexp"
|
|
|
- "strconv"
|
|
|
- "strings"
|
|
|
"sync"
|
|
|
"time"
|
|
|
)
|
|
|
|
|
|
+// TunnelConfig contains configuration for a specific tunnel
|
|
|
+type TunnelConfig struct {
|
|
|
+ Enabled bool
|
|
|
+ Server string
|
|
|
+ Port int
|
|
|
+ User string
|
|
|
+ RemotePort int
|
|
|
+ KeepaliveInterval int
|
|
|
+ ReconnectDelay int
|
|
|
+}
|
|
|
+
|
|
|
// SSHTunnel manages reverse SSH tunnel to server
|
|
|
type SSHTunnel struct {
|
|
|
- cfg *Config
|
|
|
- client *APIClient
|
|
|
- deviceID string
|
|
|
- cmd *exec.Cmd
|
|
|
- stopChan chan struct{}
|
|
|
- allocatedPort int
|
|
|
- mu sync.Mutex
|
|
|
+ name string // "ssh" or "dashboard"
|
|
|
+ localPort int // Local port to forward (22 for SSH, 8080 for dashboard)
|
|
|
+
|
|
|
+ cfg *TunnelConfig
|
|
|
+ cmd *exec.Cmd
|
|
|
+ stopChan chan struct{}
|
|
|
+ mu sync.Mutex
|
|
|
}
|
|
|
|
|
|
// NewSSHTunnel creates a new SSH tunnel manager
|
|
|
-func NewSSHTunnel(cfg *Config, client *APIClient, deviceID string) *SSHTunnel {
|
|
|
+func NewSSHTunnel(name string, localPort int, cfg *TunnelConfig) *SSHTunnel {
|
|
|
return &SSHTunnel{
|
|
|
- cfg: cfg,
|
|
|
- client: client,
|
|
|
- deviceID: deviceID,
|
|
|
- stopChan: make(chan struct{}),
|
|
|
+ name: name,
|
|
|
+ localPort: localPort,
|
|
|
+ cfg: cfg,
|
|
|
+ stopChan: make(chan struct{}),
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// Start initiates the SSH tunnel
|
|
|
func (t *SSHTunnel) Start() error {
|
|
|
- if !t.cfg.SSHTunnel.Enabled {
|
|
|
- log.Println("[tunnel] SSH tunnel disabled")
|
|
|
+ if !t.cfg.Enabled {
|
|
|
+ log.Printf("[%s-tunnel] Tunnel disabled", t.name)
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+
|
|
|
+ if t.cfg.RemotePort == 0 {
|
|
|
+ log.Printf("[%s-tunnel] Remote port not allocated yet, waiting...", t.name)
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
@@ -52,36 +60,29 @@ func (t *SSHTunnel) Start() error {
|
|
|
return fmt.Errorf("SSH key not found: %s", keyPath)
|
|
|
}
|
|
|
|
|
|
+ // Build reverse tunnel string: remote_port:localhost:local_port
|
|
|
+ reverseSpec := fmt.Sprintf("%d:localhost:%d", t.cfg.RemotePort, t.localPort)
|
|
|
+
|
|
|
args := []string{
|
|
|
"-N", // No command execution
|
|
|
- "-v", // Verbose (to parse allocated port from stderr)
|
|
|
- "-R", "0:localhost:22", // Reverse tunnel with auto-allocated port
|
|
|
- "-o", fmt.Sprintf("ServerAliveInterval=%d", t.cfg.SSHTunnel.KeepaliveInterval),
|
|
|
+ "-R", reverseSpec, // Reverse tunnel with fixed port
|
|
|
+ "-o", fmt.Sprintf("ServerAliveInterval=%d", t.cfg.KeepaliveInterval),
|
|
|
"-o", "ServerAliveCountMax=3",
|
|
|
"-o", "ExitOnForwardFailure=yes",
|
|
|
"-o", "StrictHostKeyChecking=accept-new",
|
|
|
"-i", keyPath,
|
|
|
- "-p", fmt.Sprintf("%d", t.cfg.SSHTunnel.Port),
|
|
|
- fmt.Sprintf("%s@%s", t.cfg.SSHTunnel.User, t.cfg.SSHTunnel.Server),
|
|
|
+ "-p", fmt.Sprintf("%d", t.cfg.Port),
|
|
|
+ fmt.Sprintf("%s@%s", t.cfg.User, t.cfg.Server),
|
|
|
}
|
|
|
|
|
|
t.cmd = exec.Command("ssh", args...)
|
|
|
|
|
|
- // Capture stderr to parse allocated port
|
|
|
- stderr, err := t.cmd.StderrPipe()
|
|
|
- if err != nil {
|
|
|
- return err
|
|
|
- }
|
|
|
-
|
|
|
if err := t.cmd.Start(); err != nil {
|
|
|
return err
|
|
|
}
|
|
|
|
|
|
- log.Printf("[tunnel] SSH tunnel started (server=%s:%d, user=%s)",
|
|
|
- t.cfg.SSHTunnel.Server, t.cfg.SSHTunnel.Port, t.cfg.SSHTunnel.User)
|
|
|
-
|
|
|
- // Parse stderr for allocated port
|
|
|
- go t.parseStderr(stderr)
|
|
|
+ log.Printf("[%s-tunnel] Started: %s:%d -> localhost:%d (remote_port=%d)",
|
|
|
+ t.name, t.cfg.Server, t.cfg.Port, t.localPort, t.cfg.RemotePort)
|
|
|
|
|
|
// Monitor process
|
|
|
go t.monitor()
|
|
|
@@ -89,148 +90,65 @@ func (t *SSHTunnel) Start() error {
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
-// parseStderr reads SSH debug output and extracts allocated port
|
|
|
-func (t *SSHTunnel) parseStderr(r io.Reader) {
|
|
|
- scanner := bufio.NewScanner(r)
|
|
|
- for scanner.Scan() {
|
|
|
- line := scanner.Text()
|
|
|
-
|
|
|
- // Log all SSH debug output
|
|
|
- if strings.Contains(line, "debug") {
|
|
|
- log.Printf("[tunnel] %s", line)
|
|
|
- }
|
|
|
-
|
|
|
- // Parse allocated port
|
|
|
- // Example: "Allocated port 12345 for remote forward to localhost:22"
|
|
|
- if strings.Contains(line, "Allocated port") {
|
|
|
- re := regexp.MustCompile(`Allocated port (\d+)`)
|
|
|
- if matches := re.FindStringSubmatch(line); len(matches) > 1 {
|
|
|
- port, _ := strconv.Atoi(matches[1])
|
|
|
- t.mu.Lock()
|
|
|
- t.allocatedPort = port
|
|
|
- t.mu.Unlock()
|
|
|
-
|
|
|
- log.Printf("[tunnel] Allocated port: %d", port)
|
|
|
-
|
|
|
- // Report to server
|
|
|
- go t.reportPort(port)
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
-// reportPort sends allocated port to server
|
|
|
-func (t *SSHTunnel) reportPort(port int) {
|
|
|
- body, err := json.Marshal(map[string]interface{}{
|
|
|
- "device_id": t.deviceID,
|
|
|
- "port": port,
|
|
|
- "status": "connected",
|
|
|
- })
|
|
|
- if err != nil {
|
|
|
- log.Printf("[tunnel] Failed to marshal port report: %v", err)
|
|
|
- return
|
|
|
- }
|
|
|
-
|
|
|
- req, err := http.NewRequest("POST", t.client.baseURL+"/tunnel-port", bytes.NewReader(body))
|
|
|
- if err != nil {
|
|
|
- log.Printf("[tunnel] Failed to create request: %v", err)
|
|
|
- return
|
|
|
- }
|
|
|
-
|
|
|
- req.Header.Set("Content-Type", "application/json")
|
|
|
- req.Header.Set("Authorization", "Bearer "+t.client.token)
|
|
|
-
|
|
|
- resp, err := t.client.httpClient.Do(req)
|
|
|
- if err != nil {
|
|
|
- log.Printf("[tunnel] Failed to report port: %v", err)
|
|
|
- return
|
|
|
- }
|
|
|
- defer resp.Body.Close()
|
|
|
-
|
|
|
- if resp.StatusCode != 200 && resp.StatusCode != 201 {
|
|
|
- respBody, _ := io.ReadAll(resp.Body)
|
|
|
- log.Printf("[tunnel] Failed to report port (status=%d): %s", resp.StatusCode, string(respBody))
|
|
|
- return
|
|
|
- }
|
|
|
-
|
|
|
- log.Printf("[tunnel] Port %d reported to server", port)
|
|
|
-}
|
|
|
-
|
|
|
// monitor watches SSH process and handles reconnection
|
|
|
func (t *SSHTunnel) monitor() {
|
|
|
err := t.cmd.Wait()
|
|
|
|
|
|
- // Clear allocated port
|
|
|
- t.mu.Lock()
|
|
|
- t.allocatedPort = 0
|
|
|
- t.mu.Unlock()
|
|
|
-
|
|
|
if err != nil {
|
|
|
- log.Printf("[tunnel] SSH tunnel exited with error: %v", err)
|
|
|
+ log.Printf("[%s-tunnel] Tunnel exited with error: %v", t.name, err)
|
|
|
} else {
|
|
|
- log.Println("[tunnel] SSH tunnel exited")
|
|
|
+ log.Printf("[%s-tunnel] Tunnel exited", t.name)
|
|
|
}
|
|
|
|
|
|
- // Report disconnection
|
|
|
- go t.reportDisconnected()
|
|
|
-
|
|
|
// Auto-reconnect after delay
|
|
|
select {
|
|
|
case <-t.stopChan:
|
|
|
- log.Println("[tunnel] Tunnel stopped, not reconnecting")
|
|
|
+ log.Printf("[%s-tunnel] Stopped, not reconnecting", t.name)
|
|
|
return
|
|
|
- case <-time.After(time.Duration(t.cfg.SSHTunnel.ReconnectDelay) * time.Second):
|
|
|
- log.Printf("[tunnel] Reconnecting in %ds...", t.cfg.SSHTunnel.ReconnectDelay)
|
|
|
+ case <-time.After(time.Duration(t.cfg.ReconnectDelay) * time.Second):
|
|
|
+ log.Printf("[%s-tunnel] Reconnecting in %ds...", t.name, t.cfg.ReconnectDelay)
|
|
|
if err := t.Start(); err != nil {
|
|
|
- log.Printf("[tunnel] Reconnect failed: %v", err)
|
|
|
+ log.Printf("[%s-tunnel] Reconnect failed: %v", t.name, err)
|
|
|
// Will retry after another delay via monitor()
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-// reportDisconnected notifies server that tunnel is down
|
|
|
-func (t *SSHTunnel) reportDisconnected() {
|
|
|
- body, _ := json.Marshal(map[string]interface{}{
|
|
|
- "device_id": t.deviceID,
|
|
|
- "status": "disconnected",
|
|
|
- })
|
|
|
+// Stop terminates the SSH tunnel
|
|
|
+func (t *SSHTunnel) Stop() {
|
|
|
+ log.Printf("[%s-tunnel] Stopping...", t.name)
|
|
|
|
|
|
- req, _ := http.NewRequest("POST", t.client.baseURL+"/tunnel-port", bytes.NewReader(body))
|
|
|
- req.Header.Set("Content-Type", "application/json")
|
|
|
- req.Header.Set("Authorization", "Bearer "+t.client.token)
|
|
|
+ t.mu.Lock()
|
|
|
+ defer t.mu.Unlock()
|
|
|
|
|
|
- resp, err := t.client.httpClient.Do(req)
|
|
|
- if err == nil {
|
|
|
- resp.Body.Close()
|
|
|
+ select {
|
|
|
+ case <-t.stopChan:
|
|
|
+ // Already closed
|
|
|
+ return
|
|
|
+ default:
|
|
|
+ close(t.stopChan)
|
|
|
}
|
|
|
-}
|
|
|
-
|
|
|
-// Stop terminates the SSH tunnel
|
|
|
-func (t *SSHTunnel) Stop() {
|
|
|
- log.Println("[tunnel] Stopping SSH tunnel...")
|
|
|
- close(t.stopChan)
|
|
|
|
|
|
if t.cmd != nil && t.cmd.Process != nil {
|
|
|
t.cmd.Process.Kill()
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-// GetAllocatedPort returns currently allocated port
|
|
|
-func (t *SSHTunnel) GetAllocatedPort() int {
|
|
|
- t.mu.Lock()
|
|
|
- defer t.mu.Unlock()
|
|
|
- return t.allocatedPort
|
|
|
-}
|
|
|
-
|
|
|
// Restart restarts the tunnel (useful when config changes)
|
|
|
func (t *SSHTunnel) Restart() error {
|
|
|
t.Stop()
|
|
|
time.Sleep(2 * time.Second)
|
|
|
+
|
|
|
+ // Recreate stop channel
|
|
|
+ t.mu.Lock()
|
|
|
+ t.stopChan = make(chan struct{})
|
|
|
+ t.mu.Unlock()
|
|
|
+
|
|
|
return t.Start()
|
|
|
}
|
|
|
|
|
|
// UpdateConfig updates the tunnel configuration
|
|
|
-func (t *SSHTunnel) UpdateConfig(cfg *Config) {
|
|
|
+func (t *SSHTunnel) UpdateConfig(cfg *TunnelConfig) {
|
|
|
t.mu.Lock()
|
|
|
defer t.mu.Unlock()
|
|
|
t.cfg = cfg
|