@@ -19,12 +19,16 @@ import (
1919"golang.org/x/xerrors"
2020"gopkg.in/natefinch/lumberjack.v2"
2121
22+ "github.com/coder/retry"
23+
2224"github.com/prometheus/client_golang/prometheus"
2325
2426"cdr.dev/slog"
2527"cdr.dev/slog/sloggers/sloghuman"
2628"cdr.dev/slog/sloggers/slogjson"
2729"cdr.dev/slog/sloggers/slogstackdriver"
30+ "github.com/coder/serpent"
31+
2832"github.com/coder/coder/v2/agent"
2933"github.com/coder/coder/v2/agent/agentcontainers"
3034"github.com/coder/coder/v2/agent/agentexec"
@@ -34,7 +38,6 @@ import (
3438"github.com/coder/coder/v2/cli/clilog"
3539"github.com/coder/coder/v2/codersdk"
3640"github.com/coder/coder/v2/codersdk/agentsdk"
37- "github.com/coder/serpent"
3841)
3942
4043func (r * RootCmd )workspaceAgent ()* serpent.Command {
@@ -63,8 +66,10 @@ func (r *RootCmd) workspaceAgent() *serpent.Command {
6366// This command isn't useful to manually execute.
6467Hidden :true ,
6568Handler :func (inv * serpent.Invocation )error {
66- ctx ,cancel := context .WithCancel (inv .Context ())
67- defer cancel ()
69+ ctx ,cancel := context .WithCancelCause (inv .Context ())
70+ defer func () {
71+ cancel (xerrors .New ("defer" ))
72+ }()
6873
6974var (
7075ignorePorts = map [int ]string {}
@@ -281,7 +286,6 @@ func (r *RootCmd) workspaceAgent() *serpent.Command {
281286return xerrors .Errorf ("add executable to $PATH: %w" ,err )
282287}
283288
284- prometheusRegistry := prometheus .NewRegistry ()
285289subsystemsRaw := inv .Environ .Get (agent .EnvAgentSubsystem )
286290subsystems := []codersdk.AgentSubsystem {}
287291for _ ,s := range strings .Split (subsystemsRaw ,"," ) {
@@ -328,46 +332,90 @@ func (r *RootCmd) workspaceAgent() *serpent.Command {
328332containerLister = agentcontainers .NewDocker (execer )
329333}
330334
331- agnt := agent .New (agent.Options {
332- Client :client ,
333- Logger :logger ,
334- LogDir :logDir ,
335- ScriptDataDir :scriptDataDir ,
336- // #nosec G115 - Safe conversion as tailnet listen port is within uint16 range (0-65535)
337- TailnetListenPort :uint16 (tailnetListenPort ),
338- ExchangeToken :func (ctx context.Context ) (string ,error ) {
339- if exchangeToken == nil {
340- return client .SDK .SessionToken (),nil
335+ // TODO: timeout ok?
336+ reinitCtx ,reinitCancel := context .WithTimeout (context .Background (),time .Hour * 24 )
337+ defer reinitCancel ()
338+ reinitEvents := make (chan agentsdk.ReinitializationResponse )
339+
340+ go func () {
341+ // Retry to wait for reinit, main context cancels the retrier.
342+ for retrier := retry .New (100 * time .Millisecond ,10 * time .Second );retrier .Wait (ctx ); {
343+ select {
344+ case <- reinitCtx .Done ():
345+ return
346+ default :
341347}
342- resp ,err := exchangeToken (ctx )
348+
349+ err := client .WaitForReinit (reinitCtx ,reinitEvents )
343350if err != nil {
344- return " " ,err
351+ logger . Error ( ctx , "failed to wait for reinit instructions, will retry " ,slog . Error ( err ))
345352}
346- client .SetSessionToken (resp .SessionToken )
347- return resp .SessionToken ,nil
348- },
349- EnvironmentVariables :environmentVariables ,
350- IgnorePorts :ignorePorts ,
351- SSHMaxTimeout :sshMaxTimeout ,
352- Subsystems :subsystems ,
353-
354- PrometheusRegistry :prometheusRegistry ,
355- BlockFileTransfer :blockFileTransfer ,
356- Execer :execer ,
357- ContainerLister :containerLister ,
358-
359- ExperimentalDevcontainersEnabled :experimentalDevcontainersEnabled ,
360- })
361-
362- promHandler := agent .PrometheusMetricsHandler (prometheusRegistry ,logger )
363- prometheusSrvClose := ServeHandler (ctx ,logger ,promHandler ,prometheusAddress ,"prometheus" )
364- defer prometheusSrvClose ()
365-
366- debugSrvClose := ServeHandler (ctx ,logger ,agnt .HTTPDebug (),debugAddress ,"debug" )
367- defer debugSrvClose ()
368-
369- <- ctx .Done ()
370- return agnt .Close ()
353+ }
354+ }()
355+
356+ var (
357+ lastErr error
358+ mustExit bool
359+ )
360+ for {
361+ prometheusRegistry := prometheus .NewRegistry ()
362+
363+ agnt := agent .New (agent.Options {
364+ Client :client ,
365+ Logger :logger ,
366+ LogDir :logDir ,
367+ ScriptDataDir :scriptDataDir ,
368+ // #nosec G115 - Safe conversion as tailnet listen port is within uint16 range (0-65535)
369+ TailnetListenPort :uint16 (tailnetListenPort ),
370+ ExchangeToken :func (ctx context.Context ) (string ,error ) {
371+ if exchangeToken == nil {
372+ return client .SDK .SessionToken (),nil
373+ }
374+ resp ,err := exchangeToken (ctx )
375+ if err != nil {
376+ return "" ,err
377+ }
378+ client .SetSessionToken (resp .SessionToken )
379+ return resp .SessionToken ,nil
380+ },
381+ EnvironmentVariables :environmentVariables ,
382+ IgnorePorts :ignorePorts ,
383+ SSHMaxTimeout :sshMaxTimeout ,
384+ Subsystems :subsystems ,
385+
386+ PrometheusRegistry :prometheusRegistry ,
387+ BlockFileTransfer :blockFileTransfer ,
388+ Execer :execer ,
389+ ContainerLister :containerLister ,
390+ ExperimentalDevcontainersEnabled :experimentalDevcontainersEnabled ,
391+ })
392+
393+ promHandler := agent .PrometheusMetricsHandler (prometheusRegistry ,logger )
394+ prometheusSrvClose := ServeHandler (ctx ,logger ,promHandler ,prometheusAddress ,"prometheus" )
395+
396+ debugSrvClose := ServeHandler (ctx ,logger ,agnt .HTTPDebug (),debugAddress ,"debug" )
397+
398+ select {
399+ case <- ctx .Done ():
400+ logger .Warn (ctx ,"agent shutting down" ,slog .Error (ctx .Err ()),slog .F ("cause" ,context .Cause (ctx )))
401+ mustExit = true
402+ case event := <- reinitEvents :
403+ logger .Warn (ctx ,"agent received instruction to reinitialize" ,
404+ slog .F ("message" ,event .Message ),slog .F ("reason" ,event .Reason ))
405+ }
406+
407+ lastErr = agnt .Close ()
408+ debugSrvClose ()
409+ prometheusSrvClose ()
410+
411+ if mustExit {
412+ reinitCancel ()
413+ break
414+ }
415+
416+ logger .Info (ctx ,"reinitializing..." )
417+ }
418+ return lastErr
371419},
372420}
373421