Files
coder/coderd/cryptokeys/cache_test.go
Jon Ayers cd890aa3a0 feat: enable key rotation (#15066)
This PR contains the remaining logic necessary to hook up key rotation
to the product.
2024-10-25 17:14:35 +01:00

518 lines
13 KiB
Go

package cryptokeys_test
import (
"context"
"crypto/rand"
"encoding/hex"
"strconv"
"testing"
"time"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/cryptokeys"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
func TestCryptoKeyCache(t *testing.T) {
t.Parallel()
t.Run("Signing", func(t *testing.T) {
t.Parallel()
t.Run("HitsCache", func(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitShort)
logger = slogtest.Make(t, nil)
clock = quartz.NewMock(t)
)
now := clock.Now().UTC()
expected := codersdk.CryptoKey{
Feature: codersdk.CryptoKeyFeatureTailnetResume,
Secret: generateKey(t, 64),
Sequence: 2,
StartsAt: now,
}
ff := &fakeFetcher{
keys: []codersdk.CryptoKey{expected},
}
cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock))
require.NoError(t, err)
id, got, err := cache.SigningKey(ctx)
require.NoError(t, err)
require.Equal(t, keyID(expected), id)
require.Equal(t, decodedSecret(t, expected), got)
require.Equal(t, 1, ff.called)
})
t.Run("MissesCache", func(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitShort)
logger = slogtest.Make(t, nil)
clock = quartz.NewMock(t)
)
ff := &fakeFetcher{
keys: []codersdk.CryptoKey{},
}
cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock))
require.NoError(t, err)
expected := codersdk.CryptoKey{
Feature: codersdk.CryptoKeyFeatureTailnetResume,
Secret: generateKey(t, 64),
Sequence: 12,
StartsAt: clock.Now().UTC(),
}
ff.keys = []codersdk.CryptoKey{expected}
id, got, err := cache.SigningKey(ctx)
require.NoError(t, err)
require.Equal(t, decodedSecret(t, expected), got)
require.Equal(t, keyID(expected), id)
// 1 on startup + missing cache.
require.Equal(t, 2, ff.called)
// Ensure the cache gets hit this time.
id, got, err = cache.SigningKey(ctx)
require.NoError(t, err)
require.Equal(t, decodedSecret(t, expected), got)
require.Equal(t, keyID(expected), id)
// 1 on startup + missing cache.
require.Equal(t, 2, ff.called)
})
t.Run("IgnoresInvalid", func(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitShort)
logger = slogtest.Make(t, nil)
clock = quartz.NewMock(t)
)
now := clock.Now().UTC()
expected := codersdk.CryptoKey{
Feature: codersdk.CryptoKeyFeatureTailnetResume,
Secret: generateKey(t, 64),
Sequence: 1,
StartsAt: clock.Now().UTC(),
}
ff := &fakeFetcher{
keys: []codersdk.CryptoKey{
expected,
{
Feature: codersdk.CryptoKeyFeatureTailnetResume,
Secret: generateKey(t, 64),
Sequence: 2,
StartsAt: now.Add(-time.Second),
DeletesAt: now,
},
},
}
cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock))
require.NoError(t, err)
id, got, err := cache.SigningKey(ctx)
require.NoError(t, err)
require.Equal(t, decodedSecret(t, expected), got)
require.Equal(t, keyID(expected), id)
require.Equal(t, 1, ff.called)
})
t.Run("KeyNotFound", func(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitShort)
logger = slogtest.Make(t, nil)
)
ff := &fakeFetcher{
keys: []codersdk.CryptoKey{},
}
cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume)
require.NoError(t, err)
_, _, err = cache.SigningKey(ctx)
require.ErrorIs(t, err, cryptokeys.ErrKeyNotFound)
})
})
t.Run("Verifying", func(t *testing.T) {
t.Parallel()
t.Run("HitsCache", func(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitShort)
logger = slogtest.Make(t, nil)
clock = quartz.NewMock(t)
)
now := clock.Now().UTC()
expected := codersdk.CryptoKey{
Feature: codersdk.CryptoKeyFeatureTailnetResume,
Secret: generateKey(t, 64),
Sequence: 12,
StartsAt: now,
}
ff := &fakeFetcher{
keys: []codersdk.CryptoKey{
expected,
{
Feature: codersdk.CryptoKeyFeatureTailnetResume,
Secret: generateKey(t, 64),
Sequence: 13,
StartsAt: now,
},
},
}
cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock))
require.NoError(t, err)
got, err := cache.VerifyingKey(ctx, keyID(expected))
require.NoError(t, err)
require.Equal(t, decodedSecret(t, expected), got)
require.Equal(t, 1, ff.called)
})
t.Run("MissesCache", func(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitShort)
logger = slogtest.Make(t, nil)
clock = quartz.NewMock(t)
)
ff := &fakeFetcher{
keys: []codersdk.CryptoKey{},
}
cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock))
require.NoError(t, err)
expected := codersdk.CryptoKey{
Feature: codersdk.CryptoKeyFeatureTailnetResume,
Secret: generateKey(t, 64),
Sequence: 12,
StartsAt: clock.Now().UTC(),
}
ff.keys = []codersdk.CryptoKey{expected}
got, err := cache.VerifyingKey(ctx, keyID(expected))
require.NoError(t, err)
require.Equal(t, decodedSecret(t, expected), got)
require.Equal(t, 2, ff.called)
// Ensure the cache gets hit this time.
got, err = cache.VerifyingKey(ctx, keyID(expected))
require.NoError(t, err)
require.Equal(t, decodedSecret(t, expected), got)
require.Equal(t, 2, ff.called)
})
t.Run("AllowsBeforeStartsAt", func(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitShort)
logger = slogtest.Make(t, nil)
clock = quartz.NewMock(t)
)
now := clock.Now().UTC()
expected := codersdk.CryptoKey{
Feature: codersdk.CryptoKeyFeatureTailnetResume,
Secret: generateKey(t, 64),
Sequence: 12,
StartsAt: now.Add(-time.Second),
}
ff := &fakeFetcher{
keys: []codersdk.CryptoKey{
expected,
},
}
cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock))
require.NoError(t, err)
got, err := cache.VerifyingKey(ctx, keyID(expected))
require.NoError(t, err)
require.Equal(t, decodedSecret(t, expected), got)
require.Equal(t, 1, ff.called)
})
t.Run("KeyPastDeletesAt", func(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitShort)
logger = slogtest.Make(t, nil)
clock = quartz.NewMock(t)
)
now := clock.Now().UTC()
expected := codersdk.CryptoKey{
Feature: codersdk.CryptoKeyFeatureTailnetResume,
Secret: generateKey(t, 64),
Sequence: 12,
StartsAt: now.Add(-time.Second),
DeletesAt: now,
}
ff := &fakeFetcher{
keys: []codersdk.CryptoKey{
expected,
},
}
cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock))
require.NoError(t, err)
_, err = cache.VerifyingKey(ctx, keyID(expected))
require.ErrorIs(t, err, cryptokeys.ErrKeyInvalid)
require.Equal(t, 1, ff.called)
})
t.Run("KeyNotFound", func(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitShort)
logger = slogtest.Make(t, nil)
clock = quartz.NewMock(t)
)
ff := &fakeFetcher{
keys: []codersdk.CryptoKey{},
}
cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock))
require.NoError(t, err)
_, err = cache.VerifyingKey(ctx, "1")
require.ErrorIs(t, err, cryptokeys.ErrKeyNotFound)
})
})
t.Run("CacheRefreshes", func(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitShort)
logger = slogtest.Make(t, nil)
clock = quartz.NewMock(t)
)
now := clock.Now().UTC()
expected := codersdk.CryptoKey{
Feature: codersdk.CryptoKeyFeatureTailnetResume,
Secret: generateKey(t, 64),
Sequence: 12,
StartsAt: now,
DeletesAt: now.Add(time.Minute * 10),
}
ff := &fakeFetcher{
keys: []codersdk.CryptoKey{
expected,
},
}
cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock))
require.NoError(t, err)
id, got, err := cache.SigningKey(ctx)
require.NoError(t, err)
require.Equal(t, decodedSecret(t, expected), got)
require.Equal(t, keyID(expected), id)
require.Equal(t, 1, ff.called)
newKey := codersdk.CryptoKey{
Feature: codersdk.CryptoKeyFeatureTailnetResume,
Secret: generateKey(t, 64),
Sequence: 13,
StartsAt: now,
}
ff.keys = []codersdk.CryptoKey{newKey}
// The ticker should fire and cause a request to coderd.
dur, advance := clock.AdvanceNext()
advance.MustWait(ctx)
require.Equal(t, 2, ff.called)
require.Equal(t, time.Minute*10, dur)
// Assert hits cache.
id, got, err = cache.SigningKey(ctx)
require.NoError(t, err)
require.Equal(t, keyID(newKey), id)
require.Equal(t, decodedSecret(t, newKey), got)
require.Equal(t, 2, ff.called)
// We check again to ensure the timer has been reset.
_, advance = clock.AdvanceNext()
advance.MustWait(ctx)
require.Equal(t, 3, ff.called)
require.Equal(t, time.Minute*10, dur)
})
// This test ensures that if the refresh timer races with an inflight request
// and loses that it doesn't cause a redundant fetch.
t.Run("RefreshNoDoubleFetch", func(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitShort)
logger = slogtest.Make(t, nil)
clock = quartz.NewMock(t)
)
now := clock.Now().UTC()
expected := codersdk.CryptoKey{
Feature: codersdk.CryptoKeyFeatureTailnetResume,
Secret: generateKey(t, 64),
Sequence: 12,
StartsAt: now,
DeletesAt: now.Add(time.Minute * 10),
}
ff := &fakeFetcher{
keys: []codersdk.CryptoKey{
expected,
},
}
// Create a trap that blocks when the refresh timer fires.
trap := clock.Trap().Now("refresh")
cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock))
require.NoError(t, err)
_, wait := clock.AdvanceNext()
trapped := trap.MustWait(ctx)
newKey := codersdk.CryptoKey{
Feature: codersdk.CryptoKeyFeatureTailnetResume,
Secret: generateKey(t, 64),
Sequence: 13,
StartsAt: now,
}
ff.keys = []codersdk.CryptoKey{newKey}
key, err := cache.VerifyingKey(ctx, keyID(newKey))
require.NoError(t, err)
require.Equal(t, 2, ff.called)
require.Equal(t, decodedSecret(t, newKey), key)
trapped.Release()
wait.MustWait(ctx)
require.Equal(t, 2, ff.called)
trap.Close()
// The next timer should fire in 10 minutes.
dur, wait := clock.AdvanceNext()
wait.MustWait(ctx)
require.Equal(t, time.Minute*10, dur)
require.Equal(t, 3, ff.called)
})
t.Run("Closed", func(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitShort)
logger = slogtest.Make(t, nil)
clock = quartz.NewMock(t)
)
now := clock.Now()
expected := codersdk.CryptoKey{
Feature: codersdk.CryptoKeyFeatureTailnetResume,
Secret: generateKey(t, 64),
Sequence: 12,
StartsAt: now,
}
ff := &fakeFetcher{
keys: []codersdk.CryptoKey{
expected,
},
}
cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock))
require.NoError(t, err)
id, got, err := cache.SigningKey(ctx)
require.NoError(t, err)
require.Equal(t, keyID(expected), id)
require.Equal(t, decodedSecret(t, expected), got)
require.Equal(t, 1, ff.called)
key, err := cache.VerifyingKey(ctx, keyID(expected))
require.NoError(t, err)
require.Equal(t, decodedSecret(t, expected), key)
require.Equal(t, 1, ff.called)
cache.Close()
_, _, err = cache.SigningKey(ctx)
require.ErrorIs(t, err, cryptokeys.ErrClosed)
_, err = cache.VerifyingKey(ctx, keyID(expected))
require.ErrorIs(t, err, cryptokeys.ErrClosed)
})
}
type fakeFetcher struct {
keys []codersdk.CryptoKey
called int
}
func (f *fakeFetcher) Fetch(_ context.Context, _ codersdk.CryptoKeyFeature) ([]codersdk.CryptoKey, error) {
f.called++
return f.keys, nil
}
func keyID(key codersdk.CryptoKey) string {
return strconv.FormatInt(int64(key.Sequence), 10)
}
func decodedSecret(t *testing.T, key codersdk.CryptoKey) []byte {
t.Helper()
secret, err := hex.DecodeString(key.Secret)
require.NoError(t, err)
return secret
}
func generateKey(t *testing.T, size int) string {
t.Helper()
key := make([]byte, size)
_, err := rand.Read(key)
require.NoError(t, err)
return hex.EncodeToString(key)
}