9
9
"io"
10
10
"net"
11
11
"os/exec"
12
+ "os/user"
13
+ "sync"
12
14
"time"
13
15
14
16
"cdr.dev/slog"
@@ -39,7 +41,6 @@ func DialSSHClient(conn *peer.Conn) (*gossh.Client, error) {
39
41
return nil ,err
40
42
}
41
43
sshConn ,channels ,requests ,err := gossh .NewClientConn (netConn ,"localhost:22" ,& gossh.ClientConfig {
42
- User :"kyle" ,
43
44
Config : gossh.Config {
44
45
Ciphers : []string {"arcfour" },
45
46
},
@@ -66,6 +67,7 @@ func New(dialer Dialer, options *Options) io.Closer {
66
67
clientDialer :dialer ,
67
68
options :options ,
68
69
closeCancel :cancelFunc ,
70
+ closed :make (chan struct {}),
69
71
}
70
72
server .init (ctx )
71
73
return server
@@ -76,6 +78,7 @@ type server struct {
76
78
options * Options
77
79
78
80
closeCancel context.CancelFunc
81
+ closeMutex sync.Mutex
79
82
closed chan struct {}
80
83
81
84
sshServer * ssh.Server
@@ -153,10 +156,19 @@ func (*server) handleSSHSession(session ssh.Session) error {
153
156
err error
154
157
)
155
158
159
+ username := session .User ()
160
+ if username == "" {
161
+ currentUser ,err := user .Current ()
162
+ if err != nil {
163
+ return xerrors .Errorf ("get current user: %w" ,err )
164
+ }
165
+ username = currentUser .Username
166
+ }
167
+
156
168
// gliderlabs/ssh returns a command slice of zero
157
169
// when a shell is requested.
158
170
if len (session .Command ())== 0 {
159
- command ,err = usershell .Get (session . User () )
171
+ command ,err = usershell .Get (username )
160
172
if err != nil {
161
173
return xerrors .Errorf ("get user shell: %w" ,err )
162
174
}
@@ -208,6 +220,7 @@ func (*server) handleSSHSession(session ssh.Session) error {
208
220
_ ,_ = io .Copy (session ,ptty .Output ())
209
221
}()
210
222
_ ,_ = process .Wait ()
223
+ _ = ptty .Close ()
211
224
return nil
212
225
}
213
226
@@ -254,7 +267,11 @@ func (s *server) run(ctx context.Context) {
254
267
for {
255
268
conn ,err := peerListener .Accept ()
256
269
if err != nil {
257
- // This is closed!
270
+ if s .isClosed () {
271
+ return
272
+ }
273
+ s .options .Logger .Debug (ctx ,"peer listener accept exited; restarting connection" ,slog .Error (err ))
274
+ s .run (ctx )
258
275
return
259
276
}
260
277
go s .handlePeerConn (ctx ,conn )
@@ -265,15 +282,21 @@ func (s *server) handlePeerConn(ctx context.Context, conn *peer.Conn) {
265
282
for {
266
283
channel ,err := conn .Accept (ctx )
267
284
if err != nil {
268
- // TODO: Log here!
285
+ if s .isClosed () {
286
+ return
287
+ }
288
+ s .options .Logger .Debug (ctx ,"accept channel from peer connection" ,slog .Error (err ))
269
289
return
270
290
}
271
291
272
292
switch channel .Protocol () {
273
293
case "ssh" :
274
294
s .sshServer .HandleConn (channel .NetConn ())
275
- case "proxy" :
276
- // Proxy the port provided.
295
+ default :
296
+ s .options .Logger .Warn (ctx ,"unhandled protocol from channel" ,
297
+ slog .F ("protocol" ,channel .Protocol ()),
298
+ slog .F ("label" ,channel .Label ()),
299
+ )
277
300
}
278
301
}
279
302
}
@@ -289,6 +312,13 @@ func (s *server) isClosed() bool {
289
312
}
290
313
291
314
func (s * server )Close ()error {
292
- s .sshServer .Close ()
315
+ s .closeMutex .Lock ()
316
+ defer s .closeMutex .Unlock ()
317
+ if s .isClosed () {
318
+ return nil
319
+ }
320
+ close (s .closed )
321
+ s .closeCancel ()
322
+ _ = s .sshServer .Close ()
293
323
return nil
294
324
}