Expand Up @@ -25,6 +25,14 @@ import ( "github.com/coder/serpent" ) var ( // noAddr is the zero-value of netip.Addr, and is not a valid address. We use it to identify // when the local address is not specified in port-forward flags. noAddr netip.Addr ipv6Loopback = netip.MustParseAddr("::1") ipv4Loopback = netip.MustParseAddr("127.0.0.1") ) func (r *RootCmd) portForward() *serpent.Command { var ( tcpForwards []string // <port>:<port> Expand Down Expand Up @@ -122,7 +130,7 @@ func (r *RootCmd) portForward() *serpent.Command { // Start all listeners. var ( wg = new(sync.WaitGroup) listeners = make([]net.Listener, len(specs)) listeners = make([]net.Listener,0, len(specs)*2 ) closeAllListeners = func() { logger.Debug(ctx, "closing all listeners") for _, l := range listeners { Expand All @@ -135,13 +143,25 @@ func (r *RootCmd) portForward() *serpent.Command { ) defer closeAllListeners() for i, spec := range specs { for _, spec := range specs { if spec.listenHost == noAddr { // first, opportunistically try to listen on IPv6 spec6 := spec spec6.listenHost = ipv6Loopback l6, err6 := listenAndPortForward(ctx, inv, conn, wg, spec6, logger) if err6 != nil { logger.Info(ctx, "failed to opportunistically listen on IPv6", slog.F("spec", spec), slog.Error(err6)) } else { listeners = append(listeners, l6) } spec.listenHost = ipv4Loopback } l, err := listenAndPortForward(ctx, inv, conn, wg, spec, logger) if err != nil { logger.Error(ctx, "failed to listen", slog.F("spec", spec), slog.Error(err)) return err } listeners[i] =l listeners =append(listeners, l) } stopUpdating := client.UpdateWorkspaceUsageContext(ctx, workspace.ID) Expand Down Expand Up @@ -206,12 +226,19 @@ func listenAndPortForward( spec portForwardSpec, logger slog.Logger, ) (net.Listener, error) { logger = logger.With(slog.F("network", spec.listenNetwork), slog.F("address", spec.listenAddress)) _, _ = fmt.Fprintf(inv.Stderr, "Forwarding '%v://%v' locally to '%v://%v' in the workspace\n", spec.listenNetwork, spec.listenAddress, spec.dialNetwork, spec.dialAddress) logger = logger.With( slog.F("network", spec.network), slog.F("listen_host", spec.listenHost), slog.F("listen_port", spec.listenPort), ) listenAddress := netip.AddrPortFrom(spec.listenHost, spec.listenPort) dialAddress := fmt.Sprintf("127.0.0.1:%d", spec.dialPort) _, _ = fmt.Fprintf(inv.Stderr, "Forwarding '%s://%s' locally to '%s://%s' in the workspace\n", spec.network, listenAddress, spec.network, dialAddress) l, err := inv.Net.Listen(spec.listenNetwork, spec. listenAddress) l, err := inv.Net.Listen(spec.network, listenAddress.String() ) if err != nil { return nil, xerrors.Errorf("listen '%v ://%v ': %w", spec.listenNetwork, spec. listenAddress, err) return nil, xerrors.Errorf("listen '%s ://%s ': %w", spec.network, listenAddress.String() , err) } logger.Debug(ctx, "listening") Expand All @@ -226,24 +253,31 @@ func listenAndPortForward( logger.Debug(ctx, "listener closed") return } _, _ = fmt.Fprintf(inv.Stderr, "Error accepting connection from '%v://%v': %v\n", spec.listenNetwork, spec.listenAddress, err) _, _ = fmt.Fprintf(inv.Stderr, "Error accepting connection from '%s://%s': %v\n", spec.network, listenAddress.String(), err) _, _ = fmt.Fprintln(inv.Stderr, "Killing listener") return } logger.Debug(ctx, "accepted connection", slog.F("remote_addr", netConn.RemoteAddr())) logger.Debug(ctx, "accepted connection", slog.F("remote_addr", netConn.RemoteAddr())) go func(netConn net.Conn) { defer netConn.Close() remoteConn, err := conn.DialContext(ctx, spec.dialNetwork, spec. dialAddress) remoteConn, err := conn.DialContext(ctx, spec.network, dialAddress) if err != nil { _, _ = fmt.Fprintf(inv.Stderr, "Failed to dial '%v://%v' in workspace: %s\n", spec.dialNetwork, spec.dialAddress, err) _, _ = fmt.Fprintf(inv.Stderr, "Failed to dial '%s://%s' in workspace: %s\n", spec.network, dialAddress, err) return } defer remoteConn.Close() logger.Debug(ctx, "dialed remote", slog.F("remote_addr", netConn.RemoteAddr())) logger.Debug(ctx, "dialed remote", slog.F("remote_addr", netConn.RemoteAddr())) agentssh.Bicopy(ctx, netConn, remoteConn) logger.Debug(ctx, "connection closing", slog.F("remote_addr", netConn.RemoteAddr())) logger.Debug(ctx, "connection closing", slog.F("remote_addr", netConn.RemoteAddr())) }(netConn) } }(spec) Expand All @@ -252,58 +286,48 @@ func listenAndPortForward( } type portForwardSpec struct { listenNetwork string // tcp, udp listenAddress string // <ip>:<port> or path dialNetwork string // tcp, udp dialAddress string // <ip>:<port> or path network string // tcp, udp listenHost netip.Addr listenPort, dialPort uint16 } func parsePortForwards(tcpSpecs, udpSpecs []string) ([]portForwardSpec, error) { specs := []portForwardSpec{} for _, specEntry := range tcpSpecs { for _, spec := range strings.Split(specEntry, ",") { ports , err := parseSrcDestPorts(strings.TrimSpace(spec))pfSpecs , err := parseSrcDestPorts(strings.TrimSpace(spec))if err != nil { return nil, xerrors.Errorf("failed to parse TCP port-forward specification %q: %w", spec, err) } for _, port := range ports { specs = append(specs, portForwardSpec{ listenNetwork: "tcp", listenAddress: port.local.String(), dialNetwork: "tcp", dialAddress: port.remote.String(), }) for _, pfSpec := range pfSpecs { pfSpec.network = "tcp" specs = append(specs, pfSpec) } } } for _, specEntry := range udpSpecs { for _, spec := range strings.Split(specEntry, ",") { ports , err := parseSrcDestPorts(strings.TrimSpace(spec))pfSpecs , err := parseSrcDestPorts(strings.TrimSpace(spec))if err != nil { return nil, xerrors.Errorf("failed to parse UDP port-forward specification %q: %w", spec, err) } for _, port := range ports { specs = append(specs, portForwardSpec{ listenNetwork: "udp", listenAddress: port.local.String(), dialNetwork: "udp", dialAddress: port.remote.String(), }) for _, pfSpec := range pfSpecs { pfSpec.network = "udp" specs = append(specs, pfSpec) } } } // Check for duplicate entries. locals := map[string]struct{}{} for _, spec := range specs { localStr := fmt.Sprintf("%v:%v ", spec.listenNetwork , spec.listenAddress ) localStr := fmt.Sprintf("%s:%s:%d ", spec.network , spec.listenHost, spec.listenPort ) if _, ok := locals[localStr]; ok { return nil, xerrors.Errorf("local %v %v is specified twice", spec.listenNetwork , spec.listenAddress ) return nil, xerrors.Errorf("local %s host:%s port:%d is specified twice", spec.network , spec.listenHost, spec.listenPort ) } locals[localStr] = struct{}{} } Expand All @@ -323,10 +347,6 @@ func parsePort(in string) (uint16, error) { return uint16(port), nil } type parsedSrcDestPort struct { local, remote netip.AddrPort } // specRegexp matches port specs. It handles all the following formats: // // 8000 Expand All @@ -347,21 +367,19 @@ type parsedSrcDestPort struct { // 9: end or remote port range var specRegexp = regexp.MustCompile(`^((\[[0-9a-fA-F:]+]|\d+\.\d+\.\d+\.\d+):)?(\d+)(-(\d+))?(:(\d+)(-(\d+))?)?$`) func parseSrcDestPorts(in string) ([]parsedSrcDestPort, error) { var ( err error localAddr = netip.AddrFrom4([4]byte{127, 0, 0, 1}) remoteAddr = netip.AddrFrom4([4]byte{127, 0, 0, 1}) ) func parseSrcDestPorts(in string) ([]portForwardSpec, error) { groups := specRegexp.FindStringSubmatch(in) if len(groups) == 0 { return nil, xerrors.Errorf("invalid port specification %q", in) } var localAddr netip.Addr if groups[2] != "" { localAddr , err = netip.ParseAddr(strings.Trim(groups[2], "[]"))parsedAddr , err: = netip.ParseAddr(strings.Trim(groups[2], "[]"))if err != nil { return nil, xerrors.Errorf("invalid IP address %q", groups[2]) } localAddr = parsedAddr } local, err := parsePortRange(groups[3], groups[5]) Expand All @@ -378,11 +396,12 @@ func parseSrcDestPorts(in string) ([]parsedSrcDestPort, error) { if len(local) != len(remote) { return nil, xerrors.Errorf("port ranges must be the same length, got %d ports forwarded to %d ports", len(local), len(remote)) } var out []parsedSrcDestPort var out []portForwardSpec for i := range local { out = append(out, parsedSrcDestPort{ local: netip.AddrPortFrom(localAddr, local[i]), remote: netip.AddrPortFrom(remoteAddr, remote[i]), out = append(out, portForwardSpec{ listenHost: localAddr, listenPort: local[i], dialPort: remote[i], }) } return out, nil Expand Down