- Notifications
You must be signed in to change notification settings - Fork1k
feat: Add workspace agent for SSH#318
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.
Already on GitHub?Sign in to your account
Uh oh!
There was an error while loading.Please reload this page.
Changes fromall commits
fa7489a
2606fda
b54e815
c484d47
cc2bcde
ae36c63
4932ba7
0c84636
5601b4d
94cf442
055ce11
94e03a2
f8566f3
70d3723
223540e
1a48bea
ee0ee70
30189aa
516b605
File filter
Filter by extension
Conversations
Uh oh!
There was an error while loading.Please reload this page.
Jump to
Uh oh!
There was an error while loading.Please reload this page.
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,329 @@ | ||
package agent | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others.Learn more. Will this agent be run as a standalone executable, or part of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others.Learn more. Part of the I initially had this in I'd appreciate your thoughts here. It's more of an agent than a daemon, because it doesn't require system-level privileges, and is active rather than passive. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others.Learn more. Thanks for all the helpful details@kylecarbs !
That sounds good!
Agreed, it seems like our convention here is to have packages at the top-level - don't see a compelling reason to switch that.
Makes sense to me 👍 | ||
import ( | ||
"context" | ||
"crypto/rand" | ||
"crypto/rsa" | ||
"errors" | ||
"fmt" | ||
"io" | ||
"net" | ||
"os/exec" | ||
"os/user" | ||
"sync" | ||
"time" | ||
"cdr.dev/slog" | ||
"github.com/coder/coder/agent/usershell" | ||
"github.com/coder/coder/peer" | ||
"github.com/coder/coder/peerbroker" | ||
"github.com/coder/coder/pty" | ||
"github.com/coder/retry" | ||
"github.com/gliderlabs/ssh" | ||
gossh "golang.org/x/crypto/ssh" | ||
"golang.org/x/xerrors" | ||
) | ||
func DialSSH(conn *peer.Conn) (net.Conn, error) { | ||
channel, err := conn.Dial(context.Background(), "ssh", &peer.ChannelOptions{ | ||
Protocol: "ssh", | ||
}) | ||
if err != nil { | ||
return nil, err | ||
} | ||
return channel.NetConn(), nil | ||
} | ||
func DialSSHClient(conn *peer.Conn) (*gossh.Client, error) { | ||
netConn, err := DialSSH(conn) | ||
if err != nil { | ||
return nil, err | ||
} | ||
sshConn, channels, requests, err := gossh.NewClientConn(netConn, "localhost:22", &gossh.ClientConfig{ | ||
Config: gossh.Config{ | ||
Ciphers: []string{"arcfour"}, | ||
}, | ||
// SSH host validation isn't helpful, because obtaining a peer | ||
// connection already signifies user-intent to dial a workspace. | ||
// #nosec | ||
HostKeyCallback: gossh.InsecureIgnoreHostKey(), | ||
}) | ||
if err != nil { | ||
return nil, err | ||
} | ||
return gossh.NewClient(sshConn, channels, requests), nil | ||
} | ||
type Options struct { | ||
Logger slog.Logger | ||
} | ||
type Dialer func(ctx context.Context) (*peerbroker.Listener, error) | ||
func New(dialer Dialer, options *Options) io.Closer { | ||
ctx, cancelFunc := context.WithCancel(context.Background()) | ||
server := &server{ | ||
clientDialer: dialer, | ||
options: options, | ||
closeCancel: cancelFunc, | ||
closed: make(chan struct{}), | ||
} | ||
server.init(ctx) | ||
return server | ||
} | ||
type server struct { | ||
clientDialer Dialer | ||
options *Options | ||
closeCancel context.CancelFunc | ||
closeMutex sync.Mutex | ||
closed chan struct{} | ||
sshServer *ssh.Server | ||
} | ||
func (s *server) init(ctx context.Context) { | ||
// Clients' should ignore the host key when connecting. | ||
// The agent needs to authenticate with coderd to SSH, | ||
// so SSH authentication doesn't improve security. | ||
randomHostKey, err := rsa.GenerateKey(rand.Reader, 2048) | ||
if err != nil { | ||
panic(err) | ||
} | ||
randomSigner, err := gossh.NewSignerFromKey(randomHostKey) | ||
if err != nil { | ||
panic(err) | ||
} | ||
sshLogger := s.options.Logger.Named("ssh-server") | ||
forwardHandler := &ssh.ForwardedTCPHandler{} | ||
s.sshServer = &ssh.Server{ | ||
ChannelHandlers: ssh.DefaultChannelHandlers, | ||
ConnectionFailedCallback: func(conn net.Conn, err error) { | ||
sshLogger.Info(ctx, "ssh connection ended", slog.Error(err)) | ||
}, | ||
Handler: func(session ssh.Session) { | ||
err := s.handleSSHSession(session) | ||
if err != nil { | ||
s.options.Logger.Debug(ctx, "ssh session failed", slog.Error(err)) | ||
_ = session.Exit(1) | ||
return | ||
} | ||
}, | ||
HostSigners: []ssh.Signer{randomSigner}, | ||
LocalPortForwardingCallback: func(ctx ssh.Context, destinationHost string, destinationPort uint32) bool { | ||
// Allow local port forwarding all! | ||
sshLogger.Debug(ctx, "local port forward", | ||
slog.F("destination-host", destinationHost), | ||
slog.F("destination-port", destinationPort)) | ||
return true | ||
}, | ||
PtyCallback: func(ctx ssh.Context, pty ssh.Pty) bool { | ||
return true | ||
}, | ||
ReversePortForwardingCallback: func(ctx ssh.Context, bindHost string, bindPort uint32) bool { | ||
// Allow reverse port forwarding all! | ||
sshLogger.Debug(ctx, "local port forward", | ||
slog.F("bind-host", bindHost), | ||
slog.F("bind-port", bindPort)) | ||
return true | ||
}, | ||
RequestHandlers: map[string]ssh.RequestHandler{ | ||
"tcpip-forward": forwardHandler.HandleSSHRequest, | ||
"cancel-tcpip-forward": forwardHandler.HandleSSHRequest, | ||
}, | ||
ServerConfigCallback: func(ctx ssh.Context) *gossh.ServerConfig { | ||
return &gossh.ServerConfig{ | ||
Config: gossh.Config{ | ||
// "arcfour" is the fastest SSH cipher. We prioritize throughput | ||
// over encryption here, because the WebRTC connection is already | ||
// encrypted. If possible, we'd disable encryption entirely here. | ||
Ciphers: []string{"arcfour"}, | ||
}, | ||
NoClientAuth: true, | ||
} | ||
}, | ||
} | ||
go s.run(ctx) | ||
} | ||
func (*server) handleSSHSession(session ssh.Session) error { | ||
var ( | ||
command string | ||
args = []string{} | ||
err error | ||
) | ||
username := session.User() | ||
if username == "" { | ||
currentUser, err := user.Current() | ||
if err != nil { | ||
return xerrors.Errorf("get current user: %w", err) | ||
} | ||
username = currentUser.Username | ||
} | ||
// gliderlabs/ssh returns a command slice of zero | ||
// when a shell is requested. | ||
if len(session.Command()) == 0 { | ||
command, err = usershell.Get(username) | ||
if err != nil { | ||
return xerrors.Errorf("get user shell: %w", err) | ||
} | ||
} else { | ||
command = session.Command()[0] | ||
if len(session.Command()) > 1 { | ||
args = session.Command()[1:] | ||
} | ||
} | ||
signals := make(chan ssh.Signal) | ||
breaks := make(chan bool) | ||
defer close(signals) | ||
defer close(breaks) | ||
go func() { | ||
for { | ||
select { | ||
case <-session.Context().Done(): | ||
return | ||
// Ignore signals and breaks for now! | ||
case <-signals: | ||
case <-breaks: | ||
} | ||
} | ||
}() | ||
cmd := exec.CommandContext(session.Context(), command, args...) | ||
cmd.Env = session.Environ() | ||
sshPty, windowSize, isPty := session.Pty() | ||
if isPty { | ||
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", sshPty.Term)) | ||
ptty, process, err := pty.Start(cmd) | ||
if err != nil { | ||
return xerrors.Errorf("start command: %w", err) | ||
} | ||
go func() { | ||
for win := range windowSize { | ||
err := ptty.Resize(uint16(win.Width), uint16(win.Height)) | ||
if err != nil { | ||
panic(err) | ||
} | ||
} | ||
}() | ||
go func() { | ||
_, _ = io.Copy(ptty.Input(), session) | ||
}() | ||
go func() { | ||
_, _ = io.Copy(session, ptty.Output()) | ||
}() | ||
_, _ = process.Wait() | ||
_ = ptty.Close() | ||
return nil | ||
} | ||
cmd.Stdout = session | ||
cmd.Stderr = session | ||
// This blocks forever until stdin is received if we don't | ||
// use StdinPipe. It's unknown what causes this. | ||
stdinPipe, err := cmd.StdinPipe() | ||
if err != nil { | ||
return xerrors.Errorf("create stdin pipe: %w", err) | ||
} | ||
go func() { | ||
_, _ = io.Copy(stdinPipe, session) | ||
}() | ||
err = cmd.Start() | ||
if err != nil { | ||
return xerrors.Errorf("start: %w", err) | ||
} | ||
_ = cmd.Wait() | ||
return nil | ||
} | ||
func (s *server) run(ctx context.Context) { | ||
var peerListener *peerbroker.Listener | ||
var err error | ||
// An exponential back-off occurs when the connection is failing to dial. | ||
// This is to prevent server spam in case of a coderd outage. | ||
for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); { | ||
peerListener, err = s.clientDialer(ctx) | ||
if err != nil { | ||
if errors.Is(err, context.Canceled) { | ||
return | ||
} | ||
if s.isClosed() { | ||
return | ||
} | ||
s.options.Logger.Warn(context.Background(), "failed to dial", slog.Error(err)) | ||
continue | ||
} | ||
s.options.Logger.Debug(context.Background(), "connected") | ||
break | ||
} | ||
select { | ||
case <-ctx.Done(): | ||
return | ||
default: | ||
} | ||
for { | ||
conn, err := peerListener.Accept() | ||
if err != nil { | ||
if s.isClosed() { | ||
return | ||
} | ||
s.options.Logger.Debug(ctx, "peer listener accept exited; restarting connection", slog.Error(err)) | ||
s.run(ctx) | ||
return | ||
} | ||
go s.handlePeerConn(ctx, conn) | ||
} | ||
} | ||
func (s *server) handlePeerConn(ctx context.Context, conn *peer.Conn) { | ||
for { | ||
channel, err := conn.Accept(ctx) | ||
if err != nil { | ||
if errors.Is(err, peer.ErrClosed) || s.isClosed() { | ||
return | ||
} | ||
s.options.Logger.Debug(ctx, "accept channel from peer connection", slog.Error(err)) | ||
return | ||
} | ||
switch channel.Protocol() { | ||
case "ssh": | ||
s.sshServer.HandleConn(channel.NetConn()) | ||
default: | ||
s.options.Logger.Warn(ctx, "unhandled protocol from channel", | ||
slog.F("protocol", channel.Protocol()), | ||
slog.F("label", channel.Label()), | ||
) | ||
} | ||
} | ||
} | ||
// isClosed returns whether the API is closed or not. | ||
func (s *server) isClosed() bool { | ||
select { | ||
case <-s.closed: | ||
return true | ||
default: | ||
return false | ||
} | ||
} | ||
func (s *server) Close() error { | ||
s.closeMutex.Lock() | ||
defer s.closeMutex.Unlock() | ||
if s.isClosed() { | ||
return nil | ||
} | ||
close(s.closed) | ||
s.closeCancel() | ||
_ = s.sshServer.Close() | ||
return nil | ||
} |
Uh oh!
There was an error while loading.Please reload this page.