mirror of
https://github.com/coder/coder.git
synced 2025-07-15 22:20:27 +00:00
feat: implement key rotation system (#14710)
This commit is contained in:
@ -902,7 +902,11 @@ func CryptoKey(t testing.TB, db database.Store, seed database.CryptoKey) databas
|
||||
|
||||
seed.Feature = takeFirst(seed.Feature, database.CryptoKeyFeatureWorkspaceApps)
|
||||
|
||||
if !seed.Secret.Valid {
|
||||
// An empty string for the secret is interpreted as
|
||||
// a caller wanting a new secret to be generated.
|
||||
// To generate a key with a NULL secret set Valid=false
|
||||
// and String to a non-empty string.
|
||||
if seed.Secret.String == "" {
|
||||
secret, err := newCryptoKeySecret(seed.Feature)
|
||||
require.NoError(t, err, "generate secret")
|
||||
seed.Secret = sql.NullString{
|
||||
|
@ -11,6 +11,7 @@ const (
|
||||
LockIDDBRollup
|
||||
LockIDDBPurge
|
||||
LockIDNotificationsReportGenerator
|
||||
LockIDCryptoKeyRotation
|
||||
)
|
||||
|
||||
// GenLockID generates a unique and consistent lock ID from a given string.
|
||||
|
298
coderd/keyrotate/rotate.go
Normal file
298
coderd/keyrotate/rotate.go
Normal file
@ -0,0 +1,298 @@
|
||||
package keyrotate
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"time"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
const (
|
||||
WorkspaceAppsTokenDuration = time.Minute
|
||||
OIDCConvertTokenDuration = time.Minute * 5
|
||||
TailnetResumeTokenDuration = time.Hour * 24
|
||||
|
||||
// defaultRotationInterval is the default interval at which keys are checked for rotation.
|
||||
defaultRotationInterval = time.Minute * 10
|
||||
// DefaultKeyDuration is the default duration for which a key is valid. It applies to all features.
|
||||
DefaultKeyDuration = time.Hour * 24 * 30
|
||||
)
|
||||
|
||||
// rotator is responsible for rotating keys in the database.
|
||||
type rotator struct {
|
||||
db database.Store
|
||||
logger slog.Logger
|
||||
clock quartz.Clock
|
||||
keyDuration time.Duration
|
||||
|
||||
features []database.CryptoKeyFeature
|
||||
}
|
||||
|
||||
type Option func(*rotator)
|
||||
|
||||
func WithClock(clock quartz.Clock) Option {
|
||||
return func(r *rotator) {
|
||||
r.clock = clock
|
||||
}
|
||||
}
|
||||
|
||||
func WithKeyDuration(keyDuration time.Duration) Option {
|
||||
return func(r *rotator) {
|
||||
r.keyDuration = keyDuration
|
||||
}
|
||||
}
|
||||
|
||||
// StartRotator starts a background process that rotates keys in the database.
|
||||
// It ensures there's at least one valid key per feature prior to returning.
|
||||
// Canceling the provided context will stop the background process.
|
||||
func StartRotator(ctx context.Context, logger slog.Logger, db database.Store, opts ...Option) error {
|
||||
kr := &rotator{
|
||||
db: db,
|
||||
logger: logger,
|
||||
clock: quartz.NewReal(),
|
||||
keyDuration: DefaultKeyDuration,
|
||||
features: database.AllCryptoKeyFeatureValues(),
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(kr)
|
||||
}
|
||||
|
||||
err := kr.rotateKeys(ctx)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("rotate keys: %w", err)
|
||||
}
|
||||
|
||||
go kr.start(ctx)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// start begins the process of rotating keys.
|
||||
// Canceling the context will stop the rotation process.
|
||||
func (k *rotator) start(ctx context.Context) {
|
||||
k.clock.TickerFunc(ctx, defaultRotationInterval, func() error {
|
||||
err := k.rotateKeys(ctx)
|
||||
if err != nil {
|
||||
k.logger.Error(ctx, "failed to rotate keys", slog.Error(err))
|
||||
}
|
||||
return nil
|
||||
})
|
||||
k.logger.Debug(ctx, "ctx canceled, stopping key rotation")
|
||||
}
|
||||
|
||||
// rotateKeys checks for any keys needing rotation or deletion and
|
||||
// may insert a new key if it detects that a valid one does
|
||||
// not exist for a feature.
|
||||
func (k *rotator) rotateKeys(ctx context.Context) error {
|
||||
return k.db.InTx(
|
||||
func(tx database.Store) error {
|
||||
err := tx.AcquireLock(ctx, database.LockIDCryptoKeyRotation)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("acquire lock: %w", err)
|
||||
}
|
||||
|
||||
cryptokeys, err := tx.GetCryptoKeys(ctx)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get keys: %w", err)
|
||||
}
|
||||
|
||||
featureKeys, err := keysByFeature(cryptokeys, k.features)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("keys by feature: %w", err)
|
||||
}
|
||||
|
||||
now := dbtime.Time(k.clock.Now().UTC())
|
||||
for feature, keys := range featureKeys {
|
||||
// We'll use a counter to determine if we should insert a new key. We should always have at least one key for a feature.
|
||||
var validKeys int
|
||||
for _, key := range keys {
|
||||
switch {
|
||||
case shouldDeleteKey(key, now):
|
||||
_, err := tx.DeleteCryptoKey(ctx, database.DeleteCryptoKeyParams{
|
||||
Feature: key.Feature,
|
||||
Sequence: key.Sequence,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("delete key: %w", err)
|
||||
}
|
||||
k.logger.Debug(ctx, "deleted key",
|
||||
slog.F("key", key.Sequence),
|
||||
slog.F("feature", key.Feature),
|
||||
)
|
||||
case shouldRotateKey(key, k.keyDuration, now):
|
||||
_, err := k.rotateKey(ctx, tx, key, now)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("rotate key: %w", err)
|
||||
}
|
||||
k.logger.Debug(ctx, "rotated key",
|
||||
slog.F("key", key.Sequence),
|
||||
slog.F("feature", key.Feature),
|
||||
)
|
||||
validKeys++
|
||||
default:
|
||||
// We only consider keys without a populated deletes_at field as valid.
|
||||
// This is because under normal circumstances the deletes_at field
|
||||
// is set during rotation (meaning a new key was generated)
|
||||
// but it's possible if the database was manually altered to
|
||||
// delete the new key we may be in a situation where there
|
||||
// isn't a key to replace the one scheduled for deletion.
|
||||
if !key.DeletesAt.Valid {
|
||||
validKeys++
|
||||
}
|
||||
}
|
||||
}
|
||||
if validKeys == 0 {
|
||||
k.logger.Info(ctx, "no valid keys detected, inserting new key",
|
||||
slog.F("feature", feature),
|
||||
)
|
||||
_, err := k.insertNewKey(ctx, tx, feature, now)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("insert new key: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}, &sql.TxOptions{
|
||||
Isolation: sql.LevelRepeatableRead,
|
||||
})
|
||||
}
|
||||
|
||||
func (k *rotator) insertNewKey(ctx context.Context, tx database.Store, feature database.CryptoKeyFeature, startsAt time.Time) (database.CryptoKey, error) {
|
||||
secret, err := generateNewSecret(feature)
|
||||
if err != nil {
|
||||
return database.CryptoKey{}, xerrors.Errorf("generate new secret: %w", err)
|
||||
}
|
||||
|
||||
latestKey, err := tx.GetLatestCryptoKeyByFeature(ctx, feature)
|
||||
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
|
||||
return database.CryptoKey{}, xerrors.Errorf("get latest key: %w", err)
|
||||
}
|
||||
|
||||
newKey, err := tx.InsertCryptoKey(ctx, database.InsertCryptoKeyParams{
|
||||
Feature: feature,
|
||||
Sequence: latestKey.Sequence + 1,
|
||||
Secret: sql.NullString{
|
||||
String: secret,
|
||||
Valid: true,
|
||||
},
|
||||
// Set by dbcrypt if it's required.
|
||||
SecretKeyID: sql.NullString{},
|
||||
StartsAt: startsAt.UTC(),
|
||||
})
|
||||
if err != nil {
|
||||
return database.CryptoKey{}, xerrors.Errorf("inserting new key: %w", err)
|
||||
}
|
||||
|
||||
k.logger.Info(ctx, "inserted new key for feature", slog.F("feature", feature))
|
||||
return newKey, nil
|
||||
}
|
||||
|
||||
func (k *rotator) rotateKey(ctx context.Context, tx database.Store, key database.CryptoKey, now time.Time) ([]database.CryptoKey, error) {
|
||||
startsAt := minStartsAt(key, now, k.keyDuration)
|
||||
newKey, err := k.insertNewKey(ctx, tx, key.Feature, startsAt)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("insert new key: %w", err)
|
||||
}
|
||||
|
||||
// Set old key's deletes_at to an hour + however long the token
|
||||
// for this feature is expected to be valid for. This should
|
||||
// allow for sufficient time for the new key to propagate to
|
||||
// dependent services (i.e. Workspace Proxies).
|
||||
deletesAt := startsAt.Add(time.Hour).Add(tokenDuration(key.Feature))
|
||||
|
||||
updatedKey, err := tx.UpdateCryptoKeyDeletesAt(ctx, database.UpdateCryptoKeyDeletesAtParams{
|
||||
Feature: key.Feature,
|
||||
Sequence: key.Sequence,
|
||||
DeletesAt: sql.NullTime{
|
||||
Time: deletesAt.UTC(),
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("update old key's deletes_at: %w", err)
|
||||
}
|
||||
|
||||
return []database.CryptoKey{updatedKey, newKey}, nil
|
||||
}
|
||||
|
||||
func generateNewSecret(feature database.CryptoKeyFeature) (string, error) {
|
||||
switch feature {
|
||||
case database.CryptoKeyFeatureWorkspaceApps:
|
||||
return generateKey(96)
|
||||
case database.CryptoKeyFeatureOidcConvert:
|
||||
return generateKey(32)
|
||||
case database.CryptoKeyFeatureTailnetResume:
|
||||
return generateKey(64)
|
||||
}
|
||||
return "", xerrors.Errorf("unknown feature: %s", feature)
|
||||
}
|
||||
|
||||
func generateKey(length int) (string, error) {
|
||||
b := make([]byte, length)
|
||||
_, err := rand.Read(b)
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("rand read: %w", err)
|
||||
}
|
||||
return hex.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
func tokenDuration(feature database.CryptoKeyFeature) time.Duration {
|
||||
switch feature {
|
||||
case database.CryptoKeyFeatureWorkspaceApps:
|
||||
return WorkspaceAppsTokenDuration
|
||||
case database.CryptoKeyFeatureOidcConvert:
|
||||
return OIDCConvertTokenDuration
|
||||
case database.CryptoKeyFeatureTailnetResume:
|
||||
return TailnetResumeTokenDuration
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func shouldDeleteKey(key database.CryptoKey, now time.Time) bool {
|
||||
return key.DeletesAt.Valid && !now.Before(key.DeletesAt.Time.UTC())
|
||||
}
|
||||
|
||||
func shouldRotateKey(key database.CryptoKey, keyDuration time.Duration, now time.Time) bool {
|
||||
// If deletes_at is set, we've already inserted a key.
|
||||
if key.DeletesAt.Valid {
|
||||
return false
|
||||
}
|
||||
expirationTime := key.ExpiresAt(keyDuration)
|
||||
return !now.Add(time.Hour).UTC().Before(expirationTime)
|
||||
}
|
||||
|
||||
func keysByFeature(keys []database.CryptoKey, features []database.CryptoKeyFeature) (map[database.CryptoKeyFeature][]database.CryptoKey, error) {
|
||||
m := map[database.CryptoKeyFeature][]database.CryptoKey{}
|
||||
for _, feature := range features {
|
||||
m[feature] = []database.CryptoKey{}
|
||||
}
|
||||
for _, key := range keys {
|
||||
if _, ok := m[key.Feature]; !ok {
|
||||
return nil, xerrors.Errorf("unknown feature: %s", key.Feature)
|
||||
}
|
||||
|
||||
m[key.Feature] = append(m[key.Feature], key)
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// minStartsAt ensures the minimum starts_at time we use for a new
|
||||
// key is no less than 3*the default rotation interval.
|
||||
func minStartsAt(key database.CryptoKey, now time.Time, keyDuration time.Duration) time.Time {
|
||||
expiresAt := key.ExpiresAt(keyDuration)
|
||||
minStartsAt := now.Add(3 * defaultRotationInterval)
|
||||
if expiresAt.Before(minStartsAt) {
|
||||
return minStartsAt
|
||||
}
|
||||
return expiresAt
|
||||
}
|
601
coderd/keyrotate/rotate_internal_test.go
Normal file
601
coderd/keyrotate/rotate_internal_test.go
Normal file
@ -0,0 +1,601 @@
|
||||
package keyrotate
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
func Test_rotateKeys(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("RotatesKeysNearExpiration", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
db, _ = dbtestutil.NewDB(t)
|
||||
clock = quartz.NewMock(t)
|
||||
keyDuration = time.Hour * 24 * 7
|
||||
logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
ctx = testutil.Context(t, testutil.WaitShort)
|
||||
)
|
||||
|
||||
kr := &rotator{
|
||||
db: db,
|
||||
keyDuration: keyDuration,
|
||||
clock: clock,
|
||||
logger: logger,
|
||||
features: []database.CryptoKeyFeature{
|
||||
database.CryptoKeyFeatureWorkspaceApps,
|
||||
},
|
||||
}
|
||||
|
||||
now := dbnow(clock)
|
||||
|
||||
// Seed the database with an existing key.
|
||||
oldKey := dbgen.CryptoKey(t, db, database.CryptoKey{
|
||||
Feature: database.CryptoKeyFeatureWorkspaceApps,
|
||||
StartsAt: now,
|
||||
Sequence: 15,
|
||||
})
|
||||
|
||||
// Advance the window to just inside rotation time.
|
||||
_ = clock.Advance(keyDuration - time.Minute*59)
|
||||
err := kr.rotateKeys(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
now = dbnow(clock)
|
||||
expectedDeletesAt := oldKey.ExpiresAt(keyDuration).Add(WorkspaceAppsTokenDuration + time.Hour)
|
||||
|
||||
// Fetch the old key, it should have an deletes_at now.
|
||||
oldKey, err = db.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{
|
||||
Feature: oldKey.Feature,
|
||||
Sequence: oldKey.Sequence,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, oldKey.DeletesAt.Time.UTC(), expectedDeletesAt)
|
||||
|
||||
// The new key should be created and have a starts_at of the old key's expires_at.
|
||||
newKey, err := db.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{
|
||||
Feature: database.CryptoKeyFeatureWorkspaceApps,
|
||||
Sequence: oldKey.Sequence + 1,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
requireKey(t, newKey, database.CryptoKeyFeatureWorkspaceApps, oldKey.ExpiresAt(keyDuration), nullTime, oldKey.Sequence+1)
|
||||
|
||||
// Advance the clock just before the keys delete time.
|
||||
clock.Advance(oldKey.DeletesAt.Time.UTC().Sub(now) - time.Second)
|
||||
|
||||
// No action should be taken.
|
||||
err = kr.rotateKeys(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
keys, err := db.GetCryptoKeys(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, keys, 2)
|
||||
|
||||
// Advance the clock just past the keys delete time.
|
||||
clock.Advance(oldKey.DeletesAt.Time.UTC().Sub(now) + time.Second)
|
||||
|
||||
// We should have deleted the old key.
|
||||
err = kr.rotateKeys(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// The old key should be "deleted".
|
||||
_, err = db.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{
|
||||
Feature: oldKey.Feature,
|
||||
Sequence: oldKey.Sequence,
|
||||
})
|
||||
require.ErrorIs(t, err, sql.ErrNoRows)
|
||||
|
||||
keys, err = db.GetCryptoKeys(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, keys, 1)
|
||||
require.Equal(t, newKey, keys[0])
|
||||
})
|
||||
|
||||
t.Run("DoesNotRotateValidKeys", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
db, _ = dbtestutil.NewDB(t)
|
||||
clock = quartz.NewMock(t)
|
||||
keyDuration = time.Hour * 24 * 7
|
||||
logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
ctx = testutil.Context(t, testutil.WaitShort)
|
||||
)
|
||||
|
||||
kr := &rotator{
|
||||
db: db,
|
||||
keyDuration: keyDuration,
|
||||
clock: clock,
|
||||
logger: logger,
|
||||
features: []database.CryptoKeyFeature{
|
||||
database.CryptoKeyFeatureWorkspaceApps,
|
||||
},
|
||||
}
|
||||
|
||||
now := dbnow(clock)
|
||||
|
||||
// Seed the database with an existing key
|
||||
existingKey := dbgen.CryptoKey(t, db, database.CryptoKey{
|
||||
Feature: database.CryptoKeyFeatureWorkspaceApps,
|
||||
StartsAt: now,
|
||||
Sequence: 123,
|
||||
})
|
||||
|
||||
// Advance the clock by 6 days, 22 hours. Once we
|
||||
// breach the last hour we will insert a new key.
|
||||
clock.Advance(keyDuration - 2*time.Hour)
|
||||
|
||||
err := kr.rotateKeys(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
keys, err := db.GetCryptoKeys(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, keys, 1)
|
||||
require.Equal(t, existingKey, keys[0])
|
||||
|
||||
// Advance it again to just before the key is scheduled to be rotated for sanity purposes.
|
||||
clock.Advance(time.Hour - time.Second)
|
||||
|
||||
err = kr.rotateKeys(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify that the existing key is still the only key in the database
|
||||
keys, err = db.GetCryptoKeys(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, keys, 1)
|
||||
requireKey(t, keys[0], existingKey.Feature, existingKey.StartsAt.UTC(), nullTime, existingKey.Sequence)
|
||||
})
|
||||
|
||||
// Simulate a situation where the database was manually altered such that we only have a key that is scheduled to be deleted and assert we insert a new key.
|
||||
t.Run("DeletesExpiredKeys", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
db, _ = dbtestutil.NewDB(t)
|
||||
clock = quartz.NewMock(t)
|
||||
keyDuration = time.Hour * 24 * 7
|
||||
logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
ctx = testutil.Context(t, testutil.WaitShort)
|
||||
)
|
||||
|
||||
kr := &rotator{
|
||||
db: db,
|
||||
keyDuration: keyDuration,
|
||||
clock: clock,
|
||||
logger: logger,
|
||||
features: []database.CryptoKeyFeature{
|
||||
database.CryptoKeyFeatureWorkspaceApps,
|
||||
},
|
||||
}
|
||||
|
||||
now := dbnow(clock)
|
||||
|
||||
// Seed the database with an existing key
|
||||
deletingKey := dbgen.CryptoKey(t, db, database.CryptoKey{
|
||||
Feature: database.CryptoKeyFeatureWorkspaceApps,
|
||||
StartsAt: now.Add(-keyDuration),
|
||||
Sequence: 789,
|
||||
DeletesAt: sql.NullTime{
|
||||
Time: now,
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
|
||||
err := kr.rotateKeys(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// We should only get one key since the old key
|
||||
// should be deleted.
|
||||
keys, err := db.GetCryptoKeys(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, keys, 1)
|
||||
requireKey(t, keys[0], deletingKey.Feature, deletingKey.DeletesAt.Time.UTC(), nullTime, deletingKey.Sequence+1)
|
||||
// The old key should be "deleted".
|
||||
_, err = db.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{
|
||||
Feature: deletingKey.Feature,
|
||||
Sequence: deletingKey.Sequence,
|
||||
})
|
||||
require.ErrorIs(t, err, sql.ErrNoRows)
|
||||
})
|
||||
|
||||
// This tests a situation where we have a key scheduled for deletion but it's still valid for use.
|
||||
// If no other key is detected we should insert a new key.
|
||||
t.Run("AddsKeyForDeletingKey", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
db, _ = dbtestutil.NewDB(t)
|
||||
clock = quartz.NewMock(t)
|
||||
keyDuration = time.Hour * 24 * 7
|
||||
logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
ctx = testutil.Context(t, testutil.WaitShort)
|
||||
)
|
||||
|
||||
kr := &rotator{
|
||||
db: db,
|
||||
keyDuration: keyDuration,
|
||||
clock: clock,
|
||||
logger: logger,
|
||||
features: []database.CryptoKeyFeature{
|
||||
database.CryptoKeyFeatureWorkspaceApps,
|
||||
},
|
||||
}
|
||||
|
||||
now := dbnow(clock)
|
||||
|
||||
// Seed the database with an existing key
|
||||
deletingKey := dbgen.CryptoKey(t, db, database.CryptoKey{
|
||||
Feature: database.CryptoKeyFeatureWorkspaceApps,
|
||||
StartsAt: now,
|
||||
Sequence: 456,
|
||||
DeletesAt: sql.NullTime{
|
||||
Time: now.Add(time.Hour),
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
|
||||
// We should only have inserted a key.
|
||||
err := kr.rotateKeys(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
keys, err := db.GetCryptoKeys(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, keys, 2)
|
||||
oldKey, newKey := keys[0], keys[1]
|
||||
if oldKey.Sequence != deletingKey.Sequence {
|
||||
oldKey, newKey = newKey, oldKey
|
||||
}
|
||||
requireKey(t, oldKey, deletingKey.Feature, deletingKey.StartsAt.UTC(), deletingKey.DeletesAt, deletingKey.Sequence)
|
||||
requireKey(t, newKey, deletingKey.Feature, now, nullTime, deletingKey.Sequence+1)
|
||||
})
|
||||
|
||||
t.Run("NoKeys", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
db, _ = dbtestutil.NewDB(t)
|
||||
clock = quartz.NewMock(t)
|
||||
keyDuration = time.Hour * 24 * 7
|
||||
logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
ctx = testutil.Context(t, testutil.WaitShort)
|
||||
)
|
||||
|
||||
kr := &rotator{
|
||||
db: db,
|
||||
keyDuration: keyDuration,
|
||||
clock: clock,
|
||||
logger: logger,
|
||||
features: []database.CryptoKeyFeature{
|
||||
database.CryptoKeyFeatureWorkspaceApps,
|
||||
},
|
||||
}
|
||||
|
||||
err := kr.rotateKeys(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
keys, err := db.GetCryptoKeys(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, keys, 1)
|
||||
requireKey(t, keys[0], database.CryptoKeyFeatureWorkspaceApps, clock.Now().UTC(), nullTime, 1)
|
||||
})
|
||||
|
||||
// Assert we insert a new key when the only key was manually deleted.
|
||||
t.Run("OnlyDeletedKeys", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
db, _ = dbtestutil.NewDB(t)
|
||||
clock = quartz.NewMock(t)
|
||||
keyDuration = time.Hour * 24 * 7
|
||||
logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
ctx = testutil.Context(t, testutil.WaitShort)
|
||||
)
|
||||
|
||||
kr := &rotator{
|
||||
db: db,
|
||||
keyDuration: keyDuration,
|
||||
clock: clock,
|
||||
logger: logger,
|
||||
features: []database.CryptoKeyFeature{
|
||||
database.CryptoKeyFeatureWorkspaceApps,
|
||||
},
|
||||
}
|
||||
|
||||
now := dbnow(clock)
|
||||
|
||||
deletedkey := dbgen.CryptoKey(t, db, database.CryptoKey{
|
||||
Feature: database.CryptoKeyFeatureWorkspaceApps,
|
||||
StartsAt: now,
|
||||
Sequence: 19,
|
||||
DeletesAt: sql.NullTime{
|
||||
Time: now.Add(time.Hour),
|
||||
Valid: true,
|
||||
},
|
||||
Secret: sql.NullString{
|
||||
String: "deleted",
|
||||
Valid: false,
|
||||
},
|
||||
})
|
||||
|
||||
err := kr.rotateKeys(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
keys, err := db.GetCryptoKeys(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, keys, 1)
|
||||
requireKey(t, keys[0], database.CryptoKeyFeatureWorkspaceApps, now, nullTime, deletedkey.Sequence+1)
|
||||
})
|
||||
|
||||
// This tests ensures that rotation works with multiple
|
||||
// features. It's mainly a sanity test since some bugs
|
||||
// are not unveiled in the simple n=1 case.
|
||||
t.Run("AllFeatures", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
db, _ = dbtestutil.NewDB(t)
|
||||
clock = quartz.NewMock(t)
|
||||
keyDuration = time.Hour * 24 * 30
|
||||
logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
ctx = testutil.Context(t, testutil.WaitShort)
|
||||
)
|
||||
|
||||
kr := &rotator{
|
||||
db: db,
|
||||
keyDuration: keyDuration,
|
||||
clock: clock,
|
||||
logger: logger,
|
||||
features: database.AllCryptoKeyFeatureValues(),
|
||||
}
|
||||
|
||||
now := dbnow(clock)
|
||||
|
||||
// We'll test a scenario where one feature has no valid keys.
|
||||
// Another has a key that should be rotate. And one that
|
||||
// has a valid key that shouldn't trigger an action.
|
||||
_ = dbgen.CryptoKey(t, db, database.CryptoKey{
|
||||
Feature: database.CryptoKeyFeatureTailnetResume,
|
||||
StartsAt: now.Add(-keyDuration),
|
||||
Sequence: 5,
|
||||
Secret: sql.NullString{
|
||||
String: "older key",
|
||||
Valid: false,
|
||||
},
|
||||
})
|
||||
deletedKey := dbgen.CryptoKey(t, db, database.CryptoKey{
|
||||
Feature: database.CryptoKeyFeatureTailnetResume,
|
||||
StartsAt: now.Add(-keyDuration),
|
||||
Sequence: 19,
|
||||
Secret: sql.NullString{
|
||||
String: "old key",
|
||||
Valid: false,
|
||||
},
|
||||
})
|
||||
|
||||
// Insert a key that should be rotated.
|
||||
rotatedKey := dbgen.CryptoKey(t, db, database.CryptoKey{
|
||||
Feature: database.CryptoKeyFeatureWorkspaceApps,
|
||||
StartsAt: now.Add(-keyDuration + time.Hour),
|
||||
Sequence: 42,
|
||||
})
|
||||
|
||||
// Insert a key that should not trigger an action.
|
||||
validKey := dbgen.CryptoKey(t, db, database.CryptoKey{
|
||||
Feature: database.CryptoKeyFeatureOidcConvert,
|
||||
StartsAt: now,
|
||||
Sequence: 17,
|
||||
})
|
||||
|
||||
err := kr.rotateKeys(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
keys, err := db.GetCryptoKeys(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, keys, 4)
|
||||
|
||||
kbf, err := keysByFeature(keys, database.AllCryptoKeyFeatureValues())
|
||||
require.NoError(t, err)
|
||||
|
||||
// No actions on OIDC convert.
|
||||
require.Len(t, kbf[database.CryptoKeyFeatureOidcConvert], 1)
|
||||
// Workspace apps should have been rotated.
|
||||
require.Len(t, kbf[database.CryptoKeyFeatureWorkspaceApps], 2)
|
||||
// No existing key for tailnet resume should've
|
||||
// caused a key to be inserted.
|
||||
require.Len(t, kbf[database.CryptoKeyFeatureTailnetResume], 1)
|
||||
|
||||
oidcKey := kbf[database.CryptoKeyFeatureOidcConvert][0]
|
||||
tailnetKey := kbf[database.CryptoKeyFeatureTailnetResume][0]
|
||||
requireKey(t, oidcKey, database.CryptoKeyFeatureOidcConvert, now, nullTime, validKey.Sequence)
|
||||
requireKey(t, tailnetKey, database.CryptoKeyFeatureTailnetResume, now, nullTime, deletedKey.Sequence+1)
|
||||
|
||||
newKey := kbf[database.CryptoKeyFeatureWorkspaceApps][0]
|
||||
oldKey := kbf[database.CryptoKeyFeatureWorkspaceApps][1]
|
||||
if newKey.Sequence == rotatedKey.Sequence {
|
||||
oldKey, newKey = newKey, oldKey
|
||||
}
|
||||
deletesAt := sql.NullTime{
|
||||
Time: rotatedKey.ExpiresAt(keyDuration).Add(WorkspaceAppsTokenDuration + time.Hour),
|
||||
Valid: true,
|
||||
}
|
||||
requireKey(t, oldKey, database.CryptoKeyFeatureWorkspaceApps, rotatedKey.StartsAt.UTC(), deletesAt, rotatedKey.Sequence)
|
||||
requireKey(t, newKey, database.CryptoKeyFeatureWorkspaceApps, rotatedKey.ExpiresAt(keyDuration), nullTime, rotatedKey.Sequence+1)
|
||||
})
|
||||
|
||||
t.Run("UnknownFeature", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
db, _ = dbtestutil.NewDB(t)
|
||||
clock = quartz.NewMock(t)
|
||||
keyDuration = time.Hour * 24 * 7
|
||||
logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
ctx = testutil.Context(t, testutil.WaitShort)
|
||||
)
|
||||
|
||||
kr := &rotator{
|
||||
db: db,
|
||||
keyDuration: keyDuration,
|
||||
clock: clock,
|
||||
logger: logger,
|
||||
features: []database.CryptoKeyFeature{database.CryptoKeyFeature("unknown")},
|
||||
}
|
||||
|
||||
err := kr.rotateKeys(ctx)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("MinStartsAt", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
db, _ = dbtestutil.NewDB(t)
|
||||
clock = quartz.NewMock(t)
|
||||
keyDuration = time.Hour * 24 * 5
|
||||
logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
ctx = testutil.Context(t, testutil.WaitShort)
|
||||
)
|
||||
|
||||
now := dbnow(clock)
|
||||
|
||||
kr := &rotator{
|
||||
db: db,
|
||||
keyDuration: keyDuration,
|
||||
clock: clock,
|
||||
logger: logger,
|
||||
features: []database.CryptoKeyFeature{database.CryptoKeyFeatureWorkspaceApps},
|
||||
}
|
||||
|
||||
expiringKey := dbgen.CryptoKey(t, db, database.CryptoKey{
|
||||
Feature: database.CryptoKeyFeatureWorkspaceApps,
|
||||
StartsAt: now.Add(-keyDuration),
|
||||
Sequence: 345,
|
||||
})
|
||||
|
||||
err := kr.rotateKeys(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
keys, err := db.GetCryptoKeys(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, keys, 2)
|
||||
|
||||
rotatedKey, err := db.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{
|
||||
Feature: expiringKey.Feature,
|
||||
Sequence: expiringKey.Sequence + 1,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, now.Add(defaultRotationInterval*3), rotatedKey.StartsAt.UTC())
|
||||
})
|
||||
|
||||
// Test that the the deletes_at of a key that is well past its expiration
|
||||
// Has its deletes_at field set to value that is relative
|
||||
// to the current time to afford propagation time for the
|
||||
// new key.
|
||||
t.Run("ExtensivelyExpiredKey", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
db, _ = dbtestutil.NewDB(t)
|
||||
clock = quartz.NewMock(t)
|
||||
keyDuration = time.Hour * 24 * 3
|
||||
logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
ctx = testutil.Context(t, testutil.WaitShort)
|
||||
)
|
||||
|
||||
kr := &rotator{
|
||||
db: db,
|
||||
keyDuration: keyDuration,
|
||||
clock: clock,
|
||||
logger: logger,
|
||||
features: []database.CryptoKeyFeature{database.CryptoKeyFeatureWorkspaceApps},
|
||||
}
|
||||
|
||||
now := dbnow(clock)
|
||||
|
||||
expiredKey := dbgen.CryptoKey(t, db, database.CryptoKey{
|
||||
Feature: database.CryptoKeyFeatureWorkspaceApps,
|
||||
StartsAt: now.Add(-keyDuration - 2*time.Hour),
|
||||
Sequence: 19,
|
||||
})
|
||||
|
||||
deletedKey := dbgen.CryptoKey(t, db, database.CryptoKey{
|
||||
Feature: database.CryptoKeyFeatureWorkspaceApps,
|
||||
StartsAt: now,
|
||||
Sequence: 20,
|
||||
Secret: sql.NullString{
|
||||
String: "deleted",
|
||||
Valid: false,
|
||||
},
|
||||
})
|
||||
|
||||
err := kr.rotateKeys(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
keys, err := db.GetCryptoKeys(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, keys, 2)
|
||||
|
||||
deletesAtKey, err := db.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{
|
||||
Feature: expiredKey.Feature,
|
||||
Sequence: expiredKey.Sequence,
|
||||
})
|
||||
|
||||
deletesAt := sql.NullTime{
|
||||
Time: now.Add(defaultRotationInterval * 3).Add(WorkspaceAppsTokenDuration + time.Hour),
|
||||
Valid: true,
|
||||
}
|
||||
require.NoError(t, err)
|
||||
requireKey(t, deletesAtKey, expiredKey.Feature, expiredKey.StartsAt.UTC(), deletesAt, expiredKey.Sequence)
|
||||
|
||||
newKey, err := db.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{
|
||||
Feature: expiredKey.Feature,
|
||||
Sequence: deletedKey.Sequence + 1,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
requireKey(t, newKey, expiredKey.Feature, now.Add(defaultRotationInterval*3), nullTime, deletedKey.Sequence+1)
|
||||
})
|
||||
}
|
||||
|
||||
func dbnow(c quartz.Clock) time.Time {
|
||||
return dbtime.Time(c.Now().UTC())
|
||||
}
|
||||
|
||||
func requireKey(t *testing.T, key database.CryptoKey, feature database.CryptoKeyFeature, startsAt time.Time, deletesAt sql.NullTime, sequence int32) {
|
||||
t.Helper()
|
||||
require.Equal(t, feature, key.Feature)
|
||||
require.Equal(t, startsAt, key.StartsAt.UTC())
|
||||
require.Equal(t, deletesAt.Valid, key.DeletesAt.Valid)
|
||||
require.Equal(t, deletesAt.Time.UTC(), key.DeletesAt.Time.UTC())
|
||||
require.Equal(t, sequence, key.Sequence)
|
||||
|
||||
secret, err := hex.DecodeString(key.Secret.String)
|
||||
require.NoError(t, err)
|
||||
|
||||
switch key.Feature {
|
||||
case database.CryptoKeyFeatureOidcConvert:
|
||||
require.Len(t, secret, 32)
|
||||
case database.CryptoKeyFeatureWorkspaceApps:
|
||||
require.Len(t, secret, 96)
|
||||
case database.CryptoKeyFeatureTailnetResume:
|
||||
require.Len(t, secret, 64)
|
||||
default:
|
||||
t.Fatalf("unknown key feature: %s", key.Feature)
|
||||
}
|
||||
}
|
||||
|
||||
var nullTime = sql.NullTime{Time: time.Time{}, Valid: false}
|
124
coderd/keyrotate/rotate_test.go
Normal file
124
coderd/keyrotate/rotate_test.go
Normal file
@ -0,0 +1,124 @@
|
||||
package keyrotate_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/coderd/keyrotate"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
func TestRotator(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("NoKeysOnInit", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
db, _ = dbtestutil.NewDB(t)
|
||||
clock = quartz.NewMock(t)
|
||||
logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
ctx = testutil.Context(t, testutil.WaitShort)
|
||||
)
|
||||
|
||||
dbkeys, err := db.GetCryptoKeys(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, dbkeys, 0)
|
||||
|
||||
err = keyrotate.StartRotator(ctx, logger, db, keyrotate.WithClock(clock))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Fetch the keys from the database and ensure they
|
||||
// are as expected.
|
||||
dbkeys, err = db.GetCryptoKeys(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, dbkeys, len(database.AllCryptoKeyFeatureValues()))
|
||||
requireContainsAllFeatures(t, dbkeys)
|
||||
})
|
||||
|
||||
t.Run("RotateKeys", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
db, _ = dbtestutil.NewDB(t)
|
||||
clock = quartz.NewMock(t)
|
||||
logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
ctx = testutil.Context(t, testutil.WaitShort)
|
||||
)
|
||||
|
||||
now := clock.Now().UTC()
|
||||
|
||||
rotatingKey := dbgen.CryptoKey(t, db, database.CryptoKey{
|
||||
Feature: database.CryptoKeyFeatureWorkspaceApps,
|
||||
StartsAt: now.Add(-keyrotate.DefaultKeyDuration + time.Hour + time.Minute),
|
||||
Sequence: 12345,
|
||||
})
|
||||
|
||||
trap := clock.Trap().TickerFunc()
|
||||
t.Cleanup(trap.Close)
|
||||
|
||||
err := keyrotate.StartRotator(ctx, logger, db, keyrotate.WithClock(clock))
|
||||
require.NoError(t, err)
|
||||
|
||||
initialKeyLen := len(database.AllCryptoKeyFeatureValues())
|
||||
// Fetch the keys from the database and ensure they
|
||||
// are as expected.
|
||||
dbkeys, err := db.GetCryptoKeys(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, dbkeys, initialKeyLen)
|
||||
requireContainsAllFeatures(t, dbkeys)
|
||||
|
||||
trap.MustWait(ctx).Release()
|
||||
_, wait := clock.AdvanceNext()
|
||||
wait.MustWait(ctx)
|
||||
|
||||
keys, err := db.GetCryptoKeys(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, keys, initialKeyLen+1)
|
||||
|
||||
newKey, err := db.GetLatestCryptoKeyByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, rotatingKey.Sequence+1, newKey.Sequence)
|
||||
require.Equal(t, rotatingKey.ExpiresAt(keyrotate.DefaultKeyDuration), newKey.StartsAt.UTC())
|
||||
require.False(t, newKey.DeletesAt.Valid)
|
||||
|
||||
oldKey, err := db.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{
|
||||
Feature: rotatingKey.Feature,
|
||||
Sequence: rotatingKey.Sequence,
|
||||
})
|
||||
expectedDeletesAt := rotatingKey.StartsAt.Add(keyrotate.DefaultKeyDuration + time.Hour + keyrotate.WorkspaceAppsTokenDuration)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, rotatingKey.StartsAt, oldKey.StartsAt)
|
||||
require.True(t, oldKey.DeletesAt.Valid)
|
||||
require.Equal(t, expectedDeletesAt, oldKey.DeletesAt.Time)
|
||||
|
||||
// Try rotating again and ensure no keys are rotated.
|
||||
_, wait = clock.AdvanceNext()
|
||||
wait.MustWait(ctx)
|
||||
|
||||
keys, err = db.GetCryptoKeys(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, keys, initialKeyLen+1)
|
||||
})
|
||||
}
|
||||
|
||||
func requireContainsAllFeatures(t *testing.T, keys []database.CryptoKey) {
|
||||
t.Helper()
|
||||
|
||||
features := make(map[database.CryptoKeyFeature]bool)
|
||||
for _, key := range keys {
|
||||
features[key.Feature] = true
|
||||
}
|
||||
for _, feature := range database.AllCryptoKeyFeatureValues() {
|
||||
require.True(t, features[feature])
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user