88
99 "github.com/google/uuid"
1010 "github.com/prometheus/client_golang/prometheus"
11- "github.com/spf13/afero"
1211 "github.com/stretchr/testify/require"
12+ "go.uber.org/mock/gomock"
1313 "golang.org/x/sync/errgroup"
1414
1515 "cdr.dev/slog/sloggers/slogtest"
@@ -18,6 +18,7 @@ import (
1818 "github.com/coder/coder/v2/coderd/database"
1919 "github.com/coder/coder/v2/coderd/database/dbauthz"
2020 "github.com/coder/coder/v2/coderd/database/dbgen"
21+ "github.com/coder/coder/v2/coderd/database/dbmock"
2122 "github.com/coder/coder/v2/coderd/database/dbtestutil"
2223 "github.com/coder/coder/v2/coderd/files"
2324 "github.com/coder/coder/v2/coderd/rbac"
@@ -58,7 +59,7 @@ func TestCacheRBAC(t *testing.T) {
5859 require .Equal (t , 0 , cache .Count ())
5960 rec .Reset ()
6061
61- _ , err := cache .Acquire (nobody , file .ID )
62+ _ , err := cache .Acquire (nobody , db , file .ID )
6263 require .Error (t , err )
6364 require .True (t , rbac .IsUnauthorizedError (err ))
6465
@@ -75,18 +76,18 @@ func TestCacheRBAC(t *testing.T) {
7576 require .Equal (t , 0 , cache .Count ())
7677
7778 // Read the file with a file reader to put it into the cache.
78- a , err := cache .Acquire (cacheReader , file .ID )
79+ a , err := cache .Acquire (cacheReader , db , file .ID )
7980 require .NoError (t , err )
8081 require .Equal (t , 1 , cache .Count ())
8182
8283 // "nobody" should not be able to read the file.
83- _ , err = cache .Acquire (nobody , file .ID )
84+ _ , err = cache .Acquire (nobody , db , file .ID )
8485 require .Error (t , err )
8586 require .True (t , rbac .IsUnauthorizedError (err ))
8687 require .Equal (t , 1 , cache .Count ())
8788
8889 // UserReader can
89- b , err := cache .Acquire (userReader , file .ID )
90+ b , err := cache .Acquire (userReader , db , file .ID )
9091 require .NoError (t , err )
9192 require .Equal (t , 1 , cache .Count ())
9293
@@ -110,16 +111,21 @@ func TestConcurrency(t *testing.T) {
110111 ctx := dbauthz .AsFileReader (t .Context ())
111112
112113 const fileSize = 10
113- emptyFS := afero .NewIOFS (afero .NewReadOnlyFs (afero .NewMemMapFs ()))
114114 var fetches atomic.Int64
115115 reg := prometheus .NewRegistry ()
116- c := files .New (func (_ context.Context , _ uuid.UUID ) (files.CacheEntryValue , error ) {
116+
117+ dbM := dbmock .NewMockStore (gomock .NewController (t ))
118+ dbM .EXPECT ().GetFileByID (gomock .Any (), gomock .Any ()).DoAndReturn (func (mTx context.Context , fileID uuid.UUID ) (database.File , error ) {
117119 fetches .Add (1 )
118- // Wait long enough before returning to make sure that all of the goroutines
120+ // Wait long enough before returning to make sure that all the goroutines
119121 // will be waiting in line, ensuring that no one duplicated a fetch.
120122 time .Sleep (testutil .IntervalMedium )
121- return files.CacheEntryValue {FS : emptyFS , Size : fileSize }, nil
122- }, reg , & coderdtest.FakeAuthorizer {})
123+ return database.File {
124+ Data : make ([]byte , fileSize ),
125+ }, nil
126+ }).AnyTimes ()
127+
128+ c := files .New (reg , & coderdtest.FakeAuthorizer {})
123129
124130 batches := 1000
125131 groups := make ([]* errgroup.Group , 0 , batches )
@@ -137,7 +143,7 @@ func TestConcurrency(t *testing.T) {
137143 g .Go (func () error {
138144 // We don't bother to Release these references because the Cache will be
139145 // released at the end of the test anyway.
140- _ , err := c .Acquire (ctx , id )
146+ _ , err := c .Acquire (ctx , dbM , id )
141147 return err
142148 })
143149 }
@@ -164,14 +170,15 @@ func TestRelease(t *testing.T) {
164170 ctx := dbauthz .AsFileReader (t .Context ())
165171
166172 const fileSize = 10
167- emptyFS := afero .NewIOFS (afero .NewReadOnlyFs (afero .NewMemMapFs ()))
168173 reg := prometheus .NewRegistry ()
169- c := files . New ( func ( _ context. Context , _ uuid. UUID ) (files. CacheEntryValue , error ) {
170- return files. CacheEntryValue {
171- FS : emptyFS ,
172- Size : fileSize ,
174+ dbM := dbmock . NewMockStore ( gomock . NewController ( t ))
175+ dbM . EXPECT (). GetFileByID ( gomock . Any (), gomock . Any ()). DoAndReturn ( func ( mTx context. Context , fileID uuid. UUID ) (database. File , error ) {
176+ return database. File {
177+ Data : make ([] byte , fileSize ) ,
173178 }, nil
174- }, reg , & coderdtest.FakeAuthorizer {})
179+ }).AnyTimes ()
180+
181+ c := files .New (reg , & coderdtest.FakeAuthorizer {})
175182
176183 batches := 100
177184 ids := make ([]uuid.UUID , 0 , batches )
@@ -184,9 +191,8 @@ func TestRelease(t *testing.T) {
184191 batchSize := 10
185192 for openedIdx , id := range ids {
186193 for batchIdx := range batchSize {
187- it , err := c .Acquire (ctx , id )
194+ it , err := c .Acquire (ctx , dbM , id )
188195 require .NoError (t , err )
189- require .Equal (t , emptyFS , it .FS )
190196 releases [id ] = append (releases [id ], it .Close )
191197
192198 // Each time a new file is opened, the metrics should be updated as so:
@@ -257,7 +263,7 @@ func cacheAuthzSetup(t *testing.T) (database.Store, *files.Cache, *coderdtest.Re
257263
258264 // Dbauthz wrap the db
259265 db = dbauthz .New (db , rec , logger , coderdtest .AccessControlStorePointer ())
260- c := files .NewFromStore ( db , reg , rec )
266+ c := files .New ( reg , rec )
261267 return db , c , rec
262268}
263269
0 commit comments