mirror of
https://github.com/coder/coder.git
synced 2025-07-15 22:20:27 +00:00
Use licenses to populate the Entitlements API (#3715)
* Use licenses for entitlements API Signed-off-by: Spike Curtis <spike@coder.com> * Tests for entitlements API Signed-off-by: Spike Curtis <spike@coder.com> * Add commentary about FeatureService Signed-off-by: Spike Curtis <spike@coder.com> * Lint Signed-off-by: Spike Curtis <spike@coder.com> * Quiet down the logs Signed-off-by: Spike Curtis <spike@coder.com> * Tell revive it's ok Signed-off-by: Spike Curtis <spike@coder.com> Signed-off-by: Spike Curtis <spike@coder.com>
This commit is contained in:
@ -66,6 +66,7 @@ type Options struct {
|
||||
TracerProvider *sdktrace.TracerProvider
|
||||
AutoImportTemplates []AutoImportTemplate
|
||||
LicenseHandler http.Handler
|
||||
FeaturesService FeaturesService
|
||||
}
|
||||
|
||||
// New constructs a Coder API handler.
|
||||
@ -95,6 +96,9 @@ func New(options *Options) *API {
|
||||
if options.LicenseHandler == nil {
|
||||
options.LicenseHandler = licenses()
|
||||
}
|
||||
if options.FeaturesService == nil {
|
||||
options.FeaturesService = featuresService{}
|
||||
}
|
||||
|
||||
siteCacheDir := options.CacheDir
|
||||
if siteCacheDir != "" {
|
||||
@ -400,7 +404,7 @@ func New(options *Options) *API {
|
||||
})
|
||||
r.Route("/entitlements", func(r chi.Router) {
|
||||
r.Use(apiKeyMiddleware)
|
||||
r.Get("/", entitlements)
|
||||
r.Get("/", api.FeaturesService.EntitlementsAPI)
|
||||
})
|
||||
r.Route("/licenses", func(r chi.Router) {
|
||||
r.Use(apiKeyMiddleware)
|
||||
|
@ -246,6 +246,19 @@ func (q *fakeQuerier) GetUserCount(_ context.Context) (int64, error) {
|
||||
return int64(len(q.users)), nil
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) GetActiveUserCount(_ context.Context) (int64, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
active := int64(0)
|
||||
for _, u := range q.users {
|
||||
if u.Status == database.UserStatusActive {
|
||||
active++
|
||||
}
|
||||
}
|
||||
return active, nil
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) GetUsers(_ context.Context, params database.GetUsersParams) ([]database.User, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
@ -2322,6 +2335,21 @@ func (q *fakeQuerier) GetLicenses(_ context.Context) ([]database.License, error)
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) GetUnexpiredLicenses(_ context.Context) ([]database.License, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
now := time.Now()
|
||||
var results []database.License
|
||||
for _, l := range q.licenses {
|
||||
if l.Exp.After(now) {
|
||||
results = append(results, l)
|
||||
}
|
||||
}
|
||||
sort.Slice(results, func(i, j int) bool { return results[i].ID < results[j].ID })
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) DeleteLicense(_ context.Context, id int32) (int32, error) {
|
||||
q.mutex.Lock()
|
||||
defer q.mutex.Unlock()
|
||||
|
@ -25,6 +25,7 @@ type querier interface {
|
||||
DeleteParameterValueByID(ctx context.Context, id uuid.UUID) error
|
||||
GetAPIKeyByID(ctx context.Context, id string) (APIKey, error)
|
||||
GetAPIKeysLastUsedAfter(ctx context.Context, lastUsed time.Time) ([]APIKey, error)
|
||||
GetActiveUserCount(ctx context.Context) (int64, error)
|
||||
// GetAuditLogsBefore retrieves `limit` number of audit logs before the provided
|
||||
// ID.
|
||||
GetAuditLogsBefore(ctx context.Context, arg GetAuditLogsBeforeParams) ([]AuditLog, error)
|
||||
@ -63,6 +64,7 @@ type querier interface {
|
||||
GetTemplateVersionsCreatedAfter(ctx context.Context, createdAt time.Time) ([]TemplateVersion, error)
|
||||
GetTemplates(ctx context.Context) ([]Template, error)
|
||||
GetTemplatesWithFilter(ctx context.Context, arg GetTemplatesWithFilterParams) ([]Template, error)
|
||||
GetUnexpiredLicenses(ctx context.Context) ([]License, error)
|
||||
GetUserByEmailOrUsername(ctx context.Context, arg GetUserByEmailOrUsernameParams) (User, error)
|
||||
GetUserByID(ctx context.Context, id uuid.UUID) (User, error)
|
||||
GetUserCount(ctx context.Context) (int64, error)
|
||||
|
@ -522,6 +522,41 @@ func (q *sqlQuerier) GetLicenses(ctx context.Context) ([]License, error) {
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const getUnexpiredLicenses = `-- name: GetUnexpiredLicenses :many
|
||||
SELECT id, uploaded_at, jwt, exp
|
||||
FROM licenses
|
||||
WHERE exp > NOW()
|
||||
ORDER BY (id)
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetUnexpiredLicenses(ctx context.Context) ([]License, error) {
|
||||
rows, err := q.db.QueryContext(ctx, getUnexpiredLicenses)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []License
|
||||
for rows.Next() {
|
||||
var i License
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.UploadedAt,
|
||||
&i.JWT,
|
||||
&i.Exp,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const insertLicense = `-- name: InsertLicense :one
|
||||
INSERT INTO
|
||||
licenses (
|
||||
@ -2664,6 +2699,22 @@ func (q *sqlQuerier) UpdateUserLinkedID(ctx context.Context, arg UpdateUserLinke
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getActiveUserCount = `-- name: GetActiveUserCount :one
|
||||
SELECT
|
||||
COUNT(*)
|
||||
FROM
|
||||
users
|
||||
WHERE
|
||||
status = 'active'::public.user_status
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetActiveUserCount(ctx context.Context) (int64, error) {
|
||||
row := q.db.QueryRowContext(ctx, getActiveUserCount)
|
||||
var count int64
|
||||
err := row.Scan(&count)
|
||||
return count, err
|
||||
}
|
||||
|
||||
const getAuthorizationUserRoles = `-- name: GetAuthorizationUserRoles :one
|
||||
SELECT
|
||||
-- username is returned just to help for logging purposes
|
||||
|
@ -13,6 +13,12 @@ SELECT *
|
||||
FROM licenses
|
||||
ORDER BY (id);
|
||||
|
||||
-- name: GetUnexpiredLicenses :many
|
||||
SELECT *
|
||||
FROM licenses
|
||||
WHERE exp > NOW()
|
||||
ORDER BY (id);
|
||||
|
||||
-- name: DeleteLicense :one
|
||||
DELETE
|
||||
FROM licenses
|
||||
|
@ -28,6 +28,14 @@ SELECT
|
||||
FROM
|
||||
users;
|
||||
|
||||
-- name: GetActiveUserCount :one
|
||||
SELECT
|
||||
COUNT(*)
|
||||
FROM
|
||||
users
|
||||
WHERE
|
||||
status = 'active'::public.user_status;
|
||||
|
||||
-- name: InsertUser :one
|
||||
INSERT INTO
|
||||
users (
|
||||
|
@ -7,7 +7,20 @@ import (
|
||||
"github.com/coder/coder/codersdk"
|
||||
)
|
||||
|
||||
func entitlements(rw http.ResponseWriter, _ *http.Request) {
|
||||
// FeaturesService is the interface for interacting with enterprise features.
|
||||
type FeaturesService interface {
|
||||
EntitlementsAPI(w http.ResponseWriter, r *http.Request)
|
||||
|
||||
// TODO
|
||||
// Get returns the implementations for feature interfaces. Parameter `s `must be a pointer to a
|
||||
// struct type containing feature interfaces as fields. The FeatureService sets all fields to
|
||||
// the correct implementations depending on whether the features are turned on.
|
||||
// Get(s any) error
|
||||
}
|
||||
|
||||
type featuresService struct{}
|
||||
|
||||
func (featuresService) EntitlementsAPI(rw http.ResponseWriter, _ *http.Request) {
|
||||
features := make(map[string]codersdk.Feature)
|
||||
for _, f := range codersdk.FeatureNames {
|
||||
features[f] = codersdk.Feature{
|
||||
|
@ -18,7 +18,7 @@ func TestEntitlements(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := httptest.NewRequest("GET", "https://example.com/api/v2/entitlements", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
entitlements(rw, r)
|
||||
featuresService{}.EntitlementsAPI(rw, r)
|
||||
resp := rw.Result()
|
||||
defer resp.Body.Close()
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
Reference in New Issue
Block a user