feat: implement key rotation system (#14710)

This commit is contained in:
Jon Ayers
2024-09-19 19:12:44 +01:00
committed by GitHub
parent dbe6b6c224
commit 2d5c068525
5 changed files with 1029 additions and 1 deletions

View File

@ -902,7 +902,11 @@ func CryptoKey(t testing.TB, db database.Store, seed database.CryptoKey) databas
seed.Feature = takeFirst(seed.Feature, database.CryptoKeyFeatureWorkspaceApps)
if !seed.Secret.Valid {
// An empty string for the secret is interpreted as
// a caller wanting a new secret to be generated.
// To generate a key with a NULL secret set Valid=false
// and String to a non-empty string.
if seed.Secret.String == "" {
secret, err := newCryptoKeySecret(seed.Feature)
require.NoError(t, err, "generate secret")
seed.Secret = sql.NullString{

View File

@ -11,6 +11,7 @@ const (
LockIDDBRollup
LockIDDBPurge
LockIDNotificationsReportGenerator
LockIDCryptoKeyRotation
)
// GenLockID generates a unique and consistent lock ID from a given string.

298
coderd/keyrotate/rotate.go Normal file
View File

@ -0,0 +1,298 @@
package keyrotate
import (
"context"
"crypto/rand"
"database/sql"
"encoding/hex"
"time"
"golang.org/x/xerrors"
"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/quartz"
)
const (
WorkspaceAppsTokenDuration = time.Minute
OIDCConvertTokenDuration = time.Minute * 5
TailnetResumeTokenDuration = time.Hour * 24
// defaultRotationInterval is the default interval at which keys are checked for rotation.
defaultRotationInterval = time.Minute * 10
// DefaultKeyDuration is the default duration for which a key is valid. It applies to all features.
DefaultKeyDuration = time.Hour * 24 * 30
)
// rotator is responsible for rotating keys in the database.
type rotator struct {
db database.Store
logger slog.Logger
clock quartz.Clock
keyDuration time.Duration
features []database.CryptoKeyFeature
}
type Option func(*rotator)
func WithClock(clock quartz.Clock) Option {
return func(r *rotator) {
r.clock = clock
}
}
func WithKeyDuration(keyDuration time.Duration) Option {
return func(r *rotator) {
r.keyDuration = keyDuration
}
}
// StartRotator starts a background process that rotates keys in the database.
// It ensures there's at least one valid key per feature prior to returning.
// Canceling the provided context will stop the background process.
func StartRotator(ctx context.Context, logger slog.Logger, db database.Store, opts ...Option) error {
kr := &rotator{
db: db,
logger: logger,
clock: quartz.NewReal(),
keyDuration: DefaultKeyDuration,
features: database.AllCryptoKeyFeatureValues(),
}
for _, opt := range opts {
opt(kr)
}
err := kr.rotateKeys(ctx)
if err != nil {
return xerrors.Errorf("rotate keys: %w", err)
}
go kr.start(ctx)
return nil
}
// start begins the process of rotating keys.
// Canceling the context will stop the rotation process.
func (k *rotator) start(ctx context.Context) {
k.clock.TickerFunc(ctx, defaultRotationInterval, func() error {
err := k.rotateKeys(ctx)
if err != nil {
k.logger.Error(ctx, "failed to rotate keys", slog.Error(err))
}
return nil
})
k.logger.Debug(ctx, "ctx canceled, stopping key rotation")
}
// rotateKeys checks for any keys needing rotation or deletion and
// may insert a new key if it detects that a valid one does
// not exist for a feature.
func (k *rotator) rotateKeys(ctx context.Context) error {
return k.db.InTx(
func(tx database.Store) error {
err := tx.AcquireLock(ctx, database.LockIDCryptoKeyRotation)
if err != nil {
return xerrors.Errorf("acquire lock: %w", err)
}
cryptokeys, err := tx.GetCryptoKeys(ctx)
if err != nil {
return xerrors.Errorf("get keys: %w", err)
}
featureKeys, err := keysByFeature(cryptokeys, k.features)
if err != nil {
return xerrors.Errorf("keys by feature: %w", err)
}
now := dbtime.Time(k.clock.Now().UTC())
for feature, keys := range featureKeys {
// We'll use a counter to determine if we should insert a new key. We should always have at least one key for a feature.
var validKeys int
for _, key := range keys {
switch {
case shouldDeleteKey(key, now):
_, err := tx.DeleteCryptoKey(ctx, database.DeleteCryptoKeyParams{
Feature: key.Feature,
Sequence: key.Sequence,
})
if err != nil {
return xerrors.Errorf("delete key: %w", err)
}
k.logger.Debug(ctx, "deleted key",
slog.F("key", key.Sequence),
slog.F("feature", key.Feature),
)
case shouldRotateKey(key, k.keyDuration, now):
_, err := k.rotateKey(ctx, tx, key, now)
if err != nil {
return xerrors.Errorf("rotate key: %w", err)
}
k.logger.Debug(ctx, "rotated key",
slog.F("key", key.Sequence),
slog.F("feature", key.Feature),
)
validKeys++
default:
// We only consider keys without a populated deletes_at field as valid.
// This is because under normal circumstances the deletes_at field
// is set during rotation (meaning a new key was generated)
// but it's possible if the database was manually altered to
// delete the new key we may be in a situation where there
// isn't a key to replace the one scheduled for deletion.
if !key.DeletesAt.Valid {
validKeys++
}
}
}
if validKeys == 0 {
k.logger.Info(ctx, "no valid keys detected, inserting new key",
slog.F("feature", feature),
)
_, err := k.insertNewKey(ctx, tx, feature, now)
if err != nil {
return xerrors.Errorf("insert new key: %w", err)
}
}
}
return nil
}, &sql.TxOptions{
Isolation: sql.LevelRepeatableRead,
})
}
func (k *rotator) insertNewKey(ctx context.Context, tx database.Store, feature database.CryptoKeyFeature, startsAt time.Time) (database.CryptoKey, error) {
secret, err := generateNewSecret(feature)
if err != nil {
return database.CryptoKey{}, xerrors.Errorf("generate new secret: %w", err)
}
latestKey, err := tx.GetLatestCryptoKeyByFeature(ctx, feature)
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
return database.CryptoKey{}, xerrors.Errorf("get latest key: %w", err)
}
newKey, err := tx.InsertCryptoKey(ctx, database.InsertCryptoKeyParams{
Feature: feature,
Sequence: latestKey.Sequence + 1,
Secret: sql.NullString{
String: secret,
Valid: true,
},
// Set by dbcrypt if it's required.
SecretKeyID: sql.NullString{},
StartsAt: startsAt.UTC(),
})
if err != nil {
return database.CryptoKey{}, xerrors.Errorf("inserting new key: %w", err)
}
k.logger.Info(ctx, "inserted new key for feature", slog.F("feature", feature))
return newKey, nil
}
func (k *rotator) rotateKey(ctx context.Context, tx database.Store, key database.CryptoKey, now time.Time) ([]database.CryptoKey, error) {
startsAt := minStartsAt(key, now, k.keyDuration)
newKey, err := k.insertNewKey(ctx, tx, key.Feature, startsAt)
if err != nil {
return nil, xerrors.Errorf("insert new key: %w", err)
}
// Set old key's deletes_at to an hour + however long the token
// for this feature is expected to be valid for. This should
// allow for sufficient time for the new key to propagate to
// dependent services (i.e. Workspace Proxies).
deletesAt := startsAt.Add(time.Hour).Add(tokenDuration(key.Feature))
updatedKey, err := tx.UpdateCryptoKeyDeletesAt(ctx, database.UpdateCryptoKeyDeletesAtParams{
Feature: key.Feature,
Sequence: key.Sequence,
DeletesAt: sql.NullTime{
Time: deletesAt.UTC(),
Valid: true,
},
})
if err != nil {
return nil, xerrors.Errorf("update old key's deletes_at: %w", err)
}
return []database.CryptoKey{updatedKey, newKey}, nil
}
func generateNewSecret(feature database.CryptoKeyFeature) (string, error) {
switch feature {
case database.CryptoKeyFeatureWorkspaceApps:
return generateKey(96)
case database.CryptoKeyFeatureOidcConvert:
return generateKey(32)
case database.CryptoKeyFeatureTailnetResume:
return generateKey(64)
}
return "", xerrors.Errorf("unknown feature: %s", feature)
}
func generateKey(length int) (string, error) {
b := make([]byte, length)
_, err := rand.Read(b)
if err != nil {
return "", xerrors.Errorf("rand read: %w", err)
}
return hex.EncodeToString(b), nil
}
func tokenDuration(feature database.CryptoKeyFeature) time.Duration {
switch feature {
case database.CryptoKeyFeatureWorkspaceApps:
return WorkspaceAppsTokenDuration
case database.CryptoKeyFeatureOidcConvert:
return OIDCConvertTokenDuration
case database.CryptoKeyFeatureTailnetResume:
return TailnetResumeTokenDuration
default:
return 0
}
}
func shouldDeleteKey(key database.CryptoKey, now time.Time) bool {
return key.DeletesAt.Valid && !now.Before(key.DeletesAt.Time.UTC())
}
func shouldRotateKey(key database.CryptoKey, keyDuration time.Duration, now time.Time) bool {
// If deletes_at is set, we've already inserted a key.
if key.DeletesAt.Valid {
return false
}
expirationTime := key.ExpiresAt(keyDuration)
return !now.Add(time.Hour).UTC().Before(expirationTime)
}
func keysByFeature(keys []database.CryptoKey, features []database.CryptoKeyFeature) (map[database.CryptoKeyFeature][]database.CryptoKey, error) {
m := map[database.CryptoKeyFeature][]database.CryptoKey{}
for _, feature := range features {
m[feature] = []database.CryptoKey{}
}
for _, key := range keys {
if _, ok := m[key.Feature]; !ok {
return nil, xerrors.Errorf("unknown feature: %s", key.Feature)
}
m[key.Feature] = append(m[key.Feature], key)
}
return m, nil
}
// minStartsAt ensures the minimum starts_at time we use for a new
// key is no less than 3*the default rotation interval.
func minStartsAt(key database.CryptoKey, now time.Time, keyDuration time.Duration) time.Time {
expiresAt := key.ExpiresAt(keyDuration)
minStartsAt := now.Add(3 * defaultRotationInterval)
if expiresAt.Before(minStartsAt) {
return minStartsAt
}
return expiresAt
}

View File

@ -0,0 +1,601 @@
package keyrotate
import (
"database/sql"
"encoding/hex"
"testing"
"time"
"github.com/stretchr/testify/require"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
)
func Test_rotateKeys(t *testing.T) {
t.Parallel()
t.Run("RotatesKeysNearExpiration", func(t *testing.T) {
t.Parallel()
var (
db, _ = dbtestutil.NewDB(t)
clock = quartz.NewMock(t)
keyDuration = time.Hour * 24 * 7
logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug)
ctx = testutil.Context(t, testutil.WaitShort)
)
kr := &rotator{
db: db,
keyDuration: keyDuration,
clock: clock,
logger: logger,
features: []database.CryptoKeyFeature{
database.CryptoKeyFeatureWorkspaceApps,
},
}
now := dbnow(clock)
// Seed the database with an existing key.
oldKey := dbgen.CryptoKey(t, db, database.CryptoKey{
Feature: database.CryptoKeyFeatureWorkspaceApps,
StartsAt: now,
Sequence: 15,
})
// Advance the window to just inside rotation time.
_ = clock.Advance(keyDuration - time.Minute*59)
err := kr.rotateKeys(ctx)
require.NoError(t, err)
now = dbnow(clock)
expectedDeletesAt := oldKey.ExpiresAt(keyDuration).Add(WorkspaceAppsTokenDuration + time.Hour)
// Fetch the old key, it should have an deletes_at now.
oldKey, err = db.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{
Feature: oldKey.Feature,
Sequence: oldKey.Sequence,
})
require.NoError(t, err)
require.Equal(t, oldKey.DeletesAt.Time.UTC(), expectedDeletesAt)
// The new key should be created and have a starts_at of the old key's expires_at.
newKey, err := db.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{
Feature: database.CryptoKeyFeatureWorkspaceApps,
Sequence: oldKey.Sequence + 1,
})
require.NoError(t, err)
requireKey(t, newKey, database.CryptoKeyFeatureWorkspaceApps, oldKey.ExpiresAt(keyDuration), nullTime, oldKey.Sequence+1)
// Advance the clock just before the keys delete time.
clock.Advance(oldKey.DeletesAt.Time.UTC().Sub(now) - time.Second)
// No action should be taken.
err = kr.rotateKeys(ctx)
require.NoError(t, err)
keys, err := db.GetCryptoKeys(ctx)
require.NoError(t, err)
require.Len(t, keys, 2)
// Advance the clock just past the keys delete time.
clock.Advance(oldKey.DeletesAt.Time.UTC().Sub(now) + time.Second)
// We should have deleted the old key.
err = kr.rotateKeys(ctx)
require.NoError(t, err)
// The old key should be "deleted".
_, err = db.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{
Feature: oldKey.Feature,
Sequence: oldKey.Sequence,
})
require.ErrorIs(t, err, sql.ErrNoRows)
keys, err = db.GetCryptoKeys(ctx)
require.NoError(t, err)
require.Len(t, keys, 1)
require.Equal(t, newKey, keys[0])
})
t.Run("DoesNotRotateValidKeys", func(t *testing.T) {
t.Parallel()
var (
db, _ = dbtestutil.NewDB(t)
clock = quartz.NewMock(t)
keyDuration = time.Hour * 24 * 7
logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug)
ctx = testutil.Context(t, testutil.WaitShort)
)
kr := &rotator{
db: db,
keyDuration: keyDuration,
clock: clock,
logger: logger,
features: []database.CryptoKeyFeature{
database.CryptoKeyFeatureWorkspaceApps,
},
}
now := dbnow(clock)
// Seed the database with an existing key
existingKey := dbgen.CryptoKey(t, db, database.CryptoKey{
Feature: database.CryptoKeyFeatureWorkspaceApps,
StartsAt: now,
Sequence: 123,
})
// Advance the clock by 6 days, 22 hours. Once we
// breach the last hour we will insert a new key.
clock.Advance(keyDuration - 2*time.Hour)
err := kr.rotateKeys(ctx)
require.NoError(t, err)
keys, err := db.GetCryptoKeys(ctx)
require.NoError(t, err)
require.Len(t, keys, 1)
require.Equal(t, existingKey, keys[0])
// Advance it again to just before the key is scheduled to be rotated for sanity purposes.
clock.Advance(time.Hour - time.Second)
err = kr.rotateKeys(ctx)
require.NoError(t, err)
// Verify that the existing key is still the only key in the database
keys, err = db.GetCryptoKeys(ctx)
require.NoError(t, err)
require.Len(t, keys, 1)
requireKey(t, keys[0], existingKey.Feature, existingKey.StartsAt.UTC(), nullTime, existingKey.Sequence)
})
// Simulate a situation where the database was manually altered such that we only have a key that is scheduled to be deleted and assert we insert a new key.
t.Run("DeletesExpiredKeys", func(t *testing.T) {
t.Parallel()
var (
db, _ = dbtestutil.NewDB(t)
clock = quartz.NewMock(t)
keyDuration = time.Hour * 24 * 7
logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug)
ctx = testutil.Context(t, testutil.WaitShort)
)
kr := &rotator{
db: db,
keyDuration: keyDuration,
clock: clock,
logger: logger,
features: []database.CryptoKeyFeature{
database.CryptoKeyFeatureWorkspaceApps,
},
}
now := dbnow(clock)
// Seed the database with an existing key
deletingKey := dbgen.CryptoKey(t, db, database.CryptoKey{
Feature: database.CryptoKeyFeatureWorkspaceApps,
StartsAt: now.Add(-keyDuration),
Sequence: 789,
DeletesAt: sql.NullTime{
Time: now,
Valid: true,
},
})
err := kr.rotateKeys(ctx)
require.NoError(t, err)
// We should only get one key since the old key
// should be deleted.
keys, err := db.GetCryptoKeys(ctx)
require.NoError(t, err)
require.Len(t, keys, 1)
requireKey(t, keys[0], deletingKey.Feature, deletingKey.DeletesAt.Time.UTC(), nullTime, deletingKey.Sequence+1)
// The old key should be "deleted".
_, err = db.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{
Feature: deletingKey.Feature,
Sequence: deletingKey.Sequence,
})
require.ErrorIs(t, err, sql.ErrNoRows)
})
// This tests a situation where we have a key scheduled for deletion but it's still valid for use.
// If no other key is detected we should insert a new key.
t.Run("AddsKeyForDeletingKey", func(t *testing.T) {
t.Parallel()
var (
db, _ = dbtestutil.NewDB(t)
clock = quartz.NewMock(t)
keyDuration = time.Hour * 24 * 7
logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug)
ctx = testutil.Context(t, testutil.WaitShort)
)
kr := &rotator{
db: db,
keyDuration: keyDuration,
clock: clock,
logger: logger,
features: []database.CryptoKeyFeature{
database.CryptoKeyFeatureWorkspaceApps,
},
}
now := dbnow(clock)
// Seed the database with an existing key
deletingKey := dbgen.CryptoKey(t, db, database.CryptoKey{
Feature: database.CryptoKeyFeatureWorkspaceApps,
StartsAt: now,
Sequence: 456,
DeletesAt: sql.NullTime{
Time: now.Add(time.Hour),
Valid: true,
},
})
// We should only have inserted a key.
err := kr.rotateKeys(ctx)
require.NoError(t, err)
keys, err := db.GetCryptoKeys(ctx)
require.NoError(t, err)
require.Len(t, keys, 2)
oldKey, newKey := keys[0], keys[1]
if oldKey.Sequence != deletingKey.Sequence {
oldKey, newKey = newKey, oldKey
}
requireKey(t, oldKey, deletingKey.Feature, deletingKey.StartsAt.UTC(), deletingKey.DeletesAt, deletingKey.Sequence)
requireKey(t, newKey, deletingKey.Feature, now, nullTime, deletingKey.Sequence+1)
})
t.Run("NoKeys", func(t *testing.T) {
t.Parallel()
var (
db, _ = dbtestutil.NewDB(t)
clock = quartz.NewMock(t)
keyDuration = time.Hour * 24 * 7
logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug)
ctx = testutil.Context(t, testutil.WaitShort)
)
kr := &rotator{
db: db,
keyDuration: keyDuration,
clock: clock,
logger: logger,
features: []database.CryptoKeyFeature{
database.CryptoKeyFeatureWorkspaceApps,
},
}
err := kr.rotateKeys(ctx)
require.NoError(t, err)
keys, err := db.GetCryptoKeys(ctx)
require.NoError(t, err)
require.Len(t, keys, 1)
requireKey(t, keys[0], database.CryptoKeyFeatureWorkspaceApps, clock.Now().UTC(), nullTime, 1)
})
// Assert we insert a new key when the only key was manually deleted.
t.Run("OnlyDeletedKeys", func(t *testing.T) {
t.Parallel()
var (
db, _ = dbtestutil.NewDB(t)
clock = quartz.NewMock(t)
keyDuration = time.Hour * 24 * 7
logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug)
ctx = testutil.Context(t, testutil.WaitShort)
)
kr := &rotator{
db: db,
keyDuration: keyDuration,
clock: clock,
logger: logger,
features: []database.CryptoKeyFeature{
database.CryptoKeyFeatureWorkspaceApps,
},
}
now := dbnow(clock)
deletedkey := dbgen.CryptoKey(t, db, database.CryptoKey{
Feature: database.CryptoKeyFeatureWorkspaceApps,
StartsAt: now,
Sequence: 19,
DeletesAt: sql.NullTime{
Time: now.Add(time.Hour),
Valid: true,
},
Secret: sql.NullString{
String: "deleted",
Valid: false,
},
})
err := kr.rotateKeys(ctx)
require.NoError(t, err)
keys, err := db.GetCryptoKeys(ctx)
require.NoError(t, err)
require.Len(t, keys, 1)
requireKey(t, keys[0], database.CryptoKeyFeatureWorkspaceApps, now, nullTime, deletedkey.Sequence+1)
})
// This tests ensures that rotation works with multiple
// features. It's mainly a sanity test since some bugs
// are not unveiled in the simple n=1 case.
t.Run("AllFeatures", func(t *testing.T) {
t.Parallel()
var (
db, _ = dbtestutil.NewDB(t)
clock = quartz.NewMock(t)
keyDuration = time.Hour * 24 * 30
logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug)
ctx = testutil.Context(t, testutil.WaitShort)
)
kr := &rotator{
db: db,
keyDuration: keyDuration,
clock: clock,
logger: logger,
features: database.AllCryptoKeyFeatureValues(),
}
now := dbnow(clock)
// We'll test a scenario where one feature has no valid keys.
// Another has a key that should be rotate. And one that
// has a valid key that shouldn't trigger an action.
_ = dbgen.CryptoKey(t, db, database.CryptoKey{
Feature: database.CryptoKeyFeatureTailnetResume,
StartsAt: now.Add(-keyDuration),
Sequence: 5,
Secret: sql.NullString{
String: "older key",
Valid: false,
},
})
deletedKey := dbgen.CryptoKey(t, db, database.CryptoKey{
Feature: database.CryptoKeyFeatureTailnetResume,
StartsAt: now.Add(-keyDuration),
Sequence: 19,
Secret: sql.NullString{
String: "old key",
Valid: false,
},
})
// Insert a key that should be rotated.
rotatedKey := dbgen.CryptoKey(t, db, database.CryptoKey{
Feature: database.CryptoKeyFeatureWorkspaceApps,
StartsAt: now.Add(-keyDuration + time.Hour),
Sequence: 42,
})
// Insert a key that should not trigger an action.
validKey := dbgen.CryptoKey(t, db, database.CryptoKey{
Feature: database.CryptoKeyFeatureOidcConvert,
StartsAt: now,
Sequence: 17,
})
err := kr.rotateKeys(ctx)
require.NoError(t, err)
keys, err := db.GetCryptoKeys(ctx)
require.NoError(t, err)
require.Len(t, keys, 4)
kbf, err := keysByFeature(keys, database.AllCryptoKeyFeatureValues())
require.NoError(t, err)
// No actions on OIDC convert.
require.Len(t, kbf[database.CryptoKeyFeatureOidcConvert], 1)
// Workspace apps should have been rotated.
require.Len(t, kbf[database.CryptoKeyFeatureWorkspaceApps], 2)
// No existing key for tailnet resume should've
// caused a key to be inserted.
require.Len(t, kbf[database.CryptoKeyFeatureTailnetResume], 1)
oidcKey := kbf[database.CryptoKeyFeatureOidcConvert][0]
tailnetKey := kbf[database.CryptoKeyFeatureTailnetResume][0]
requireKey(t, oidcKey, database.CryptoKeyFeatureOidcConvert, now, nullTime, validKey.Sequence)
requireKey(t, tailnetKey, database.CryptoKeyFeatureTailnetResume, now, nullTime, deletedKey.Sequence+1)
newKey := kbf[database.CryptoKeyFeatureWorkspaceApps][0]
oldKey := kbf[database.CryptoKeyFeatureWorkspaceApps][1]
if newKey.Sequence == rotatedKey.Sequence {
oldKey, newKey = newKey, oldKey
}
deletesAt := sql.NullTime{
Time: rotatedKey.ExpiresAt(keyDuration).Add(WorkspaceAppsTokenDuration + time.Hour),
Valid: true,
}
requireKey(t, oldKey, database.CryptoKeyFeatureWorkspaceApps, rotatedKey.StartsAt.UTC(), deletesAt, rotatedKey.Sequence)
requireKey(t, newKey, database.CryptoKeyFeatureWorkspaceApps, rotatedKey.ExpiresAt(keyDuration), nullTime, rotatedKey.Sequence+1)
})
t.Run("UnknownFeature", func(t *testing.T) {
t.Parallel()
var (
db, _ = dbtestutil.NewDB(t)
clock = quartz.NewMock(t)
keyDuration = time.Hour * 24 * 7
logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug)
ctx = testutil.Context(t, testutil.WaitShort)
)
kr := &rotator{
db: db,
keyDuration: keyDuration,
clock: clock,
logger: logger,
features: []database.CryptoKeyFeature{database.CryptoKeyFeature("unknown")},
}
err := kr.rotateKeys(ctx)
require.Error(t, err)
})
t.Run("MinStartsAt", func(t *testing.T) {
t.Parallel()
var (
db, _ = dbtestutil.NewDB(t)
clock = quartz.NewMock(t)
keyDuration = time.Hour * 24 * 5
logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug)
ctx = testutil.Context(t, testutil.WaitShort)
)
now := dbnow(clock)
kr := &rotator{
db: db,
keyDuration: keyDuration,
clock: clock,
logger: logger,
features: []database.CryptoKeyFeature{database.CryptoKeyFeatureWorkspaceApps},
}
expiringKey := dbgen.CryptoKey(t, db, database.CryptoKey{
Feature: database.CryptoKeyFeatureWorkspaceApps,
StartsAt: now.Add(-keyDuration),
Sequence: 345,
})
err := kr.rotateKeys(ctx)
require.NoError(t, err)
keys, err := db.GetCryptoKeys(ctx)
require.NoError(t, err)
require.Len(t, keys, 2)
rotatedKey, err := db.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{
Feature: expiringKey.Feature,
Sequence: expiringKey.Sequence + 1,
})
require.NoError(t, err)
require.Equal(t, now.Add(defaultRotationInterval*3), rotatedKey.StartsAt.UTC())
})
// Test that the the deletes_at of a key that is well past its expiration
// Has its deletes_at field set to value that is relative
// to the current time to afford propagation time for the
// new key.
t.Run("ExtensivelyExpiredKey", func(t *testing.T) {
t.Parallel()
var (
db, _ = dbtestutil.NewDB(t)
clock = quartz.NewMock(t)
keyDuration = time.Hour * 24 * 3
logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug)
ctx = testutil.Context(t, testutil.WaitShort)
)
kr := &rotator{
db: db,
keyDuration: keyDuration,
clock: clock,
logger: logger,
features: []database.CryptoKeyFeature{database.CryptoKeyFeatureWorkspaceApps},
}
now := dbnow(clock)
expiredKey := dbgen.CryptoKey(t, db, database.CryptoKey{
Feature: database.CryptoKeyFeatureWorkspaceApps,
StartsAt: now.Add(-keyDuration - 2*time.Hour),
Sequence: 19,
})
deletedKey := dbgen.CryptoKey(t, db, database.CryptoKey{
Feature: database.CryptoKeyFeatureWorkspaceApps,
StartsAt: now,
Sequence: 20,
Secret: sql.NullString{
String: "deleted",
Valid: false,
},
})
err := kr.rotateKeys(ctx)
require.NoError(t, err)
keys, err := db.GetCryptoKeys(ctx)
require.NoError(t, err)
require.Len(t, keys, 2)
deletesAtKey, err := db.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{
Feature: expiredKey.Feature,
Sequence: expiredKey.Sequence,
})
deletesAt := sql.NullTime{
Time: now.Add(defaultRotationInterval * 3).Add(WorkspaceAppsTokenDuration + time.Hour),
Valid: true,
}
require.NoError(t, err)
requireKey(t, deletesAtKey, expiredKey.Feature, expiredKey.StartsAt.UTC(), deletesAt, expiredKey.Sequence)
newKey, err := db.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{
Feature: expiredKey.Feature,
Sequence: deletedKey.Sequence + 1,
})
require.NoError(t, err)
requireKey(t, newKey, expiredKey.Feature, now.Add(defaultRotationInterval*3), nullTime, deletedKey.Sequence+1)
})
}
func dbnow(c quartz.Clock) time.Time {
return dbtime.Time(c.Now().UTC())
}
func requireKey(t *testing.T, key database.CryptoKey, feature database.CryptoKeyFeature, startsAt time.Time, deletesAt sql.NullTime, sequence int32) {
t.Helper()
require.Equal(t, feature, key.Feature)
require.Equal(t, startsAt, key.StartsAt.UTC())
require.Equal(t, deletesAt.Valid, key.DeletesAt.Valid)
require.Equal(t, deletesAt.Time.UTC(), key.DeletesAt.Time.UTC())
require.Equal(t, sequence, key.Sequence)
secret, err := hex.DecodeString(key.Secret.String)
require.NoError(t, err)
switch key.Feature {
case database.CryptoKeyFeatureOidcConvert:
require.Len(t, secret, 32)
case database.CryptoKeyFeatureWorkspaceApps:
require.Len(t, secret, 96)
case database.CryptoKeyFeatureTailnetResume:
require.Len(t, secret, 64)
default:
t.Fatalf("unknown key feature: %s", key.Feature)
}
}
var nullTime = sql.NullTime{Time: time.Time{}, Valid: false}

