How to Create An SSH Tunnel in Go
Sometimes resources (such as database servers) are not publicly accessible. This is critical for security, but it can be a pain when writing scripts that need to access these resources for debugging and other ad-hoc tasks.
One solution is to create an SSH tunnel in bash and point your script to it. However:
- You may need to write scripts that are too complicated for bash.
- It can make your scripts brittle if you need to run multiple tunnels or forget to clean them up for long running processes.
- You may not have access to a separate terminal to run the SSH tunnel such as under some automation script.
- Want to use all your existing Go code, but bolt on the tunnel.
- Dislike bash.
Well, here you go. The following code supports creating multiple hassle-free SSH tunnels in pure Go and support using a private key or password authentication:
package mainimport (
"fmt"
"golang.org/x/crypto/ssh"
"io"
"io/ioutil"
"log"
"net"
"strconv"
"strings"
)type Endpoint struct {
Host string
Port int
User string
}func NewEndpoint(s string) *Endpoint {
endpoint := &Endpoint{
Host: s,
} if parts := strings.Split(endpoint.Host, "@"); len(parts) > 1 {
endpoint.User = parts[0]
endpoint.Host = parts[1]
} if parts := strings.Split(endpoint.Host, ":"); len(parts) > 1 {
endpoint.Host = parts[0]
endpoint.Port, _ = strconv.Atoi(parts[1])
} return endpoint
}func (endpoint *Endpoint) String() string {
return fmt.Sprintf("%s:%d", endpoint.Host, endpoint.Port)
}type SSHTunnel struct {
Local *Endpoint
Server *Endpoint
Remote *Endpoint
Config *ssh.ClientConfig
Log *log.Logger
}func (tunnel *SSHTunnel) logf(fmt string, args ...interface{}) {
if tunnel.Log != nil {
tunnel.Log.Printf(fmt, args...)
}
}func (tunnel *SSHTunnel) Start() error {
listener, err := net.Listen("tcp", tunnel.Local.String())
if err != nil {
return err
}
defer listener.Close() tunnel.Local.Port = listener.Addr().(*net.TCPAddr).Port for {
conn, err := listener.Accept()
if err != nil {
return err
} tunnel.logf("accepted connection")
go tunnel.forward(conn)
}
}func (tunnel *SSHTunnel) forward(localConn net.Conn) {
serverConn, err := ssh.Dial("tcp", tunnel.Server.String(), tunnel.Config)
if err != nil {
tunnel.logf("server dial error: %s", err)
return
} tunnel.logf("connected to %s (1 of 2)\n", tunnel.Server.String()) remoteConn, err := serverConn.Dial("tcp", tunnel.Remote.String())
if err != nil {
tunnel.logf("remote dial error: %s", err)
return
} tunnel.logf("connected to %s (2 of 2)\n", tunnel.Remote.String()) copyConn := func(writer, reader net.Conn) {
_, err := io.Copy(writer, reader)
if err != nil {
tunnel.logf("io.Copy error: %s", err)
}
} go copyConn(localConn, remoteConn)
go copyConn(remoteConn, localConn)
}func PrivateKeyFile(file string) ssh.AuthMethod {
buffer, err := ioutil.ReadFile(file)
if err != nil {
return nil
} key, err := ssh.ParsePrivateKey(buffer)
if err != nil {
return nil
} return ssh.PublicKeys(key)
}func NewSSHTunnel(tunnel string, auth ssh.AuthMethod, destination string) *SSHTunnel {
// A random port will be chosen for us.
localEndpoint := NewEndpoint("localhost:0") server := NewEndpoint(tunnel)
if server.Port == 0 {
server.Port = 22
} sshTunnel := &SSHTunnel{
Config: &ssh.ClientConfig{
User: server.User,
Auth: []ssh.AuthMethod{auth},
HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
// Always accept key.
return nil
},
},
Local: localEndpoint,
Server: server,
Remote: NewEndpoint(destination),
} return sshTunnel
}
Here is an example of usage:
func main() {
// Setup the tunnel, but do not yet start it yet.
tunnel := NewSSHTunnel(
// User and host of tunnel server, it will default to port 22
// if not specified.
"ec2-user@jumpbox.us-east-1.mydomain.com", // Pick ONE of the following authentication methods:
PrivateKeyFile("path/to/private/key.pem"), // 1. private key
ssh.Password("password"), // 2. password // The destination host and port of the actual server.
"dqrsdfdssdfx.us-east-1.redshift.amazonaws.com:5439",
) // You can provide a logger for debugging, or remove this line to
// make it silent.
tunnel.Log = log.New(os.Stdout, "", log.Ldate | log.Lmicroseconds) // Start the server in the background. You will need to wait a
// small amount of time for it to bind to the localhost port
// before you can start sending connections.
go tunnel.Start()
time.Sleep(100 * time.Millisecond) // NewSSHTunnel will bind to a random port so that you can have
// multiple SSH tunnels available. The port is available through:
// tunnel.Local.Port
// You can use any normal Go code to connect to the destination
// server through localhost. You may need to use 127.0.0.1 for
// some libraries.
//
// Here is an example of connecting to a PostgreSQL server:
conn := fmt.Sprintf("host=127.0.0.1 port=%d username=foo", tunnel.Local.Port)
db, err := sql.Open("postgres", conn) // ...
}
A big thanks to Svetlin Ralchev who provided a lot of the original bits and pieces.
If you prefer, the code is also available as a package at github.com/elliotchance/sshtunnel
.
Originally published at http://elliot.land on January 15, 2019.