4
4
"context"
5
5
"crypto/rand"
6
6
"crypto/rsa"
7
+ "encoding/json"
7
8
"errors"
8
9
"fmt"
9
10
"io"
@@ -12,10 +13,14 @@ import (
12
13
"os/exec"
13
14
"os/user"
14
15
"runtime"
16
+ "strconv"
15
17
"strings"
16
18
"sync"
17
19
"time"
18
20
21
+ "github.com/google/uuid"
22
+ "github.com/smallnest/ringbuffer"
23
+
19
24
gsyslog"github.com/hashicorp/go-syslog"
20
25
"go.uber.org/atomic"
21
26
@@ -33,6 +38,11 @@ import (
33
38
"golang.org/x/xerrors"
34
39
)
35
40
41
+ type Options struct {
42
+ ReconnectingPTYTimeout time.Duration
43
+ Logger slog.Logger
44
+ }
45
+
36
46
type Metadata struct {
37
47
OwnerEmail string `json:"owner_email"`
38
48
OwnerUsername string `json:"owner_username"`
@@ -42,13 +52,20 @@ type Metadata struct {
42
52
43
53
type Dialer func (ctx context.Context ,logger slog.Logger ) (Metadata ,* peerbroker.Listener ,error )
44
54
45
- func New (dialer Dialer ,logger slog.Logger ) io.Closer {
55
+ func New (dialer Dialer ,options * Options ) io.Closer {
56
+ if options == nil {
57
+ options = & Options {}
58
+ }
59
+ if options .ReconnectingPTYTimeout == 0 {
60
+ options .ReconnectingPTYTimeout = 5 * time .Minute
61
+ }
46
62
ctx ,cancelFunc := context .WithCancel (context .Background ())
47
63
server := & agent {
48
- dialer :dialer ,
49
- logger :logger ,
50
- closeCancel :cancelFunc ,
51
- closed :make (chan struct {}),
64
+ dialer :dialer ,
65
+ reconnectingPTYTimeout :options .ReconnectingPTYTimeout ,
66
+ logger :options .Logger ,
67
+ closeCancel :cancelFunc ,
68
+ closed :make (chan struct {}),
52
69
}
53
70
server .init (ctx )
54
71
return server
@@ -58,6 +75,9 @@ type agent struct {
58
75
dialer Dialer
59
76
logger slog.Logger
60
77
78
+ reconnectingPTYs sync.Map
79
+ reconnectingPTYTimeout time.Duration
80
+
61
81
connCloseWait sync.WaitGroup
62
82
closeCancel context.CancelFunc
63
83
closeMutex sync.Mutex
@@ -196,6 +216,8 @@ func (a *agent) handlePeerConn(ctx context.Context, conn *peer.Conn) {
196
216
switch channel .Protocol () {
197
217
case "ssh" :
198
218
go a .sshServer .HandleConn (channel .NetConn ())
219
+ case "reconnecting-pty" :
220
+ go a .handleReconnectingPTY (ctx ,channel .Label (),channel .NetConn ())
199
221
default :
200
222
a .logger .Warn (ctx ,"unhandled protocol from channel" ,
201
223
slog .F ("protocol" ,channel .Protocol ()),
@@ -282,22 +304,25 @@ func (a *agent) init(ctx context.Context) {
282
304
go a .run (ctx )
283
305
}
284
306
285
- func (a * agent )handleSSHSession (session ssh.Session )error {
307
+ // createCommand processes raw command input with OpenSSH-like behavior.
308
+ // If the rawCommand provided is empty, it will default to the users shell.
309
+ // This injects environment variables specified by the user at launch too.
310
+ func (a * agent )createCommand (ctx context.Context ,rawCommand string ,env []string ) (* exec.Cmd ,error ) {
286
311
currentUser ,err := user .Current ()
287
312
if err != nil {
288
- return xerrors .Errorf ("get current user: %w" ,err )
313
+ return nil , xerrors .Errorf ("get current user: %w" ,err )
289
314
}
290
315
username := currentUser .Username
291
316
292
317
shell ,err := usershell .Get (username )
293
318
if err != nil {
294
- return xerrors .Errorf ("get user shell: %w" ,err )
319
+ return nil , xerrors .Errorf ("get user shell: %w" ,err )
295
320
}
296
321
297
322
// gliderlabs/ssh returns a command slice of zero
298
323
// when a shell is requested.
299
- command := session . RawCommand ()
300
- if len (session . Command () )== 0 {
324
+ command := rawCommand
325
+ if len (command )== 0 {
301
326
command = shell
302
327
}
303
328
@@ -307,11 +332,11 @@ func (a *agent) handleSSHSession(session ssh.Session) error {
307
332
if runtime .GOOS == "windows" {
308
333
caller = "/c"
309
334
}
310
- cmd := exec .CommandContext (session . Context () ,shell ,caller ,command )
311
- cmd .Env = append (os .Environ (),session . Environ () ... )
335
+ cmd := exec .CommandContext (ctx ,shell ,caller ,command )
336
+ cmd .Env = append (os .Environ (),env ... )
312
337
executablePath ,err := os .Executable ()
313
338
if err != nil {
314
- return xerrors .Errorf ("getting os executable: %w" ,err )
339
+ return nil , xerrors .Errorf ("getting os executable: %w" ,err )
315
340
}
316
341
// Git on Windows resolves with UNIX-style paths.
317
342
// If using backslashes, it's unable to find the executable.
@@ -332,6 +357,14 @@ func (a *agent) handleSSHSession(session ssh.Session) error {
332
357
}
333
358
}
334
359
}
360
+ return cmd ,nil
361
+ }
362
+
363
+ func (a * agent )handleSSHSession (session ssh.Session )error {
364
+ cmd ,err := a .createCommand (session .Context (),session .RawCommand (),session .Environ ())
365
+ if err != nil {
366
+ return err
367
+ }
335
368
336
369
sshPty ,windowSize ,isPty := session .Pty ()
337
370
if isPty {
@@ -381,6 +414,144 @@ func (a *agent) handleSSHSession(session ssh.Session) error {
381
414
return cmd .Wait ()
382
415
}
383
416
417
+ func (a * agent )handleReconnectingPTY (ctx context.Context ,rawID string ,conn net.Conn ) {
418
+ defer conn .Close ()
419
+
420
+ idParts := strings .Split (rawID ,":" )
421
+ if len (idParts )!= 3 {
422
+ a .logger .Warn (ctx ,"client sent invalid id format" ,slog .F ("raw-id" ,rawID ))
423
+ return
424
+ }
425
+ id := idParts [0 ]
426
+ // Enforce a consistent format for IDs.
427
+ _ ,err := uuid .Parse (id )
428
+ if err != nil {
429
+ a .logger .Warn (ctx ,"client sent reconnection token that isn't a uuid" ,slog .F ("id" ,id ),slog .Error (err ))
430
+ return
431
+ }
432
+ height ,err := strconv .Atoi (idParts [1 ])
433
+ if err != nil {
434
+ a .logger .Warn (ctx ,"client sent invalid height" ,slog .F ("id" ,id ),slog .F ("height" ,idParts [1 ]))
435
+ return
436
+ }
437
+ width ,err := strconv .Atoi (idParts [2 ])
438
+ if err != nil {
439
+ a .logger .Warn (ctx ,"client sent invalid width" ,slog .F ("id" ,id ),slog .F ("width" ,idParts [2 ]))
440
+ return
441
+ }
442
+
443
+ var rpty * reconnectingPTY
444
+ rawRPTY ,ok := a .reconnectingPTYs .Load (id )
445
+ if ok {
446
+ rpty ,ok = rawRPTY .(* reconnectingPTY )
447
+ if ! ok {
448
+ a .logger .Warn (ctx ,"found invalid type in reconnecting pty map" ,slog .F ("id" ,id ))
449
+ }
450
+ }else {
451
+ // Empty command will default to the users shell!
452
+ cmd ,err := a .createCommand (ctx ,"" ,nil )
453
+ if err != nil {
454
+ a .logger .Warn (ctx ,"create reconnecting pty command" ,slog .Error (err ))
455
+ return
456
+ }
457
+ ptty ,_ ,err := pty .Start (cmd )
458
+ if err != nil {
459
+ a .logger .Warn (ctx ,"start reconnecting pty command" ,slog .F ("id" ,id ))
460
+ }
461
+
462
+ a .closeMutex .Lock ()
463
+ a .connCloseWait .Add (1 )
464
+ a .closeMutex .Unlock ()
465
+ rpty = & reconnectingPTY {
466
+ activeConns :make (map [string ]net.Conn ),
467
+ ptty :ptty ,
468
+ timeout :time .NewTimer (a .reconnectingPTYTimeout ),
469
+ // Default to buffer 1MB.
470
+ ringBuffer :ringbuffer .New (1 << 20 ),
471
+ }
472
+ a .reconnectingPTYs .Store (id ,rpty )
473
+ go func () {
474
+ // Close if the inactive timeout occurs, or the context ends.
475
+ select {
476
+ case <- rpty .timeout .C :
477
+ a .logger .Info (ctx ,"killing reconnecting pty due to inactivity" ,slog .F ("id" ,id ))
478
+ case <- ctx .Done ():
479
+ }
480
+ rpty .Close ()
481
+ }()
482
+ go func () {
483
+ buffer := make ([]byte ,32 * 1024 )
484
+ for {
485
+ read ,err := rpty .ptty .Output ().Read (buffer )
486
+ if err != nil {
487
+ rpty .Close ()
488
+ break
489
+ }
490
+ part := buffer [:read ]
491
+ _ ,err = rpty .ringBuffer .Write (part )
492
+ if err != nil {
493
+ a .logger .Error (ctx ,"reconnecting pty write buffer" ,slog .Error (err ),slog .F ("id" ,id ))
494
+ break
495
+ }
496
+ rpty .activeConnsMutex .Lock ()
497
+ for _ ,conn := range rpty .activeConns {
498
+ _ ,_ = conn .Write (part )
499
+ }
500
+ rpty .activeConnsMutex .Unlock ()
501
+ }
502
+ // If we break from the loop, the reconnecting PTY ended.
503
+ a .reconnectingPTYs .Delete (id )
504
+ a .connCloseWait .Done ()
505
+ }()
506
+ }
507
+ err = rpty .ptty .Resize (uint16 (height ),uint16 (width ))
508
+ if err != nil {
509
+ // We can continue after this, it's not fatal!
510
+ a .logger .Error (ctx ,"resize reconnecting pty" ,slog .F ("id" ,id ),slog .Error (err ))
511
+ }
512
+
513
+ _ ,err = conn .Write (rpty .ringBuffer .Bytes ())
514
+ if err != nil {
515
+ a .logger .Warn (ctx ,"write reconnecting pty buffer" ,slog .F ("id" ,id ),slog .Error (err ))
516
+ return
517
+ }
518
+ connectionID := uuid .NewString ()
519
+ rpty .activeConnsMutex .Lock ()
520
+ rpty .activeConns [connectionID ]= conn
521
+ rpty .activeConnsMutex .Unlock ()
522
+ defer func () {
523
+ rpty .activeConnsMutex .Lock ()
524
+ delete (rpty .activeConns ,connectionID )
525
+ rpty .activeConnsMutex .Unlock ()
526
+ }()
527
+ decoder := json .NewDecoder (conn )
528
+ var req ReconnectingPTYRequest
529
+ for {
530
+ err = decoder .Decode (& req )
531
+ if xerrors .Is (err ,io .EOF ) {
532
+ return
533
+ }
534
+ if err != nil {
535
+ a .logger .Warn (ctx ,"reconnecting pty buffer read error" ,slog .F ("id" ,id ),slog .Error (err ))
536
+ return
537
+ }
538
+ _ ,err = rpty .ptty .Input ().Write ([]byte (req .Data ))
539
+ if err != nil {
540
+ a .logger .Warn (ctx ,"write to reconnecting pty" ,slog .F ("id" ,id ),slog .Error (err ))
541
+ return
542
+ }
543
+ // Check if a resize needs to happen!
544
+ if req .Height == 0 || req .Width == 0 {
545
+ continue
546
+ }
547
+ err = rpty .ptty .Resize (req .Height ,req .Width )
548
+ if err != nil {
549
+ // We can continue after this, it's not fatal!
550
+ a .logger .Error (ctx ,"resize reconnecting pty" ,slog .F ("id" ,id ),slog .Error (err ))
551
+ }
552
+ }
553
+ }
554
+
384
555
// isClosed returns whether the API is closed or not.
385
556
func (a * agent )isClosed ()bool {
386
557
select {
@@ -403,3 +574,22 @@ func (a *agent) Close() error {
403
574
a .connCloseWait .Wait ()
404
575
return nil
405
576
}
577
+
578
+ type reconnectingPTY struct {
579
+ activeConnsMutex sync.Mutex
580
+ activeConns map [string ]net.Conn
581
+
582
+ ringBuffer * ringbuffer.RingBuffer
583
+ timeout * time.Timer
584
+ ptty pty.PTY
585
+ }
586
+
587
+ func (r * reconnectingPTY )Close () {
588
+ r .activeConnsMutex .Lock ()
589
+ defer r .activeConnsMutex .Unlock ()
590
+ for _ ,conn := range r .activeConns {
591
+ _ = conn .Close ()
592
+ }
593
+ _ = r .ptty .Close ()
594
+ r .ringBuffer .Reset ()
595
+ }