@@ -25,6 +25,14 @@ import (
25
25
"github.com/coder/serpent"
26
26
)
27
27
28
+ var (
29
+ // noAddr is the zero-value of netip.Addr, and is not a valid address. We use it to identify
30
+ // when the local address is not specified in port-forward flags.
31
+ noAddr netip.Addr
32
+ ipv6Loopback = netip .MustParseAddr ("::1" )
33
+ ipv4Loopback = netip .MustParseAddr ("127.0.0.1" )
34
+ )
35
+
28
36
func (r * RootCmd )portForward ()* serpent.Command {
29
37
var (
30
38
tcpForwards []string // <port>:<port>
@@ -122,7 +130,7 @@ func (r *RootCmd) portForward() *serpent.Command {
122
130
// Start all listeners.
123
131
var (
124
132
wg = new (sync.WaitGroup )
125
- listeners = make ([]net.Listener ,len (specs ))
133
+ listeners = make ([]net.Listener ,0 , len (specs )* 2 )
126
134
closeAllListeners = func () {
127
135
logger .Debug (ctx ,"closing all listeners" )
128
136
for _ ,l := range listeners {
@@ -135,13 +143,26 @@ func (r *RootCmd) portForward() *serpent.Command {
135
143
)
136
144
defer closeAllListeners ()
137
145
138
- for i ,spec := range specs {
146
+ for _ ,spec := range specs {
147
+
148
+ if spec .listenHost == noAddr {
149
+ // first, opportunistically try to listen on IPv6
150
+ spec6 := spec
151
+ spec6 .listenHost = ipv6Loopback
152
+ l6 ,err6 := listenAndPortForward (ctx ,inv ,conn ,wg ,spec6 ,logger )
153
+ if err6 != nil {
154
+ logger .Info (ctx ,"failed to opportunistically listen on IPv6" ,slog .F ("spec" ,spec ),slog .Error (err6 ))
155
+ }else {
156
+ listeners = append (listeners ,l6 )
157
+ }
158
+ spec .listenHost = ipv4Loopback
159
+ }
139
160
l ,err := listenAndPortForward (ctx ,inv ,conn ,wg ,spec ,logger )
140
161
if err != nil {
141
162
logger .Error (ctx ,"failed to listen" ,slog .F ("spec" ,spec ),slog .Error (err ))
142
163
return err
143
164
}
144
- listeners [ i ] = l
165
+ listeners = append ( listeners , l )
145
166
}
146
167
147
168
stopUpdating := client .UpdateWorkspaceUsageContext (ctx ,workspace .ID )
@@ -206,12 +227,19 @@ func listenAndPortForward(
206
227
spec portForwardSpec ,
207
228
logger slog.Logger ,
208
229
) (net.Listener ,error ) {
209
- logger = logger .With (slog .F ("network" ,spec .listenNetwork ),slog .F ("address" ,spec .listenAddress ))
210
- _ ,_ = fmt .Fprintf (inv .Stderr ,"Forwarding '%v://%v' locally to '%v://%v' in the workspace\n " ,spec .listenNetwork ,spec .listenAddress ,spec .dialNetwork ,spec .dialAddress )
230
+ logger = logger .With (
231
+ slog .F ("network" ,spec .network ),
232
+ slog .F ("listen_host" ,spec .listenHost ),
233
+ slog .F ("listen_port" ,spec .listenPort ),
234
+ )
235
+ listenAddress := netip .AddrPortFrom (spec .listenHost ,spec .listenPort )
236
+ dialAddress := fmt .Sprintf ("127.0.0.1:%d" ,spec .dialPort )
237
+ _ ,_ = fmt .Fprintf (inv .Stderr ,"Forwarding '%s://%s' locally to '%s://%s' in the workspace\n " ,
238
+ spec .network ,listenAddress ,spec .network ,dialAddress )
211
239
212
- l ,err := inv .Net .Listen (spec .listenNetwork , spec . listenAddress )
240
+ l ,err := inv .Net .Listen (spec .network , listenAddress . String () )
213
241
if err != nil {
214
- return nil ,xerrors .Errorf ("listen '%v ://%v ': %w" ,spec .listenNetwork , spec . listenAddress ,err )
242
+ return nil ,xerrors .Errorf ("listen '%s ://%s ': %w" ,spec .network , listenAddress . String () ,err )
215
243
}
216
244
logger .Debug (ctx ,"listening" )
217
245
@@ -226,24 +254,31 @@ func listenAndPortForward(
226
254
logger .Debug (ctx ,"listener closed" )
227
255
return
228
256
}
229
- _ ,_ = fmt .Fprintf (inv .Stderr ,"Error accepting connection from '%v://%v': %v\n " ,spec .listenNetwork ,spec .listenAddress ,err )
257
+ _ ,_ = fmt .Fprintf (inv .Stderr ,
258
+ "Error accepting connection from '%s://%s': %v\n " ,
259
+ spec .network ,listenAddress .String (),err )
230
260
_ ,_ = fmt .Fprintln (inv .Stderr ,"Killing listener" )
231
261
return
232
262
}
233
- logger .Debug (ctx ,"accepted connection" ,slog .F ("remote_addr" ,netConn .RemoteAddr ()))
263
+ logger .Debug (ctx ,"accepted connection" ,
264
+ slog .F ("remote_addr" ,netConn .RemoteAddr ()))
234
265
235
266
go func (netConn net.Conn ) {
236
267
defer netConn .Close ()
237
- remoteConn ,err := conn .DialContext (ctx ,spec .dialNetwork , spec . dialAddress )
268
+ remoteConn ,err := conn .DialContext (ctx ,spec .network , dialAddress )
238
269
if err != nil {
239
- _ ,_ = fmt .Fprintf (inv .Stderr ,"Failed to dial '%v://%v' in workspace: %s\n " ,spec .dialNetwork ,spec .dialAddress ,err )
270
+ _ ,_ = fmt .Fprintf (inv .Stderr ,
271
+ "Failed to dial '%s://%s' in workspace: %s\n " ,
272
+ spec .network ,dialAddress ,err )
240
273
return
241
274
}
242
275
defer remoteConn .Close ()
243
- logger .Debug (ctx ,"dialed remote" ,slog .F ("remote_addr" ,netConn .RemoteAddr ()))
276
+ logger .Debug (ctx ,
277
+ "dialed remote" ,slog .F ("remote_addr" ,netConn .RemoteAddr ()))
244
278
245
279
agentssh .Bicopy (ctx ,netConn ,remoteConn )
246
- logger .Debug (ctx ,"connection closing" ,slog .F ("remote_addr" ,netConn .RemoteAddr ()))
280
+ logger .Debug (ctx ,
281
+ "connection closing" ,slog .F ("remote_addr" ,netConn .RemoteAddr ()))
247
282
}(netConn )
248
283
}
249
284
}(spec )
@@ -252,58 +287,48 @@ func listenAndPortForward(
252
287
}
253
288
254
289
type portForwardSpec struct {
255
- listenNetwork string // tcp, udp
256
- listenAddress string // <ip>:<port> or path
257
-
258
- dialNetwork string // tcp, udp
259
- dialAddress string // <ip>:<port> or path
290
+ network string // tcp, udp
291
+ listenHost netip.Addr
292
+ listenPort ,dialPort uint16
260
293
}
261
294
262
295
func parsePortForwards (tcpSpecs ,udpSpecs []string ) ([]portForwardSpec ,error ) {
263
296
specs := []portForwardSpec {}
264
297
265
298
for _ ,specEntry := range tcpSpecs {
266
299
for _ ,spec := range strings .Split (specEntry ,"," ) {
267
- ports ,err := parseSrcDestPorts (strings .TrimSpace (spec ))
300
+ pfSpecs ,err := parseSrcDestPorts (strings .TrimSpace (spec ))
268
301
if err != nil {
269
302
return nil ,xerrors .Errorf ("failed to parse TCP port-forward specification %q: %w" ,spec ,err )
270
303
}
271
304
272
- for _ ,port := range ports {
273
- specs = append (specs ,portForwardSpec {
274
- listenNetwork :"tcp" ,
275
- listenAddress :port .local .String (),
276
- dialNetwork :"tcp" ,
277
- dialAddress :port .remote .String (),
278
- })
305
+ for _ ,pfSpec := range pfSpecs {
306
+ pfSpec .network = "tcp"
307
+ specs = append (specs ,pfSpec )
279
308
}
280
309
}
281
310
}
282
311
283
312
for _ ,specEntry := range udpSpecs {
284
313
for _ ,spec := range strings .Split (specEntry ,"," ) {
285
- ports ,err := parseSrcDestPorts (strings .TrimSpace (spec ))
314
+ pfSpecs ,err := parseSrcDestPorts (strings .TrimSpace (spec ))
286
315
if err != nil {
287
316
return nil ,xerrors .Errorf ("failed to parse UDP port-forward specification %q: %w" ,spec ,err )
288
317
}
289
318
290
- for _ ,port := range ports {
291
- specs = append (specs ,portForwardSpec {
292
- listenNetwork :"udp" ,
293
- listenAddress :port .local .String (),
294
- dialNetwork :"udp" ,
295
- dialAddress :port .remote .String (),
296
- })
319
+ for _ ,pfSpec := range pfSpecs {
320
+ pfSpec .network = "udp"
321
+ specs = append (specs ,pfSpec )
297
322
}
298
323
}
299
324
}
300
325
301
326
// Check for duplicate entries.
302
327
locals := map [string ]struct {}{}
303
328
for _ ,spec := range specs {
304
- localStr := fmt .Sprintf ("%v:%v " ,spec .listenNetwork ,spec .listenAddress )
329
+ localStr := fmt .Sprintf ("%s:%s:%d " ,spec .network ,spec .listenHost , spec . listenPort )
305
330
if _ ,ok := locals [localStr ];ok {
306
- return nil ,xerrors .Errorf ("local %v %v is specified twice" ,spec .listenNetwork ,spec .listenAddress )
331
+ return nil ,xerrors .Errorf ("local %s host:%s port:%d is specified twice" ,spec .network ,spec .listenHost , spec . listenPort )
307
332
}
308
333
locals [localStr ]= struct {}{}
309
334
}
@@ -323,10 +348,6 @@ func parsePort(in string) (uint16, error) {
323
348
return uint16 (port ),nil
324
349
}
325
350
326
- type parsedSrcDestPort struct {
327
- local ,remote netip.AddrPort
328
- }
329
-
330
351
// specRegexp matches port specs. It handles all the following formats:
331
352
//
332
353
// 8000
@@ -347,21 +368,19 @@ type parsedSrcDestPort struct {
347
368
// 9: end or remote port range
348
369
var specRegexp = regexp .MustCompile (`^((\[[0-9a-fA-F:]+]|\d+\.\d+\.\d+\.\d+):)?(\d+)(-(\d+))?(:(\d+)(-(\d+))?)?$` )
349
370
350
- func parseSrcDestPorts (in string ) ([]parsedSrcDestPort ,error ) {
351
- var (
352
- err error
353
- localAddr = netip .AddrFrom4 ([4 ]byte {127 ,0 ,0 ,1 })
354
- remoteAddr = netip .AddrFrom4 ([4 ]byte {127 ,0 ,0 ,1 })
355
- )
371
+ func parseSrcDestPorts (in string ) ([]portForwardSpec ,error ) {
356
372
groups := specRegexp .FindStringSubmatch (in )
357
373
if len (groups )== 0 {
358
374
return nil ,xerrors .Errorf ("invalid port specification %q" ,in )
359
375
}
376
+
377
+ var localAddr netip.Addr
360
378
if groups [2 ]!= "" {
361
- localAddr ,err = netip .ParseAddr (strings .Trim (groups [2 ],"[]" ))
379
+ parsedAddr ,err : =netip .ParseAddr (strings .Trim (groups [2 ],"[]" ))
362
380
if err != nil {
363
381
return nil ,xerrors .Errorf ("invalid IP address %q" ,groups [2 ])
364
382
}
383
+ localAddr = parsedAddr
365
384
}
366
385
367
386
local ,err := parsePortRange (groups [3 ],groups [5 ])
@@ -378,11 +397,12 @@ func parseSrcDestPorts(in string) ([]parsedSrcDestPort, error) {
378
397
if len (local )!= len (remote ) {
379
398
return nil ,xerrors .Errorf ("port ranges must be the same length, got %d ports forwarded to %d ports" ,len (local ),len (remote ))
380
399
}
381
- var out []parsedSrcDestPort
400
+ var out []portForwardSpec
382
401
for i := range local {
383
- out = append (out ,parsedSrcDestPort {
384
- local :netip .AddrPortFrom (localAddr ,local [i ]),
385
- remote :netip .AddrPortFrom (remoteAddr ,remote [i ]),
402
+ out = append (out ,portForwardSpec {
403
+ listenHost :localAddr ,
404
+ listenPort :local [i ],
405
+ dialPort :remote [i ],
386
406
})
387
407
}
388
408
return out ,nil