feat: enable key rotation (#15066)

This PR contains the remaining logic necessary to hook up key rotation
to the product.
This commit is contained in:
Jon Ayers
2024-10-25 17:14:35 +01:00
committed by GitHub
parent ccfffc6911
commit cd890aa3a0
54 changed files with 1412 additions and 1129 deletions

View File

@ -3,6 +3,7 @@ package cryptokeys
import (
"context"
"encoding/hex"
"fmt"
"io"
"strconv"
"sync"
@ -12,7 +13,7 @@ import (
"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/db2sdk"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/quartz"
)
@ -25,7 +26,7 @@ var (
)
type Fetcher interface {
Fetch(ctx context.Context) ([]codersdk.CryptoKey, error)
Fetch(ctx context.Context, feature codersdk.CryptoKeyFeature) ([]codersdk.CryptoKey, error)
}
type EncryptionKeycache interface {
@ -62,27 +63,26 @@ const (
)
type DBFetcher struct {
DB database.Store
Feature database.CryptoKeyFeature
DB database.Store
}
func (d *DBFetcher) Fetch(ctx context.Context) ([]codersdk.CryptoKey, error) {
keys, err := d.DB.GetCryptoKeysByFeature(ctx, d.Feature)
func (d *DBFetcher) Fetch(ctx context.Context, feature codersdk.CryptoKeyFeature) ([]codersdk.CryptoKey, error) {
keys, err := d.DB.GetCryptoKeysByFeature(ctx, database.CryptoKeyFeature(feature))
if err != nil {
return nil, xerrors.Errorf("get crypto keys by feature: %w", err)
}
return db2sdk.CryptoKeys(keys), nil
return toSDKKeys(keys), nil
}
// cache implements the caching functionality for both signing and encryption keys.
type cache struct {
clock quartz.Clock
refreshCtx context.Context
refreshCancel context.CancelFunc
fetcher Fetcher
logger slog.Logger
feature codersdk.CryptoKeyFeature
ctx context.Context
cancel context.CancelFunc
clock quartz.Clock
fetcher Fetcher
logger slog.Logger
feature codersdk.CryptoKeyFeature
mu sync.Mutex
keys map[int32]codersdk.CryptoKey
@ -109,7 +109,8 @@ func NewSigningCache(ctx context.Context, logger slog.Logger, fetcher Fetcher,
if !isSigningKeyFeature(feature) {
return nil, xerrors.Errorf("invalid feature: %s", feature)
}
return newCache(ctx, logger, fetcher, feature, opts...)
logger = logger.Named(fmt.Sprintf("%s_signing_keycache", feature))
return newCache(ctx, logger, fetcher, feature, opts...), nil
}
func NewEncryptionCache(ctx context.Context, logger slog.Logger, fetcher Fetcher,
@ -118,10 +119,11 @@ func NewEncryptionCache(ctx context.Context, logger slog.Logger, fetcher Fetcher
if !isEncryptionKeyFeature(feature) {
return nil, xerrors.Errorf("invalid feature: %s", feature)
}
return newCache(ctx, logger, fetcher, feature, opts...)
logger = logger.Named(fmt.Sprintf("%s_encryption_keycache", feature))
return newCache(ctx, logger, fetcher, feature, opts...), nil
}
func newCache(ctx context.Context, logger slog.Logger, fetcher Fetcher, feature codersdk.CryptoKeyFeature, opts ...func(*cache)) (*cache, error) {
func newCache(ctx context.Context, logger slog.Logger, fetcher Fetcher, feature codersdk.CryptoKeyFeature, opts ...func(*cache)) *cache {
cache := &cache{
clock: quartz.NewReal(),
logger: logger,
@ -134,16 +136,16 @@ func newCache(ctx context.Context, logger slog.Logger, fetcher Fetcher, feature
}
cache.cond = sync.NewCond(&cache.mu)
cache.refreshCtx, cache.refreshCancel = context.WithCancel(ctx)
//nolint:gocritic // We need to be able to read the keys in order to cache them.
cache.ctx, cache.cancel = context.WithCancel(dbauthz.AsKeyReader(ctx))
cache.refresher = cache.clock.AfterFunc(refreshInterval, cache.refresh)
keys, err := cache.cryptoKeys(ctx)
keys, err := cache.cryptoKeys(cache.ctx)
if err != nil {
cache.refreshCancel()
return nil, xerrors.Errorf("initial fetch: %w", err)
cache.logger.Critical(cache.ctx, "failed initial fetch", slog.Error(err))
}
cache.keys = keys
return cache, nil
return cache
}
func (c *cache) EncryptingKey(ctx context.Context) (string, interface{}, error) {
@ -151,6 +153,8 @@ func (c *cache) EncryptingKey(ctx context.Context) (string, interface{}, error)
return "", nil, ErrInvalidFeature
}
//nolint:gocritic // cache can only read crypto keys.
ctx = dbauthz.AsKeyReader(ctx)
return c.cryptoKey(ctx, latestSequence)
}
@ -164,6 +168,8 @@ func (c *cache) DecryptingKey(ctx context.Context, id string) (interface{}, erro
return nil, xerrors.Errorf("parse id: %w", err)
}
//nolint:gocritic // cache can only read crypto keys.
ctx = dbauthz.AsKeyReader(ctx)
_, secret, err := c.cryptoKey(ctx, int32(seq))
if err != nil {
return nil, xerrors.Errorf("crypto key: %w", err)
@ -176,6 +182,8 @@ func (c *cache) SigningKey(ctx context.Context) (string, interface{}, error) {
return "", nil, ErrInvalidFeature
}
//nolint:gocritic // cache can only read crypto keys.
ctx = dbauthz.AsKeyReader(ctx)
return c.cryptoKey(ctx, latestSequence)
}
@ -188,7 +196,8 @@ func (c *cache) VerifyingKey(ctx context.Context, id string) (interface{}, error
if err != nil {
return nil, xerrors.Errorf("parse id: %w", err)
}
//nolint:gocritic // cache can only read crypto keys.
ctx = dbauthz.AsKeyReader(ctx)
_, secret, err := c.cryptoKey(ctx, int32(seq))
if err != nil {
return nil, xerrors.Errorf("crypto key: %w", err)
@ -198,12 +207,12 @@ func (c *cache) VerifyingKey(ctx context.Context, id string) (interface{}, error
}
func isEncryptionKeyFeature(feature codersdk.CryptoKeyFeature) bool {
return feature == codersdk.CryptoKeyFeatureWorkspaceApp
return feature == codersdk.CryptoKeyFeatureWorkspaceAppsAPIKey
}
func isSigningKeyFeature(feature codersdk.CryptoKeyFeature) bool {
switch feature {
case codersdk.CryptoKeyFeatureTailnetResume, codersdk.CryptoKeyFeatureOIDCConvert:
case codersdk.CryptoKeyFeatureTailnetResume, codersdk.CryptoKeyFeatureOIDCConvert, codersdk.CryptoKeyFeatureWorkspaceAppsToken:
return true
default:
return false
@ -292,14 +301,15 @@ func checkKey(key codersdk.CryptoKey, sequence int32, now time.Time) (string, []
func (c *cache) refresh() {
now := c.clock.Now("CryptoKeyCache", "refresh")
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
c.mu.Unlock()
return
}
// If something's already fetching, we don't need to do anything.
if c.fetching {
c.mu.Unlock()
return
}
@ -307,20 +317,21 @@ func (c *cache) refresh() {
// is ongoing but prior to the timer getting reset. In this case we want to
// avoid double fetching.
if now.Sub(c.lastFetch) < refreshInterval {
c.mu.Unlock()
return
}
c.fetching = true
c.mu.Unlock()
keys, err := c.cryptoKeys(c.refreshCtx)
keys, err := c.cryptoKeys(c.ctx)
if err != nil {
c.logger.Error(c.refreshCtx, "fetch crypto keys", slog.Error(err))
c.logger.Error(c.ctx, "fetch crypto keys", slog.Error(err))
return
}
// We don't defer an unlock here due to the deferred unlock at the top of the function.
c.mu.Lock()
defer c.mu.Unlock()
c.lastFetch = c.clock.Now()
c.refresher.Reset(refreshInterval)
@ -332,9 +343,9 @@ func (c *cache) refresh() {
// cryptoKeys queries the control plane for the crypto keys.
// Outside of initialization, this should only be called by fetch.
func (c *cache) cryptoKeys(ctx context.Context) (map[int32]codersdk.CryptoKey, error) {
keys, err := c.fetcher.Fetch(ctx)
keys, err := c.fetcher.Fetch(ctx, c.feature)
if err != nil {
return nil, xerrors.Errorf("crypto keys: %w", err)
return nil, xerrors.Errorf("fetch: %w", err)
}
cache := toKeyMap(keys, c.clock.Now())
return cache, nil
@ -361,9 +372,28 @@ func (c *cache) Close() error {
}
c.closed = true
c.refreshCancel()
c.cancel()
c.refresher.Stop()
c.cond.Broadcast()
return nil
}
// We have to do this to avoid a circular dependency on db2sdk (cryptokeys -> db2sdk -> tailnet -> cryptokeys)
func toSDKKeys(keys []database.CryptoKey) []codersdk.CryptoKey {
into := make([]codersdk.CryptoKey, 0, len(keys))
for _, key := range keys {
into = append(into, toSDK(key))
}
return into
}
func toSDK(key database.CryptoKey) codersdk.CryptoKey {
return codersdk.CryptoKey{
Feature: codersdk.CryptoKeyFeature(key.Feature),
Sequence: key.Sequence,
StartsAt: key.StartsAt,
DeletesAt: key.DeletesAt.Time,
Secret: key.Secret.String,
}
}