mirror of
https://github.com/coder/coder.git
synced 2025-07-21 01:28:49 +00:00
feat: enable key rotation (#15066)
This PR contains the remaining logic necessary to hook up key rotation to the product.
This commit is contained in:
@ -3,6 +3,7 @@ package cryptokeys
|
||||
import (
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
"sync"
|
||||
@ -12,7 +13,7 @@ import (
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/db2sdk"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
@ -25,7 +26,7 @@ var (
|
||||
)
|
||||
|
||||
type Fetcher interface {
|
||||
Fetch(ctx context.Context) ([]codersdk.CryptoKey, error)
|
||||
Fetch(ctx context.Context, feature codersdk.CryptoKeyFeature) ([]codersdk.CryptoKey, error)
|
||||
}
|
||||
|
||||
type EncryptionKeycache interface {
|
||||
@ -62,27 +63,26 @@ const (
|
||||
)
|
||||
|
||||
type DBFetcher struct {
|
||||
DB database.Store
|
||||
Feature database.CryptoKeyFeature
|
||||
DB database.Store
|
||||
}
|
||||
|
||||
func (d *DBFetcher) Fetch(ctx context.Context) ([]codersdk.CryptoKey, error) {
|
||||
keys, err := d.DB.GetCryptoKeysByFeature(ctx, d.Feature)
|
||||
func (d *DBFetcher) Fetch(ctx context.Context, feature codersdk.CryptoKeyFeature) ([]codersdk.CryptoKey, error) {
|
||||
keys, err := d.DB.GetCryptoKeysByFeature(ctx, database.CryptoKeyFeature(feature))
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("get crypto keys by feature: %w", err)
|
||||
}
|
||||
|
||||
return db2sdk.CryptoKeys(keys), nil
|
||||
return toSDKKeys(keys), nil
|
||||
}
|
||||
|
||||
// cache implements the caching functionality for both signing and encryption keys.
|
||||
type cache struct {
|
||||
clock quartz.Clock
|
||||
refreshCtx context.Context
|
||||
refreshCancel context.CancelFunc
|
||||
fetcher Fetcher
|
||||
logger slog.Logger
|
||||
feature codersdk.CryptoKeyFeature
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
clock quartz.Clock
|
||||
fetcher Fetcher
|
||||
logger slog.Logger
|
||||
feature codersdk.CryptoKeyFeature
|
||||
|
||||
mu sync.Mutex
|
||||
keys map[int32]codersdk.CryptoKey
|
||||
@ -109,7 +109,8 @@ func NewSigningCache(ctx context.Context, logger slog.Logger, fetcher Fetcher,
|
||||
if !isSigningKeyFeature(feature) {
|
||||
return nil, xerrors.Errorf("invalid feature: %s", feature)
|
||||
}
|
||||
return newCache(ctx, logger, fetcher, feature, opts...)
|
||||
logger = logger.Named(fmt.Sprintf("%s_signing_keycache", feature))
|
||||
return newCache(ctx, logger, fetcher, feature, opts...), nil
|
||||
}
|
||||
|
||||
func NewEncryptionCache(ctx context.Context, logger slog.Logger, fetcher Fetcher,
|
||||
@ -118,10 +119,11 @@ func NewEncryptionCache(ctx context.Context, logger slog.Logger, fetcher Fetcher
|
||||
if !isEncryptionKeyFeature(feature) {
|
||||
return nil, xerrors.Errorf("invalid feature: %s", feature)
|
||||
}
|
||||
return newCache(ctx, logger, fetcher, feature, opts...)
|
||||
logger = logger.Named(fmt.Sprintf("%s_encryption_keycache", feature))
|
||||
return newCache(ctx, logger, fetcher, feature, opts...), nil
|
||||
}
|
||||
|
||||
func newCache(ctx context.Context, logger slog.Logger, fetcher Fetcher, feature codersdk.CryptoKeyFeature, opts ...func(*cache)) (*cache, error) {
|
||||
func newCache(ctx context.Context, logger slog.Logger, fetcher Fetcher, feature codersdk.CryptoKeyFeature, opts ...func(*cache)) *cache {
|
||||
cache := &cache{
|
||||
clock: quartz.NewReal(),
|
||||
logger: logger,
|
||||
@ -134,16 +136,16 @@ func newCache(ctx context.Context, logger slog.Logger, fetcher Fetcher, feature
|
||||
}
|
||||
|
||||
cache.cond = sync.NewCond(&cache.mu)
|
||||
cache.refreshCtx, cache.refreshCancel = context.WithCancel(ctx)
|
||||
//nolint:gocritic // We need to be able to read the keys in order to cache them.
|
||||
cache.ctx, cache.cancel = context.WithCancel(dbauthz.AsKeyReader(ctx))
|
||||
cache.refresher = cache.clock.AfterFunc(refreshInterval, cache.refresh)
|
||||
|
||||
keys, err := cache.cryptoKeys(ctx)
|
||||
keys, err := cache.cryptoKeys(cache.ctx)
|
||||
if err != nil {
|
||||
cache.refreshCancel()
|
||||
return nil, xerrors.Errorf("initial fetch: %w", err)
|
||||
cache.logger.Critical(cache.ctx, "failed initial fetch", slog.Error(err))
|
||||
}
|
||||
cache.keys = keys
|
||||
return cache, nil
|
||||
return cache
|
||||
}
|
||||
|
||||
func (c *cache) EncryptingKey(ctx context.Context) (string, interface{}, error) {
|
||||
@ -151,6 +153,8 @@ func (c *cache) EncryptingKey(ctx context.Context) (string, interface{}, error)
|
||||
return "", nil, ErrInvalidFeature
|
||||
}
|
||||
|
||||
//nolint:gocritic // cache can only read crypto keys.
|
||||
ctx = dbauthz.AsKeyReader(ctx)
|
||||
return c.cryptoKey(ctx, latestSequence)
|
||||
}
|
||||
|
||||
@ -164,6 +168,8 @@ func (c *cache) DecryptingKey(ctx context.Context, id string) (interface{}, erro
|
||||
return nil, xerrors.Errorf("parse id: %w", err)
|
||||
}
|
||||
|
||||
//nolint:gocritic // cache can only read crypto keys.
|
||||
ctx = dbauthz.AsKeyReader(ctx)
|
||||
_, secret, err := c.cryptoKey(ctx, int32(seq))
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("crypto key: %w", err)
|
||||
@ -176,6 +182,8 @@ func (c *cache) SigningKey(ctx context.Context) (string, interface{}, error) {
|
||||
return "", nil, ErrInvalidFeature
|
||||
}
|
||||
|
||||
//nolint:gocritic // cache can only read crypto keys.
|
||||
ctx = dbauthz.AsKeyReader(ctx)
|
||||
return c.cryptoKey(ctx, latestSequence)
|
||||
}
|
||||
|
||||
@ -188,7 +196,8 @@ func (c *cache) VerifyingKey(ctx context.Context, id string) (interface{}, error
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("parse id: %w", err)
|
||||
}
|
||||
|
||||
//nolint:gocritic // cache can only read crypto keys.
|
||||
ctx = dbauthz.AsKeyReader(ctx)
|
||||
_, secret, err := c.cryptoKey(ctx, int32(seq))
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("crypto key: %w", err)
|
||||
@ -198,12 +207,12 @@ func (c *cache) VerifyingKey(ctx context.Context, id string) (interface{}, error
|
||||
}
|
||||
|
||||
func isEncryptionKeyFeature(feature codersdk.CryptoKeyFeature) bool {
|
||||
return feature == codersdk.CryptoKeyFeatureWorkspaceApp
|
||||
return feature == codersdk.CryptoKeyFeatureWorkspaceAppsAPIKey
|
||||
}
|
||||
|
||||
func isSigningKeyFeature(feature codersdk.CryptoKeyFeature) bool {
|
||||
switch feature {
|
||||
case codersdk.CryptoKeyFeatureTailnetResume, codersdk.CryptoKeyFeatureOIDCConvert:
|
||||
case codersdk.CryptoKeyFeatureTailnetResume, codersdk.CryptoKeyFeatureOIDCConvert, codersdk.CryptoKeyFeatureWorkspaceAppsToken:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
@ -292,14 +301,15 @@ func checkKey(key codersdk.CryptoKey, sequence int32, now time.Time) (string, []
|
||||
func (c *cache) refresh() {
|
||||
now := c.clock.Now("CryptoKeyCache", "refresh")
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.closed {
|
||||
c.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
// If something's already fetching, we don't need to do anything.
|
||||
if c.fetching {
|
||||
c.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
@ -307,20 +317,21 @@ func (c *cache) refresh() {
|
||||
// is ongoing but prior to the timer getting reset. In this case we want to
|
||||
// avoid double fetching.
|
||||
if now.Sub(c.lastFetch) < refreshInterval {
|
||||
c.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
c.fetching = true
|
||||
|
||||
c.mu.Unlock()
|
||||
keys, err := c.cryptoKeys(c.refreshCtx)
|
||||
keys, err := c.cryptoKeys(c.ctx)
|
||||
if err != nil {
|
||||
c.logger.Error(c.refreshCtx, "fetch crypto keys", slog.Error(err))
|
||||
c.logger.Error(c.ctx, "fetch crypto keys", slog.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
// We don't defer an unlock here due to the deferred unlock at the top of the function.
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.lastFetch = c.clock.Now()
|
||||
c.refresher.Reset(refreshInterval)
|
||||
@ -332,9 +343,9 @@ func (c *cache) refresh() {
|
||||
// cryptoKeys queries the control plane for the crypto keys.
|
||||
// Outside of initialization, this should only be called by fetch.
|
||||
func (c *cache) cryptoKeys(ctx context.Context) (map[int32]codersdk.CryptoKey, error) {
|
||||
keys, err := c.fetcher.Fetch(ctx)
|
||||
keys, err := c.fetcher.Fetch(ctx, c.feature)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("crypto keys: %w", err)
|
||||
return nil, xerrors.Errorf("fetch: %w", err)
|
||||
}
|
||||
cache := toKeyMap(keys, c.clock.Now())
|
||||
return cache, nil
|
||||
@ -361,9 +372,28 @@ func (c *cache) Close() error {
|
||||
}
|
||||
|
||||
c.closed = true
|
||||
c.refreshCancel()
|
||||
c.cancel()
|
||||
c.refresher.Stop()
|
||||
c.cond.Broadcast()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// We have to do this to avoid a circular dependency on db2sdk (cryptokeys -> db2sdk -> tailnet -> cryptokeys)
|
||||
func toSDKKeys(keys []database.CryptoKey) []codersdk.CryptoKey {
|
||||
into := make([]codersdk.CryptoKey, 0, len(keys))
|
||||
for _, key := range keys {
|
||||
into = append(into, toSDK(key))
|
||||
}
|
||||
return into
|
||||
}
|
||||
|
||||
func toSDK(key database.CryptoKey) codersdk.CryptoKey {
|
||||
return codersdk.CryptoKey{
|
||||
Feature: codersdk.CryptoKeyFeature(key.Feature),
|
||||
Sequence: key.Sequence,
|
||||
StartsAt: key.StartsAt,
|
||||
DeletesAt: key.DeletesAt.Time,
|
||||
Secret: key.Secret.String,
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user