mirror of
https://github.com/coder/coder.git
synced 2025-07-15 22:20:27 +00:00
chore: refactor keycache implementation to reduce duplication (#15100)
This commit is contained in:
369
coderd/cryptokeys/cache.go
Normal file
369
coderd/cryptokeys/cache.go
Normal file
@ -0,0 +1,369 @@
|
||||
package cryptokeys
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"io"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/db2sdk"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrKeyNotFound = xerrors.New("key not found")
|
||||
ErrKeyInvalid = xerrors.New("key is invalid for use")
|
||||
ErrClosed = xerrors.New("closed")
|
||||
ErrInvalidFeature = xerrors.New("invalid feature for this operation")
|
||||
)
|
||||
|
||||
type Fetcher interface {
|
||||
Fetch(ctx context.Context) ([]codersdk.CryptoKey, error)
|
||||
}
|
||||
|
||||
type EncryptionKeycache interface {
|
||||
// EncryptingKey returns the latest valid key for encrypting payloads. A valid
|
||||
// key is one that is both past its start time and before its deletion time.
|
||||
EncryptingKey(ctx context.Context) (id string, key interface{}, err error)
|
||||
// DecryptingKey returns the key with the provided id which maps to its sequence
|
||||
// number. The key is valid for decryption as long as it is not deleted or past
|
||||
// its deletion date. We must allow for keys prior to their start time to
|
||||
// account for clock skew between peers (one key may be past its start time on
|
||||
// one machine while another is not).
|
||||
DecryptingKey(ctx context.Context, id string) (key interface{}, err error)
|
||||
io.Closer
|
||||
}
|
||||
|
||||
type SigningKeycache interface {
|
||||
// SigningKey returns the latest valid key for signing. A valid key is one
|
||||
// that is both past its start time and before its deletion time.
|
||||
SigningKey(ctx context.Context) (id string, key interface{}, err error)
|
||||
// VerifyingKey returns the key with the provided id which should map to its
|
||||
// sequence number. The key is valid for verifying as long as it is not deleted
|
||||
// or past its deletion date. We must allow for keys prior to their start time
|
||||
// to account for clock skew between peers (one key may be past its start time
|
||||
// on one machine while another is not).
|
||||
VerifyingKey(ctx context.Context, id string) (key interface{}, err error)
|
||||
io.Closer
|
||||
}
|
||||
|
||||
const (
|
||||
// latestSequence is a special sequence number that represents the latest key.
|
||||
latestSequence = -1
|
||||
// refreshInterval is the interval at which the key cache will refresh.
|
||||
refreshInterval = time.Minute * 10
|
||||
)
|
||||
|
||||
type DBFetcher struct {
|
||||
DB database.Store
|
||||
Feature database.CryptoKeyFeature
|
||||
}
|
||||
|
||||
func (d *DBFetcher) Fetch(ctx context.Context) ([]codersdk.CryptoKey, error) {
|
||||
keys, err := d.DB.GetCryptoKeysByFeature(ctx, d.Feature)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get crypto keys by feature: %w", err)
|
||||
}
|
||||
|
||||
return db2sdk.CryptoKeys(keys), nil
|
||||
}
|
||||
|
||||
// cache implements the caching functionality for both signing and encryption keys.
|
||||
type cache struct {
|
||||
clock quartz.Clock
|
||||
refreshCtx context.Context
|
||||
refreshCancel context.CancelFunc
|
||||
fetcher Fetcher
|
||||
logger slog.Logger
|
||||
feature codersdk.CryptoKeyFeature
|
||||
|
||||
mu sync.Mutex
|
||||
keys map[int32]codersdk.CryptoKey
|
||||
lastFetch time.Time
|
||||
refresher *quartz.Timer
|
||||
fetching bool
|
||||
closed bool
|
||||
cond *sync.Cond
|
||||
}
|
||||
|
||||
type CacheOption func(*cache)
|
||||
|
||||
func WithCacheClock(clock quartz.Clock) CacheOption {
|
||||
return func(d *cache) {
|
||||
d.clock = clock
|
||||
}
|
||||
}
|
||||
|
||||
// NewSigningCache instantiates a cache. Close should be called to release resources
|
||||
// associated with its internal timer.
|
||||
func NewSigningCache(ctx context.Context, logger slog.Logger, fetcher Fetcher,
|
||||
feature codersdk.CryptoKeyFeature, opts ...func(*cache),
|
||||
) (SigningKeycache, error) {
|
||||
if !isSigningKeyFeature(feature) {
|
||||
return nil, xerrors.Errorf("invalid feature: %s", feature)
|
||||
}
|
||||
return newCache(ctx, logger, fetcher, feature, opts...)
|
||||
}
|
||||
|
||||
func NewEncryptionCache(ctx context.Context, logger slog.Logger, fetcher Fetcher,
|
||||
feature codersdk.CryptoKeyFeature, opts ...func(*cache),
|
||||
) (EncryptionKeycache, error) {
|
||||
if !isEncryptionKeyFeature(feature) {
|
||||
return nil, xerrors.Errorf("invalid feature: %s", feature)
|
||||
}
|
||||
return newCache(ctx, logger, fetcher, feature, opts...)
|
||||
}
|
||||
|
||||
func newCache(ctx context.Context, logger slog.Logger, fetcher Fetcher, feature codersdk.CryptoKeyFeature, opts ...func(*cache)) (*cache, error) {
|
||||
cache := &cache{
|
||||
clock: quartz.NewReal(),
|
||||
logger: logger,
|
||||
fetcher: fetcher,
|
||||
feature: feature,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(cache)
|
||||
}
|
||||
|
||||
cache.cond = sync.NewCond(&cache.mu)
|
||||
cache.refreshCtx, cache.refreshCancel = context.WithCancel(ctx)
|
||||
cache.refresher = cache.clock.AfterFunc(refreshInterval, cache.refresh)
|
||||
|
||||
keys, err := cache.cryptoKeys(ctx)
|
||||
if err != nil {
|
||||
cache.refreshCancel()
|
||||
return nil, xerrors.Errorf("initial fetch: %w", err)
|
||||
}
|
||||
cache.keys = keys
|
||||
return cache, nil
|
||||
}
|
||||
|
||||
func (c *cache) EncryptingKey(ctx context.Context) (string, interface{}, error) {
|
||||
if !isEncryptionKeyFeature(c.feature) {
|
||||
return "", nil, ErrInvalidFeature
|
||||
}
|
||||
|
||||
return c.cryptoKey(ctx, latestSequence)
|
||||
}
|
||||
|
||||
func (c *cache) DecryptingKey(ctx context.Context, id string) (interface{}, error) {
|
||||
if !isEncryptionKeyFeature(c.feature) {
|
||||
return nil, ErrInvalidFeature
|
||||
}
|
||||
|
||||
seq, err := strconv.ParseInt(id, 10, 64)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("parse id: %w", err)
|
||||
}
|
||||
|
||||
_, secret, err := c.cryptoKey(ctx, int32(seq))
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("crypto key: %w", err)
|
||||
}
|
||||
return secret, nil
|
||||
}
|
||||
|
||||
func (c *cache) SigningKey(ctx context.Context) (string, interface{}, error) {
|
||||
if !isSigningKeyFeature(c.feature) {
|
||||
return "", nil, ErrInvalidFeature
|
||||
}
|
||||
|
||||
return c.cryptoKey(ctx, latestSequence)
|
||||
}
|
||||
|
||||
func (c *cache) VerifyingKey(ctx context.Context, id string) (interface{}, error) {
|
||||
if !isSigningKeyFeature(c.feature) {
|
||||
return nil, ErrInvalidFeature
|
||||
}
|
||||
|
||||
seq, err := strconv.ParseInt(id, 10, 64)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("parse id: %w", err)
|
||||
}
|
||||
|
||||
_, secret, err := c.cryptoKey(ctx, int32(seq))
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("crypto key: %w", err)
|
||||
}
|
||||
|
||||
return secret, nil
|
||||
}
|
||||
|
||||
func isEncryptionKeyFeature(feature codersdk.CryptoKeyFeature) bool {
|
||||
return feature == codersdk.CryptoKeyFeatureWorkspaceApp
|
||||
}
|
||||
|
||||
func isSigningKeyFeature(feature codersdk.CryptoKeyFeature) bool {
|
||||
switch feature {
|
||||
case codersdk.CryptoKeyFeatureTailnetResume, codersdk.CryptoKeyFeatureOIDCConvert:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func idSecret(k codersdk.CryptoKey) (string, []byte, error) {
|
||||
key, err := hex.DecodeString(k.Secret)
|
||||
if err != nil {
|
||||
return "", nil, xerrors.Errorf("decode key: %w", err)
|
||||
}
|
||||
|
||||
return strconv.FormatInt(int64(k.Sequence), 10), key, nil
|
||||
}
|
||||
|
||||
func (c *cache) cryptoKey(ctx context.Context, sequence int32) (string, []byte, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.closed {
|
||||
return "", nil, ErrClosed
|
||||
}
|
||||
|
||||
var key codersdk.CryptoKey
|
||||
var ok bool
|
||||
for key, ok = c.key(sequence); !ok && c.fetching && !c.closed; {
|
||||
c.cond.Wait()
|
||||
}
|
||||
|
||||
if c.closed {
|
||||
return "", nil, ErrClosed
|
||||
}
|
||||
|
||||
if ok {
|
||||
return checkKey(key, sequence, c.clock.Now())
|
||||
}
|
||||
|
||||
c.fetching = true
|
||||
c.mu.Unlock()
|
||||
|
||||
keys, err := c.cryptoKeys(ctx)
|
||||
if err != nil {
|
||||
return "", nil, xerrors.Errorf("get keys: %w", err)
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
c.lastFetch = c.clock.Now()
|
||||
c.refresher.Reset(refreshInterval)
|
||||
c.keys = keys
|
||||
c.fetching = false
|
||||
c.cond.Broadcast()
|
||||
|
||||
key, ok = c.key(sequence)
|
||||
if !ok {
|
||||
return "", nil, ErrKeyNotFound
|
||||
}
|
||||
|
||||
return checkKey(key, sequence, c.clock.Now())
|
||||
}
|
||||
|
||||
func (c *cache) key(sequence int32) (codersdk.CryptoKey, bool) {
|
||||
if sequence == latestSequence {
|
||||
return c.keys[latestSequence], c.keys[latestSequence].CanSign(c.clock.Now())
|
||||
}
|
||||
|
||||
key, ok := c.keys[sequence]
|
||||
return key, ok
|
||||
}
|
||||
|
||||
func checkKey(key codersdk.CryptoKey, sequence int32, now time.Time) (string, []byte, error) {
|
||||
if sequence == latestSequence {
|
||||
if !key.CanSign(now) {
|
||||
return "", nil, ErrKeyInvalid
|
||||
}
|
||||
return idSecret(key)
|
||||
}
|
||||
|
||||
if !key.CanVerify(now) {
|
||||
return "", nil, ErrKeyInvalid
|
||||
}
|
||||
|
||||
return idSecret(key)
|
||||
}
|
||||
|
||||
// refresh fetches the keys and updates the cache.
|
||||
func (c *cache) refresh() {
|
||||
now := c.clock.Now("CryptoKeyCache", "refresh")
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.closed {
|
||||
return
|
||||
}
|
||||
|
||||
// If something's already fetching, we don't need to do anything.
|
||||
if c.fetching {
|
||||
return
|
||||
}
|
||||
|
||||
// There's a window we must account for where the timer fires while a fetch
|
||||
// is ongoing but prior to the timer getting reset. In this case we want to
|
||||
// avoid double fetching.
|
||||
if now.Sub(c.lastFetch) < refreshInterval {
|
||||
return
|
||||
}
|
||||
|
||||
c.fetching = true
|
||||
|
||||
c.mu.Unlock()
|
||||
keys, err := c.cryptoKeys(c.refreshCtx)
|
||||
if err != nil {
|
||||
c.logger.Error(c.refreshCtx, "fetch crypto keys", slog.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
// We don't defer an unlock here due to the deferred unlock at the top of the function.
|
||||
c.mu.Lock()
|
||||
|
||||
c.lastFetch = c.clock.Now()
|
||||
c.refresher.Reset(refreshInterval)
|
||||
c.keys = keys
|
||||
c.fetching = false
|
||||
c.cond.Broadcast()
|
||||
}
|
||||
|
||||
// cryptoKeys queries the control plane for the crypto keys.
|
||||
// Outside of initialization, this should only be called by fetch.
|
||||
func (c *cache) cryptoKeys(ctx context.Context) (map[int32]codersdk.CryptoKey, error) {
|
||||
keys, err := c.fetcher.Fetch(ctx)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("crypto keys: %w", err)
|
||||
}
|
||||
cache := toKeyMap(keys, c.clock.Now())
|
||||
return cache, nil
|
||||
}
|
||||
|
||||
func toKeyMap(keys []codersdk.CryptoKey, now time.Time) map[int32]codersdk.CryptoKey {
|
||||
m := make(map[int32]codersdk.CryptoKey)
|
||||
var latest codersdk.CryptoKey
|
||||
for _, key := range keys {
|
||||
m[key.Sequence] = key
|
||||
if key.Sequence > latest.Sequence && key.CanSign(now) {
|
||||
m[latestSequence] = key
|
||||
}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func (c *cache) Close() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.closed {
|
||||
return nil
|
||||
}
|
||||
|
||||
c.closed = true
|
||||
c.refreshCancel()
|
||||
c.refresher.Stop()
|
||||
c.cond.Broadcast()
|
||||
|
||||
return nil
|
||||
}
|
517
coderd/cryptokeys/cache_test.go
Normal file
517
coderd/cryptokeys/cache_test.go
Normal file
@ -0,0 +1,517 @@
|
||||
package cryptokeys_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/goleak"
|
||||
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/cryptokeys"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
goleak.VerifyTestMain(m)
|
||||
}
|
||||
|
||||
func TestCryptoKeyCache(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("Signing", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("HitsCache", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var (
|
||||
ctx = testutil.Context(t, testutil.WaitShort)
|
||||
logger = slogtest.Make(t, nil)
|
||||
clock = quartz.NewMock(t)
|
||||
)
|
||||
|
||||
now := clock.Now().UTC()
|
||||
expected := codersdk.CryptoKey{
|
||||
Feature: codersdk.CryptoKeyFeatureTailnetResume,
|
||||
Secret: generateKey(t, 64),
|
||||
Sequence: 2,
|
||||
StartsAt: now,
|
||||
}
|
||||
|
||||
ff := &fakeFetcher{
|
||||
keys: []codersdk.CryptoKey{expected},
|
||||
}
|
||||
|
||||
cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock))
|
||||
require.NoError(t, err)
|
||||
|
||||
id, got, err := cache.SigningKey(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, keyID(expected), id)
|
||||
require.Equal(t, decodedSecret(t, expected), got)
|
||||
require.Equal(t, 1, ff.called)
|
||||
})
|
||||
|
||||
t.Run("MissesCache", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var (
|
||||
ctx = testutil.Context(t, testutil.WaitShort)
|
||||
logger = slogtest.Make(t, nil)
|
||||
clock = quartz.NewMock(t)
|
||||
)
|
||||
|
||||
ff := &fakeFetcher{
|
||||
keys: []codersdk.CryptoKey{},
|
||||
}
|
||||
|
||||
cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock))
|
||||
require.NoError(t, err)
|
||||
|
||||
expected := codersdk.CryptoKey{
|
||||
Feature: codersdk.CryptoKeyFeatureTailnetResume,
|
||||
Secret: generateKey(t, 64),
|
||||
Sequence: 12,
|
||||
StartsAt: clock.Now().UTC(),
|
||||
}
|
||||
ff.keys = []codersdk.CryptoKey{expected}
|
||||
|
||||
id, got, err := cache.SigningKey(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, decodedSecret(t, expected), got)
|
||||
require.Equal(t, keyID(expected), id)
|
||||
// 1 on startup + missing cache.
|
||||
require.Equal(t, 2, ff.called)
|
||||
|
||||
// Ensure the cache gets hit this time.
|
||||
id, got, err = cache.SigningKey(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, decodedSecret(t, expected), got)
|
||||
require.Equal(t, keyID(expected), id)
|
||||
// 1 on startup + missing cache.
|
||||
require.Equal(t, 2, ff.called)
|
||||
})
|
||||
|
||||
t.Run("IgnoresInvalid", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
ctx = testutil.Context(t, testutil.WaitShort)
|
||||
logger = slogtest.Make(t, nil)
|
||||
clock = quartz.NewMock(t)
|
||||
)
|
||||
now := clock.Now().UTC()
|
||||
|
||||
expected := codersdk.CryptoKey{
|
||||
Feature: codersdk.CryptoKeyFeatureTailnetResume,
|
||||
Secret: generateKey(t, 64),
|
||||
Sequence: 1,
|
||||
StartsAt: clock.Now().UTC(),
|
||||
}
|
||||
|
||||
ff := &fakeFetcher{
|
||||
keys: []codersdk.CryptoKey{
|
||||
expected,
|
||||
{
|
||||
Feature: codersdk.CryptoKeyFeatureTailnetResume,
|
||||
Secret: generateKey(t, 64),
|
||||
Sequence: 2,
|
||||
StartsAt: now.Add(-time.Second),
|
||||
DeletesAt: now,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock))
|
||||
require.NoError(t, err)
|
||||
|
||||
id, got, err := cache.SigningKey(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, decodedSecret(t, expected), got)
|
||||
require.Equal(t, keyID(expected), id)
|
||||
require.Equal(t, 1, ff.called)
|
||||
})
|
||||
|
||||
t.Run("KeyNotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
ctx = testutil.Context(t, testutil.WaitShort)
|
||||
logger = slogtest.Make(t, nil)
|
||||
)
|
||||
|
||||
ff := &fakeFetcher{
|
||||
keys: []codersdk.CryptoKey{},
|
||||
}
|
||||
|
||||
cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = cache.SigningKey(ctx)
|
||||
require.ErrorIs(t, err, cryptokeys.ErrKeyNotFound)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("Verifying", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("HitsCache", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
ctx = testutil.Context(t, testutil.WaitShort)
|
||||
logger = slogtest.Make(t, nil)
|
||||
clock = quartz.NewMock(t)
|
||||
)
|
||||
|
||||
now := clock.Now().UTC()
|
||||
expected := codersdk.CryptoKey{
|
||||
Feature: codersdk.CryptoKeyFeatureTailnetResume,
|
||||
Secret: generateKey(t, 64),
|
||||
Sequence: 12,
|
||||
StartsAt: now,
|
||||
}
|
||||
ff := &fakeFetcher{
|
||||
keys: []codersdk.CryptoKey{
|
||||
expected,
|
||||
{
|
||||
Feature: codersdk.CryptoKeyFeatureTailnetResume,
|
||||
Secret: generateKey(t, 64),
|
||||
Sequence: 13,
|
||||
StartsAt: now,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock))
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := cache.VerifyingKey(ctx, keyID(expected))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, decodedSecret(t, expected), got)
|
||||
require.Equal(t, 1, ff.called)
|
||||
})
|
||||
|
||||
t.Run("MissesCache", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var (
|
||||
ctx = testutil.Context(t, testutil.WaitShort)
|
||||
logger = slogtest.Make(t, nil)
|
||||
clock = quartz.NewMock(t)
|
||||
)
|
||||
|
||||
ff := &fakeFetcher{
|
||||
keys: []codersdk.CryptoKey{},
|
||||
}
|
||||
|
||||
cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock))
|
||||
require.NoError(t, err)
|
||||
|
||||
expected := codersdk.CryptoKey{
|
||||
Feature: codersdk.CryptoKeyFeatureTailnetResume,
|
||||
Secret: generateKey(t, 64),
|
||||
Sequence: 12,
|
||||
StartsAt: clock.Now().UTC(),
|
||||
}
|
||||
ff.keys = []codersdk.CryptoKey{expected}
|
||||
|
||||
got, err := cache.VerifyingKey(ctx, keyID(expected))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, decodedSecret(t, expected), got)
|
||||
require.Equal(t, 2, ff.called)
|
||||
|
||||
// Ensure the cache gets hit this time.
|
||||
got, err = cache.VerifyingKey(ctx, keyID(expected))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, decodedSecret(t, expected), got)
|
||||
require.Equal(t, 2, ff.called)
|
||||
})
|
||||
|
||||
t.Run("AllowsBeforeStartsAt", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
ctx = testutil.Context(t, testutil.WaitShort)
|
||||
logger = slogtest.Make(t, nil)
|
||||
clock = quartz.NewMock(t)
|
||||
)
|
||||
|
||||
now := clock.Now().UTC()
|
||||
expected := codersdk.CryptoKey{
|
||||
Feature: codersdk.CryptoKeyFeatureTailnetResume,
|
||||
Secret: generateKey(t, 64),
|
||||
Sequence: 12,
|
||||
StartsAt: now.Add(-time.Second),
|
||||
}
|
||||
|
||||
ff := &fakeFetcher{
|
||||
keys: []codersdk.CryptoKey{
|
||||
expected,
|
||||
},
|
||||
}
|
||||
|
||||
cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock))
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := cache.VerifyingKey(ctx, keyID(expected))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, decodedSecret(t, expected), got)
|
||||
require.Equal(t, 1, ff.called)
|
||||
})
|
||||
|
||||
t.Run("KeyPastDeletesAt", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
ctx = testutil.Context(t, testutil.WaitShort)
|
||||
logger = slogtest.Make(t, nil)
|
||||
clock = quartz.NewMock(t)
|
||||
)
|
||||
|
||||
now := clock.Now().UTC()
|
||||
expected := codersdk.CryptoKey{
|
||||
Feature: codersdk.CryptoKeyFeatureTailnetResume,
|
||||
Secret: generateKey(t, 64),
|
||||
Sequence: 12,
|
||||
StartsAt: now.Add(-time.Second),
|
||||
DeletesAt: now,
|
||||
}
|
||||
|
||||
ff := &fakeFetcher{
|
||||
keys: []codersdk.CryptoKey{
|
||||
expected,
|
||||
},
|
||||
}
|
||||
|
||||
cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = cache.VerifyingKey(ctx, keyID(expected))
|
||||
require.ErrorIs(t, err, cryptokeys.ErrKeyInvalid)
|
||||
require.Equal(t, 1, ff.called)
|
||||
})
|
||||
|
||||
t.Run("KeyNotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
ctx = testutil.Context(t, testutil.WaitShort)
|
||||
logger = slogtest.Make(t, nil)
|
||||
clock = quartz.NewMock(t)
|
||||
)
|
||||
|
||||
ff := &fakeFetcher{
|
||||
keys: []codersdk.CryptoKey{},
|
||||
}
|
||||
|
||||
cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = cache.VerifyingKey(ctx, "1")
|
||||
require.ErrorIs(t, err, cryptokeys.ErrKeyNotFound)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("CacheRefreshes", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
ctx = testutil.Context(t, testutil.WaitShort)
|
||||
logger = slogtest.Make(t, nil)
|
||||
clock = quartz.NewMock(t)
|
||||
)
|
||||
|
||||
now := clock.Now().UTC()
|
||||
expected := codersdk.CryptoKey{
|
||||
Feature: codersdk.CryptoKeyFeatureTailnetResume,
|
||||
Secret: generateKey(t, 64),
|
||||
Sequence: 12,
|
||||
StartsAt: now,
|
||||
DeletesAt: now.Add(time.Minute * 10),
|
||||
}
|
||||
ff := &fakeFetcher{
|
||||
keys: []codersdk.CryptoKey{
|
||||
expected,
|
||||
},
|
||||
}
|
||||
|
||||
cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock))
|
||||
require.NoError(t, err)
|
||||
|
||||
id, got, err := cache.SigningKey(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, decodedSecret(t, expected), got)
|
||||
require.Equal(t, keyID(expected), id)
|
||||
require.Equal(t, 1, ff.called)
|
||||
|
||||
newKey := codersdk.CryptoKey{
|
||||
Feature: codersdk.CryptoKeyFeatureTailnetResume,
|
||||
Secret: generateKey(t, 64),
|
||||
Sequence: 13,
|
||||
StartsAt: now,
|
||||
}
|
||||
ff.keys = []codersdk.CryptoKey{newKey}
|
||||
|
||||
// The ticker should fire and cause a request to coderd.
|
||||
dur, advance := clock.AdvanceNext()
|
||||
advance.MustWait(ctx)
|
||||
require.Equal(t, 2, ff.called)
|
||||
require.Equal(t, time.Minute*10, dur)
|
||||
|
||||
// Assert hits cache.
|
||||
id, got, err = cache.SigningKey(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, keyID(newKey), id)
|
||||
require.Equal(t, decodedSecret(t, newKey), got)
|
||||
require.Equal(t, 2, ff.called)
|
||||
|
||||
// We check again to ensure the timer has been reset.
|
||||
_, advance = clock.AdvanceNext()
|
||||
advance.MustWait(ctx)
|
||||
require.Equal(t, 3, ff.called)
|
||||
require.Equal(t, time.Minute*10, dur)
|
||||
})
|
||||
|
||||
// This test ensures that if the refresh timer races with an inflight request
|
||||
// and loses that it doesn't cause a redundant fetch.
|
||||
|
||||
t.Run("RefreshNoDoubleFetch", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
ctx = testutil.Context(t, testutil.WaitShort)
|
||||
logger = slogtest.Make(t, nil)
|
||||
clock = quartz.NewMock(t)
|
||||
)
|
||||
|
||||
now := clock.Now().UTC()
|
||||
expected := codersdk.CryptoKey{
|
||||
Feature: codersdk.CryptoKeyFeatureTailnetResume,
|
||||
Secret: generateKey(t, 64),
|
||||
Sequence: 12,
|
||||
StartsAt: now,
|
||||
DeletesAt: now.Add(time.Minute * 10),
|
||||
}
|
||||
ff := &fakeFetcher{
|
||||
keys: []codersdk.CryptoKey{
|
||||
expected,
|
||||
},
|
||||
}
|
||||
|
||||
// Create a trap that blocks when the refresh timer fires.
|
||||
trap := clock.Trap().Now("refresh")
|
||||
cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, wait := clock.AdvanceNext()
|
||||
trapped := trap.MustWait(ctx)
|
||||
|
||||
newKey := codersdk.CryptoKey{
|
||||
Feature: codersdk.CryptoKeyFeatureTailnetResume,
|
||||
Secret: generateKey(t, 64),
|
||||
Sequence: 13,
|
||||
StartsAt: now,
|
||||
}
|
||||
ff.keys = []codersdk.CryptoKey{newKey}
|
||||
|
||||
key, err := cache.VerifyingKey(ctx, keyID(newKey))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 2, ff.called)
|
||||
require.Equal(t, decodedSecret(t, newKey), key)
|
||||
|
||||
trapped.Release()
|
||||
wait.MustWait(ctx)
|
||||
require.Equal(t, 2, ff.called)
|
||||
trap.Close()
|
||||
|
||||
// The next timer should fire in 10 minutes.
|
||||
dur, wait := clock.AdvanceNext()
|
||||
wait.MustWait(ctx)
|
||||
require.Equal(t, time.Minute*10, dur)
|
||||
require.Equal(t, 3, ff.called)
|
||||
})
|
||||
|
||||
t.Run("Closed", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
ctx = testutil.Context(t, testutil.WaitShort)
|
||||
logger = slogtest.Make(t, nil)
|
||||
clock = quartz.NewMock(t)
|
||||
)
|
||||
|
||||
now := clock.Now()
|
||||
expected := codersdk.CryptoKey{
|
||||
Feature: codersdk.CryptoKeyFeatureTailnetResume,
|
||||
Secret: generateKey(t, 64),
|
||||
Sequence: 12,
|
||||
StartsAt: now,
|
||||
}
|
||||
ff := &fakeFetcher{
|
||||
keys: []codersdk.CryptoKey{
|
||||
expected,
|
||||
},
|
||||
}
|
||||
|
||||
cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock))
|
||||
require.NoError(t, err)
|
||||
|
||||
id, got, err := cache.SigningKey(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, keyID(expected), id)
|
||||
require.Equal(t, decodedSecret(t, expected), got)
|
||||
require.Equal(t, 1, ff.called)
|
||||
|
||||
key, err := cache.VerifyingKey(ctx, keyID(expected))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, decodedSecret(t, expected), key)
|
||||
require.Equal(t, 1, ff.called)
|
||||
|
||||
cache.Close()
|
||||
|
||||
_, _, err = cache.SigningKey(ctx)
|
||||
require.ErrorIs(t, err, cryptokeys.ErrClosed)
|
||||
|
||||
_, err = cache.VerifyingKey(ctx, keyID(expected))
|
||||
require.ErrorIs(t, err, cryptokeys.ErrClosed)
|
||||
})
|
||||
}
|
||||
|
||||
type fakeFetcher struct {
|
||||
keys []codersdk.CryptoKey
|
||||
called int
|
||||
}
|
||||
|
||||
func (f *fakeFetcher) Fetch(_ context.Context) ([]codersdk.CryptoKey, error) {
|
||||
f.called++
|
||||
return f.keys, nil
|
||||
}
|
||||
|
||||
func keyID(key codersdk.CryptoKey) string {
|
||||
return strconv.FormatInt(int64(key.Sequence), 10)
|
||||
}
|
||||
|
||||
func decodedSecret(t *testing.T, key codersdk.CryptoKey) []byte {
|
||||
t.Helper()
|
||||
|
||||
secret, err := hex.DecodeString(key.Secret)
|
||||
require.NoError(t, err)
|
||||
|
||||
return secret
|
||||
}
|
||||
|
||||
func generateKey(t *testing.T, size int) string {
|
||||
t.Helper()
|
||||
|
||||
key := make([]byte, size)
|
||||
_, err := rand.Read(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
return hex.EncodeToString(key)
|
||||
}
|
@ -1,286 +0,0 @@
|
||||
package cryptokeys
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
// never represents the maximum value for a time.Duration.
|
||||
const never = 1<<63 - 1
|
||||
|
||||
// dbCache implements Keycache for callers with access to the database.
|
||||
type dbCache struct {
|
||||
db database.Store
|
||||
feature database.CryptoKeyFeature
|
||||
logger slog.Logger
|
||||
clock quartz.Clock
|
||||
|
||||
// The following are initialized by NewDBCache.
|
||||
keysMu sync.RWMutex
|
||||
keys map[int32]database.CryptoKey
|
||||
latestKey database.CryptoKey
|
||||
timer *quartz.Timer
|
||||
// invalidateAt is the time at which the keys cache should be invalidated.
|
||||
invalidateAt time.Time
|
||||
closed bool
|
||||
}
|
||||
|
||||
type DBCacheOption func(*dbCache)
|
||||
|
||||
func WithDBCacheClock(clock quartz.Clock) DBCacheOption {
|
||||
return func(d *dbCache) {
|
||||
d.clock = clock
|
||||
}
|
||||
}
|
||||
|
||||
// NewSigningCache creates a new DBCache. Close should be called to
|
||||
// release resources associated with its internal timer.
|
||||
func NewSigningCache(logger slog.Logger, db database.Store, feature database.CryptoKeyFeature, opts ...func(*dbCache)) (SigningKeycache, error) {
|
||||
if !isSigningKeyFeature(feature) {
|
||||
return nil, ErrInvalidFeature
|
||||
}
|
||||
|
||||
return newDBCache(logger, db, feature, opts...), nil
|
||||
}
|
||||
|
||||
func NewEncryptionCache(logger slog.Logger, db database.Store, feature database.CryptoKeyFeature, opts ...func(*dbCache)) (EncryptionKeycache, error) {
|
||||
if !isEncryptionKeyFeature(feature) {
|
||||
return nil, ErrInvalidFeature
|
||||
}
|
||||
|
||||
return newDBCache(logger, db, feature, opts...), nil
|
||||
}
|
||||
|
||||
func newDBCache(logger slog.Logger, db database.Store, feature database.CryptoKeyFeature, opts ...func(*dbCache)) *dbCache {
|
||||
d := &dbCache{
|
||||
db: db,
|
||||
feature: feature,
|
||||
clock: quartz.NewReal(),
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(d)
|
||||
}
|
||||
|
||||
// Initialize the timer. This will get properly initialized the first time we fetch.
|
||||
d.timer = d.clock.AfterFunc(never, d.clear)
|
||||
|
||||
return d
|
||||
}
|
||||
|
||||
func (d *dbCache) EncryptingKey(ctx context.Context) (id string, key interface{}, err error) {
|
||||
if !isEncryptionKeyFeature(d.feature) {
|
||||
return "", nil, ErrInvalidFeature
|
||||
}
|
||||
|
||||
return d.latest(ctx)
|
||||
}
|
||||
|
||||
func (d *dbCache) DecryptingKey(ctx context.Context, id string) (key interface{}, err error) {
|
||||
if !isEncryptionKeyFeature(d.feature) {
|
||||
return nil, ErrInvalidFeature
|
||||
}
|
||||
|
||||
return d.sequence(ctx, id)
|
||||
}
|
||||
|
||||
func (d *dbCache) SigningKey(ctx context.Context) (id string, key interface{}, err error) {
|
||||
if !isSigningKeyFeature(d.feature) {
|
||||
return "", nil, ErrInvalidFeature
|
||||
}
|
||||
|
||||
return d.latest(ctx)
|
||||
}
|
||||
|
||||
func (d *dbCache) VerifyingKey(ctx context.Context, id string) (key interface{}, err error) {
|
||||
if !isSigningKeyFeature(d.feature) {
|
||||
return nil, ErrInvalidFeature
|
||||
}
|
||||
|
||||
return d.sequence(ctx, id)
|
||||
}
|
||||
|
||||
// sequence returns the CryptoKey with the given sequence number, provided that
|
||||
// it is neither deleted nor has breached its deletion date. It should only be
|
||||
// used for verifying or decrypting payloads. To sign/encrypt call Signing.
|
||||
func (d *dbCache) sequence(ctx context.Context, id string) (interface{}, error) {
|
||||
sequence, err := strconv.ParseInt(id, 10, 32)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("expecting sequence number got %q: %w", id, err)
|
||||
}
|
||||
|
||||
d.keysMu.RLock()
|
||||
if d.closed {
|
||||
d.keysMu.RUnlock()
|
||||
return nil, ErrClosed
|
||||
}
|
||||
|
||||
now := d.clock.Now()
|
||||
key, ok := d.keys[int32(sequence)]
|
||||
d.keysMu.RUnlock()
|
||||
if ok {
|
||||
return checkKey(key, now)
|
||||
}
|
||||
|
||||
d.keysMu.Lock()
|
||||
defer d.keysMu.Unlock()
|
||||
|
||||
if d.closed {
|
||||
return nil, ErrClosed
|
||||
}
|
||||
|
||||
key, ok = d.keys[int32(sequence)]
|
||||
if ok {
|
||||
return checkKey(key, now)
|
||||
}
|
||||
|
||||
err = d.fetch(ctx)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("fetch: %w", err)
|
||||
}
|
||||
|
||||
key, ok = d.keys[int32(sequence)]
|
||||
if !ok {
|
||||
return nil, ErrKeyNotFound
|
||||
}
|
||||
|
||||
return checkKey(key, now)
|
||||
}
|
||||
|
||||
// latest returns the latest valid key for signing. A valid key is one that is
|
||||
// both past its start time and before its deletion time.
|
||||
func (d *dbCache) latest(ctx context.Context) (string, interface{}, error) {
|
||||
d.keysMu.RLock()
|
||||
|
||||
if d.closed {
|
||||
d.keysMu.RUnlock()
|
||||
return "", nil, ErrClosed
|
||||
}
|
||||
|
||||
latest := d.latestKey
|
||||
d.keysMu.RUnlock()
|
||||
|
||||
now := d.clock.Now()
|
||||
if latest.CanSign(now) {
|
||||
return idSecret(latest)
|
||||
}
|
||||
|
||||
d.keysMu.Lock()
|
||||
defer d.keysMu.Unlock()
|
||||
|
||||
if d.closed {
|
||||
return "", nil, ErrClosed
|
||||
}
|
||||
|
||||
if d.latestKey.CanSign(now) {
|
||||
return idSecret(d.latestKey)
|
||||
}
|
||||
|
||||
// Refetch all keys for this feature so we can find the latest valid key.
|
||||
err := d.fetch(ctx)
|
||||
if err != nil {
|
||||
return "", nil, xerrors.Errorf("fetch: %w", err)
|
||||
}
|
||||
|
||||
return idSecret(d.latestKey)
|
||||
}
|
||||
|
||||
// clear invalidates the cache. This forces the subsequent call to fetch fresh keys.
|
||||
func (d *dbCache) clear() {
|
||||
now := d.clock.Now("DBCache", "clear")
|
||||
d.keysMu.Lock()
|
||||
defer d.keysMu.Unlock()
|
||||
// Check if we raced with a fetch. It's possible that the timer fired and we
|
||||
// lost the race to the mutex. We want to avoid invalidating
|
||||
// a cache that was just refetched.
|
||||
if now.Before(d.invalidateAt) {
|
||||
return
|
||||
}
|
||||
d.keys = nil
|
||||
d.latestKey = database.CryptoKey{}
|
||||
}
|
||||
|
||||
// fetch fetches all keys for the given feature and determines the latest key.
|
||||
// It must be called while holding the keysMu lock.
|
||||
func (d *dbCache) fetch(ctx context.Context) error {
|
||||
keys, err := d.db.GetCryptoKeysByFeature(ctx, d.feature)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get crypto keys by feature: %w", err)
|
||||
}
|
||||
|
||||
now := d.clock.Now()
|
||||
_ = d.timer.Reset(time.Minute * 10)
|
||||
d.invalidateAt = now.Add(time.Minute * 10)
|
||||
|
||||
cache := make(map[int32]database.CryptoKey)
|
||||
var latest database.CryptoKey
|
||||
for _, key := range keys {
|
||||
cache[key.Sequence] = key
|
||||
if key.CanSign(now) && key.Sequence > latest.Sequence {
|
||||
latest = key
|
||||
}
|
||||
}
|
||||
|
||||
if len(cache) == 0 {
|
||||
return ErrKeyNotFound
|
||||
}
|
||||
|
||||
if !latest.CanSign(now) {
|
||||
return ErrKeyInvalid
|
||||
}
|
||||
|
||||
d.keys, d.latestKey = cache, latest
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkKey(key database.CryptoKey, now time.Time) (interface{}, error) {
|
||||
if !key.CanVerify(now) {
|
||||
return nil, ErrKeyInvalid
|
||||
}
|
||||
|
||||
return key.DecodeString()
|
||||
}
|
||||
|
||||
func (d *dbCache) Close() error {
|
||||
d.keysMu.Lock()
|
||||
defer d.keysMu.Unlock()
|
||||
|
||||
if d.closed {
|
||||
return nil
|
||||
}
|
||||
|
||||
d.timer.Stop()
|
||||
d.closed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func isEncryptionKeyFeature(feature database.CryptoKeyFeature) bool {
|
||||
return feature == database.CryptoKeyFeatureWorkspaceApps
|
||||
}
|
||||
|
||||
func isSigningKeyFeature(feature database.CryptoKeyFeature) bool {
|
||||
switch feature {
|
||||
case database.CryptoKeyFeatureTailnetResume, database.CryptoKeyFeatureOidcConvert:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func idSecret(k database.CryptoKey) (string, interface{}, error) {
|
||||
key, err := k.DecodeString()
|
||||
if err != nil {
|
||||
return "", nil, xerrors.Errorf("decode key: %w", err)
|
||||
}
|
||||
|
||||
return strconv.FormatInt(int64(k.Sequence), 10), key, nil
|
||||
}
|
@ -1,490 +0,0 @@
|
||||
package cryptokeys
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
func Test_version(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("HitsCache", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
ctrl = gomock.NewController(t)
|
||||
mockDB = dbmock.NewMockStore(ctrl)
|
||||
clock = quartz.NewMock(t)
|
||||
logger = slogtest.Make(t, nil)
|
||||
ctx = testutil.Context(t, testutil.WaitShort)
|
||||
)
|
||||
|
||||
expectedKey := database.CryptoKey{
|
||||
Feature: database.CryptoKeyFeatureWorkspaceApps,
|
||||
Sequence: 32,
|
||||
Secret: sql.NullString{
|
||||
String: mustGenerateKey(t),
|
||||
Valid: true,
|
||||
},
|
||||
}
|
||||
|
||||
cache := map[int32]database.CryptoKey{
|
||||
32: expectedKey,
|
||||
}
|
||||
|
||||
k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock))
|
||||
defer k.Close()
|
||||
k.keys = cache
|
||||
|
||||
secret, err := k.sequence(ctx, keyID(expectedKey))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, decodedSecret(t, expectedKey), secret)
|
||||
})
|
||||
|
||||
t.Run("MissesCache", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
ctrl = gomock.NewController(t)
|
||||
mockDB = dbmock.NewMockStore(ctrl)
|
||||
clock = quartz.NewMock(t)
|
||||
ctx = testutil.Context(t, testutil.WaitShort)
|
||||
logger = slogtest.Make(t, nil)
|
||||
)
|
||||
|
||||
expectedKey := database.CryptoKey{
|
||||
Feature: database.CryptoKeyFeatureWorkspaceApps,
|
||||
Sequence: 33,
|
||||
StartsAt: clock.Now(),
|
||||
Secret: sql.NullString{
|
||||
String: mustGenerateKey(t),
|
||||
Valid: true,
|
||||
},
|
||||
}
|
||||
|
||||
mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{expectedKey}, nil)
|
||||
|
||||
k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock))
|
||||
defer k.Close()
|
||||
|
||||
got, err := k.sequence(ctx, keyID(expectedKey))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, decodedSecret(t, expectedKey), got)
|
||||
})
|
||||
|
||||
t.Run("InvalidCachedKey", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
ctrl = gomock.NewController(t)
|
||||
mockDB = dbmock.NewMockStore(ctrl)
|
||||
clock = quartz.NewMock(t)
|
||||
ctx = testutil.Context(t, testutil.WaitShort)
|
||||
logger = slogtest.Make(t, nil)
|
||||
)
|
||||
|
||||
cache := map[int32]database.CryptoKey{
|
||||
32: {
|
||||
Feature: database.CryptoKeyFeatureWorkspaceApps,
|
||||
Sequence: 32,
|
||||
Secret: sql.NullString{
|
||||
String: mustGenerateKey(t),
|
||||
Valid: true,
|
||||
},
|
||||
DeletesAt: sql.NullTime{
|
||||
Time: clock.Now(),
|
||||
Valid: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock))
|
||||
defer k.Close()
|
||||
k.keys = cache
|
||||
|
||||
_, err := k.sequence(ctx, "32")
|
||||
require.ErrorIs(t, err, ErrKeyInvalid)
|
||||
})
|
||||
|
||||
t.Run("InvalidDBKey", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
ctrl = gomock.NewController(t)
|
||||
mockDB = dbmock.NewMockStore(ctrl)
|
||||
clock = quartz.NewMock(t)
|
||||
ctx = testutil.Context(t, testutil.WaitShort)
|
||||
logger = slogtest.Make(t, nil)
|
||||
)
|
||||
|
||||
invalidKey := database.CryptoKey{
|
||||
Feature: database.CryptoKeyFeatureWorkspaceApps,
|
||||
Sequence: 32,
|
||||
Secret: sql.NullString{
|
||||
String: mustGenerateKey(t),
|
||||
Valid: true,
|
||||
},
|
||||
DeletesAt: sql.NullTime{
|
||||
Time: clock.Now(),
|
||||
Valid: true,
|
||||
},
|
||||
}
|
||||
mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{invalidKey}, nil)
|
||||
|
||||
k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock))
|
||||
defer k.Close()
|
||||
|
||||
_, err := k.sequence(ctx, keyID(invalidKey))
|
||||
require.ErrorIs(t, err, ErrKeyInvalid)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_latest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("HitsCache", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
ctrl = gomock.NewController(t)
|
||||
mockDB = dbmock.NewMockStore(ctrl)
|
||||
clock = quartz.NewMock(t)
|
||||
ctx = testutil.Context(t, testutil.WaitShort)
|
||||
logger = slogtest.Make(t, nil)
|
||||
)
|
||||
|
||||
latestKey := database.CryptoKey{
|
||||
Feature: database.CryptoKeyFeatureWorkspaceApps,
|
||||
Sequence: 32,
|
||||
Secret: sql.NullString{
|
||||
String: mustGenerateKey(t),
|
||||
Valid: true,
|
||||
},
|
||||
StartsAt: clock.Now(),
|
||||
}
|
||||
k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock))
|
||||
defer k.Close()
|
||||
|
||||
k.latestKey = latestKey
|
||||
|
||||
id, secret, err := k.latest(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, keyID(latestKey), id)
|
||||
require.Equal(t, decodedSecret(t, latestKey), secret)
|
||||
})
|
||||
|
||||
t.Run("InvalidCachedKey", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
ctrl = gomock.NewController(t)
|
||||
mockDB = dbmock.NewMockStore(ctrl)
|
||||
clock = quartz.NewMock(t)
|
||||
ctx = testutil.Context(t, testutil.WaitShort)
|
||||
logger = slogtest.Make(t, nil)
|
||||
)
|
||||
|
||||
latestKey := database.CryptoKey{
|
||||
Feature: database.CryptoKeyFeatureWorkspaceApps,
|
||||
Sequence: 33,
|
||||
Secret: sql.NullString{
|
||||
String: mustGenerateKey(t),
|
||||
Valid: true,
|
||||
},
|
||||
StartsAt: clock.Now(),
|
||||
}
|
||||
|
||||
invalidKey := database.CryptoKey{
|
||||
Feature: database.CryptoKeyFeatureWorkspaceApps,
|
||||
Sequence: 32,
|
||||
Secret: sql.NullString{
|
||||
String: mustGenerateKey(t),
|
||||
Valid: true,
|
||||
},
|
||||
StartsAt: clock.Now().Add(-time.Hour),
|
||||
DeletesAt: sql.NullTime{
|
||||
Time: clock.Now(),
|
||||
Valid: true,
|
||||
},
|
||||
}
|
||||
|
||||
mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{latestKey}, nil)
|
||||
|
||||
k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock))
|
||||
defer k.Close()
|
||||
k.latestKey = invalidKey
|
||||
|
||||
id, secret, err := k.latest(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, keyID(latestKey), id)
|
||||
require.Equal(t, decodedSecret(t, latestKey), secret)
|
||||
})
|
||||
|
||||
t.Run("UsesActiveKey", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
ctrl = gomock.NewController(t)
|
||||
mockDB = dbmock.NewMockStore(ctrl)
|
||||
clock = quartz.NewMock(t)
|
||||
ctx = testutil.Context(t, testutil.WaitShort)
|
||||
logger = slogtest.Make(t, nil)
|
||||
)
|
||||
|
||||
inactiveKey := database.CryptoKey{
|
||||
Feature: database.CryptoKeyFeatureWorkspaceApps,
|
||||
Sequence: 32,
|
||||
Secret: sql.NullString{
|
||||
String: mustGenerateKey(t),
|
||||
Valid: true,
|
||||
},
|
||||
StartsAt: clock.Now().Add(time.Hour),
|
||||
}
|
||||
|
||||
activeKey := database.CryptoKey{
|
||||
Feature: database.CryptoKeyFeatureWorkspaceApps,
|
||||
Sequence: 33,
|
||||
Secret: sql.NullString{
|
||||
String: mustGenerateKey(t),
|
||||
Valid: true,
|
||||
},
|
||||
StartsAt: clock.Now(),
|
||||
}
|
||||
|
||||
mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{inactiveKey, activeKey}, nil)
|
||||
|
||||
k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock))
|
||||
defer k.Close()
|
||||
|
||||
id, secret, err := k.latest(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, keyID(activeKey), id)
|
||||
require.Equal(t, decodedSecret(t, activeKey), secret)
|
||||
})
|
||||
|
||||
t.Run("NoValidKeys", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
ctrl = gomock.NewController(t)
|
||||
mockDB = dbmock.NewMockStore(ctrl)
|
||||
clock = quartz.NewMock(t)
|
||||
ctx = testutil.Context(t, testutil.WaitShort)
|
||||
logger = slogtest.Make(t, nil)
|
||||
)
|
||||
|
||||
inactiveKey := database.CryptoKey{
|
||||
Feature: database.CryptoKeyFeatureWorkspaceApps,
|
||||
Sequence: 32,
|
||||
Secret: sql.NullString{
|
||||
String: mustGenerateKey(t),
|
||||
Valid: true,
|
||||
},
|
||||
StartsAt: clock.Now().Add(time.Hour),
|
||||
}
|
||||
|
||||
invalidKey := database.CryptoKey{
|
||||
Feature: database.CryptoKeyFeatureWorkspaceApps,
|
||||
Sequence: 33,
|
||||
Secret: sql.NullString{
|
||||
String: mustGenerateKey(t),
|
||||
Valid: true,
|
||||
},
|
||||
StartsAt: clock.Now().Add(-time.Hour),
|
||||
DeletesAt: sql.NullTime{
|
||||
Time: clock.Now(),
|
||||
Valid: true,
|
||||
},
|
||||
}
|
||||
|
||||
mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{inactiveKey, invalidKey}, nil)
|
||||
|
||||
k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock))
|
||||
defer k.Close()
|
||||
|
||||
_, _, err := k.latest(ctx)
|
||||
require.ErrorIs(t, err, ErrKeyInvalid)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_clear(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("InvalidatesCache", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
ctrl = gomock.NewController(t)
|
||||
mockDB = dbmock.NewMockStore(ctrl)
|
||||
clock = quartz.NewMock(t)
|
||||
ctx = testutil.Context(t, testutil.WaitShort)
|
||||
logger = slogtest.Make(t, nil)
|
||||
)
|
||||
|
||||
k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock))
|
||||
defer k.Close()
|
||||
|
||||
activeKey := database.CryptoKey{
|
||||
Feature: database.CryptoKeyFeatureWorkspaceApps,
|
||||
Sequence: 33,
|
||||
Secret: sql.NullString{
|
||||
String: mustGenerateKey(t),
|
||||
Valid: true,
|
||||
},
|
||||
StartsAt: clock.Now(),
|
||||
}
|
||||
|
||||
mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{activeKey}, nil)
|
||||
|
||||
_, _, err := k.latest(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
dur, wait := clock.AdvanceNext()
|
||||
wait.MustWait(ctx)
|
||||
require.Equal(t, time.Minute*10, dur)
|
||||
require.Len(t, k.keys, 0)
|
||||
require.Equal(t, database.CryptoKey{}, k.latestKey)
|
||||
})
|
||||
|
||||
t.Run("ResetsTimer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
ctrl = gomock.NewController(t)
|
||||
mockDB = dbmock.NewMockStore(ctrl)
|
||||
clock = quartz.NewMock(t)
|
||||
ctx = testutil.Context(t, testutil.WaitShort)
|
||||
logger = slogtest.Make(t, nil)
|
||||
)
|
||||
|
||||
k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock))
|
||||
defer k.Close()
|
||||
|
||||
key := database.CryptoKey{
|
||||
Feature: database.CryptoKeyFeatureWorkspaceApps,
|
||||
Sequence: 32,
|
||||
Secret: sql.NullString{
|
||||
String: mustGenerateKey(t),
|
||||
Valid: true,
|
||||
},
|
||||
StartsAt: clock.Now(),
|
||||
}
|
||||
|
||||
mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{key}, nil)
|
||||
|
||||
// Advance it five minutes so that we can test that the
|
||||
// timer is reset and doesn't fire after another five minute.
|
||||
clock.Advance(time.Minute * 5)
|
||||
|
||||
id, secret, err := k.latest(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, keyID(key), id)
|
||||
require.Equal(t, decodedSecret(t, key), secret)
|
||||
|
||||
// Advancing the clock now should require 10 minutes
|
||||
// before the timer fires again.
|
||||
dur, wait := clock.AdvanceNext()
|
||||
wait.MustWait(ctx)
|
||||
require.Equal(t, time.Minute*10, dur)
|
||||
require.Len(t, k.keys, 0)
|
||||
require.Equal(t, database.CryptoKey{}, k.latestKey)
|
||||
})
|
||||
|
||||
// InvalidateAt tests that we have accounted for the race condition where a
|
||||
// timer fires to invalidate the cache at the same time we are fetching new
|
||||
// keys. In such cases we want to skip invalidation.
|
||||
t.Run("InvalidateAt", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
ctrl = gomock.NewController(t)
|
||||
mockDB = dbmock.NewMockStore(ctrl)
|
||||
clock = quartz.NewMock(t)
|
||||
ctx = testutil.Context(t, testutil.WaitShort)
|
||||
logger = slogtest.Make(t, nil)
|
||||
)
|
||||
|
||||
trap := clock.Trap().Now("clear")
|
||||
|
||||
k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock))
|
||||
defer k.Close()
|
||||
|
||||
key := database.CryptoKey{
|
||||
Feature: database.CryptoKeyFeatureWorkspaceApps,
|
||||
Sequence: 32,
|
||||
Secret: sql.NullString{
|
||||
String: mustGenerateKey(t),
|
||||
Valid: true,
|
||||
},
|
||||
StartsAt: clock.Now(),
|
||||
}
|
||||
|
||||
mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{key}, nil).Times(2)
|
||||
|
||||
// Move us past the initial timer.
|
||||
id, secret, err := k.latest(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, keyID(key), id)
|
||||
require.Equal(t, decodedSecret(t, key), secret)
|
||||
// Null these out so that we refetch.
|
||||
k.keys = nil
|
||||
k.latestKey = database.CryptoKey{}
|
||||
|
||||
// Initiate firing the timer.
|
||||
dur, wait := clock.AdvanceNext()
|
||||
require.Equal(t, time.Minute*10, dur)
|
||||
// Trap the function just before acquiring the mutex.
|
||||
call := trap.MustWait(ctx)
|
||||
|
||||
// Refetch keys.
|
||||
id, secret, err = k.latest(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, keyID(key), id)
|
||||
require.Equal(t, decodedSecret(t, key), secret)
|
||||
|
||||
// Let the rest of the timer function run.
|
||||
// It should see that we have refetched keys and
|
||||
// not invalidate.
|
||||
call.Release()
|
||||
wait.MustWait(ctx)
|
||||
require.Len(t, k.keys, 1)
|
||||
require.Equal(t, key, k.latestKey)
|
||||
trap.Close()
|
||||
|
||||
// Refetching the keys should've instantiated a new timer. This one should invalidate keys.
|
||||
_, wait = clock.AdvanceNext()
|
||||
wait.MustWait(ctx)
|
||||
require.Len(t, k.keys, 0)
|
||||
require.Equal(t, database.CryptoKey{}, k.latestKey)
|
||||
})
|
||||
}
|
||||
|
||||
func mustGenerateKey(t *testing.T) string {
|
||||
t.Helper()
|
||||
key, err := generateKey(64)
|
||||
require.NoError(t, err)
|
||||
return key
|
||||
}
|
||||
|
||||
func keyID(key database.CryptoKey) string {
|
||||
return strconv.FormatInt(int64(key.Sequence), 10)
|
||||
}
|
||||
|
||||
func decodedSecret(t *testing.T, key database.CryptoKey) []byte {
|
||||
t.Helper()
|
||||
decoded, err := key.DecodeString()
|
||||
require.NoError(t, err)
|
||||
return decoded
|
||||
}
|
@ -1,216 +0,0 @@
|
||||
package cryptokeys_test
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/goleak"
|
||||
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/cryptokeys"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
goleak.VerifyTestMain(m)
|
||||
}
|
||||
|
||||
func TestDBKeyCache(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("VerifyingKey", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("HitsCache", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
db, _ = dbtestutil.NewDB(t)
|
||||
clock = quartz.NewMock(t)
|
||||
ctx = testutil.Context(t, testutil.WaitShort)
|
||||
logger = slogtest.Make(t, nil)
|
||||
)
|
||||
|
||||
key := dbgen.CryptoKey(t, db, database.CryptoKey{
|
||||
Feature: database.CryptoKeyFeatureOidcConvert,
|
||||
Sequence: 1,
|
||||
StartsAt: clock.Now().UTC(),
|
||||
})
|
||||
|
||||
k, err := cryptokeys.NewSigningCache(logger, db, database.CryptoKeyFeatureOidcConvert, cryptokeys.WithDBCacheClock(clock))
|
||||
require.NoError(t, err)
|
||||
defer k.Close()
|
||||
|
||||
got, err := k.VerifyingKey(ctx, keyID(key))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, decodedSecret(t, key), got)
|
||||
})
|
||||
|
||||
t.Run("NotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
db, _ = dbtestutil.NewDB(t)
|
||||
clock = quartz.NewMock(t)
|
||||
ctx = testutil.Context(t, testutil.WaitShort)
|
||||
logger = slogtest.Make(t, nil)
|
||||
)
|
||||
|
||||
k, err := cryptokeys.NewSigningCache(logger, db, database.CryptoKeyFeatureOidcConvert, cryptokeys.WithDBCacheClock(clock))
|
||||
require.NoError(t, err)
|
||||
defer k.Close()
|
||||
|
||||
_, err = k.VerifyingKey(ctx, "123")
|
||||
require.ErrorIs(t, err, cryptokeys.ErrKeyNotFound)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("Signing", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
db, _ = dbtestutil.NewDB(t)
|
||||
clock = quartz.NewMock(t)
|
||||
ctx = testutil.Context(t, testutil.WaitShort)
|
||||
logger = slogtest.Make(t, nil)
|
||||
)
|
||||
|
||||
_ = dbgen.CryptoKey(t, db, database.CryptoKey{
|
||||
Feature: database.CryptoKeyFeatureOidcConvert,
|
||||
Sequence: 10,
|
||||
StartsAt: clock.Now().UTC(),
|
||||
})
|
||||
|
||||
expectedKey := dbgen.CryptoKey(t, db, database.CryptoKey{
|
||||
Feature: database.CryptoKeyFeatureOidcConvert,
|
||||
Sequence: 12,
|
||||
StartsAt: clock.Now().UTC(),
|
||||
})
|
||||
|
||||
_ = dbgen.CryptoKey(t, db, database.CryptoKey{
|
||||
Feature: database.CryptoKeyFeatureOidcConvert,
|
||||
Sequence: 2,
|
||||
StartsAt: clock.Now().UTC(),
|
||||
})
|
||||
|
||||
k, err := cryptokeys.NewSigningCache(logger, db, database.CryptoKeyFeatureOidcConvert, cryptokeys.WithDBCacheClock(clock))
|
||||
require.NoError(t, err)
|
||||
defer k.Close()
|
||||
|
||||
id, key, err := k.SigningKey(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, keyID(expectedKey), id)
|
||||
require.Equal(t, decodedSecret(t, expectedKey), key)
|
||||
})
|
||||
|
||||
t.Run("Closed", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
db, _ = dbtestutil.NewDB(t)
|
||||
clock = quartz.NewMock(t)
|
||||
ctx = testutil.Context(t, testutil.WaitShort)
|
||||
logger = slogtest.Make(t, nil)
|
||||
)
|
||||
|
||||
expectedKey := dbgen.CryptoKey(t, db, database.CryptoKey{
|
||||
Feature: database.CryptoKeyFeatureOidcConvert,
|
||||
Sequence: 10,
|
||||
StartsAt: clock.Now(),
|
||||
})
|
||||
|
||||
k, err := cryptokeys.NewSigningCache(logger, db, database.CryptoKeyFeatureOidcConvert, cryptokeys.WithDBCacheClock(clock))
|
||||
require.NoError(t, err)
|
||||
defer k.Close()
|
||||
|
||||
id, key, err := k.SigningKey(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, keyID(expectedKey), id)
|
||||
require.Equal(t, decodedSecret(t, expectedKey), key)
|
||||
|
||||
key, err = k.VerifyingKey(ctx, keyID(expectedKey))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, decodedSecret(t, expectedKey), key)
|
||||
|
||||
k.Close()
|
||||
|
||||
_, _, err = k.SigningKey(ctx)
|
||||
require.ErrorIs(t, err, cryptokeys.ErrClosed)
|
||||
|
||||
_, err = k.VerifyingKey(ctx, keyID(expectedKey))
|
||||
require.ErrorIs(t, err, cryptokeys.ErrClosed)
|
||||
})
|
||||
|
||||
t.Run("InvalidSigningFeature", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
db, _ = dbtestutil.NewDB(t)
|
||||
clock = quartz.NewMock(t)
|
||||
logger = slogtest.Make(t, nil)
|
||||
ctx = testutil.Context(t, testutil.WaitShort)
|
||||
)
|
||||
|
||||
_, err := cryptokeys.NewSigningCache(logger, db, database.CryptoKeyFeatureWorkspaceApps, cryptokeys.WithDBCacheClock(clock))
|
||||
require.ErrorIs(t, err, cryptokeys.ErrInvalidFeature)
|
||||
|
||||
// Instantiate a signing cache and try to use it as an encryption cache.
|
||||
sc, err := cryptokeys.NewSigningCache(logger, db, database.CryptoKeyFeatureOidcConvert, cryptokeys.WithDBCacheClock(clock))
|
||||
require.NoError(t, err)
|
||||
defer sc.Close()
|
||||
|
||||
ec, ok := sc.(cryptokeys.EncryptionKeycache)
|
||||
require.True(t, ok)
|
||||
_, _, err = ec.EncryptingKey(ctx)
|
||||
require.ErrorIs(t, err, cryptokeys.ErrInvalidFeature)
|
||||
|
||||
_, err = ec.DecryptingKey(ctx, "123")
|
||||
require.ErrorIs(t, err, cryptokeys.ErrInvalidFeature)
|
||||
})
|
||||
|
||||
t.Run("InvalidEncryptionFeature", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
db, _ = dbtestutil.NewDB(t)
|
||||
clock = quartz.NewMock(t)
|
||||
logger = slogtest.Make(t, nil)
|
||||
ctx = testutil.Context(t, testutil.WaitShort)
|
||||
)
|
||||
|
||||
_, err := cryptokeys.NewEncryptionCache(logger, db, database.CryptoKeyFeatureOidcConvert, cryptokeys.WithDBCacheClock(clock))
|
||||
require.ErrorIs(t, err, cryptokeys.ErrInvalidFeature)
|
||||
|
||||
// Instantiate an encryption cache and try to use it as a signing cache.
|
||||
ec, err := cryptokeys.NewEncryptionCache(logger, db, database.CryptoKeyFeatureWorkspaceApps, cryptokeys.WithDBCacheClock(clock))
|
||||
require.NoError(t, err)
|
||||
defer ec.Close()
|
||||
|
||||
sc, ok := ec.(cryptokeys.SigningKeycache)
|
||||
require.True(t, ok)
|
||||
_, _, err = sc.SigningKey(ctx)
|
||||
require.ErrorIs(t, err, cryptokeys.ErrInvalidFeature)
|
||||
|
||||
_, err = sc.VerifyingKey(ctx, "123")
|
||||
require.ErrorIs(t, err, cryptokeys.ErrInvalidFeature)
|
||||
})
|
||||
}
|
||||
|
||||
func keyID(key database.CryptoKey) string {
|
||||
return strconv.FormatInt(int64(key.Sequence), 10)
|
||||
}
|
||||
|
||||
func decodedSecret(t *testing.T, key database.CryptoKey) []byte {
|
||||
t.Helper()
|
||||
|
||||
secret, err := key.DecodeString()
|
||||
require.NoError(t, err)
|
||||
|
||||
return secret
|
||||
}
|
@ -1,41 +0,0 @@
|
||||
package cryptokeys
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrKeyNotFound = xerrors.New("key not found")
|
||||
ErrKeyInvalid = xerrors.New("key is invalid for use")
|
||||
ErrClosed = xerrors.New("closed")
|
||||
ErrInvalidFeature = xerrors.New("invalid feature for this operation")
|
||||
)
|
||||
|
||||
type EncryptionKeycache interface {
|
||||
// EncryptingKey returns the latest valid key for encrypting payloads. A valid
|
||||
// key is one that is both past its start time and before its deletion time.
|
||||
EncryptingKey(ctx context.Context) (id string, key interface{}, err error)
|
||||
// DecryptingKey returns the key with the provided id which maps to its sequence
|
||||
// number. The key is valid for decryption as long as it is not deleted or past
|
||||
// its deletion date. We must allow for keys prior to their start time to
|
||||
// account for clock skew between peers (one key may be past its start time on
|
||||
// one machine while another is not).
|
||||
DecryptingKey(ctx context.Context, id string) (key interface{}, err error)
|
||||
io.Closer
|
||||
}
|
||||
|
||||
type SigningKeycache interface {
|
||||
// SigningKey returns the latest valid key for signing. A valid key is one
|
||||
// that is both past its start time and before its deletion time.
|
||||
SigningKey(ctx context.Context) (id string, key interface{}, err error)
|
||||
// VerifyingKey returns the key with the provided id which should map to its
|
||||
// sequence number. The key is valid for verifying as long as it is not deleted
|
||||
// or past its deletion date. We must allow for keys prior to their start time
|
||||
// to account for clock skew between peers (one key may be past its start time
|
||||
// on one machine while another is not).
|
||||
VerifyingKey(ctx context.Context, id string) (key interface{}, err error)
|
||||
io.Closer
|
||||
}
|
@ -27,7 +27,7 @@ type DecryptKeyProvider interface {
|
||||
func Encrypt(ctx context.Context, e EncryptKeyProvider, claims Claims) (string, error) {
|
||||
id, key, err := e.EncryptingKey(ctx)
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("get signing key: %w", err)
|
||||
return "", xerrors.Errorf("encrypting key: %w", err)
|
||||
}
|
||||
|
||||
encrypter, err := jose.NewEncrypter(
|
||||
|
@ -18,6 +18,7 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/coderd/jwtutils"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
@ -238,10 +239,11 @@ func TestJWS(t *testing.T) {
|
||||
Feature: database.CryptoKeyFeatureOidcConvert,
|
||||
StartsAt: time.Now(),
|
||||
})
|
||||
log = slogtest.Make(t, nil)
|
||||
log = slogtest.Make(t, nil)
|
||||
fetcher = &cryptokeys.DBFetcher{DB: db, Feature: database.CryptoKeyFeatureOidcConvert}
|
||||
)
|
||||
|
||||
cache, err := cryptokeys.NewSigningCache(log, db, database.CryptoKeyFeatureOidcConvert)
|
||||
cache, err := cryptokeys.NewSigningCache(ctx, log, fetcher, codersdk.CryptoKeyFeatureOIDCConvert)
|
||||
require.NoError(t, err)
|
||||
|
||||
claims := testClaims{
|
||||
@ -328,9 +330,11 @@ func TestJWE(t *testing.T) {
|
||||
StartsAt: time.Now(),
|
||||
})
|
||||
log = slogtest.Make(t, nil)
|
||||
|
||||
fetcher = &cryptokeys.DBFetcher{DB: db, Feature: database.CryptoKeyFeatureWorkspaceApps}
|
||||
)
|
||||
|
||||
cache, err := cryptokeys.NewEncryptionCache(log, db, database.CryptoKeyFeatureWorkspaceApps)
|
||||
cache, err := cryptokeys.NewEncryptionCache(ctx, log, fetcher, codersdk.CryptoKeyFeatureWorkspaceApp)
|
||||
require.NoError(t, err)
|
||||
|
||||
claims := testClaims{
|
||||
|
Reference in New Issue
Block a user