Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit11bda98

Browse files
fix: avoid writing messages after close and improve handshake (#476)
Co-authored-by: Mathias Fredriksson <mafredri@gmail.com>
1 parent1253b77 commit11bda98

File tree

5 files changed

+252
-65
lines changed

5 files changed

+252
-65
lines changed

‎close.go‎

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ func CloseStatus(err error) StatusCode {
100100
func (c*Conn)Close(codeStatusCode,reasonstring) (errerror) {
101101
defererrd.Wrap(&err,"failed to close WebSocket")
102102

103-
if!c.casClosing() {
103+
ifc.casClosing() {
104104
err=c.waitGoroutines()
105105
iferr!=nil {
106106
returnerr
@@ -133,7 +133,7 @@ func (c *Conn) Close(code StatusCode, reason string) (err error) {
133133
func (c*Conn)CloseNow() (errerror) {
134134
defererrd.Wrap(&err,"failed to immediately close WebSocket")
135135

136-
if!c.casClosing() {
136+
ifc.casClosing() {
137137
err=c.waitGoroutines()
138138
iferr!=nil {
139139
returnerr
@@ -329,13 +329,7 @@ func (ce CloseError) bytesErr() ([]byte, error) {
329329
}
330330

331331
func (c*Conn)casClosing()bool {
332-
c.closeMu.Lock()
333-
deferc.closeMu.Unlock()
334-
if!c.closing {
335-
c.closing=true
336-
returntrue
337-
}
338-
returnfalse
332+
returnc.closing.Swap(true)
339333
}
340334

341335
func (c*Conn)isClosed()bool {

‎conn.go‎

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,19 @@ type Conn struct {
6969
writeHeaderBuf [8]byte
7070
writeHeaderheader
7171

72+
// Close handshake state.
73+
closeStateMu sync.RWMutex
74+
closeReceivedErrerror
75+
closeSentErrerror
76+
77+
// CloseRead state.
7278
closeReadMu sync.Mutex
7379
closeReadCtx context.Context
7480
closeReadDonechanstruct{}
7581

82+
closing atomic.Bool
83+
closeMu sync.Mutex// Protects following.
7684
closedchanstruct{}
77-
closeMu sync.Mutex
78-
closingbool
7985

8086
pingCounter atomic.Int64
8187
activePingsMu sync.Mutex

‎conn_test.go‎

Lines changed: 148 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"errors"
99
"fmt"
1010
"io"
11+
"net"
1112
"net/http"
1213
"net/http/httptest"
1314
"os"
@@ -460,7 +461,7 @@ func (tt *connTest) goDiscardLoop(c *websocket.Conn) {
460461
}
461462

462463
funcBenchmarkConn(b*testing.B) {
463-
varbenchCases= []struct {
464+
benchCases:= []struct {
464465
namestring
465466
mode websocket.CompressionMode
466467
}{
@@ -625,3 +626,149 @@ func TestConcurrentClosePing(t *testing.T) {
625626
}()
626627
}
627628
}
629+
630+
funcTestConnClosePropagation(t*testing.T) {
631+
t.Parallel()
632+
633+
want:= []byte("hello")
634+
keepWriting:=func(c*websocket.Conn)<-chanerror {
635+
returnxsync.Go(func()error {
636+
for {
637+
err:=c.Write(context.Background(),websocket.MessageText,want)
638+
iferr!=nil {
639+
returnerr
640+
}
641+
}
642+
})
643+
}
644+
keepReading:=func(c*websocket.Conn)<-chanerror {
645+
returnxsync.Go(func()error {
646+
for {
647+
_,got,err:=c.Read(context.Background())
648+
iferr!=nil {
649+
returnerr
650+
}
651+
if!bytes.Equal(want,got) {
652+
returnfmt.Errorf("unexpected message: want %q, got %q",want,got)
653+
}
654+
}
655+
})
656+
}
657+
checkReadErr:=func(t*testing.T,errerror) {
658+
// Check read error (output depends on when read is called in relation to connection closure).
659+
varce websocket.CloseError
660+
iferrors.As(err,&ce) {
661+
assert.Equal(t,"",websocket.StatusNormalClosure,ce.Code)
662+
}else {
663+
assert.ErrorIs(t,net.ErrClosed,err)
664+
}
665+
}
666+
checkConnErrs:=func(t*testing.T,conn...*websocket.Conn) {
667+
for_,c:=rangeconn {
668+
// Check write error.
669+
err:=c.Write(context.Background(),websocket.MessageText,want)
670+
assert.ErrorIs(t,net.ErrClosed,err)
671+
672+
_,_,err=c.Read(context.Background())
673+
checkReadErr(t,err)
674+
}
675+
}
676+
677+
t.Run("CloseOtherSideDuringWrite",func(t*testing.T) {
678+
tt,this,other:=newConnTest(t,nil,nil)
679+
680+
_=this.CloseRead(tt.ctx)
681+
thisWriteErr:=keepWriting(this)
682+
683+
_,got,err:=other.Read(tt.ctx)
684+
assert.Success(t,err)
685+
assert.Equal(t,"msg",want,got)
686+
687+
err=other.Close(websocket.StatusNormalClosure,"")
688+
assert.Success(t,err)
689+
690+
select {
691+
caseerr:=<-thisWriteErr:
692+
assert.ErrorIs(t,net.ErrClosed,err)
693+
case<-tt.ctx.Done():
694+
t.Fatal(tt.ctx.Err())
695+
}
696+
697+
checkConnErrs(t,this,other)
698+
})
699+
t.Run("CloseThisSideDuringWrite",func(t*testing.T) {
700+
tt,this,other:=newConnTest(t,nil,nil)
701+
702+
_=this.CloseRead(tt.ctx)
703+
thisWriteErr:=keepWriting(this)
704+
otherReadErr:=keepReading(other)
705+
706+
err:=this.Close(websocket.StatusNormalClosure,"")
707+
assert.Success(t,err)
708+
709+
select {
710+
caseerr:=<-thisWriteErr:
711+
assert.ErrorIs(t,net.ErrClosed,err)
712+
case<-tt.ctx.Done():
713+
t.Fatal(tt.ctx.Err())
714+
}
715+
716+
select {
717+
caseerr:=<-otherReadErr:
718+
checkReadErr(t,err)
719+
case<-tt.ctx.Done():
720+
t.Fatal(tt.ctx.Err())
721+
}
722+
723+
checkConnErrs(t,this,other)
724+
})
725+
t.Run("CloseOtherSideDuringRead",func(t*testing.T) {
726+
tt,this,other:=newConnTest(t,nil,nil)
727+
728+
_=other.CloseRead(tt.ctx)
729+
errs:=keepReading(this)
730+
731+
err:=other.Write(tt.ctx,websocket.MessageText,want)
732+
assert.Success(t,err)
733+
734+
err=other.Close(websocket.StatusNormalClosure,"")
735+
assert.Success(t,err)
736+
737+
select {
738+
caseerr:=<-errs:
739+
checkReadErr(t,err)
740+
case<-tt.ctx.Done():
741+
t.Fatal(tt.ctx.Err())
742+
}
743+
744+
checkConnErrs(t,this,other)
745+
})
746+
t.Run("CloseThisSideDuringRead",func(t*testing.T) {
747+
tt,this,other:=newConnTest(t,nil,nil)
748+
749+
thisReadErr:=keepReading(this)
750+
otherReadErr:=keepReading(other)
751+
752+
err:=other.Write(tt.ctx,websocket.MessageText,want)
753+
assert.Success(t,err)
754+
755+
err=this.Close(websocket.StatusNormalClosure,"")
756+
assert.Success(t,err)
757+
758+
select {
759+
caseerr:=<-thisReadErr:
760+
checkReadErr(t,err)
761+
case<-tt.ctx.Done():
762+
t.Fatal(tt.ctx.Err())
763+
}
764+
765+
select {
766+
caseerr:=<-otherReadErr:
767+
checkReadErr(t,err)
768+
case<-tt.ctx.Done():
769+
t.Fatal(tt.ctx.Err())
770+
}
771+
772+
checkConnErrs(t,this,other)
773+
})
774+
}

‎read.go‎

Lines changed: 59 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -217,57 +217,68 @@ func (c *Conn) readLoop(ctx context.Context) (header, error) {
217217
}
218218
}
219219

220-
func (c*Conn)readFrameHeader(ctx context.Context) (header,error) {
220+
// prepareRead sets the readTimeout context and returns a done function
221+
// to be called after the read is done. It also returns an error if the
222+
// connection is closed. The reference to the error is used to assign
223+
// an error depending on if the connection closed or the context timed
224+
// out during use. Typically the referenced error is a named return
225+
// variable of the function calling this method.
226+
func (c*Conn)prepareRead(ctx context.Context,err*error) (func(),error) {
221227
select {
222228
case<-c.closed:
223-
returnheader{},net.ErrClosed
229+
returnnil,net.ErrClosed
224230
casec.readTimeout<-ctx:
225231
}
226232

227-
h,err:=readFrameHeader(c.br,c.readHeaderBuf[:])
228-
iferr!=nil {
233+
done:=func() {
229234
select {
230235
case<-c.closed:
231-
returnheader{},net.ErrClosed
232-
case<-ctx.Done():
233-
returnheader{},ctx.Err()
234-
default:
235-
returnheader{},err
236+
if*err!=nil {
237+
*err=net.ErrClosed
238+
}
239+
casec.readTimeout<-context.Background():
240+
}
241+
if*err!=nil&&ctx.Err()!=nil {
242+
*err=ctx.Err()
236243
}
237244
}
238245

239-
select {
240-
case<-c.closed:
241-
returnheader{},net.ErrClosed
242-
casec.readTimeout<-context.Background():
246+
c.closeStateMu.Lock()
247+
closeReceivedErr:=c.closeReceivedErr
248+
c.closeStateMu.Unlock()
249+
ifcloseReceivedErr!=nil {
250+
deferdone()
251+
returnnil,closeReceivedErr
243252
}
244253

245-
returnh,nil
254+
returndone,nil
246255
}
247256

248-
func (c*Conn)readFramePayload(ctx context.Context,p []byte) (int,error) {
249-
select {
250-
case<-c.closed:
251-
return0,net.ErrClosed
252-
casec.readTimeout<-ctx:
257+
func (c*Conn)readFrameHeader(ctx context.Context) (_header,errerror) {
258+
readDone,err:=c.prepareRead(ctx,&err)
259+
iferr!=nil {
260+
returnheader{},err
253261
}
262+
deferreadDone()
254263

255-
n,err:=io.ReadFull(c.br,p)
264+
h,err:=readFrameHeader(c.br,c.readHeaderBuf[:])
256265
iferr!=nil {
257-
select {
258-
case<-c.closed:
259-
returnn,net.ErrClosed
260-
case<-ctx.Done():
261-
returnn,ctx.Err()
262-
default:
263-
returnn,fmt.Errorf("failed to read frame payload: %w",err)
264-
}
266+
returnheader{},err
265267
}
266268

267-
select {
268-
case<-c.closed:
269-
returnn,net.ErrClosed
270-
casec.readTimeout<-context.Background():
269+
returnh,nil
270+
}
271+
272+
func (c*Conn)readFramePayload(ctx context.Context,p []byte) (_int,errerror) {
273+
readDone,err:=c.prepareRead(ctx,&err)
274+
iferr!=nil {
275+
return0,err
276+
}
277+
deferreadDone()
278+
279+
n,err:=io.ReadFull(c.br,p)
280+
iferr!=nil {
281+
returnn,fmt.Errorf("failed to read frame payload: %w",err)
271282
}
272283

273284
returnn,err
@@ -325,9 +336,22 @@ func (c *Conn) handleControl(ctx context.Context, h header) (err error) {
325336
}
326337

327338
err=fmt.Errorf("received close frame: %w",ce)
328-
c.writeClose(ce.Code,ce.Reason)
329-
c.readMu.unlock()
330-
c.close()
339+
c.closeStateMu.Lock()
340+
c.closeReceivedErr=err
341+
closeSent:=c.closeSentErr!=nil
342+
c.closeStateMu.Unlock()
343+
344+
// Only unlock readMu if this connection is being closed becaue
345+
// c.close will try to acquire the readMu lock. We unlock for
346+
// writeClose as well because it may also call c.close.
347+
if!closeSent {
348+
c.readMu.unlock()
349+
_=c.writeClose(ce.Code,ce.Reason)
350+
}
351+
if!c.casClosing() {
352+
c.readMu.unlock()
353+
_=c.close()
354+
}
331355
returnerr
332356
}
333357

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp