|
8 | 8 | "errors"
|
9 | 9 | "fmt"
|
10 | 10 | "io"
|
| 11 | +"net" |
11 | 12 | "net/http"
|
12 | 13 | "net/http/httptest"
|
13 | 14 | "os"
|
@@ -460,7 +461,7 @@ func (tt *connTest) goDiscardLoop(c *websocket.Conn) {
|
460 | 461 | }
|
461 | 462 |
|
462 | 463 | funcBenchmarkConn(b*testing.B) {
|
463 |
| -varbenchCases= []struct { |
| 464 | +benchCases:= []struct { |
464 | 465 | namestring
|
465 | 466 | mode websocket.CompressionMode
|
466 | 467 | }{
|
@@ -625,3 +626,149 @@ func TestConcurrentClosePing(t *testing.T) {
|
625 | 626 | }()
|
626 | 627 | }
|
627 | 628 | }
|
| 629 | + |
| 630 | +funcTestConnClosePropagation(t*testing.T) { |
| 631 | +t.Parallel() |
| 632 | + |
| 633 | +want:= []byte("hello") |
| 634 | +keepWriting:=func(c*websocket.Conn)<-chanerror { |
| 635 | +returnxsync.Go(func()error { |
| 636 | +for { |
| 637 | +err:=c.Write(context.Background(),websocket.MessageText,want) |
| 638 | +iferr!=nil { |
| 639 | +returnerr |
| 640 | +} |
| 641 | +} |
| 642 | +}) |
| 643 | +} |
| 644 | +keepReading:=func(c*websocket.Conn)<-chanerror { |
| 645 | +returnxsync.Go(func()error { |
| 646 | +for { |
| 647 | +_,got,err:=c.Read(context.Background()) |
| 648 | +iferr!=nil { |
| 649 | +returnerr |
| 650 | +} |
| 651 | +if!bytes.Equal(want,got) { |
| 652 | +returnfmt.Errorf("unexpected message: want %q, got %q",want,got) |
| 653 | +} |
| 654 | +} |
| 655 | +}) |
| 656 | +} |
| 657 | +checkReadErr:=func(t*testing.T,errerror) { |
| 658 | +// Check read error (output depends on when read is called in relation to connection closure). |
| 659 | +varce websocket.CloseError |
| 660 | +iferrors.As(err,&ce) { |
| 661 | +assert.Equal(t,"",websocket.StatusNormalClosure,ce.Code) |
| 662 | +}else { |
| 663 | +assert.ErrorIs(t,net.ErrClosed,err) |
| 664 | +} |
| 665 | +} |
| 666 | +checkConnErrs:=func(t*testing.T,conn...*websocket.Conn) { |
| 667 | +for_,c:=rangeconn { |
| 668 | +// Check write error. |
| 669 | +err:=c.Write(context.Background(),websocket.MessageText,want) |
| 670 | +assert.ErrorIs(t,net.ErrClosed,err) |
| 671 | + |
| 672 | +_,_,err=c.Read(context.Background()) |
| 673 | +checkReadErr(t,err) |
| 674 | +} |
| 675 | +} |
| 676 | + |
| 677 | +t.Run("CloseOtherSideDuringWrite",func(t*testing.T) { |
| 678 | +tt,this,other:=newConnTest(t,nil,nil) |
| 679 | + |
| 680 | +_=this.CloseRead(tt.ctx) |
| 681 | +thisWriteErr:=keepWriting(this) |
| 682 | + |
| 683 | +_,got,err:=other.Read(tt.ctx) |
| 684 | +assert.Success(t,err) |
| 685 | +assert.Equal(t,"msg",want,got) |
| 686 | + |
| 687 | +err=other.Close(websocket.StatusNormalClosure,"") |
| 688 | +assert.Success(t,err) |
| 689 | + |
| 690 | +select { |
| 691 | +caseerr:=<-thisWriteErr: |
| 692 | +assert.ErrorIs(t,net.ErrClosed,err) |
| 693 | +case<-tt.ctx.Done(): |
| 694 | +t.Fatal(tt.ctx.Err()) |
| 695 | +} |
| 696 | + |
| 697 | +checkConnErrs(t,this,other) |
| 698 | +}) |
| 699 | +t.Run("CloseThisSideDuringWrite",func(t*testing.T) { |
| 700 | +tt,this,other:=newConnTest(t,nil,nil) |
| 701 | + |
| 702 | +_=this.CloseRead(tt.ctx) |
| 703 | +thisWriteErr:=keepWriting(this) |
| 704 | +otherReadErr:=keepReading(other) |
| 705 | + |
| 706 | +err:=this.Close(websocket.StatusNormalClosure,"") |
| 707 | +assert.Success(t,err) |
| 708 | + |
| 709 | +select { |
| 710 | +caseerr:=<-thisWriteErr: |
| 711 | +assert.ErrorIs(t,net.ErrClosed,err) |
| 712 | +case<-tt.ctx.Done(): |
| 713 | +t.Fatal(tt.ctx.Err()) |
| 714 | +} |
| 715 | + |
| 716 | +select { |
| 717 | +caseerr:=<-otherReadErr: |
| 718 | +checkReadErr(t,err) |
| 719 | +case<-tt.ctx.Done(): |
| 720 | +t.Fatal(tt.ctx.Err()) |
| 721 | +} |
| 722 | + |
| 723 | +checkConnErrs(t,this,other) |
| 724 | +}) |
| 725 | +t.Run("CloseOtherSideDuringRead",func(t*testing.T) { |
| 726 | +tt,this,other:=newConnTest(t,nil,nil) |
| 727 | + |
| 728 | +_=other.CloseRead(tt.ctx) |
| 729 | +errs:=keepReading(this) |
| 730 | + |
| 731 | +err:=other.Write(tt.ctx,websocket.MessageText,want) |
| 732 | +assert.Success(t,err) |
| 733 | + |
| 734 | +err=other.Close(websocket.StatusNormalClosure,"") |
| 735 | +assert.Success(t,err) |
| 736 | + |
| 737 | +select { |
| 738 | +caseerr:=<-errs: |
| 739 | +checkReadErr(t,err) |
| 740 | +case<-tt.ctx.Done(): |
| 741 | +t.Fatal(tt.ctx.Err()) |
| 742 | +} |
| 743 | + |
| 744 | +checkConnErrs(t,this,other) |
| 745 | +}) |
| 746 | +t.Run("CloseThisSideDuringRead",func(t*testing.T) { |
| 747 | +tt,this,other:=newConnTest(t,nil,nil) |
| 748 | + |
| 749 | +thisReadErr:=keepReading(this) |
| 750 | +otherReadErr:=keepReading(other) |
| 751 | + |
| 752 | +err:=other.Write(tt.ctx,websocket.MessageText,want) |
| 753 | +assert.Success(t,err) |
| 754 | + |
| 755 | +err=this.Close(websocket.StatusNormalClosure,"") |
| 756 | +assert.Success(t,err) |
| 757 | + |
| 758 | +select { |
| 759 | +caseerr:=<-thisReadErr: |
| 760 | +checkReadErr(t,err) |
| 761 | +case<-tt.ctx.Done(): |
| 762 | +t.Fatal(tt.ctx.Err()) |
| 763 | +} |
| 764 | + |
| 765 | +select { |
| 766 | +caseerr:=<-otherReadErr: |
| 767 | +checkReadErr(t,err) |
| 768 | +case<-tt.ctx.Done(): |
| 769 | +t.Fatal(tt.ctx.Err()) |
| 770 | +} |
| 771 | + |
| 772 | +checkConnErrs(t,this,other) |
| 773 | +}) |
| 774 | +} |