@@ -20,6 +20,7 @@ import (
2020"regexp"
2121"runtime"
2222"strings"
23+ "sync"
2324"testing"
2425"time"
2526
@@ -1318,9 +1319,6 @@ func TestSSH(t *testing.T) {
13181319
13191320tmpdir := tempDirUnixSocket (t )
13201321localSock := filepath .Join (tmpdir ,"local.sock" )
1321- l ,err := net .Listen ("unix" ,localSock )
1322- require .NoError (t ,err )
1323- defer l .Close ()
13241322remoteSock := filepath .Join (tmpdir ,"remote.sock" )
13251323
13261324inv ,root := clitest .New (t ,
@@ -1332,23 +1330,62 @@ func TestSSH(t *testing.T) {
13321330clitest .SetupConfig (t ,client ,root )
13331331pty := ptytest .New (t ).Attach (inv )
13341332inv .Stderr = pty .Output ()
1335- cmdDone := tGo (t ,func () {
1336- err := inv .WithContext (ctx ).Run ()
1337- assert .NoError (t ,err ,"ssh command failed" )
1338- })
13391333
1340- // Wait for the prompt or any output really to indicate the command has
1341- // started and accepting input on stdin.
1334+ w := clitest .StartWithWaiter (t ,inv .WithContext (ctx ))
1335+ defer w .Wait ()// We don't care about any exit error (exit code 255: SSH connection ended unexpectedly).
1336+
1337+ // Since something was output, it should be safe to write input.
1338+ // This could show a prompt or "running startup scripts", so it's
1339+ // not indicative of the SSH connection being ready.
13421340_ = pty .Peek (ctx ,1 )
13431341
1344- // This needs to support most shells on Linux or macOS
1345- // We can't include exactly what's expected in the input, as that will always be matched
1346- pty .WriteLine (fmt .Sprintf (`echo "results: $(netstat -an | grep %s | wc -l | tr -d ' ')"` ,remoteSock ))
1347- pty .ExpectMatchContext (ctx ,"results: 1" )
1342+ // Ensure the SSH connection is ready by testing the shell
1343+ // input/output.
1344+ pty .WriteLine ("echo ping' 'pong" )
1345+ pty .ExpectMatchContext (ctx ,"ping pong" )
1346+
1347+ // Start the listener on the "local machine".
1348+ l ,err := net .Listen ("unix" ,localSock )
1349+ require .NoError (t ,err )
1350+ defer l .Close ()
1351+ testutil .Go (t ,func () {
1352+ var wg sync.WaitGroup
1353+ defer wg .Wait ()
1354+ for {
1355+ fd ,err := l .Accept ()
1356+ if err != nil {
1357+ if ! errors .Is (err ,net .ErrClosed ) {
1358+ assert .NoError (t ,err ,"listener accept failed" )
1359+ }
1360+ return
1361+ }
1362+
1363+ wg .Add (1 )
1364+ go func () {
1365+ defer wg .Done ()
1366+ defer fd .Close ()
1367+ agentssh .Bicopy (ctx ,fd ,fd )
1368+ }()
1369+ }
1370+ })
1371+
1372+ // Dial the forwarded socket on the "remote machine".
1373+ d := & net.Dialer {}
1374+ fd ,err := d .DialContext (ctx ,"unix" ,remoteSock )
1375+ require .NoError (t ,err )
1376+ defer fd .Close ()
1377+
1378+ // Ping / pong to ensure the socket is working.
1379+ _ ,err = fd .Write ([]byte ("hello world" ))
1380+ require .NoError (t ,err )
1381+
1382+ buf := make ([]byte ,11 )
1383+ _ ,err = fd .Read (buf )
1384+ require .NoError (t ,err )
1385+ require .Equal (t ,"hello world" ,string (buf ))
13481386
13491387// And we're done.
13501388pty .WriteLine ("exit" )
1351- <- cmdDone
13521389})
13531390
13541391// Test that we can forward a local unix socket to a remote unix socket and
@@ -1377,6 +1414,8 @@ func TestSSH(t *testing.T) {
13771414require .NoError (t ,err )
13781415defer l .Close ()
13791416testutil .Go (t ,func () {
1417+ var wg sync.WaitGroup
1418+ defer wg .Wait ()
13801419for {
13811420fd ,err := l .Accept ()
13821421if err != nil {
@@ -1386,10 +1425,12 @@ func TestSSH(t *testing.T) {
13861425return
13871426}
13881427
1389- testutil .Go (t ,func () {
1428+ wg .Add (1 )
1429+ go func () {
1430+ defer wg .Done ()
13901431defer fd .Close ()
13911432agentssh .Bicopy (ctx ,fd ,fd )
1392- })
1433+ }( )
13931434}
13941435})
13951436
@@ -1522,6 +1563,8 @@ func TestSSH(t *testing.T) {
15221563require .NoError (t ,err )
15231564defer l .Close ()//nolint:revive // Defer is fine in this loop, we only run it twice.
15241565testutil .Go (t ,func () {
1566+ var wg sync.WaitGroup
1567+ defer wg .Wait ()
15251568for {
15261569fd ,err := l .Accept ()
15271570if err != nil {
@@ -1531,10 +1574,12 @@ func TestSSH(t *testing.T) {
15311574return
15321575}
15331576
1534- testutil .Go (t ,func () {
1577+ wg .Add (1 )
1578+ go func () {
1579+ defer wg .Done ()
15351580defer fd .Close ()
15361581agentssh .Bicopy (ctx ,fd ,fd )
1537- })
1582+ }( )
15381583}
15391584})
15401585