mirror of
https://github.com/coder/coder.git
synced 2025-07-03 16:13:58 +00:00
518 lines
13 KiB
Go
518 lines
13 KiB
Go
package cryptokeys_test
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"encoding/hex"
|
|
"strconv"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/require"
|
|
"go.uber.org/goleak"
|
|
|
|
"cdr.dev/slog/sloggers/slogtest"
|
|
|
|
"github.com/coder/coder/v2/coderd/cryptokeys"
|
|
"github.com/coder/coder/v2/codersdk"
|
|
"github.com/coder/coder/v2/testutil"
|
|
"github.com/coder/quartz"
|
|
)
|
|
|
|
func TestMain(m *testing.M) {
|
|
goleak.VerifyTestMain(m)
|
|
}
|
|
|
|
func TestCryptoKeyCache(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("Signing", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("HitsCache", func(t *testing.T) {
|
|
t.Parallel()
|
|
var (
|
|
ctx = testutil.Context(t, testutil.WaitShort)
|
|
logger = slogtest.Make(t, nil)
|
|
clock = quartz.NewMock(t)
|
|
)
|
|
|
|
now := clock.Now().UTC()
|
|
expected := codersdk.CryptoKey{
|
|
Feature: codersdk.CryptoKeyFeatureTailnetResume,
|
|
Secret: generateKey(t, 64),
|
|
Sequence: 2,
|
|
StartsAt: now,
|
|
}
|
|
|
|
ff := &fakeFetcher{
|
|
keys: []codersdk.CryptoKey{expected},
|
|
}
|
|
|
|
cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock))
|
|
require.NoError(t, err)
|
|
|
|
id, got, err := cache.SigningKey(ctx)
|
|
require.NoError(t, err)
|
|
require.Equal(t, keyID(expected), id)
|
|
require.Equal(t, decodedSecret(t, expected), got)
|
|
require.Equal(t, 1, ff.called)
|
|
})
|
|
|
|
t.Run("MissesCache", func(t *testing.T) {
|
|
t.Parallel()
|
|
var (
|
|
ctx = testutil.Context(t, testutil.WaitShort)
|
|
logger = slogtest.Make(t, nil)
|
|
clock = quartz.NewMock(t)
|
|
)
|
|
|
|
ff := &fakeFetcher{
|
|
keys: []codersdk.CryptoKey{},
|
|
}
|
|
|
|
cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock))
|
|
require.NoError(t, err)
|
|
|
|
expected := codersdk.CryptoKey{
|
|
Feature: codersdk.CryptoKeyFeatureTailnetResume,
|
|
Secret: generateKey(t, 64),
|
|
Sequence: 12,
|
|
StartsAt: clock.Now().UTC(),
|
|
}
|
|
ff.keys = []codersdk.CryptoKey{expected}
|
|
|
|
id, got, err := cache.SigningKey(ctx)
|
|
require.NoError(t, err)
|
|
require.Equal(t, decodedSecret(t, expected), got)
|
|
require.Equal(t, keyID(expected), id)
|
|
// 1 on startup + missing cache.
|
|
require.Equal(t, 2, ff.called)
|
|
|
|
// Ensure the cache gets hit this time.
|
|
id, got, err = cache.SigningKey(ctx)
|
|
require.NoError(t, err)
|
|
require.Equal(t, decodedSecret(t, expected), got)
|
|
require.Equal(t, keyID(expected), id)
|
|
// 1 on startup + missing cache.
|
|
require.Equal(t, 2, ff.called)
|
|
})
|
|
|
|
t.Run("IgnoresInvalid", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var (
|
|
ctx = testutil.Context(t, testutil.WaitShort)
|
|
logger = slogtest.Make(t, nil)
|
|
clock = quartz.NewMock(t)
|
|
)
|
|
now := clock.Now().UTC()
|
|
|
|
expected := codersdk.CryptoKey{
|
|
Feature: codersdk.CryptoKeyFeatureTailnetResume,
|
|
Secret: generateKey(t, 64),
|
|
Sequence: 1,
|
|
StartsAt: clock.Now().UTC(),
|
|
}
|
|
|
|
ff := &fakeFetcher{
|
|
keys: []codersdk.CryptoKey{
|
|
expected,
|
|
{
|
|
Feature: codersdk.CryptoKeyFeatureTailnetResume,
|
|
Secret: generateKey(t, 64),
|
|
Sequence: 2,
|
|
StartsAt: now.Add(-time.Second),
|
|
DeletesAt: now,
|
|
},
|
|
},
|
|
}
|
|
|
|
cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock))
|
|
require.NoError(t, err)
|
|
|
|
id, got, err := cache.SigningKey(ctx)
|
|
require.NoError(t, err)
|
|
require.Equal(t, decodedSecret(t, expected), got)
|
|
require.Equal(t, keyID(expected), id)
|
|
require.Equal(t, 1, ff.called)
|
|
})
|
|
|
|
t.Run("KeyNotFound", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var (
|
|
ctx = testutil.Context(t, testutil.WaitShort)
|
|
logger = slogtest.Make(t, nil)
|
|
)
|
|
|
|
ff := &fakeFetcher{
|
|
keys: []codersdk.CryptoKey{},
|
|
}
|
|
|
|
cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume)
|
|
require.NoError(t, err)
|
|
|
|
_, _, err = cache.SigningKey(ctx)
|
|
require.ErrorIs(t, err, cryptokeys.ErrKeyNotFound)
|
|
})
|
|
})
|
|
|
|
t.Run("Verifying", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("HitsCache", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var (
|
|
ctx = testutil.Context(t, testutil.WaitShort)
|
|
logger = slogtest.Make(t, nil)
|
|
clock = quartz.NewMock(t)
|
|
)
|
|
|
|
now := clock.Now().UTC()
|
|
expected := codersdk.CryptoKey{
|
|
Feature: codersdk.CryptoKeyFeatureTailnetResume,
|
|
Secret: generateKey(t, 64),
|
|
Sequence: 12,
|
|
StartsAt: now,
|
|
}
|
|
ff := &fakeFetcher{
|
|
keys: []codersdk.CryptoKey{
|
|
expected,
|
|
{
|
|
Feature: codersdk.CryptoKeyFeatureTailnetResume,
|
|
Secret: generateKey(t, 64),
|
|
Sequence: 13,
|
|
StartsAt: now,
|
|
},
|
|
},
|
|
}
|
|
|
|
cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock))
|
|
require.NoError(t, err)
|
|
|
|
got, err := cache.VerifyingKey(ctx, keyID(expected))
|
|
require.NoError(t, err)
|
|
require.Equal(t, decodedSecret(t, expected), got)
|
|
require.Equal(t, 1, ff.called)
|
|
})
|
|
|
|
t.Run("MissesCache", func(t *testing.T) {
|
|
t.Parallel()
|
|
var (
|
|
ctx = testutil.Context(t, testutil.WaitShort)
|
|
logger = slogtest.Make(t, nil)
|
|
clock = quartz.NewMock(t)
|
|
)
|
|
|
|
ff := &fakeFetcher{
|
|
keys: []codersdk.CryptoKey{},
|
|
}
|
|
|
|
cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock))
|
|
require.NoError(t, err)
|
|
|
|
expected := codersdk.CryptoKey{
|
|
Feature: codersdk.CryptoKeyFeatureTailnetResume,
|
|
Secret: generateKey(t, 64),
|
|
Sequence: 12,
|
|
StartsAt: clock.Now().UTC(),
|
|
}
|
|
ff.keys = []codersdk.CryptoKey{expected}
|
|
|
|
got, err := cache.VerifyingKey(ctx, keyID(expected))
|
|
require.NoError(t, err)
|
|
require.Equal(t, decodedSecret(t, expected), got)
|
|
require.Equal(t, 2, ff.called)
|
|
|
|
// Ensure the cache gets hit this time.
|
|
got, err = cache.VerifyingKey(ctx, keyID(expected))
|
|
require.NoError(t, err)
|
|
require.Equal(t, decodedSecret(t, expected), got)
|
|
require.Equal(t, 2, ff.called)
|
|
})
|
|
|
|
t.Run("AllowsBeforeStartsAt", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var (
|
|
ctx = testutil.Context(t, testutil.WaitShort)
|
|
logger = slogtest.Make(t, nil)
|
|
clock = quartz.NewMock(t)
|
|
)
|
|
|
|
now := clock.Now().UTC()
|
|
expected := codersdk.CryptoKey{
|
|
Feature: codersdk.CryptoKeyFeatureTailnetResume,
|
|
Secret: generateKey(t, 64),
|
|
Sequence: 12,
|
|
StartsAt: now.Add(-time.Second),
|
|
}
|
|
|
|
ff := &fakeFetcher{
|
|
keys: []codersdk.CryptoKey{
|
|
expected,
|
|
},
|
|
}
|
|
|
|
cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock))
|
|
require.NoError(t, err)
|
|
|
|
got, err := cache.VerifyingKey(ctx, keyID(expected))
|
|
require.NoError(t, err)
|
|
require.Equal(t, decodedSecret(t, expected), got)
|
|
require.Equal(t, 1, ff.called)
|
|
})
|
|
|
|
t.Run("KeyPastDeletesAt", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var (
|
|
ctx = testutil.Context(t, testutil.WaitShort)
|
|
logger = slogtest.Make(t, nil)
|
|
clock = quartz.NewMock(t)
|
|
)
|
|
|
|
now := clock.Now().UTC()
|
|
expected := codersdk.CryptoKey{
|
|
Feature: codersdk.CryptoKeyFeatureTailnetResume,
|
|
Secret: generateKey(t, 64),
|
|
Sequence: 12,
|
|
StartsAt: now.Add(-time.Second),
|
|
DeletesAt: now,
|
|
}
|
|
|
|
ff := &fakeFetcher{
|
|
keys: []codersdk.CryptoKey{
|
|
expected,
|
|
},
|
|
}
|
|
|
|
cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock))
|
|
require.NoError(t, err)
|
|
|
|
_, err = cache.VerifyingKey(ctx, keyID(expected))
|
|
require.ErrorIs(t, err, cryptokeys.ErrKeyInvalid)
|
|
require.Equal(t, 1, ff.called)
|
|
})
|
|
|
|
t.Run("KeyNotFound", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var (
|
|
ctx = testutil.Context(t, testutil.WaitShort)
|
|
logger = slogtest.Make(t, nil)
|
|
clock = quartz.NewMock(t)
|
|
)
|
|
|
|
ff := &fakeFetcher{
|
|
keys: []codersdk.CryptoKey{},
|
|
}
|
|
|
|
cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock))
|
|
require.NoError(t, err)
|
|
|
|
_, err = cache.VerifyingKey(ctx, "1")
|
|
require.ErrorIs(t, err, cryptokeys.ErrKeyNotFound)
|
|
})
|
|
})
|
|
|
|
t.Run("CacheRefreshes", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var (
|
|
ctx = testutil.Context(t, testutil.WaitShort)
|
|
logger = slogtest.Make(t, nil)
|
|
clock = quartz.NewMock(t)
|
|
)
|
|
|
|
now := clock.Now().UTC()
|
|
expected := codersdk.CryptoKey{
|
|
Feature: codersdk.CryptoKeyFeatureTailnetResume,
|
|
Secret: generateKey(t, 64),
|
|
Sequence: 12,
|
|
StartsAt: now,
|
|
DeletesAt: now.Add(time.Minute * 10),
|
|
}
|
|
ff := &fakeFetcher{
|
|
keys: []codersdk.CryptoKey{
|
|
expected,
|
|
},
|
|
}
|
|
|
|
cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock))
|
|
require.NoError(t, err)
|
|
|
|
id, got, err := cache.SigningKey(ctx)
|
|
require.NoError(t, err)
|
|
require.Equal(t, decodedSecret(t, expected), got)
|
|
require.Equal(t, keyID(expected), id)
|
|
require.Equal(t, 1, ff.called)
|
|
|
|
newKey := codersdk.CryptoKey{
|
|
Feature: codersdk.CryptoKeyFeatureTailnetResume,
|
|
Secret: generateKey(t, 64),
|
|
Sequence: 13,
|
|
StartsAt: now,
|
|
}
|
|
ff.keys = []codersdk.CryptoKey{newKey}
|
|
|
|
// The ticker should fire and cause a request to coderd.
|
|
dur, advance := clock.AdvanceNext()
|
|
advance.MustWait(ctx)
|
|
require.Equal(t, 2, ff.called)
|
|
require.Equal(t, time.Minute*10, dur)
|
|
|
|
// Assert hits cache.
|
|
id, got, err = cache.SigningKey(ctx)
|
|
require.NoError(t, err)
|
|
require.Equal(t, keyID(newKey), id)
|
|
require.Equal(t, decodedSecret(t, newKey), got)
|
|
require.Equal(t, 2, ff.called)
|
|
|
|
// We check again to ensure the timer has been reset.
|
|
_, advance = clock.AdvanceNext()
|
|
advance.MustWait(ctx)
|
|
require.Equal(t, 3, ff.called)
|
|
require.Equal(t, time.Minute*10, dur)
|
|
})
|
|
|
|
// This test ensures that if the refresh timer races with an inflight request
|
|
// and loses that it doesn't cause a redundant fetch.
|
|
|
|
t.Run("RefreshNoDoubleFetch", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var (
|
|
ctx = testutil.Context(t, testutil.WaitShort)
|
|
logger = slogtest.Make(t, nil)
|
|
clock = quartz.NewMock(t)
|
|
)
|
|
|
|
now := clock.Now().UTC()
|
|
expected := codersdk.CryptoKey{
|
|
Feature: codersdk.CryptoKeyFeatureTailnetResume,
|
|
Secret: generateKey(t, 64),
|
|
Sequence: 12,
|
|
StartsAt: now,
|
|
DeletesAt: now.Add(time.Minute * 10),
|
|
}
|
|
ff := &fakeFetcher{
|
|
keys: []codersdk.CryptoKey{
|
|
expected,
|
|
},
|
|
}
|
|
|
|
// Create a trap that blocks when the refresh timer fires.
|
|
trap := clock.Trap().Now("refresh")
|
|
cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock))
|
|
require.NoError(t, err)
|
|
|
|
_, wait := clock.AdvanceNext()
|
|
trapped := trap.MustWait(ctx)
|
|
|
|
newKey := codersdk.CryptoKey{
|
|
Feature: codersdk.CryptoKeyFeatureTailnetResume,
|
|
Secret: generateKey(t, 64),
|
|
Sequence: 13,
|
|
StartsAt: now,
|
|
}
|
|
ff.keys = []codersdk.CryptoKey{newKey}
|
|
|
|
key, err := cache.VerifyingKey(ctx, keyID(newKey))
|
|
require.NoError(t, err)
|
|
require.Equal(t, 2, ff.called)
|
|
require.Equal(t, decodedSecret(t, newKey), key)
|
|
|
|
trapped.Release()
|
|
wait.MustWait(ctx)
|
|
require.Equal(t, 2, ff.called)
|
|
trap.Close()
|
|
|
|
// The next timer should fire in 10 minutes.
|
|
dur, wait := clock.AdvanceNext()
|
|
wait.MustWait(ctx)
|
|
require.Equal(t, time.Minute*10, dur)
|
|
require.Equal(t, 3, ff.called)
|
|
})
|
|
|
|
t.Run("Closed", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var (
|
|
ctx = testutil.Context(t, testutil.WaitShort)
|
|
logger = slogtest.Make(t, nil)
|
|
clock = quartz.NewMock(t)
|
|
)
|
|
|
|
now := clock.Now()
|
|
expected := codersdk.CryptoKey{
|
|
Feature: codersdk.CryptoKeyFeatureTailnetResume,
|
|
Secret: generateKey(t, 64),
|
|
Sequence: 12,
|
|
StartsAt: now,
|
|
}
|
|
ff := &fakeFetcher{
|
|
keys: []codersdk.CryptoKey{
|
|
expected,
|
|
},
|
|
}
|
|
|
|
cache, err := cryptokeys.NewSigningCache(ctx, logger, ff, codersdk.CryptoKeyFeatureTailnetResume, cryptokeys.WithCacheClock(clock))
|
|
require.NoError(t, err)
|
|
|
|
id, got, err := cache.SigningKey(ctx)
|
|
require.NoError(t, err)
|
|
require.Equal(t, keyID(expected), id)
|
|
require.Equal(t, decodedSecret(t, expected), got)
|
|
require.Equal(t, 1, ff.called)
|
|
|
|
key, err := cache.VerifyingKey(ctx, keyID(expected))
|
|
require.NoError(t, err)
|
|
require.Equal(t, decodedSecret(t, expected), key)
|
|
require.Equal(t, 1, ff.called)
|
|
|
|
cache.Close()
|
|
|
|
_, _, err = cache.SigningKey(ctx)
|
|
require.ErrorIs(t, err, cryptokeys.ErrClosed)
|
|
|
|
_, err = cache.VerifyingKey(ctx, keyID(expected))
|
|
require.ErrorIs(t, err, cryptokeys.ErrClosed)
|
|
})
|
|
}
|
|
|
|
type fakeFetcher struct {
|
|
keys []codersdk.CryptoKey
|
|
called int
|
|
}
|
|
|
|
func (f *fakeFetcher) Fetch(_ context.Context, _ codersdk.CryptoKeyFeature) ([]codersdk.CryptoKey, error) {
|
|
f.called++
|
|
return f.keys, nil
|
|
}
|
|
|
|
func keyID(key codersdk.CryptoKey) string {
|
|
return strconv.FormatInt(int64(key.Sequence), 10)
|
|
}
|
|
|
|
func decodedSecret(t *testing.T, key codersdk.CryptoKey) []byte {
|
|
t.Helper()
|
|
|
|
secret, err := hex.DecodeString(key.Secret)
|
|
require.NoError(t, err)
|
|
|
|
return secret
|
|
}
|
|
|
|
func generateKey(t *testing.T, size int) string {
|
|
t.Helper()
|
|
|
|
key := make([]byte, size)
|
|
_, err := rand.Read(key)
|
|
require.NoError(t, err)
|
|
|
|
return hex.EncodeToString(key)
|
|
}
|