diff --git a/coderd/files/cache.go b/coderd/files/cache.go index 3698aac928..159f1b8aee 100644 --- a/coderd/files/cache.go +++ b/coderd/files/cache.go @@ -25,60 +25,61 @@ type FileAcquirer interface { // New returns a file cache that will fetch files from a database func New(registerer prometheus.Registerer, authz rbac.Authorizer) *Cache { - return (&Cache{ - lock: sync.Mutex{}, - data: make(map[uuid.UUID]*cacheEntry), - authz: authz, - }).registerMetrics(registerer) + return &Cache{ + lock: sync.Mutex{}, + data: make(map[uuid.UUID]*cacheEntry), + authz: authz, + cacheMetrics: newCacheMetrics(registerer), + } } -func (c *Cache) registerMetrics(registerer prometheus.Registerer) *Cache { +func newCacheMetrics(registerer prometheus.Registerer) cacheMetrics { subsystem := "file_cache" f := promauto.With(registerer) - c.currentCacheSize = f.NewGauge(prometheus.GaugeOpts{ - Namespace: "coderd", - Subsystem: subsystem, - Name: "open_files_size_bytes_current", - Help: "The current amount of memory of all files currently open in the file cache.", - }) + return cacheMetrics{ + currentCacheSize: f.NewGauge(prometheus.GaugeOpts{ + Namespace: "coderd", + Subsystem: subsystem, + Name: "open_files_size_bytes_current", + Help: "The current amount of memory of all files currently open in the file cache.", + }), - c.totalCacheSize = f.NewCounter(prometheus.CounterOpts{ - Namespace: "coderd", - Subsystem: subsystem, - Name: "open_files_size_bytes_total", - Help: "The total amount of memory ever opened in the file cache. This number never decrements.", - }) + totalCacheSize: f.NewCounter(prometheus.CounterOpts{ + Namespace: "coderd", + Subsystem: subsystem, + Name: "open_files_size_bytes_total", + Help: "The total amount of memory ever opened in the file cache. This number never decrements.", + }), - c.currentOpenFiles = f.NewGauge(prometheus.GaugeOpts{ - Namespace: "coderd", - Subsystem: subsystem, - Name: "open_files_current", - Help: "The count of unique files currently open in the file cache.", - }) + currentOpenFiles: f.NewGauge(prometheus.GaugeOpts{ + Namespace: "coderd", + Subsystem: subsystem, + Name: "open_files_current", + Help: "The count of unique files currently open in the file cache.", + }), - c.totalOpenedFiles = f.NewCounter(prometheus.CounterOpts{ - Namespace: "coderd", - Subsystem: subsystem, - Name: "open_files_total", - Help: "The total count of unique files ever opened in the file cache.", - }) + totalOpenedFiles: f.NewCounter(prometheus.CounterOpts{ + Namespace: "coderd", + Subsystem: subsystem, + Name: "open_files_total", + Help: "The total count of unique files ever opened in the file cache.", + }), - c.currentOpenFileReferences = f.NewGauge(prometheus.GaugeOpts{ - Namespace: "coderd", - Subsystem: subsystem, - Name: "open_file_refs_current", - Help: "The count of file references currently open in the file cache. Multiple references can be held for the same file.", - }) + currentOpenFileReferences: f.NewGauge(prometheus.GaugeOpts{ + Namespace: "coderd", + Subsystem: subsystem, + Name: "open_file_refs_current", + Help: "The count of file references currently open in the file cache. Multiple references can be held for the same file.", + }), - c.totalOpenFileReferences = f.NewCounterVec(prometheus.CounterOpts{ - Namespace: "coderd", - Subsystem: subsystem, - Name: "open_file_refs_total", - Help: "The total number of file references ever opened in the file cache. The 'hit' label indicates if the file was loaded from the cache.", - }, []string{"hit"}) - - return c + totalOpenFileReferences: f.NewCounterVec(prometheus.CounterOpts{ + Namespace: "coderd", + Subsystem: subsystem, + Name: "open_file_refs_total", + Help: "The total number of file references ever opened in the file cache. The 'hit' label indicates if the file was loaded from the cache.", + }, []string{"hit"}), + } } // Cache persists the files for template versions, and is used by dynamic @@ -106,18 +107,23 @@ type cacheMetrics struct { totalCacheSize prometheus.Counter } +type cacheEntry struct { + // Safety: refCount must only be accessed while the Cache lock is held. + refCount int + value *lazy.ValueWithError[CacheEntryValue] + + // Safety: close must only be called while the Cache lock is held + close func() + // Safety: purge must only be called while the Cache lock is held + purge func() +} + type CacheEntryValue struct { fs.FS Object rbac.Object Size int64 } -type cacheEntry struct { - // refCount must only be accessed while the Cache lock is held. - refCount int - value *lazy.ValueWithError[CacheEntryValue] -} - var _ fs.FS = (*CloseFS)(nil) // CloseFS is a wrapper around fs.FS that implements io.Closer. The Close() @@ -129,106 +135,141 @@ type CloseFS struct { close func() } -func (f *CloseFS) Close() { f.close() } +func (f *CloseFS) Close() { + f.close() +} // Acquire will load the fs.FS for the given file. It guarantees that parallel // calls for the same fileID will only result in one fetch, and that parallel // calls for distinct fileIDs will fetch in parallel. // -// Safety: Every call to Acquire that does not return an error must have a -// matching call to Release. +// Safety: Every call to Acquire that does not return an error must call close +// on the returned value when it is done being used. func (c *Cache) Acquire(ctx context.Context, db database.Store, fileID uuid.UUID) (*CloseFS, error) { // It's important that this `Load` call occurs outside `prepare`, after the // mutex has been released, or we would continue to hold the lock until the // entire file has been fetched, which may be slow, and would prevent other // files from being fetched in parallel. - it, err := c.prepare(ctx, db, fileID).Load() + e := c.prepare(db, fileID) + ev, err := e.value.Load() if err != nil { - c.release(fileID) + c.lock.Lock() + defer c.lock.Unlock() + e.close() + e.purge() return nil, err } + cleanup := func() { + c.lock.Lock() + defer c.lock.Unlock() + e.close() + } + + // We always run the fetch under a system context and actor, so we need to + // check the caller's context (including the actor) manually before returning. + + // Check if the caller's context was canceled. Even though `Authorize` takes + // a context, we still check it manually first because none of our mock + // database implementations check for context cancellation. + if err := ctx.Err(); err != nil { + cleanup() + return nil, err + } + + // Check that the caller is authorized to access the file subject, ok := dbauthz.ActorFromContext(ctx) if !ok { + cleanup() return nil, dbauthz.ErrNoActor } - // Always check the caller can actually read the file. - if err := c.authz.Authorize(ctx, subject, policy.ActionRead, it.Object); err != nil { - c.release(fileID) + if err := c.authz.Authorize(ctx, subject, policy.ActionRead, ev.Object); err != nil { + cleanup() return nil, err } - var once sync.Once + var closeOnce sync.Once return &CloseFS{ - FS: it.FS, + FS: ev.FS, close: func() { // sync.Once makes the Close() idempotent, so we can call it // multiple times without worrying about double-releasing. - once.Do(func() { c.release(fileID) }) + closeOnce.Do(func() { + c.lock.Lock() + defer c.lock.Unlock() + e.close() + }) }, }, nil } -func (c *Cache) prepare(ctx context.Context, db database.Store, fileID uuid.UUID) *lazy.ValueWithError[CacheEntryValue] { +func (c *Cache) prepare(db database.Store, fileID uuid.UUID) *cacheEntry { c.lock.Lock() defer c.lock.Unlock() hitLabel := "true" entry, ok := c.data[fileID] if !ok { - value := lazy.NewWithError(func() (CacheEntryValue, error) { - val, err := fetch(ctx, db, fileID) + hitLabel = "false" - // Always add to the cache size the bytes of the file loaded. - if err == nil { + var purgeOnce sync.Once + entry = &cacheEntry{ + value: lazy.NewWithError(func() (CacheEntryValue, error) { + val, err := fetch(db, fileID) + if err != nil { + return val, err + } + + // Add the size of the file to the cache size metrics. c.currentCacheSize.Add(float64(val.Size)) c.totalCacheSize.Add(float64(val.Size)) - } - return val, err - }) + return val, err + }), - entry = &cacheEntry{ - value: value, - refCount: 0, + close: func() { + entry.refCount-- + c.currentOpenFileReferences.Dec() + if entry.refCount > 0 { + return + } + + entry.purge() + }, + + purge: func() { + purgeOnce.Do(func() { + c.purge(fileID) + }) + }, } c.data[fileID] = entry + c.currentOpenFiles.Inc() c.totalOpenedFiles.Inc() - hitLabel = "false" } c.currentOpenFileReferences.Inc() c.totalOpenFileReferences.WithLabelValues(hitLabel).Inc() entry.refCount++ - return entry.value + return entry } -// release decrements the reference count for the given fileID, and frees the -// backing data if there are no further references being held. -// -// release should only be called after a successful call to Acquire using the Release() -// method on the returned *CloseFS. -func (c *Cache) release(fileID uuid.UUID) { - c.lock.Lock() - defer c.lock.Unlock() - +// purge immediately removes an entry from the cache, even if it has open +// references. +// Safety: Must only be called while the Cache lock is held +func (c *Cache) purge(fileID uuid.UUID) { entry, ok := c.data[fileID] if !ok { - // If we land here, it's almost certainly because a bug already happened, - // and we're freeing something that's already been freed, or we're calling - // this function with an incorrect ID. Should this function return an error? - return - } - - c.currentOpenFileReferences.Dec() - entry.refCount-- - if entry.refCount > 0 { + // If we land here, it's probably because of a fetch attempt that + // resulted in an error, and got purged already. It may also be an + // erroneous extra close, but we can't really distinguish between those + // two cases currently. return } + // Purge the file from the cache. c.currentOpenFiles.Dec() - ev, err := entry.value.Load() if err == nil { c.currentCacheSize.Add(-1 * float64(ev.Size)) @@ -246,11 +287,18 @@ func (c *Cache) Count() int { return len(c.data) } -func fetch(ctx context.Context, store database.Store, fileID uuid.UUID) (CacheEntryValue, error) { - // Make sure the read does not fail due to authorization issues. - // Authz is checked on the Acquire call, so this is safe. +func fetch(store database.Store, fileID uuid.UUID) (CacheEntryValue, error) { + // Because many callers can be waiting on the same file fetch concurrently, we + // want to prevent any failures that would cause them all to receive errors + // because the caller who initiated the fetch would fail. + // - We always run the fetch with an uncancelable context, and then check + // context cancellation for each acquirer afterwards. + // - We always run the fetch as a system user, and then check authorization + // for each acquirer afterwards. + // This prevents a canceled context or an unauthorized user from "holding up + // the queue". //nolint:gocritic - file, err := store.GetFileByID(dbauthz.AsFileReader(ctx), fileID) + file, err := store.GetFileByID(dbauthz.AsFileReader(context.Background()), fileID) if err != nil { return CacheEntryValue{}, xerrors.Errorf("failed to read file from database: %w", err) } diff --git a/coderd/files/cache_internal_test.go b/coderd/files/cache_internal_test.go new file mode 100644 index 0000000000..89348c65a2 --- /dev/null +++ b/coderd/files/cache_internal_test.go @@ -0,0 +1,23 @@ +package files + +import ( + "context" + + "github.com/google/uuid" + + "github.com/coder/coder/v2/coderd/database" +) + +// LeakCache prevents entries from even being released to enable testing certain +// behaviors. +type LeakCache struct { + *Cache +} + +func (c *LeakCache) Acquire(ctx context.Context, db database.Store, fileID uuid.UUID) (*CloseFS, error) { + // We need to call prepare first to both 1. leak a reference and 2. prevent + // the behavior of immediately closing on an error (as implemented in Acquire) + // from freeing the file. + c.prepare(db, fileID) + return c.Cache.Acquire(ctx, db, fileID) +} diff --git a/coderd/files/cache_test.go b/coderd/files/cache_test.go index 8a8acfbc07..6f8f74e74f 100644 --- a/coderd/files/cache_test.go +++ b/coderd/files/cache_test.go @@ -2,12 +2,14 @@ package files_test import ( "context" + "sync" "sync/atomic" "testing" "time" "github.com/google/uuid" "github.com/prometheus/client_golang/prometheus" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" "golang.org/x/sync/errgroup" @@ -26,6 +28,104 @@ import ( "github.com/coder/coder/v2/testutil" ) +func TestCancelledFetch(t *testing.T) { + t.Parallel() + + fileID := uuid.New() + dbM := dbmock.NewMockStore(gomock.NewController(t)) + + // The file fetch should succeed. + dbM.EXPECT().GetFileByID(gomock.Any(), gomock.Any()).DoAndReturn(func(mTx context.Context, fileID uuid.UUID) (database.File, error) { + return database.File{ + ID: fileID, + Data: make([]byte, 100), + }, nil + }) + + cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) + + // Cancel the context for the first call; should fail. + //nolint:gocritic // Unit testing + ctx, cancel := context.WithCancel(dbauthz.AsFileReader(testutil.Context(t, testutil.WaitShort))) + cancel() + _, err := cache.Acquire(ctx, dbM, fileID) + assert.ErrorIs(t, err, context.Canceled) +} + +// TestCancelledConcurrentFetch runs 2 Acquire calls. The first has a canceled +// context and will get a ctx.Canceled error. The second call should get a warmfirst error and try to fetch the file +// again, which should succeed. +func TestCancelledConcurrentFetch(t *testing.T) { + t.Parallel() + + fileID := uuid.New() + dbM := dbmock.NewMockStore(gomock.NewController(t)) + + // The file fetch should succeed. + dbM.EXPECT().GetFileByID(gomock.Any(), gomock.Any()).DoAndReturn(func(mTx context.Context, fileID uuid.UUID) (database.File, error) { + return database.File{ + ID: fileID, + Data: make([]byte, 100), + }, nil + }) + + cache := files.LeakCache{Cache: files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{})} + + //nolint:gocritic // Unit testing + ctx := dbauthz.AsFileReader(testutil.Context(t, testutil.WaitShort)) + + // Cancel the context for the first call; should fail. + canceledCtx, cancel := context.WithCancel(ctx) + cancel() + _, err := cache.Acquire(canceledCtx, dbM, fileID) + require.ErrorIs(t, err, context.Canceled) + + // Second call, that should succeed without fetching from the database again + // since the cache should be populated by the fetch the first request started + // even if it doesn't wait for completion. + _, err = cache.Acquire(ctx, dbM, fileID) + require.NoError(t, err) +} + +func TestConcurrentFetch(t *testing.T) { + t.Parallel() + + fileID := uuid.New() + + // Only allow one call, which should succeed + dbM := dbmock.NewMockStore(gomock.NewController(t)) + dbM.EXPECT().GetFileByID(gomock.Any(), gomock.Any()).DoAndReturn(func(mTx context.Context, fileID uuid.UUID) (database.File, error) { + return database.File{ID: fileID}, nil + }) + + cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) + //nolint:gocritic // Unit testing + ctx := dbauthz.AsFileReader(testutil.Context(t, testutil.WaitShort)) + + // Expect 2 calls to Acquire before we continue the test + var ( + hold sync.WaitGroup + wg sync.WaitGroup + ) + + for range 2 { + hold.Add(1) + // TODO: wg.Go in Go 1.25 + wg.Add(1) + go func() { + defer wg.Done() + hold.Done() + hold.Wait() + _, err := cache.Acquire(ctx, dbM, fileID) + require.NoError(t, err) + }() + } + + // Wait for both go routines to assert their errors and finish. + wg.Wait() + require.Equal(t, 1, cache.Count()) +} + // nolint:paralleltest,tparallel // Serially testing is easier func TestCacheRBAC(t *testing.T) { t.Parallel()