chore: refactor keycache implementation to reduce duplication (#15100)

This commit is contained in:
Jon Ayers
2024-10-16 20:01:45 +01:00
committed by GitHub
parent 8e254cbb07
commit f537193682
10 changed files with 512 additions and 1339 deletions

369
coderd/cryptokeys/cache.go Normal file
View File

@ -0,0 +1,369 @@
package cryptokeys
import (
"context"
"encoding/hex"
"io"
"strconv"
"sync"
"time"
"golang.org/x/xerrors"
"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/db2sdk"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/quartz"
)
var (
ErrKeyNotFound = xerrors.New("key not found")
ErrKeyInvalid = xerrors.New("key is invalid for use")
ErrClosed = xerrors.New("closed")
ErrInvalidFeature = xerrors.New("invalid feature for this operation")
)
type Fetcher interface {
Fetch(ctx context.Context) ([]codersdk.CryptoKey, error)
}
type EncryptionKeycache interface {
// EncryptingKey returns the latest valid key for encrypting payloads. A valid
// key is one that is both past its start time and before its deletion time.
EncryptingKey(ctx context.Context) (id string, key interface{}, err error)
// DecryptingKey returns the key with the provided id which maps to its sequence
// number. The key is valid for decryption as long as it is not deleted or past
// its deletion date. We must allow for keys prior to their start time to
// account for clock skew between peers (one key may be past its start time on
// one machine while another is not).
DecryptingKey(ctx context.Context, id string) (key interface{}, err error)
io.Closer
}
type SigningKeycache interface {
// SigningKey returns the latest valid key for signing. A valid key is one
// that is both past its start time and before its deletion time.
SigningKey(ctx context.Context) (id string, key interface{}, err error)
// VerifyingKey returns the key with the provided id which should map to its
// sequence number. The key is valid for verifying as long as it is not deleted
// or past its deletion date. We must allow for keys prior to their start time
// to account for clock skew between peers (one key may be past its start time
// on one machine while another is not).
VerifyingKey(ctx context.Context, id string) (key interface{}, err error)
io.Closer
}
const (
// latestSequence is a special sequence number that represents the latest key.
latestSequence = -1
// refreshInterval is the interval at which the key cache will refresh.
refreshInterval = time.Minute * 10
)
type DBFetcher struct {
DB database.Store
Feature database.CryptoKeyFeature
}
func (d *DBFetcher) Fetch(ctx context.Context) ([]codersdk.CryptoKey, error) {
keys, err := d.DB.GetCryptoKeysByFeature(ctx, d.Feature)
if err != nil {
return nil, xerrors.Errorf("get crypto keys by feature: %w", err)
}
return db2sdk.CryptoKeys(keys), nil
}
// cache implements the caching functionality for both signing and encryption keys.
type cache struct {
clock quartz.Clock
refreshCtx context.Context
refreshCancel context.CancelFunc
fetcher Fetcher
logger slog.Logger
feature codersdk.CryptoKeyFeature
mu sync.Mutex
keys map[int32]codersdk.CryptoKey
lastFetch time.Time
refresher *quartz.Timer
fetching bool
closed bool
cond *sync.Cond
}
type CacheOption func(*cache)
func WithCacheClock(clock quartz.Clock) CacheOption {
return func(d *cache) {
d.clock = clock
}
}
// NewSigningCache instantiates a cache. Close should be called to release resources
// associated with its internal timer.
func NewSigningCache(ctx context.Context, logger slog.Logger, fetcher Fetcher,
feature codersdk.CryptoKeyFeature, opts ...func(*cache),
) (SigningKeycache, error) {
if !isSigningKeyFeature(feature) {
return nil, xerrors.Errorf("invalid feature: %s", feature)
}
return newCache(ctx, logger, fetcher, feature, opts...)
}
func NewEncryptionCache(ctx context.Context, logger slog.Logger, fetcher Fetcher,
feature codersdk.CryptoKeyFeature, opts ...func(*cache),
) (EncryptionKeycache, error) {
if !isEncryptionKeyFeature(feature) {
return nil, xerrors.Errorf("invalid feature: %s", feature)
}
return newCache(ctx, logger, fetcher, feature, opts...)
}
func newCache(ctx context.Context, logger slog.Logger, fetcher Fetcher, feature codersdk.CryptoKeyFeature, opts ...func(*cache)) (*cache, error) {
cache := &cache{
clock: quartz.NewReal(),
logger: logger,
fetcher: fetcher,
feature: feature,
}
for _, opt := range opts {
opt(cache)
}
cache.cond = sync.NewCond(&cache.mu)
cache.refreshCtx, cache.refreshCancel = context.WithCancel(ctx)
cache.refresher = cache.clock.AfterFunc(refreshInterval, cache.refresh)
keys, err := cache.cryptoKeys(ctx)
if err != nil {
cache.refreshCancel()
return nil, xerrors.Errorf("initial fetch: %w", err)
}
cache.keys = keys
return cache, nil
}
func (c *cache) EncryptingKey(ctx context.Context) (string, interface{}, error) {
if !isEncryptionKeyFeature(c.feature) {
return "", nil, ErrInvalidFeature
}
return c.cryptoKey(ctx, latestSequence)
}
func (c *cache) DecryptingKey(ctx context.Context, id string) (interface{}, error) {
if !isEncryptionKeyFeature(c.feature) {
return nil, ErrInvalidFeature
}
seq, err := strconv.ParseInt(id, 10, 64)
if err != nil {
return nil, xerrors.Errorf("parse id: %w", err)
}
_, secret, err := c.cryptoKey(ctx, int32(seq))
if err != nil {
return nil, xerrors.Errorf("crypto key: %w", err)
}
return secret, nil
}
func (c *cache) SigningKey(ctx context.Context) (string, interface{}, error) {
if !isSigningKeyFeature(c.feature) {
return "", nil, ErrInvalidFeature
}
return c.cryptoKey(ctx, latestSequence)
}
func (c *cache) VerifyingKey(ctx context.Context, id string) (interface{}, error) {
if !isSigningKeyFeature(c.feature) {
return nil, ErrInvalidFeature
}
seq, err := strconv.ParseInt(id, 10, 64)
if err != nil {
return nil, xerrors.Errorf("parse id: %w", err)
}
_, secret, err := c.cryptoKey(ctx, int32(seq))
if err != nil {
return nil, xerrors.Errorf("crypto key: %w", err)
}
return secret, nil
}
func isEncryptionKeyFeature(feature codersdk.CryptoKeyFeature) bool {
return feature == codersdk.CryptoKeyFeatureWorkspaceApp
}
func isSigningKeyFeature(feature codersdk.CryptoKeyFeature) bool {
switch feature {
case codersdk.CryptoKeyFeatureTailnetResume, codersdk.CryptoKeyFeatureOIDCConvert:
return true
default:
return false
}
}
func idSecret(k codersdk.CryptoKey) (string, []byte, error) {
key, err := hex.DecodeString(k.Secret)
if err != nil {
return "", nil, xerrors.Errorf("decode key: %w", err)
}
return strconv.FormatInt(int64(k.Sequence), 10), key, nil
}
func (c *cache) cryptoKey(ctx context.Context, sequence int32) (string, []byte, error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return "", nil, ErrClosed
}
var key codersdk.CryptoKey
var ok bool
for key, ok = c.key(sequence); !ok && c.fetching && !c.closed; {
c.cond.Wait()
}
if c.closed {
return "", nil, ErrClosed
}
if ok {
return checkKey(key, sequence, c.clock.Now())
}
c.fetching = true
c.mu.Unlock()
keys, err := c.cryptoKeys(ctx)
if err != nil {
return "", nil, xerrors.Errorf("get keys: %w", err)
}
c.mu.Lock()
c.lastFetch = c.clock.Now()
c.refresher.Reset(refreshInterval)
c.keys = keys
c.fetching = false
c.cond.Broadcast()
key, ok = c.key(sequence)
if !ok {
return "", nil, ErrKeyNotFound
}
return checkKey(key, sequence, c.clock.Now())
}
func (c *cache) key(sequence int32) (codersdk.CryptoKey, bool) {
if sequence == latestSequence {
return c.keys[latestSequence], c.keys[latestSequence].CanSign(c.clock.Now())
}
key, ok := c.keys[sequence]
return key, ok
}
func checkKey(key codersdk.CryptoKey, sequence int32, now time.Time) (string, []byte, error) {
if sequence == latestSequence {
if !key.CanSign(now) {
return "", nil, ErrKeyInvalid
}
return idSecret(key)
}
if !key.CanVerify(now) {
return "", nil, ErrKeyInvalid
}
return idSecret(key)
}
// refresh fetches the keys and updates the cache.
func (c *cache) refresh() {
now := c.clock.Now("CryptoKeyCache", "refresh")
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return
}
// If something's already fetching, we don't need to do anything.
if c.fetching {
return
}
// There's a window we must account for where the timer fires while a fetch
// is ongoing but prior to the timer getting reset. In this case we want to
// avoid double fetching.
if now.Sub(c.lastFetch) < refreshInterval {
return
}
c.fetching = true
c.mu.Unlock()
keys, err := c.cryptoKeys(c.refreshCtx)
if err != nil {
c.logger.Error(c.refreshCtx, "fetch crypto keys", slog.Error(err))
return
}
// We don't defer an unlock here due to the deferred unlock at the top of the function.
c.mu.Lock()
c.lastFetch = c.clock.Now()
c.refresher.Reset(refreshInterval)
c.keys = keys
c.fetching = false
c.cond.Broadcast()
}
// cryptoKeys queries the control plane for the crypto keys.
// Outside of initialization, this should only be called by fetch.
func (c *cache) cryptoKeys(ctx context.Context) (map[int32]codersdk.CryptoKey, error) {
keys, err := c.fetcher.Fetch(ctx)
if err != nil {
return nil, xerrors.Errorf("crypto keys: %w", err)
}
cache := toKeyMap(keys, c.clock.Now())
return cache, nil
}
func toKeyMap(keys []codersdk.CryptoKey, now time.Time) map[int32]codersdk.CryptoKey {
m := make(map[int32]codersdk.CryptoKey)
var latest codersdk.CryptoKey
for _, key := range keys {
m[key.Sequence] = key
if key.Sequence > latest.Sequence && key.CanSign(now) {
m[latestSequence] = key
}
}
return m
}
func (c *cache) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return nil
}
c.closed = true
c.refreshCancel()
c.refresher.Stop()
c.cond.Broadcast()
return nil
}

