11package coderd
22
33import (
4+ "context"
45"database/sql"
56"encoding/json"
67"fmt"
@@ -16,6 +17,7 @@ import (
1617"nhooyr.io/websocket"
1718
1819"cdr.dev/slog"
20+
1921"github.com/coder/coder/agent"
2022"github.com/coder/coder/coderd/database"
2123"github.com/coder/coder/coderd/httpapi"
@@ -69,17 +71,18 @@ func (api *API) workspaceAgentDial(rw http.ResponseWriter, r *http.Request) {
6971})
7072return
7173}
72- defer func () {
73- _ = conn .Close (websocket .StatusNormalClosure ,"" )
74- }()
74+
75+ ctx ,wsNetConn := websocketNetConn (r .Context (),conn ,websocket .MessageBinary )
76+ defer wsNetConn .Close ()// Also closes conn.
77+
7578config := yamux .DefaultConfig ()
7679config .LogOutput = io .Discard
77- session ,err := yamux .Server (websocket . NetConn ( r . Context (), conn , websocket . MessageBinary ) ,config )
80+ session ,err := yamux .Server (wsNetConn ,config )
7881if err != nil {
7982_ = conn .Close (websocket .StatusAbnormalClosure ,err .Error ())
8083return
8184}
82- err = peerbroker .ProxyListen (r . Context () ,session , peerbroker.ProxyOptions {
85+ err = peerbroker .ProxyListen (ctx ,session , peerbroker.ProxyOptions {
8386ChannelID :workspaceAgent .ID .String (),
8487Logger :api .Logger .Named ("peerbroker-proxy-dial" ),
8588Pubsub :api .Pubsub ,
@@ -193,13 +196,12 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
193196return
194197}
195198
196- defer func () {
197- _ = conn .Close (websocket .StatusNormalClosure ,"" )
198- }()
199+ ctx ,wsNetConn := websocketNetConn (r .Context (),conn ,websocket .MessageBinary )
200+ defer wsNetConn .Close ()// Also closes conn.
199201
200202config := yamux .DefaultConfig ()
201203config .LogOutput = io .Discard
202- session ,err := yamux .Server (websocket . NetConn ( r . Context (), conn , websocket . MessageBinary ) ,config )
204+ session ,err := yamux .Server (wsNetConn ,config )
203205if err != nil {
204206_ = conn .Close (websocket .StatusAbnormalClosure ,err .Error ())
205207return
@@ -229,7 +231,7 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
229231}
230232disconnectedAt := workspaceAgent .DisconnectedAt
231233updateConnectionTimes := func ()error {
232- err = api .Database .UpdateWorkspaceAgentConnectionByID (r . Context () , database.UpdateWorkspaceAgentConnectionByIDParams {
234+ err = api .Database .UpdateWorkspaceAgentConnectionByID (ctx , database.UpdateWorkspaceAgentConnectionByIDParams {
233235ID :workspaceAgent .ID ,
234236FirstConnectedAt :firstConnectedAt ,
235237LastConnectedAt :lastConnectedAt ,
@@ -255,7 +257,7 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
255257return
256258}
257259
258- api .Logger .Info (r . Context () ,"accepting agent" ,slog .F ("resource" ,resource ),slog .F ("agent" ,workspaceAgent ))
260+ api .Logger .Info (ctx ,"accepting agent" ,slog .F ("resource" ,resource ),slog .F ("agent" ,workspaceAgent ))
259261
260262ticker := time .NewTicker (api .AgentConnectionUpdateFrequency )
261263defer ticker .Stop ()
@@ -324,16 +326,16 @@ func (api *API) workspaceAgentTurn(rw http.ResponseWriter, r *http.Request) {
324326})
325327return
326328}
327- defer func () {
328- _ = wsConn . Close ( websocket . StatusNormalClosure , "" )
329- }()
330- netConn := websocket . NetConn ( r . Context (), wsConn , websocket . MessageBinary )
331- api .Logger .Debug (r . Context () ,"accepting turn connection" ,slog .F ("remote-address" ,r .RemoteAddr ),slog .F ("local-address" ,localAddress ))
329+
330+ ctx , wsNetConn := websocketNetConn ( r . Context (), wsConn , websocket . MessageBinary )
331+ defer wsNetConn . Close () // Also closes conn.
332+
333+ api .Logger .Debug (ctx ,"accepting turn connection" ,slog .F ("remote-address" ,r .RemoteAddr ),slog .F ("local-address" ,localAddress ))
332334select {
333- case <- api .TURNServer .Accept (netConn ,remoteAddress ,localAddress ).Closed ():
334- case <- r . Context () .Done ():
335+ case <- api .TURNServer .Accept (wsNetConn ,remoteAddress ,localAddress ).Closed ():
336+ case <- ctx .Done ():
335337}
336- api .Logger .Debug (r . Context () ,"completed turn connection" ,slog .F ("remote-address" ,r .RemoteAddr ),slog .F ("local-address" ,localAddress ))
338+ api .Logger .Debug (ctx ,"completed turn connection" ,slog .F ("remote-address" ,r .RemoteAddr ),slog .F ("local-address" ,localAddress ))
337339}
338340
339341// workspaceAgentPTY spawns a PTY and pipes it over a WebSocket.
@@ -384,12 +386,11 @@ func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
384386})
385387return
386388}
387- defer func () {
388- _ = conn .Close (websocket .StatusNormalClosure ,"ended" )
389- }()
390- // Accept text connections, because it's more developer friendly.
391- wsNetConn := websocket .NetConn (r .Context (),conn ,websocket .MessageBinary )
392- agentConn ,err := api .dialWorkspaceAgent (r ,workspaceAgent .ID )
389+
390+ ctx ,wsNetConn := websocketNetConn (r .Context (),conn ,websocket .MessageBinary )
391+ defer wsNetConn .Close ()// Also closes conn.
392+
393+ agentConn ,err := api .dialWorkspaceAgent (ctx ,r ,workspaceAgent .ID )
393394if err != nil {
394395_ = conn .Close (websocket .StatusInternalError ,httpapi .WebsocketCloseSprintf ("dial workspace agent: %s" ,err ))
395396return
@@ -408,11 +409,13 @@ func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
408409_ ,_ = io .Copy (ptNetConn ,wsNetConn )
409410}
410411
411- // dialWorkspaceAgent connects to a workspace agent by ID.
412- func (api * API )dialWorkspaceAgent (r * http.Request ,agentID uuid.UUID ) (* agent.Conn ,error ) {
412+ // dialWorkspaceAgent connects to a workspace agent by ID. Only rely on
413+ // r.Context() for cancellation if it's use is safe or r.Hijack() has
414+ // not been performed.
415+ func (api * API )dialWorkspaceAgent (ctx context.Context ,r * http.Request ,agentID uuid.UUID ) (* agent.Conn ,error ) {
413416client ,server := provisionersdk .TransportPipe ()
414417go func () {
415- _ = peerbroker .ProxyListen (r . Context () ,server , peerbroker.ProxyOptions {
418+ _ = peerbroker .ProxyListen (ctx ,server , peerbroker.ProxyOptions {
416419ChannelID :agentID .String (),
417420Logger :api .Logger .Named ("peerbroker-proxy-dial" ),
418421Pubsub :api .Pubsub ,
@@ -422,7 +425,7 @@ func (api *API) dialWorkspaceAgent(r *http.Request, agentID uuid.UUID) (*agent.C
422425}()
423426
424427peerClient := proto .NewDRPCPeerBrokerClient (provisionersdk .Conn (client ))
425- stream ,err := peerClient .NegotiateConnection (r . Context () )
428+ stream ,err := peerClient .NegotiateConnection (ctx )
426429if err != nil {
427430return nil ,xerrors .Errorf ("negotiate: %w" ,err )
428431}
@@ -434,7 +437,7 @@ func (api *API) dialWorkspaceAgent(r *http.Request, agentID uuid.UUID) (*agent.C
434437options .SettingEngine .SetICEProxyDialer (turnconn .ProxyDialer (func () (c net.Conn ,err error ) {
435438clientPipe ,serverPipe := net .Pipe ()
436439go func () {
437- <- r . Context () .Done ()
440+ <- ctx .Done ()
438441_ = clientPipe .Close ()
439442_ = serverPipe .Close ()
440443}()
@@ -515,3 +518,44 @@ func convertWorkspaceAgent(dbAgent database.WorkspaceAgent, agentUpdateFrequency
515518
516519return workspaceAgent ,nil
517520}
521+
522+ // wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
523+ // is called if a read or write error is encountered.
524+ type wsNetConn struct {
525+ cancel context.CancelFunc
526+ net.Conn
527+ }
528+
529+ func (c * wsNetConn )Read (b []byte ) (n int ,err error ) {
530+ n ,err = c .Conn .Read (b )
531+ if err != nil {
532+ c .cancel ()
533+ }
534+ return n ,err
535+ }
536+
537+ func (c * wsNetConn )Write (b []byte ) (n int ,err error ) {
538+ n ,err = c .Conn .Write (b )
539+ if err != nil {
540+ c .cancel ()
541+ }
542+ return n ,err
543+ }
544+
545+ func (c * wsNetConn )Close ()error {
546+ defer c .cancel ()
547+ return c .Conn .Close ()
548+ }
549+
550+ // websocketNetConn wraps websocket.NetConn and returns a context that
551+ // is tied to the parent context and the lifetime of the conn. Any error
552+ // during read or write will cancel the context, but not close the
553+ // conn. Close should be called to release context resources.
554+ func websocketNetConn (ctx context.Context ,conn * websocket.Conn ,msgType websocket.MessageType ) (context.Context , net.Conn ) {
555+ ctx ,cancel := context .WithCancel (ctx )
556+ nc := websocket .NetConn (ctx ,conn ,msgType )
557+ return ctx ,& wsNetConn {
558+ cancel :cancel ,
559+ Conn :nc ,
560+ }
561+ }