@@ -31,6 +31,11 @@ import (
31
31
32
32
var validProxyByHostnameRegex = regexp .MustCompile (`^[a-zA-Z0-9._-]+$` )
33
33
34
+ var errForeignKeyConstraint = & pq.Error {
35
+ Code :"23503" ,
36
+ Message :"update or delete on table violates foreign key constraint" ,
37
+ }
38
+
34
39
var errDuplicateKey = & pq.Error {
35
40
Code :"23505" ,
36
41
Message :"duplicate key value violates unique constraint" ,
@@ -45,6 +50,7 @@ func New() database.Store {
45
50
organizationMembers :make ([]database.OrganizationMember ,0 ),
46
51
organizations :make ([]database.Organization ,0 ),
47
52
users :make ([]database.User ,0 ),
53
+ dbcryptKeys :make ([]database.DBCryptKey ,0 ),
48
54
gitAuthLinks :make ([]database.GitAuthLink ,0 ),
49
55
groups :make ([]database.Group ,0 ),
50
56
groupMembers :make ([]database.GroupMember ,0 ),
@@ -117,6 +123,7 @@ type data struct {
117
123
// New tables
118
124
workspaceAgentStats []database.WorkspaceAgentStat
119
125
auditLogs []database.AuditLog
126
+ dbcryptKeys []database.DBCryptKey
120
127
files []database.File
121
128
gitAuthLinks []database.GitAuthLink
122
129
gitSSHKey []database.GitSSHKey
@@ -665,6 +672,19 @@ func (q *FakeQuerier) isEveryoneGroup(id uuid.UUID) bool {
665
672
return false
666
673
}
667
674
675
+ func (q * FakeQuerier )GetActiveDBCryptKeys (_ context.Context ) ([]database.DBCryptKey ,error ) {
676
+ q .mutex .RLock ()
677
+ defer q .mutex .RUnlock ()
678
+ ks := make ([]database.DBCryptKey ,0 ,len (q .dbcryptKeys ))
679
+ for _ ,k := range q .dbcryptKeys {
680
+ if ! k .ActiveKeyDigest .Valid {
681
+ continue
682
+ }
683
+ ks = append ([]database.DBCryptKey {},k )
684
+ }
685
+ return ks ,nil
686
+ }
687
+
668
688
func (* FakeQuerier )AcquireLock (_ context.Context ,_ int64 )error {
669
689
return xerrors .New ("AcquireLock must only be called within a transaction" )
670
690
}
@@ -1151,6 +1171,14 @@ func (q *FakeQuerier) GetAuthorizationUserRoles(_ context.Context, userID uuid.U
1151
1171
},nil
1152
1172
}
1153
1173
1174
+ func (q * FakeQuerier )GetDBCryptKeys (_ context.Context ) ([]database.DBCryptKey ,error ) {
1175
+ q .mutex .RLock ()
1176
+ defer q .mutex .RUnlock ()
1177
+ ks := make ([]database.DBCryptKey ,0 )
1178
+ ks = append (ks ,q .dbcryptKeys ... )
1179
+ return ks ,nil
1180
+ }
1181
+
1154
1182
func (q * FakeQuerier )GetDERPMeshKey (_ context.Context ) (string ,error ) {
1155
1183
q .mutex .RLock ()
1156
1184
defer q .mutex .RUnlock ()
@@ -1393,6 +1421,18 @@ func (q *FakeQuerier) GetGitAuthLink(_ context.Context, arg database.GetGitAuthL
1393
1421
return database.GitAuthLink {},sql .ErrNoRows
1394
1422
}
1395
1423
1424
+ func (q * FakeQuerier )GetGitAuthLinksByUserID (_ context.Context ,userID uuid.UUID ) ([]database.GitAuthLink ,error ) {
1425
+ q .mutex .RLock ()
1426
+ defer q .mutex .RUnlock ()
1427
+ gals := make ([]database.GitAuthLink ,0 )
1428
+ for _ ,gal := range q .gitAuthLinks {
1429
+ if gal .UserID == userID {
1430
+ gals = append (gals ,gal )
1431
+ }
1432
+ }
1433
+ return gals ,nil
1434
+ }
1435
+
1396
1436
func (q * FakeQuerier )GetGitSSHKey (_ context.Context ,userID uuid.UUID ) (database.GitSSHKey ,error ) {
1397
1437
q .mutex .RLock ()
1398
1438
defer q .mutex .RUnlock ()
@@ -2833,6 +2873,18 @@ func (q *FakeQuerier) GetUserLinkByUserIDLoginType(_ context.Context, params dat
2833
2873
return database.UserLink {},sql .ErrNoRows
2834
2874
}
2835
2875
2876
+ func (q * FakeQuerier )GetUserLinksByUserID (_ context.Context ,userID uuid.UUID ) ([]database.UserLink ,error ) {
2877
+ q .mutex .RLock ()
2878
+ defer q .mutex .RUnlock ()
2879
+ uls := make ([]database.UserLink ,0 )
2880
+ for _ ,ul := range q .userLinks {
2881
+ if ul .UserID == userID {
2882
+ uls = append (uls ,ul )
2883
+ }
2884
+ }
2885
+ return uls ,nil
2886
+ }
2887
+
2836
2888
func (q * FakeQuerier )GetUsers (_ context.Context ,params database.GetUsersParams ) ([]database.GetUsersRow ,error ) {
2837
2889
if err := validateDatabaseType (params );err != nil {
2838
2890
return nil ,err
@@ -3846,6 +3898,26 @@ func (q *FakeQuerier) InsertAuditLog(_ context.Context, arg database.InsertAudit
3846
3898
return alog ,nil
3847
3899
}
3848
3900
3901
+ func (q * FakeQuerier )InsertDBCryptKey (_ context.Context ,arg database.InsertDBCryptKeyParams )error {
3902
+ err := validateDatabaseType (arg )
3903
+ if err != nil {
3904
+ return err
3905
+ }
3906
+
3907
+ for _ ,key := range q .dbcryptKeys {
3908
+ if key .Number == arg .Number {
3909
+ return errDuplicateKey
3910
+ }
3911
+ }
3912
+
3913
+ q .dbcryptKeys = append (q .dbcryptKeys , database.DBCryptKey {
3914
+ Number :arg .Number ,
3915
+ ActiveKeyDigest : sql.NullString {String :arg .ActiveKeyDigest ,Valid :true },
3916
+ Test :arg .Test ,
3917
+ })
3918
+ return nil
3919
+ }
3920
+
3849
3921
func (q * FakeQuerier )InsertDERPMeshKey (_ context.Context ,id string )error {
3850
3922
q .mutex .Lock ()
3851
3923
defer q .mutex .Unlock ()
@@ -3892,13 +3964,15 @@ func (q *FakeQuerier) InsertGitAuthLink(_ context.Context, arg database.InsertGi
3892
3964
defer q .mutex .Unlock ()
3893
3965
// nolint:gosimple
3894
3966
gitAuthLink := database.GitAuthLink {
3895
- ProviderID :arg .ProviderID ,
3896
- UserID :arg .UserID ,
3897
- CreatedAt :arg .CreatedAt ,
3898
- UpdatedAt :arg .UpdatedAt ,
3899
- OAuthAccessToken :arg .OAuthAccessToken ,
3900
- OAuthRefreshToken :arg .OAuthRefreshToken ,
3901
- OAuthExpiry :arg .OAuthExpiry ,
3967
+ ProviderID :arg .ProviderID ,
3968
+ UserID :arg .UserID ,
3969
+ CreatedAt :arg .CreatedAt ,
3970
+ UpdatedAt :arg .UpdatedAt ,
3971
+ OAuthAccessToken :arg .OAuthAccessToken ,
3972
+ OAuthAccessTokenKeyID :arg .OAuthAccessTokenKeyID ,
3973
+ OAuthRefreshToken :arg .OAuthRefreshToken ,
3974
+ OAuthRefreshTokenKeyID :arg .OAuthRefreshTokenKeyID ,
3975
+ OAuthExpiry :arg .OAuthExpiry ,
3902
3976
}
3903
3977
q .gitAuthLinks = append (q .gitAuthLinks ,gitAuthLink )
3904
3978
return gitAuthLink ,nil
@@ -4362,12 +4436,14 @@ func (q *FakeQuerier) InsertUserLink(_ context.Context, args database.InsertUser
4362
4436
4363
4437
//nolint:gosimple
4364
4438
link := database.UserLink {
4365
- UserID :args .UserID ,
4366
- LoginType :args .LoginType ,
4367
- LinkedID :args .LinkedID ,
4368
- OAuthAccessToken :args .OAuthAccessToken ,
4369
- OAuthRefreshToken :args .OAuthRefreshToken ,
4370
- OAuthExpiry :args .OAuthExpiry ,
4439
+ UserID :args .UserID ,
4440
+ LoginType :args .LoginType ,
4441
+ LinkedID :args .LinkedID ,
4442
+ OAuthAccessToken :args .OAuthAccessToken ,
4443
+ OAuthAccessTokenKeyID :args .OAuthAccessTokenKeyID ,
4444
+ OAuthRefreshToken :args .OAuthRefreshToken ,
4445
+ OAuthRefreshTokenKeyID :args .OAuthRefreshTokenKeyID ,
4446
+ OAuthExpiry :args .OAuthExpiry ,
4371
4447
}
4372
4448
4373
4449
q .userLinks = append (q .userLinks ,link )
@@ -4793,6 +4869,46 @@ func (q *FakeQuerier) RegisterWorkspaceProxy(_ context.Context, arg database.Reg
4793
4869
return database.WorkspaceProxy {},sql .ErrNoRows
4794
4870
}
4795
4871
4872
+ func (q * FakeQuerier )RevokeDBCryptKey (_ context.Context ,activeKeyDigest string )error {
4873
+ q .mutex .Lock ()
4874
+ defer q .mutex .Unlock ()
4875
+
4876
+ for i := range q .dbcryptKeys {
4877
+ key := q .dbcryptKeys [i ]
4878
+
4879
+ // Is the key already revoked?
4880
+ if ! key .ActiveKeyDigest .Valid {
4881
+ continue
4882
+ }
4883
+
4884
+ if key .ActiveKeyDigest .String != activeKeyDigest {
4885
+ continue
4886
+ }
4887
+
4888
+ // Check for foreign key constraints.
4889
+ for _ ,ul := range q .userLinks {
4890
+ if (ul .OAuthAccessTokenKeyID .Valid && ul .OAuthAccessTokenKeyID .String == activeKeyDigest )||
4891
+ (ul .OAuthRefreshTokenKeyID .Valid && ul .OAuthRefreshTokenKeyID .String == activeKeyDigest ) {
4892
+ return errForeignKeyConstraint
4893
+ }
4894
+ }
4895
+ for _ ,gal := range q .gitAuthLinks {
4896
+ if (gal .OAuthAccessTokenKeyID .Valid && gal .OAuthAccessTokenKeyID .String == activeKeyDigest )||
4897
+ (gal .OAuthRefreshTokenKeyID .Valid && gal .OAuthRefreshTokenKeyID .String == activeKeyDigest ) {
4898
+ return errForeignKeyConstraint
4899
+ }
4900
+ }
4901
+
4902
+ // Revoke the key.
4903
+ q .dbcryptKeys [i ].RevokedAt = sql.NullTime {Time :dbtime .Now (),Valid :true }
4904
+ q .dbcryptKeys [i ].RevokedKeyDigest = sql.NullString {String :key .ActiveKeyDigest .String ,Valid :true }
4905
+ q .dbcryptKeys [i ].ActiveKeyDigest = sql.NullString {}
4906
+ return nil
4907
+ }
4908
+
4909
+ return sql .ErrNoRows
4910
+ }
4911
+
4796
4912
func (* FakeQuerier )TryAcquireLock (_ context.Context ,_ int64 ) (bool ,error ) {
4797
4913
return false ,xerrors .New ("TryAcquireLock must only be called within a transaction" )
4798
4914
}
@@ -4834,7 +4950,9 @@ func (q *FakeQuerier) UpdateGitAuthLink(_ context.Context, arg database.UpdateGi
4834
4950
}
4835
4951
gitAuthLink .UpdatedAt = arg .UpdatedAt
4836
4952
gitAuthLink .OAuthAccessToken = arg .OAuthAccessToken
4953
+ gitAuthLink .OAuthAccessTokenKeyID = arg .OAuthAccessTokenKeyID
4837
4954
gitAuthLink .OAuthRefreshToken = arg .OAuthRefreshToken
4955
+ gitAuthLink .OAuthRefreshTokenKeyID = arg .OAuthRefreshTokenKeyID
4838
4956
gitAuthLink .OAuthExpiry = arg .OAuthExpiry
4839
4957
q .gitAuthLinks [index ]= gitAuthLink
4840
4958
@@ -5306,7 +5424,9 @@ func (q *FakeQuerier) UpdateUserLink(_ context.Context, params database.UpdateUs
5306
5424
for i ,link := range q .userLinks {
5307
5425
if link .UserID == params .UserID && link .LoginType == params .LoginType {
5308
5426
link .OAuthAccessToken = params .OAuthAccessToken
5427
+ link .OAuthAccessTokenKeyID = params .OAuthAccessTokenKeyID
5309
5428
link .OAuthRefreshToken = params .OAuthRefreshToken
5429
+ link .OAuthRefreshTokenKeyID = params .OAuthRefreshTokenKeyID
5310
5430
link .OAuthExpiry = params .OAuthExpiry
5311
5431
5312
5432
q .userLinks [i ]= link