|
8 | 8 | "context"
|
9 | 9 | "fmt"
|
10 | 10 | "net"
|
| 11 | +"os" |
11 | 12 | "os/user"
|
| 13 | +"path/filepath" |
12 | 14 | "runtime"
|
13 | 15 | "strings"
|
14 | 16 | "sync"
|
@@ -403,6 +405,81 @@ func TestNewServer_Signal(t *testing.T) {
|
403 | 405 | })
|
404 | 406 | }
|
405 | 407 |
|
| 408 | +funcTestSSHServer_ClosesStdin(t*testing.T) { |
| 409 | +t.Parallel() |
| 410 | +ifruntime.GOOS=="windows" { |
| 411 | +t.Skip("bash doesn't exist on Windows") |
| 412 | +} |
| 413 | + |
| 414 | +ctx:=testutil.Context(t,testutil.WaitMedium) |
| 415 | +logger:=testutil.Logger(t) |
| 416 | +s,err:=agentssh.NewServer(ctx,logger,prometheus.NewRegistry(),afero.NewMemMapFs(),agentexec.DefaultExecer,nil) |
| 417 | +require.NoError(t,err) |
| 418 | +defers.Close() |
| 419 | +err=s.UpdateHostSigner(42) |
| 420 | +assert.NoError(t,err) |
| 421 | + |
| 422 | +ln,err:=net.Listen("tcp","127.0.0.1:0") |
| 423 | +require.NoError(t,err) |
| 424 | + |
| 425 | +done:=make(chanstruct{}) |
| 426 | +gofunc() { |
| 427 | +deferclose(done) |
| 428 | +err:=s.Serve(ln) |
| 429 | +assert.Error(t,err)// Server is closed. |
| 430 | +}() |
| 431 | +deferfunc() { |
| 432 | +err:=s.Close() |
| 433 | +require.NoError(t,err) |
| 434 | +<-done |
| 435 | +}() |
| 436 | + |
| 437 | +c:=sshClient(t,ln.Addr().String()) |
| 438 | + |
| 439 | +sess,err:=c.NewSession() |
| 440 | +require.NoError(t,err) |
| 441 | +stdout,err:=sess.StdoutPipe() |
| 442 | +require.NoError(t,err) |
| 443 | +stdin,err:=sess.StdinPipe() |
| 444 | +require.NoError(t,err) |
| 445 | +deferstdin.Close() |
| 446 | + |
| 447 | +dir:=t.TempDir() |
| 448 | +err=os.MkdirAll(dir,0o755) |
| 449 | +require.NoError(t,err) |
| 450 | +filePath:=filepath.Join(dir,"result.txt") |
| 451 | + |
| 452 | +// the shell command `read` will block until data is written to stdin, or closed. It will return |
| 453 | +// exit code 1 if it hits EOF, which is what we want to test. |
| 454 | +cmdErrCh:=make(chanerror,1) |
| 455 | +gofunc() { |
| 456 | +cmdErrCh<-sess.Start(fmt.Sprintf("echo started; read; echo\"read exit code: $?\" > %s",filePath)) |
| 457 | +}() |
| 458 | + |
| 459 | +cmdErr:=testutil.RequireReceive(ctx,t,cmdErrCh) |
| 460 | +require.NoError(t,cmdErr) |
| 461 | + |
| 462 | +readCh:=make(chanerror,1) |
| 463 | +gofunc() { |
| 464 | +buf:=make([]byte,8) |
| 465 | +_,err:=stdout.Read(buf) |
| 466 | +assert.Equal(t,"started\n",string(buf)) |
| 467 | +readCh<-err |
| 468 | +}() |
| 469 | +err=testutil.RequireReceive(ctx,t,readCh) |
| 470 | +require.NoError(t,err) |
| 471 | + |
| 472 | +sess.Close() |
| 473 | + |
| 474 | +varcontent []byte |
| 475 | +require.Eventually(t,func()bool { |
| 476 | +content,err=os.ReadFile(filePath) |
| 477 | +returnerr==nil |
| 478 | +},testutil.WaitMedium,testutil.IntervalFast) |
| 479 | +require.NoError(t,err) |
| 480 | +require.Equal(t,"read exit code: 1\n",string(content)) |
| 481 | +} |
| 482 | + |
406 | 483 | funcsshClient(t*testing.T,addrstring)*ssh.Client {
|
407 | 484 | conn,err:=net.Dial("tcp",addr)
|
408 | 485 | require.NoError(t,err)
|
|