|
| 1 | +package agent |
| 2 | + |
| 3 | +import ( |
| 4 | +"context" |
| 5 | +"crypto/rand" |
| 6 | +"crypto/rsa" |
| 7 | +"errors" |
| 8 | +"fmt" |
| 9 | +"io" |
| 10 | +"net" |
| 11 | +"os/exec" |
| 12 | +"sync" |
| 13 | +"time" |
| 14 | + |
| 15 | +"cdr.dev/slog" |
| 16 | +"github.com/coder/coder/agent/usershell" |
| 17 | +"github.com/coder/coder/peer" |
| 18 | +"github.com/coder/coder/peerbroker" |
| 19 | +"github.com/coder/coder/pty" |
| 20 | +"github.com/coder/retry" |
| 21 | + |
| 22 | +"github.com/gliderlabs/ssh" |
| 23 | +gossh"golang.org/x/crypto/ssh" |
| 24 | +"golang.org/x/xerrors" |
| 25 | +) |
| 26 | + |
| 27 | +funcDialSSH(conn*peer.Conn) (net.Conn,error) { |
| 28 | +channel,err:=conn.Dial(context.Background(),"ssh",&peer.ChannelOptions{ |
| 29 | +Protocol:"ssh", |
| 30 | +}) |
| 31 | +iferr!=nil { |
| 32 | +returnnil,err |
| 33 | +} |
| 34 | +returnchannel.NetConn(),nil |
| 35 | +} |
| 36 | + |
| 37 | +funcDialSSHClient(conn*peer.Conn) (*gossh.Client,error) { |
| 38 | +netConn,err:=DialSSH(conn) |
| 39 | +iferr!=nil { |
| 40 | +returnnil,err |
| 41 | +} |
| 42 | +sshConn,channels,requests,err:=gossh.NewClientConn(netConn,"localhost:22",&gossh.ClientConfig{ |
| 43 | +User:"kyle", |
| 44 | +Config: gossh.Config{ |
| 45 | +Ciphers: []string{"arcfour"}, |
| 46 | +}, |
| 47 | +HostKeyCallback:gossh.InsecureIgnoreHostKey(), |
| 48 | +}) |
| 49 | +iferr!=nil { |
| 50 | +returnnil,err |
| 51 | +} |
| 52 | +returngossh.NewClient(sshConn,channels,requests),nil |
| 53 | +} |
| 54 | + |
| 55 | +typeOptionsstruct { |
| 56 | +Logger slog.Logger |
| 57 | +} |
| 58 | + |
| 59 | +typeDialerfunc(ctx context.Context) (*peerbroker.Listener,error) |
| 60 | + |
| 61 | +funcNew(dialerDialer,options*Options) io.Closer { |
| 62 | +ctx,cancelFunc:=context.WithCancel(context.Background()) |
| 63 | +server:=&server{ |
| 64 | +clientDialer:dialer, |
| 65 | +options:options, |
| 66 | +closeCancel:cancelFunc, |
| 67 | +} |
| 68 | +server.init(ctx) |
| 69 | +returnserver |
| 70 | +} |
| 71 | + |
| 72 | +typeserverstruct { |
| 73 | +clientDialerDialer |
| 74 | +options*Options |
| 75 | + |
| 76 | +closeCancel context.CancelFunc |
| 77 | +closeMutex sync.Mutex |
| 78 | +closedchanstruct{} |
| 79 | +closeErrorerror |
| 80 | + |
| 81 | +sshServer*ssh.Server |
| 82 | +} |
| 83 | + |
| 84 | +func (s*server)init(ctx context.Context) { |
| 85 | +// Clients' should ignore the host key when connecting. |
| 86 | +// The agent needs to authenticate with coderd to SSH, |
| 87 | +// so SSH authentication doesn't improve security. |
| 88 | +randomHostKey,err:=rsa.GenerateKey(rand.Reader,2048) |
| 89 | +iferr!=nil { |
| 90 | +panic(err) |
| 91 | +} |
| 92 | +randomSigner,err:=gossh.NewSignerFromKey(randomHostKey) |
| 93 | +iferr!=nil { |
| 94 | +panic(err) |
| 95 | +} |
| 96 | +sshLogger:=s.options.Logger.Named("ssh-server") |
| 97 | +forwardHandler:=&ssh.ForwardedTCPHandler{} |
| 98 | +s.sshServer=&ssh.Server{ |
| 99 | +ChannelHandlers:ssh.DefaultChannelHandlers, |
| 100 | +ConnectionFailedCallback:func(conn net.Conn,errerror) { |
| 101 | +sshLogger.Info(ctx,"ssh connection ended",slog.Error(err)) |
| 102 | +}, |
| 103 | +Handler:func(session ssh.Session) { |
| 104 | +err:=s.handleSSHSession(session) |
| 105 | +iferr!=nil { |
| 106 | +s.options.Logger.Debug(ctx,"ssh session failed",slog.Error(err)) |
| 107 | +_=session.Exit(1) |
| 108 | +return |
| 109 | +} |
| 110 | +}, |
| 111 | +HostSigners: []ssh.Signer{randomSigner}, |
| 112 | +LocalPortForwardingCallback:func(ctx ssh.Context,destinationHoststring,destinationPortuint32)bool { |
| 113 | +// Allow local port forwarding all! |
| 114 | +sshLogger.Debug(ctx,"local port forward", |
| 115 | +slog.F("destination-host",destinationHost), |
| 116 | +slog.F("destination-port",destinationPort)) |
| 117 | +returntrue |
| 118 | +}, |
| 119 | +PtyCallback:func(ctx ssh.Context,pty ssh.Pty)bool { |
| 120 | +returntrue |
| 121 | +}, |
| 122 | +ReversePortForwardingCallback:func(ctx ssh.Context,bindHoststring,bindPortuint32)bool { |
| 123 | +// Allow reverse port forwarding all! |
| 124 | +sshLogger.Debug(ctx,"local port forward", |
| 125 | +slog.F("bind-host",bindHost), |
| 126 | +slog.F("bind-port",bindPort)) |
| 127 | +returntrue |
| 128 | +}, |
| 129 | +RequestHandlers:map[string]ssh.RequestHandler{ |
| 130 | +"tcpip-forward":forwardHandler.HandleSSHRequest, |
| 131 | +"cancel-tcpip-forward":forwardHandler.HandleSSHRequest, |
| 132 | +}, |
| 133 | +ServerConfigCallback:func(ctx ssh.Context)*gossh.ServerConfig { |
| 134 | +return&gossh.ServerConfig{ |
| 135 | +Config: gossh.Config{ |
| 136 | +// "arcfour" is the fastest SSH cipher. We prioritize throughput |
| 137 | +// over encryption here, because the WebRTC connection is already |
| 138 | +// encrypted. If possible, we'd disable encryption entirely here. |
| 139 | +Ciphers: []string{"arcfour"}, |
| 140 | +}, |
| 141 | +NoClientAuth:true, |
| 142 | +} |
| 143 | +}, |
| 144 | +} |
| 145 | + |
| 146 | +gos.run(ctx) |
| 147 | +} |
| 148 | + |
| 149 | +func (*server)handleSSHSession(session ssh.Session)error { |
| 150 | +var ( |
| 151 | +commandstring |
| 152 | +args= []string{} |
| 153 | +errerror |
| 154 | +) |
| 155 | + |
| 156 | +// gliderlabs/ssh returns a command slice of zero |
| 157 | +// when a shell is requested. |
| 158 | +iflen(session.Command())==0 { |
| 159 | +command,err=usershell.Get(session.User()) |
| 160 | +iferr!=nil { |
| 161 | +returnxerrors.Errorf("get user shell: %w",err) |
| 162 | +} |
| 163 | +}else { |
| 164 | +command=session.Command()[0] |
| 165 | +iflen(session.Command())>1 { |
| 166 | +args=session.Command()[1:] |
| 167 | +} |
| 168 | +} |
| 169 | + |
| 170 | +signals:=make(chan ssh.Signal) |
| 171 | +breaks:=make(chanbool) |
| 172 | +deferclose(signals) |
| 173 | +deferclose(breaks) |
| 174 | +gofunc() { |
| 175 | +for { |
| 176 | +select { |
| 177 | +case<-session.Context().Done(): |
| 178 | +return |
| 179 | +// Ignore signals and breaks for now! |
| 180 | +case<-signals: |
| 181 | +case<-breaks: |
| 182 | +} |
| 183 | +} |
| 184 | +}() |
| 185 | + |
| 186 | +cmd:=exec.CommandContext(session.Context(),command,args...) |
| 187 | +cmd.Env=session.Environ() |
| 188 | + |
| 189 | +sshPty,windowSize,isPty:=session.Pty() |
| 190 | +ifisPty { |
| 191 | +cmd.Env=append(cmd.Env,fmt.Sprintf("TERM=%s",sshPty.Term)) |
| 192 | +ptty,process,err:=pty.Start(cmd) |
| 193 | +iferr!=nil { |
| 194 | +returnxerrors.Errorf("start command: %w",err) |
| 195 | +} |
| 196 | +gofunc() { |
| 197 | +forwin:=rangewindowSize { |
| 198 | +err:=ptty.Resize(uint16(win.Width),uint16(win.Height)) |
| 199 | +iferr!=nil { |
| 200 | +panic(err) |
| 201 | +} |
| 202 | +} |
| 203 | +}() |
| 204 | +gofunc() { |
| 205 | +_,_=io.Copy(ptty.Input(),session) |
| 206 | +}() |
| 207 | +gofunc() { |
| 208 | +_,_=io.Copy(session,ptty.Output()) |
| 209 | +}() |
| 210 | +_,err=process.Wait() |
| 211 | +returnerr |
| 212 | +} |
| 213 | + |
| 214 | +cmd.Stdout=session |
| 215 | +cmd.Stderr=session |
| 216 | +// This blocks forever until stdin is received if we don't |
| 217 | +// use StdinPipe. It's unknown what causes this. |
| 218 | +stdinPipe,err:=cmd.StdinPipe() |
| 219 | +iferr!=nil { |
| 220 | +returnxerrors.Errorf("create stdin pipe: %w",err) |
| 221 | +} |
| 222 | +gofunc() { |
| 223 | +_,_=io.Copy(stdinPipe,session) |
| 224 | +}() |
| 225 | +err=cmd.Start() |
| 226 | +iferr!=nil { |
| 227 | +returnxerrors.Errorf("start: %w",err) |
| 228 | +} |
| 229 | +returncmd.Wait() |
| 230 | +} |
| 231 | + |
| 232 | +func (s*server)run(ctx context.Context) { |
| 233 | +varpeerListener*peerbroker.Listener |
| 234 | +varerrerror |
| 235 | +// An exponential back-off occurs when the connection is failing to dial. |
| 236 | +// This is to prevent server spam in case of a coderd outage. |
| 237 | +forretrier:=retry.New(50*time.Millisecond,10*time.Second);retrier.Wait(ctx); { |
| 238 | +peerListener,err=s.clientDialer(ctx) |
| 239 | +iferr!=nil { |
| 240 | +iferrors.Is(err,context.Canceled) { |
| 241 | +return |
| 242 | +} |
| 243 | +ifs.isClosed() { |
| 244 | +return |
| 245 | +} |
| 246 | +s.options.Logger.Warn(context.Background(),"failed to dial",slog.Error(err)) |
| 247 | +continue |
| 248 | +} |
| 249 | +s.options.Logger.Debug(context.Background(),"connected") |
| 250 | +break |
| 251 | +} |
| 252 | + |
| 253 | +for { |
| 254 | +conn,err:=peerListener.Accept() |
| 255 | +iferr!=nil { |
| 256 | +// This is closed! |
| 257 | +return |
| 258 | +} |
| 259 | +gos.handlePeerConn(ctx,conn) |
| 260 | +} |
| 261 | +} |
| 262 | + |
| 263 | +func (s*server)handlePeerConn(ctx context.Context,conn*peer.Conn) { |
| 264 | +for { |
| 265 | +channel,err:=conn.Accept(ctx) |
| 266 | +iferr!=nil { |
| 267 | +// TODO: Log here! |
| 268 | +return |
| 269 | +} |
| 270 | + |
| 271 | +switchchannel.Protocol() { |
| 272 | +case"ssh": |
| 273 | +s.sshServer.HandleConn(channel.NetConn()) |
| 274 | +case"proxy": |
| 275 | +// Proxy the port provided. |
| 276 | +} |
| 277 | +} |
| 278 | +} |
| 279 | + |
| 280 | +// isClosed returns whether the API is closed or not. |
| 281 | +func (s*server)isClosed()bool { |
| 282 | +select { |
| 283 | +case<-s.closed: |
| 284 | +returntrue |
| 285 | +default: |
| 286 | +returnfalse |
| 287 | +} |
| 288 | +} |
| 289 | + |
| 290 | +func (s*server)Close()error { |
| 291 | +s.sshServer.Close() |
| 292 | +returnnil |
| 293 | +} |