feat: add jwt pkg (#14928)

- Adds a `jwtutils` package to be shared amongst the various
packages in the codebase that make use of JWTs. It's intended to help us
standardize on one library instead of some implementations using
`go-jose` and others using `golang-jwt`.

The main reason we're converging on `go-jose` is due to its support for
JWEs, `golang-jwt` also has a repo to handle it but it doesn't look
maintained: https://github.com/golang-jwt/jwe
This commit is contained in:
Jon Ayers
2024-10-03 21:09:52 -05:00
committed by GitHub
parent 50d9206950
commit 68ec532ca7
13 changed files with 1001 additions and 122 deletions

View File

@ -537,7 +537,8 @@ gen/mark-fresh:
tailnet/tailnettest/coordinatormock.go \
tailnet/tailnettest/coordinateemock.go \
tailnet/tailnettest/multiagentmock.go \
"
"
for file in $$files; do
echo "$$file"
if [ ! -f "$$file" ]; then

View File

@ -2,6 +2,7 @@ package cryptokeys
import (
"context"
"strconv"
"sync"
"time"
@ -9,16 +10,14 @@ import (
"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/db2sdk"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/quartz"
)
// never represents the maximum value for a time.Duration.
const never = 1<<63 - 1
// DBCache implements Keycache for callers with access to the database.
type DBCache struct {
// dbCache implements Keycache for callers with access to the database.
type dbCache struct {
db database.Store
feature database.CryptoKeyFeature
logger slog.Logger
@ -34,18 +33,34 @@ type DBCache struct {
closed bool
}
type DBCacheOption func(*DBCache)
type DBCacheOption func(*dbCache)
func WithDBCacheClock(clock quartz.Clock) DBCacheOption {
return func(d *DBCache) {
return func(d *dbCache) {
d.clock = clock
}
}
// NewDBCache creates a new DBCache. Close should be called to
// NewSigningCache creates a new DBCache. Close should be called to
// release resources associated with its internal timer.
func NewDBCache(logger slog.Logger, db database.Store, feature database.CryptoKeyFeature, opts ...func(*DBCache)) *DBCache {
d := &DBCache{
func NewSigningCache(logger slog.Logger, db database.Store, feature database.CryptoKeyFeature, opts ...func(*dbCache)) (SigningKeycache, error) {
if !isSigningKeyFeature(feature) {
return nil, ErrInvalidFeature
}
return newDBCache(logger, db, feature, opts...), nil
}
func NewEncryptionCache(logger slog.Logger, db database.Store, feature database.CryptoKeyFeature, opts ...func(*dbCache)) (EncryptionKeycache, error) {
if !isEncryptionKeyFeature(feature) {
return nil, ErrInvalidFeature
}
return newDBCache(logger, db, feature, opts...), nil
}
func newDBCache(logger slog.Logger, db database.Store, feature database.CryptoKeyFeature, opts ...func(*dbCache)) *dbCache {
d := &dbCache{
db: db,
feature: feature,
clock: quartz.NewReal(),
@ -56,23 +71,61 @@ func NewDBCache(logger slog.Logger, db database.Store, feature database.CryptoKe
opt(d)
}
// Initialize the timer. This will get properly initialized the first time we fetch.
d.timer = d.clock.AfterFunc(never, d.clear)
return d
}
// Verifying returns the CryptoKey with the given sequence number, provided that
func (d *dbCache) EncryptingKey(ctx context.Context) (id string, key interface{}, err error) {
if !isEncryptionKeyFeature(d.feature) {
return "", nil, ErrInvalidFeature
}
return d.latest(ctx)
}
func (d *dbCache) DecryptingKey(ctx context.Context, id string) (key interface{}, err error) {
if !isEncryptionKeyFeature(d.feature) {
return nil, ErrInvalidFeature
}
return d.sequence(ctx, id)
}
func (d *dbCache) SigningKey(ctx context.Context) (id string, key interface{}, err error) {
if !isSigningKeyFeature(d.feature) {
return "", nil, ErrInvalidFeature
}
return d.latest(ctx)
}
func (d *dbCache) VerifyingKey(ctx context.Context, id string) (key interface{}, err error) {
if !isSigningKeyFeature(d.feature) {
return nil, ErrInvalidFeature
}
return d.sequence(ctx, id)
}
// sequence returns the CryptoKey with the given sequence number, provided that
// it is neither deleted nor has breached its deletion date. It should only be
// used for verifying or decrypting payloads. To sign/encrypt call Signing.
func (d *DBCache) Verifying(ctx context.Context, sequence int32) (codersdk.CryptoKey, error) {
func (d *dbCache) sequence(ctx context.Context, id string) (interface{}, error) {
sequence, err := strconv.ParseInt(id, 10, 32)
if err != nil {
return nil, xerrors.Errorf("expecting sequence number got %q: %w", id, err)
}
d.keysMu.RLock()
if d.closed {
d.keysMu.RUnlock()
return codersdk.CryptoKey{}, ErrClosed
return nil, ErrClosed
}
now := d.clock.Now()
key, ok := d.keys[sequence]
key, ok := d.keys[int32(sequence)]
d.keysMu.RUnlock()
if ok {
return checkKey(key, now)
@ -82,35 +135,35 @@ func (d *DBCache) Verifying(ctx context.Context, sequence int32) (codersdk.Crypt
defer d.keysMu.Unlock()
if d.closed {
return codersdk.CryptoKey{}, ErrClosed
return nil, ErrClosed
}
key, ok = d.keys[sequence]
key, ok = d.keys[int32(sequence)]
if ok {
return checkKey(key, now)
}
err := d.fetch(ctx)
err = d.fetch(ctx)
if err != nil {
return codersdk.CryptoKey{}, xerrors.Errorf("fetch: %w", err)
return nil, xerrors.Errorf("fetch: %w", err)
}
key, ok = d.keys[sequence]
key, ok = d.keys[int32(sequence)]
if !ok {
return codersdk.CryptoKey{}, ErrKeyNotFound
return nil, ErrKeyNotFound
}
return checkKey(key, now)
}
// Signing returns the latest valid key for signing. A valid key is one that is
// latest returns the latest valid key for signing. A valid key is one that is
// both past its start time and before its deletion time.
func (d *DBCache) Signing(ctx context.Context) (codersdk.CryptoKey, error) {
func (d *dbCache) latest(ctx context.Context) (string, interface{}, error) {
d.keysMu.RLock()
if d.closed {
d.keysMu.RUnlock()
return codersdk.CryptoKey{}, ErrClosed
return "", nil, ErrClosed
}
latest := d.latestKey
@ -118,31 +171,31 @@ func (d *DBCache) Signing(ctx context.Context) (codersdk.CryptoKey, error) {
now := d.clock.Now()
if latest.CanSign(now) {
return db2sdk.CryptoKey(latest), nil
return idSecret(latest)
}
d.keysMu.Lock()
defer d.keysMu.Unlock()
if d.closed {
return codersdk.CryptoKey{}, ErrClosed
return "", nil, ErrClosed
}
if d.latestKey.CanSign(now) {
return db2sdk.CryptoKey(d.latestKey), nil
return idSecret(d.latestKey)
}
// Refetch all keys for this feature so we can find the latest valid key.
err := d.fetch(ctx)
if err != nil {
return codersdk.CryptoKey{}, xerrors.Errorf("fetch: %w", err)
return "", nil, xerrors.Errorf("fetch: %w", err)
}
return db2sdk.CryptoKey(d.latestKey), nil
return idSecret(d.latestKey)
}
// clear invalidates the cache. This forces the subsequent call to fetch fresh keys.
func (d *DBCache) clear() {
func (d *dbCache) clear() {
now := d.clock.Now("DBCache", "clear")
d.keysMu.Lock()
defer d.keysMu.Unlock()
@ -158,7 +211,7 @@ func (d *DBCache) clear() {
// fetch fetches all keys for the given feature and determines the latest key.
// It must be called while holding the keysMu lock.
func (d *DBCache) fetch(ctx context.Context) error {
func (d *dbCache) fetch(ctx context.Context) error {
keys, err := d.db.GetCryptoKeysByFeature(ctx, d.feature)
if err != nil {
return xerrors.Errorf("get crypto keys by feature: %w", err)
@ -189,22 +242,45 @@ func (d *DBCache) fetch(ctx context.Context) error {
return nil
}
func checkKey(key database.CryptoKey, now time.Time) (codersdk.CryptoKey, error) {
func checkKey(key database.CryptoKey, now time.Time) (interface{}, error) {
if !key.CanVerify(now) {
return codersdk.CryptoKey{}, ErrKeyInvalid
return nil, ErrKeyInvalid
}
return db2sdk.CryptoKey(key), nil
return key.DecodeString()
}
func (d *DBCache) Close() {
func (d *dbCache) Close() error {
d.keysMu.Lock()
defer d.keysMu.Unlock()
if d.closed {
return
return nil
}
d.timer.Stop()
d.closed = true
return nil
}
func isEncryptionKeyFeature(feature database.CryptoKeyFeature) bool {
return feature == database.CryptoKeyFeatureWorkspaceApps
}
func isSigningKeyFeature(feature database.CryptoKeyFeature) bool {
switch feature {
case database.CryptoKeyFeatureTailnetResume, database.CryptoKeyFeatureOidcConvert:
return true
default:
return false
}
}
func idSecret(k database.CryptoKey) (string, interface{}, error) {
key, err := k.DecodeString()
if err != nil {
return "", nil, xerrors.Errorf("decode key: %w", err)
}
return strconv.FormatInt(int64(k.Sequence), 10), key, nil
}

View File

@ -2,6 +2,7 @@ package cryptokeys
import (
"database/sql"
"strconv"
"testing"
"time"
@ -11,13 +12,12 @@ import (
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/db2sdk"
"github.com/coder/coder/v2/coderd/database/dbmock"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
)
func Test_Verifying(t *testing.T) {
func Test_version(t *testing.T) {
t.Parallel()
t.Run("HitsCache", func(t *testing.T) {
@ -35,7 +35,7 @@ func Test_Verifying(t *testing.T) {
Feature: database.CryptoKeyFeatureWorkspaceApps,
Sequence: 32,
Secret: sql.NullString{
String: "secret",
String: mustGenerateKey(t),
Valid: true,
},
}
@ -44,13 +44,13 @@ func Test_Verifying(t *testing.T) {
32: expectedKey,
}
k := NewDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock))
k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock))
defer k.Close()
k.keys = cache
got, err := k.Verifying(ctx, 32)
secret, err := k.sequence(ctx, keyID(expectedKey))
require.NoError(t, err)
require.Equal(t, db2sdk.CryptoKey(expectedKey), got)
require.Equal(t, decodedSecret(t, expectedKey), secret)
})
t.Run("MissesCache", func(t *testing.T) {
@ -69,20 +69,19 @@ func Test_Verifying(t *testing.T) {
Sequence: 33,
StartsAt: clock.Now(),
Secret: sql.NullString{
String: "secret",
String: mustGenerateKey(t),
Valid: true,
},
}
mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{expectedKey}, nil)
k := NewDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock))
k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock))
defer k.Close()
got, err := k.Verifying(ctx, 33)
got, err := k.sequence(ctx, keyID(expectedKey))
require.NoError(t, err)
require.Equal(t, db2sdk.CryptoKey(expectedKey), got)
require.Equal(t, db2sdk.CryptoKey(expectedKey), db2sdk.CryptoKey(k.latestKey))
require.Equal(t, decodedSecret(t, expectedKey), got)
})
t.Run("InvalidCachedKey", func(t *testing.T) {
@ -101,7 +100,7 @@ func Test_Verifying(t *testing.T) {
Feature: database.CryptoKeyFeatureWorkspaceApps,
Sequence: 32,
Secret: sql.NullString{
String: "secret",
String: mustGenerateKey(t),
Valid: true,
},
DeletesAt: sql.NullTime{
@ -111,11 +110,11 @@ func Test_Verifying(t *testing.T) {
},
}
k := NewDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock))
k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock))
defer k.Close()
k.keys = cache
_, err := k.Verifying(ctx, 32)
_, err := k.sequence(ctx, "32")
require.ErrorIs(t, err, ErrKeyInvalid)
})
@ -134,7 +133,7 @@ func Test_Verifying(t *testing.T) {
Feature: database.CryptoKeyFeatureWorkspaceApps,
Sequence: 32,
Secret: sql.NullString{
String: "secret",
String: mustGenerateKey(t),
Valid: true,
},
DeletesAt: sql.NullTime{
@ -144,15 +143,15 @@ func Test_Verifying(t *testing.T) {
}
mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{invalidKey}, nil)
k := NewDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock))
k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock))
defer k.Close()
_, err := k.Verifying(ctx, 32)
_, err := k.sequence(ctx, keyID(invalidKey))
require.ErrorIs(t, err, ErrKeyInvalid)
})
}
func Test_Signing(t *testing.T) {
func Test_latest(t *testing.T) {
t.Parallel()
t.Run("HitsCache", func(t *testing.T) {
@ -170,19 +169,20 @@ func Test_Signing(t *testing.T) {
Feature: database.CryptoKeyFeatureWorkspaceApps,
Sequence: 32,
Secret: sql.NullString{
String: "secret",
String: mustGenerateKey(t),
Valid: true,
},
StartsAt: clock.Now(),
}
k := NewDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock))
k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock))
defer k.Close()
k.latestKey = latestKey
got, err := k.Signing(ctx)
id, secret, err := k.latest(ctx)
require.NoError(t, err)
require.Equal(t, db2sdk.CryptoKey(latestKey), got)
require.Equal(t, keyID(latestKey), id)
require.Equal(t, decodedSecret(t, latestKey), secret)
})
t.Run("InvalidCachedKey", func(t *testing.T) {
@ -200,7 +200,7 @@ func Test_Signing(t *testing.T) {
Feature: database.CryptoKeyFeatureWorkspaceApps,
Sequence: 33,
Secret: sql.NullString{
String: "secret",
String: mustGenerateKey(t),
Valid: true,
},
StartsAt: clock.Now(),
@ -210,7 +210,7 @@ func Test_Signing(t *testing.T) {
Feature: database.CryptoKeyFeatureWorkspaceApps,
Sequence: 32,
Secret: sql.NullString{
String: "secret",
String: mustGenerateKey(t),
Valid: true,
},
StartsAt: clock.Now().Add(-time.Hour),
@ -222,13 +222,14 @@ func Test_Signing(t *testing.T) {
mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{latestKey}, nil)
k := NewDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock))
k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock))
defer k.Close()
k.latestKey = invalidKey
got, err := k.Signing(ctx)
id, secret, err := k.latest(ctx)
require.NoError(t, err)
require.Equal(t, db2sdk.CryptoKey(latestKey), got)
require.Equal(t, keyID(latestKey), id)
require.Equal(t, decodedSecret(t, latestKey), secret)
})
t.Run("UsesActiveKey", func(t *testing.T) {
@ -246,7 +247,7 @@ func Test_Signing(t *testing.T) {
Feature: database.CryptoKeyFeatureWorkspaceApps,
Sequence: 32,
Secret: sql.NullString{
String: "secret",
String: mustGenerateKey(t),
Valid: true,
},
StartsAt: clock.Now().Add(time.Hour),
@ -256,7 +257,7 @@ func Test_Signing(t *testing.T) {
Feature: database.CryptoKeyFeatureWorkspaceApps,
Sequence: 33,
Secret: sql.NullString{
String: "secret",
String: mustGenerateKey(t),
Valid: true,
},
StartsAt: clock.Now(),
@ -264,12 +265,13 @@ func Test_Signing(t *testing.T) {
mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{inactiveKey, activeKey}, nil)
k := NewDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock))
k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock))
defer k.Close()
got, err := k.Signing(ctx)
id, secret, err := k.latest(ctx)
require.NoError(t, err)
require.Equal(t, db2sdk.CryptoKey(activeKey), got)
require.Equal(t, keyID(activeKey), id)
require.Equal(t, decodedSecret(t, activeKey), secret)
})
t.Run("NoValidKeys", func(t *testing.T) {
@ -287,7 +289,7 @@ func Test_Signing(t *testing.T) {
Feature: database.CryptoKeyFeatureWorkspaceApps,
Sequence: 32,
Secret: sql.NullString{
String: "secret",
String: mustGenerateKey(t),
Valid: true,
},
StartsAt: clock.Now().Add(time.Hour),
@ -297,7 +299,7 @@ func Test_Signing(t *testing.T) {
Feature: database.CryptoKeyFeatureWorkspaceApps,
Sequence: 33,
Secret: sql.NullString{
String: "secret",
String: mustGenerateKey(t),
Valid: true,
},
StartsAt: clock.Now().Add(-time.Hour),
@ -309,10 +311,10 @@ func Test_Signing(t *testing.T) {
mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{inactiveKey, invalidKey}, nil)
k := NewDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock))
k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock))
defer k.Close()
_, err := k.Signing(ctx)
_, _, err := k.latest(ctx)
require.ErrorIs(t, err, ErrKeyInvalid)
})
}
@ -331,14 +333,14 @@ func Test_clear(t *testing.T) {
logger = slogtest.Make(t, nil)
)
k := NewDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock))
k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock))
defer k.Close()
activeKey := database.CryptoKey{
Feature: database.CryptoKeyFeatureWorkspaceApps,
Sequence: 33,
Secret: sql.NullString{
String: "secret",
String: mustGenerateKey(t),
Valid: true,
},
StartsAt: clock.Now(),
@ -346,7 +348,7 @@ func Test_clear(t *testing.T) {
mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{activeKey}, nil)
_, err := k.Signing(ctx)
_, _, err := k.latest(ctx)
require.NoError(t, err)
dur, wait := clock.AdvanceNext()
@ -367,14 +369,14 @@ func Test_clear(t *testing.T) {
logger = slogtest.Make(t, nil)
)
k := NewDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock))
k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock))
defer k.Close()
key := database.CryptoKey{
Feature: database.CryptoKeyFeatureWorkspaceApps,
Sequence: 32,
Secret: sql.NullString{
String: "secret",
String: mustGenerateKey(t),
Valid: true,
},
StartsAt: clock.Now(),
@ -386,9 +388,10 @@ func Test_clear(t *testing.T) {
// timer is reset and doesn't fire after another five minute.
clock.Advance(time.Minute * 5)
latest, err := k.Signing(ctx)
id, secret, err := k.latest(ctx)
require.NoError(t, err)
require.Equal(t, db2sdk.CryptoKey(key), latest)
require.Equal(t, keyID(key), id)
require.Equal(t, decodedSecret(t, key), secret)
// Advancing the clock now should require 10 minutes
// before the timer fires again.
@ -415,14 +418,14 @@ func Test_clear(t *testing.T) {
trap := clock.Trap().Now("clear")
k := NewDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock))
k := newDBCache(logger, mockDB, database.CryptoKeyFeatureWorkspaceApps, WithDBCacheClock(clock))
defer k.Close()
key := database.CryptoKey{
Feature: database.CryptoKeyFeatureWorkspaceApps,
Sequence: 32,
Secret: sql.NullString{
String: "secret",
String: mustGenerateKey(t),
Valid: true,
},
StartsAt: clock.Now(),
@ -431,9 +434,10 @@ func Test_clear(t *testing.T) {
mockDB.EXPECT().GetCryptoKeysByFeature(ctx, database.CryptoKeyFeatureWorkspaceApps).Return([]database.CryptoKey{key}, nil).Times(2)
// Move us past the initial timer.
latest, err := k.Signing(ctx)
id, secret, err := k.latest(ctx)
require.NoError(t, err)
require.Equal(t, db2sdk.CryptoKey(key), latest)
require.Equal(t, keyID(key), id)
require.Equal(t, decodedSecret(t, key), secret)
// Null these out so that we refetch.
k.keys = nil
k.latestKey = database.CryptoKey{}
@ -445,9 +449,10 @@ func Test_clear(t *testing.T) {
call := trap.MustWait(ctx)
// Refetch keys.
latest, err = k.Signing(ctx)
id, secret, err = k.latest(ctx)
require.NoError(t, err)
require.Equal(t, db2sdk.CryptoKey(key), latest)
require.Equal(t, keyID(key), id)
require.Equal(t, decodedSecret(t, key), secret)
// Let the rest of the timer function run.
// It should see that we have refetched keys and
@ -465,3 +470,21 @@ func Test_clear(t *testing.T) {
require.Equal(t, database.CryptoKey{}, k.latestKey)
})
}
func mustGenerateKey(t *testing.T) string {
t.Helper()
key, err := generateKey(64)
require.NoError(t, err)
return key
}
func keyID(key database.CryptoKey) string {
return strconv.FormatInt(int64(key.Sequence), 10)
}
func decodedSecret(t *testing.T, key database.CryptoKey) []byte {
t.Helper()
decoded, err := key.DecodeString()
require.NoError(t, err)
return decoded
}

View File

@ -1,6 +1,7 @@
package cryptokeys_test
import (
"strconv"
"testing"
"github.com/stretchr/testify/require"
@ -10,7 +11,6 @@ import (
"github.com/coder/coder/v2/coderd/cryptokeys"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/db2sdk"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/testutil"
@ -24,7 +24,7 @@ func TestMain(m *testing.M) {
func TestDBKeyCache(t *testing.T) {
t.Parallel()
t.Run("Verifying", func(t *testing.T) {
t.Run("VerifyingKey", func(t *testing.T) {
t.Parallel()
t.Run("HitsCache", func(t *testing.T) {
@ -38,17 +38,18 @@ func TestDBKeyCache(t *testing.T) {
)
key := dbgen.CryptoKey(t, db, database.CryptoKey{
Feature: database.CryptoKeyFeatureWorkspaceApps,
Feature: database.CryptoKeyFeatureOidcConvert,
Sequence: 1,
StartsAt: clock.Now().UTC(),
})
k := cryptokeys.NewDBCache(logger, db, database.CryptoKeyFeatureWorkspaceApps, cryptokeys.WithDBCacheClock(clock))
k, err := cryptokeys.NewSigningCache(logger, db, database.CryptoKeyFeatureOidcConvert, cryptokeys.WithDBCacheClock(clock))
require.NoError(t, err)
defer k.Close()
got, err := k.Verifying(ctx, key.Sequence)
got, err := k.VerifyingKey(ctx, keyID(key))
require.NoError(t, err)
require.Equal(t, db2sdk.CryptoKey(key), got)
require.Equal(t, decodedSecret(t, key), got)
})
t.Run("NotFound", func(t *testing.T) {
@ -61,10 +62,11 @@ func TestDBKeyCache(t *testing.T) {
logger = slogtest.Make(t, nil)
)
k := cryptokeys.NewDBCache(logger, db, database.CryptoKeyFeatureWorkspaceApps, cryptokeys.WithDBCacheClock(clock))
k, err := cryptokeys.NewSigningCache(logger, db, database.CryptoKeyFeatureOidcConvert, cryptokeys.WithDBCacheClock(clock))
require.NoError(t, err)
defer k.Close()
_, err := k.Verifying(ctx, 123)
_, err = k.VerifyingKey(ctx, "123")
require.ErrorIs(t, err, cryptokeys.ErrKeyNotFound)
})
})
@ -80,29 +82,31 @@ func TestDBKeyCache(t *testing.T) {
)
_ = dbgen.CryptoKey(t, db, database.CryptoKey{
Feature: database.CryptoKeyFeatureWorkspaceApps,
Feature: database.CryptoKeyFeatureOidcConvert,
Sequence: 10,
StartsAt: clock.Now().UTC(),
})
expectedKey := dbgen.CryptoKey(t, db, database.CryptoKey{
Feature: database.CryptoKeyFeatureWorkspaceApps,
Feature: database.CryptoKeyFeatureOidcConvert,
Sequence: 12,
StartsAt: clock.Now().UTC(),
})
_ = dbgen.CryptoKey(t, db, database.CryptoKey{
Feature: database.CryptoKeyFeatureWorkspaceApps,
Feature: database.CryptoKeyFeatureOidcConvert,
Sequence: 2,
StartsAt: clock.Now().UTC(),
})
k := cryptokeys.NewDBCache(logger, db, database.CryptoKeyFeatureWorkspaceApps, cryptokeys.WithDBCacheClock(clock))
k, err := cryptokeys.NewSigningCache(logger, db, database.CryptoKeyFeatureOidcConvert, cryptokeys.WithDBCacheClock(clock))
require.NoError(t, err)
defer k.Close()
got, err := k.Signing(ctx)
id, key, err := k.SigningKey(ctx)
require.NoError(t, err)
require.Equal(t, db2sdk.CryptoKey(expectedKey), got)
require.Equal(t, keyID(expectedKey), id)
require.Equal(t, decodedSecret(t, expectedKey), key)
})
t.Run("Closed", func(t *testing.T) {
@ -116,28 +120,97 @@ func TestDBKeyCache(t *testing.T) {
)
expectedKey := dbgen.CryptoKey(t, db, database.CryptoKey{
Feature: database.CryptoKeyFeatureWorkspaceApps,
Feature: database.CryptoKeyFeatureOidcConvert,
Sequence: 10,
StartsAt: clock.Now(),
})
k := cryptokeys.NewDBCache(logger, db, database.CryptoKeyFeatureWorkspaceApps, cryptokeys.WithDBCacheClock(clock))
k, err := cryptokeys.NewSigningCache(logger, db, database.CryptoKeyFeatureOidcConvert, cryptokeys.WithDBCacheClock(clock))
require.NoError(t, err)
defer k.Close()
got, err := k.Signing(ctx)
id, key, err := k.SigningKey(ctx)
require.NoError(t, err)
require.Equal(t, db2sdk.CryptoKey(expectedKey), got)
require.Equal(t, keyID(expectedKey), id)
require.Equal(t, decodedSecret(t, expectedKey), key)
got, err = k.Verifying(ctx, expectedKey.Sequence)
key, err = k.VerifyingKey(ctx, keyID(expectedKey))
require.NoError(t, err)
require.Equal(t, db2sdk.CryptoKey(expectedKey), got)
require.Equal(t, decodedSecret(t, expectedKey), key)
k.Close()
_, err = k.Signing(ctx)
_, _, err = k.SigningKey(ctx)
require.ErrorIs(t, err, cryptokeys.ErrClosed)
_, err = k.Verifying(ctx, expectedKey.Sequence)
_, err = k.VerifyingKey(ctx, keyID(expectedKey))
require.ErrorIs(t, err, cryptokeys.ErrClosed)
})
t.Run("InvalidSigningFeature", func(t *testing.T) {
t.Parallel()
var (
db, _ = dbtestutil.NewDB(t)
clock = quartz.NewMock(t)
logger = slogtest.Make(t, nil)
ctx = testutil.Context(t, testutil.WaitShort)
)
_, err := cryptokeys.NewSigningCache(logger, db, database.CryptoKeyFeatureWorkspaceApps, cryptokeys.WithDBCacheClock(clock))
require.ErrorIs(t, err, cryptokeys.ErrInvalidFeature)
// Instantiate a signing cache and try to use it as an encryption cache.
sc, err := cryptokeys.NewSigningCache(logger, db, database.CryptoKeyFeatureOidcConvert, cryptokeys.WithDBCacheClock(clock))
require.NoError(t, err)
defer sc.Close()
ec, ok := sc.(cryptokeys.EncryptionKeycache)
require.True(t, ok)
_, _, err = ec.EncryptingKey(ctx)
require.ErrorIs(t, err, cryptokeys.ErrInvalidFeature)
_, err = ec.DecryptingKey(ctx, "123")
require.ErrorIs(t, err, cryptokeys.ErrInvalidFeature)
})
t.Run("InvalidEncryptionFeature", func(t *testing.T) {
t.Parallel()
var (
db, _ = dbtestutil.NewDB(t)
clock = quartz.NewMock(t)
logger = slogtest.Make(t, nil)
ctx = testutil.Context(t, testutil.WaitShort)
)
_, err := cryptokeys.NewEncryptionCache(logger, db, database.CryptoKeyFeatureOidcConvert, cryptokeys.WithDBCacheClock(clock))
require.ErrorIs(t, err, cryptokeys.ErrInvalidFeature)
// Instantiate an encryption cache and try to use it as a signing cache.
ec, err := cryptokeys.NewEncryptionCache(logger, db, database.CryptoKeyFeatureWorkspaceApps, cryptokeys.WithDBCacheClock(clock))
require.NoError(t, err)
defer ec.Close()
sc, ok := ec.(cryptokeys.SigningKeycache)
require.True(t, ok)
_, _, err = sc.SigningKey(ctx)
require.ErrorIs(t, err, cryptokeys.ErrInvalidFeature)
_, err = sc.VerifyingKey(ctx, "123")
require.ErrorIs(t, err, cryptokeys.ErrInvalidFeature)
})
}
func keyID(key database.CryptoKey) string {
return strconv.FormatInt(int64(key.Sequence), 10)
}
func decodedSecret(t *testing.T, key database.CryptoKey) []byte {
t.Helper()
secret, err := key.DecodeString()
require.NoError(t, err)
return secret
}

2
coderd/cryptokeys/doc.go Normal file
View File

@ -0,0 +1,2 @@
// Package cryptokeys provides an abstraction for fetching internally used cryptographic keys mainly for JWT signing and verification.
package cryptokeys

View File

@ -2,20 +2,40 @@ package cryptokeys
import (
"context"
"io"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/codersdk"
)
var (
ErrKeyNotFound = xerrors.New("key not found")
ErrKeyInvalid = xerrors.New("key is invalid for use")
ErrClosed = xerrors.New("closed")
ErrKeyNotFound = xerrors.New("key not found")
ErrKeyInvalid = xerrors.New("key is invalid for use")
ErrClosed = xerrors.New("closed")
ErrInvalidFeature = xerrors.New("invalid feature for this operation")
)
// Keycache provides an abstraction for fetching signing keys.
type Keycache interface {
Signing(ctx context.Context) (codersdk.CryptoKey, error)
Verifying(ctx context.Context, sequence int32) (codersdk.CryptoKey, error)
type EncryptionKeycache interface {
// EncryptingKey returns the latest valid key for encrypting payloads. A valid
// key is one that is both past its start time and before its deletion time.
EncryptingKey(ctx context.Context) (id string, key interface{}, err error)
// DecryptingKey returns the key with the provided id which maps to its sequence
// number. The key is valid for decryption as long as it is not deleted or past
// its deletion date. We must allow for keys prior to their start time to
// account for clock skew between peers (one key may be past its start time on
// one machine while another is not).
DecryptingKey(ctx context.Context, id string) (key interface{}, err error)
io.Closer
}
type SigningKeycache interface {
// SigningKey returns the latest valid key for signing. A valid key is one
// that is both past its start time and before its deletion time.
SigningKey(ctx context.Context) (id string, key interface{}, err error)
// VerifyingKey returns the key with the provided id which should map to its
// sequence number. The key is valid for verifying as long as it is not deleted
// or past its deletion date. We must allow for keys prior to their start time
// to account for clock skew between peers (one key may be past its start time
// on one machine while another is not).
VerifyingKey(ctx context.Context, id string) (key interface{}, err error)
io.Closer
}

View File

@ -227,9 +227,9 @@ func (k *rotator) rotateKey(ctx context.Context, tx database.Store, key database
func generateNewSecret(feature database.CryptoKeyFeature) (string, error) {
switch feature {
case database.CryptoKeyFeatureWorkspaceApps:
return generateKey(96)
case database.CryptoKeyFeatureOidcConvert:
return generateKey(32)
case database.CryptoKeyFeatureOidcConvert:
return generateKey(64)
case database.CryptoKeyFeatureTailnetResume:
return generateKey(64)
}

View File

@ -588,9 +588,9 @@ func requireKey(t *testing.T, key database.CryptoKey, feature database.CryptoKey
switch key.Feature {
case database.CryptoKeyFeatureOidcConvert:
require.Len(t, secret, 32)
require.Len(t, secret, 64)
case database.CryptoKeyFeatureWorkspaceApps:
require.Len(t, secret, 96)
require.Len(t, secret, 32)
case database.CryptoKeyFeatureTailnetResume:
require.Len(t, secret, 64)
default:

View File

@ -988,9 +988,9 @@ func takeFirst[Value comparable](values ...Value) Value {
func newCryptoKeySecret(feature database.CryptoKeyFeature) (string, error) {
switch feature {
case database.CryptoKeyFeatureWorkspaceApps:
return generateCryptoKey(96)
case database.CryptoKeyFeatureOidcConvert:
return generateCryptoKey(32)
case database.CryptoKeyFeatureOidcConvert:
return generateCryptoKey(64)
case database.CryptoKeyFeatureTailnetResume:
return generateCryptoKey(64)
}

121
coderd/jwtutils/jwe.go Normal file
View File

@ -0,0 +1,121 @@
package jwtutils
import (
"context"
"encoding/json"
"time"
"github.com/go-jose/go-jose/v4"
"github.com/go-jose/go-jose/v4/jwt"
"golang.org/x/xerrors"
)
const (
encryptKeyAlgo = jose.A256GCMKW
encryptContentAlgo = jose.A256GCM
)
type EncryptKeyProvider interface {
EncryptingKey(ctx context.Context) (id string, key interface{}, err error)
}
type DecryptKeyProvider interface {
DecryptingKey(ctx context.Context, id string) (key interface{}, err error)
}
// Encrypt encrypts a token and returns it as a string.
func Encrypt(ctx context.Context, e EncryptKeyProvider, claims Claims) (string, error) {
id, key, err := e.EncryptingKey(ctx)
if err != nil {
return "", xerrors.Errorf("get signing key: %w", err)
}
encrypter, err := jose.NewEncrypter(
encryptContentAlgo,
jose.Recipient{
Algorithm: encryptKeyAlgo,
Key: key,
},
&jose.EncrypterOptions{
Compression: jose.DEFLATE,
ExtraHeaders: map[jose.HeaderKey]interface{}{
keyIDHeaderKey: id,
},
},
)
if err != nil {
return "", xerrors.Errorf("initialize encrypter: %w", err)
}
payload, err := json.Marshal(claims)
if err != nil {
return "", xerrors.Errorf("marshal payload: %w", err)
}
encrypted, err := encrypter.Encrypt(payload)
if err != nil {
return "", xerrors.Errorf("encrypt: %w", err)
}
compact, err := encrypted.CompactSerialize()
if err != nil {
return "", xerrors.Errorf("compact serialize: %w", err)
}
return compact, nil
}
// DecryptOptions are options for decrypting a JWE.
type DecryptOptions struct {
RegisteredClaims jwt.Expected
KeyAlgorithm jose.KeyAlgorithm
ContentEncryptionAlgorithm jose.ContentEncryption
}
// Decrypt decrypts the token using the provided key. It unmarshals into the provided claims.
func Decrypt(ctx context.Context, d DecryptKeyProvider, token string, claims Claims, opts ...func(*DecryptOptions)) error {
options := DecryptOptions{
RegisteredClaims: jwt.Expected{
Time: time.Now(),
},
KeyAlgorithm: encryptKeyAlgo,
ContentEncryptionAlgorithm: encryptContentAlgo,
}
for _, opt := range opts {
opt(&options)
}
object, err := jose.ParseEncrypted(token,
[]jose.KeyAlgorithm{options.KeyAlgorithm},
[]jose.ContentEncryption{options.ContentEncryptionAlgorithm},
)
if err != nil {
return xerrors.Errorf("parse jwe: %w", err)
}
if object.Header.Algorithm != string(encryptKeyAlgo) {
return xerrors.Errorf("expected JWE algorithm to be %q, got %q", encryptKeyAlgo, object.Header.Algorithm)
}
kid := object.Header.KeyID
if kid == "" {
return xerrors.Errorf("expected %q header to be a string", keyIDHeaderKey)
}
key, err := d.DecryptingKey(ctx, kid)
if err != nil {
return xerrors.Errorf("key with id %q: %w", kid, err)
}
decrypted, err := object.Decrypt(key)
if err != nil {
return xerrors.Errorf("decrypt: %w", err)
}
if err := json.Unmarshal(decrypted, &claims); err != nil {
return xerrors.Errorf("unmarshal: %w", err)
}
return claims.Validate(options.RegisteredClaims)
}

127
coderd/jwtutils/jws.go Normal file
View File

@ -0,0 +1,127 @@
package jwtutils
import (
"context"
"encoding/json"
"time"
"github.com/go-jose/go-jose/v4"
"github.com/go-jose/go-jose/v4/jwt"
"golang.org/x/xerrors"
)
const (
keyIDHeaderKey = "kid"
)
// Claims defines the payload for a JWT. Most callers
// should embed jwt.Claims
type Claims interface {
Validate(jwt.Expected) error
}
const (
signingAlgo = jose.HS512
)
type SigningKeyProvider interface {
SigningKey(ctx context.Context) (id string, key interface{}, err error)
}
type VerifyKeyProvider interface {
VerifyingKey(ctx context.Context, id string) (key interface{}, err error)
}
// Sign signs a token and returns it as a string.
func Sign(ctx context.Context, s SigningKeyProvider, claims Claims) (string, error) {
id, key, err := s.SigningKey(ctx)
if err != nil {
return "", xerrors.Errorf("get signing key: %w", err)
}
signer, err := jose.NewSigner(jose.SigningKey{
Algorithm: signingAlgo,
Key: key,
}, &jose.SignerOptions{
ExtraHeaders: map[jose.HeaderKey]interface{}{
keyIDHeaderKey: id,
},
})
if err != nil {
return "", xerrors.Errorf("new signer: %w", err)
}
payload, err := json.Marshal(claims)
if err != nil {
return "", xerrors.Errorf("marshal claims: %w", err)
}
signed, err := signer.Sign(payload)
if err != nil {
return "", xerrors.Errorf("sign payload: %w", err)
}
compact, err := signed.CompactSerialize()
if err != nil {
return "", xerrors.Errorf("compact serialize: %w", err)
}
return compact, nil
}
// VerifyOptions are options for verifying a JWT.
type VerifyOptions struct {
RegisteredClaims jwt.Expected
SignatureAlgorithm jose.SignatureAlgorithm
}
// Verify verifies that a token was signed by the provided key. It unmarshals into the provided claims.
func Verify(ctx context.Context, v VerifyKeyProvider, token string, claims Claims, opts ...func(*VerifyOptions)) error {
options := VerifyOptions{
RegisteredClaims: jwt.Expected{
Time: time.Now(),
},
SignatureAlgorithm: signingAlgo,
}
for _, opt := range opts {
opt(&options)
}
object, err := jose.ParseSigned(token, []jose.SignatureAlgorithm{options.SignatureAlgorithm})
if err != nil {
return xerrors.Errorf("parse JWS: %w", err)
}
if len(object.Signatures) != 1 {
return xerrors.New("expected 1 signature")
}
signature := object.Signatures[0]
if signature.Header.Algorithm != string(signingAlgo) {
return xerrors.Errorf("expected JWS algorithm to be %q, got %q", signingAlgo, object.Signatures[0].Header.Algorithm)
}
kid := signature.Header.KeyID
if kid == "" {
return xerrors.Errorf("expected %q header to be a string", keyIDHeaderKey)
}
key, err := v.VerifyingKey(ctx, kid)
if err != nil {
return xerrors.Errorf("key with id %q: %w", kid, err)
}
payload, err := object.Verify(key)
if err != nil {
return xerrors.Errorf("verify payload: %w", err)
}
err = json.Unmarshal(payload, &claims)
if err != nil {
return xerrors.Errorf("unmarshal payload: %w", err)
}
return claims.Validate(options.RegisteredClaims)
}

436
coderd/jwtutils/jwt_test.go Normal file
View File

@ -0,0 +1,436 @@
package jwtutils_test
import (
"context"
"crypto/rand"
"testing"
"time"
"github.com/go-jose/go-jose/v4"
"github.com/go-jose/go-jose/v4/jwt"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/cryptokeys"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/jwtutils"
"github.com/coder/coder/v2/testutil"
)
func TestClaims(t *testing.T) {
t.Parallel()
type tokenType struct {
Name string
KeySize int
Sign bool
}
types := []tokenType{
{
Name: "JWE",
Sign: false,
KeySize: 32,
},
{
Name: "JWS",
Sign: true,
KeySize: 64,
},
}
type testcase struct {
name string
claims jwtutils.Claims
expectedClaims jwt.Expected
expectedErr error
}
cases := []testcase{
{
name: "OK",
claims: jwt.Claims{
Issuer: "coder",
Subject: "user@coder.com",
Audience: jwt.Audience{"coder"},
Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)),
IssuedAt: jwt.NewNumericDate(time.Now()),
NotBefore: jwt.NewNumericDate(time.Now()),
},
},
{
name: "WrongIssuer",
claims: jwt.Claims{
Issuer: "coder",
Subject: "user@coder.com",
Audience: jwt.Audience{"coder"},
Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)),
IssuedAt: jwt.NewNumericDate(time.Now()),
NotBefore: jwt.NewNumericDate(time.Now()),
},
expectedClaims: jwt.Expected{
Issuer: "coder2",
},
expectedErr: jwt.ErrInvalidIssuer,
},
{
name: "WrongSubject",
claims: jwt.Claims{
Issuer: "coder",
Subject: "user@coder.com",
Audience: jwt.Audience{"coder"},
Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)),
IssuedAt: jwt.NewNumericDate(time.Now()),
NotBefore: jwt.NewNumericDate(time.Now()),
},
expectedClaims: jwt.Expected{
Subject: "user2@coder.com",
},
expectedErr: jwt.ErrInvalidSubject,
},
{
name: "WrongAudience",
claims: jwt.Claims{
Issuer: "coder",
Subject: "user@coder.com",
Audience: jwt.Audience{"coder"},
Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)),
IssuedAt: jwt.NewNumericDate(time.Now()),
NotBefore: jwt.NewNumericDate(time.Now()),
},
},
{
name: "Expired",
claims: jwt.Claims{
Issuer: "coder",
Subject: "user@coder.com",
Audience: jwt.Audience{"coder"},
Expiry: jwt.NewNumericDate(time.Now().Add(time.Minute)),
IssuedAt: jwt.NewNumericDate(time.Now()),
NotBefore: jwt.NewNumericDate(time.Now()),
},
expectedClaims: jwt.Expected{
Time: time.Now().Add(time.Minute * 3),
},
expectedErr: jwt.ErrExpired,
},
{
name: "IssuedInFuture",
claims: jwt.Claims{
Issuer: "coder",
Subject: "user@coder.com",
Audience: jwt.Audience{"coder"},
Expiry: jwt.NewNumericDate(time.Now().Add(time.Minute)),
IssuedAt: jwt.NewNumericDate(time.Now()),
},
expectedClaims: jwt.Expected{
Time: time.Now().Add(-time.Minute * 3),
},
expectedErr: jwt.ErrIssuedInTheFuture,
},
{
name: "IsBefore",
claims: jwt.Claims{
Issuer: "coder",
Subject: "user@coder.com",
Audience: jwt.Audience{"coder"},
Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)),
IssuedAt: jwt.NewNumericDate(time.Now()),
NotBefore: jwt.NewNumericDate(time.Now().Add(time.Minute * 5)),
},
expectedClaims: jwt.Expected{
Time: time.Now().Add(time.Minute * 3),
},
expectedErr: jwt.ErrNotValidYet,
},
}
for _, tt := range types {
tt := tt
t.Run(tt.Name, func(t *testing.T) {
t.Parallel()
for _, c := range cases {
c := c
t.Run(c.name, func(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitShort)
key = newKey(t, tt.KeySize)
token string
err error
)
if tt.Sign {
token, err = jwtutils.Sign(ctx, key, c.claims)
} else {
token, err = jwtutils.Encrypt(ctx, key, c.claims)
}
require.NoError(t, err)
var actual jwt.Claims
if tt.Sign {
err = jwtutils.Verify(ctx, key, token, &actual, withVerifyExpected(c.expectedClaims))
} else {
err = jwtutils.Decrypt(ctx, key, token, &actual, withDecryptExpected(c.expectedClaims))
}
if c.expectedErr != nil {
require.ErrorIs(t, err, c.expectedErr)
} else {
require.NoError(t, err)
require.Equal(t, c.claims, actual)
}
})
}
})
}
}
func TestJWS(t *testing.T) {
t.Parallel()
t.Run("WrongSignatureAlgorithm", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
key := newKey(t, 64)
token, err := jwtutils.Sign(ctx, key, jwt.Claims{})
require.NoError(t, err)
var actual testClaims
err = jwtutils.Verify(ctx, key, token, &actual, withSignatureAlgorithm(jose.HS256))
require.Error(t, err)
})
t.Run("CustomClaims", func(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitShort)
key = newKey(t, 64)
)
expected := testClaims{
MyClaim: "my_value",
}
token, err := jwtutils.Sign(ctx, key, expected)
require.NoError(t, err)
var actual testClaims
err = jwtutils.Verify(ctx, key, token, &actual, withVerifyExpected(jwt.Expected{}))
require.NoError(t, err)
require.Equal(t, expected, actual)
})
t.Run("WithKeycache", func(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitShort)
db, _ = dbtestutil.NewDB(t)
_ = dbgen.CryptoKey(t, db, database.CryptoKey{
Feature: database.CryptoKeyFeatureOidcConvert,
StartsAt: time.Now(),
})
log = slogtest.Make(t, nil)
)
cache, err := cryptokeys.NewSigningCache(log, db, database.CryptoKeyFeatureOidcConvert)
require.NoError(t, err)
claims := testClaims{
MyClaim: "my_value",
Claims: jwt.Claims{
Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)),
},
}
token, err := jwtutils.Sign(ctx, cache, claims)
require.NoError(t, err)
var actual testClaims
err = jwtutils.Verify(ctx, cache, token, &actual)
require.NoError(t, err)
require.Equal(t, claims, actual)
})
}
func TestJWE(t *testing.T) {
t.Parallel()
t.Run("WrongKeyAlgorithm", func(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitShort)
key = newKey(t, 32)
)
token, err := jwtutils.Encrypt(ctx, key, jwt.Claims{})
require.NoError(t, err)
var actual testClaims
err = jwtutils.Decrypt(ctx, key, token, &actual, withKeyAlgorithm(jose.A128GCMKW))
require.Error(t, err)
})
t.Run("WrongContentyEncryption", func(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitShort)
key = newKey(t, 32)
)
token, err := jwtutils.Encrypt(ctx, key, jwt.Claims{})
require.NoError(t, err)
var actual testClaims
err = jwtutils.Decrypt(ctx, key, token, &actual, withContentEncryptionAlgorithm(jose.A128GCM))
require.Error(t, err)
})
t.Run("CustomClaims", func(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitShort)
key = newKey(t, 32)
)
expected := testClaims{
MyClaim: "my_value",
}
token, err := jwtutils.Encrypt(ctx, key, expected)
require.NoError(t, err)
var actual testClaims
err = jwtutils.Decrypt(ctx, key, token, &actual, withDecryptExpected(jwt.Expected{}))
require.NoError(t, err)
require.Equal(t, expected, actual)
})
t.Run("WithKeycache", func(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitShort)
db, _ = dbtestutil.NewDB(t)
_ = dbgen.CryptoKey(t, db, database.CryptoKey{
Feature: database.CryptoKeyFeatureWorkspaceApps,
StartsAt: time.Now(),
})
log = slogtest.Make(t, nil)
)
cache, err := cryptokeys.NewEncryptionCache(log, db, database.CryptoKeyFeatureWorkspaceApps)
require.NoError(t, err)
claims := testClaims{
MyClaim: "my_value",
Claims: jwt.Claims{
Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour)),
},
}
token, err := jwtutils.Encrypt(ctx, cache, claims)
require.NoError(t, err)
var actual testClaims
err = jwtutils.Decrypt(ctx, cache, token, &actual)
require.NoError(t, err)
require.Equal(t, claims, actual)
})
}
func generateSecret(t *testing.T, keySize int) []byte {
t.Helper()
b := make([]byte, keySize)
_, err := rand.Read(b)
require.NoError(t, err)
return b
}
type testClaims struct {
MyClaim string `json:"my_claim"`
jwt.Claims
}
func withDecryptExpected(e jwt.Expected) func(*jwtutils.DecryptOptions) {
return func(opts *jwtutils.DecryptOptions) {
opts.RegisteredClaims = e
}
}
func withVerifyExpected(e jwt.Expected) func(*jwtutils.VerifyOptions) {
return func(opts *jwtutils.VerifyOptions) {
opts.RegisteredClaims = e
}
}
func withSignatureAlgorithm(alg jose.SignatureAlgorithm) func(*jwtutils.VerifyOptions) {
return func(opts *jwtutils.VerifyOptions) {
opts.SignatureAlgorithm = alg
}
}
func withKeyAlgorithm(alg jose.KeyAlgorithm) func(*jwtutils.DecryptOptions) {
return func(opts *jwtutils.DecryptOptions) {
opts.KeyAlgorithm = alg
}
}
func withContentEncryptionAlgorithm(alg jose.ContentEncryption) func(*jwtutils.DecryptOptions) {
return func(opts *jwtutils.DecryptOptions) {
opts.ContentEncryptionAlgorithm = alg
}
}
type key struct {
t testing.TB
id string
secret []byte
}
func newKey(t *testing.T, size int) *key {
t.Helper()
id := uuid.New().String()
secret := generateSecret(t, size)
return &key{
t: t,
id: id,
secret: secret,
}
}
func (k *key) SigningKey(_ context.Context) (id string, key interface{}, err error) {
return k.id, k.secret, nil
}
func (k *key) VerifyingKey(_ context.Context, id string) (key interface{}, err error) {
k.t.Helper()
require.Equal(k.t, k.id, id)
return k.secret, nil
}
func (k *key) EncryptingKey(_ context.Context) (id string, key interface{}, err error) {
return k.id, k.secret, nil
}
func (k *key) DecryptingKey(_ context.Context, id string) (key interface{}, err error) {
k.t.Helper()
require.Equal(k.t, k.id, id)
return k.secret, nil
}

2
go.mod
View File

@ -207,6 +207,7 @@ require (
github.com/coder/serpent v0.8.0
github.com/emersion/go-sasl v0.0.0-20200509203442-7bfe0ed36a21
github.com/emersion/go-smtp v0.21.2
github.com/go-jose/go-jose/v4 v4.0.2
github.com/gomarkdown/markdown v0.0.0-20231222211730-1d6d20845b47
github.com/google/go-github/v61 v61.0.0
github.com/mocktools/go-smtp-mock/v2 v2.3.0
@ -224,7 +225,6 @@ require (
github.com/charmbracelet/x/ansi v0.2.3 // indirect
github.com/charmbracelet/x/term v0.2.0 // indirect
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
github.com/go-jose/go-jose/v4 v4.0.2 // indirect
github.com/go-viper/mapstructure/v2 v2.0.0 // indirect
github.com/hashicorp/go-plugin v1.6.1 // indirect
github.com/hashicorp/go-retryablehttp v0.7.7 // indirect