From af125c37957b2b0b06b1a55c9732395ebe724804 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Fri, 23 Aug 2024 16:21:58 -0500 Subject: [PATCH] chore: refactor entitlements to be a safe object to use (#14406) * chore: refactor entitlements to be passable as an argument Previously, all usage of entitlements requires mutex usage on the api struct directly. This prevents passing the entitlements to a sub package. It also creates the possibility for misuse. --- coderd/coderd.go | 8 ++ coderd/entitlements/entitlements.go | 109 ++++++++++++++++++ coderd/entitlements/entitlements_test.go | 63 ++++++++++ codersdk/deployment.go | 6 + enterprise/coderd/coderd.go | 72 +++++------- enterprise/coderd/jfrog.go | 6 +- enterprise/coderd/license/metricscollector.go | 12 +- .../coderd/license/metricscollector_test.go | 16 +-- enterprise/coderd/licenses.go | 4 +- enterprise/coderd/provisionerdaemons.go | 6 +- enterprise/coderd/scim.go | 6 +- enterprise/coderd/templates.go | 18 +-- enterprise/coderd/userauth.go | 12 +- enterprise/coderd/users.go | 10 +- enterprise/coderd/workspaceagents.go | 5 +- enterprise/coderd/workspacequota.go | 4 +- site/site.go | 14 +-- 17 files changed, 247 insertions(+), 124 deletions(-) create mode 100644 coderd/entitlements/entitlements.go create mode 100644 coderd/entitlements/entitlements_test.go diff --git a/coderd/coderd.go b/coderd/coderd.go index d26a49b87b..8ec8400db4 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -37,6 +37,7 @@ import ( "tailscale.com/util/singleflight" "cdr.dev/slog" + "github.com/coder/coder/v2/coderd/entitlements" "github.com/coder/quartz" "github.com/coder/serpent" @@ -157,6 +158,9 @@ type Options struct { TrialGenerator func(ctx context.Context, body codersdk.LicensorTrialRequest) error // RefreshEntitlements is used to set correct entitlements after creating first user and generating trial license. RefreshEntitlements func(ctx context.Context) error + // Entitlements can come from the enterprise caller if enterprise code is + // included. + Entitlements *entitlements.Set // PostAuthAdditionalHeadersFunc is used to add additional headers to the response // after a successful authentication. // This is somewhat janky, but seemingly the only reasonable way to add a header @@ -263,6 +267,9 @@ func New(options *Options) *API { if options == nil { options = &Options{} } + if options.Entitlements == nil { + options.Entitlements = entitlements.New() + } if options.NewTicker == nil { options.NewTicker = func(duration time.Duration) (tick <-chan time.Time, done func()) { ticker := time.NewTicker(duration) @@ -500,6 +507,7 @@ func New(options *Options) *API { DocsURL: options.DeploymentValues.DocsURL.String(), AppearanceFetcher: &api.AppearanceFetcher, BuildInfo: buildInfo, + Entitlements: options.Entitlements, }) api.SiteHandler.Experiments.Store(&experiments) diff --git a/coderd/entitlements/entitlements.go b/coderd/entitlements/entitlements.go new file mode 100644 index 0000000000..9efc3a6317 --- /dev/null +++ b/coderd/entitlements/entitlements.go @@ -0,0 +1,109 @@ +package entitlements + +import ( + "encoding/json" + "net/http" + "sync" + "time" + + "github.com/coder/coder/v2/codersdk" +) + +type Set struct { + entitlementsMu sync.RWMutex + entitlements codersdk.Entitlements +} + +func New() *Set { + return &Set{ + // Some defaults for an unlicensed instance. + // These will be updated when coderd is initialized. + entitlements: codersdk.Entitlements{ + Features: map[codersdk.FeatureName]codersdk.Feature{}, + Warnings: nil, + Errors: nil, + HasLicense: false, + Trial: false, + RequireTelemetry: false, + RefreshedAt: time.Time{}, + }, + } +} + +// AllowRefresh returns whether the entitlements are allowed to be refreshed. +// If it returns false, that means it was recently refreshed and the caller should +// wait the returned duration before trying again. +func (l *Set) AllowRefresh(now time.Time) (bool, time.Duration) { + l.entitlementsMu.RLock() + defer l.entitlementsMu.RUnlock() + + diff := now.Sub(l.entitlements.RefreshedAt) + if diff < time.Minute { + return false, time.Minute - diff + } + + return true, 0 +} + +func (l *Set) Feature(name codersdk.FeatureName) (codersdk.Feature, bool) { + l.entitlementsMu.RLock() + defer l.entitlementsMu.RUnlock() + + f, ok := l.entitlements.Features[name] + return f, ok +} + +func (l *Set) Enabled(feature codersdk.FeatureName) bool { + l.entitlementsMu.RLock() + defer l.entitlementsMu.RUnlock() + + f, ok := l.entitlements.Features[feature] + if !ok { + return false + } + return f.Enabled +} + +// AsJSON is used to return this to the api without exposing the entitlements for +// mutation. +func (l *Set) AsJSON() json.RawMessage { + l.entitlementsMu.RLock() + defer l.entitlementsMu.RUnlock() + + b, _ := json.Marshal(l.entitlements) + return b +} + +func (l *Set) Replace(entitlements codersdk.Entitlements) { + l.entitlementsMu.Lock() + defer l.entitlementsMu.Unlock() + + l.entitlements = entitlements +} + +func (l *Set) Update(do func(entitlements *codersdk.Entitlements)) { + l.entitlementsMu.Lock() + defer l.entitlementsMu.Unlock() + + do(&l.entitlements) +} + +func (l *Set) FeatureChanged(featureName codersdk.FeatureName, newFeature codersdk.Feature) (initial, changed, enabled bool) { + l.entitlementsMu.RLock() + defer l.entitlementsMu.RUnlock() + + oldFeature := l.entitlements.Features[featureName] + if oldFeature.Enabled != newFeature.Enabled { + return false, true, newFeature.Enabled + } + return false, false, newFeature.Enabled +} + +func (l *Set) WriteEntitlementWarningHeaders(header http.Header) { + l.entitlementsMu.RLock() + defer l.entitlementsMu.RUnlock() + + for _, warning := range l.entitlements.Warnings { + header.Add(codersdk.EntitlementsWarningHeader, warning) + } +} diff --git a/coderd/entitlements/entitlements_test.go b/coderd/entitlements/entitlements_test.go new file mode 100644 index 0000000000..f5dbb1f7a7 --- /dev/null +++ b/coderd/entitlements/entitlements_test.go @@ -0,0 +1,63 @@ +package entitlements_test + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/coderd/entitlements" + "github.com/coder/coder/v2/codersdk" +) + +func TestUpdate(t *testing.T) { + t.Parallel() + + set := entitlements.New() + require.False(t, set.Enabled(codersdk.FeatureMultipleOrganizations)) + + set.Update(func(entitlements *codersdk.Entitlements) { + entitlements.Features[codersdk.FeatureMultipleOrganizations] = codersdk.Feature{ + Enabled: true, + Entitlement: codersdk.EntitlementEntitled, + } + }) + require.True(t, set.Enabled(codersdk.FeatureMultipleOrganizations)) +} + +func TestAllowRefresh(t *testing.T) { + t.Parallel() + + now := time.Now() + set := entitlements.New() + set.Update(func(entitlements *codersdk.Entitlements) { + entitlements.RefreshedAt = now + }) + + ok, wait := set.AllowRefresh(now) + require.False(t, ok) + require.InDelta(t, time.Minute.Seconds(), wait.Seconds(), 5) + + set.Update(func(entitlements *codersdk.Entitlements) { + entitlements.RefreshedAt = now.Add(time.Minute * -2) + }) + + ok, wait = set.AllowRefresh(now) + require.True(t, ok) + require.Equal(t, time.Duration(0), wait) +} + +func TestReplace(t *testing.T) { + t.Parallel() + + set := entitlements.New() + require.False(t, set.Enabled(codersdk.FeatureMultipleOrganizations)) + set.Replace(codersdk.Entitlements{ + Features: map[codersdk.FeatureName]codersdk.Feature{ + codersdk.FeatureMultipleOrganizations: { + Enabled: true, + }, + }, + }) + require.True(t, set.Enabled(codersdk.FeatureMultipleOrganizations)) +} diff --git a/codersdk/deployment.go b/codersdk/deployment.go index be99ab5388..fe0bebf744 100644 --- a/codersdk/deployment.go +++ b/codersdk/deployment.go @@ -35,6 +35,12 @@ const ( EntitlementNotEntitled Entitlement = "not_entitled" ) +// Entitled returns if the entitlement can be used. So this is true if it +// is entitled or still in it's grace period. +func (e Entitlement) Entitled() bool { + return e == EntitlementEntitled || e == EntitlementGracePeriod +} + // Weight converts the enum types to a numerical value for easier // comparisons. Easier than sets of if statements. func (e Entitlement) Weight() int { diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index effe2aa2ce..066bea50b2 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -15,6 +15,7 @@ import ( "github.com/coder/coder/v2/buildinfo" "github.com/coder/coder/v2/coderd/appearance" "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/entitlements" agplportsharing "github.com/coder/coder/v2/coderd/portsharing" "github.com/coder/coder/v2/coderd/rbac/policy" "github.com/coder/coder/v2/enterprise/coderd/portsharing" @@ -103,19 +104,26 @@ func New(ctx context.Context, options *Options) (_ *API, err error) { } return nil, xerrors.Errorf("init database encryption: %w", err) } + + entitlementsSet := entitlements.New() options.Database = cryptDB api := &API{ - ctx: ctx, - cancel: cancelFunc, - Options: options, + ctx: ctx, + cancel: cancelFunc, + Options: options, + entitlements: entitlementsSet, provisionerDaemonAuth: &provisionerDaemonAuth{ psk: options.ProvisionerDaemonPSK, authorizer: options.Authorizer, db: options.Database, }, + licenseMetricsCollector: &license.MetricsCollector{ + Entitlements: entitlementsSet, + }, } // This must happen before coderd initialization! options.PostAuthAdditionalHeadersFunc = api.writeEntitlementWarningsHeader + options.Options.Entitlements = api.entitlements api.AGPL = coderd.New(options.Options) defer func() { if err != nil { @@ -493,7 +501,7 @@ func New(ctx context.Context, options *Options) (_ *API, err error) { } api.AGPL.WorkspaceProxiesFetchUpdater.Store(&fetchUpdater) - err = api.PrometheusRegistry.Register(&api.licenseMetricsCollector) + err = api.PrometheusRegistry.Register(api.licenseMetricsCollector) if err != nil { return nil, xerrors.Errorf("unable to register license metrics collector") } @@ -553,13 +561,11 @@ type API struct { // ProxyHealth checks the reachability of all workspace proxies. ProxyHealth *proxyhealth.ProxyHealth - entitlementsUpdateMu sync.Mutex - entitlementsMu sync.RWMutex - entitlements codersdk.Entitlements + entitlements *entitlements.Set provisionerDaemonAuth *provisionerDaemonAuth - licenseMetricsCollector license.MetricsCollector + licenseMetricsCollector *license.MetricsCollector tailnetService *tailnet.ClientService } @@ -588,11 +594,8 @@ func (api *API) writeEntitlementWarningsHeader(a rbac.Subject, header http.Heade // has no roles. This is a normal user! return } - api.entitlementsMu.RLock() - defer api.entitlementsMu.RUnlock() - for _, warning := range api.entitlements.Warnings { - header.Add(codersdk.EntitlementsWarningHeader, warning) - } + + api.entitlements.WriteEntitlementWarningHeaders(header) } func (api *API) Close() error { @@ -614,9 +617,6 @@ func (api *API) Close() error { } func (api *API) updateEntitlements(ctx context.Context) error { - api.entitlementsUpdateMu.Lock() - defer api.entitlementsUpdateMu.Unlock() - replicas := api.replicaManager.AllPrimary() agedReplicas := make([]database.Replica, 0, len(replicas)) for _, replica := range replicas { @@ -632,7 +632,7 @@ func (api *API) updateEntitlements(ctx context.Context) error { agedReplicas = append(agedReplicas, replica) } - entitlements, err := license.Entitlements( + reloadedEntitlements, err := license.Entitlements( ctx, api.Database, len(agedReplicas), len(api.ExternalAuthConfigs), api.LicenseKeys, map[codersdk.FeatureName]bool{ codersdk.FeatureAuditLog: api.AuditLogging, @@ -652,29 +652,24 @@ func (api *API) updateEntitlements(ctx context.Context) error { return err } - if entitlements.RequireTelemetry && !api.DeploymentValues.Telemetry.Enable.Value() { + if reloadedEntitlements.RequireTelemetry && !api.DeploymentValues.Telemetry.Enable.Value() { // We can't fail because then the user couldn't remove the offending // license w/o a restart. // // We don't simply append to entitlement.Errors since we don't want any // enterprise features enabled. - api.entitlements.Errors = []string{ - "License requires telemetry but telemetry is disabled", - } + api.entitlements.Update(func(entitlements *codersdk.Entitlements) { + entitlements.Errors = []string{ + "License requires telemetry but telemetry is disabled", + } + }) + api.Logger.Error(ctx, "license requires telemetry enabled") return nil } featureChanged := func(featureName codersdk.FeatureName) (initial, changed, enabled bool) { - if api.entitlements.Features == nil { - return true, false, entitlements.Features[featureName].Enabled - } - oldFeature := api.entitlements.Features[featureName] - newFeature := entitlements.Features[featureName] - if oldFeature.Enabled != newFeature.Enabled { - return false, true, newFeature.Enabled - } - return false, false, newFeature.Enabled + return api.entitlements.FeatureChanged(featureName, reloadedEntitlements.Features[featureName]) } shouldUpdate := func(initial, changed, enabled bool) bool { @@ -831,20 +826,16 @@ func (api *API) updateEntitlements(ctx context.Context) error { } // External token encryption is soft-enforced - featureExternalTokenEncryption := entitlements.Features[codersdk.FeatureExternalTokenEncryption] + featureExternalTokenEncryption := reloadedEntitlements.Features[codersdk.FeatureExternalTokenEncryption] featureExternalTokenEncryption.Enabled = len(api.ExternalTokenEncryption) > 0 if featureExternalTokenEncryption.Enabled && featureExternalTokenEncryption.Entitlement != codersdk.EntitlementEntitled { msg := fmt.Sprintf("%s is enabled (due to setting external token encryption keys) but your license is not entitled to this feature.", codersdk.FeatureExternalTokenEncryption.Humanize()) api.Logger.Warn(ctx, msg) - entitlements.Warnings = append(entitlements.Warnings, msg) + reloadedEntitlements.Warnings = append(reloadedEntitlements.Warnings, msg) } - entitlements.Features[codersdk.FeatureExternalTokenEncryption] = featureExternalTokenEncryption + reloadedEntitlements.Features[codersdk.FeatureExternalTokenEncryption] = featureExternalTokenEncryption - api.entitlementsMu.Lock() - defer api.entitlementsMu.Unlock() - api.entitlements = entitlements - api.licenseMetricsCollector.Entitlements.Store(&entitlements) - api.AGPL.SiteHandler.Entitlements.Store(&entitlements) + api.entitlements.Replace(reloadedEntitlements) return nil } @@ -1024,10 +1015,7 @@ func derpMapper(logger slog.Logger, proxyHealth *proxyhealth.ProxyHealth) func(* // @Router /entitlements [get] func (api *API) serveEntitlements(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() - api.entitlementsMu.RLock() - entitlements := api.entitlements - api.entitlementsMu.RUnlock() - httpapi.Write(ctx, rw, http.StatusOK, entitlements) + httpapi.Write(ctx, rw, http.StatusOK, api.entitlements.AsJSON()) } func (api *API) runEntitlementsLoop(ctx context.Context) { diff --git a/enterprise/coderd/jfrog.go b/enterprise/coderd/jfrog.go index 9262c673eb..e1afe473c2 100644 --- a/enterprise/coderd/jfrog.go +++ b/enterprise/coderd/jfrog.go @@ -104,14 +104,10 @@ func (api *API) jFrogXrayScan(rw http.ResponseWriter, r *http.Request) { func (api *API) jfrogEnabledMW(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - api.entitlementsMu.RLock() // This doesn't actually use the external auth feature but we want // to lock this behind an enterprise license and it's somewhat // related to external auth (in that it is JFrog integration). - enabled := api.entitlements.Features[codersdk.FeatureMultipleExternalAuth].Enabled - api.entitlementsMu.RUnlock() - - if !enabled { + if !api.entitlements.Enabled(codersdk.FeatureMultipleExternalAuth) { httpapi.RouteNotFound(rw) return } diff --git a/enterprise/coderd/license/metricscollector.go b/enterprise/coderd/license/metricscollector.go index 85aac23b2f..8c0ccd83fb 100644 --- a/enterprise/coderd/license/metricscollector.go +++ b/enterprise/coderd/license/metricscollector.go @@ -1,10 +1,9 @@ package license import ( - "sync/atomic" - "github.com/prometheus/client_golang/prometheus" + "github.com/coder/coder/v2/coderd/entitlements" "github.com/coder/coder/v2/codersdk" ) @@ -15,7 +14,7 @@ var ( ) type MetricsCollector struct { - Entitlements atomic.Pointer[codersdk.Entitlements] + Entitlements *entitlements.Set } var _ prometheus.Collector = new(MetricsCollector) @@ -27,12 +26,7 @@ func (*MetricsCollector) Describe(descCh chan<- *prometheus.Desc) { } func (mc *MetricsCollector) Collect(metricsCh chan<- prometheus.Metric) { - entitlements := mc.Entitlements.Load() - if entitlements == nil || entitlements.Features == nil { - return - } - - userLimitEntitlement, ok := entitlements.Features[codersdk.FeatureUserLimit] + userLimitEntitlement, ok := mc.Entitlements.Feature(codersdk.FeatureUserLimit) if !ok { return } diff --git a/enterprise/coderd/license/metricscollector_test.go b/enterprise/coderd/license/metricscollector_test.go index 36661c8cdb..0ce9e8e4b5 100644 --- a/enterprise/coderd/license/metricscollector_test.go +++ b/enterprise/coderd/license/metricscollector_test.go @@ -9,6 +9,7 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/stretchr/testify/require" + "github.com/coder/coder/v2/coderd/entitlements" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/enterprise/coderd/license" ) @@ -25,14 +26,13 @@ func TestCollectLicenseMetrics(t *testing.T) { actualUsers = 4 userLimit = 7 ) - sut.Entitlements.Store(&codersdk.Entitlements{ - Features: map[codersdk.FeatureName]codersdk.Feature{ - codersdk.FeatureUserLimit: { - Enabled: true, - Actual: ptr.Int64(actualUsers), - Limit: ptr.Int64(userLimit), - }, - }, + sut.Entitlements = entitlements.New() + sut.Entitlements.Update(func(entitlements *codersdk.Entitlements) { + entitlements.Features[codersdk.FeatureUserLimit] = codersdk.Feature{ + Enabled: true, + Actual: ptr.Int64(actualUsers), + Limit: ptr.Int64(userLimit), + } }) registry.Register(&sut) diff --git a/enterprise/coderd/licenses.go b/enterprise/coderd/licenses.go index 54bc57b649..7db217234c 100644 --- a/enterprise/coderd/licenses.go +++ b/enterprise/coderd/licenses.go @@ -189,12 +189,10 @@ func (api *API) postRefreshEntitlements(rw http.ResponseWriter, r *http.Request) // Prevent abuse by limiting how often we allow a forced refresh. now := time.Now() - if diff := now.Sub(api.entitlements.RefreshedAt); diff < time.Minute { - wait := time.Minute - diff + if ok, wait := api.entitlements.AllowRefresh(now); !ok { rw.Header().Set("Retry-After", strconv.Itoa(int(wait.Seconds()))) httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ Message: fmt.Sprintf("Entitlements already recently refreshed, please wait %d seconds to force a new refresh", int(wait.Seconds())), - Detail: fmt.Sprintf("Last refresh at %s", now.UTC().String()), }) return } diff --git a/enterprise/coderd/provisionerdaemons.go b/enterprise/coderd/provisionerdaemons.go index ff5eb70944..52836da237 100644 --- a/enterprise/coderd/provisionerdaemons.go +++ b/enterprise/coderd/provisionerdaemons.go @@ -39,11 +39,7 @@ import ( func (api *API) provisionerDaemonsEnabledMW(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - api.entitlementsMu.RLock() - epd := api.entitlements.Features[codersdk.FeatureExternalProvisionerDaemons].Enabled - api.entitlementsMu.RUnlock() - - if !epd { + if !api.entitlements.Enabled(codersdk.FeatureExternalProvisionerDaemons) { httpapi.Write(r.Context(), rw, http.StatusForbidden, codersdk.Response{ Message: "External provisioner daemons is an Enterprise feature. Contact sales!", }) diff --git a/enterprise/coderd/scim.go b/enterprise/coderd/scim.go index 9a803c51d9..7211880096 100644 --- a/enterprise/coderd/scim.go +++ b/enterprise/coderd/scim.go @@ -25,11 +25,7 @@ import ( func (api *API) scimEnabledMW(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - api.entitlementsMu.RLock() - scim := api.entitlements.Features[codersdk.FeatureSCIM].Enabled - api.entitlementsMu.RUnlock() - - if !scim { + if !api.entitlements.Enabled(codersdk.FeatureSCIM) { httpapi.RouteNotFound(rw) return } diff --git a/enterprise/coderd/templates.go b/enterprise/coderd/templates.go index cf7a34530a..bd0b803cb9 100644 --- a/enterprise/coderd/templates.go +++ b/enterprise/coderd/templates.go @@ -342,28 +342,14 @@ func convertSDKTemplateRole(role codersdk.TemplateRole) []policy.Action { // TODO move to api.RequireFeatureMW when we are OK with changing the behavior. func (api *API) templateRBACEnabledMW(next http.Handler) http.Handler { - return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - api.entitlementsMu.RLock() - rbac := api.entitlements.Features[codersdk.FeatureTemplateRBAC].Enabled - api.entitlementsMu.RUnlock() - - if !rbac { - httpapi.RouteNotFound(rw) - return - } - - next.ServeHTTP(rw, r) - }) + return api.RequireFeatureMW(codersdk.FeatureTemplateRBAC)(next) } func (api *API) RequireFeatureMW(feat codersdk.FeatureName) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { // Entitlement must be enabled. - api.entitlementsMu.RLock() - enabled := api.entitlements.Features[feat].Enabled - api.entitlementsMu.RUnlock() - if !enabled { + if !api.entitlements.Enabled(feat) { licenseType := "a Premium" if feat.Enterprise() { licenseType = "an Enterprise" diff --git a/enterprise/coderd/userauth.go b/enterprise/coderd/userauth.go index a2dcac6085..5c972515b7 100644 --- a/enterprise/coderd/userauth.go +++ b/enterprise/coderd/userauth.go @@ -14,11 +14,7 @@ import ( // nolint: revive func (api *API) setUserGroups(ctx context.Context, logger slog.Logger, db database.Store, userID uuid.UUID, orgGroupNames map[uuid.UUID][]string, createMissingGroups bool) error { - api.entitlementsMu.RLock() - enabled := api.entitlements.Features[codersdk.FeatureTemplateRBAC].Enabled - api.entitlementsMu.RUnlock() - - if !enabled { + if !api.entitlements.Enabled(codersdk.FeatureTemplateRBAC) { return nil } @@ -82,11 +78,7 @@ func (api *API) setUserGroups(ctx context.Context, logger slog.Logger, db databa } func (api *API) setUserSiteRoles(ctx context.Context, logger slog.Logger, db database.Store, userID uuid.UUID, roles []string) error { - api.entitlementsMu.RLock() - enabled := api.entitlements.Features[codersdk.FeatureUserRoleManagement].Enabled - api.entitlementsMu.RUnlock() - - if !enabled { + if !api.entitlements.Enabled(codersdk.FeatureUserRoleManagement) { logger.Warn(ctx, "attempted to assign OIDC user roles without enterprise entitlement, roles left unchanged", slog.F("user_id", userID), slog.F("roles", roles), ) diff --git a/enterprise/coderd/users.go b/enterprise/coderd/users.go index 07e66708b1..808f91140f 100644 --- a/enterprise/coderd/users.go +++ b/enterprise/coderd/users.go @@ -18,18 +18,14 @@ const TimeFormatHHMM = "15:04" func (api *API) autostopRequirementEnabledMW(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - // Entitlement must be enabled. - api.entitlementsMu.RLock() - entitled := api.entitlements.Features[codersdk.FeatureAdvancedTemplateScheduling].Entitlement != codersdk.EntitlementNotEntitled - enabled := api.entitlements.Features[codersdk.FeatureAdvancedTemplateScheduling].Enabled - api.entitlementsMu.RUnlock() - if !entitled { + feature, ok := api.entitlements.Feature(codersdk.FeatureAdvancedTemplateScheduling) + if !ok || !feature.Entitlement.Entitled() { httpapi.Write(r.Context(), rw, http.StatusForbidden, codersdk.Response{ Message: "Advanced template scheduling (and user quiet hours schedule) is an Enterprise feature. Contact sales!", }) return } - if !enabled { + if !feature.Enabled { httpapi.Write(r.Context(), rw, http.StatusForbidden, codersdk.Response{ Message: "Advanced template scheduling (and user quiet hours schedule) is not enabled.", }) diff --git a/enterprise/coderd/workspaceagents.go b/enterprise/coderd/workspaceagents.go index d14aa9580b..baf4a9f4a9 100644 --- a/enterprise/coderd/workspaceagents.go +++ b/enterprise/coderd/workspaceagents.go @@ -9,10 +9,7 @@ import ( ) func (api *API) shouldBlockNonBrowserConnections(rw http.ResponseWriter) bool { - api.entitlementsMu.RLock() - browserOnly := api.entitlements.Features[codersdk.FeatureBrowserOnly].Enabled - api.entitlementsMu.RUnlock() - if browserOnly { + if api.entitlements.Enabled(codersdk.FeatureBrowserOnly) { httpapi.Write(context.Background(), rw, http.StatusConflict, codersdk.Response{ Message: "Non-browser connections are disabled for your deployment.", }) diff --git a/enterprise/coderd/workspacequota.go b/enterprise/coderd/workspacequota.go index f93ab4ffc4..da6546687d 100644 --- a/enterprise/coderd/workspacequota.go +++ b/enterprise/coderd/workspacequota.go @@ -155,9 +155,7 @@ func (api *API) workspaceQuota(rw http.ResponseWriter, r *http.Request) { user = httpmw.UserParam(r) ) - api.entitlementsMu.RLock() - licensed := api.entitlements.Features[codersdk.FeatureTemplateRBAC].Enabled - api.entitlementsMu.RUnlock() + licensed := api.entitlements.Enabled(codersdk.FeatureTemplateRBAC) // There are no groups and thus no allowance if RBAC isn't licensed. var quotaAllowance int64 = -1 diff --git a/site/site.go b/site/site.go index 42d7968b33..b168910ce3 100644 --- a/site/site.go +++ b/site/site.go @@ -38,6 +38,7 @@ import ( "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/db2sdk" "github.com/coder/coder/v2/coderd/database/dbauthz" + "github.com/coder/coder/v2/coderd/entitlements" "github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/coderd/httpmw" "github.com/coder/coder/v2/codersdk" @@ -79,6 +80,7 @@ type Options struct { DocsURL string BuildInfo codersdk.BuildInfoResponse AppearanceFetcher *atomic.Pointer[appearance.Fetcher] + Entitlements *entitlements.Set } func New(opts *Options) *Handler { @@ -91,6 +93,7 @@ func New(opts *Options) *Handler { handler := &Handler{ opts: opts, secureHeaders: secureHeaders(), + Entitlements: opts.Entitlements, } // html files are handled by a text/template. Non-html files @@ -173,7 +176,7 @@ type Handler struct { // regions if the user does not have the correct permissions. RegionsFetcher func(ctx context.Context) (any, error) - Entitlements atomic.Pointer[codersdk.Entitlements] + Entitlements *entitlements.Set Experiments atomic.Pointer[codersdk.Experiments] } @@ -379,15 +382,12 @@ func (h *Handler) renderHTMLWithState(r *http.Request, filePath string, state ht state.User = html.EscapeString(string(user)) } }() - entitlements := h.Entitlements.Load() - if entitlements != nil { + + if h.Entitlements != nil { wg.Add(1) go func() { defer wg.Done() - entitlements, err := json.Marshal(entitlements) - if err == nil { - state.Entitlements = html.EscapeString(string(entitlements)) - } + state.Entitlements = html.EscapeString(string(h.Entitlements.AsJSON())) }() }