View File

@ -0,0 +1,517 @@
package cryptokeys_test
import (
"context"
"crypto/rand"
"encoding/hex"
"strconv"
"testing"
"time"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/cryptokeys"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
func TestCryptoKeyCache(t *testing.T) {
t.Parallel()
t.Run("Signing", func(t *testing.T) {
t.Parallel()
t.Run("HitsCache", func(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitShort)
logger = slogtest.Make(t, nil)
clock = quartz.NewMock(t)
)
now := clock.Now().UTC()
expected := codersdk.CryptoKey{
Feature: codersdk.CryptoKeyFeatureTailnetResume,
Secret: generateKey(t, 64),
Sequence: 2,
StartsAt: now,
}
ff := &fakeFetcher{
keys: []codersdk.CryptoKey{expected},
}
cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock))
require.NoError(t, err)
id, got, err := cache.SigningKey(ctx)
require.NoError(t, err)
require.Equal(t, keyID(expected), id)
require.Equal(t, decodedSecret(t, expected), got)
require.Equal(t, 1, ff.called)
})
t.Run("MissesCache", func(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitShort)
logger = slogtest.Make(t, nil)
clock = quartz.NewMock(t)
)
ff := &fakeFetcher{
keys: []codersdk.CryptoKey{},
}
cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock))
require.NoError(t, err)
expected := codersdk.CryptoKey{
Feature: codersdk.CryptoKeyFeatureTailnetResume,
Secret: generateKey(t, 64),
Sequence: 12,
StartsAt: clock.Now().UTC(),
}
ff.keys = []codersdk.CryptoKey{expected}
id, got, err := cache.SigningKey(ctx)
require.NoError(t, err)
require.Equal(t, decodedSecret(t, expected), got)
require.Equal(t, keyID(expected), id)
// 1 on startup + missing cache.
require.Equal(t, 2, ff.called)
// Ensure the cache gets hit this time.
id, got, err = cache.SigningKey(ctx)
require.NoError(t, err)
require.Equal(t, decodedSecret(t, expected), got)
require.Equal(t, keyID(expected), id)
// 1 on startup + missing cache.
require.Equal(t, 2, ff.called)
})
t.Run("IgnoresInvalid", func(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitShort)
logger = slogtest.Make(t, nil)
clock = quartz.NewMock(t)
)
now := clock.Now().UTC()
expected := codersdk.CryptoKey{
Feature: codersdk.CryptoKeyFeatureTailnetResume,
Secret: generateKey(t, 64),
Sequence: 1,
StartsAt: clock.Now().UTC(),
}
ff := &fakeFetcher{
keys: []codersdk.CryptoKey{
expected,
{
Feature: codersdk.CryptoKeyFeatureTailnetResume,
Secret: generateKey(t, 64),
Sequence: 2,
StartsAt: now.Add(-time.Second),
DeletesAt: now,
},
},
}
cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock))
require.NoError(t, err)
id, got, err := cache.SigningKey(ctx)
require.NoError(t, err)
require.Equal(t, decodedSecret(t, expected), got)
require.Equal(t, keyID(expected), id)
require.Equal(t, 1, ff.called)
})
t.Run("KeyNotFound", func(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitShort)
logger = slogtest.Make(t, nil)
)
ff := &fakeFetcher{
keys: []codersdk.CryptoKey{},
}
cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume)
require.NoError(t, err)
_, _, err = cache.SigningKey(ctx)
require.ErrorIs(t, err, cryptokeys.ErrKeyNotFound)
})
})
t.Run("Verifying", func(t *testing.T) {
t.Parallel()
t.Run("HitsCache", func(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitShort)
logger = slogtest.Make(t, nil)
clock = quartz.NewMock(t)
)
now := clock.Now().UTC()
expected := codersdk.CryptoKey{
Feature: codersdk.CryptoKeyFeatureTailnetResume,
Secret: generateKey(t, 64),
Sequence: 12,
StartsAt: now,
}
ff := &fakeFetcher{
keys: []codersdk.CryptoKey{
expected,
{
Feature: codersdk.CryptoKeyFeatureTailnetResume,
Secret: generateKey(t, 64),
Sequence: 13,
StartsAt: now,
},
},
}
cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock))
require.NoError(t, err)
got, err := cache.VerifyingKey(ctx, keyID(expected))
require.NoError(t, err)
require.Equal(t, decodedSecret(t, expected), got)
require.Equal(t, 1, ff.called)
})
t.Run("MissesCache", func(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitShort)
logger = slogtest.Make(t, nil)
clock = quartz.NewMock(t)
)
ff := &fakeFetcher{
keys: []codersdk.CryptoKey{},
}
cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock))
require.NoError(t, err)
expected := codersdk.CryptoKey{
Feature: codersdk.CryptoKeyFeatureTailnetResume,
Secret: generateKey(t, 64),
Sequence: 12,
StartsAt: clock.Now().UTC(),
}
ff.keys = []codersdk.CryptoKey{expected}
got, err := cache.VerifyingKey(ctx, keyID(expected))
require.NoError(t, err)
require.Equal(t, decodedSecret(t, expected), got)
require.Equal(t, 2, ff.called)
// Ensure the cache gets hit this time.
got, err = cache.VerifyingKey(ctx, keyID(expected))
require.NoError(t, err)
require.Equal(t, decodedSecret(t, expected), got)
require.Equal(t, 2, ff.called)
})
t.Run("AllowsBeforeStartsAt", func(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitShort)
logger = slogtest.Make(t, nil)
clock = quartz.NewMock(t)
)
now := clock.Now().UTC()
expected := codersdk.CryptoKey{
Feature: codersdk.CryptoKeyFeatureTailnetResume,
Secret: generateKey(t, 64),
Sequence: 12,
StartsAt: now.Add(-time.Second),
}
ff := &fakeFetcher{
keys: []codersdk.CryptoKey{
expected,
},
}
cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock))
require.NoError(t, err)
got, err := cache.VerifyingKey(ctx, keyID(expected))
require.NoError(t, err)
require.Equal(t, decodedSecret(t, expected), got)
require.Equal(t, 1, ff.called)
})
t.Run("KeyPastDeletesAt", func(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitShort)
logger = slogtest.Make(t, nil)
clock = quartz.NewMock(t)
)
now := clock.Now().UTC()
expected := codersdk.CryptoKey{
Feature: codersdk.CryptoKeyFeatureTailnetResume,
Secret: generateKey(t, 64),
Sequence: 12,
StartsAt: now.Add(-time.Second),
DeletesAt: now,
}
ff := &fakeFetcher{
keys: []codersdk.CryptoKey{
expected,
},
}
cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock))
require.NoError(t, err)
_, err = cache.VerifyingKey(ctx, keyID(expected))
require.ErrorIs(t, err, cryptokeys.ErrKeyInvalid)
require.Equal(t, 1, ff.called)
})
t.Run("KeyNotFound", func(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitShort)
logger = slogtest.Make(t, nil)
clock = quartz.NewMock(t)
)
ff := &fakeFetcher{
keys: []codersdk.CryptoKey{},
}
cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock))
require.NoError(t, err)
_, err = cache.VerifyingKey(ctx, "1")
require.ErrorIs(t, err, cryptokeys.ErrKeyNotFound)
})
})
t.Run("CacheRefreshes", func(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitShort)
logger = slogtest.Make(t, nil)
clock = quartz.NewMock(t)
)
now := clock.Now().UTC()
expected := codersdk.CryptoKey{
Feature: codersdk.CryptoKeyFeatureTailnetResume,
Secret: generateKey(t, 64),
Sequence: 12,
StartsAt: now,
DeletesAt: now.Add(time.Minute * 10),
}
ff := &fakeFetcher{
keys: []codersdk.CryptoKey{
expected,
},
}
cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock))
require.NoError(t, err)
id, got, err := cache.SigningKey(ctx)
require.NoError(t, err)
require.Equal(t, decodedSecret(t, expected), got)
require.Equal(t, keyID(expected), id)
require.Equal(t, 1, ff.called)
newKey := codersdk.CryptoKey{
Feature: codersdk.CryptoKeyFeatureTailnetResume,
Secret: generateKey(t, 64),
Sequence: 13,
StartsAt: now,
}
ff.keys = []codersdk.CryptoKey{newKey}
// The ticker should fire and cause a request to coderd.
dur, advance := clock.AdvanceNext()
advance.MustWait(ctx)
require.Equal(t, 2, ff.called)
require.Equal(t, time.Minute*10, dur)
// Assert hits cache.
id, got, err = cache.SigningKey(ctx)
require.NoError(t, err)
require.Equal(t, keyID(newKey), id)
require.Equal(t, decodedSecret(t, newKey), got)
require.Equal(t, 2, ff.called)
// We check again to ensure the timer has been reset.
_, advance = clock.AdvanceNext()
advance.MustWait(ctx)
require.Equal(t, 3, ff.called)
require.Equal(t, time.Minute*10, dur)
})
// This test ensures that if the refresh timer races with an inflight request
// and loses that it doesn't cause a redundant fetch.
t.Run("RefreshNoDoubleFetch", func(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitShort)
logger = slogtest.Make(t, nil)
clock = quartz.NewMock(t)
)
now := clock.Now().UTC()
expected := codersdk.CryptoKey{
Feature: codersdk.CryptoKeyFeatureTailnetResume,
Secret: generateKey(t, 64),
Sequence: 12,
StartsAt: now,
DeletesAt: now.Add(time.Minute * 10),
}
ff := &fakeFetcher{
keys: []codersdk.CryptoKey{
expected,
},
}
// Create a trap that blocks when the refresh timer fires.
trap := clock.Trap().Now("refresh")
cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock))
require.NoError(t, err)
_, wait := clock.AdvanceNext()
trapped := trap.MustWait(ctx)
newKey := codersdk.CryptoKey{
Feature: codersdk.CryptoKeyFeatureTailnetResume,
Secret: generateKey(t, 64),
Sequence: 13,
StartsAt: now,
}
ff.keys = []codersdk.CryptoKey{newKey}
key, err := cache.VerifyingKey(ctx, keyID(newKey))
require.NoError(t, err)
require.Equal(t, 2, ff.called)
require.Equal(t, decodedSecret(t, newKey), key)
trapped.Release()
wait.MustWait(ctx)
require.Equal(t, 2, ff.called)
trap.Close()
// The next timer should fire in 10 minutes.
dur, wait := clock.AdvanceNext()
wait.MustWait(ctx)
require.Equal(t, time.Minute*10, dur)
require.Equal(t, 3, ff.called)
})
t.Run("Closed", func(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitShort)
logger = slogtest.Make(t, nil)
clock = quartz.NewMock(t)
)
now := clock.Now()
expected := codersdk.CryptoKey{
Feature: codersdk.CryptoKeyFeatureTailnetResume,
Secret: generateKey(t, 64),
Sequence: 12,
StartsAt: now,
}
ff := &fakeFetcher{
keys: []codersdk.CryptoKey{
expected,
},
}
cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock))
require.NoError(t, err)
id, got, err := cache.SigningKey(ctx)
require.NoError(t, err)
require.Equal(t, keyID(expected), id)
require.Equal(t, decodedSecret(t, expected), got)
require.Equal(t, 1, ff.called)
key, err := cache.VerifyingKey(ctx, keyID(expected))
require.NoError(t, err)
require.Equal(t, decodedSecret(t, expected), key)
require.Equal(t, 1, ff.called)
cache.Close()
_, _, err = cache.SigningKey(ctx)
require.ErrorIs(t, err, cryptokeys.ErrClosed)
_, err = cache.VerifyingKey(ctx, keyID(expected))
require.ErrorIs(t, err, cryptokeys.ErrClosed)
})
}
type fakeFetcher struct {
keys []codersdk.CryptoKey
called int
}
func (f *fakeFetcher) Fetch(_ context.Context) ([]codersdk.CryptoKey, error) {
f.called++
return f.keys, nil
}
func keyID(key codersdk.CryptoKey) string {
return strconv.FormatInt(int64(key.Sequence), 10)
}
func decodedSecret(t *testing.T, key codersdk.CryptoKey) []byte {
t.Helper()
secret, err := hex.DecodeString(key.Secret)
require.NoError(t, err)
return secret
}
func generateKey(t *testing.T, size int) string {
t.Helper()
key := make([]byte, size)
_, err := rand.Read(key)
require.NoError(t, err)
return hex.EncodeToString(key)
}

