@@ -25,6 +25,14 @@ import (
25
25
"github.com/coder/serpent"
26
26
)
27
27
28
+ var (
29
+ // noAddr is the zero-value of netip.Addr, and is not a valid address. We use it to identify
30
+ // when the local address is not specified in port-forward flags.
31
+ noAddr netip.Addr
32
+ ipv6Loopback = netip .MustParseAddr ("::1" )
33
+ ipv4Loopback = netip .MustParseAddr ("127.0.0.1" )
34
+ )
35
+
28
36
func (r * RootCmd )portForward ()* serpent.Command {
29
37
var (
30
38
tcpForwards []string // <port>:<port>
@@ -122,7 +130,7 @@ func (r *RootCmd) portForward() *serpent.Command {
122
130
// Start all listeners.
123
131
var (
124
132
wg = new (sync.WaitGroup )
125
- listeners = make ([]net.Listener ,len (specs ))
133
+ listeners = make ([]net.Listener ,0 , len (specs )* 2 )
126
134
closeAllListeners = func () {
127
135
logger .Debug (ctx ,"closing all listeners" )
128
136
for _ ,l := range listeners {
@@ -135,13 +143,25 @@ func (r *RootCmd) portForward() *serpent.Command {
135
143
)
136
144
defer closeAllListeners ()
137
145
138
- for i ,spec := range specs {
146
+ for _ ,spec := range specs {
147
+ if spec .listenHost == noAddr {
148
+ // first, opportunistically try to listen on IPv6
149
+ spec6 := spec
150
+ spec6 .listenHost = ipv6Loopback
151
+ l6 ,err6 := listenAndPortForward (ctx ,inv ,conn ,wg ,spec6 ,logger )
152
+ if err6 != nil {
153
+ logger .Info (ctx ,"failed to opportunistically listen on IPv6" ,slog .F ("spec" ,spec ),slog .Error (err6 ))
154
+ }else {
155
+ listeners = append (listeners ,l6 )
156
+ }
157
+ spec .listenHost = ipv4Loopback
158
+ }
139
159
l ,err := listenAndPortForward (ctx ,inv ,conn ,wg ,spec ,logger )
140
160
if err != nil {
141
161
logger .Error (ctx ,"failed to listen" ,slog .F ("spec" ,spec ),slog .Error (err ))
142
162
return err
143
163
}
144
- listeners [ i ] = l
164
+ listeners = append ( listeners , l )
145
165
}
146
166
147
167
stopUpdating := client .UpdateWorkspaceUsageContext (ctx ,workspace .ID )
@@ -206,12 +226,19 @@ func listenAndPortForward(
206
226
spec portForwardSpec ,
207
227
logger slog.Logger ,
208
228
) (net.Listener ,error ) {
209
- logger = logger .With (slog .F ("network" ,spec .listenNetwork ),slog .F ("address" ,spec .listenAddress ))
210
- _ ,_ = fmt .Fprintf (inv .Stderr ,"Forwarding '%v://%v' locally to '%v://%v' in the workspace\n " ,spec .listenNetwork ,spec .listenAddress ,spec .dialNetwork ,spec .dialAddress )
229
+ logger = logger .With (
230
+ slog .F ("network" ,spec .network ),
231
+ slog .F ("listen_host" ,spec .listenHost ),
232
+ slog .F ("listen_port" ,spec .listenPort ),
233
+ )
234
+ listenAddress := netip .AddrPortFrom (spec .listenHost ,spec .listenPort )
235
+ dialAddress := fmt .Sprintf ("127.0.0.1:%d" ,spec .dialPort )
236
+ _ ,_ = fmt .Fprintf (inv .Stderr ,"Forwarding '%s://%s' locally to '%s://%s' in the workspace\n " ,
237
+ spec .network ,listenAddress ,spec .network ,dialAddress )
211
238
212
- l ,err := inv .Net .Listen (spec .listenNetwork , spec . listenAddress )
239
+ l ,err := inv .Net .Listen (spec .network , listenAddress . String () )
213
240
if err != nil {
214
- return nil ,xerrors .Errorf ("listen '%v ://%v ': %w" ,spec .listenNetwork , spec . listenAddress ,err )
241
+ return nil ,xerrors .Errorf ("listen '%s ://%s ': %w" ,spec .network , listenAddress . String () ,err )
215
242
}
216
243
logger .Debug (ctx ,"listening" )
217
244
@@ -226,24 +253,31 @@ func listenAndPortForward(
226
253
logger .Debug (ctx ,"listener closed" )
227
254
return
228
255
}
229
- _ ,_ = fmt .Fprintf (inv .Stderr ,"Error accepting connection from '%v://%v': %v\n " ,spec .listenNetwork ,spec .listenAddress ,err )
256
+ _ ,_ = fmt .Fprintf (inv .Stderr ,
257
+ "Error accepting connection from '%s://%s': %v\n " ,
258
+ spec .network ,listenAddress .String (),err )
230
259
_ ,_ = fmt .Fprintln (inv .Stderr ,"Killing listener" )
231
260
return
232
261
}
233
- logger .Debug (ctx ,"accepted connection" ,slog .F ("remote_addr" ,netConn .RemoteAddr ()))
262
+ logger .Debug (ctx ,"accepted connection" ,
263
+ slog .F ("remote_addr" ,netConn .RemoteAddr ()))
234
264
235
265
go func (netConn net.Conn ) {
236
266
defer netConn .Close ()
237
- remoteConn ,err := conn .DialContext (ctx ,spec .dialNetwork , spec . dialAddress )
267
+ remoteConn ,err := conn .DialContext (ctx ,spec .network , dialAddress )
238
268
if err != nil {
239
- _ ,_ = fmt .Fprintf (inv .Stderr ,"Failed to dial '%v://%v' in workspace: %s\n " ,spec .dialNetwork ,spec .dialAddress ,err )
269
+ _ ,_ = fmt .Fprintf (inv .Stderr ,
270
+ "Failed to dial '%s://%s' in workspace: %s\n " ,
271
+ spec .network ,dialAddress ,err )
240
272
return
241
273
}
242
274
defer remoteConn .Close ()
243
- logger .Debug (ctx ,"dialed remote" ,slog .F ("remote_addr" ,netConn .RemoteAddr ()))
275
+ logger .Debug (ctx ,
276
+ "dialed remote" ,slog .F ("remote_addr" ,netConn .RemoteAddr ()))
244
277
245
278
agentssh .Bicopy (ctx ,netConn ,remoteConn )
246
- logger .Debug (ctx ,"connection closing" ,slog .F ("remote_addr" ,netConn .RemoteAddr ()))
279
+ logger .Debug (ctx ,
280
+ "connection closing" ,slog .F ("remote_addr" ,netConn .RemoteAddr ()))
247
281
}(netConn )
248
282
}
249
283
}(spec )
@@ -252,58 +286,48 @@ func listenAndPortForward(
252
286
}
253
287
254
288
type portForwardSpec struct {
255
- listenNetwork string // tcp, udp
256
- listenAddress string // <ip>:<port> or path
257
-
258
- dialNetwork string // tcp, udp
259
- dialAddress string // <ip>:<port> or path
289
+ network string // tcp, udp
290
+ listenHost netip.Addr
291
+ listenPort ,dialPort uint16
260
292
}
261
293
262
294
func parsePortForwards (tcpSpecs ,udpSpecs []string ) ([]portForwardSpec ,error ) {
263
295
specs := []portForwardSpec {}
264
296
265
297
for _ ,specEntry := range tcpSpecs {
266
298
for _ ,spec := range strings .Split (specEntry ,"," ) {
267
- ports ,err := parseSrcDestPorts (strings .TrimSpace (spec ))
299
+ pfSpecs ,err := parseSrcDestPorts (strings .TrimSpace (spec ))
268
300
if err != nil {
269
301
return nil ,xerrors .Errorf ("failed to parse TCP port-forward specification %q: %w" ,spec ,err )
270
302
}
271
303
272
- for _ ,port := range ports {
273
- specs = append (specs ,portForwardSpec {
274
- listenNetwork :"tcp" ,
275
- listenAddress :port .local .String (),
276
- dialNetwork :"tcp" ,
277
- dialAddress :port .remote .String (),
278
- })
304
+ for _ ,pfSpec := range pfSpecs {
305
+ pfSpec .network = "tcp"
306
+ specs = append (specs ,pfSpec )
279
307
}
280
308
}
281
309
}
282
310
283
311
for _ ,specEntry := range udpSpecs {
284
312
for _ ,spec := range strings .Split (specEntry ,"," ) {
285
- ports ,err := parseSrcDestPorts (strings .TrimSpace (spec ))
313
+ pfSpecs ,err := parseSrcDestPorts (strings .TrimSpace (spec ))
286
314
if err != nil {
287
315
return nil ,xerrors .Errorf ("failed to parse UDP port-forward specification %q: %w" ,spec ,err )
288
316
}
289
317
290
- for _ ,port := range ports {
291
- specs = append (specs ,portForwardSpec {
292
- listenNetwork :"udp" ,
293
- listenAddress :port .local .String (),
294
- dialNetwork :"udp" ,
295
- dialAddress :port .remote .String (),
296
- })
318
+ for _ ,pfSpec := range pfSpecs {
319
+ pfSpec .network = "udp"
320
+ specs = append (specs ,pfSpec )
297
321
}
298
322
}
299
323
}
300
324
301
325
// Check for duplicate entries.
302
326
locals := map [string ]struct {}{}
303
327
for _ ,spec := range specs {
304
- localStr := fmt .Sprintf ("%v:%v " ,spec .listenNetwork ,spec .listenAddress )
328
+ localStr := fmt .Sprintf ("%s:%s:%d " ,spec .network ,spec .listenHost , spec . listenPort )
305
329
if _ ,ok := locals [localStr ];ok {
306
- return nil ,xerrors .Errorf ("local %v %v is specified twice" ,spec .listenNetwork ,spec .listenAddress )
330
+ return nil ,xerrors .Errorf ("local %s host:%s port:%d is specified twice" ,spec .network ,spec .listenHost , spec . listenPort )
307
331
}
308
332
locals [localStr ]= struct {}{}
309
333
}
@@ -323,10 +347,6 @@ func parsePort(in string) (uint16, error) {
323
347
return uint16 (port ),nil
324
348
}
325
349
326
- type parsedSrcDestPort struct {
327
- local ,remote netip.AddrPort
328
- }
329
-
330
350
// specRegexp matches port specs. It handles all the following formats:
331
351
//
332
352
// 8000
@@ -347,21 +367,19 @@ type parsedSrcDestPort struct {
347
367
// 9: end or remote port range
348
368
var specRegexp = regexp .MustCompile (`^((\[[0-9a-fA-F:]+]|\d+\.\d+\.\d+\.\d+):)?(\d+)(-(\d+))?(:(\d+)(-(\d+))?)?$` )
349
369
350
- func parseSrcDestPorts (in string ) ([]parsedSrcDestPort ,error ) {
351
- var (
352
- err error
353
- localAddr = netip .AddrFrom4 ([4 ]byte {127 ,0 ,0 ,1 })
354
- remoteAddr = netip .AddrFrom4 ([4 ]byte {127 ,0 ,0 ,1 })
355
- )
370
+ func parseSrcDestPorts (in string ) ([]portForwardSpec ,error ) {
356
371
groups := specRegexp .FindStringSubmatch (in )
357
372
if len (groups )== 0 {
358
373
return nil ,xerrors .Errorf ("invalid port specification %q" ,in )
359
374
}
375
+
376
+ var localAddr netip.Addr
360
377
if groups [2 ]!= "" {
361
- localAddr ,err = netip .ParseAddr (strings .Trim (groups [2 ],"[]" ))
378
+ parsedAddr ,err : =netip .ParseAddr (strings .Trim (groups [2 ],"[]" ))
362
379
if err != nil {
363
380
return nil ,xerrors .Errorf ("invalid IP address %q" ,groups [2 ])
364
381
}
382
+ localAddr = parsedAddr
365
383
}
366
384
367
385
local ,err := parsePortRange (groups [3 ],groups [5 ])
@@ -378,11 +396,12 @@ func parseSrcDestPorts(in string) ([]parsedSrcDestPort, error) {
378
396
if len (local )!= len (remote ) {
379
397
return nil ,xerrors .Errorf ("port ranges must be the same length, got %d ports forwarded to %d ports" ,len (local ),len (remote ))
380
398
}
381
- var out []parsedSrcDestPort
399
+ var out []portForwardSpec
382
400
for i := range local {
383
- out = append (out ,parsedSrcDestPort {
384
- local :netip .AddrPortFrom (localAddr ,local [i ]),
385
- remote :netip .AddrPortFrom (remoteAddr ,remote [i ]),
401
+ out = append (out ,portForwardSpec {
402
+ listenHost :localAddr ,
403
+ listenPort :local [i ],
404
+ dialPort :remote [i ],
386
405
})
387
406
}
388
407
return out ,nil