7
7
"errors"
8
8
"fmt"
9
9
"io"
10
+ "math"
10
11
"net"
11
12
"os"
12
13
"path/filepath"
@@ -22,61 +23,76 @@ import (
22
23
"cdr.dev/slog"
23
24
)
24
25
25
- // x11Callback is called when the client requests X11 forwarding.
26
- // It adds an Xauthority entry to the Xauthority file.
27
- func (s * Server )x11Callback (ctx ssh.Context ,x11 ssh.X11 )bool {
28
- hostname ,err := os .Hostname ()
29
- if err != nil {
30
- s .logger .Warn (ctx ,"failed to get hostname" ,slog .Error (err ))
31
- s .metrics .x11HandlerErrors .WithLabelValues ("hostname" ).Add (1 )
32
- return false
33
- }
34
-
35
- err = s .fs .MkdirAll (s .config .X11SocketDir ,0o700 )
36
- if err != nil {
37
- s .logger .Warn (ctx ,"failed to make the x11 socket dir" ,slog .F ("dir" ,s .config .X11SocketDir ),slog .Error (err ))
38
- s .metrics .x11HandlerErrors .WithLabelValues ("socker_dir" ).Add (1 )
39
- return false
40
- }
26
+ const (
27
+ // X11StartPort is the starting port for X11 forwarding, this is the
28
+ // port used for "DISPLAY=localhost:0".
29
+ X11StartPort = 6000
30
+ // X11DefaultDisplayOffset is the default offset for X11 forwarding.
31
+ X11DefaultDisplayOffset = 10
32
+ )
41
33
42
- err = addXauthEntry (ctx ,s .fs ,hostname ,strconv .Itoa (int (x11 .ScreenNumber )),x11 .AuthProtocol ,x11 .AuthCookie )
43
- if err != nil {
44
- s .logger .Warn (ctx ,"failed to add Xauthority entry" ,slog .Error (err ))
45
- s .metrics .x11HandlerErrors .WithLabelValues ("xauthority" ).Add (1 )
46
- return false
47
- }
34
+ // x11Callback is called when the client requests X11 forwarding.
35
+ func (* Server )x11Callback (_ ssh.Context ,_ ssh.X11 )bool {
36
+ // Always allow.
48
37
return true
49
38
}
50
39
51
40
// x11Handler is called when a session has requested X11 forwarding.
52
41
// It listens for X11 connections and forwards them to the client.
53
- func (s * Server )x11Handler (ctx ssh.Context ,x11 ssh.X11 )bool {
42
+ func (s * Server )x11Handler (ctx ssh.Context ,x11 ssh.X11 )( display int , handled bool ) {
54
43
serverConn ,valid := ctx .Value (ssh .ContextKeyConn ).(* gossh.ServerConn )
55
44
if ! valid {
56
45
s .logger .Warn (ctx ,"failed to get server connection" )
57
- return false
46
+ return - 1 , false
58
47
}
59
- // We want to overwrite the socket so that subsequent connections will succeed.
60
- socketPath := filepath .Join (s .config .X11SocketDir ,fmt .Sprintf ("X%d" ,x11 .ScreenNumber ))
61
- err := os .Remove (socketPath )
62
- if err != nil && ! errors .Is (err ,os .ErrNotExist ) {
63
- s .logger .Warn (ctx ,"failed to remove existing X11 socket" ,slog .Error (err ))
64
- return false
65
- }
66
- listener ,err := net .Listen ("unix" ,socketPath )
48
+
49
+ hostname ,err := os .Hostname ()
67
50
if err != nil {
51
+ s .logger .Warn (ctx ,"failed to get hostname" ,slog .Error (err ))
52
+ s .metrics .x11HandlerErrors .WithLabelValues ("hostname" ).Add (1 )
53
+ return - 1 ,false
54
+ }
55
+
56
+ var (
57
+ lc net.ListenConfig
58
+ ln net.Listener
59
+ port = X11StartPort + * s .config .X11DisplayOffset
60
+ )
61
+ // Look for an open port to listen on..
62
+ for ;port >= X11StartPort && port < math .MaxUint16 ;port ++ {
63
+ ln ,err = lc .Listen (ctx ,"tcp" ,fmt .Sprintf ("localhost:%d" ,port ))
64
+ if err == nil {
65
+ display = port - X11StartPort
66
+ break
67
+ }
68
+ }
69
+ if ln == nil {
68
70
s .logger .Warn (ctx ,"failed to listen for X11" ,slog .Error (err ))
69
- return false
71
+ s .metrics .x11HandlerErrors .WithLabelValues ("listen" ).Add (1 )
72
+ return - 1 ,false
73
+ }
74
+ s .trackListener (ln ,true )
75
+ defer func () {
76
+ if ! handled {
77
+ s .trackListener (ln ,false )
78
+ _ = ln .Close ()
79
+ }
80
+ }()
81
+
82
+ err = addXauthEntry (ctx ,s .fs ,hostname ,strconv .Itoa (port ),x11 .AuthProtocol ,x11 .AuthCookie )
83
+ if err != nil {
84
+ s .logger .Warn (ctx ,"failed to add Xauthority entry" ,slog .Error (err ))
85
+ s .metrics .x11HandlerErrors .WithLabelValues ("xauthority" ).Add (1 )
86
+ return - 1 ,false
70
87
}
71
- s .trackListener (listener ,true )
72
88
73
89
go func () {
74
- defer listener .Close ()
75
- defer s .trackListener (listener ,false )
90
+ defer ln .Close ()
91
+ defer s .trackListener (ln ,false )
76
92
handledFirstConnection := false
77
93
78
94
for {
79
- conn ,err := listener .Accept ()
95
+ conn ,err := ln .Accept ()
80
96
if err != nil {
81
97
if errors .Is (err ,net .ErrClosed ) {
82
98
return
@@ -91,33 +107,37 @@ func (s *Server) x11Handler(ctx ssh.Context, x11 ssh.X11) bool {
91
107
}
92
108
handledFirstConnection = true
93
109
94
- unixConn ,ok := conn .(* net.UnixConn )
110
+ tcpConn ,ok := conn .(* net.TCPConn )
95
111
if ! ok {
96
- s .logger .Warn (ctx ,fmt .Sprintf ("failed to cast connection to UnixConn. got: %T" ,conn ))
112
+ s .logger .Warn (ctx ,fmt .Sprintf ("failed to cast connection to TCPConn. got: %T" ,conn ))
113
+ _ = conn .Close ()
97
114
return
98
115
}
99
- unixAddr ,ok := unixConn .LocalAddr ().(* net.UnixAddr )
116
+ tcpAddr ,ok := tcpConn .LocalAddr ().(* net.TCPAddr )
100
117
if ! ok {
101
- s .logger .Warn (ctx ,fmt .Sprintf ("failed to cast local address to UnixAddr. got: %T" ,unixConn .LocalAddr ()))
118
+ s .logger .Warn (ctx ,fmt .Sprintf ("failed to cast local address to TCPAddr. got: %T" ,tcpConn .LocalAddr ()))
119
+ _ = conn .Close ()
102
120
return
103
121
}
104
122
105
123
channel ,reqs ,err := serverConn .OpenChannel ("x11" ,gossh .Marshal (struct {
106
124
OriginatorAddress string
107
125
OriginatorPort uint32
108
126
}{
109
- OriginatorAddress :unixAddr . Name ,
110
- OriginatorPort :0 ,
127
+ OriginatorAddress :tcpAddr . IP . String () ,
128
+ OriginatorPort :uint32 ( tcpAddr . Port ) ,
111
129
}))
112
130
if err != nil {
113
131
s .logger .Warn (ctx ,"failed to open X11 channel" ,slog .Error (err ))
132
+ _ = conn .Close ()
114
133
return
115
134
}
116
135
go gossh .DiscardRequests (reqs )
117
136
go Bicopy (ctx ,conn ,channel )
118
137
}
119
138
}()
120
- return true
139
+
140
+ return display ,true
121
141
}
122
142
123
143
// addXauthEntry adds an Xauthority entry to the Xauthority file.