| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192 |
- package main
- import (
- "fmt"
- "io"
- "log"
- "net"
- "os"
- "sync"
- "time"
- "golang.org/x/crypto/ssh"
- )
- // SSHTunnel manages a reverse SSH tunnel
- type SSHTunnel struct {
- cfg *Config
- running bool
- mu sync.Mutex
- stop chan struct{}
- }
- // NewSSHTunnel creates a new SSH tunnel manager
- func NewSSHTunnel(cfg *Config) *SSHTunnel {
- return &SSHTunnel{
- cfg: cfg,
- stop: make(chan struct{}),
- }
- }
- // Start starts the tunnel maintenance loop
- func (t *SSHTunnel) Start() {
- t.mu.Lock()
- if t.running {
- t.mu.Unlock()
- return
- }
- t.running = true
- t.mu.Unlock()
- go t.maintainLoop()
- }
- // Stop stops the tunnel
- func (t *SSHTunnel) Stop() {
- t.mu.Lock()
- defer t.mu.Unlock()
- if t.running {
- close(t.stop)
- t.running = false
- }
- }
- func (t *SSHTunnel) maintainLoop() {
- for {
- select {
- case <-t.stop:
- return
- default:
- }
- t.mu.Lock()
- cfg := t.cfg
- t.mu.Unlock()
- if !cfg.SSHTunnel.Enabled {
- time.Sleep(10 * time.Second)
- continue
- }
- if err := t.runTunnel(); err != nil {
- log.Printf("SSH tunnel error: %v, reconnecting in %ds...", err, cfg.SSHTunnel.ReconnectDelay)
- }
- select {
- case <-t.stop:
- return
- case <-time.After(time.Duration(cfg.SSHTunnel.ReconnectDelay) * time.Second):
- }
- }
- }
- func (t *SSHTunnel) runTunnel() error {
- cfg := t.cfg.SSHTunnel
- // Load private key
- keyData, err := os.ReadFile(cfg.KeyPath)
- if err != nil {
- return fmt.Errorf("read key: %w", err)
- }
- signer, err := ssh.ParsePrivateKey(keyData)
- if err != nil {
- return fmt.Errorf("parse key: %w", err)
- }
- // SSH config
- sshConfig := &ssh.ClientConfig{
- User: cfg.User,
- Auth: []ssh.AuthMethod{
- ssh.PublicKeys(signer),
- },
- HostKeyCallback: ssh.InsecureIgnoreHostKey(), // TODO: use known_hosts
- Timeout: 30 * time.Second,
- }
- // Connect to server
- serverAddr := fmt.Sprintf("%s:%d", cfg.Server, cfg.Port)
- log.Printf("SSH tunnel connecting to %s...", serverAddr)
- client, err := ssh.Dial("tcp", serverAddr, sshConfig)
- if err != nil {
- return fmt.Errorf("ssh dial: %w", err)
- }
- defer client.Close()
- // Request remote port forwarding
- remoteAddr := fmt.Sprintf("127.0.0.1:%d", cfg.RemotePort)
- listener, err := client.Listen("tcp", remoteAddr)
- if err != nil {
- return fmt.Errorf("remote listen: %w", err)
- }
- defer listener.Close()
- log.Printf("SSH tunnel established: remote %s -> local :22", remoteAddr)
- // Handle incoming connections
- errChan := make(chan error, 1)
- go func() {
- for {
- conn, err := listener.Accept()
- if err != nil {
- errChan <- err
- return
- }
- go t.handleConnection(conn)
- }
- }()
- // Keepalive loop
- keepaliveTicker := time.NewTicker(time.Duration(cfg.KeepaliveInterval) * time.Second)
- defer keepaliveTicker.Stop()
- for {
- select {
- case <-t.stop:
- return nil
- case err := <-errChan:
- return err
- case <-keepaliveTicker.C:
- _, _, err := client.SendRequest("keepalive@openssh.com", true, nil)
- if err != nil {
- return fmt.Errorf("keepalive: %w", err)
- }
- }
- }
- }
- func (t *SSHTunnel) handleConnection(remoteConn net.Conn) {
- defer remoteConn.Close()
- // Connect to local SSH
- localConn, err := net.Dial("tcp", "127.0.0.1:22")
- if err != nil {
- log.Printf("Failed to connect to local SSH: %v", err)
- return
- }
- defer localConn.Close()
- // Bidirectional copy
- var wg sync.WaitGroup
- wg.Add(2)
- go func() {
- defer wg.Done()
- io.Copy(localConn, remoteConn)
- }()
- go func() {
- defer wg.Done()
- io.Copy(remoteConn, localConn)
- }()
- wg.Wait()
- }
- // UpdateConfig updates the tunnel configuration
- func (t *SSHTunnel) UpdateConfig(cfg *Config) {
- t.mu.Lock()
- defer t.mu.Unlock()
- t.cfg = cfg
- }
|