@@ -45,12 +45,14 @@ func NewCryptoKeyCache(ctx context.Context, log slog.Logger, client *wsproxysdk.
45
45
opt (cache )
46
46
}
47
47
48
- m ,latest ,err := cache .fetch (ctx )
48
+ cache .refreshCtx ,cache .refreshCancel = context .WithCancel (ctx )
49
+ cache .refresher = cache .Clock .AfterFunc (time .Minute * 10 ,cache .refresh )
50
+ m ,latest ,err := cache .fetchKeys (ctx )
49
51
if err != nil {
52
+ cache .refreshCancel ()
50
53
return nil ,xerrors .Errorf ("initial fetch: %w" ,err )
51
54
}
52
55
cache .keys ,cache .latest = m ,latest
53
- cache .refresher = cache .Clock .AfterFunc (time .Minute * 10 ,cache .refresh )
54
56
55
57
return cache ,nil
56
58
}
@@ -77,9 +79,12 @@ func (k *CryptoKeyCache) Signing(ctx context.Context) (codersdk.CryptoKey, error
77
79
}
78
80
79
81
k .keysMu .RLock ()
80
- if k .latest .CanSign (now ) {
81
- k .keysMu .RUnlock ()
82
- return k .latest ,nil
82
+ latest = k .latest
83
+ k .keysMu .RUnlock ()
84
+
85
+ now = k .Clock .Now ()
86
+ if latest .CanSign (now ) {
87
+ return latest ,nil
83
88
}
84
89
85
90
_ ,latest ,err := k .fetch (ctx )
@@ -91,27 +96,28 @@ func (k *CryptoKeyCache) Signing(ctx context.Context) (codersdk.CryptoKey, error
91
96
}
92
97
93
98
func (k * CryptoKeyCache )Verifying (ctx context.Context ,sequence int32 ) (codersdk.CryptoKey ,error ) {
94
- now := k .Clock .Now ()
95
- k .keysMu .RLock ()
96
99
if k .isClosed () {
97
- k .keysMu .RUnlock ()
98
100
return codersdk.CryptoKey {},cryptokeys .ErrClosed
99
101
}
100
102
103
+ now := k .Clock .Now ()
104
+ k .keysMu .RLock ()
101
105
key ,ok := k .keys [sequence ]
102
106
k .keysMu .RUnlock ()
103
107
if ok {
104
108
return validKey (key ,now )
105
109
}
106
110
107
- k .keysMu .Lock ()
108
- defer k .keysMu .Unlock ()
111
+ k .fetchLock .Lock ()
112
+ defer k .fetchLock .Unlock ()
109
113
110
114
if k .isClosed () {
111
115
return codersdk.CryptoKey {},cryptokeys .ErrClosed
112
116
}
113
117
118
+ k .keysMu .RLock ()
114
119
key ,ok = k .keys [sequence ]
120
+ k .keysMu .RUnlock ()
115
121
if ok {
116
122
return validKey (key ,now )
117
123
}
@@ -134,14 +140,23 @@ func (k *CryptoKeyCache) refresh() {
134
140
return
135
141
}
136
142
137
- k .keysMu .RLock ()
138
- if k .Clock .Now ().Sub (k .lastFetch )< time .Minute * 10 {
139
- k .keysMu .Unlock ()
143
+ k .fetchLock .Lock ()
144
+ defer k .fetchLock .Unlock ()
145
+
146
+ if k .isClosed () {
140
147
return
141
148
}
142
149
143
- k .fetchLock .Lock ()
144
- defer k .fetchLock .Unlock ()
150
+ k .keysMu .RLock ()
151
+ lastFetch := k .lastFetch
152
+ k .keysMu .RUnlock ()
153
+
154
+ // There's a window we must account for where the timer fires while a fetch
155
+ // is ongoing but prior to the timer getting reset. In this case we want to
156
+ // avoid double fetching.
157
+ if k .Clock .Now ().Sub (lastFetch )< time .Minute * 10 {
158
+ return
159
+ }
145
160
146
161
_ ,_ ,err := k .fetch (k .refreshCtx )
147
162
if err != nil {
@@ -150,19 +165,28 @@ func (k *CryptoKeyCache) refresh() {
150
165
}
151
166
}
152
167
153
- func (k * CryptoKeyCache )fetch (ctx context.Context ) (map [int32 ]codersdk.CryptoKey , codersdk.CryptoKey ,error ) {
154
-
168
+ func (k * CryptoKeyCache )fetchKeys (ctx context.Context ) (map [int32 ]codersdk.CryptoKey , codersdk.CryptoKey ,error ) {
155
169
keys ,err := k .client .CryptoKeys (ctx )
156
170
if err != nil {
157
- return nil , codersdk.CryptoKey {},xerrors .Errorf ("get security keys: %w" ,err )
171
+ return nil , codersdk.CryptoKey {},xerrors .Errorf ("crypto keys: %w" ,err )
158
172
}
173
+ cache ,latest := toKeyMap (keys .CryptoKeys ,k .Clock .Now ())
174
+ return cache ,latest ,nil
175
+ }
159
176
160
- if len (keys .CryptoKeys )== 0 {
177
+ // fetch fetches the keys from the control plane and updates the cache. The fetchMu
178
+ // must be held when calling this function to avoid multiple concurrent fetches.
179
+ func (k * CryptoKeyCache )fetch (ctx context.Context ) (map [int32 ]codersdk.CryptoKey , codersdk.CryptoKey ,error ) {
180
+ keys ,latest ,err := k .fetchKeys (ctx )
181
+ if err != nil {
182
+ return nil , codersdk.CryptoKey {},xerrors .Errorf ("fetch keys: %w" ,err )
183
+ }
184
+
185
+ if len (keys )== 0 {
161
186
return nil , codersdk.CryptoKey {},cryptokeys .ErrKeyNotFound
162
187
}
163
188
164
189
now := k .Clock .Now ()
165
- kmap ,latest := toKeyMap (keys .CryptoKeys ,now )
166
190
if ! latest .CanSign (now ) {
167
191
return nil , codersdk.CryptoKey {},cryptokeys .ErrKeyInvalid
168
192
}
@@ -172,9 +196,9 @@ func (k *CryptoKeyCache) fetch(ctx context.Context) (map[int32]codersdk.CryptoKe
172
196
173
197
k .lastFetch = k .Clock .Now ()
174
198
k .refresher .Reset (time .Minute * 10 )
175
- k .keys ,k .latest = kmap ,latest
199
+ k .keys ,k .latest = keys ,latest
176
200
177
- return kmap ,latest ,nil
201
+ return keys ,latest ,nil
178
202
}
179
203
180
204
func toKeyMap (keys []codersdk.CryptoKey ,now time.Time ) (map [int32 ]codersdk.CryptoKey , codersdk.CryptoKey ) {
@@ -202,6 +226,11 @@ func (k *CryptoKeyCache) isClosed() bool {
202
226
}
203
227
204
228
func (k * CryptoKeyCache )Close () {
229
+ // The fetch lock must always be held before holding the keys lock
230
+ // otherwise we risk a deadlock.
231
+ k .fetchLock .Lock ()
232
+ defer k .fetchLock .Unlock ()
233
+
205
234
k .keysMu .Lock ()
206
235
defer k .keysMu .Unlock ()
207
236