1
1
package coderd
2
2
3
3
import (
4
+ "context"
4
5
"database/sql"
5
6
"encoding/json"
7
+ "errors"
6
8
"fmt"
7
9
"io"
8
10
"net"
@@ -16,6 +18,7 @@ import (
16
18
"nhooyr.io/websocket"
17
19
18
20
"cdr.dev/slog"
21
+
19
22
"github.com/coder/coder/agent"
20
23
"github.com/coder/coder/coderd/database"
21
24
"github.com/coder/coder/coderd/httpapi"
@@ -324,16 +327,16 @@ func (api *API) workspaceAgentTurn(rw http.ResponseWriter, r *http.Request) {
324
327
})
325
328
return
326
329
}
327
- defer func () {
328
- _ = wsConn . Close ( websocket . StatusNormalClosure , "" )
329
- }()
330
- netConn := websocket . NetConn ( r . Context (), wsConn , websocket . MessageBinary )
331
- api .Logger .Debug (r . Context () ,"accepting turn connection" ,slog .F ("remote-address" ,r .RemoteAddr ),slog .F ("local-address" ,localAddress ))
330
+
331
+ ctx , wsNetConn := websocketNetConn ( r . Context (), wsConn , websocket . MessageBinary )
332
+ defer wsNetConn . Close () // Also closes conn.
333
+
334
+ api .Logger .Debug (ctx ,"accepting turn connection" ,slog .F ("remote-address" ,r .RemoteAddr ),slog .F ("local-address" ,localAddress ))
332
335
select {
333
- case <- api .TURNServer .Accept (netConn ,remoteAddress ,localAddress ).Closed ():
334
- case <- r . Context () .Done ():
336
+ case <- api .TURNServer .Accept (wsNetConn ,remoteAddress ,localAddress ).Closed ():
337
+ case <- ctx .Done ():
335
338
}
336
- api .Logger .Debug (r . Context () ,"completed turn connection" ,slog .F ("remote-address" ,r .RemoteAddr ),slog .F ("local-address" ,localAddress ))
339
+ api .Logger .Debug (ctx ,"completed turn connection" ,slog .F ("remote-address" ,r .RemoteAddr ),slog .F ("local-address" ,localAddress ))
337
340
}
338
341
339
342
// workspaceAgentPTY spawns a PTY and pipes it over a WebSocket.
@@ -515,3 +518,44 @@ func convertWorkspaceAgent(dbAgent database.WorkspaceAgent, agentUpdateFrequency
515
518
516
519
return workspaceAgent ,nil
517
520
}
521
+
522
+ // wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
523
+ // is called if io.EOF is encountered.
524
+ type wsNetConn struct {
525
+ cancel context.CancelFunc
526
+ net.Conn
527
+ }
528
+
529
+ func (c * wsNetConn )Read (b []byte ) (n int ,err error ) {
530
+ n ,err = c .Conn .Read (b )
531
+ if errors .Is (err ,io .EOF ) {
532
+ c .cancel ()
533
+ }
534
+ return n ,err
535
+ }
536
+
537
+ func (c * wsNetConn )Write (b []byte ) (n int ,err error ) {
538
+ n ,err = c .Conn .Write (b )
539
+ if errors .Is (err ,io .EOF ) {
540
+ c .cancel ()
541
+ }
542
+ return n ,err
543
+ }
544
+
545
+ func (c * wsNetConn )Close ()error {
546
+ defer c .cancel ()
547
+ return c .Conn .Close ()
548
+ }
549
+
550
+ // websocketNetConn wraps websocket.NetConn and returns a context that
551
+ // is tied to the parent context and the lifetime of the conn. A io.EOF
552
+ // error during read or write will cancel the context, but not close the
553
+ // conn. Close should be called to release context resources.
554
+ func websocketNetConn (ctx context.Context ,conn * websocket.Conn ,msgType websocket.MessageType ) (context.Context , net.Conn ) {
555
+ ctx ,cancel := context .WithCancel (ctx )
556
+ nc := websocket .NetConn (ctx ,conn ,msgType )
557
+ return ctx ,& wsNetConn {
558
+ cancel :cancel ,
559
+ Conn :nc ,
560
+ }
561
+ }