View File

@ -1,286 +0,0 @@
package cryptokeys
import (
"context"
"strconv"
"sync"
"time"
"golang.org/x/xerrors"
"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/quartz"
)
// never represents the maximum value for a time.Duration.
const never = 1<<63 - 1
// dbCache implements Keycache for callers with access to the database.
type dbCache struct {
db database.Store
feature database.CryptoKeyFeature
logger slog.Logger
clock quartz.Clock
// The following are initialized by NewDBCache.
keysMu sync.RWMutex
keys map[int32]database.CryptoKey
latestKey database.CryptoKey
timer *quartz.Timer
// invalidateAt is the time at which the keys cache should be invalidated.
invalidateAt time.Time
closed bool
}
type DBCacheOption func(*dbCache)
func WithDBCacheClock(clock quartz.Clock) DBCacheOption {
return func(d *dbCache) {
d.clock = clock
}
}
// NewSigningCache creates a new DBCache. Close should be called to
// release resources associated with its internal timer.
func NewSigningCache(logger slog.Logger, db database.Store, feature database.CryptoKeyFeature, opts ...func(*dbCache)) (SigningKeycache, error) {
if !isSigningKeyFeature(feature) {
return nil, ErrInvalidFeature
}
return newDBCache(logger, db, feature, opts...), nil
}
func NewEncryptionCache(logger slog.Logger, db database.Store, feature database.CryptoKeyFeature, opts ...func(*dbCache)) (EncryptionKeycache, error) {
if !isEncryptionKeyFeature(feature) {
return nil, ErrInvalidFeature
}
return newDBCache(logger, db, feature, opts...), nil
}
func newDBCache(logger slog.Logger, db database.Store, feature database.CryptoKeyFeature, opts ...func(*dbCache)) *dbCache {
d := &dbCache{
db: db,
feature: feature,
clock: quartz.NewReal(),
logger: logger,
}
for _, opt := range opts {
opt(d)
}
// Initialize the timer. This will get properly initialized the first time we fetch.
d.timer = d.clock.AfterFunc(never, d.clear)
return d
}
func (d *dbCache) EncryptingKey(ctx context.Context) (id string, key interface{}, err error) {
if !isEncryptionKeyFeature(d.feature) {
return "", nil, ErrInvalidFeature
}
return d.latest(ctx)
}
func (d *dbCache) DecryptingKey(ctx context.Context, id string) (key interface{}, err error) {
if !isEncryptionKeyFeature(d.feature) {
return nil, ErrInvalidFeature
}
return d.sequence(ctx, id)
}
func (d *dbCache) SigningKey(ctx context.Context) (id string, key interface{}, err error) {
if !isSigningKeyFeature(d.feature) {
return "", nil, ErrInvalidFeature
}
return d.latest(ctx)
}
func (d *dbCache) VerifyingKey(ctx context.Context, id string) (key interface{}, err error) {
if !isSigningKeyFeature(d.feature) {
return nil, ErrInvalidFeature
}
return d.sequence(ctx, id)
}
// sequence returns the CryptoKey with the given sequence number, provided that
// it is neither deleted nor has breached its deletion date. It should only be
// used for verifying or decrypting payloads. To sign/encrypt call Signing.
func (d *dbCache) sequence(ctx context.Context, id string) (interface{}, error) {
sequence, err := strconv.ParseInt(id, 10, 32)
if err != nil {
return nil, xerrors.Errorf("expecting sequence number got %q: %w", id, err)
}
d.keysMu.RLock()
if d.closed {
d.keysMu.RUnlock()
return nil, ErrClosed
}
now := d.clock.Now()
key, ok := d.keys[int32(sequence)]
d.keysMu.RUnlock()
if ok {
return checkKey(key, now)
}
d.keysMu.Lock()
defer d.keysMu.Unlock()
if d.closed {
return nil, ErrClosed
}
key, ok = d.keys[int32(sequence)]
if ok {
return checkKey(key, now)
}
err = d.fetch(ctx)
if err != nil {
return nil, xerrors.Errorf("fetch: %w", err)
}
key, ok = d.keys[int32(sequence)]
if !ok {
return nil, ErrKeyNotFound
}
return checkKey(key, now)
}
// latest returns the latest valid key for signing. A valid key is one that is
// both past its start time and before its deletion time.
func (d *dbCache) latest(ctx context.Context) (string, interface{}, error) {
d.keysMu.RLock()
if d.closed {
d.keysMu.RUnlock()
return "", nil, ErrClosed
}
latest := d.latestKey
d.keysMu.RUnlock()
now := d.clock.Now()
if latest.CanSign(now) {
return idSecret(latest)
}
d.keysMu.Lock()
defer d.keysMu.Unlock()
if d.closed {
return "", nil, ErrClosed
}
if d.latestKey.CanSign(now) {
return idSecret(d.latestKey)
}
// Refetch all keys for this feature so we can find the latest valid key.
err := d.fetch(ctx)
if err != nil {
return "", nil, xerrors.Errorf("fetch: %w", err)
}
return idSecret(d.latestKey)
}
// clear invalidates the cache. This forces the subsequent call to fetch fresh keys.
func (d *dbCache) clear() {
now := d.clock.Now("DBCache", "clear")
d.keysMu.Lock()
defer d.keysMu.Unlock()
// Check if we raced with a fetch. It's possible that the timer fired and we
// lost the race to the mutex. We want to avoid invalidating
// a cache that was just refetched.
if now.Before(d.invalidateAt) {
return
}
d.keys = nil
d.latestKey = database.CryptoKey{}
}
// fetch fetches all keys for the given feature and determines the latest key.
// It must be called while holding the keysMu lock.
func (d *dbCache) fetch(ctx context.Context) error {
keys, err := d.db.GetCryptoKeysByFeature(ctx, d.feature)
if err != nil {
return xerrors.Errorf("get crypto keys by feature: %w", err)
}
now := d.clock.Now()
_ = d.timer.Reset(time.Minute * 10)
d.invalidateAt = now.Add(time.Minute * 10)
cache := make(map[int32]database.CryptoKey)
var latest database.CryptoKey
for _, key := range keys {
cache[key.Sequence] = key
if key.CanSign(now) && key.Sequence > latest.Sequence {
latest = key
}
}
if len(cache) == 0 {
return ErrKeyNotFound
}
if !latest.CanSign(now) {
return ErrKeyInvalid
}
d.keys, d.latestKey = cache, latest
return nil
}
func checkKey(key database.CryptoKey, now time.Time) (interface{}, error) {
if !key.CanVerify(now) {
return nil, ErrKeyInvalid
}
return key.DecodeString()
}
func (d *dbCache) Close() error {
d.keysMu.Lock()
defer d.keysMu.Unlock()
if d.closed {
return nil
}
d.timer.Stop()
d.closed = true
return nil
}
func isEncryptionKeyFeature(feature database.CryptoKeyFeature) bool {
return feature == database.CryptoKeyFeatureWorkspaceApps
}
func isSigningKeyFeature(feature database.CryptoKeyFeature) bool {
switch feature {
case database.CryptoKeyFeatureTailnetResume, database.CryptoKeyFeatureOidcConvert:
return true
default:
return false
}
}
func idSecret(k database.CryptoKey) (string, interface{}, error) {
key, err := k.DecodeString()
if err != nil {
return "", nil, xerrors.Errorf("decode key: %w", err)
}
return strconv.FormatInt(int64(k.Sequence), 10), key, nil
}

