Expand Up @@ -144,20 +144,19 @@ func TestPortForward(t *testing.T) { for _, c := range cases { //nolint:paralleltest // the `c := c` confuses the linter c := c // Avoid parallel test here because setupLocal reserves // a free open port which is not guaranteed to be free // after the listener closes. //nolint:paralleltest t.Run(c.name, func(t *testing.T) { t.Parallel() //nolint:paralleltest t.Run("OnePort", func(t *testing.T) { t.Parallel() var ( client = coderdtest.New(t, &coderdtest.Options{IncludeProvisionerD: true}) user = coderdtest.CreateFirstUser(t, client) _, workspace = runAgent(t, client, user.UserID) l1, p1 = setupTestListener(t, c.setupRemote(t))p1 = setupTestListener(t, c.setupRemote(t))) t.Cleanup(func() { _ = l1.Close() }) // Create a flag that forwards from local to listener 1. localAddress, localFlag := c.setupLocal(t) Expand All @@ -171,9 +170,9 @@ func TestPortForward(t *testing.T) { cmd.SetOut(io.MultiWriter(buf, os.Stderr)) ctx, cancel := context.WithCancel(context.Background()) defer cancel() errC := make(chan error) go func() { err := cmd.ExecuteContext(ctx) assert.ErrorIs(t, err, context.Canceled) errC <- cmd.ExecuteContext(ctx) }() waitForPortForwardReady(t, buf) Expand All @@ -188,21 +187,21 @@ func TestPortForward(t *testing.T) { defer c2.Close() testDial(t, c2) testDial(t, c1) cancel() err = <-errC require.ErrorIs(t, err, context.Canceled) }) //nolint:paralleltest t.Run("TwoPorts", func(t *testing.T) { t.Parallel() var ( client = coderdtest.New(t, &coderdtest.Options{IncludeProvisionerD: true}) user = coderdtest.CreateFirstUser(t, client) _, workspace = runAgent(t, client, user.UserID) l1, p1 = setupTestListener(t, c.setupRemote(t))l2, p2 = setupTestListener(t, c.setupRemote(t))p1 = setupTestListener(t, c.setupRemote(t))p2 = setupTestListener(t, c.setupRemote(t))) t.Cleanup(func() { _ = l1.Close() _ = l2.Close() }) // Create a flags for listener 1 and listener 2. localAddress1, localFlag1 := c.setupLocal(t) Expand All @@ -218,9 +217,9 @@ func TestPortForward(t *testing.T) { cmd.SetOut(io.MultiWriter(buf, os.Stderr)) ctx, cancel := context.WithCancel(context.Background()) defer cancel() errC := make(chan error) go func() { err := cmd.ExecuteContext(ctx) assert.ErrorIs(t, err, context.Canceled) errC <- cmd.ExecuteContext(ctx) }() waitForPortForwardReady(t, buf) Expand All @@ -235,13 +234,17 @@ func TestPortForward(t *testing.T) { defer c2.Close() testDial(t, c2) testDial(t, c1) cancel() err = <-errC require.ErrorIs(t, err, context.Canceled) }) }) } // Test doing a TCP -> Unix forward. //nolint:paralleltest t.Run("TCP2Unix", func(t *testing.T) { t.Parallel() var ( client = coderdtest.New(t, &coderdtest.Options{IncludeProvisionerD: true}) user = coderdtest.CreateFirstUser(t, client) Expand All @@ -253,11 +256,8 @@ func TestPortForward(t *testing.T) { unixCase = cases[2] // Setup remote Unix listener. l1, p1 = setupTestListener(t, unixCase.setupRemote(t))p1 = setupTestListener(t, unixCase.setupRemote(t)) ) t.Cleanup(func() { _ = l1.Close() }) // Create a flag that forwards from local TCP to Unix listener 1. // Notably this is a --unix flag. Expand All @@ -272,9 +272,9 @@ func TestPortForward(t *testing.T) { cmd.SetOut(io.MultiWriter(buf, os.Stderr)) ctx, cancel := context.WithCancel(context.Background()) defer cancel() errC := make(chan error) go func() { err := cmd.ExecuteContext(ctx) assert.ErrorIs(t, err, context.Canceled) errC <- cmd.ExecuteContext(ctx) }() waitForPortForwardReady(t, buf) Expand All @@ -289,11 +289,15 @@ func TestPortForward(t *testing.T) { defer c2.Close() testDial(t, c2) testDial(t, c1) cancel() err = <-errC require.ErrorIs(t, err, context.Canceled) }) // Test doing TCP, UDP and Unix at the same time. //nolint:paralleltest t.Run("All", func(t *testing.T) { t.Parallel() var ( client = coderdtest.New(t, &coderdtest.Options{IncludeProvisionerD: true}) user = coderdtest.CreateFirstUser(t, client) Expand All @@ -311,10 +315,7 @@ func TestPortForward(t *testing.T) { continue } l, p := setupTestListener(t, c.setupRemote(t)) t.Cleanup(func() { _ = l.Close() }) p := setupTestListener(t, c.setupRemote(t)) localAddress, localFlag := c.setupLocal(t) dials = append(dials, addr{ Expand All @@ -332,10 +333,9 @@ func TestPortForward(t *testing.T) { cmd.SetOut(io.MultiWriter(buf, os.Stderr)) ctx, cancel := context.WithCancel(context.Background()) defer cancel() errC := make(chan error) go func() { err := cmd.ExecuteContext(ctx) assert.Error(t, err) assert.ErrorIs(t, err, context.Canceled) errC <- cmd.ExecuteContext(ctx) }() waitForPortForwardReady(t, buf) Expand All @@ -357,6 +357,10 @@ func TestPortForward(t *testing.T) { for i := len(conns) - 1; i >= 0; i-- { testDial(t, conns[i]) } cancel() err := <-errC require.ErrorIs(t, err, context.Canceled) }) } Expand Down Expand Up @@ -400,11 +404,15 @@ func runAgent(t *testing.T, client *codersdk.Client, userID uuid.UUID) ([]coders // Start workspace agent in a goroutine cmd, root := clitest.New(t, "agent", "--agent-token", agentToken, "--agent-url", client.URL.String()) clitest.SetupConfig(t, client, root) errC := make(chan error) agentCtx, agentCancel := context.WithCancel(ctx) t.Cleanup(agentCancel) t.Cleanup(func() { agentCancel() err := <-errC require.NoError(t, err) }) go func() { err := cmd.ExecuteContext(agentCtx) assert.NoError(t, err) errC <- cmd.ExecuteContext(agentCtx) }() coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID) Expand All @@ -416,18 +424,30 @@ func runAgent(t *testing.T, client *codersdk.Client, userID uuid.UUID) ([]coders // setupTestListener starts accepting connections and echoing a single packet. // Returns the listener and the listen port or Unix path. func setupTestListener(t *testing.T, l net.Listener) (net.Listener, string) { func setupTestListener(t *testing.T, l net.Listener) string { // Wait for listener to completely exit before releasing. done := make(chan struct{}) t.Cleanup(func() { _ = l.Close() <-done }) go func() { defer close(done) // Guard against testAccept running require after test completion. var wg sync.WaitGroup defer wg.Wait() for { c, err := l.Accept() if err != nil { return } go testAccept(t, c) wg.Add(1) go func() { testAccept(t, c) wg.Done() }() } }() Expand All @@ -438,7 +458,7 @@ func setupTestListener(t *testing.T, l net.Listener) (net.Listener, string) { addr = port } returnl, addr return addr } var dialTestPayload = []byte("dean-was-here123") Expand Down Expand Up @@ -502,8 +522,10 @@ func newThreadSafeBuffer() *threadSafeBuffer { } } var _ io.Reader = &threadSafeBuffer{} var _ io.Writer = &threadSafeBuffer{} var ( _ io.Reader = &threadSafeBuffer{} _ io.Writer = &threadSafeBuffer{} ) // Read implements io.Reader. func (b *threadSafeBuffer) Read(p []byte) (int, error) { Expand Down