@@ -1,7 +1,6 @@ package wsnet import ( "bytes" "context" "crypto/rand" "errors" Expand All @@ -15,6 +14,8 @@ import ( "cdr.dev/slog/sloggers/slogtest" "github.com/pion/ice/v2" "github.com/pion/webrtc/v3" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func ExampleDial_basic() { Expand Down Expand Up @@ -50,37 +51,30 @@ func ExampleDial_basic() { // You now have access to the proxied remote port in `conn`. } // nolint:gocognit,gocyclo func TestDial(t *testing.T) { t.Run("Ping", func(t *testing.T) { t.Parallel() connectAddr, listenAddr := createDumbBroker(t) _, err := Listen(context.Background(), slogtest.Make(t, nil), listenAddr, "") if err != nil { t.Error(err) return } l, err := Listen(context.Background(), slogtest.Make(t, nil), listenAddr, "") require.NoError(t, err) defer l.Close() dialer, err := DialWebsocket(context.Background(), connectAddr, nil) if err != nil { t.Error(err) return } require.NoError(t, err) err = dialer.Ping(context.Background()) if err != nil { t.Error(err) } require.NoError(t, err) }) t.Run("Ping Close", func(t *testing.T) { t.Parallel() connectAddr, listenAddr := createDumbBroker(t) _, err := Listen(context.Background(), slogtest.Make(t, nil), listenAddr, "") if err != nil { t.Error(err) return } l, err := Listen(context.Background(), slogtest.Make(t, nil), listenAddr, "") require.NoError(t, err) defer l.Close() turnAddr, closeTurn := createTURNServer(t, ice.SchemeTypeTURN) dialer, err := DialWebsocket(context.Background(), connectAddr, &DialOptions{ ICEServers: []webrtc.ICEServer{{ Expand All @@ -90,167 +84,124 @@ func TestDial(t *testing.T) { CredentialType: webrtc.ICECredentialTypePassword, }}, }) if err != nil { t.Error(err) return } require.NoError(t, err) _ = dialer.Ping(context.Background()) closeTurn() err = dialer.Ping(context.Background()) if err != io.EOF { t.Error(err) return } assert.ErrorIs(t, err, io.EOF) }) t.Run("OPError", func(t *testing.T) { t.Parallel() connectAddr, listenAddr := createDumbBroker(t) _, err := Listen(context.Background(), slogtest.Make(t, nil), listenAddr, "") if err != nil { t.Error(err) return } l, err := Listen(context.Background(), slogtest.Make(t, nil), listenAddr, "") require.NoError(t, err) defer l.Close() dialer, err := DialWebsocket(context.Background(), connectAddr, nil) if err != nil { t.Error(err) } require.NoError(t, err) _, err = dialer.DialContext(context.Background(), "tcp", "localhost:100") if err == nil { t.Error("should have gotten err") return } _, ok := err.(*net.OpError) if !ok { t.Error("invalid error type returned") return } assert.Error(t, err) // Double pointer intended. netErr := &net.OpError{} assert.ErrorAs(t, err, &netErr) }) t.Run("Proxy", func(t *testing.T) { t.Parallel() listener, err := net.Listen("tcp", "0.0.0.0:0") if err != nil { t.Error(err) return } require.NoError(t, err) msg := []byte("Hello!") go func() { conn, err := listener.Accept() if err != nil { t.Error(err) } require.NoError(t, err) _, _ = conn.Write(msg) }() connectAddr, listenAddr := createDumbBroker(t) _, err = Listen(context.Background(), slogtest.Make(t, nil), listenAddr, "") if err != nil { t.Error(err) return } l, err := Listen(context.Background(), slogtest.Make(t, nil), listenAddr, "") require.NoError(t, err) defer l.Close() dialer, err := DialWebsocket(context.Background(), connectAddr, nil) if err != nil { t.Error(err) return } require.NoError(t, err) conn, err := dialer.DialContext(context.Background(), listener.Addr().Network(), listener.Addr().String()) if err != nil { t.Error(err) return } require.NoError(t, err) rec := make([]byte, len(msg)) _, err = conn.Read(rec) if err != nil { t.Error(err) return } if !bytes.Equal(msg, rec) { t.Error("bytes were different", string(msg), string(rec)) } require.NoError(t, err) assert.Equal(t, msg, rec) }) // Expect that we'd get an EOF on the server closing. t.Run("EOF on Close", func(t *testing.T) { t.Parallel() listener, err := net.Listen("tcp", "0.0.0.0:0") if err != nil { t.Error(err) return } require.NoError(t, err) go func() { _, _ = listener.Accept() }() connectAddr, listenAddr := createDumbBroker(t) srv, err := Listen(context.Background(), slogtest.Make(t, nil), listenAddr, "") if err != nil { t.Error(err) return } l, err := Listen(context.Background(), slogtest.Make(t, nil), listenAddr, "") require.NoError(t, err) defer l.Close() dialer, err := DialWebsocket(context.Background(), connectAddr, nil) if err != nil { t.Error(err) } require.NoError(t, err) conn, err := dialer.DialContext(context.Background(), listener.Addr().Network(), listener.Addr().String()) if err != nil { t.Error(err) return } go srv.Close() require.NoError(t, err) go l.Close() rec := make([]byte, 16) _, err = conn.Read(rec) if !errors.Is(err, io.EOF) { t.Error(err) return } assert.ErrorIs(t, err, io.EOF) }) t.Run("Disconnect", func(t *testing.T) { t.Parallel() connectAddr, listenAddr := createDumbBroker(t) _, err := Listen(context.Background(), slogtest.Make(t, nil), listenAddr, "") if err != nil { t.Error(err) return } l, err := Listen(context.Background(), slogtest.Make(t, nil), listenAddr, "") require.NoError(t, err) defer l.Close() dialer, err := DialWebsocket(context.Background(), connectAddr, nil) if err != nil { t.Error(err) return } require.NoError(t, err) err = dialer.Close() if err != nil { t.Error(err) return } require.NoError(t, err) err = dialer.Ping(context.Background()) if err != webrtc.ErrConnectionClosed { t.Error(err) } assert.ErrorIs(t, err, webrtc.ErrConnectionClosed) }) t.Run("Disconnect DialContext", func(t *testing.T) { t.Parallel() tcpListener, err := net.Listen("tcp", "0.0.0.0:0") if err != nil { t.Error(err) return } require.NoError(t, err) go func() { _, _ = tcpListener.Accept() }() connectAddr, listenAddr := createDumbBroker(t) _, err = Listen(context.Background(), slogtest.Make(t, nil), listenAddr, "") if err != nil { t.Error(err) return } l, err := Listen(context.Background(), slogtest.Make(t, nil), listenAddr, "") require.NoError(t, err) defer l.Close() turnAddr, closeTurn := createTURNServer(t, ice.SchemeTypeTURN) dialer, err := DialWebsocket(context.Background(), connectAddr, &DialOptions{ ICEServers: []webrtc.ICEServer{{ Expand All @@ -260,42 +211,32 @@ func TestDial(t *testing.T) { CredentialType: webrtc.ICECredentialTypePassword, }}, }) if err != nil { t.Error(err) return } require.NoError(t, err) conn, err := dialer.DialContext(context.Background(), "tcp", tcpListener.Addr().String()) if err != nil { t.Error(err) return } require.NoError(t, err) // Close the TURN server before reading... // WebRTC connections take a few seconds to timeout. closeTurn() _, err = conn.Read(make([]byte, 16)) if err != io.EOF { t.Error(err) return } assert.ErrorIs(t, err, io.EOF) }) t.Run("Closed", func(t *testing.T) { t.Parallel() connectAddr, listenAddr := createDumbBroker(t) _, err := Listen(context.Background(), slogtest.Make(t, nil), listenAddr, "") if err != nil { t.Error(err) return } l, err := Listen(context.Background(), slogtest.Make(t, nil), listenAddr, "") require.NoError(t, err) defer l.Close() dialer, err := DialWebsocket(context.Background(), connectAddr, nil) if err != nil { t.Error(err) return } require.NoError(t, err) go func() { _ = dialer.Close() }() select { case <-dialer.Closed(): case <-time.NewTimer(time.Second).C: Expand Down Expand Up @@ -334,11 +275,12 @@ func BenchmarkThroughput(b *testing.B) { } }() connectAddr, listenAddr := createDumbBroker(b) _ , err = Listen(context.Background(), slogtest.Make(b, nil), listenAddr, "")l , err: = Listen(context.Background(), slogtest.Make(b, nil), listenAddr, "")if err != nil { b.Error(err) return } defer l.Close() dialer, err := DialWebsocket(context.Background(), connectAddr, nil) if err != nil { Expand Down