View File

@ -1,490 +0,0 @@
package cryptokeys
import (
"database/sql"
"strconv"
"testing"
"time"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbmock"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
)
func Test_version(t *testing.T) {
t.Parallel()
t.Run("HitsCache", func(t *testing.T) {
t.Parallel()
var (
ctrl = gomock.NewController(t)
mockDB = dbmock.NewMockStore(ctrl)
clock = quartz.NewMock(t)
logger = slogtest.Make(t, nil)
ctx = testutil.Context(t, testutil.WaitShort)
)
expectedKey := database.CryptoKey{
Feature: database.CryptoKeyFeatureWorkspaceApps,
Sequence: 32,
Secret: sql.NullString{
String: mustGenerateKey(t),
Valid: true,
},
}
cache := map[int32]database.CryptoKey{
32: expectedKey,
}
k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock))
defer k.Close()
k.keys = cache
secret, err := k.sequence(ctx, keyID(expectedKey))
require.NoError(t, err)
require.Equal(t, decodedSecret(t, expectedKey), secret)
})
t.Run("MissesCache", func(t *testing.T) {
t.Parallel()
var (
ctrl = gomock.NewController(t)
mockDB = dbmock.NewMockStore(ctrl)
clock = quartz.NewMock(t)
ctx = testutil.Context(t, testutil.WaitShort)
logger = slogtest.Make(t, nil)
)
expectedKey := database.CryptoKey{
Feature: database.CryptoKeyFeatureWorkspaceApps,
Sequence: 33,
StartsAt: clock.Now(),
Secret: sql.NullString{
String: mustGenerateKey(t),
Valid: true,
},
}
mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{expectedKey}, nil)
k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock))
defer k.Close()
got, err := k.sequence(ctx, keyID(expectedKey))
require.NoError(t, err)
require.Equal(t, decodedSecret(t, expectedKey), got)
})
t.Run("InvalidCachedKey", func(t *testing.T) {
t.Parallel()
var (
ctrl = gomock.NewController(t)
mockDB = dbmock.NewMockStore(ctrl)
clock = quartz.NewMock(t)
ctx = testutil.Context(t, testutil.WaitShort)
logger = slogtest.Make(t, nil)
)
cache := map[int32]database.CryptoKey{
32: {
Feature: database.CryptoKeyFeatureWorkspaceApps,
Sequence: 32,
Secret: sql.NullString{
String: mustGenerateKey(t),
Valid: true,
},
DeletesAt: sql.NullTime{
Time: clock.Now(),
Valid: true,
},
},
}
k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock))
defer k.Close()
k.keys = cache
_, err := k.sequence(ctx, "32")
require.ErrorIs(t, err, ErrKeyInvalid)
})
t.Run("InvalidDBKey", func(t *testing.T) {
t.Parallel()
var (
ctrl = gomock.NewController(t)
mockDB = dbmock.NewMockStore(ctrl)
clock = quartz.NewMock(t)
ctx = testutil.Context(t, testutil.WaitShort)
logger = slogtest.Make(t, nil)
)
invalidKey := database.CryptoKey{
Feature: database.CryptoKeyFeatureWorkspaceApps,
Sequence: 32,
Secret: sql.NullString{
String: mustGenerateKey(t),
Valid: true,
},
DeletesAt: sql.NullTime{
Time: clock.Now(),
Valid: true,
},
}
mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{invalidKey}, nil)
k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock))
defer k.Close()
_, err := k.sequence(ctx, keyID(invalidKey))
require.ErrorIs(t, err, ErrKeyInvalid)
})
}
func Test_latest(t *testing.T) {
t.Parallel()
t.Run("HitsCache", func(t *testing.T) {
t.Parallel()
var (
ctrl = gomock.NewController(t)
mockDB = dbmock.NewMockStore(ctrl)
clock = quartz.NewMock(t)
ctx = testutil.Context(t, testutil.WaitShort)
logger = slogtest.Make(t, nil)
)
latestKey := database.CryptoKey{
Feature: database.CryptoKeyFeatureWorkspaceApps,
Sequence: 32,
Secret: sql.NullString{
String: mustGenerateKey(t),
Valid: true,
},
StartsAt: clock.Now(),
}
k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock))
defer k.Close()
k.latestKey = latestKey
id, secret, err := k.latest(ctx)
require.NoError(t, err)
require.Equal(t, keyID(latestKey), id)
require.Equal(t, decodedSecret(t, latestKey), secret)
})
t.Run("InvalidCachedKey", func(t *testing.T) {
t.Parallel()
var (
ctrl = gomock.NewController(t)
mockDB = dbmock.NewMockStore(ctrl)
clock = quartz.NewMock(t)
ctx = testutil.Context(t, testutil.WaitShort)
logger = slogtest.Make(t, nil)
)
latestKey := database.CryptoKey{
Feature: database.CryptoKeyFeatureWorkspaceApps,
Sequence: 33,
Secret: sql.NullString{
String: mustGenerateKey(t),
Valid: true,
},
StartsAt: clock.Now(),
}
invalidKey := database.CryptoKey{
Feature: database.CryptoKeyFeatureWorkspaceApps,
Sequence: 32,
Secret: sql.NullString{
String: mustGenerateKey(t),
Valid: true,
},
StartsAt: clock.Now().Add(-time.Hour),
DeletesAt: sql.NullTime{
Time: clock.Now(),
Valid: true,
},
}
mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{latestKey}, nil)
k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock))
defer k.Close()
k.latestKey = invalidKey
id, secret, err := k.latest(ctx)
require.NoError(t, err)
require.Equal(t, keyID(latestKey), id)
require.Equal(t, decodedSecret(t, latestKey), secret)
})
t.Run("UsesActiveKey", func(t *testing.T) {
t.Parallel()
var (
ctrl = gomock.NewController(t)
mockDB = dbmock.NewMockStore(ctrl)
clock = quartz.NewMock(t)
ctx = testutil.Context(t, testutil.WaitShort)
logger = slogtest.Make(t, nil)
)
inactiveKey := database.CryptoKey{
Feature: database.CryptoKeyFeatureWorkspaceApps,
Sequence: 32,
Secret: sql.NullString{
String: mustGenerateKey(t),
Valid: true,
},
StartsAt: clock.Now().Add(time.Hour),
}
activeKey := database.CryptoKey{
Feature: database.CryptoKeyFeatureWorkspaceApps,
Sequence: 33,
Secret: sql.NullString{
String: mustGenerateKey(t),
Valid: true,
},
StartsAt: clock.Now(),
}
mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{inactiveKey, activeKey}, nil)
k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock))
defer k.Close()
id, secret, err := k.latest(ctx)
require.NoError(t, err)
require.Equal(t, keyID(activeKey), id)
require.Equal(t, decodedSecret(t, activeKey), secret)
})
t.Run("NoValidKeys", func(t *testing.T) {
t.Parallel()
var (
ctrl = gomock.NewController(t)
mockDB = dbmock.NewMockStore(ctrl)
clock = quartz.NewMock(t)
ctx = testutil.Context(t, testutil.WaitShort)
logger = slogtest.Make(t, nil)
)
inactiveKey := database.CryptoKey{
Feature: database.CryptoKeyFeatureWorkspaceApps,
Sequence: 32,
Secret: sql.NullString{
String: mustGenerateKey(t),
Valid: true,
},
StartsAt: clock.Now().Add(time.Hour),
}
invalidKey := database.CryptoKey{
Feature: database.CryptoKeyFeatureWorkspaceApps,
Sequence: 33,
Secret: sql.NullString{
String: mustGenerateKey(t),
Valid: true,
},
StartsAt: clock.Now().Add(-time.Hour),
DeletesAt: sql.NullTime{
Time: clock.Now(),
Valid: true,
},
}
mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{inactiveKey, invalidKey}, nil)
k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock))
defer k.Close()
_, _, err := k.latest(ctx)
require.ErrorIs(t, err, ErrKeyInvalid)
})
}
func Test_clear(t *testing.T) {
t.Parallel()
t.Run("InvalidatesCache", func(t *testing.T) {
t.Parallel()
var (
ctrl = gomock.NewController(t)
mockDB = dbmock.NewMockStore(ctrl)
clock = quartz.NewMock(t)
ctx = testutil.Context(t, testutil.WaitShort)
logger = slogtest.Make(t, nil)
)
k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock))
defer k.Close()
activeKey := database.CryptoKey{
Feature: database.CryptoKeyFeatureWorkspaceApps,
Sequence: 33,
Secret: sql.NullString{
String: mustGenerateKey(t),
Valid: true,
},
StartsAt: clock.Now(),
}
mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{activeKey}, nil)
_, _, err := k.latest(ctx)
require.NoError(t, err)
dur, wait := clock.AdvanceNext()
wait.MustWait(ctx)
require.Equal(t, time.Minute*10, dur)
require.Len(t, k.keys, 0)
require.Equal(t, database.CryptoKey{}, k.latestKey)
})
t.Run("ResetsTimer", func(t *testing.T) {
t.Parallel()
var (
ctrl = gomock.NewController(t)
mockDB = dbmock.NewMockStore(ctrl)
clock = quartz.NewMock(t)
ctx = testutil.Context(t, testutil.WaitShort)
logger = slogtest.Make(t, nil)
)
k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock))
defer k.Close()
key := database.CryptoKey{
Feature: database.CryptoKeyFeatureWorkspaceApps,
Sequence: 32,
Secret: sql.NullString{
String: mustGenerateKey(t),
Valid: true,
},
StartsAt: clock.Now(),
}
mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{key}, nil)
// Advance it five minutes so that we can test that the
// timer is reset and doesn't fire after another five minute.
clock.Advance(time.Minute * 5)
id, secret, err := k.latest(ctx)
require.NoError(t, err)
require.Equal(t, keyID(key), id)
require.Equal(t, decodedSecret(t, key), secret)
// Advancing the clock now should require 10 minutes
// before the timer fires again.
dur, wait := clock.AdvanceNext()
wait.MustWait(ctx)
require.Equal(t, time.Minute*10, dur)
require.Len(t, k.keys, 0)
require.Equal(t, database.CryptoKey{}, k.latestKey)
})
// InvalidateAt tests that we have accounted for the race condition where a
// timer fires to invalidate the cache at the same time we are fetching new
// keys. In such cases we want to skip invalidation.
t.Run("InvalidateAt", func(t *testing.T) {
t.Parallel()
var (
ctrl = gomock.NewController(t)
mockDB = dbmock.NewMockStore(ctrl)
clock = quartz.NewMock(t)
ctx = testutil.Context(t, testutil.WaitShort)
logger = slogtest.Make(t, nil)
)
trap := clock.Trap().Now("clear")
k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock))
defer k.Close()
key := database.CryptoKey{
Feature: database.CryptoKeyFeatureWorkspaceApps,
Sequence: 32,
Secret: sql.NullString{
String: mustGenerateKey(t),
Valid: true,
},
StartsAt: clock.Now(),
}
mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{key}, nil).Times(2)
// Move us past the initial timer.
id, secret, err := k.latest(ctx)
require.NoError(t, err)
require.Equal(t, keyID(key), id)
require.Equal(t, decodedSecret(t, key), secret)
// Null these out so that we refetch.
k.keys = nil
k.latestKey = database.CryptoKey{}
// Initiate firing the timer.
dur, wait := clock.AdvanceNext()
require.Equal(t, time.Minute*10, dur)
// Trap the function just before acquiring the mutex.
call := trap.MustWait(ctx)
// Refetch keys.
id, secret, err = k.latest(ctx)
require.NoError(t, err)
require.Equal(t, keyID(key), id)
require.Equal(t, decodedSecret(t, key), secret)
// Let the rest of the timer function run.
// It should see that we have refetched keys and
// not invalidate.
call.Release()
wait.MustWait(ctx)
require.Len(t, k.keys, 1)
require.Equal(t, key, k.latestKey)
trap.Close()
// Refetching the keys should've instantiated a new timer. This one should invalidate keys.
_, wait = clock.AdvanceNext()
wait.MustWait(ctx)
require.Len(t, k.keys, 0)
require.Equal(t, database.CryptoKey{}, k.latestKey)
})
}
func mustGenerateKey(t *testing.T) string {
t.Helper()
key, err := generateKey(64)
require.NoError(t, err)
return key
}
func keyID(key database.CryptoKey) string {
return strconv.FormatInt(int64(key.Sequence), 10)
}
func decodedSecret(t *testing.T, key database.CryptoKey) []byte {
t.Helper()
decoded, err := key.DecodeString()
require.NoError(t, err)
return decoded
}

