Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitdb16dc5

Browse files
mafredripull[bot]
authored andcommitted
fix: Improve use of context inwebsocket.NetConn code paths (#6198)
1 parentbce4a85 commitdb16dc5

File tree

5 files changed

+162
-19
lines changed

5 files changed

+162
-19
lines changed

‎coderd/workspaceagents.go‎

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -748,10 +748,13 @@ func (api *API) workspaceAgentClientCoordinate(rw http.ResponseWriter, r *http.R
748748
})
749749
return
750750
}
751+
ctx,wsNetConn:=websocketNetConn(ctx,conn,websocket.MessageBinary)
752+
deferwsNetConn.Close()
753+
751754
gohttpapi.Heartbeat(ctx,conn)
752755

753756
deferconn.Close(websocket.StatusNormalClosure,"")
754-
err= (*api.TailnetCoordinator.Load()).ServeClient(websocket.NetConn(ctx,conn,websocket.MessageBinary),uuid.New(),workspaceAgent.ID)
757+
err= (*api.TailnetCoordinator.Load()).ServeClient(wsNetConn,uuid.New(),workspaceAgent.ID)
755758
iferr!=nil {
756759
_=conn.Close(websocket.StatusInternalError,err.Error())
757760
return

‎codersdk/agentsdk/agentsdk.go‎

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,8 @@ func (c *Client) Listen(ctx context.Context) (net.Conn, error) {
159159
returnnil,codersdk.ReadBodyAsError(res)
160160
}
161161

162+
ctx,wsNetConn:=websocketNetConn(ctx,conn,websocket.MessageBinary)
163+
162164
// Ping once every 30 seconds to ensure that the websocket is alive. If we
163165
// don't get a response within 30s we kill the websocket and reconnect.
164166
// See: https://github.com/coder/coder/pull/5824
@@ -195,7 +197,7 @@ func (c *Client) Listen(ctx context.Context) (net.Conn, error) {
195197
}
196198
}()
197199

198-
returnwebsocket.NetConn(ctx,conn,websocket.MessageBinary),nil
200+
returnwsNetConn,nil
199201
}
200202

201203
typePostAppHealthsRequeststruct {
@@ -529,3 +531,44 @@ type closeFunc func() error
529531
func (ccloseFunc)Close()error {
530532
returnc()
531533
}
534+
535+
// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
536+
// is called if a read or write error is encountered.
537+
typewsNetConnstruct {
538+
cancel context.CancelFunc
539+
net.Conn
540+
}
541+
542+
func (c*wsNetConn)Read(b []byte) (nint,errerror) {
543+
n,err=c.Conn.Read(b)
544+
iferr!=nil {
545+
c.cancel()
546+
}
547+
returnn,err
548+
}
549+
550+
func (c*wsNetConn)Write(b []byte) (nint,errerror) {
551+
n,err=c.Conn.Write(b)
552+
iferr!=nil {
553+
c.cancel()
554+
}
555+
returnn,err
556+
}
557+
558+
func (c*wsNetConn)Close()error {
559+
deferc.cancel()
560+
returnc.Conn.Close()
561+
}
562+
563+
// websocketNetConn wraps websocket.NetConn and returns a context that
564+
// is tied to the parent context and the lifetime of the conn. Any error
565+
// during read or write will cancel the context, but not close the
566+
// conn. Close should be called to release context resources.
567+
funcwebsocketNetConn(ctx context.Context,conn*websocket.Conn,msgType websocket.MessageType) (context.Context, net.Conn) {
568+
ctx,cancel:=context.WithCancel(ctx)
569+
nc:=websocket.NetConn(ctx,conn,msgType)
570+
returnctx,&wsNetConn{
571+
cancel:cancel,
572+
Conn:nc,
573+
}
574+
}

‎codersdk/provisionerdaemons.go‎

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"encoding/json"
77
"fmt"
88
"io"
9+
"net"
910
"net/http"
1011
"net/http/cookiejar"
1112
"net/url"
@@ -143,8 +144,9 @@ func (c *Client) provisionerJobLogsAfter(ctx context.Context, path string, after
143144
returnnil,nil,ReadBodyAsError(res)
144145
}
145146
logs:=make(chanProvisionerJobLog)
146-
decoder:=json.NewDecoder(websocket.NetConn(ctx,conn,websocket.MessageText))
147147
closed:=make(chanstruct{})
148+
ctx,wsNetConn:=websocketNetConn(ctx,conn,websocket.MessageText)
149+
decoder:=json.NewDecoder(wsNetConn)
148150
gofunc() {
149151
deferclose(closed)
150152
deferclose(logs)
@@ -163,13 +165,15 @@ func (c *Client) provisionerJobLogsAfter(ctx context.Context, path string, after
163165
}
164166
}()
165167
returnlogs,closeFunc(func()error {
166-
_=conn.Close(websocket.StatusNormalClosure,"")
168+
_=wsNetConn.Close()
167169
<-closed
168170
returnnil
169171
}),nil
170172
}
171173

