@@ -37,12 +37,30 @@ const (
3737X11MaxPort = X11StartPort + X11MaxDisplays
3838)
3939
40+ // X11Network abstracts the creation of network listeners for X11 forwarding.
41+ // It is intended mainly for testing; production code uses the default
42+ // implementation backed by the operating system networking stack.
43+ type X11Network interface {
44+ Listen (network ,address string ) (net.Listener ,error )
45+ }
46+
47+ // osNet is the default X11Network implementation that uses the standard
48+ // library network stack.
49+ type osNet struct {}
50+
51+ func (osNet )Listen (network ,address string ) (net.Listener ,error ) {
52+ return net .Listen (network ,address )
53+ }
54+
4055type x11Forwarder struct {
4156logger slog.Logger
4257x11HandlerErrors * prometheus.CounterVec
4358fs afero.Fs
4459displayOffset int
4560
61+ // network creates X11 listener sockets. Defaults to osNet{}.
62+ network X11Network
63+
4664mu sync.Mutex
4765sessions map [* x11Session ]struct {}
4866connections map [net.Conn ]struct {}
@@ -147,26 +165,27 @@ func (x *x11Forwarder) listenForConnections(
147165x .closeAndRemoveSession (session )
148166}
149167
150- tcpConn ,ok := conn .(* net.TCPConn )
151- if ! ok {
152- x .logger .Warn (ctx ,fmt .Sprintf ("failed to cast connection to TCPConn. got: %T" ,conn ))
153- _ = conn .Close ()
154- continue
168+ var originAddr string
169+ var originPort uint32
170+
171+ if tcpConn ,ok := conn .(* net.TCPConn );ok {
172+ if tcpAddr ,ok := tcpConn .LocalAddr ().(* net.TCPAddr );ok {
173+ originAddr = tcpAddr .IP .String ()
174+ // #nosec G115 - Safe conversion as TCP port numbers are within uint32 range (0-65535)
175+ originPort = uint32 (tcpAddr .Port )
176+ }
155177}
156- tcpAddr ,ok := tcpConn .LocalAddr ().(* net.TCPAddr )
157- if ! ok {
158- x .logger .Warn (ctx ,fmt .Sprintf ("failed to cast local address to TCPAddr. got: %T" ,tcpConn .LocalAddr ()))
159- _ = conn .Close ()
160- continue
178+ // Fallback values for in-memory or non-TCP connections.
179+ if originAddr == "" {
180+ originAddr = "127.0.0.1"
161181}
162182
163183channel ,reqs ,err := serverConn .OpenChannel ("x11" ,gossh .Marshal (struct {
164184OriginatorAddress string
165185OriginatorPort uint32
166186}{
167- OriginatorAddress :tcpAddr .IP .String (),
168- // #nosec G115 - Safe conversion as TCP port numbers are within uint32 range (0-65535)
169- OriginatorPort :uint32 (tcpAddr .Port ),
187+ OriginatorAddress :originAddr ,
188+ OriginatorPort :originPort ,
170189}))
171190if err != nil {
172191x .logger .Warn (ctx ,"failed to open X11 channel" ,slog .Error (err ))
@@ -287,13 +306,13 @@ func (x *x11Forwarder) evictLeastRecentlyUsedSession() {
287306// createX11Listener creates a listener for X11 forwarding, it will use
288307// the next available port starting from X11StartPort and displayOffset.
289308func (x * x11Forwarder )createX11Listener (ctx context.Context ) (ln net.Listener ,display int ,err error ) {
290- var lc net.ListenConfig
291309// Look for an open port to listen on.
292310for port := X11StartPort + x .displayOffset ;port <= X11MaxPort ;port ++ {
293311if ctx .Err ()!= nil {
294312return nil ,- 1 ,ctx .Err ()
295313}
296- ln ,err = lc .Listen (ctx ,"tcp" ,fmt .Sprintf ("localhost:%d" ,port ))
314+
315+ ln ,err = x .network .Listen ("tcp" ,fmt .Sprintf ("localhost:%d" ,port ))
297316if err == nil {
298317display = port - X11StartPort
299318return ln ,display ,nil