@@ -10,6 +10,7 @@ import (
10
10
"net/http"
11
11
"net/http/httptest"
12
12
"strings"
13
+ "sync"
13
14
"testing"
14
15
15
16
"nhooyr.io/websocket/internal/test/assert"
@@ -142,6 +143,42 @@ func TestAccept(t *testing.T) {
142
143
_ ,err := Accept (w ,r ,nil )
143
144
assert .Contains (t ,err ,`failed to hijack connection` )
144
145
})
146
+ t .Run ("closeRace" ,func (t * testing.T ) {
147
+ t .Parallel ()
148
+
149
+ server ,_ := net .Pipe ()
150
+
151
+ rw := bufio .NewReadWriter (bufio .NewReader (server ),bufio .NewWriter (server ))
152
+ newResponseWriter := func () http.ResponseWriter {
153
+ return mockHijacker {
154
+ ResponseWriter :httptest .NewRecorder (),
155
+ hijack :func () (net.Conn ,* bufio.ReadWriter ,error ) {
156
+ return server ,rw ,nil
157
+ },
158
+ }
159
+ }
160
+ w := newResponseWriter ()
161
+
162
+ r := httptest .NewRequest ("GET" ,"/" ,nil )
163
+ r .Header .Set ("Connection" ,"Upgrade" )
164
+ r .Header .Set ("Upgrade" ,"websocket" )
165
+ r .Header .Set ("Sec-WebSocket-Version" ,"13" )
166
+ r .Header .Set ("Sec-WebSocket-Key" ,xrand .Base64 (16 ))
167
+
168
+ c ,err := Accept (w ,r ,nil )
169
+ wg := & sync.WaitGroup {}
170
+ wg .Add (2 )
171
+ go func () {
172
+ c .Close (StatusInternalError ,"the sky is falling" )
173
+ wg .Done ()
174
+ }()
175
+ go func () {
176
+ c .CloseNow ()
177
+ wg .Done ()
178
+ }()
179
+ wg .Wait ()
180
+ assert .Success (t ,err )
181
+ })
145
182
}
146
183
147
184
func Test_verifyClientHandshake (t * testing.T ) {