View File

@ -1,216 +0,0 @@
package cryptokeys_test
import (
"strconv"
"testing"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/cryptokeys"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
func TestDBKeyCache(t *testing.T) {
t.Parallel()
t.Run("VerifyingKey", func(t *testing.T) {
t.Parallel()
t.Run("HitsCache", func(t *testing.T) {
t.Parallel()
var (
db, _ = dbtestutil.NewDB(t)
clock = quartz.NewMock(t)
ctx = testutil.Context(t, testutil.WaitShort)
logger = slogtest.Make(t, nil)
)
key := dbgen.CryptoKey(t, db, database.CryptoKey{
Feature: database.CryptoKeyFeatureOidcConvert,
Sequence: 1,
StartsAt: clock.Now().UTC(),
})
k, err := cryptokeys.NewSigningCache(logger, db, database.CryptoKeyFeatureOidcConvert, cryptokeys.WithDBCacheClock(clock))
require.NoError(t, err)
defer k.Close()
got, err := k.VerifyingKey(ctx, keyID(key))
require.NoError(t, err)
require.Equal(t, decodedSecret(t, key), got)
})
t.Run("NotFound", func(t *testing.T) {
t.Parallel()
var (
db, _ = dbtestutil.NewDB(t)
clock = quartz.NewMock(t)
ctx = testutil.Context(t, testutil.WaitShort)
logger = slogtest.Make(t, nil)
)
k, err := cryptokeys.NewSigningCache(logger, db, database.CryptoKeyFeatureOidcConvert, cryptokeys.WithDBCacheClock(clock))
require.NoError(t, err)
defer k.Close()
_, err = k.VerifyingKey(ctx, "123")
require.ErrorIs(t, err, cryptokeys.ErrKeyNotFound)
})
})
t.Run("Signing", func(t *testing.T) {
t.Parallel()
var (
db, _ = dbtestutil.NewDB(t)
clock = quartz.NewMock(t)
ctx = testutil.Context(t, testutil.WaitShort)
logger = slogtest.Make(t, nil)
)
_ = dbgen.CryptoKey(t, db, database.CryptoKey{
Feature: database.CryptoKeyFeatureOidcConvert,
Sequence: 10,
StartsAt: clock.Now().UTC(),
})
expectedKey := dbgen.CryptoKey(t, db, database.CryptoKey{
Feature: database.CryptoKeyFeatureOidcConvert,
Sequence: 12,
StartsAt: clock.Now().UTC(),
})
_ = dbgen.CryptoKey(t, db, database.CryptoKey{
Feature: database.CryptoKeyFeatureOidcConvert,
Sequence: 2,
StartsAt: clock.Now().UTC(),
})
k, err := cryptokeys.NewSigningCache(logger, db, database.CryptoKeyFeatureOidcConvert, cryptokeys.WithDBCacheClock(clock))
require.NoError(t, err)
defer k.Close()
id, key, err := k.SigningKey(ctx)
require.NoError(t, err)
require.Equal(t, keyID(expectedKey), id)
require.Equal(t, decodedSecret(t, expectedKey), key)
})
t.Run("Closed", func(t *testing.T) {
t.Parallel()
var (
db, _ = dbtestutil.NewDB(t)
clock = quartz.NewMock(t)
ctx = testutil.Context(t, testutil.WaitShort)
logger = slogtest.Make(t, nil)
)
expectedKey := dbgen.CryptoKey(t, db, database.CryptoKey{
Feature: database.CryptoKeyFeatureOidcConvert,
Sequence: 10,
StartsAt: clock.Now(),
})
k, err := cryptokeys.NewSigningCache(logger, db, database.CryptoKeyFeatureOidcConvert, cryptokeys.WithDBCacheClock(clock))
require.NoError(t, err)
defer k.Close()
id, key, err := k.SigningKey(ctx)
require.NoError(t, err)
require.Equal(t, keyID(expectedKey), id)
require.Equal(t, decodedSecret(t, expectedKey), key)
key, err = k.VerifyingKey(ctx, keyID(expectedKey))
require.NoError(t, err)
require.Equal(t, decodedSecret(t, expectedKey), key)
k.Close()
_, _, err = k.SigningKey(ctx)
require.ErrorIs(t, err, cryptokeys.ErrClosed)
_, err = k.VerifyingKey(ctx, keyID(expectedKey))
require.ErrorIs(t, err, cryptokeys.ErrClosed)
})
t.Run("InvalidSigningFeature", func(t *testing.T) {
t.Parallel()
var (
db, _ = dbtestutil.NewDB(t)
clock = quartz.NewMock(t)
logger = slogtest.Make(t, nil)
ctx = testutil.Context(t, testutil.WaitShort)
)
_, err := cryptokeys.NewSigningCache(logger, db, database.CryptoKeyFeatureWorkspaceApps, cryptokeys.WithDBCacheClock(clock))
require.ErrorIs(t, err, cryptokeys.ErrInvalidFeature)
// Instantiate a signing cache and try to use it as an encryption cache.
sc, err := cryptokeys.NewSigningCache(logger, db, database.CryptoKeyFeatureOidcConvert, cryptokeys.WithDBCacheClock(clock))
require.NoError(t, err)
defer sc.Close()
ec, ok := sc.(cryptokeys.EncryptionKeycache)
require.True(t, ok)
_, _, err = ec.EncryptingKey(ctx)
require.ErrorIs(t, err, cryptokeys.ErrInvalidFeature)
_, err = ec.DecryptingKey(ctx, "123")
require.ErrorIs(t, err, cryptokeys.ErrInvalidFeature)
})
t.Run("InvalidEncryptionFeature", func(t *testing.T) {
t.Parallel()
var (
db, _ = dbtestutil.NewDB(t)
clock = quartz.NewMock(t)
logger = slogtest.Make(t, nil)
ctx = testutil.Context(t, testutil.WaitShort)
)
_, err := cryptokeys.NewEncryptionCache(logger, db, database.CryptoKeyFeatureOidcConvert, cryptokeys.WithDBCacheClock(clock))
require.ErrorIs(t, err, cryptokeys.ErrInvalidFeature)
// Instantiate an encryption cache and try to use it as a signing cache.
ec, err := cryptokeys.NewEncryptionCache(logger, db, database.CryptoKeyFeatureWorkspaceApps, cryptokeys.WithDBCacheClock(clock))
require.NoError(t, err)
defer ec.Close()
sc, ok := ec.(cryptokeys.SigningKeycache)
require.True(t, ok)
_, _, err = sc.SigningKey(ctx)
require.ErrorIs(t, err, cryptokeys.ErrInvalidFeature)
_, err = sc.VerifyingKey(ctx, "123")
require.ErrorIs(t, err, cryptokeys.ErrInvalidFeature)
})
}
func keyID(key database.CryptoKey) string {
return strconv.FormatInt(int64(key.Sequence), 10)
}
func decodedSecret(t *testing.T, key database.CryptoKey) []byte {
t.Helper()
secret, err := key.DecodeString()
require.NoError(t, err)
return secret
}

View File

@ -1,41 +0,0 @@
package cryptokeys
import (
"context"
"io"
"golang.org/x/xerrors"
)
var (
ErrKeyNotFound = xerrors.New("key not found")
ErrKeyInvalid = xerrors.New("key is invalid for use")
ErrClosed = xerrors.New("closed")
ErrInvalidFeature = xerrors.New("invalid feature for this operation")
)
type EncryptionKeycache interface {
// EncryptingKey returns the latest valid key for encrypting payloads. A valid
// key is one that is both past its start time and before its deletion time.
EncryptingKey(ctx context.Context) (id string, key interface{}, err error)
// DecryptingKey returns the key with the provided id which maps to its sequence
// number. The key is valid for decryption as long as it is not deleted or past
// its deletion date. We must allow for keys prior to their start time to
// account for clock skew between peers (one key may be past its start time on
// one machine while another is not).
DecryptingKey(ctx context.Context, id string) (key interface{}, err error)
io.Closer
}
type SigningKeycache interface {
// SigningKey returns the latest valid key for signing. A valid key is one
// that is both past its start time and before its deletion time.
SigningKey(ctx context.Context) (id string, key interface{}, err error)
// VerifyingKey returns the key with the provided id which should map to its
// sequence number. The key is valid for verifying as long as it is not deleted
// or past its deletion date. We must allow for keys prior to their start time
// to account for clock skew between peers (one key may be past its start time
// on one machine while another is not).
VerifyingKey(ctx context.Context, id string) (key interface{}, err error)
io.Closer
}

View File

@ -27,7 +27,7 @@ type DecryptKeyProvider interface {
func Encrypt(ctx context.Context, e EncryptKeyProvider, claims Claims) (string, error) {
id, key, err := e.EncryptingKey(ctx)
if err != nil {
return "", xerrors.Errorf("get signing key: %w", err)
return "", xerrors.Errorf("encrypting key: %w", err)
}
encrypter, err := jose.NewEncrypter(

View File

@ -18,6 +18,7 @@ import (
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/jwtutils"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
)
@ -238,10 +239,11 @@ func TestJWS(t *testing.T) {
Feature: database.CryptoKeyFeatureOidcConvert,
StartsAt: time.Now(),
})
log = slogtest.Make(t, nil)
log = slogtest.Make(t, nil)
fetcher = &cryptokeys.DBFetcher{DB: db, Feature: database.CryptoKeyFeatureOidcConvert}
)
cache, err := cryptokeys.NewSigningCache(log, db, database.CryptoKeyFeatureOidcConvert)
cache, err := cryptokeys.NewSigningCache(ctx, log, fetcher, codersdk.CryptoKeyFeatureOIDCConvert)
require.NoError(t, err)
claims := testClaims{
@ -328,9 +330,11 @@ func TestJWE(t *testing.T) {
StartsAt: time.Now(),
})
log = slogtest.Make(t, nil)
fetcher = &cryptokeys.DBFetcher{DB: db, Feature: database.CryptoKeyFeatureWorkspaceApps}
)
cache, err := cryptokeys.NewEncryptionCache(log, db, database.CryptoKeyFeatureWorkspaceApps)
cache, err := cryptokeys.NewEncryptionCache(ctx, log, fetcher, codersdk.CryptoKeyFeatureWorkspaceApp)
require.NoError(t, err)
claims := testClaims{