View File

@ -0,0 +1,124 @@
package keyrotate_test
import (
"testing"
"time"
"github.com/stretchr/testify/require"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/keyrotate"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
)
func TestRotator(t *testing.T) {
t.Parallel()
t.Run("NoKeysOnInit", func(t *testing.T) {
t.Parallel()
var (
db, _ = dbtestutil.NewDB(t)
clock = quartz.NewMock(t)
logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug)
ctx = testutil.Context(t, testutil.WaitShort)
)
dbkeys, err := db.GetCryptoKeys(ctx)
require.NoError(t, err)
require.Len(t, dbkeys, 0)
err = keyrotate.StartRotator(ctx, logger, db, keyrotate.WithClock(clock))
require.NoError(t, err)
// Fetch the keys from the database and ensure they
// are as expected.
dbkeys, err = db.GetCryptoKeys(ctx)
require.NoError(t, err)
require.Len(t, dbkeys, len(database.AllCryptoKeyFeatureValues()))
requireContainsAllFeatures(t, dbkeys)
})
t.Run("RotateKeys", func(t *testing.T) {
t.Parallel()
var (
db, _ = dbtestutil.NewDB(t)
clock = quartz.NewMock(t)
logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug)
ctx = testutil.Context(t, testutil.WaitShort)
)
now := clock.Now().UTC()
rotatingKey := dbgen.CryptoKey(t, db, database.CryptoKey{
Feature: database.CryptoKeyFeatureWorkspaceApps,
StartsAt: now.Add(-keyrotate.DefaultKeyDuration + time.Hour + time.Minute),
Sequence: 12345,
})
trap := clock.Trap().TickerFunc()
t.Cleanup(trap.Close)
err := keyrotate.StartRotator(ctx, logger, db, keyrotate.WithClock(clock))
require.NoError(t, err)
initialKeyLen := len(database.AllCryptoKeyFeatureValues())
// Fetch the keys from the database and ensure they
// are as expected.
dbkeys, err := db.GetCryptoKeys(ctx)
require.NoError(t, err)
require.Len(t, dbkeys, initialKeyLen)
requireContainsAllFeatures(t, dbkeys)
trap.MustWait(ctx).Release()
_, wait := clock.AdvanceNext()
wait.MustWait(ctx)
keys, err := db.GetCryptoKeys(ctx)
require.NoError(t, err)
require.Len(t, keys, initialKeyLen+1)
newKey, err := db.GetLatestCryptoKeyByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps)
require.NoError(t, err)
require.Equal(t, rotatingKey.Sequence+1, newKey.Sequence)
require.Equal(t, rotatingKey.ExpiresAt(keyrotate.DefaultKeyDuration), newKey.StartsAt.UTC())
require.False(t, newKey.DeletesAt.Valid)
oldKey, err := db.GetCryptoKeyByFeatureAndSequence(ctx, database.GetCryptoKeyByFeatureAndSequenceParams{
Feature: rotatingKey.Feature,
Sequence: rotatingKey.Sequence,
})
expectedDeletesAt := rotatingKey.StartsAt.Add(keyrotate.DefaultKeyDuration + time.Hour + keyrotate.WorkspaceAppsTokenDuration)
require.NoError(t, err)
require.Equal(t, rotatingKey.StartsAt, oldKey.StartsAt)
require.True(t, oldKey.DeletesAt.Valid)
require.Equal(t, expectedDeletesAt, oldKey.DeletesAt.Time)
// Try rotating again and ensure no keys are rotated.
_, wait = clock.AdvanceNext()
wait.MustWait(ctx)
keys, err = db.GetCryptoKeys(ctx)
require.NoError(t, err)
require.Len(t, keys, initialKeyLen+1)
})
}
func requireContainsAllFeatures(t *testing.T, keys []database.CryptoKey) {
t.Helper()
features := make(map[database.CryptoKeyFeature]bool)
for _, key := range keys {
features[key.Feature] = true
}
for _, feature := range database.AllCryptoKeyFeatureValues() {
require.True(t, features[feature])
}
}