@@ -51,9 +51,8 @@ type Conn struct {
5151br * bufio.Reader
5252bw * bufio.Writer
5353
54- readTimeout chan context.Context
55- writeTimeout chan context.Context
56- timeoutLoopDone chan struct {}
54+ readTimeoutStop atomic.Pointer [func ()bool ]
55+ writeTimeoutStop atomic.Pointer [func ()bool ]
5756
5857// Read state.
5958readMu * mu
@@ -113,10 +112,6 @@ func newConn(cfg connConfig) *Conn {
113112br :cfg .br ,
114113bw :cfg .bw ,
115114
116- readTimeout :make (chan context.Context ),
117- writeTimeout :make (chan context.Context ),
118- timeoutLoopDone :make (chan struct {}),
119-
120115closed :make (chan struct {}),
121116activePings :make (map [string ]chan <- struct {}),
122117onPingReceived :cfg .onPingReceived ,
@@ -144,8 +139,6 @@ func newConn(cfg connConfig) *Conn {
144139c .close ()
145140})
146141
147- go c .timeoutLoop ()
148-
149142return c
150143}
151144
@@ -175,27 +168,34 @@ func (c *Conn) close() error {
175168return err
176169}
177170
178- func (c * Conn )timeoutLoop () {
179- defer close (c .timeoutLoopDone )
171+ func (c * Conn )setupWriteTimeout (ctx context.Context ) {
172+ stop := context .AfterFunc (ctx ,func () {
173+ c .clearWriteTimeout ()
174+ c .close ()
175+ })
176+ swapTimeoutStop (& c .writeTimeoutStop ,& stop )
177+ }
180178
181- readCtx := context .Background ()
182- writeCtx := context .Background ()
179+ func (c * Conn )clearWriteTimeout () {
180+ swapTimeoutStop (& c .writeTimeoutStop ,nil )
181+ }
183182
184- for {
185- select {
186- case <- c .closed :
187- return
188-
189- case writeCtx = <- c .writeTimeout :
190- case readCtx = <- c .readTimeout :
191-
192- case <- readCtx .Done ():
193- c .close ()
194- return
195- case <- writeCtx .Done ():
196- c .close ()
197- return
198- }
183+ func (c * Conn )setupReadTimeout (ctx context.Context ) {
184+ stop := context .AfterFunc (ctx ,func () {
185+ c .clearReadTimeout ()
186+ c .close ()
187+ })
188+ swapTimeoutStop (& c .readTimeoutStop ,& stop )
189+ }
190+
191+ func (c * Conn )clearReadTimeout () {
192+ swapTimeoutStop (& c .readTimeoutStop ,nil )
193+ }
194+
195+ func swapTimeoutStop (p * atomic.Pointer [func ()bool ],newStop * func ()bool ) {
196+ oldStop := p .Swap (newStop )
197+ if oldStop != nil {
198+ (* oldStop )()
199199}
200200}
201201