mirror of
https://github.com/coder/coder.git
synced 2025-07-15 22:20:27 +00:00
Use licenses to populate the Entitlements API (#3715)
* Use licenses for entitlements API Signed-off-by: Spike Curtis <spike@coder.com> * Tests for entitlements API Signed-off-by: Spike Curtis <spike@coder.com> * Add commentary about FeatureService Signed-off-by: Spike Curtis <spike@coder.com> * Lint Signed-off-by: Spike Curtis <spike@coder.com> * Quiet down the logs Signed-off-by: Spike Curtis <spike@coder.com> * Tell revive it's ok Signed-off-by: Spike Curtis <spike@coder.com> Signed-off-by: Spike Curtis <spike@coder.com>
This commit is contained in:
@ -1,12 +1,18 @@
|
||||
package coderd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/coderd"
|
||||
"github.com/coder/coder/coderd/rbac"
|
||||
)
|
||||
|
||||
const EnvAuditLogEnable = "CODER_AUDIT_LOG_ENABLE"
|
||||
|
||||
func NewEnterprise(options *coderd.Options) *coderd.API {
|
||||
var eOpts = *options
|
||||
if eOpts.Authorizer == nil {
|
||||
@ -26,5 +32,18 @@ func NewEnterprise(options *coderd.Options) *coderd.API {
|
||||
Authorizer: eOpts.Authorizer,
|
||||
Logger: eOpts.Logger,
|
||||
}).handler()
|
||||
en := Enablements{AuditLogs: true}
|
||||
auditLog := os.Getenv(EnvAuditLogEnable)
|
||||
auditLog = strings.ToLower(auditLog)
|
||||
if auditLog == "disable" || auditLog == "false" || auditLog == "0" || auditLog == "no" {
|
||||
en.AuditLogs = false
|
||||
}
|
||||
eOpts.FeaturesService = newFeaturesService(
|
||||
context.Background(),
|
||||
eOpts.Logger,
|
||||
eOpts.Database,
|
||||
eOpts.Pubsub,
|
||||
en,
|
||||
)
|
||||
return coderd.New(&eOpts)
|
||||
}
|
||||
|
261
enterprise/coderd/features.go
Normal file
261
enterprise/coderd/features.go
Normal file
@ -0,0 +1,261 @@
|
||||
package coderd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
|
||||
"cdr.dev/slog"
|
||||
|
||||
agpl "github.com/coder/coder/coderd"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/codersdk"
|
||||
)
|
||||
|
||||
type Enablements struct {
|
||||
AuditLogs bool
|
||||
}
|
||||
|
||||
type featuresService struct {
|
||||
logger slog.Logger
|
||||
database database.Store
|
||||
pubsub database.Pubsub
|
||||
keys map[string]ed25519.PublicKey
|
||||
enablements Enablements
|
||||
resyncInterval time.Duration
|
||||
|
||||
mu sync.RWMutex
|
||||
entitlements entitlements
|
||||
}
|
||||
|
||||
// newFeaturesService creates a FeaturesService and starts it. It will continue running for the
|
||||
// duration of the passed ctx.
|
||||
func newFeaturesService(
|
||||
ctx context.Context,
|
||||
logger slog.Logger,
|
||||
db database.Store,
|
||||
pubsub database.Pubsub,
|
||||
enablements Enablements,
|
||||
) agpl.FeaturesService {
|
||||
fs := &featuresService{
|
||||
logger: logger,
|
||||
database: db,
|
||||
pubsub: pubsub,
|
||||
keys: keys,
|
||||
enablements: enablements,
|
||||
resyncInterval: 10 * time.Minute,
|
||||
entitlements: entitlements{
|
||||
activeUsers: numericalEntitlement{
|
||||
entitlementLimit: entitlementLimit{
|
||||
unlimited: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
go fs.syncEntitlements(ctx)
|
||||
return fs
|
||||
}
|
||||
|
||||
func (s *featuresService) EntitlementsAPI(rw http.ResponseWriter, r *http.Request) {
|
||||
s.mu.RLock()
|
||||
e := s.entitlements
|
||||
s.mu.RUnlock()
|
||||
|
||||
resp := codersdk.Entitlements{
|
||||
Features: make(map[string]codersdk.Feature),
|
||||
Warnings: make([]string, 0),
|
||||
HasLicense: e.hasLicense,
|
||||
}
|
||||
|
||||
// User limit
|
||||
uf := codersdk.Feature{
|
||||
Entitlement: e.activeUsers.state.toSDK(),
|
||||
Enabled: true,
|
||||
}
|
||||
if !e.activeUsers.unlimited {
|
||||
n, err := s.database.GetActiveUserCount(r.Context())
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Unable to query database",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
uf.Actual = &n
|
||||
uf.Limit = &e.activeUsers.limit
|
||||
if n > e.activeUsers.limit {
|
||||
resp.Warnings = append(resp.Warnings,
|
||||
fmt.Sprintf(
|
||||
"Your deployment has %d active users but is only licensed for %d",
|
||||
n, e.activeUsers.limit))
|
||||
}
|
||||
}
|
||||
resp.Features[codersdk.FeatureUserLimit] = uf
|
||||
|
||||
// Audit logs
|
||||
resp.Features[codersdk.FeatureAuditLog] = codersdk.Feature{
|
||||
Entitlement: e.auditLogs.state.toSDK(),
|
||||
Enabled: s.enablements.AuditLogs,
|
||||
}
|
||||
if e.auditLogs.state == gracePeriod && s.enablements.AuditLogs {
|
||||
resp.Warnings = append(resp.Warnings,
|
||||
"Audit logging is enabled but your license for this feature is expired")
|
||||
}
|
||||
|
||||
httpapi.Write(rw, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
type entitlementState int
|
||||
|
||||
const (
|
||||
notEntitled entitlementState = iota
|
||||
gracePeriod
|
||||
entitled
|
||||
)
|
||||
|
||||
type entitlementLimit struct {
|
||||
unlimited bool
|
||||
limit int64
|
||||
}
|
||||
|
||||
type entitlement struct {
|
||||
state entitlementState
|
||||
}
|
||||
|
||||
func (s entitlementState) toSDK() codersdk.Entitlement {
|
||||
switch s {
|
||||
case notEntitled:
|
||||
return codersdk.EntitlementNotEntitled
|
||||
case gracePeriod:
|
||||
return codersdk.EntitlementGracePeriod
|
||||
case entitled:
|
||||
return codersdk.EntitlementEntitled
|
||||
default:
|
||||
panic("unknown entitlementState")
|
||||
}
|
||||
}
|
||||
|
||||
type numericalEntitlement struct {
|
||||
entitlement
|
||||
entitlementLimit
|
||||
}
|
||||
|
||||
type entitlements struct {
|
||||
hasLicense bool
|
||||
activeUsers numericalEntitlement
|
||||
auditLogs entitlement
|
||||
}
|
||||
|
||||
func (s *featuresService) getEntitlements(ctx context.Context) (entitlements, error) {
|
||||
licenses, err := s.database.GetUnexpiredLicenses(ctx)
|
||||
if err != nil {
|
||||
return entitlements{}, err
|
||||
}
|
||||
now := time.Now()
|
||||
e := entitlements{
|
||||
activeUsers: numericalEntitlement{
|
||||
entitlementLimit: entitlementLimit{
|
||||
unlimited: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, l := range licenses {
|
||||
claims, err := validateDBLicense(l, s.keys)
|
||||
if err != nil {
|
||||
s.logger.Debug(ctx, "skipping invalid license",
|
||||
slog.F("id", l.ID), slog.Error(err))
|
||||
continue
|
||||
}
|
||||
e.hasLicense = true
|
||||
thisEntitlement := entitled
|
||||
if now.After(claims.LicenseExpires.Time) {
|
||||
// if the grace period were over, the validation fails, so if we are after
|
||||
// LicenseExpires we must be in grace period.
|
||||
thisEntitlement = gracePeriod
|
||||
}
|
||||
if claims.Features.UserLimit > 0 {
|
||||
e.activeUsers.state = thisEntitlement
|
||||
e.activeUsers.unlimited = false
|
||||
e.activeUsers.limit = max(e.activeUsers.limit, claims.Features.UserLimit)
|
||||
}
|
||||
if claims.Features.AuditLog > 0 {
|
||||
e.auditLogs.state = thisEntitlement
|
||||
}
|
||||
}
|
||||
return e, nil
|
||||
}
|
||||
|
||||
func (s *featuresService) syncEntitlements(ctx context.Context) {
|
||||
eb := backoff.NewExponentialBackOff()
|
||||
eb.MaxElapsedTime = 0 // retry indefinitely
|
||||
b := backoff.WithContext(eb, ctx)
|
||||
updates := make(chan struct{}, 1)
|
||||
subscribed := false
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
// pass
|
||||
}
|
||||
if !subscribed {
|
||||
cancel, err := s.pubsub.Subscribe(PubSubEventLicenses, func(_ context.Context, _ []byte) {
|
||||
// don't block. If the channel is full, drop the event, as there is a resync
|
||||
// scheduled already.
|
||||
select {
|
||||
case updates <- struct{}{}:
|
||||
// pass
|
||||
default:
|
||||
// pass
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
s.logger.Warn(ctx, "failed to subscribe to license updates", slog.Error(err))
|
||||
time.Sleep(b.NextBackOff())
|
||||
continue
|
||||
}
|
||||
// nolint: revive
|
||||
defer cancel()
|
||||
subscribed = true
|
||||
s.logger.Debug(ctx, "successfully subscribed to pubsub")
|
||||
}
|
||||
|
||||
s.logger.Info(ctx, "syncing licensed entitlements")
|
||||
ents, err := s.getEntitlements(ctx)
|
||||
if err != nil {
|
||||
s.logger.Warn(ctx, "failed to get feature entitlements", slog.Error(err))
|
||||
time.Sleep(b.NextBackOff())
|
||||
continue
|
||||
}
|
||||
b.Reset()
|
||||
|
||||
s.mu.Lock()
|
||||
s.entitlements = ents
|
||||
s.mu.Unlock()
|
||||
s.logger.Debug(ctx, "synced licensed entitlements")
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(s.resyncInterval):
|
||||
continue
|
||||
case <-updates:
|
||||
s.logger.Debug(ctx, "got pubsub update")
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func max(a, b int64) int64 {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
337
enterprise/coderd/features_internal_test.go
Normal file
337
enterprise/coderd/features_internal_test.go
Normal file
@ -0,0 +1,337 @@
|
||||
package coderd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
|
||||
"github.com/coder/coder/coderd"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/databasefake"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/testutil"
|
||||
)
|
||||
|
||||
func TestFeaturesService_EntitlementsAPI(t *testing.T) {
|
||||
t.Parallel()
|
||||
logger := slogtest.Make(t, nil)
|
||||
|
||||
// Note that these are not actually used because we don't run the syncEntitlements
|
||||
// routine in this test.
|
||||
pubsub := database.NewPubsubInMemory()
|
||||
pub, _, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
keyID := "testing"
|
||||
|
||||
t.Run("NoLicense", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := databasefake.New()
|
||||
uut := &featuresService{
|
||||
logger: logger,
|
||||
database: db,
|
||||
pubsub: pubsub,
|
||||
keys: map[string]ed25519.PublicKey{keyID: pub},
|
||||
enablements: Enablements{AuditLogs: true},
|
||||
entitlements: entitlements{
|
||||
hasLicense: false,
|
||||
activeUsers: numericalEntitlement{
|
||||
entitlement{notEntitled},
|
||||
entitlementLimit{
|
||||
unlimited: true,
|
||||
},
|
||||
},
|
||||
auditLogs: entitlement{notEntitled},
|
||||
},
|
||||
}
|
||||
result := requestEntitlements(t, uut)
|
||||
assert.False(t, result.HasLicense)
|
||||
assert.Empty(t, result.Warnings)
|
||||
assert.Equal(t, codersdk.EntitlementNotEntitled, result.Features[codersdk.FeatureUserLimit].Entitlement)
|
||||
assert.Equal(t, codersdk.EntitlementNotEntitled, result.Features[codersdk.FeatureAuditLog].Entitlement)
|
||||
})
|
||||
|
||||
t.Run("FullLicense", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
defer cancel()
|
||||
db := databasefake.New()
|
||||
uut := &featuresService{
|
||||
logger: logger,
|
||||
database: db,
|
||||
pubsub: pubsub,
|
||||
keys: map[string]ed25519.PublicKey{keyID: pub},
|
||||
enablements: Enablements{AuditLogs: true},
|
||||
entitlements: entitlements{
|
||||
hasLicense: true,
|
||||
activeUsers: numericalEntitlement{
|
||||
entitlement{entitled},
|
||||
entitlementLimit{
|
||||
unlimited: false,
|
||||
limit: 100,
|
||||
},
|
||||
},
|
||||
auditLogs: entitlement{entitled},
|
||||
},
|
||||
}
|
||||
_, err = db.InsertUser(ctx, database.InsertUserParams{
|
||||
ID: uuid.UUID{},
|
||||
Email: "",
|
||||
Username: "",
|
||||
HashedPassword: nil,
|
||||
CreatedAt: time.Time{},
|
||||
UpdatedAt: time.Time{},
|
||||
RBACRoles: nil,
|
||||
LoginType: "",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
result := requestEntitlements(t, uut)
|
||||
assert.True(t, result.HasLicense)
|
||||
ul := result.Features[codersdk.FeatureUserLimit]
|
||||
assert.Equal(t, codersdk.EntitlementEntitled, ul.Entitlement)
|
||||
assert.Equal(t, int64(100), *ul.Limit)
|
||||
assert.Equal(t, int64(1), *ul.Actual)
|
||||
assert.True(t, ul.Enabled)
|
||||
al := result.Features[codersdk.FeatureAuditLog]
|
||||
assert.Equal(t, codersdk.EntitlementEntitled, al.Entitlement)
|
||||
assert.True(t, al.Enabled)
|
||||
assert.Nil(t, al.Limit)
|
||||
assert.Nil(t, al.Actual)
|
||||
assert.Empty(t, result.Warnings)
|
||||
})
|
||||
|
||||
t.Run("Warnings", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
defer cancel()
|
||||
db := databasefake.New()
|
||||
uut := &featuresService{
|
||||
logger: logger,
|
||||
database: db,
|
||||
pubsub: pubsub,
|
||||
keys: map[string]ed25519.PublicKey{keyID: pub},
|
||||
enablements: Enablements{AuditLogs: true},
|
||||
entitlements: entitlements{
|
||||
hasLicense: true,
|
||||
activeUsers: numericalEntitlement{
|
||||
entitlement{gracePeriod},
|
||||
entitlementLimit{
|
||||
unlimited: false,
|
||||
limit: 4,
|
||||
},
|
||||
},
|
||||
auditLogs: entitlement{gracePeriod},
|
||||
},
|
||||
}
|
||||
for i := byte(0); i < 5; i++ {
|
||||
_, err = db.InsertUser(ctx, database.InsertUserParams{
|
||||
ID: uuid.UUID{i},
|
||||
Email: "",
|
||||
Username: "",
|
||||
HashedPassword: nil,
|
||||
CreatedAt: time.Time{},
|
||||
UpdatedAt: time.Time{},
|
||||
RBACRoles: nil,
|
||||
LoginType: "",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
result := requestEntitlements(t, uut)
|
||||
assert.True(t, result.HasLicense)
|
||||
ul := result.Features[codersdk.FeatureUserLimit]
|
||||
assert.Equal(t, codersdk.EntitlementGracePeriod, ul.Entitlement)
|
||||
assert.Equal(t, int64(4), *ul.Limit)
|
||||
assert.Equal(t, int64(5), *ul.Actual)
|
||||
assert.True(t, ul.Enabled)
|
||||
al := result.Features[codersdk.FeatureAuditLog]
|
||||
assert.Equal(t, codersdk.EntitlementGracePeriod, al.Entitlement)
|
||||
assert.True(t, al.Enabled)
|
||||
assert.Nil(t, al.Limit)
|
||||
assert.Nil(t, al.Actual)
|
||||
assert.Len(t, result.Warnings, 2)
|
||||
assert.Contains(t, result.Warnings,
|
||||
"Your deployment has 5 active users but is only licensed for 4")
|
||||
assert.Contains(t, result.Warnings,
|
||||
"Audit logging is enabled but your license for this feature is expired")
|
||||
})
|
||||
}
|
||||
|
||||
func TestFeaturesServiceSyncEntitlements(t *testing.T) {
|
||||
t.Parallel()
|
||||
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
keyID := "testing"
|
||||
|
||||
// This tests that pubsub updates work by setting the resync interval very long
|
||||
t.Run("PubSub", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
defer cancel()
|
||||
logger := slogtest.Make(t, nil)
|
||||
pubsub := database.NewPubsubInMemory()
|
||||
db := databasefake.New()
|
||||
uut := &featuresService{
|
||||
logger: logger,
|
||||
database: db,
|
||||
pubsub: pubsub,
|
||||
keys: map[string]ed25519.PublicKey{keyID: pub},
|
||||
enablements: Enablements{AuditLogs: true},
|
||||
resyncInterval: time.Hour, // no resyncs during test
|
||||
entitlements: entitlements{},
|
||||
}
|
||||
|
||||
_, invalidKey, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Start of day, 3 licenses, one expired, one invalid
|
||||
_ = putLicense(ctx, t, db, priv, keyID, 1000, -2*time.Hour, -1*time.Hour)
|
||||
_ = putLicense(ctx, t, db, invalidKey, "invalid", 900, time.Hour, 2*time.Hour)
|
||||
l0 := putLicense(ctx, t, db, priv, keyID, 300, time.Hour, 2*time.Hour)
|
||||
|
||||
go uut.syncEntitlements(ctx)
|
||||
|
||||
testutil.Eventually(ctx, t, userLimitIs(uut, 300), testutil.IntervalFast)
|
||||
|
||||
// New license
|
||||
l1 := putLicense(ctx, t, db, priv, keyID, 305, time.Hour, 2*time.Hour)
|
||||
err = pubsub.Publish(PubSubEventLicenses, []byte("add"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// User limit goes up, because 305 > 300
|
||||
testutil.Eventually(ctx, t, userLimitIs(uut, 305), testutil.IntervalFast)
|
||||
|
||||
// New license with lower limit
|
||||
_ = putLicense(ctx, t, db, priv, keyID, 295, time.Hour, 2*time.Hour)
|
||||
err = pubsub.Publish(PubSubEventLicenses, []byte("add"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Need to delete the others before the limit lowers
|
||||
_, err = db.DeleteLicense(ctx, l1.ID)
|
||||
require.NoError(t, err)
|
||||
err = pubsub.Publish(PubSubEventLicenses, []byte("delete"))
|
||||
require.NoError(t, err)
|
||||
testutil.Eventually(ctx, t, userLimitIs(uut, 300), testutil.IntervalFast)
|
||||
|
||||
_, err = db.DeleteLicense(ctx, l0.ID)
|
||||
require.NoError(t, err)
|
||||
err = pubsub.Publish(PubSubEventLicenses, []byte("delete"))
|
||||
require.NoError(t, err)
|
||||
testutil.Eventually(ctx, t, userLimitIs(uut, 295), testutil.IntervalFast)
|
||||
})
|
||||
|
||||
// This tests that periodic resyncs work by setting the resync interval very fast and
|
||||
// not sending any pubsub updates.
|
||||
t.Run("Resyncs", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
defer cancel()
|
||||
logger := slogtest.Make(t, nil)
|
||||
pubsub := database.NewPubsubInMemory()
|
||||
db := databasefake.New()
|
||||
uut := &featuresService{
|
||||
logger: logger,
|
||||
database: db,
|
||||
pubsub: pubsub,
|
||||
keys: map[string]ed25519.PublicKey{keyID: pub},
|
||||
enablements: Enablements{AuditLogs: true},
|
||||
resyncInterval: 10 * time.Millisecond,
|
||||
entitlements: entitlements{},
|
||||
}
|
||||
|
||||
_, invalidKey, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Start of day, 3 licenses, one expired, one invalid
|
||||
_ = putLicense(ctx, t, db, priv, keyID, 1000, -2*time.Hour, -1*time.Hour)
|
||||
_ = putLicense(ctx, t, db, invalidKey, "invalid", 900, time.Hour, 2*time.Hour)
|
||||
l0 := putLicense(ctx, t, db, priv, keyID, 300, time.Hour, 2*time.Hour)
|
||||
|
||||
go uut.syncEntitlements(ctx)
|
||||
|
||||
testutil.Eventually(ctx, t, userLimitIs(uut, 300), testutil.IntervalFast)
|
||||
|
||||
// New license
|
||||
l1 := putLicense(ctx, t, db, priv, keyID, 305, time.Hour, 2*time.Hour)
|
||||
|
||||
// User limit goes up, because 305 > 300
|
||||
testutil.Eventually(ctx, t, userLimitIs(uut, 305), testutil.IntervalFast)
|
||||
|
||||
// New license with lower limit
|
||||
_ = putLicense(ctx, t, db, priv, keyID, 295, time.Hour, 2*time.Hour)
|
||||
|
||||
// Need to delete the others before the limit lowers
|
||||
_, err = db.DeleteLicense(ctx, l1.ID)
|
||||
require.NoError(t, err)
|
||||
testutil.Eventually(ctx, t, userLimitIs(uut, 300), testutil.IntervalFast)
|
||||
|
||||
_, err = db.DeleteLicense(ctx, l0.ID)
|
||||
require.NoError(t, err)
|
||||
testutil.Eventually(ctx, t, userLimitIs(uut, 295), testutil.IntervalFast)
|
||||
})
|
||||
}
|
||||
|
||||
func requestEntitlements(t *testing.T, uut coderd.FeaturesService) codersdk.Entitlements {
|
||||
t.Helper()
|
||||
r := httptest.NewRequest("GET", "https://example.com/api/v2/entitlements", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
uut.EntitlementsAPI(rw, r)
|
||||
resp := rw.Result()
|
||||
defer resp.Body.Close()
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
dec := json.NewDecoder(resp.Body)
|
||||
var result codersdk.Entitlements
|
||||
err := dec.Decode(&result)
|
||||
require.NoError(t, err)
|
||||
return result
|
||||
}
|
||||
|
||||
func putLicense(
|
||||
ctx context.Context, t *testing.T, db database.Store,
|
||||
k ed25519.PrivateKey, keyID string, userLimit int64,
|
||||
timeToGrace, timeToExpire time.Duration,
|
||||
) database.License {
|
||||
t.Helper()
|
||||
c := &Claims{
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Issuer: "test@testing.test",
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(timeToExpire)),
|
||||
NotBefore: jwt.NewNumericDate(time.Now().Add(-time.Minute)),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now().Add(-time.Minute)),
|
||||
},
|
||||
LicenseExpires: jwt.NewNumericDate(time.Now().Add(timeToGrace)),
|
||||
Version: CurrentVersion,
|
||||
Features: Features{
|
||||
UserLimit: userLimit,
|
||||
AuditLog: 1,
|
||||
},
|
||||
}
|
||||
j, err := makeLicense(c, k, keyID)
|
||||
require.NoError(t, err)
|
||||
l, err := db.InsertLicense(ctx, database.InsertLicenseParams{
|
||||
UploadedAt: c.IssuedAt.Time,
|
||||
JWT: j,
|
||||
Exp: c.ExpiresAt.Time,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return l
|
||||
}
|
||||
|
||||
func userLimitIs(fs *featuresService, limit int64) func(context.Context) bool {
|
||||
return func(_ context.Context) bool {
|
||||
fs.mu.RLock()
|
||||
defer fs.mu.RUnlock()
|
||||
return fs.entitlements.activeUsers.limit == limit
|
||||
}
|
||||
}
|
@ -64,8 +64,9 @@ type Claims struct {
|
||||
}
|
||||
|
||||
var (
|
||||
ErrInvalidVersion = xerrors.New("license must be version 3")
|
||||
ErrMissingKeyID = xerrors.Errorf("JOSE header must contain %s", HeaderKeyID)
|
||||
ErrInvalidVersion = xerrors.New("license must be version 3")
|
||||
ErrMissingKeyID = xerrors.Errorf("JOSE header must contain %s", HeaderKeyID)
|
||||
ErrMissingLicenseExpires = xerrors.New("license missing license_expires")
|
||||
)
|
||||
|
||||
// parseLicense parses the license and returns the claims. If the license's signature is invalid or
|
||||
@ -92,6 +93,30 @@ func parseLicense(l string, keys map[string]ed25519.PublicKey) (jwt.MapClaims, e
|
||||
return nil, xerrors.New("unable to parse Claims")
|
||||
}
|
||||
|
||||
// validateDBLicense validates a database.License record, and if valid, returns the claims. If
|
||||
// unparsable or invalid, it returns an error
|
||||
func validateDBLicense(l database.License, keys map[string]ed25519.PublicKey) (*Claims, error) {
|
||||
tok, err := jwt.ParseWithClaims(
|
||||
l.JWT,
|
||||
&Claims{},
|
||||
keyFunc(keys),
|
||||
jwt.WithValidMethods(ValidMethods),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if claims, ok := tok.Claims.(*Claims); ok && tok.Valid {
|
||||
if claims.Version != uint64(CurrentVersion) {
|
||||
return nil, ErrInvalidVersion
|
||||
}
|
||||
if claims.LicenseExpires == nil {
|
||||
return nil, ErrMissingLicenseExpires
|
||||
}
|
||||
return claims, nil
|
||||
}
|
||||
return nil, xerrors.New("unable to parse Claims")
|
||||
}
|
||||
|
||||
func keyFunc(keys map[string]ed25519.PublicKey) func(*jwt.Token) (interface{}, error) {
|
||||
return func(j *jwt.Token) (interface{}, error) {
|
||||
keyID, ok := j.Header[HeaderKeyID].(string)
|
||||
@ -297,5 +322,11 @@ func (a *licenseAPI) delete(rw http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
err = a.pubsub.Publish(PubSubEventLicenses, []byte("delete"))
|
||||
if err != nil {
|
||||
a.logger.Error(context.Background(), "failed to publish license delete", slog.Error(err))
|
||||
// don't fail the HTTP request, since we did write it successfully to the database
|
||||
}
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
Reference in New Issue
Block a user