Use licenses to populate the Entitlements API (#3715)

* Use licenses for entitlements API

Signed-off-by: Spike Curtis <spike@coder.com>

* Tests for entitlements API

Signed-off-by: Spike Curtis <spike@coder.com>

* Add commentary about FeatureService

Signed-off-by: Spike Curtis <spike@coder.com>

* Lint

Signed-off-by: Spike Curtis <spike@coder.com>

* Quiet down the logs

Signed-off-by: Spike Curtis <spike@coder.com>

* Tell revive it's ok

Signed-off-by: Spike Curtis <spike@coder.com>

Signed-off-by: Spike Curtis <spike@coder.com>
This commit is contained in:
Spike Curtis
2022-08-29 16:45:40 -07:00
committed by GitHub
parent 05f932b37e
commit cc346afce6
14 changed files with 773 additions and 10 deletions

View File

@ -66,6 +66,7 @@ type Options struct {
TracerProvider *sdktrace.TracerProvider
AutoImportTemplates []AutoImportTemplate
LicenseHandler http.Handler
FeaturesService FeaturesService
}
// New constructs a Coder API handler.
@ -95,6 +96,9 @@ func New(options *Options) *API {
if options.LicenseHandler == nil {
options.LicenseHandler = licenses()
}
if options.FeaturesService == nil {
options.FeaturesService = featuresService{}
}
siteCacheDir := options.CacheDir
if siteCacheDir != "" {
@ -400,7 +404,7 @@ func New(options *Options) *API {
})
r.Route("/entitlements", func(r chi.Router) {
r.Use(apiKeyMiddleware)
r.Get("/", entitlements)
r.Get("/", api.FeaturesService.EntitlementsAPI)
})
r.Route("/licenses", func(r chi.Router) {
r.Use(apiKeyMiddleware)

View File

@ -246,6 +246,19 @@ func (q *fakeQuerier) GetUserCount(_ context.Context) (int64, error) {
return int64(len(q.users)), nil
}
func (q *fakeQuerier) GetActiveUserCount(_ context.Context) (int64, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
active := int64(0)
for _, u := range q.users {
if u.Status == database.UserStatusActive {
active++
}
}
return active, nil
}
func (q *fakeQuerier) GetUsers(_ context.Context, params database.GetUsersParams) ([]database.User, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
@ -2322,6 +2335,21 @@ func (q *fakeQuerier) GetLicenses(_ context.Context) ([]database.License, error)
return results, nil
}
func (q *fakeQuerier) GetUnexpiredLicenses(_ context.Context) ([]database.License, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
now := time.Now()
var results []database.License
for _, l := range q.licenses {
if l.Exp.After(now) {
results = append(results, l)
}
}
sort.Slice(results, func(i, j int) bool { return results[i].ID < results[j].ID })
return results, nil
}
func (q *fakeQuerier) DeleteLicense(_ context.Context, id int32) (int32, error) {
q.mutex.Lock()
defer q.mutex.Unlock()

View File

@ -25,6 +25,7 @@ type querier interface {
DeleteParameterValueByID(ctx context.Context, id uuid.UUID) error
GetAPIKeyByID(ctx context.Context, id string) (APIKey, error)
GetAPIKeysLastUsedAfter(ctx context.Context, lastUsed time.Time) ([]APIKey, error)
GetActiveUserCount(ctx context.Context) (int64, error)
// GetAuditLogsBefore retrieves `limit` number of audit logs before the provided
// ID.
GetAuditLogsBefore(ctx context.Context, arg GetAuditLogsBeforeParams) ([]AuditLog, error)
@ -63,6 +64,7 @@ type querier interface {
GetTemplateVersionsCreatedAfter(ctx context.Context, createdAt time.Time) ([]TemplateVersion, error)
GetTemplates(ctx context.Context) ([]Template, error)
GetTemplatesWithFilter(ctx context.Context, arg GetTemplatesWithFilterParams) ([]Template, error)
GetUnexpiredLicenses(ctx context.Context) ([]License, error)
GetUserByEmailOrUsername(ctx context.Context, arg GetUserByEmailOrUsernameParams) (User, error)
GetUserByID(ctx context.Context, id uuid.UUID) (User, error)
GetUserCount(ctx context.Context) (int64, error)

View File

@ -522,6 +522,41 @@ func (q *sqlQuerier) GetLicenses(ctx context.Context) ([]License, error) {
return items, nil
}
const getUnexpiredLicenses = `-- name: GetUnexpiredLicenses :many
SELECT id, uploaded_at, jwt, exp
FROM licenses
WHERE exp > NOW()
ORDER BY (id)
`
func (q *sqlQuerier) GetUnexpiredLicenses(ctx context.Context) ([]License, error) {
rows, err := q.db.QueryContext(ctx, getUnexpiredLicenses)
if err != nil {
return nil, err
}
defer rows.Close()
var items []License
for rows.Next() {
var i License
if err := rows.Scan(
&i.ID,
&i.UploadedAt,
&i.JWT,
&i.Exp,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const insertLicense = `-- name: InsertLicense :one
INSERT INTO
licenses (
@ -2664,6 +2699,22 @@ func (q *sqlQuerier) UpdateUserLinkedID(ctx context.Context, arg UpdateUserLinke
return i, err
}
const getActiveUserCount = `-- name: GetActiveUserCount :one
SELECT
COUNT(*)
FROM
users
WHERE
status = 'active'::public.user_status
`
func (q *sqlQuerier) GetActiveUserCount(ctx context.Context) (int64, error) {
row := q.db.QueryRowContext(ctx, getActiveUserCount)
var count int64
err := row.Scan(&count)
return count, err
}
const getAuthorizationUserRoles = `-- name: GetAuthorizationUserRoles :one
SELECT
-- username is returned just to help for logging purposes

View File

@ -13,6 +13,12 @@ SELECT *
FROM licenses
ORDER BY (id);
-- name: GetUnexpiredLicenses :many
SELECT *
FROM licenses
WHERE exp > NOW()
ORDER BY (id);
-- name: DeleteLicense :one
DELETE
FROM licenses

View File

@ -28,6 +28,14 @@ SELECT
FROM
users;
-- name: GetActiveUserCount :one
SELECT
COUNT(*)
FROM
users
WHERE
status = 'active'::public.user_status;
-- name: InsertUser :one
INSERT INTO
users (

View File

@ -7,7 +7,20 @@ import (
"github.com/coder/coder/codersdk"
)
func entitlements(rw http.ResponseWriter, _ *http.Request) {
// FeaturesService is the interface for interacting with enterprise features.
type FeaturesService interface {
EntitlementsAPI(w http.ResponseWriter, r *http.Request)
// TODO
// Get returns the implementations for feature interfaces. Parameter `s `must be a pointer to a
// struct type containing feature interfaces as fields. The FeatureService sets all fields to
// the correct implementations depending on whether the features are turned on.
// Get(s any) error
}
type featuresService struct{}
func (featuresService) EntitlementsAPI(rw http.ResponseWriter, _ *http.Request) {
features := make(map[string]codersdk.Feature)
for _, f := range codersdk.FeatureNames {
features[f] = codersdk.Feature{

View File

@ -18,7 +18,7 @@ func TestEntitlements(t *testing.T) {
t.Parallel()
r := httptest.NewRequest("GET", "https://example.com/api/v2/entitlements", nil)
rw := httptest.NewRecorder()
entitlements(rw, r)
featuresService{}.EntitlementsAPI(rw, r)
resp := rw.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)