Feature server implementation (#3899)

* Feature server implementation

Signed-off-by: Spike Curtis <spike@coder.com>

* Fix imports

Signed-off-by: Spike Curtis <spike@coder.com>

Signed-off-by: Spike Curtis <spike@coder.com>
This commit is contained in:
Spike Curtis
2022-09-06 11:59:10 -07:00
committed by GitHub
parent 1b6f9e54a3
commit a7cdec5d39
4 changed files with 402 additions and 11 deletions

View File

@ -2,7 +2,11 @@ package coderd
import ( import (
"net/http" "net/http"
"reflect"
"golang.org/x/xerrors"
"github.com/coder/coder/coderd/audit"
"github.com/coder/coder/coderd/httpapi" "github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/codersdk" "github.com/coder/coder/codersdk"
) )
@ -11,11 +15,10 @@ import (
type FeaturesService interface { type FeaturesService interface {
EntitlementsAPI(w http.ResponseWriter, r *http.Request) EntitlementsAPI(w http.ResponseWriter, r *http.Request)
// TODO // Get returns the implementations for feature interfaces. Parameter `s` must be a pointer to a
// Get returns the implementations for feature interfaces. Parameter `s `must be a pointer to a
// struct type containing feature interfaces as fields. The FeatureService sets all fields to // struct type containing feature interfaces as fields. The FeatureService sets all fields to
// the correct implementations depending on whether the features are turned on. // the correct implementations depending on whether the features are turned on.
// Get(s any) error Get(s any) error
} }
type featuresService struct{} type featuresService struct{}
@ -34,3 +37,57 @@ func (featuresService) EntitlementsAPI(rw http.ResponseWriter, _ *http.Request)
HasLicense: false, HasLicense: false,
}) })
} }
// Get returns the implementations for feature interfaces. Parameter `s` must be a pointer to a
// struct type containing feature interfaces as fields. The AGPL featureService always returns the
// "disabled" version of the feature interface because it doesn't include any enterprise features
// by definition.
func (featuresService) Get(ps any) error {
if reflect.TypeOf(ps).Kind() != reflect.Pointer {
return xerrors.New("input must be pointer to struct")
}
vs := reflect.ValueOf(ps).Elem()
if vs.Kind() != reflect.Struct {
return xerrors.New("input must be pointer to struct")
}
for i := 0; i < vs.NumField(); i++ {
vf := vs.Field(i)
tf := vf.Type()
if tf.Kind() != reflect.Interface {
return xerrors.Errorf("fields of input struct must be interfaces: %s", tf.String())
}
err := setImplementation(vf, tf)
if err != nil {
return err
}
}
return nil
}
// setImplementation finds the correct implementation for the field's type, and sets it on the
// struct. It returns an error if unsuccessful
func setImplementation(vf reflect.Value, tf reflect.Type) error {
// when we get more than a few features it might make sense to have a data structure for finding
// the correct implementation that's faster than just a linear search, but for now just spin
// through the implementations we have.
vd := reflect.ValueOf(DisabledImplementations)
for j := 0; j < vd.NumField(); j++ {
vdf := vd.Field(j)
if vdf.Type() == tf {
vf.Set(vdf)
return nil
}
}
return xerrors.Errorf("unable to find implementation of interface %s", tf.String())
}
// FeatureInterfaces contains a field for each interface controlled by an enterprise feature.
type FeatureInterfaces struct {
Auditor audit.Auditor
}
// DisabledImplementations includes all the implementations of turned-off features. There are no
// turned-on implementations in AGPL code.
var DisabledImplementations = FeatureInterfaces{
Auditor: audit.NewNop(),
}

View File

@ -9,6 +9,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/coder/coder/coderd/audit"
"github.com/coder/coder/codersdk" "github.com/coder/coder/codersdk"
) )
@ -36,3 +37,64 @@ func TestEntitlements(t *testing.T) {
} }
}) })
} }
func TestFeaturesServiceGet(t *testing.T) {
t.Parallel()
t.Run("Auditor", func(t *testing.T) {
t.Parallel()
uut := featuresService{}
target := struct {
Auditor audit.Auditor
}{}
err := uut.Get(&target)
require.NoError(t, err)
assert.NotNil(t, target.Auditor)
})
t.Run("NotPointer", func(t *testing.T) {
t.Parallel()
uut := featuresService{}
target := struct {
Auditor audit.Auditor
}{}
err := uut.Get(target)
require.Error(t, err)
assert.Nil(t, target.Auditor)
})
t.Run("UnknownInterface", func(t *testing.T) {
t.Parallel()
uut := featuresService{}
target := struct {
test testInterface
}{}
err := uut.Get(&target)
require.Error(t, err)
assert.Nil(t, target.test)
})
t.Run("PointerToNonStruct", func(t *testing.T) {
t.Parallel()
uut := featuresService{}
var target audit.Auditor
err := uut.Get(&target)
require.Error(t, err)
assert.Nil(t, target)
})
t.Run("StructWithNonInterfaces", func(t *testing.T) {
t.Parallel()
uut := featuresService{}
target := struct {
N int64
Auditor audit.Auditor
}{}
err := uut.Get(&target)
require.Error(t, err)
assert.Nil(t, target.Auditor)
})
}
type testInterface interface {
Test() error
}

View File

@ -5,17 +5,23 @@ import (
"crypto/ed25519" "crypto/ed25519"
"fmt" "fmt"
"net/http" "net/http"
"reflect"
"sync" "sync"
"time" "time"
"github.com/coder/coder/enterprise/audit/backends"
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
"golang.org/x/xerrors"
"cdr.dev/slog" "cdr.dev/slog"
agpl "github.com/coder/coder/coderd" agpl "github.com/coder/coder/coderd"
agplAudit "github.com/coder/coder/coderd/audit"
"github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/httpapi" "github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/codersdk" "github.com/coder/coder/codersdk"
"github.com/coder/coder/enterprise/audit"
) )
type Enablements struct { type Enablements struct {
@ -29,6 +35,13 @@ type featuresService struct {
keys map[string]ed25519.PublicKey keys map[string]ed25519.PublicKey
enablements Enablements enablements Enablements
resyncInterval time.Duration resyncInterval time.Duration
// enabledImplementations includes an "enabled" implementation of every feature. This is
// initialized at start of day and remains static. The consequence of this is that these things
// are hanging around using memory even if not licensed or in use, but it greatly simplifies the
// logic because we don't have to bother creating and destroying them as entitlements change.
// If we have a particularly memory-hungry feature in future, we might wish to reconsider this
// choice.
enabledImplementations agpl.FeatureInterfaces
mu sync.RWMutex mu sync.RWMutex
entitlements entitlements entitlements entitlements
@ -44,11 +57,18 @@ func newFeaturesService(
enablements Enablements, enablements Enablements,
) agpl.FeaturesService { ) agpl.FeaturesService {
fs := &featuresService{ fs := &featuresService{
logger: logger, logger: logger,
database: db, database: db,
pubsub: pubsub, pubsub: pubsub,
keys: keys, keys: keys,
enablements: enablements, enablements: enablements,
enabledImplementations: agpl.FeatureInterfaces{
Auditor: audit.NewAuditor(
audit.DefaultFilter,
backends.NewPostgres(db, true),
backends.NewSlog(logger),
),
},
resyncInterval: 10 * time.Minute, resyncInterval: 10 * time.Minute,
entitlements: entitlements{ entitlements: entitlements{
activeUsers: numericalEntitlement{ activeUsers: numericalEntitlement{
@ -259,3 +279,48 @@ func max(a, b int64) int64 {
} }
return b return b
} }
func (s *featuresService) Get(ps any) error {
if reflect.TypeOf(ps).Kind() != reflect.Pointer {
return xerrors.New("input must be pointer to struct")
}
vs := reflect.ValueOf(ps).Elem()
if vs.Kind() != reflect.Struct {
return xerrors.New("input must be pointer to struct")
}
// grab a local copy of entitlements so that we have a consistent set, but aren't keeping it
// locked from updates while we process.
s.mu.RLock()
ent := s.entitlements
s.mu.RUnlock()
for i := 0; i < vs.NumField(); i++ {
vf := vs.Field(i)
tf := vf.Type()
if tf.Kind() != reflect.Interface {
return xerrors.Errorf("fields of input struct must be interfaces: %s", tf.String())
}
err := s.setImplementation(ent, vf, tf)
if err != nil {
return err
}
}
return nil
}
func (s *featuresService) setImplementation(ent entitlements, vf reflect.Value, tf reflect.Type) error {
// c.f. https://stackoverflow.com/questions/7132848/how-to-get-the-reflect-type-of-an-interface
switch tf {
case reflect.TypeOf((*agplAudit.Auditor)(nil)).Elem():
// Audit logging
if !s.enablements.AuditLogs || ent.auditLogs.state == notEntitled {
vf.Set(reflect.ValueOf(agpl.DisabledImplementations.Auditor))
return nil
}
vf.Set(reflect.ValueOf(s.enabledImplementations.Auditor))
return nil
default:
return xerrors.Errorf("unable to find implementation of interface %s", tf.String())
}
}

View File

@ -7,21 +7,24 @@ import (
"encoding/json" "encoding/json"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"reflect"
"testing" "testing"
"time" "time"
"github.com/golang-jwt/jwt/v4" "github.com/golang-jwt/jwt/v4"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"cdr.dev/slog/sloggers/slogtest" "cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/coderd" agplCoderd "github.com/coder/coder/coderd"
agplAudit "github.com/coder/coder/coderd/audit"
"github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/databasefake" "github.com/coder/coder/coderd/database/databasefake"
"github.com/coder/coder/codersdk" "github.com/coder/coder/codersdk"
"github.com/coder/coder/enterprise/audit"
"github.com/coder/coder/enterprise/audit/backends"
"github.com/coder/coder/testutil" "github.com/coder/coder/testutil"
) )
@ -282,7 +285,7 @@ func TestFeaturesServiceSyncEntitlements(t *testing.T) {
}) })
} }
func requestEntitlements(t *testing.T, uut coderd.FeaturesService) codersdk.Entitlements { func requestEntitlements(t *testing.T, uut agplCoderd.FeaturesService) codersdk.Entitlements {
t.Helper() t.Helper()
r := httptest.NewRequest("GET", "https://example.com/api/v2/entitlements", nil) r := httptest.NewRequest("GET", "https://example.com/api/v2/entitlements", nil)
rw := httptest.NewRecorder() rw := httptest.NewRecorder()
@ -335,3 +338,207 @@ func userLimitIs(fs *featuresService, limit int64) func(context.Context) bool {
return fs.entitlements.activeUsers.limit == limit return fs.entitlements.activeUsers.limit == limit
} }
} }
func TestFeaturesServiceGet(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
// Note that these are not actually used because we don't run the syncEntitlements
// routine in this test.
pubsub := database.NewPubsubInMemory()
pub, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
keyID := "testing"
db := databasefake.New()
t.Run("AuditorOff", func(t *testing.T) {
t.Parallel()
uut := &featuresService{
logger: logger,
database: db,
pubsub: pubsub,
keys: map[string]ed25519.PublicKey{keyID: pub},
enablements: Enablements{AuditLogs: true},
enabledImplementations: agplCoderd.FeatureInterfaces{
Auditor: audit.NewAuditor(audit.DefaultFilter),
},
entitlements: entitlements{
hasLicense: false,
activeUsers: numericalEntitlement{
entitlement{notEntitled},
entitlementLimit{
unlimited: true,
},
},
auditLogs: entitlement{notEntitled},
},
}
target := struct {
Auditor agplAudit.Auditor
}{}
err := uut.Get(&target)
require.NoError(t, err)
assert.NotNil(t, target.Auditor)
nop := agplAudit.NewNop()
assert.Equal(t, reflect.ValueOf(nop).Type(), reflect.ValueOf(target.Auditor).Type())
})
t.Run("AuditorOn", func(t *testing.T) {
t.Parallel()
uut := &featuresService{
logger: logger,
database: db,
pubsub: pubsub,
keys: map[string]ed25519.PublicKey{keyID: pub},
enablements: Enablements{AuditLogs: true},
enabledImplementations: agplCoderd.FeatureInterfaces{
Auditor: audit.NewAuditor(audit.DefaultFilter),
},
entitlements: entitlements{
hasLicense: false,
activeUsers: numericalEntitlement{
entitlement{notEntitled},
entitlementLimit{
unlimited: true,
},
},
auditLogs: entitlement{entitled},
},
}
target := struct {
Auditor agplAudit.Auditor
}{}
err := uut.Get(&target)
require.NoError(t, err)
assert.NotNil(t, target.Auditor)
ea := audit.NewAuditor(
audit.DefaultFilter,
backends.NewPostgres(db, true),
backends.NewSlog(logger),
)
assert.Equal(t, reflect.ValueOf(ea).Type(), reflect.ValueOf(target.Auditor).Type())
})
t.Run("NotPointer", func(t *testing.T) {
t.Parallel()
uut := &featuresService{
logger: logger,
database: db,
pubsub: pubsub,
keys: map[string]ed25519.PublicKey{keyID: pub},
enablements: Enablements{AuditLogs: true},
enabledImplementations: agplCoderd.FeatureInterfaces{
Auditor: audit.NewAuditor(audit.DefaultFilter),
},
entitlements: entitlements{
hasLicense: false,
activeUsers: numericalEntitlement{
entitlement{notEntitled},
entitlementLimit{
unlimited: true,
},
},
auditLogs: entitlement{notEntitled},
},
}
target := struct {
Auditor agplAudit.Auditor
}{}
err := uut.Get(target)
require.Error(t, err)
assert.Nil(t, target.Auditor)
})
t.Run("UnknownInterface", func(t *testing.T) {
t.Parallel()
uut := &featuresService{
logger: logger,
database: db,
pubsub: pubsub,
keys: map[string]ed25519.PublicKey{keyID: pub},
enablements: Enablements{AuditLogs: true},
enabledImplementations: agplCoderd.FeatureInterfaces{
Auditor: audit.NewAuditor(audit.DefaultFilter),
},
entitlements: entitlements{
hasLicense: false,
activeUsers: numericalEntitlement{
entitlement{notEntitled},
entitlementLimit{
unlimited: true,
},
},
auditLogs: entitlement{notEntitled},
},
}
target := struct {
test testInterface
}{}
err := uut.Get(&target)
require.Error(t, err)
assert.Nil(t, target.test)
})
t.Run("PointerToNonStruct", func(t *testing.T) {
t.Parallel()
uut := &featuresService{
logger: logger,
database: db,
pubsub: pubsub,
keys: map[string]ed25519.PublicKey{keyID: pub},
enablements: Enablements{AuditLogs: true},
enabledImplementations: agplCoderd.FeatureInterfaces{
Auditor: audit.NewAuditor(audit.DefaultFilter),
},
entitlements: entitlements{
hasLicense: false,
activeUsers: numericalEntitlement{
entitlement{notEntitled},
entitlementLimit{
unlimited: true,
},
},
auditLogs: entitlement{notEntitled},
},
}
var target agplAudit.Auditor
err := uut.Get(&target)
require.Error(t, err)
assert.Nil(t, target)
})
t.Run("StructWithNonInterfaces", func(t *testing.T) {
t.Parallel()
uut := &featuresService{
logger: logger,
database: db,
pubsub: pubsub,
keys: map[string]ed25519.PublicKey{keyID: pub},
enablements: Enablements{AuditLogs: true},
enabledImplementations: agplCoderd.FeatureInterfaces{
Auditor: audit.NewAuditor(audit.DefaultFilter),
},
entitlements: entitlements{
hasLicense: false,
activeUsers: numericalEntitlement{
entitlement{notEntitled},
entitlementLimit{
unlimited: true,
},
},
auditLogs: entitlement{notEntitled},
},
}
target := struct {
N int64
Auditor agplAudit.Auditor
}{}
err := uut.Get(&target)
require.Error(t, err)
assert.Nil(t, target.Auditor)
})
}
type testInterface interface {
Test() error
}