package main import ( "fmt" "io" "log" "net" "os" "sync" "time" "golang.org/x/crypto/ssh" ) // SSHTunnel manages a reverse SSH tunnel type SSHTunnel struct { cfg *Config running bool mu sync.Mutex stop chan struct{} } // NewSSHTunnel creates a new SSH tunnel manager func NewSSHTunnel(cfg *Config) *SSHTunnel { return &SSHTunnel{ cfg: cfg, stop: make(chan struct{}), } } // Start starts the tunnel maintenance loop func (t *SSHTunnel) Start() { t.mu.Lock() if t.running { t.mu.Unlock() return } t.running = true t.mu.Unlock() go t.maintainLoop() } // Stop stops the tunnel func (t *SSHTunnel) Stop() { t.mu.Lock() defer t.mu.Unlock() if t.running { close(t.stop) t.running = false } } func (t *SSHTunnel) maintainLoop() { for { select { case <-t.stop: return default: } t.mu.Lock() cfg := t.cfg t.mu.Unlock() if !cfg.SSHTunnel.Enabled { time.Sleep(10 * time.Second) continue } if err := t.runTunnel(); err != nil { log.Printf("SSH tunnel error: %v, reconnecting in %ds...", err, cfg.SSHTunnel.ReconnectDelay) } select { case <-t.stop: return case <-time.After(time.Duration(cfg.SSHTunnel.ReconnectDelay) * time.Second): } } } func (t *SSHTunnel) runTunnel() error { cfg := t.cfg.SSHTunnel // Load private key keyData, err := os.ReadFile(cfg.KeyPath) if err != nil { return fmt.Errorf("read key: %w", err) } signer, err := ssh.ParsePrivateKey(keyData) if err != nil { return fmt.Errorf("parse key: %w", err) } // SSH config sshConfig := &ssh.ClientConfig{ User: cfg.User, Auth: []ssh.AuthMethod{ ssh.PublicKeys(signer), }, HostKeyCallback: ssh.InsecureIgnoreHostKey(), // TODO: use known_hosts Timeout: 30 * time.Second, } // Connect to server serverAddr := fmt.Sprintf("%s:%d", cfg.Server, cfg.Port) log.Printf("SSH tunnel connecting to %s...", serverAddr) client, err := ssh.Dial("tcp", serverAddr, sshConfig) if err != nil { return fmt.Errorf("ssh dial: %w", err) } defer client.Close() // Request remote port forwarding remoteAddr := fmt.Sprintf("127.0.0.1:%d", cfg.RemotePort) listener, err := client.Listen("tcp", remoteAddr) if err != nil { return fmt.Errorf("remote listen: %w", err) } defer listener.Close() log.Printf("SSH tunnel established: remote %s -> local :22", remoteAddr) // Handle incoming connections errChan := make(chan error, 1) go func() { for { conn, err := listener.Accept() if err != nil { errChan <- err return } go t.handleConnection(conn) } }() // Keepalive loop keepaliveTicker := time.NewTicker(time.Duration(cfg.KeepaliveInterval) * time.Second) defer keepaliveTicker.Stop() for { select { case <-t.stop: return nil case err := <-errChan: return err case <-keepaliveTicker.C: _, _, err := client.SendRequest("keepalive@openssh.com", true, nil) if err != nil { return fmt.Errorf("keepalive: %w", err) } } } } func (t *SSHTunnel) handleConnection(remoteConn net.Conn) { defer remoteConn.Close() // Connect to local SSH localConn, err := net.Dial("tcp", "127.0.0.1:22") if err != nil { log.Printf("Failed to connect to local SSH: %v", err) return } defer localConn.Close() // Bidirectional copy var wg sync.WaitGroup wg.Add(2) go func() { defer wg.Done() io.Copy(localConn, remoteConn) }() go func() { defer wg.Done() io.Copy(remoteConn, localConn) }() wg.Wait() } // UpdateConfig updates the tunnel configuration func (t *SSHTunnel) UpdateConfig(cfg *Config) { t.mu.Lock() defer t.mu.Unlock() t.cfg = cfg }