Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit9d3f404

Browse files
committed
feat(agent/agentssh): use tcp for X11 forwarding
Fixes#14198
1 parent5366f25 commit9d3f404

File tree

3 files changed

+105
-58
lines changed

3 files changed

+105
-58
lines changed

‎agent/agentssh/agentssh.go

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,9 @@ type Config struct {
7979
// where users will land when they connect via SSH. Default is the home
8080
// directory of the user.
8181
WorkingDirectoryfunc()string
82-
//X11SocketDir is thedirectory where X11 sockets are created. Default is
83-
///tmp/.X11-unix.
84-
X11SocketDirstring
82+
//X11DisplayOffset is theoffset to add to the X11 display number.
83+
//Default is 10.
84+
X11DisplayOffset*int
8585
// BlockFileTransfer restricts use of file transfer applications.
8686
BlockFileTransferbool
8787
}
@@ -124,8 +124,9 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
124124
ifconfig==nil {
125125
config=&Config{}
126126
}
127-
ifconfig.X11SocketDir=="" {
128-
config.X11SocketDir=filepath.Join(os.TempDir(),".X11-unix")
127+
ifconfig.X11DisplayOffset==nil {
128+
offset:=X11DefaultDisplayOffset
129+
config.X11DisplayOffset=&offset
129130
}
130131
ifconfig.UpdateEnv==nil {
131132
config.UpdateEnv=func(current []string) ([]string,error) {returncurrent,nil }
@@ -273,13 +274,13 @@ func (s *Server) sessionHandler(session ssh.Session) {
273274
extraEnv:=make([]string,0)
274275
x11,hasX11:=session.X11()
275276
ifhasX11 {
276-
handled:=s.x11Handler(session.Context(),x11)
277+
display,handled:=s.x11Handler(session.Context(),x11)
277278
if!handled {
278279
_=session.Exit(1)
279280
logger.Error(ctx,"x11 handler failed")
280281
return
281282
}
282-
extraEnv=append(extraEnv,fmt.Sprintf("DISPLAY=:%d.0",x11.ScreenNumber))
283+
extraEnv=append(extraEnv,fmt.Sprintf("DISPLAY=localhost:%d.%d",display,x11.ScreenNumber))
283284
}
284285

285286
ifs.fileTransferBlocked(session) {

‎agent/agentssh/x11.go

Lines changed: 64 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"errors"
88
"fmt"
99
"io"
10+
"math"
1011
"net"
1112
"os"
1213
"path/filepath"
@@ -22,61 +23,76 @@ import (
2223
"cdr.dev/slog"
2324
)
2425

25-
// x11Callback is called when the client requests X11 forwarding.
26-
// It adds an Xauthority entry to the Xauthority file.
27-
func (s*Server)x11Callback(ctx ssh.Context,x11 ssh.X11)bool {
28-
hostname,err:=os.Hostname()
29-
iferr!=nil {
30-
s.logger.Warn(ctx,"failed to get hostname",slog.Error(err))
31-
s.metrics.x11HandlerErrors.WithLabelValues("hostname").Add(1)
32-
returnfalse
33-
}
34-
35-
err=s.fs.MkdirAll(s.config.X11SocketDir,0o700)
36-
iferr!=nil {
37-
s.logger.Warn(ctx,"failed to make the x11 socket dir",slog.F("dir",s.config.X11SocketDir),slog.Error(err))
38-
s.metrics.x11HandlerErrors.WithLabelValues("socker_dir").Add(1)
39-
returnfalse
40-
}
26+
const (
27+
// X11StartPort is the starting port for X11 forwarding, this is the
28+
// port used for "DISPLAY=localhost:0".
29+
X11StartPort=6000
30+
// X11DefaultDisplayOffset is the default offset for X11 forwarding.
31+
X11DefaultDisplayOffset=10
32+
)
4133

42-
err=addXauthEntry(ctx,s.fs,hostname,strconv.Itoa(int(x11.ScreenNumber)),x11.AuthProtocol,x11.AuthCookie)
43-
iferr!=nil {
44-
s.logger.Warn(ctx,"failed to add Xauthority entry",slog.Error(err))
45-
s.metrics.x11HandlerErrors.WithLabelValues("xauthority").Add(1)
46-
returnfalse
47-
}
34+
// x11Callback is called when the client requests X11 forwarding.
35+
func (*Server)x11Callback(_ ssh.Context,_ ssh.X11)bool {
36+
// Always allow.
4837
returntrue
4938
}
5039

5140
// x11Handler is called when a session has requested X11 forwarding.
5241
// It listens for X11 connections and forwards them to the client.
53-
func (s*Server)x11Handler(ctx ssh.Context,x11 ssh.X11)bool {
42+
func (s*Server)x11Handler(ctx ssh.Context,x11 ssh.X11)(displayint,handledbool) {
5443
serverConn,valid:=ctx.Value(ssh.ContextKeyConn).(*gossh.ServerConn)
5544
if!valid {
5645
s.logger.Warn(ctx,"failed to get server connection")
57-
returnfalse
46+
return-1,false
5847
}
59-
// We want to overwrite the socket so that subsequent connections will succeed.
60-
socketPath:=filepath.Join(s.config.X11SocketDir,fmt.Sprintf("X%d",x11.ScreenNumber))
61-
err:=os.Remove(socketPath)
62-
iferr!=nil&&!errors.Is(err,os.ErrNotExist) {
63-
s.logger.Warn(ctx,"failed to remove existing X11 socket",slog.Error(err))
64-
returnfalse
65-
}
66-
listener,err:=net.Listen("unix",socketPath)
48+
49+
hostname,err:=os.Hostname()
6750
iferr!=nil {
51+
s.logger.Warn(ctx,"failed to get hostname",slog.Error(err))
52+
s.metrics.x11HandlerErrors.WithLabelValues("hostname").Add(1)
53+
return-1,false
54+
}
55+
56+
var (
57+
lc net.ListenConfig
58+
ln net.Listener
59+
port=X11StartPort+*s.config.X11DisplayOffset
60+
)
61+
// Look for an open port to listen on..
62+
for ;port>=X11StartPort&&port<math.MaxUint16;port++ {
63+
ln,err=lc.Listen(ctx,"tcp",fmt.Sprintf("localhost:%d",port))
64+
iferr==nil {
65+
display=port-X11StartPort
66+
break
67+
}
68+
}
69+
ifln==nil {
6870
s.logger.Warn(ctx,"failed to listen for X11",slog.Error(err))
69-
returnfalse
71+
s.metrics.x11HandlerErrors.WithLabelValues("listen").Add(1)
72+
return-1,false
73+
}
74+
s.trackListener(ln,true)
75+
deferfunc() {
76+
if!handled {
77+
s.trackListener(ln,false)
78+
_=ln.Close()
79+
}
80+
}()
81+
82+
err=addXauthEntry(ctx,s.fs,hostname,strconv.Itoa(port),x11.AuthProtocol,x11.AuthCookie)
83+
iferr!=nil {
84+
s.logger.Warn(ctx,"failed to add Xauthority entry",slog.Error(err))
85+
s.metrics.x11HandlerErrors.WithLabelValues("xauthority").Add(1)
86+
return-1,false
7087
}
71-
s.trackListener(listener,true)
7288

7389
gofunc() {
74-
deferlistener.Close()
75-
defers.trackListener(listener,false)
90+
deferln.Close()
91+
defers.trackListener(ln,false)
7692
handledFirstConnection:=false
7793

7894
for {
79-
conn,err:=listener.Accept()
95+
conn,err:=ln.Accept()
8096
iferr!=nil {
8197
iferrors.Is(err,net.ErrClosed) {
8298
return
@@ -91,33 +107,37 @@ func (s *Server) x11Handler(ctx ssh.Context, x11 ssh.X11) bool {
91107
}
92108
handledFirstConnection=true
93109

94-
unixConn,ok:=conn.(*net.UnixConn)
110+
tcpConn,ok:=conn.(*net.TCPConn)
95111
if!ok {
96-
s.logger.Warn(ctx,fmt.Sprintf("failed to cast connection to UnixConn. got: %T",conn))
112+
s.logger.Warn(ctx,fmt.Sprintf("failed to cast connection to TCPConn. got: %T",conn))
113+
_=conn.Close()
97114
return
98115
}
99-
unixAddr,ok:=unixConn.LocalAddr().(*net.UnixAddr)
116+
tcpAddr,ok:=tcpConn.LocalAddr().(*net.TCPAddr)
100117
if!ok {
101-
s.logger.Warn(ctx,fmt.Sprintf("failed to cast local address to UnixAddr. got: %T",unixConn.LocalAddr()))
118+
s.logger.Warn(ctx,fmt.Sprintf("failed to cast local address to TCPAddr. got: %T",tcpConn.LocalAddr()))
119+
_=conn.Close()
102120
return
103121
}
104122

105123
channel,reqs,err:=serverConn.OpenChannel("x11",gossh.Marshal(struct {
106124
OriginatorAddressstring
107125
OriginatorPortuint32
108126
}{
109-
OriginatorAddress:unixAddr.Name,
110-
OriginatorPort:0,
127+
OriginatorAddress:tcpAddr.IP.String(),
128+
OriginatorPort:uint32(tcpAddr.Port),
111129
}))
112130
iferr!=nil {
113131
s.logger.Warn(ctx,"failed to open X11 channel",slog.Error(err))
132+
_=conn.Close()
114133
return
115134
}
116135
gogossh.DiscardRequests(reqs)
117136
goBicopy(ctx,conn,channel)
118137
}
119138
}()
120-
returntrue
139+
140+
returndisplay,true
121141
}
122142

123143
// addXauthEntry adds an Xauthority entry to the Xauthority file.

‎agent/agentssh/x11_test.go

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
package agentssh_test
22

33
import (
4+
"bufio"
5+
"bytes"
46
"context"
57
"encoding/hex"
8+
"fmt"
69
"net"
710
"os"
811
"path/filepath"
912
"runtime"
13+
"strconv"
14+
"strings"
1015
"testing"
1116

1217
"github.com/gliderlabs/ssh"
@@ -31,10 +36,7 @@ func TestServer_X11(t *testing.T) {
3136
ctx:=context.Background()
3237
logger:=slogtest.Make(t,nil).Leveled(slog.LevelDebug)
3338
fs:=afero.NewOsFs()
34-
dir:=t.TempDir()
35-
s,err:=agentssh.NewServer(ctx,logger,prometheus.NewRegistry(),fs,&agentssh.Config{
36-
X11SocketDir:dir,
37-
})
39+
s,err:=agentssh.NewServer(ctx,logger,prometheus.NewRegistry(),fs,&agentssh.Config{})
3840
require.NoError(t,err)
3941
defers.Close()
4042

@@ -53,21 +55,45 @@ func TestServer_X11(t *testing.T) {
5355
sess,err:=c.NewSession()
5456
require.NoError(t,err)
5557

58+
wantScreenNumber:=1
5659
reply,err:=sess.SendRequest("x11-req",true,gossh.Marshal(ssh.X11{
5760
AuthProtocol:"MIT-MAGIC-COOKIE-1",
5861
AuthCookie:hex.EncodeToString([]byte("cookie")),
59-
ScreenNumber:0,
62+
ScreenNumber:uint32(wantScreenNumber),
6063
}))
6164
require.NoError(t,err)
6265
assert.True(t,reply)
6366

64-
err=sess.Shell()
67+
// Want: ~DISPLAY=localhost:10.1
68+
out,err:=sess.Output("echo DISPLAY=$DISPLAY")
6569
require.NoError(t,err)
6670

71+
sc:=bufio.NewScanner(bytes.NewReader(out))
72+
displayNumber:=-1
73+
forsc.Scan() {
74+
line:=strings.TrimSpace(sc.Text())
75+
t.Log(line)
76+
ifstrings.HasPrefix(line,"DISPLAY=") {
77+
parts:=strings.SplitN(line,"=",2)
78+
display:=parts[1]
79+
parts=strings.SplitN(display,":",2)
80+
parts=strings.SplitN(parts[1],".",2)
81+
displayNumber,err=strconv.Atoi(parts[0])
82+
require.NoError(t,err)
83+
assert.GreaterOrEqual(t,displayNumber,10,"display number should be >= 10")
84+
gotScreenNumber,err:=strconv.Atoi(parts[1])
85+
require.NoError(t,err)
86+
assert.Equal(t,wantScreenNumber,gotScreenNumber,"screen number should match")
87+
break
88+
}
89+
}
90+
require.NoError(t,sc.Err())
91+
require.NotEqual(t,-1,displayNumber)
92+
6793
x11Chans:=c.HandleChannelOpen("x11")
6894
payload:="hello world"
6995
require.Eventually(t,func()bool {
70-
conn,err:=net.Dial("unix",filepath.Join(dir,"X0"))
96+
conn,err:=net.Dial("tcp",fmt.Sprintf("localhost:%d",agentssh.X11StartPort+displayNumber))
7197
iferr==nil {
7298
_,err=conn.Write([]byte(payload))
7399
assert.NoError(t,err)

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp