@@ -3,10 +3,13 @@ package coderd_test
3
3
import (
4
4
"context"
5
5
"fmt"
6
+ "io"
7
+ "net"
6
8
"net/http"
7
9
"net/http/httptest"
8
- "net/netip"
9
10
"net/url"
11
+ "strconv"
12
+ "sync/atomic"
10
13
"testing"
11
14
12
15
"github.com/google/uuid"
@@ -35,9 +38,10 @@ func TestServerTailnet_AgentConn_OK(t *testing.T) {
35
38
defer cancel ()
36
39
37
40
// Connect through the ServerTailnet
38
- agentID ,_ ,serverTailnet := setupAgent (t ,nil )
41
+ agents ,serverTailnet := setupServerTailnetAgent (t ,1 )
42
+ a := agents [0 ]
39
43
40
- conn ,release ,err := serverTailnet .AgentConn (ctx ,agentID )
44
+ conn ,release ,err := serverTailnet .AgentConn (ctx ,a . id )
41
45
require .NoError (t ,err )
42
46
defer release ()
43
47
@@ -53,12 +57,13 @@ func TestServerTailnet_ReverseProxy(t *testing.T) {
53
57
ctx ,cancel := context .WithTimeout (context .Background (),testutil .WaitLong )
54
58
defer cancel ()
55
59
56
- agentID ,_ ,serverTailnet := setupAgent (t ,nil )
60
+ agents ,serverTailnet := setupServerTailnetAgent (t ,1 )
61
+ a := agents [0 ]
57
62
58
63
u ,err := url .Parse (fmt .Sprintf ("http://127.0.0.1:%d" ,codersdk .WorkspaceAgentHTTPAPIServerPort ))
59
64
require .NoError (t ,err )
60
65
61
- rp := serverTailnet .ReverseProxy (u ,u ,agentID )
66
+ rp := serverTailnet .ReverseProxy (u ,u ,a . id )
62
67
63
68
rw := httptest .NewRecorder ()
64
69
req := httptest .NewRequest (
@@ -74,13 +79,147 @@ func TestServerTailnet_ReverseProxy(t *testing.T) {
74
79
assert .Equal (t ,http .StatusOK ,res .StatusCode )
75
80
})
76
81
82
+ t .Run ("HostRewrite" ,func (t * testing.T ) {
83
+ t .Parallel ()
84
+
85
+ ctx ,cancel := context .WithTimeout (context .Background (),testutil .WaitLong )
86
+ defer cancel ()
87
+
88
+ agents ,serverTailnet := setupServerTailnetAgent (t ,1 )
89
+ a := agents [0 ]
90
+
91
+ u ,err := url .Parse (fmt .Sprintf ("http://127.0.0.1:%d" ,codersdk .WorkspaceAgentHTTPAPIServerPort ))
92
+ require .NoError (t ,err )
93
+
94
+ rp := serverTailnet .ReverseProxy (u ,u ,a .id )
95
+
96
+ req ,err := http .NewRequestWithContext (ctx ,http .MethodGet ,u .String (),nil )
97
+ require .NoError (t ,err )
98
+
99
+ // Ensure the reverse proxy director rewrites the url host to the agent's IP.
100
+ rp .Director (req )
101
+ assert .Equal (t ,
102
+ fmt .Sprintf ("[%s]:%d" ,tailnet .IPFromUUID (a .id ).String (),codersdk .WorkspaceAgentHTTPAPIServerPort ),
103
+ req .URL .Host ,
104
+ )
105
+ })
106
+
107
+ t .Run ("CachesConnection" ,func (t * testing.T ) {
108
+ t .Parallel ()
109
+
110
+ ctx ,cancel := context .WithTimeout (context .Background (),testutil .WaitLong )
111
+ defer cancel ()
112
+
113
+ agents ,serverTailnet := setupServerTailnetAgent (t ,1 )
114
+ a := agents [0 ]
115
+ port := ":4444"
116
+ ln ,err := a .TailnetConn ().Listen ("tcp" ,port )
117
+ require .NoError (t ,err )
118
+ wln := & wrappedListener {Listener :ln }
119
+
120
+ serverClosed := make (chan struct {})
121
+ go func () {
122
+ defer close (serverClosed )
123
+ //nolint:gosec
124
+ _ = http .Serve (wln ,http .HandlerFunc (func (w http.ResponseWriter ,r * http.Request ) {
125
+ w .WriteHeader (http .StatusOK )
126
+ w .Write ([]byte ("hello from agent" ))
127
+ }))
128
+ }()
129
+ defer func () {
130
+ // wait for server to close
131
+ <- serverClosed
132
+ }()
133
+
134
+ defer ln .Close ()
135
+
136
+ u ,err := url .Parse ("http://127.0.0.1" + port )
137
+ require .NoError (t ,err )
138
+
139
+ rp := serverTailnet .ReverseProxy (u ,u ,a .id )
140
+
141
+ for i := 0 ;i < 5 ;i ++ {
142
+ rw := httptest .NewRecorder ()
143
+ req := httptest .NewRequest (
144
+ http .MethodGet ,
145
+ u .String (),
146
+ nil ,
147
+ ).WithContext (ctx )
148
+
149
+ rp .ServeHTTP (rw ,req )
150
+ res := rw .Result ()
151
+
152
+ _ ,_ = io .Copy (io .Discard ,res .Body )
153
+ res .Body .Close ()
154
+ assert .Equal (t ,http .StatusOK ,res .StatusCode )
155
+ }
156
+
157
+ assert .Equal (t ,1 ,wln .getDials ())
158
+ })
159
+
160
+ t .Run ("NotReusedBetweenAgents" ,func (t * testing.T ) {
161
+ t .Parallel ()
162
+
163
+ ctx ,cancel := context .WithTimeout (context .Background (),testutil .WaitLong )
164
+ defer cancel ()
165
+
166
+ agents ,serverTailnet := setupServerTailnetAgent (t ,2 )
167
+ port := ":4444"
168
+
169
+ for i ,ag := range agents {
170
+ i := i
171
+ ln ,err := ag .TailnetConn ().Listen ("tcp" ,port )
172
+ require .NoError (t ,err )
173
+ wln := & wrappedListener {Listener :ln }
174
+
175
+ serverClosed := make (chan struct {})
176
+ go func () {
177
+ defer close (serverClosed )
178
+ //nolint:gosec
179
+ _ = http .Serve (wln ,http .HandlerFunc (func (w http.ResponseWriter ,r * http.Request ) {
180
+ w .WriteHeader (http .StatusOK )
181
+ w .Write ([]byte (strconv .Itoa (i )))
182
+ }))
183
+ }()
184
+ defer func () {//nolint:revive
185
+ // wait for server to close
186
+ <- serverClosed
187
+ }()
188
+
189
+ defer ln .Close ()//nolint:revive
190
+ }
191
+
192
+ u ,err := url .Parse ("http://127.0.0.1" + port )
193
+ require .NoError (t ,err )
194
+
195
+ for i ,ag := range agents {
196
+ rp := serverTailnet .ReverseProxy (u ,u ,ag .id )
197
+
198
+ rw := httptest .NewRecorder ()
199
+ req := httptest .NewRequest (
200
+ http .MethodGet ,
201
+ u .String (),
202
+ nil ,
203
+ ).WithContext (ctx )
204
+
205
+ rp .ServeHTTP (rw ,req )
206
+ res := rw .Result ()
207
+
208
+ body ,_ := io .ReadAll (res .Body )
209
+ res .Body .Close ()
210
+ assert .Equal (t ,http .StatusOK ,res .StatusCode )
211
+ assert .Equal (t ,strconv .Itoa (i ),string (body ))
212
+ }
213
+ })
214
+
77
215
t .Run ("HTTPSProxy" ,func (t * testing.T ) {
78
216
t .Parallel ()
79
217
80
218
ctx ,cancel := context .WithTimeout (context .Background (),testutil .WaitLong )
81
219
defer cancel ()
82
220
83
- agentID ,_ ,serverTailnet := setupAgent (t ,nil )
221
+ agents ,serverTailnet := setupServerTailnetAgent (t ,1 )
222
+ a := agents [0 ]
84
223
85
224
const expectedResponseCode = 209
86
225
// Test that we can proxy HTTPS traffic.
@@ -92,7 +231,7 @@ func TestServerTailnet_ReverseProxy(t *testing.T) {
92
231
uri ,err := url .Parse (s .URL )
93
232
require .NoError (t ,err )
94
233
95
- rp := serverTailnet .ReverseProxy (uri ,uri ,agentID )
234
+ rp := serverTailnet .ReverseProxy (uri ,uri ,a . id )
96
235
97
236
rw := httptest .NewRecorder ()
98
237
req := httptest .NewRequest (
@@ -109,44 +248,74 @@ func TestServerTailnet_ReverseProxy(t *testing.T) {
109
248
})
110
249
}
111
250
112
- func setupAgent (t * testing.T ,agentAddresses []netip.Prefix ) (uuid.UUID , agent.Agent ,* coderd.ServerTailnet ) {
251
+ type wrappedListener struct {
252
+ net.Listener
253
+ dials int32
254
+ }
255
+
256
+ func (w * wrappedListener )Accept () (net.Conn ,error ) {
257
+ conn ,err := w .Listener .Accept ()
258
+ if err != nil {
259
+ return nil ,err
260
+ }
261
+
262
+ atomic .AddInt32 (& w .dials ,1 )
263
+ return conn ,nil
264
+ }
265
+
266
+ func (w * wrappedListener )getDials ()int {
267
+ return int (atomic .LoadInt32 (& w .dials ))
268
+ }
269
+
270
+ type agentWithID struct {
271
+ id uuid.UUID
272
+ agent.Agent
273
+ }
274
+
275
+ func setupServerTailnetAgent (t * testing.T ,agentNum int ) ([]agentWithID ,* coderd.ServerTailnet ) {
113
276
logger := slogtest .Make (t ,nil ).Leveled (slog .LevelDebug )
114
277
derpMap ,derpServer := tailnettest .RunDERPAndSTUN (t )
115
- manifest := agentsdk.Manifest {
116
- AgentID :uuid .New (),
117
- DERPMap :derpMap ,
118
- }
119
278
120
279
coord := tailnet .NewCoordinator (logger )
121
280
t .Cleanup (func () {
122
281
_ = coord .Close ()
123
282
})
124
283
125
- c := agenttest .NewClient (t ,logger ,manifest .AgentID ,manifest ,make (chan * agentsdk.Stats ,50 ),coord )
126
- t .Cleanup (c .Close )
284
+ agents := []agentWithID {}
127
285
128
- options := agent.Options {
129
- Client :c ,
130
- Filesystem :afero .NewMemMapFs (),
131
- Logger :logger .Named ("agent" ),
132
- Addresses :agentAddresses ,
133
- }
286
+ for i := 0 ;i < agentNum ;i ++ {
287
+ manifest := agentsdk.Manifest {
288
+ AgentID :uuid .New (),
289
+ DERPMap :derpMap ,
290
+ }
134
291
135
- ag := agent .New (options )
136
- t .Cleanup (func () {
137
- _ = ag .Close ()
138
- })
292
+ c := agenttest .NewClient (t ,logger ,manifest .AgentID ,manifest ,make (chan * agentsdk.Stats ,50 ),coord )
293
+ t .Cleanup (c .Close )
294
+
295
+ options := agent.Options {
296
+ Client :c ,
297
+ Filesystem :afero .NewMemMapFs (),
298
+ Logger :logger .Named ("agent" ),
299
+ }
139
300
140
- // Wait for the agent to connect.
141
- require .Eventually (t ,func ()bool {
142
- return coord .Node (manifest .AgentID )!= nil
143
- },testutil .WaitShort ,testutil .IntervalFast )
301
+ ag := agent .New (options )
302
+ t .Cleanup (func () {
303
+ _ = ag .Close ()
304
+ })
305
+
306
+ // Wait for the agent to connect.
307
+ require .Eventually (t ,func ()bool {
308
+ return coord .Node (manifest .AgentID )!= nil
309
+ },testutil .WaitShort ,testutil .IntervalFast )
310
+
311
+ agents = append (agents ,agentWithID {id :manifest .AgentID ,Agent :ag })
312
+ }
144
313
145
314
serverTailnet ,err := coderd .NewServerTailnet (
146
315
context .Background (),
147
316
logger ,
148
317
derpServer ,
149
- func ()* tailcfg.DERPMap {return manifest . DERPMap },
318
+ func ()* tailcfg.DERPMap {return derpMap },
150
319
false ,
151
320
func (context.Context ) (tailnet.MultiAgentConn ,error ) {return coord .ServeMultiAgent (uuid .New ()),nil },
152
321
trace .NewNoopTracerProvider (),
@@ -157,5 +326,5 @@ func setupAgent(t *testing.T, agentAddresses []netip.Prefix) (uuid.UUID, agent.A
157
326
_ = serverTailnet .Close ()
158
327
})
159
328
160
- return manifest . AgentID , ag ,serverTailnet
329
+ return agents ,serverTailnet
161
330
}