@@ -103,7 +103,11 @@ func main() {
103103system ,noLoad ,debugMode bool
104104)
105105
106- defaultSocketPath := func ()string {
106+ envSocketPath := func ()string {
107+ if val ,ok := os .LookupEnv ("SSH_AUTH_SOCK" );ok && socketPath == "" {
108+ return val
109+ }
110+
107111dir := os .Getenv ("XDG_RUNTIME_DIR" )
108112if dir == "" {
109113dir = "/var/tmp"
@@ -113,7 +117,7 @@ func main() {
113117
114118var sockets SocketSet
115119
116- flag .StringVar (& socketPath ,"l" ,defaultSocketPath ,"path of the UNIX socket to listen on" )
120+ flag .StringVar (& socketPath ,"l" ,envSocketPath ,"path of the UNIX socket to listen on" )
117121flag .Var (& sockets ,"A" ,"fallback ssh-agent sockets" )
118122flag .BoolVar (& swtpmFlag ,"swtpm" ,false ,"use swtpm instead of actual tpm" )
119123flag .BoolVar (& printSocketFlag ,"print-socket" ,false ,"print path of UNIX socket to stdout" )
@@ -161,15 +165,6 @@ func main() {
161165keyDir = utils .SSHDir ()
162166}
163167
164- fi ,err := os .Lstat (keyDir )
165- if err != nil {
166- slog .Error (err .Error ())
167- os .Exit (1 )
168- }
169- if fi .Mode ()& os .ModeSymlink == os .ModeSymlink {
170- slog .Info ("Not following symbolic link" ,slog .String ("key_directory" ,keyDir ))
171- }
172-
173168if term .IsTerminal (int (os .Stdin .Fd ())) {
174169slog .Info ("Warning: ssh-tpm-agent is meant to run as a background daemon." )
175170slog .Info ("Running multiple instances is likely to lead to conflicts." )
@@ -187,44 +182,14 @@ func main() {
187182agents = append (agents ,sshagent .NewClient (conn ))
188183}
189184
190- var listener * net.UnixListener
191-
192- if os .Getenv ("LISTEN_FDS" )!= "" {
193- if err != nil {
194- slog .Error (err .Error ())
195- os .Exit (1 )
196- }
197-
198- file := os .NewFile (uintptr (3 ),"ssh-tpm-agent.socket" )
199- fl ,err := net .FileListener (file )
200- if err != nil {
201- slog .Error (err .Error ())
202- os .Exit (1 )
203- }
204- var ok bool
205- listener ,ok = fl .(* net.UnixListener )
206- if ! ok {
207- slog .Error ("Socket-activation FD isn't a unix socket" )
208- os .Exit (1 )
209- }
210-
211- slog .Info ("Socket activated agent." )
212- }else {
213- os .Remove (socketPath )
214- if err := os .MkdirAll (filepath .Dir (socketPath ),0o777 );err != nil {
215- slog .Error ("Failed to create UNIX socket folder:" ,err )
216- os .Exit (1 )
217- }
218- listener ,err = net .ListenUnix ("unix" ,& net.UnixAddr {Net :"unix" ,Name :socketPath })
219- if err != nil {
220- slog .Error ("Failed to listen on UNIX socket:" ,err )
221- os .Exit (1 )
222- }
223- slog .Info ("Listening on socket" ,slog .String ("path" ,socketPath ))
185+ listener ,err := createListener (socketPath )
186+ if err != nil {
187+ slog .Error ("creating listener" ,slog .String ("error" ,err .Error ()))
188+ os .Exit (1 )
224189}
225190
226- a := agent .NewAgent (listener ,
227- agents ,
191+ agent := agent .NewAgent (listener , agents ,
192+
228193// TPM Callback
229194func () (tpm transport.TPMCloser ) {
230195// the agent will close the TPM after this is called
@@ -248,13 +213,48 @@ func main() {
248213signal .Notify (c ,syscall .SIGHUP )
249214go func () {
250215for range c {
251- a .Stop ()
216+ agent .Stop ()
252217}
253218}()
254219
255220if ! noLoad {
256- a .LoadKeys (keyDir )
221+ if err := agent .LoadKeys (keyDir );err != nil {
222+ slog .Error ("loading keys" ,slog .String ("error" ,err .Error ()))
223+ }
224+ }
225+
226+ agent .Wait ()
227+ }
228+
229+ func createListener (socketPath string ) (* net.UnixListener ,error ) {
230+ if _ ,ok := os .LookupEnv ("LISTEN_FDS" );ok {
231+ f := os .NewFile (uintptr (3 ),"ssh-tpm-agent.socket" )
232+
233+ fListener ,err := net .FileListener (f )
234+ if err != nil {
235+ return nil ,err
236+ }
237+
238+ listener ,ok := fListener .(* net.UnixListener )
239+ if ! ok {
240+ return nil ,fmt .Errorf ("socket-activation file descriptor isn't an unix socket" )
241+ }
242+
243+ slog .Info ("Activated agent by socket" )
244+ return listener ,nil
245+ }
246+
247+ _ = os .Remove (socketPath )
248+
249+ if err := os .MkdirAll (filepath .Dir (socketPath ),0o770 );err != nil {
250+ return nil ,fmt .Errorf ("creating UNIX socket directory: %w" ,err )
251+ }
252+
253+ listener ,err := net .ListenUnix ("unix" ,& net.UnixAddr {Net :"unix" ,Name :socketPath })
254+ if err != nil {
255+ return nil ,err
257256}
258257
259- a .Wait ()
258+ slog .Info ("Listening on socket" ,slog .String ("path" ,socketPath ))
259+ return listener ,nil
260260}