5
5
"errors"
6
6
"io"
7
7
"net"
8
- "sync"
9
8
"time"
10
9
11
10
"github.com/armon/circbuf"
@@ -23,9 +22,6 @@ import (
23
22
type bufferedReconnectingPTY struct {
24
23
command * pty.Cmd
25
24
26
- // mutex protects writing to the circular buffer and connections.
27
- mutex sync.RWMutex
28
-
29
25
activeConns map [string ]net.Conn
30
26
circularBuffer * circbuf.Buffer
31
27
@@ -100,7 +96,7 @@ func newBuffered(ctx context.Context, cmd *pty.Cmd, options *Options, logger slo
100
96
break
101
97
}
102
98
part := buffer [:read ]
103
- rpty .mutex .Lock ()
99
+ rpty .state . cond . L .Lock ()
104
100
_ ,err = rpty .circularBuffer .Write (part )
105
101
if err != nil {
106
102
logger .Error (ctx ,"write to circular buffer" ,slog .Error (err ))
@@ -119,7 +115,7 @@ func newBuffered(ctx context.Context, cmd *pty.Cmd, options *Options, logger slo
119
115
rpty .metrics .WithLabelValues ("write" ).Add (1 )
120
116
}
121
117
}
122
- rpty .mutex .Unlock ()
118
+ rpty .state . cond . L .Unlock ()
123
119
}
124
120
}()
125
121
@@ -136,14 +132,29 @@ func (rpty *bufferedReconnectingPTY) lifecycle(ctx context.Context, logger slog.
136
132
logger .Debug (ctx ,"reconnecting pty ready" )
137
133
rpty .state .setState (StateReady ,nil )
138
134
139
- state ,reasonErr := rpty .state .waitForStateOrContext (ctx ,StateClosing )
135
+ state ,reasonErr := rpty .state .waitForStateOrContext (ctx ,StateClosing , nil )
140
136
if state < StateClosing {
141
137
// If we have not closed yet then the context is what unblocked us (which
142
138
// means the agent is shutting down) so move into the closing phase.
143
139
rpty .Close (reasonErr .Error ())
144
140
}
145
141
rpty .timer .Stop ()
146
142
143
+ rpty .state .cond .L .Lock ()
144
+ // Log these closes only for debugging since the connections or processes
145
+ // might have already closed on their own.
146
+ for _ ,conn := range rpty .activeConns {
147
+ err := conn .Close ()
148
+ if err != nil {
149
+ logger .Debug (ctx ,"closed conn with error" ,slog .Error (err ))
150
+ }
151
+ }
152
+ // Connections get removed once the pty closes but it is possible there is
153
+ // still some data that needs to be written so clear the map now to avoid
154
+ // writing to closed connections.
155
+ rpty .activeConns = map [string ]net.Conn {}
156
+ rpty .state .cond .L .Unlock ()
157
+
147
158
// Log close/kill only for debugging since the process might have already
148
159
// closed on its own.
149
160
err := rpty .ptty .Close ()
@@ -167,65 +178,49 @@ func (rpty *bufferedReconnectingPTY) Attach(ctx context.Context, connID string,
167
178
ctx ,cancel := context .WithCancel (ctx )
168
179
defer cancel ()
169
180
170
- state ,err := rpty .state .waitForStateOrContext (ctx ,StateReady )
171
- if state != StateReady {
172
- return xerrors .Errorf ("reconnecting pty ready wait: %w" ,err )
173
- }
181
+ // Once we are ready, attach the active connection while we hold the mutex.
182
+ _ ,err := rpty .state .waitForStateOrContext (ctx ,StateReady ,func (state State ,err error )error {
183
+ if state != StateReady {
184
+ return xerrors .Errorf ("reconnecting pty ready wait: %w" ,err )
185
+ }
186
+
187
+ go heartbeat (ctx ,rpty .timer ,rpty .timeout )
188
+
189
+ // Resize the PTY to initial height + width.
190
+ err = rpty .ptty .Resize (height ,width )
191
+ if err != nil {
192
+ // We can continue after this, it's not fatal!
193
+ logger .Warn (ctx ,"reconnecting PTY initial resize failed, but will continue" ,slog .Error (err ))
194
+ rpty .metrics .WithLabelValues ("resize" ).Add (1 )
195
+ }
174
196
175
- go heartbeat (ctx ,rpty .timer ,rpty .timeout )
197
+ // Write any previously stored data for the TTY and store the connection for
198
+ // future writes.
199
+ prevBuf := slices .Clone (rpty .circularBuffer .Bytes ())
200
+ _ ,err = conn .Write (prevBuf )
201
+ if err != nil {
202
+ rpty .metrics .WithLabelValues ("write" ).Add (1 )
203
+ return xerrors .Errorf ("write buffer to conn: %w" ,err )
204
+ }
205
+ rpty .activeConns [connID ]= conn
176
206
177
- err = rpty .doAttach (ctx ,connID ,conn ,height ,width ,logger )
207
+ return nil
208
+ })
178
209
if err != nil {
179
210
return err
180
211
}
181
212
182
- go func () {
183
- _ ,_ = rpty .state .waitForStateOrContext (ctx ,StateClosing )
184
- rpty .mutex .Lock ()
185
- defer rpty .mutex .Unlock ()
213
+ defer func () {
214
+ rpty .state .cond .L .Lock ()
215
+ defer rpty .state .cond .L .Unlock ()
186
216
delete (rpty .activeConns ,connID )
187
- // Log closes only for debugging since the connection might have already
188
- // closed on its own.
189
- err := conn .Close ()
190
- if err != nil {
191
- logger .Debug (ctx ,"closed conn with error" ,slog .Error (err ))
192
- }
193
217
}()
194
218
195
219
// Pipe conn -> pty and block. pty -> conn is handled in newBuffered().
196
220
readConnLoop (ctx ,conn ,rpty .ptty ,rpty .metrics ,logger )
197
221
return nil
198
222
}
199
223
200
- // doAttach adds the connection to the map, replays the buffer, and starts the
201
- // heartbeat. It exists separately only so we can defer the mutex unlock which
202
- // is not possible in Attach since it blocks.
203
- func (rpty * bufferedReconnectingPTY )doAttach (ctx context.Context ,connID string ,conn net.Conn ,height ,width uint16 ,logger slog.Logger )error {
204
- // Ensure we do not write to or close connections while we attach.
205
- rpty .mutex .Lock ()
206
- defer rpty .mutex .Unlock ()
207
-
208
- // Resize the PTY to initial height + width.
209
- err := rpty .ptty .Resize (height ,width )
210
- if err != nil {
211
- // We can continue after this, it's not fatal!
212
- logger .Warn (ctx ,"reconnecting PTY initial resize failed, but will continue" ,slog .Error (err ))
213
- rpty .metrics .WithLabelValues ("resize" ).Add (1 )
214
- }
215
-
216
- // Write any previously stored data for the TTY and store the connection for
217
- // future writes.
218
- prevBuf := slices .Clone (rpty .circularBuffer .Bytes ())
219
- _ ,err = conn .Write (prevBuf )
220
- if err != nil {
221
- rpty .metrics .WithLabelValues ("write" ).Add (1 )
222
- return xerrors .Errorf ("write buffer to conn: %w" ,err )
223
- }
224
- rpty .activeConns [connID ]= conn
225
-
226
- return nil
227
- }
228
-
229
224
func (rpty * bufferedReconnectingPTY )Wait () {
230
225
_ ,_ = rpty .state .waitForState (StateClosing )
231
226
}