tunnel.go 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. package main
  2. import (
  3. "fmt"
  4. "io"
  5. "log"
  6. "net"
  7. "os"
  8. "sync"
  9. "time"
  10. "golang.org/x/crypto/ssh"
  11. )
  12. // SSHTunnel manages a reverse SSH tunnel
  13. type SSHTunnel struct {
  14. cfg *Config
  15. running bool
  16. mu sync.Mutex
  17. stop chan struct{}
  18. }
  19. // NewSSHTunnel creates a new SSH tunnel manager
  20. func NewSSHTunnel(cfg *Config) *SSHTunnel {
  21. return &SSHTunnel{
  22. cfg: cfg,
  23. stop: make(chan struct{}),
  24. }
  25. }
  26. // Start starts the tunnel maintenance loop
  27. func (t *SSHTunnel) Start() {
  28. t.mu.Lock()
  29. if t.running {
  30. t.mu.Unlock()
  31. return
  32. }
  33. t.running = true
  34. t.mu.Unlock()
  35. go t.maintainLoop()
  36. }
  37. // Stop stops the tunnel
  38. func (t *SSHTunnel) Stop() {
  39. t.mu.Lock()
  40. defer t.mu.Unlock()
  41. if t.running {
  42. close(t.stop)
  43. t.running = false
  44. }
  45. }
  46. func (t *SSHTunnel) maintainLoop() {
  47. for {
  48. select {
  49. case <-t.stop:
  50. return
  51. default:
  52. }
  53. t.mu.Lock()
  54. cfg := t.cfg
  55. t.mu.Unlock()
  56. if !cfg.SSHTunnel.Enabled {
  57. time.Sleep(10 * time.Second)
  58. continue
  59. }
  60. if err := t.runTunnel(); err != nil {
  61. log.Printf("SSH tunnel error: %v, reconnecting in %ds...", err, cfg.SSHTunnel.ReconnectDelay)
  62. }
  63. select {
  64. case <-t.stop:
  65. return
  66. case <-time.After(time.Duration(cfg.SSHTunnel.ReconnectDelay) * time.Second):
  67. }
  68. }
  69. }
  70. func (t *SSHTunnel) runTunnel() error {
  71. cfg := t.cfg.SSHTunnel
  72. // Load private key
  73. keyData, err := os.ReadFile(cfg.KeyPath)
  74. if err != nil {
  75. return fmt.Errorf("read key: %w", err)
  76. }
  77. signer, err := ssh.ParsePrivateKey(keyData)
  78. if err != nil {
  79. return fmt.Errorf("parse key: %w", err)
  80. }
  81. // SSH config
  82. sshConfig := &ssh.ClientConfig{
  83. User: cfg.User,
  84. Auth: []ssh.AuthMethod{
  85. ssh.PublicKeys(signer),
  86. },
  87. HostKeyCallback: ssh.InsecureIgnoreHostKey(), // TODO: use known_hosts
  88. Timeout: 30 * time.Second,
  89. }
  90. // Connect to server
  91. serverAddr := fmt.Sprintf("%s:%d", cfg.Server, cfg.Port)
  92. log.Printf("SSH tunnel connecting to %s...", serverAddr)
  93. client, err := ssh.Dial("tcp", serverAddr, sshConfig)
  94. if err != nil {
  95. return fmt.Errorf("ssh dial: %w", err)
  96. }
  97. defer client.Close()
  98. // Request remote port forwarding
  99. remoteAddr := fmt.Sprintf("127.0.0.1:%d", cfg.RemotePort)
  100. listener, err := client.Listen("tcp", remoteAddr)
  101. if err != nil {
  102. return fmt.Errorf("remote listen: %w", err)
  103. }
  104. defer listener.Close()
  105. log.Printf("SSH tunnel established: remote %s -> local :22", remoteAddr)
  106. // Handle incoming connections
  107. errChan := make(chan error, 1)
  108. go func() {
  109. for {
  110. conn, err := listener.Accept()
  111. if err != nil {
  112. errChan <- err
  113. return
  114. }
  115. go t.handleConnection(conn)
  116. }
  117. }()
  118. // Keepalive loop
  119. keepaliveTicker := time.NewTicker(time.Duration(cfg.KeepaliveInterval) * time.Second)
  120. defer keepaliveTicker.Stop()
  121. for {
  122. select {
  123. case <-t.stop:
  124. return nil
  125. case err := <-errChan:
  126. return err
  127. case <-keepaliveTicker.C:
  128. _, _, err := client.SendRequest("keepalive@openssh.com", true, nil)
  129. if err != nil {
  130. return fmt.Errorf("keepalive: %w", err)
  131. }
  132. }
  133. }
  134. }
  135. func (t *SSHTunnel) handleConnection(remoteConn net.Conn) {
  136. defer remoteConn.Close()
  137. // Connect to local SSH
  138. localConn, err := net.Dial("tcp", "127.0.0.1:22")
  139. if err != nil {
  140. log.Printf("Failed to connect to local SSH: %v", err)
  141. return
  142. }
  143. defer localConn.Close()
  144. // Bidirectional copy
  145. var wg sync.WaitGroup
  146. wg.Add(2)
  147. go func() {
  148. defer wg.Done()
  149. io.Copy(localConn, remoteConn)
  150. }()
  151. go func() {
  152. defer wg.Done()
  153. io.Copy(remoteConn, localConn)
  154. }()
  155. wg.Wait()
  156. }
  157. // UpdateConfig updates the tunnel configuration
  158. func (t *SSHTunnel) UpdateConfig(cfg *Config) {
  159. t.mu.Lock()
  160. defer t.mu.Unlock()
  161. t.cfg = cfg
  162. }