4
4
"context"
5
5
"crypto/rand"
6
6
"crypto/rsa"
7
+ "encoding/json"
7
8
"errors"
8
9
"fmt"
9
10
"io"
@@ -12,10 +13,14 @@ import (
12
13
"os/exec"
13
14
"os/user"
14
15
"runtime"
16
+ "strconv"
15
17
"strings"
16
18
"sync"
17
19
"time"
18
20
21
+ "github.com/google/uuid"
22
+ "github.com/smallnest/ringbuffer"
23
+
19
24
gsyslog"github.com/hashicorp/go-syslog"
20
25
"go.uber.org/atomic"
21
26
@@ -33,6 +38,11 @@ import (
33
38
"golang.org/x/xerrors"
34
39
)
35
40
41
+ type Options struct {
42
+ ReconnectingPTYTimeout time.Duration
43
+ Logger slog.Logger
44
+ }
45
+
36
46
type Metadata struct {
37
47
OwnerEmail string `json:"owner_email"`
38
48
OwnerUsername string `json:"owner_username"`
@@ -42,13 +52,20 @@ type Metadata struct {
42
52
43
53
type Dialer func (ctx context.Context ,logger slog.Logger ) (Metadata ,* peerbroker.Listener ,error )
44
54
45
- func New (dialer Dialer ,logger slog.Logger ) io.Closer {
55
+ func New (dialer Dialer ,options * Options ) io.Closer {
56
+ if options == nil {
57
+ options = & Options {}
58
+ }
59
+ if options .ReconnectingPTYTimeout == 0 {
60
+ options .ReconnectingPTYTimeout = 5 * time .Minute
61
+ }
46
62
ctx ,cancelFunc := context .WithCancel (context .Background ())
47
63
server := & agent {
48
- dialer :dialer ,
49
- logger :logger ,
50
- closeCancel :cancelFunc ,
51
- closed :make (chan struct {}),
64
+ dialer :dialer ,
65
+ reconnectingPTYTimeout :options .ReconnectingPTYTimeout ,
66
+ logger :options .Logger ,
67
+ closeCancel :cancelFunc ,
68
+ closed :make (chan struct {}),
52
69
}
53
70
server .init (ctx )
54
71
return server
@@ -58,6 +75,9 @@ type agent struct {
58
75
dialer Dialer
59
76
logger slog.Logger
60
77
78
+ reconnectingPTYs sync.Map
79
+ reconnectingPTYTimeout time.Duration
80
+
61
81
connCloseWait sync.WaitGroup
62
82
closeCancel context.CancelFunc
63
83
closeMutex sync.Mutex
@@ -196,6 +216,8 @@ func (a *agent) handlePeerConn(ctx context.Context, conn *peer.Conn) {
196
216
switch channel .Protocol () {
197
217
case "ssh" :
198
218
go a .sshServer .HandleConn (channel .NetConn ())
219
+ case "reconnecting-pty" :
220
+ go a .handleReconnectingPTY (ctx ,channel .Label (),channel .NetConn ())
199
221
default :
200
222
a .logger .Warn (ctx ,"unhandled protocol from channel" ,
201
223
slog .F ("protocol" ,channel .Protocol ()),
@@ -282,22 +304,25 @@ func (a *agent) init(ctx context.Context) {
282
304
go a .run (ctx )
283
305
}
284
306
285
- func (a * agent )handleSSHSession (session ssh.Session )error {
307
+ // createCommand processes raw command input with OpenSSH-like behavior.
308
+ // If the rawCommand provided is empty, it will default to the users shell.
309
+ // This injects environment variables specified by the user at launch too.
310
+ func (a * agent )createCommand (ctx context.Context ,rawCommand string ,env []string ) (* exec.Cmd ,error ) {
286
311
currentUser ,err := user .Current ()
287
312
if err != nil {
288
- return xerrors .Errorf ("get current user: %w" ,err )
313
+ return nil , xerrors .Errorf ("get current user: %w" ,err )
289
314
}
290
315
username := currentUser .Username
291
316
292
317
shell ,err := usershell .Get (username )
293
318
if err != nil {
294
- return xerrors .Errorf ("get user shell: %w" ,err )
319
+ return nil , xerrors .Errorf ("get user shell: %w" ,err )
295
320
}
296
321
297
322
// gliderlabs/ssh returns a command slice of zero
298
323
// when a shell is requested.
299
- command := session . RawCommand ()
300
- if len (session . Command () )== 0 {
324
+ command := rawCommand
325
+ if len (command )== 0 {
301
326
command = shell
302
327
}
303
328
@@ -307,11 +332,11 @@ func (a *agent) handleSSHSession(session ssh.Session) error {
307
332
if runtime .GOOS == "windows" {
308
333
caller = "/c"
309
334
}
310
- cmd := exec .CommandContext (session . Context () ,shell ,caller ,command )
311
- cmd .Env = append (os .Environ (),session . Environ () ... )
335
+ cmd := exec .CommandContext (ctx ,shell ,caller ,command )
336
+ cmd .Env = append (os .Environ (),env ... )
312
337
executablePath ,err := os .Executable ()
313
338
if err != nil {
314
- return xerrors .Errorf ("getting os executable: %w" ,err )
339
+ return nil , xerrors .Errorf ("getting os executable: %w" ,err )
315
340
}
316
341
// Git on Windows resolves with UNIX-style paths.
317
342
// If using backslashes, it's unable to find the executable.
@@ -332,6 +357,14 @@ func (a *agent) handleSSHSession(session ssh.Session) error {
332
357
}
333
358
}
334
359
}
360
+ return cmd ,nil
361
+ }
362
+
363
+ func (a * agent )handleSSHSession (session ssh.Session )error {
364
+ cmd ,err := a .createCommand (session .Context (),session .RawCommand (),session .Environ ())
365
+ if err != nil {
366
+ return err
367
+ }
335
368
336
369
sshPty ,windowSize ,isPty := session .Pty ()
337
370
if isPty {
@@ -381,6 +414,140 @@ func (a *agent) handleSSHSession(session ssh.Session) error {
381
414
return cmd .Wait ()
382
415
}
383
416
417
+ func (a * agent )handleReconnectingPTY (ctx context.Context ,rawID string ,conn net.Conn ) {
418
+ defer conn .Close ()
419
+
420
+ idParts := strings .Split (rawID ,":" )
421
+ if len (idParts )!= 3 {
422
+ a .logger .Warn (ctx ,"client sent invalid id format" ,slog .F ("raw-id" ,rawID ))
423
+ return
424
+ }
425
+ id := idParts [0 ]
426
+ // Enforce a consistent format for IDs.
427
+ _ ,err := uuid .Parse (id )
428
+ if err != nil {
429
+ a .logger .Warn (ctx ,"client sent reconnection token that isn't a uuid" ,slog .F ("id" ,id ),slog .Error (err ))
430
+ return
431
+ }
432
+ height ,err := strconv .Atoi (idParts [1 ])
433
+ if err != nil {
434
+ a .logger .Warn (ctx ,"client sent invalid height" ,slog .F ("id" ,id ),slog .F ("height" ,idParts [1 ]))
435
+ return
436
+ }
437
+ width ,err := strconv .Atoi (idParts [2 ])
438
+ if err != nil {
439
+ a .logger .Warn (ctx ,"client sent invalid width" ,slog .F ("id" ,id ),slog .F ("width" ,idParts [2 ]))
440
+ return
441
+ }
442
+
443
+ var rpty * reconnectingPTY
444
+ rawRPTY ,ok := a .reconnectingPTYs .Load (id )
445
+ if ok {
446
+ rpty ,ok = rawRPTY .(* reconnectingPTY )
447
+ if ! ok {
448
+ a .logger .Warn (ctx ,"found invalid type in reconnecting pty map" ,slog .F ("id" ,id ))
449
+ }
450
+ }else {
451
+ // Empty command will default to the users shell!
452
+ cmd ,err := a .createCommand (ctx ,"" ,nil )
453
+ if err != nil {
454
+ a .logger .Warn (ctx ,"create reconnecting pty command" ,slog .Error (err ))
455
+ return
456
+ }
457
+ ptty ,_ ,err := pty .Start (cmd )
458
+ if err != nil {
459
+ a .logger .Warn (ctx ,"start reconnecting pty command" ,slog .F ("id" ,id ))
460
+ }
461
+
462
+ rpty = & reconnectingPTY {
463
+ activeConns :make (map [string ]net.Conn ),
464
+ ptty :ptty ,
465
+ timeout :time .NewTimer (a .reconnectingPTYTimeout ),
466
+ // Default to buffer 1MB.
467
+ ringBuffer :ringbuffer .New (1 << 20 ),
468
+ }
469
+ a .reconnectingPTYs .Store (id ,rpty )
470
+ go func () {
471
+ // Close if the inactive timeout occurs, or the context ends.
472
+ select {
473
+ case <- rpty .timeout .C :
474
+ a .logger .Info (ctx ,"killing reconnecting pty due to inactivity" ,slog .F ("id" ,id ))
475
+ case <- ctx .Done ():
476
+ }
477
+ rpty .Close ()
478
+ }()
479
+ go func () {
480
+ buffer := make ([]byte ,32 * 1024 )
481
+ for {
482
+ read ,err := rpty .ptty .Output ().Read (buffer )
483
+ if err != nil {
484
+ rpty .Close ()
485
+ break
486
+ }
487
+ part := buffer [:read ]
488
+ _ ,err = rpty .ringBuffer .Write (part )
489
+ if err != nil {
490
+ a .logger .Error (ctx ,"reconnecting pty write buffer" ,slog .Error (err ),slog .F ("id" ,id ))
491
+ return
492
+ }
493
+ rpty .activeConnsMutex .Lock ()
494
+ for _ ,conn := range rpty .activeConns {
495
+ _ ,_ = conn .Write (part )
496
+ }
497
+ rpty .activeConnsMutex .Unlock ()
498
+ }
499
+ // If we break from the loop, the reconnecting PTY ended.
500
+ a .reconnectingPTYs .Delete (id )
501
+ }()
502
+ }
503
+ err = rpty .ptty .Resize (uint16 (height ),uint16 (width ))
504
+ if err != nil {
505
+ // We can continue after this, it's not fatal!
506
+ a .logger .Error (ctx ,"resize reconnecting pty" ,slog .F ("id" ,id ),slog .Error (err ))
507
+ }
508
+
509
+ _ ,err = conn .Write (rpty .ringBuffer .Bytes ())
510
+ if err != nil {
511
+ a .logger .Warn (ctx ,"write reconnecting pty buffer" ,slog .F ("id" ,id ),slog .Error (err ))
512
+ return
513
+ }
514
+ connectionID := uuid .NewString ()
515
+ rpty .activeConnsMutex .Lock ()
516
+ rpty .activeConns [connectionID ]= conn
517
+ rpty .activeConnsMutex .Unlock ()
518
+ defer func () {
519
+ rpty .activeConnsMutex .Lock ()
520
+ delete (rpty .activeConns ,connectionID )
521
+ rpty .activeConnsMutex .Unlock ()
522
+ }()
523
+ decoder := json .NewDecoder (conn )
524
+ var req ReconnectingPTYRequest
525
+ for {
526
+ err = decoder .Decode (& req )
527
+ if xerrors .Is (err ,io .EOF ) {
528
+ return
529
+ }
530
+ if err != nil {
531
+ a .logger .Warn (ctx ,"reconnecting pty buffer read error" ,slog .F ("id" ,id ),slog .Error (err ))
532
+ return
533
+ }
534
+ _ ,err = rpty .ptty .Input ().Write ([]byte (req .Data ))
535
+ if err != nil {
536
+ a .logger .Warn (ctx ,"write to reconnecting pty" ,slog .F ("id" ,id ),slog .Error (err ))
537
+ return
538
+ }
539
+ // Check if a resize needs to happen!
540
+ if req .Height == 0 || req .Width == 0 {
541
+ continue
542
+ }
543
+ err = rpty .ptty .Resize (req .Height ,req .Width )
544
+ if err != nil {
545
+ // We can continue after this, it's not fatal!
546
+ a .logger .Error (ctx ,"resize reconnecting pty" ,slog .F ("id" ,id ),slog .Error (err ))
547
+ }
548
+ }
549
+ }
550
+
384
551
// isClosed returns whether the API is closed or not.
385
552
func (a * agent )isClosed ()bool {
386
553
select {
@@ -403,3 +570,22 @@ func (a *agent) Close() error {
403
570
a .connCloseWait .Wait ()
404
571
return nil
405
572
}
573
+
574
+ type reconnectingPTY struct {
575
+ activeConnsMutex sync.Mutex
576
+ activeConns map [string ]net.Conn
577
+
578
+ ringBuffer * ringbuffer.RingBuffer
579
+ timeout * time.Timer
580
+ ptty pty.PTY
581
+ }
582
+
583
+ func (r * reconnectingPTY )Close () {
584
+ r .activeConnsMutex .Lock ()
585
+ defer r .activeConnsMutex .Unlock ()
586
+ for _ ,conn := range r .activeConns {
587
+ _ = conn .Close ()
588
+ }
589
+ _ = r .ptty .Close ()
590
+ r .ringBuffer .Reset ()
591
+ }