172-
// ListenProvisionerDaemon returns the gRPC service for a provisioner daemon implementation.
174+
// ListenProvisionerDaemon returns the gRPC service for a provisioner daemon
175+
// implementation. The context is during dial, not during the lifetime of the
176+
// client. Client should be closed after use.
173177
func (c*Client)ServeProvisionerDaemon(ctx context.Context,organization uuid.UUID,provisioners []ProvisionerType,tagsmap[string]string) (proto.DRPCProvisionerDaemonClient,error) {
174178
serverURL,err:=c.URL.Parse(fmt.Sprintf("/api/v2/organizations/%s/provisionerdaemons/serve",organization))
175179
iferr!=nil {
@@ -210,9 +214,55 @@ func (c *Client) ServeProvisionerDaemon(ctx context.Context, organization uuid.U
210214

211215
config:=yamux.DefaultConfig()
212216
config.LogOutput=io.Discard
213-
session,err:=yamux.Client(websocket.NetConn(ctx,conn,websocket.MessageBinary),config)
217+
// Use background context because caller should close the client.
218+
_,wsNetConn:=websocketNetConn(context.Background(),conn,websocket.MessageBinary)
219+
session,err:=yamux.Client(wsNetConn,config)
214220
iferr!=nil {
221+
_=conn.Close(websocket.StatusGoingAway,"")
222+
_=wsNetConn.Close()
215223
returnnil,xerrors.Errorf("multiplex client: %w",err)
216224
}
217225
returnproto.NewDRPCProvisionerDaemonClient(provisionersdk.MultiplexedConn(session)),nil
218226
}
227+
228+
// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
229+
// is called if a read or write error is encountered.
230+
// @typescript-ignore wsNetConn
231+
typewsNetConnstruct {
232+
cancel context.CancelFunc
233+
net.Conn
234+
}
235+
236+
func (c*wsNetConn)Read(b []byte) (nint,errerror) {
237+
n,err=c.Conn.Read(b)
238+
iferr!=nil {
239+
c.cancel()
240+
}
241+
returnn,err
242+
}
243+
244+
func (c*wsNetConn)Write(b []byte) (nint,errerror) {
245+
n,err=c.Conn.Write(b)
246+
iferr!=nil {
247+
c.cancel()
248+
}
249+
returnn,err
250+
}
251+
252+
func (c*wsNetConn)Close()error {
253+
deferc.cancel()
254+
returnc.Conn.Close()
255+
}
256+
257+
// websocketNetConn wraps websocket.NetConn and returns a context that
258+
// is tied to the parent context and the lifetime of the conn. Any error
259+
// during read or write will cancel the context, but not close the
260+
// conn. Close should be called to release context resources.
261+
funcwebsocketNetConn(ctx context.Context,conn*websocket.Conn,msgType websocket.MessageType) (context.Context, net.Conn) {
262+
ctx,cancel:=context.WithCancel(ctx)
263+
nc:=websocket.NetConn(ctx,conn,msgType)
264+
returnctx,&wsNetConn{
265+
cancel:cancel,
266+
Conn:nc,
267+
}
268+
}

‎codersdk/workspaceagents.go‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ func (c *Client) WorkspaceAgentReconnectingPTY(ctx context.Context, agentID, rec
257257
}
258258
returnnil,ReadBodyAsError(res)
259259
}
260-
returnwebsocket.NetConn(ctx,conn,websocket.MessageBinary),nil
260+
returnwebsocket.NetConn(context.Background(),conn,websocket.MessageBinary),nil
261261
}
262262

