|
| 1 | +package cryptokeys |
| 2 | + |
| 3 | +import ( |
| 4 | +"context" |
| 5 | +"encoding/hex" |
| 6 | +"io" |
| 7 | +"strconv" |
| 8 | +"sync" |
| 9 | +"time" |
| 10 | + |
| 11 | +"golang.org/x/xerrors" |
| 12 | + |
| 13 | +"cdr.dev/slog" |
| 14 | +"github.com/coder/coder/v2/coderd/database" |
| 15 | +"github.com/coder/coder/v2/coderd/database/db2sdk" |
| 16 | +"github.com/coder/coder/v2/codersdk" |
| 17 | +"github.com/coder/quartz" |
| 18 | +) |
| 19 | + |
| 20 | +var ( |
| 21 | +ErrKeyNotFound=xerrors.New("key not found") |
| 22 | +ErrKeyInvalid=xerrors.New("key is invalid for use") |
| 23 | +ErrClosed=xerrors.New("closed") |
| 24 | +ErrInvalidFeature=xerrors.New("invalid feature for this operation") |
| 25 | +) |
| 26 | + |
| 27 | +typeFetcherinterface { |
| 28 | +Fetch(ctx context.Context) ([]codersdk.CryptoKey,error) |
| 29 | +} |
| 30 | + |
| 31 | +typeEncryptionKeycacheinterface { |
| 32 | +// EncryptingKey returns the latest valid key for encrypting payloads. A valid |
| 33 | +// key is one that is both past its start time and before its deletion time. |
| 34 | +EncryptingKey(ctx context.Context) (idstring,keyinterface{},errerror) |
| 35 | +// DecryptingKey returns the key with the provided id which maps to its sequence |
| 36 | +// number. The key is valid for decryption as long as it is not deleted or past |
| 37 | +// its deletion date. We must allow for keys prior to their start time to |
| 38 | +// account for clock skew between peers (one key may be past its start time on |
| 39 | +// one machine while another is not). |
| 40 | +DecryptingKey(ctx context.Context,idstring) (keyinterface{},errerror) |
| 41 | +io.Closer |
| 42 | +} |
| 43 | + |
| 44 | +typeSigningKeycacheinterface { |
| 45 | +// SigningKey returns the latest valid key for signing. A valid key is one |
| 46 | +// that is both past its start time and before its deletion time. |
| 47 | +SigningKey(ctx context.Context) (idstring,keyinterface{},errerror) |
| 48 | +// VerifyingKey returns the key with the provided id which should map to its |
| 49 | +// sequence number. The key is valid for verifying as long as it is not deleted |
| 50 | +// or past its deletion date. We must allow for keys prior to their start time |
| 51 | +// to account for clock skew between peers (one key may be past its start time |
| 52 | +// on one machine while another is not). |
| 53 | +VerifyingKey(ctx context.Context,idstring) (keyinterface{},errerror) |
| 54 | +io.Closer |
| 55 | +} |
| 56 | + |
| 57 | +const ( |
| 58 | +// latestSequence is a special sequence number that represents the latest key. |
| 59 | +latestSequence=-1 |
| 60 | +// refreshInterval is the interval at which the key cache will refresh. |
| 61 | +refreshInterval=time.Minute*10 |
| 62 | +) |
| 63 | + |
| 64 | +typeDBFetcherstruct { |
| 65 | +DB database.Store |
| 66 | +Feature database.CryptoKeyFeature |
| 67 | +} |
| 68 | + |
| 69 | +func (d*DBFetcher)Fetch(ctx context.Context) ([]codersdk.CryptoKey,error) { |
| 70 | +keys,err:=d.DB.GetCryptoKeysByFeature(ctx,d.Feature) |
| 71 | +iferr!=nil { |
| 72 | +returnnil,xerrors.Errorf("get crypto keys by feature: %w",err) |
| 73 | +} |
| 74 | + |
| 75 | +returndb2sdk.CryptoKeys(keys),nil |
| 76 | +} |
| 77 | + |
| 78 | +// cache implements the caching functionality for both signing and encryption keys. |
| 79 | +typecachestruct { |
| 80 | +clock quartz.Clock |
| 81 | +refreshCtx context.Context |
| 82 | +refreshCancel context.CancelFunc |
| 83 | +fetcherFetcher |
| 84 | +logger slog.Logger |
| 85 | +feature codersdk.CryptoKeyFeature |
| 86 | + |
| 87 | +mu sync.Mutex |
| 88 | +keysmap[int32]codersdk.CryptoKey |
| 89 | +lastFetch time.Time |
| 90 | +refresher*quartz.Timer |
| 91 | +fetchingbool |
| 92 | +closedbool |
| 93 | +cond*sync.Cond |
| 94 | +} |
| 95 | + |
| 96 | +typeCacheOptionfunc(*cache) |
| 97 | + |
| 98 | +funcWithCacheClock(clock quartz.Clock)CacheOption { |
| 99 | +returnfunc(d*cache) { |
| 100 | +d.clock=clock |
| 101 | +} |
| 102 | +} |
| 103 | + |
| 104 | +// NewSigningCache instantiates a cache. Close should be called to release resources |
| 105 | +// associated with its internal timer. |
| 106 | +funcNewSigningCache(ctx context.Context,logger slog.Logger,fetcherFetcher, |
| 107 | +feature codersdk.CryptoKeyFeature,opts...func(*cache), |
| 108 | +) (SigningKeycache,error) { |
| 109 | +if!isSigningKeyFeature(feature) { |
| 110 | +returnnil,xerrors.Errorf("invalid feature: %s",feature) |
| 111 | +} |
| 112 | +returnnewCache(ctx,logger,fetcher,feature,opts...) |
| 113 | +} |
| 114 | + |
| 115 | +funcNewEncryptionCache(ctx context.Context,logger slog.Logger,fetcherFetcher, |
| 116 | +feature codersdk.CryptoKeyFeature,opts...func(*cache), |
| 117 | +) (EncryptionKeycache,error) { |
| 118 | +if!isEncryptionKeyFeature(feature) { |
| 119 | +returnnil,xerrors.Errorf("invalid feature: %s",feature) |
| 120 | +} |
| 121 | +returnnewCache(ctx,logger,fetcher,feature,opts...) |
| 122 | +} |
| 123 | + |
| 124 | +funcnewCache(ctx context.Context,logger slog.Logger,fetcherFetcher,feature codersdk.CryptoKeyFeature,opts...func(*cache)) (*cache,error) { |
| 125 | +cache:=&cache{ |
| 126 | +clock:quartz.NewReal(), |
| 127 | +logger:logger, |
| 128 | +fetcher:fetcher, |
| 129 | +feature:feature, |
| 130 | +} |
| 131 | + |
| 132 | +for_,opt:=rangeopts { |
| 133 | +opt(cache) |
| 134 | +} |
| 135 | + |
| 136 | +cache.cond=sync.NewCond(&cache.mu) |
| 137 | +cache.refreshCtx,cache.refreshCancel=context.WithCancel(ctx) |
| 138 | +cache.refresher=cache.clock.AfterFunc(refreshInterval,cache.refresh) |
| 139 | + |
| 140 | +keys,err:=cache.cryptoKeys(ctx) |
| 141 | +iferr!=nil { |
| 142 | +cache.refreshCancel() |
| 143 | +returnnil,xerrors.Errorf("initial fetch: %w",err) |
| 144 | +} |
| 145 | +cache.keys=keys |
| 146 | +returncache,nil |
| 147 | +} |
| 148 | + |
| 149 | +func (c*cache)EncryptingKey(ctx context.Context) (string,interface{},error) { |
| 150 | +if!isEncryptionKeyFeature(c.feature) { |
| 151 | +return"",nil,ErrInvalidFeature |
| 152 | +} |
| 153 | + |
| 154 | +returnc.cryptoKey(ctx,latestSequence) |
| 155 | +} |
| 156 | + |
| 157 | +func (c*cache)DecryptingKey(ctx context.Context,idstring) (interface{},error) { |
| 158 | +if!isEncryptionKeyFeature(c.feature) { |
| 159 | +returnnil,ErrInvalidFeature |
| 160 | +} |
| 161 | + |
| 162 | +seq,err:=strconv.ParseInt(id,10,64) |
| 163 | +iferr!=nil { |
| 164 | +returnnil,xerrors.Errorf("parse id: %w",err) |
| 165 | +} |
| 166 | + |
| 167 | +_,secret,err:=c.cryptoKey(ctx,int32(seq)) |
| 168 | +iferr!=nil { |
| 169 | +returnnil,xerrors.Errorf("crypto key: %w",err) |
| 170 | +} |
| 171 | +returnsecret,nil |
| 172 | +} |
| 173 | + |
| 174 | +func (c*cache)SigningKey(ctx context.Context) (string,interface{},error) { |
| 175 | +if!isSigningKeyFeature(c.feature) { |
| 176 | +return"",nil,ErrInvalidFeature |
| 177 | +} |
| 178 | + |
| 179 | +returnc.cryptoKey(ctx,latestSequence) |
| 180 | +} |
| 181 | + |
| 182 | +func (c*cache)VerifyingKey(ctx context.Context,idstring) (interface{},error) { |
| 183 | +if!isSigningKeyFeature(c.feature) { |
| 184 | +returnnil,ErrInvalidFeature |
| 185 | +} |
| 186 | + |
| 187 | +seq,err:=strconv.ParseInt(id,10,64) |
| 188 | +iferr!=nil { |
| 189 | +returnnil,xerrors.Errorf("parse id: %w",err) |
| 190 | +} |
| 191 | + |
| 192 | +_,secret,err:=c.cryptoKey(ctx,int32(seq)) |
| 193 | +iferr!=nil { |
| 194 | +returnnil,xerrors.Errorf("crypto key: %w",err) |
| 195 | +} |
| 196 | + |
| 197 | +returnsecret,nil |
| 198 | +} |
| 199 | + |
| 200 | +funcisEncryptionKeyFeature(feature codersdk.CryptoKeyFeature)bool { |
| 201 | +returnfeature==codersdk.CryptoKeyFeatureWorkspaceApp |
| 202 | +} |
| 203 | + |
| 204 | +funcisSigningKeyFeature(feature codersdk.CryptoKeyFeature)bool { |
| 205 | +switchfeature { |
| 206 | +casecodersdk.CryptoKeyFeatureTailnetResume,codersdk.CryptoKeyFeatureOIDCConvert: |
| 207 | +returntrue |
| 208 | +default: |
| 209 | +returnfalse |
| 210 | +} |
| 211 | +} |
| 212 | + |
| 213 | +funcidSecret(k codersdk.CryptoKey) (string, []byte,error) { |
| 214 | +key,err:=hex.DecodeString(k.Secret) |
| 215 | +iferr!=nil { |
| 216 | +return"",nil,xerrors.Errorf("decode key: %w",err) |
| 217 | +} |
| 218 | + |
| 219 | +returnstrconv.FormatInt(int64(k.Sequence),10),key,nil |
| 220 | +} |
| 221 | + |
| 222 | +func (c*cache)cryptoKey(ctx context.Context,sequenceint32) (string, []byte,error) { |
| 223 | +c.mu.Lock() |
| 224 | +deferc.mu.Unlock() |
| 225 | + |
| 226 | +ifc.closed { |
| 227 | +return"",nil,ErrClosed |
| 228 | +} |
| 229 | + |
| 230 | +varkey codersdk.CryptoKey |
| 231 | +varokbool |
| 232 | +forkey,ok=c.key(sequence);!ok&&c.fetching&&!c.closed; { |
| 233 | +c.cond.Wait() |
| 234 | +} |
| 235 | + |
| 236 | +ifc.closed { |
| 237 | +return"",nil,ErrClosed |
| 238 | +} |
| 239 | + |
| 240 | +ifok { |
| 241 | +returncheckKey(key,sequence,c.clock.Now()) |
| 242 | +} |
| 243 | + |
| 244 | +c.fetching=true |
| 245 | +c.mu.Unlock() |
| 246 | + |
| 247 | +keys,err:=c.cryptoKeys(ctx) |
| 248 | +iferr!=nil { |
| 249 | +return"",nil,xerrors.Errorf("get keys: %w",err) |
| 250 | +} |
| 251 | + |
| 252 | +c.mu.Lock() |
| 253 | +c.lastFetch=c.clock.Now() |
| 254 | +c.refresher.Reset(refreshInterval) |
| 255 | +c.keys=keys |
| 256 | +c.fetching=false |
| 257 | +c.cond.Broadcast() |
| 258 | + |
| 259 | +key,ok=c.key(sequence) |
| 260 | +if!ok { |
| 261 | +return"",nil,ErrKeyNotFound |
| 262 | +} |
| 263 | + |
| 264 | +returncheckKey(key,sequence,c.clock.Now()) |
| 265 | +} |
| 266 | + |
| 267 | +func (c*cache)key(sequenceint32) (codersdk.CryptoKey,bool) { |
| 268 | +ifsequence==latestSequence { |
| 269 | +returnc.keys[latestSequence],c.keys[latestSequence].CanSign(c.clock.Now()) |
| 270 | +} |
| 271 | + |
| 272 | +key,ok:=c.keys[sequence] |
| 273 | +returnkey,ok |
| 274 | +} |
| 275 | + |
| 276 | +funccheckKey(key codersdk.CryptoKey,sequenceint32,now time.Time) (string, []byte,error) { |
| 277 | +ifsequence==latestSequence { |
| 278 | +if!key.CanSign(now) { |
| 279 | +return"",nil,ErrKeyInvalid |
| 280 | +} |
| 281 | +returnidSecret(key) |
| 282 | +} |
| 283 | + |
| 284 | +if!key.CanVerify(now) { |
| 285 | +return"",nil,ErrKeyInvalid |
| 286 | +} |
| 287 | + |
| 288 | +returnidSecret(key) |
| 289 | +} |
| 290 | + |
| 291 | +// refresh fetches the keys and updates the cache. |
| 292 | +func (c*cache)refresh() { |
| 293 | +now:=c.clock.Now("CryptoKeyCache","refresh") |
| 294 | +c.mu.Lock() |
| 295 | +deferc.mu.Unlock() |
| 296 | + |
| 297 | +ifc.closed { |
| 298 | +return |
| 299 | +} |
| 300 | + |
| 301 | +// If something's already fetching, we don't need to do anything. |
| 302 | +ifc.fetching { |
| 303 | +return |
| 304 | +} |
| 305 | + |
| 306 | +// There's a window we must account for where the timer fires while a fetch |
| 307 | +// is ongoing but prior to the timer getting reset. In this case we want to |
| 308 | +// avoid double fetching. |
| 309 | +ifnow.Sub(c.lastFetch)<refreshInterval { |
| 310 | +return |
| 311 | +} |
| 312 | + |
| 313 | +c.fetching=true |
| 314 | + |
| 315 | +c.mu.Unlock() |
| 316 | +keys,err:=c.cryptoKeys(c.refreshCtx) |
| 317 | +iferr!=nil { |
| 318 | +c.logger.Error(c.refreshCtx,"fetch crypto keys",slog.Error(err)) |
| 319 | +return |
| 320 | +} |
| 321 | + |
| 322 | +// We don't defer an unlock here due to the deferred unlock at the top of the function. |
| 323 | +c.mu.Lock() |
| 324 | + |
| 325 | +c.lastFetch=c.clock.Now() |
| 326 | +c.refresher.Reset(refreshInterval) |
| 327 | +c.keys=keys |
| 328 | +c.fetching=false |
| 329 | +c.cond.Broadcast() |
| 330 | +} |
| 331 | + |
| 332 | +// cryptoKeys queries the control plane for the crypto keys. |
| 333 | +// Outside of initialization, this should only be called by fetch. |
| 334 | +func (c*cache)cryptoKeys(ctx context.Context) (map[int32]codersdk.CryptoKey,error) { |
| 335 | +keys,err:=c.fetcher.Fetch(ctx) |
| 336 | +iferr!=nil { |
| 337 | +returnnil,xerrors.Errorf("crypto keys: %w",err) |
| 338 | +} |
| 339 | +cache:=toKeyMap(keys,c.clock.Now()) |
| 340 | +returncache,nil |
| 341 | +} |
| 342 | + |
| 343 | +functoKeyMap(keys []codersdk.CryptoKey,now time.Time)map[int32]codersdk.CryptoKey { |
| 344 | +m:=make(map[int32]codersdk.CryptoKey) |
| 345 | +varlatest codersdk.CryptoKey |
| 346 | +for_,key:=rangekeys { |
| 347 | +m[key.Sequence]=key |
| 348 | +ifkey.Sequence>latest.Sequence&&key.CanSign(now) { |
| 349 | +m[latestSequence]=key |
| 350 | +} |
| 351 | +} |
| 352 | +returnm |
| 353 | +} |
| 354 | + |
| 355 | +func (c*cache)Close()error { |
| 356 | +c.mu.Lock() |
| 357 | +deferc.mu.Unlock() |
| 358 | + |
| 359 | +ifc.closed { |
| 360 | +returnnil |
| 361 | +} |
| 362 | + |
| 363 | +c.closed=true |
| 364 | +c.refreshCancel() |
| 365 | +c.refresher.Stop() |
| 366 | +c.cond.Broadcast() |
| 367 | + |
| 368 | +returnnil |
| 369 | +} |