Expand Up @@ -8,6 +8,7 @@ import ( "io" "net" "net/url" "os" "sync" "time" Expand All @@ -16,11 +17,18 @@ import ( "golang.org/x/net/proxy" "nhooyr.io/websocket" "cdr.dev/slog" "cdr.dev/slog/sloggers/sloghuman" "cdr.dev/coder-cli/coder-sdk" ) // DialOptions are configurable options for a wsnet connection. type DialOptions struct { // Logger is an optional logger to use for logging mostly debug messages. If // set to nil, nothing will be logged. Log *slog.Logger // ICEServers is an array of STUN or TURN servers to use for negotiation purposes. // See: https://developer.mozilla.org/en-US/docs/Web/API/RTCConfiguration/iceServers ICEServers []webrtc.ICEServer Expand All @@ -36,6 +44,17 @@ type DialOptions struct { // DialWebsocket dials the broker with a WebSocket and negotiates a connection. func DialWebsocket(ctx context.Context, broker string, netOpts *DialOptions, wsOpts *websocket.DialOptions) (*Dialer, error) { if netOpts == nil { netOpts = &DialOptions{} } if netOpts.Log == nil { // This logger will log nothing. log := slog.Make() netOpts.Log = &log } log := *netOpts.Log log.Debug(ctx, "connecting to broker", slog.F("broker", broker)) conn, resp, err := websocket.Dial(ctx, broker, wsOpts) if err != nil { if resp != nil { Expand All @@ -46,6 +65,8 @@ func DialWebsocket(ctx context.Context, broker string, netOpts *DialOptions, wsO } return nil, fmt.Errorf("dial websocket: %w", err) } log.Debug(ctx, "connected to broker") nconn := websocket.NetConn(ctx, conn, websocket.MessageBinary) defer func() { _ = nconn.Close() Expand All @@ -60,6 +81,11 @@ func Dial(ctx context.Context, conn net.Conn, options *DialOptions) (*Dialer, er if options == nil { options = &DialOptions{} } if options.Log == nil { log := slog.Make(sloghuman.Sink(os.Stderr)).Leveled(slog.LevelInfo).Named("wsnet_dial") options.Log = &log } log := *options.Log if options.ICEServers == nil { options.ICEServers = []webrtc.ICEServer{} } Expand All @@ -71,13 +97,20 @@ func Dial(ctx context.Context, conn net.Conn, options *DialOptions) (*Dialer, er token: options.TURNProxyAuthToken, } } log.Debug(ctx, "creating peer connection", slog.F("options", options), slog.F("turn_proxy", turnProxy)) rtc, err := newPeerConnection(options.ICEServers, turnProxy) if err != nil { return nil, fmt.Errorf("create peer connection: %w", err) } log.Debug(ctx, "created peer connection") rtc.OnConnectionStateChange(func(pcs webrtc.PeerConnectionState) { log.Debug(ctx, "connection state change", slog.F("state", pcs.String())) }) flushCandidates := proxyICECandidates(rtc, conn) log.Debug(ctx, "creating control channel", slog.F("proto", controlChannel)) ctrl, err := rtc.CreateDataChannel(controlChannel, &webrtc.DataChannelInit{ Protocol: stringPtr(controlChannel), Ordered: boolPtr(true), Expand All @@ -90,6 +123,7 @@ func Dial(ctx context.Context, conn net.Conn, options *DialOptions) (*Dialer, er if err != nil { return nil, fmt.Errorf("create offer: %w", err) } log.Debug(ctx, "created offer", slog.F("offer", offer)) err = rtc.SetLocalDescription(offer) if err != nil { return nil, fmt.Errorf("set local offer: %w", err) Expand All @@ -100,21 +134,25 @@ func Dial(ctx context.Context, conn net.Conn, options *DialOptions) (*Dialer, er turnProxyURL = options.TURNProxyURL.String() } offerMessage, err :=json.Marshal(& BrokerMessage{bmsg := BrokerMessage{Offer: &offer, Servers: options.ICEServers, TURNProxyURL: turnProxyURL, }) } log.Debug(ctx, "sending offer message", slog.F("msg", bmsg)) offerMessage, err := json.Marshal(&bmsg) if err != nil { return nil, fmt.Errorf("marshal offer message: %w", err) } _, err = conn.Write(offerMessage) if err != nil { return nil, fmt.Errorf("write offer: %w", err) } flushCandidates() dialer := &Dialer{ log: log, conn: conn, ctrl: ctrl, rtc: rtc, Expand All @@ -128,6 +166,7 @@ func Dial(ctx context.Context, conn net.Conn, options *DialOptions) (*Dialer, er // inside a workspace. The opposing end of the WebSocket messages // should be proxied with a Listener. type Dialer struct { log slog.Logger conn net.Conn ctrl *webrtc.DataChannel ctrlrw datachannel.ReadWriteCloser Expand All @@ -152,20 +191,25 @@ func (d *Dialer) negotiate(ctx context.Context) (err error) { defer func() { _ = d.conn.Close() }() err := waitForConnectionOpen(ctx, d.rtc) err := waitForConnectionOpen(context.Background(), d.rtc) if err != nil { d.log.Debug(ctx, "negotiation error", slog.Error(err)) if errors.Is(err, context.DeadlineExceeded) { _ = d.conn.Close() } errCh <- err errCh <-fmt.Errorf("wait for connection to open: %w", err) return } d.rtc.OnConnectionStateChange(func(pcs webrtc.PeerConnectionState) { if pcs == webrtc.PeerConnectionStateConnected { d.log.Debug(ctx, "connected") return } // Close connections opened when RTC was alive. d.log.Warn(ctx, "closing connections due to connection state change", slog.F("pcs", pcs.String())) d.connClosersMut.Lock() defer d.connClosersMut.Unlock() for _, connCloser := range d.connClosers { Expand All @@ -175,6 +219,7 @@ func (d *Dialer) negotiate(ctx context.Context) (err error) { }) }() d.log.Debug(ctx, "beginning negotiation") for { var msg BrokerMessage err = decoder.Decode(&msg) Expand All @@ -184,6 +229,8 @@ func (d *Dialer) negotiate(ctx context.Context) (err error) { if err != nil { return fmt.Errorf("read: %w", err) } d.log.Debug(ctx, "got message from handshake conn", slog.F("msg", msg)) if msg.Candidate != "" { c := webrtc.ICECandidateInit{ Candidate: msg.Candidate, Expand All @@ -192,17 +239,22 @@ func (d *Dialer) negotiate(ctx context.Context) (err error) { pendingCandidates = append(pendingCandidates, c) continue } d.log.Debug(ctx, "adding remote ICE candidate", slog.F("c", c)) err = d.rtc.AddICECandidate(c) if err != nil { return fmt.Errorf("accept ice candidate: %s: %w", msg.Candidate, err) } continue } if msg.Answer != nil { d.log.Debug(ctx, "received answer", slog.F("a", *msg.Answer)) err = d.rtc.SetRemoteDescription(*msg.Answer) if err != nil { return fmt.Errorf("set answer: %w", err) } for _, candidate := range pendingCandidates { err = d.rtc.AddICECandidate(candidate) if err != nil { Expand All @@ -212,11 +264,15 @@ func (d *Dialer) negotiate(ctx context.Context) (err error) { pendingCandidates = nil continue } if msg.Error != "" { return errors.New(msg.Error) d.log.Debug(ctx, "got error from peer", slog.F("err", msg.Error)) return fmt.Errorf("error from peer: %v", msg.Error) } return fmt.Errorf("unhandled message: %+v", msg) } return <-errCh } Expand All @@ -234,6 +290,7 @@ func (d *Dialer) activeConnections() int { // Close closes the RTC connection. // All data channels dialed will be closed. func (d *Dialer) Close() error { d.log.Debug(context.Background(), "close called") return d.rtc.Close() } Expand All @@ -242,6 +299,7 @@ func (d *Dialer) Ping(ctx context.Context) error { if d.ctrl.ReadyState() == webrtc.DataChannelStateClosed || d.ctrl.ReadyState() == webrtc.DataChannelStateClosing { return webrtc.ErrConnectionClosed } // Since we control the client and server we could open this // data channel with `Negotiated` true to reduce traffic being // sent when the RTC connection is opened. Expand All @@ -257,6 +315,7 @@ func (d *Dialer) Ping(ctx context.Context) error { } d.pingMut.Lock() defer d.pingMut.Unlock() d.log.Debug(ctx, "sending ping") _, err = d.ctrlrw.Write([]byte{'a'}) if err != nil { return fmt.Errorf("write: %w", err) Expand All @@ -281,13 +340,18 @@ func (d *Dialer) Ping(ctx context.Context) error { // DialContext dials the network and address on the remote listener. func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { proto := fmt.Sprintf("%s:%s", network, address) ctx = slog.With(ctx, slog.F("proto", proto)) d.log.Debug(ctx, "opening data channel") dc, err := d.rtc.CreateDataChannel("proxy", &webrtc.DataChannelInit{ Ordered: boolPtr(network != "udp"), Protocol:stringPtr(fmt.Sprintf("%s:%s", network, address)) , Protocol:&proto , }) if err != nil { return nil, fmt.Errorf("create data channel: %w", err) } d.connClosersMut.Lock() d.connClosers = append(d.connClosers, dc) d.connClosersMut.Unlock() Expand All @@ -296,10 +360,18 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net. if err != nil { return nil, fmt.Errorf("wait for open: %w", err) } ctx = slog.With(ctx, slog.F("dc_id", dc.ID())) d.log.Debug(ctx, "data channel opened") rw, err := dc.Detach() if err != nil { return nil, fmt.Errorf("detach: %w", err) } d.log.Debug(ctx, "data channel detached") ctx, cancel := context.WithTimeout(ctx, time.Second*5) defer cancel() errCh := make(chan error) go func() { Expand All @@ -309,6 +381,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net. errCh <- fmt.Errorf("read dial response: %w", err) return } d.log.Debug(ctx, "dial response", slog.F("res", res)) if res.Err == "" { close(errCh) return Expand All @@ -323,8 +396,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net. } errCh <- err }() ctx, cancel := context.WithTimeout(ctx, time.Second*5) defer cancel() select { case err := <-errCh: if err != nil { Expand All @@ -343,5 +415,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net. rw: rw, } c.init() d.log.Debug(ctx, "dial channel ready") return c, nil }