88
99"github.com/google/uuid"
1010"github.com/prometheus/client_golang/prometheus"
11- "github.com/spf13/afero"
1211"github.com/stretchr/testify/require"
12+ "go.uber.org/mock/gomock"
1313"golang.org/x/sync/errgroup"
1414
1515"cdr.dev/slog/sloggers/slogtest"
@@ -18,6 +18,7 @@ import (
1818"github.com/coder/coder/v2/coderd/database"
1919"github.com/coder/coder/v2/coderd/database/dbauthz"
2020"github.com/coder/coder/v2/coderd/database/dbgen"
21+ "github.com/coder/coder/v2/coderd/database/dbmock"
2122"github.com/coder/coder/v2/coderd/database/dbtestutil"
2223"github.com/coder/coder/v2/coderd/files"
2324"github.com/coder/coder/v2/coderd/rbac"
@@ -58,7 +59,7 @@ func TestCacheRBAC(t *testing.T) {
5859require .Equal (t ,0 ,cache .Count ())
5960rec .Reset ()
6061
61- _ ,err := cache .Acquire (nobody ,file .ID )
62+ _ ,err := cache .Acquire (nobody ,db , file .ID )
6263require .Error (t ,err )
6364require .True (t ,rbac .IsUnauthorizedError (err ))
6465
@@ -75,18 +76,18 @@ func TestCacheRBAC(t *testing.T) {
7576require .Equal (t ,0 ,cache .Count ())
7677
7778// Read the file with a file reader to put it into the cache.
78- a ,err := cache .Acquire (cacheReader ,file .ID )
79+ a ,err := cache .Acquire (cacheReader ,db , file .ID )
7980require .NoError (t ,err )
8081require .Equal (t ,1 ,cache .Count ())
8182
8283// "nobody" should not be able to read the file.
83- _ ,err = cache .Acquire (nobody ,file .ID )
84+ _ ,err = cache .Acquire (nobody ,db , file .ID )
8485require .Error (t ,err )
8586require .True (t ,rbac .IsUnauthorizedError (err ))
8687require .Equal (t ,1 ,cache .Count ())
8788
8889// UserReader can
89- b ,err := cache .Acquire (userReader ,file .ID )
90+ b ,err := cache .Acquire (userReader ,db , file .ID )
9091require .NoError (t ,err )
9192require .Equal (t ,1 ,cache .Count ())
9293
@@ -110,16 +111,21 @@ func TestConcurrency(t *testing.T) {
110111ctx := dbauthz .AsFileReader (t .Context ())
111112
112113const fileSize = 10
113- emptyFS := afero .NewIOFS (afero .NewReadOnlyFs (afero .NewMemMapFs ()))
114114var fetches atomic.Int64
115115reg := prometheus .NewRegistry ()
116- c := files .New (func (_ context.Context ,_ uuid.UUID ) (files.CacheEntryValue ,error ) {
116+
117+ dbM := dbmock .NewMockStore (gomock .NewController (t ))
118+ dbM .EXPECT ().GetFileByID (gomock .Any (),gomock .Any ()).DoAndReturn (func (mTx context.Context ,fileID uuid.UUID ) (database.File ,error ) {
117119fetches .Add (1 )
118- // Wait long enough before returning to make sure that allof the goroutines
120+ // Wait long enough before returning to make sure that all the goroutines
119121// will be waiting in line, ensuring that no one duplicated a fetch.
120122time .Sleep (testutil .IntervalMedium )
121- return files.CacheEntryValue {FS :emptyFS ,Size :fileSize },nil
122- },reg ,& coderdtest.FakeAuthorizer {})
123+ return database.File {
124+ Data :make ([]byte ,fileSize ),
125+ },nil
126+ }).AnyTimes ()
127+
128+ c := files .New (reg ,& coderdtest.FakeAuthorizer {})
123129
124130batches := 1000
125131groups := make ([]* errgroup.Group ,0 ,batches )
@@ -137,7 +143,7 @@ func TestConcurrency(t *testing.T) {
137143g .Go (func ()error {
138144// We don't bother to Release these references because the Cache will be
139145// released at the end of the test anyway.
140- _ ,err := c .Acquire (ctx ,id )
146+ _ ,err := c .Acquire (ctx ,dbM , id )
141147return err
142148})
143149}
@@ -164,14 +170,15 @@ func TestRelease(t *testing.T) {
164170ctx := dbauthz .AsFileReader (t .Context ())
165171
166172const fileSize = 10
167- emptyFS := afero .NewIOFS (afero .NewReadOnlyFs (afero .NewMemMapFs ()))
168173reg := prometheus .NewRegistry ()
169- c := files . New ( func ( _ context. Context , _ uuid. UUID ) (files. CacheEntryValue , error ) {
170- return files. CacheEntryValue {
171- FS : emptyFS ,
172- Size : fileSize ,
174+ dbM := dbmock . NewMockStore ( gomock . NewController ( t ))
175+ dbM . EXPECT (). GetFileByID ( gomock . Any (), gomock . Any ()). DoAndReturn ( func ( mTx context. Context , fileID uuid. UUID ) (database. File , error ) {
176+ return database. File {
177+ Data : make ([] byte , fileSize ) ,
173178},nil
174- },reg ,& coderdtest.FakeAuthorizer {})
179+ }).AnyTimes ()
180+
181+ c := files .New (reg ,& coderdtest.FakeAuthorizer {})
175182
176183batches := 100
177184ids := make ([]uuid.UUID ,0 ,batches )
@@ -184,9 +191,8 @@ func TestRelease(t *testing.T) {
184191batchSize := 10
185192for openedIdx ,id := range ids {
186193for batchIdx := range batchSize {
187- it ,err := c .Acquire (ctx ,id )
194+ it ,err := c .Acquire (ctx ,dbM , id )
188195require .NoError (t ,err )
189- require .Equal (t ,emptyFS ,it .FS )
190196releases [id ]= append (releases [id ],it .Close )
191197
192198// Each time a new file is opened, the metrics should be updated as so:
@@ -257,7 +263,7 @@ func cacheAuthzSetup(t *testing.T) (database.Store, *files.Cache, *coderdtest.Re
257263
258264// Dbauthz wrap the db
259265db = dbauthz .New (db ,rec ,logger ,coderdtest .AccessControlStorePointer ())
260- c := files .NewFromStore ( db , reg ,rec )
266+ c := files .New ( reg ,rec )
261267return db ,c ,rec
262268}
263269