Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitb4f5920

Browse files
authored
fix: Avoid use of r.Context() after r.Hijack() (#1978)
1 parent61aacff commitb4f5920

File tree

1 file changed

+74
-30
lines changed

1 file changed

+74
-30
lines changed

‎coderd/workspaceagents.go‎

Lines changed: 74 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package coderd
22

33
import (
4+
"context"
45
"database/sql"
56
"encoding/json"
67
"fmt"
@@ -16,6 +17,7 @@ import (
1617
"nhooyr.io/websocket"
1718

1819
"cdr.dev/slog"
20+
1921
"github.com/coder/coder/agent"
2022
"github.com/coder/coder/coderd/database"
2123
"github.com/coder/coder/coderd/httpapi"
@@ -69,17 +71,18 @@ func (api *API) workspaceAgentDial(rw http.ResponseWriter, r *http.Request) {
6971
})
7072
return
7173
}
72-
deferfunc() {
73-
_=conn.Close(websocket.StatusNormalClosure,"")
74-
}()
74+
75+
ctx,wsNetConn:=websocketNetConn(r.Context(),conn,websocket.MessageBinary)
76+
deferwsNetConn.Close()// Also closes conn.
77+
7578
config:=yamux.DefaultConfig()
7679
config.LogOutput=io.Discard
77-
session,err:=yamux.Server(websocket.NetConn(r.Context(),conn,websocket.MessageBinary),config)
80+
session,err:=yamux.Server(wsNetConn,config)
7881
iferr!=nil {
7982
_=conn.Close(websocket.StatusAbnormalClosure,err.Error())
8083
return
8184
}
82-
err=peerbroker.ProxyListen(r.Context(),session, peerbroker.ProxyOptions{
85+
err=peerbroker.ProxyListen(ctx,session, peerbroker.ProxyOptions{
8386
ChannelID:workspaceAgent.ID.String(),
8487
Logger:api.Logger.Named("peerbroker-proxy-dial"),
8588
Pubsub:api.Pubsub,
@@ -193,13 +196,12 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
193196
return
194197
}
195198

196-
deferfunc() {
197-
_=conn.Close(websocket.StatusNormalClosure,"")
198-
}()
199+
ctx,wsNetConn:=websocketNetConn(r.Context(),conn,websocket.MessageBinary)
200+
deferwsNetConn.Close()// Also closes conn.
199201

200202
config:=yamux.DefaultConfig()
201203
config.LogOutput=io.Discard
202-
session,err:=yamux.Server(websocket.NetConn(r.Context(),conn,websocket.MessageBinary),config)
204+
session,err:=yamux.Server(wsNetConn,config)
203205
iferr!=nil {
204206
_=conn.Close(websocket.StatusAbnormalClosure,err.Error())
205207
return
@@ -229,7 +231,7 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
229231
}
230232
disconnectedAt:=workspaceAgent.DisconnectedAt
231233
updateConnectionTimes:=func()error {
232-
err=api.Database.UpdateWorkspaceAgentConnectionByID(r.Context(), database.UpdateWorkspaceAgentConnectionByIDParams{
234+
err=api.Database.UpdateWorkspaceAgentConnectionByID(ctx, database.UpdateWorkspaceAgentConnectionByIDParams{
233235
ID:workspaceAgent.ID,
234236
FirstConnectedAt:firstConnectedAt,
235237
LastConnectedAt:lastConnectedAt,
@@ -255,7 +257,7 @@ func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
255257
return
256258
}
257259

258-
api.Logger.Info(r.Context(),"accepting agent",slog.F("resource",resource),slog.F("agent",workspaceAgent))
260+
api.Logger.Info(ctx,"accepting agent",slog.F("resource",resource),slog.F("agent",workspaceAgent))
259261

260262
ticker:=time.NewTicker(api.AgentConnectionUpdateFrequency)
261263
deferticker.Stop()
@@ -324,16 +326,16 @@ func (api *API) workspaceAgentTurn(rw http.ResponseWriter, r *http.Request) {
324326
})
325327
return
326328
}
327-
deferfunc() {
328-
_=wsConn.Close(websocket.StatusNormalClosure,"")
329-
}()
330-
netConn:=websocket.NetConn(r.Context(),wsConn,websocket.MessageBinary)
331-
api.Logger.Debug(r.Context(),"accepting turn connection",slog.F("remote-address",r.RemoteAddr),slog.F("local-address",localAddress))
329+
330+
ctx,wsNetConn:=websocketNetConn(r.Context(),wsConn,websocket.MessageBinary)
331+
deferwsNetConn.Close()// Also closes conn.
332+
333+
api.Logger.Debug(ctx,"accepting turn connection",slog.F("remote-address",r.RemoteAddr),slog.F("local-address",localAddress))
332334
select {
333-
case<-api.TURNServer.Accept(netConn,remoteAddress,localAddress).Closed():
334-
case<-r.Context().Done():
335+
case<-api.TURNServer.Accept(wsNetConn,remoteAddress,localAddress).Closed():
336+
case<-ctx.Done():
335337
}
336-
api.Logger.Debug(r.Context(),"completed turn connection",slog.F("remote-address",r.RemoteAddr),slog.F("local-address",localAddress))
338+
api.Logger.Debug(ctx,"completed turn connection",slog.F("remote-address",r.RemoteAddr),slog.F("local-address",localAddress))
337339
}
338340

339341
// workspaceAgentPTY spawns a PTY and pipes it over a WebSocket.
@@ -384,12 +386,11 @@ func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
384386
})
385387
return
386388
}
387-
deferfunc() {
388-
_=conn.Close(websocket.StatusNormalClosure,"ended")
389-
}()
390-
// Accept text connections, because it's more developer friendly.
391-
wsNetConn:=websocket.NetConn(r.Context(),conn,websocket.MessageBinary)
392-
agentConn,err:=api.dialWorkspaceAgent(r,workspaceAgent.ID)
389+
390+
ctx,wsNetConn:=websocketNetConn(r.Context(),conn,websocket.MessageBinary)
391+
deferwsNetConn.Close()// Also closes conn.
392+
393+
agentConn,err:=api.dialWorkspaceAgent(ctx,r,workspaceAgent.ID)
393394
iferr!=nil {
394395
_=conn.Close(websocket.StatusInternalError,httpapi.WebsocketCloseSprintf("dial workspace agent: %s",err))
395396
return
@@ -408,11 +409,13 @@ func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
408409
_,_=io.Copy(ptNetConn,wsNetConn)
409410
}
410411

411-
// dialWorkspaceAgent connects to a workspace agent by ID.
412-
func (api*API)dialWorkspaceAgent(r*http.Request,agentID uuid.UUID) (*agent.Conn,error) {
412+
// dialWorkspaceAgent connects to a workspace agent by ID. Only rely on
413+
// r.Context() for cancellation if it's use is safe or r.Hijack() has
414+
// not been performed.
415+
func (api*API)dialWorkspaceAgent(ctx context.Context,r*http.Request,agentID uuid.UUID) (*agent.Conn,error) {
413416
client,server:=provisionersdk.TransportPipe()
414417
gofunc() {
415-
_=peerbroker.ProxyListen(r.Context(),server, peerbroker.ProxyOptions{
418+
_=peerbroker.ProxyListen(ctx,server, peerbroker.ProxyOptions{
416419
ChannelID:agentID.String(),
417420
Logger:api.Logger.Named("peerbroker-proxy-dial"),
418421
Pubsub:api.Pubsub,
@@ -422,7 +425,7 @@ func (api *API) dialWorkspaceAgent(r *http.Request, agentID uuid.UUID) (*agent.C
422425
}()
423426

424427
peerClient:=proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(client))
425-
stream,err:=peerClient.NegotiateConnection(r.Context())
428+
stream,err:=peerClient.NegotiateConnection(ctx)
426429
iferr!=nil {
427430
returnnil,xerrors.Errorf("negotiate: %w",err)
428431
}
@@ -434,7 +437,7 @@ func (api *API) dialWorkspaceAgent(r *http.Request, agentID uuid.UUID) (*agent.C
434437
options.SettingEngine.SetICEProxyDialer(turnconn.ProxyDialer(func() (c net.Conn,errerror) {
435438
clientPipe,serverPipe:=net.Pipe()
436439
gofunc() {
437-
<-r.Context().Done()
440+
<-ctx.Done()
438441
_=clientPipe.Close()
439442
_=serverPipe.Close()
440443
}()
@@ -515,3 +518,44 @@ func convertWorkspaceAgent(dbAgent database.WorkspaceAgent, agentUpdateFrequency
515518

516519
returnworkspaceAgent,nil
517520
}
521+
522+
// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
523+
// is called if a read or write error is encountered.
524+
typewsNetConnstruct {
525+
cancel context.CancelFunc
526+
net.Conn
527+
}
528+
529+
func (c*wsNetConn)Read(b []byte) (nint,errerror) {
530+
n,err=c.Conn.Read(b)
531+
iferr!=nil {
532+
c.cancel()
533+
}
534+
returnn,err
535+
}
536+
537+
func (c*wsNetConn)Write(b []byte) (nint,errerror) {
538+
n,err=c.Conn.Write(b)
539+
iferr!=nil {
540+
c.cancel()
541+
}
542+
returnn,err
543+
}
544+
545+
func (c*wsNetConn)Close()error {
546+
deferc.cancel()
547+
returnc.Conn.Close()
548+
}
549+
550+
// websocketNetConn wraps websocket.NetConn and returns a context that
551+
// is tied to the parent context and the lifetime of the conn. Any error
552+
// during read or write will cancel the context, but not close the
553+
// conn. Close should be called to release context resources.
554+
funcwebsocketNetConn(ctx context.Context,conn*websocket.Conn,msgType websocket.MessageType) (context.Context, net.Conn) {
555+
ctx,cancel:=context.WithCancel(ctx)
556+
nc:=websocket.NetConn(ctx,conn,msgType)
557+
returnctx,&wsNetConn{
558+
cancel:cancel,
559+
Conn:nc,
560+
}
561+
}

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp