@@ -3,6 +3,7 @@ package wsproxy
3
3
import (
4
4
"context"
5
5
"sync"
6
+ "sync/atomic"
6
7
"time"
7
8
8
9
"golang.org/x/xerrors"
@@ -18,16 +19,19 @@ import (
18
19
var _ cryptokeys.Keycache = & CryptoKeyCache {}
19
20
20
21
type CryptoKeyCache struct {
21
- ctx context.Context
22
- cancel context.CancelFunc
23
- client * wsproxysdk.Client
24
- logger slog.Logger
25
- Clock quartz.Clock
26
-
27
- keysMu sync.RWMutex
28
- keys map [int32 ]codersdk.CryptoKey
29
- latest codersdk.CryptoKey
30
- closed bool
22
+ refreshCtx context.Context
23
+ refreshCancel context.CancelFunc
24
+ client * wsproxysdk.Client
25
+ logger slog.Logger
26
+ Clock quartz.Clock
27
+
28
+ keysMu sync.RWMutex
29
+ keys map [int32 ]codersdk.CryptoKey
30
+ latest codersdk.CryptoKey
31
+ fetchLock sync.RWMutex
32
+ lastFetch time.Time
33
+ refresher * quartz.Timer
34
+ closed atomic.Bool
31
35
}
32
36
33
37
func NewCryptoKeyCache (ctx context.Context ,log slog.Logger ,client * wsproxysdk.Client ,opts ... func (* CryptoKeyCache )) (* CryptoKeyCache ,error ) {
@@ -46,21 +50,17 @@ func NewCryptoKeyCache(ctx context.Context, log slog.Logger, client *wsproxysdk.
46
50
return nil ,xerrors .Errorf ("initial fetch: %w" ,err )
47
51
}
48
52
cache .keys ,cache .latest = m ,latest
49
- cache .ctx ,cache .cancel = context .WithCancel (ctx )
50
-
51
- go cache .refresh ()
53
+ cache .refresher = cache .Clock .AfterFunc (time .Minute * 10 ,cache .refresh )
52
54
53
55
return cache ,nil
54
56
}
55
57
56
58
func (k * CryptoKeyCache )Signing (ctx context.Context ) (codersdk.CryptoKey ,error ) {
57
- k .keysMu .RLock ()
58
-
59
- if k .closed {
60
- k .keysMu .RUnlock ()
59
+ if k .isClosed () {
61
60
return codersdk.CryptoKey {},cryptokeys .ErrClosed
62
61
}
63
62
63
+ k .keysMu .RLock ()
64
64
latest := k .latest
65
65
k .keysMu .RUnlock ()
66
66
@@ -69,34 +69,31 @@ func (k *CryptoKeyCache) Signing(ctx context.Context) (codersdk.CryptoKey, error
69
69
return latest ,nil
70
70
}
71
71
72
- k .keysMu .Lock ()
73
- defer k .keysMu .Unlock ()
72
+ k .fetchLock .Lock ()
73
+ defer k .fetchLock .Unlock ()
74
74
75
- if k .closed {
75
+ if k .isClosed () {
76
76
return codersdk.CryptoKey {},cryptokeys .ErrClosed
77
77
}
78
78
79
+ k .keysMu .RLock ()
79
80
if k .latest .CanSign (now ) {
81
+ k .keysMu .RUnlock ()
80
82
return k .latest ,nil
81
83
}
82
84
83
- var err error
84
- k .keys ,k .latest ,err = k .fetch (ctx )
85
+ _ ,latest ,err := k .fetch (ctx )
85
86
if err != nil {
86
87
return codersdk.CryptoKey {},xerrors .Errorf ("fetch: %w" ,err )
87
88
}
88
89
89
- if ! k .latest .CanSign (now ) {
90
- return codersdk.CryptoKey {},cryptokeys .ErrKeyNotFound
91
- }
92
-
93
- return k .latest ,nil
90
+ return latest ,nil
94
91
}
95
92
96
93
func (k * CryptoKeyCache )Verifying (ctx context.Context ,sequence int32 ) (codersdk.CryptoKey ,error ) {
97
94
now := k .Clock .Now ()
98
95
k .keysMu .RLock ()
99
- if k .closed {
96
+ if k .isClosed () {
100
97
k .keysMu .RUnlock ()
101
98
return codersdk.CryptoKey {},cryptokeys .ErrClosed
102
99
}
@@ -110,7 +107,7 @@ func (k *CryptoKeyCache) Verifying(ctx context.Context, sequence int32) (codersd
110
107
k .keysMu .Lock ()
111
108
defer k .keysMu .Unlock ()
112
109
113
- if k .closed {
110
+ if k .isClosed () {
114
111
return codersdk.CryptoKey {},cryptokeys .ErrClosed
115
112
}
116
113
@@ -119,13 +116,12 @@ func (k *CryptoKeyCache) Verifying(ctx context.Context, sequence int32) (codersd
119
116
return validKey (key ,now )
120
117
}
121
118
122
- var err error
123
- k .keys ,k .latest ,err = k .fetch (ctx )
119
+ keys ,_ ,err := k .fetch (ctx )
124
120
if err != nil {
125
121
return codersdk.CryptoKey {},xerrors .Errorf ("fetch: %w" ,err )
126
122
}
127
123
128
- key ,ok = k . keys [sequence ]
124
+ key ,ok = keys [sequence ]
129
125
if ! ok {
130
126
return codersdk.CryptoKey {},cryptokeys .ErrKeyNotFound
131
127
}
@@ -134,28 +130,50 @@ func (k *CryptoKeyCache) Verifying(ctx context.Context, sequence int32) (codersd
134
130
}
135
131
136
132
func (k * CryptoKeyCache )refresh () {
137
- k .Clock .TickerFunc (k .ctx ,time .Minute * 10 ,func ()error {
138
- kmap ,latest ,err := k .fetch (k .ctx )
139
- if err != nil {
140
- k .logger .Error (k .ctx ,"failed to fetch crypto keys" ,slog .Error (err ))
141
- return nil
142
- }
133
+ if k .isClosed () {
134
+ return
135
+ }
136
+
137
+ k .keysMu .RLock ()
138
+ if k .Clock .Now ().Sub (k .lastFetch )< time .Minute * 10 {
139
+ k .keysMu .Unlock ()
140
+ return
141
+ }
142
+
143
+ k .fetchLock .Lock ()
144
+ defer k .fetchLock .Unlock ()
143
145
144
- k .keysMu .Lock ()
145
- defer k .keysMu .Unlock ()
146
- k .keys = kmap
147
- k .latest = latest
148
- return nil
149
- })
146
+ _ ,_ ,err := k .fetch (k .refreshCtx )
147
+ if err != nil {
148
+ k .logger .Error (k .refreshCtx ,"fetch crypto keys" ,slog .Error (err ))
149
+ return
150
+ }
150
151
}
151
152
152
153
func (k * CryptoKeyCache )fetch (ctx context.Context ) (map [int32 ]codersdk.CryptoKey , codersdk.CryptoKey ,error ) {
154
+
153
155
keys ,err := k .client .CryptoKeys (ctx )
154
156
if err != nil {
155
157
return nil , codersdk.CryptoKey {},xerrors .Errorf ("get security keys: %w" ,err )
156
158
}
157
159
158
- kmap ,latest := toKeyMap (keys .CryptoKeys ,k .Clock .Now ())
160
+ if len (keys .CryptoKeys )== 0 {
161
+ return nil , codersdk.CryptoKey {},cryptokeys .ErrKeyNotFound
162
+ }
163
+
164
+ now := k .Clock .Now ()
165
+ kmap ,latest := toKeyMap (keys .CryptoKeys ,now )
166
+ if ! latest .CanSign (now ) {
167
+ return nil , codersdk.CryptoKey {},cryptokeys .ErrKeyInvalid
168
+ }
169
+
170
+ k .keysMu .Lock ()
171
+ defer k .keysMu .Unlock ()
172
+
173
+ k .lastFetch = k .Clock .Now ()
174
+ k .refresher .Reset (time .Minute * 10 )
175
+ k .keys ,k .latest = kmap ,latest
176
+
159
177
return kmap ,latest ,nil
160
178
}
161
179
@@ -179,14 +197,18 @@ func validKey(key codersdk.CryptoKey, now time.Time) (codersdk.CryptoKey, error)
179
197
return key ,nil
180
198
}
181
199
200
+ func (k * CryptoKeyCache )isClosed ()bool {
201
+ return k .closed .Load ()
202
+ }
203
+
182
204
func (k * CryptoKeyCache )Close () {
183
205
k .keysMu .Lock ()
184
206
defer k .keysMu .Unlock ()
185
207
186
- if k .closed {
208
+ if k .isClosed () {
187
209
return
188
210
}
189
211
190
- k .cancel ()
191
- k .closed = true
212
+ k .refreshCancel ()
213
+ k .closed . Store ( true )
192
214
}