263263
// WorkspaceAgentListeningPorts returns a list of ports that are currently being

‎enterprise/coderd/provisionerdaemons.go‎

Lines changed: 59 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
package coderd
22

33
import (
4+
"context"
45
"database/sql"
56
"encoding/json"
67
"errors"
78
"fmt"
89
"io"
10+
"net"
911
"net/http"
1012
"strings"
1113

@@ -94,12 +96,14 @@ func (api *API) provisionerDaemons(rw http.ResponseWriter, r *http.Request) {
9496
// @Success 101
9597
// @Router /organizations/{organization}/provisionerdaemons/serve [get]
9698
func (api*API)provisionerDaemonServe(rw http.ResponseWriter,r*http.Request) {
99+
ctx:=r.Context()
100+
97101
tags:=map[string]string{}
98102
ifr.URL.Query().Has("tag") {
99103
for_,tag:=ranger.URL.Query()["tag"] {
100104
parts:=strings.SplitN(tag,"=",2)
101105
iflen(parts)<2 {
102-
httpapi.Write(r.Context(),rw,http.StatusBadRequest, codersdk.Response{
106+
httpapi.Write(ctx,rw,http.StatusBadRequest, codersdk.Response{
103107
Message:fmt.Sprintf("Invalid format for tag %q. Key and value must be separated with =.",tag),
104108
})
105109
return
@@ -108,7 +112,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
108112
}
109113
}
110114
if!r.URL.Query().Has("provisioner") {
111-
httpapi.Write(r.Context(),rw,http.StatusBadRequest, codersdk.Response{
115+
httpapi.Write(ctx,rw,http.StatusBadRequest, codersdk.Response{
112116
Message:"The provisioner query parameter must be specified.",
113117
})
114118
return
@@ -122,7 +126,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
122126
casestring(codersdk.ProvisionerTypeTerraform):
123127
provisionersMap[codersdk.ProvisionerTypeTerraform]=struct{}{}
124128
default:
125-
httpapi.Write(r.Context(),rw,http.StatusBadRequest, codersdk.Response{
129+
httpapi.Write(ctx,rw,http.StatusBadRequest, codersdk.Response{
126130
Message:fmt.Sprintf("Unknown provisioner type %q",provisioner),
127131
})
128132
return
@@ -137,7 +141,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
137141

138142
iftags[provisionerdserver.TagScope]==provisionerdserver.ScopeOrganization {
139143
if!api.AGPL.Authorize(r,rbac.ActionCreate,rbac.ResourceProvisionerDaemon) {
140-
httpapi.Write(r.Context(),rw,http.StatusForbidden, codersdk.Response{
144+
httpapi.Write(ctx,rw,http.StatusForbidden, codersdk.Response{
141145
Message:"You aren't allowed to create provisioner daemons for the organization.",
142146
})
143147
return
@@ -155,15 +159,15 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
155159
}
156160

157161
name:=namesgenerator.GetRandomName(1)
158-
daemon,err:=api.Database.InsertProvisionerDaemon(r.Context(), database.InsertProvisionerDaemonParams{
162+
daemon,err:=api.Database.InsertProvisionerDaemon(ctx, database.InsertProvisionerDaemonParams{
159163
ID:uuid.New(),
160164
CreatedAt:database.Now(),
161165
Name:name,
162166
Provisioners:provisioners,
163167
Tags:tags,
164168
})
165169
iferr!=nil {
166-
httpapi.Write(r.Context(),rw,http.StatusInternalServerError, codersdk.Response{
170+
httpapi.Write(ctx,rw,http.StatusInternalServerError, codersdk.Response{
167171
Message:"Internal error writing provisioner daemon.",
168172
Detail:err.Error(),
169173
})
@@ -172,7 +176,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
172176

173177
rawTags,err:=json.Marshal(daemon.Tags)
174178
iferr!=nil {
175-
httpapi.Write(r.Context(),rw,http.StatusInternalServerError, codersdk.Response{
179+
httpapi.Write(ctx,rw,http.StatusInternalServerError, codersdk.Response{
176180
Message:"Internal error marshaling daemon tags.",
177181
Detail:err.Error(),
178182
})
@@ -189,7 +193,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
189193
CompressionMode:websocket.CompressionDisabled,
190194
})
191195
iferr!=nil {
192-
httpapi.Write(r.Context(),rw,http.StatusBadRequest, codersdk.Response{
196+
httpapi.Write(ctx,rw,http.StatusBadRequest, codersdk.Response{
193197
Message:"Internal error accepting websocket connection.",
194198
Detail:err.Error(),
195199
})
@@ -203,7 +207,9 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
203207
// the same connection.
204208
config:=yamux.DefaultConfig()
205209
config.LogOutput=io.Discard
206-
session,err:=yamux.Server(websocket.NetConn(r.Context(),conn,websocket.MessageBinary),config)
210+
ctx,wsNetConn:=websocketNetConn(ctx,conn,websocket.MessageBinary)
211+
deferwsNetConn.Close()
212+
session,err:=yamux.Server(wsNetConn,config)
207213
iferr!=nil {
208214
_=conn.Close(websocket.StatusInternalError,httpapi.WebsocketCloseSprintf("multiplex server: %s",err))
209215
return
@@ -229,12 +235,12 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
229235
ifxerrors.Is(err,io.EOF) {
230236
return
231237
}
232-
api.Logger.Debug(r.Context(),"drpc server error",slog.Error(err))
238+
api.Logger.Debug(ctx,"drpc server error",slog.Error(err))
233239
},
234240
})
235-
err=server.Serve(r.Context(),session)
241+
err=server.Serve(ctx,session)
236242
iferr!=nil&&!xerrors.Is(err,io.EOF) {
237-
api.Logger.Debug(r.Context(),"provisioner daemon disconnected",slog.Error(err))
243+
api.Logger.Debug(ctx,"provisioner daemon disconnected",slog.Error(err))
238244
_=conn.Close(websocket.StatusInternalError,httpapi.WebsocketCloseSprintf("serve: %s",err))
239245
return
240246
}
@@ -254,3 +260,44 @@ func convertProvisionerDaemon(daemon database.ProvisionerDaemon) codersdk.Provis
254260
}
255261
returnresult
256262
}
263+
264+
// wsNetConn wraps net.Conn created by websocket.NetConn(). Cancel func
265+
// is called if a read or write error is encountered.
266+
typewsNetConnstruct {
267+
cancel context.CancelFunc
268+
net.Conn
269+
}
270+
271+
func (c*wsNetConn)Read(b []byte) (nint,errerror) {
272+
n,err=c.Conn.Read(b)
273+
iferr!=nil {
274+
c.cancel()
275+
}
276+
returnn,err
277+
}
278+
279+
func (c*wsNetConn)Write(b []byte) (nint,errerror) {
280+
n,err=c.Conn.Write(b)
281+
iferr!=nil {
282+
c.cancel()
283+
}
284+
returnn,err
285+
}
286+
287+
func (c*wsNetConn)Close()error {
288+
deferc.cancel()
289+
returnc.Conn.Close()
290+
}
291+
292+
// websocketNetConn wraps websocket.NetConn and returns a context that
293+
// is tied to the parent context and the lifetime of the conn. Any error
294+
// during read or write will cancel the context, but not close the
295+
// conn. Close should be called to release context resources.
296+
funcwebsocketNetConn(ctx context.Context,conn*websocket.Conn,msgType websocket.MessageType) (context.Context, net.Conn) {
297+
ctx,cancel:=context.WithCancel(ctx)
298+
nc:=websocket.NetConn(ctx,conn,msgType)
299+
returnctx,&wsNetConn{
300+
cancel:cancel,
301+
Conn:nc,
302+
}
303+
}

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp