chore: Refactor Enterprise code to layer on top of AGPL (#4034)

* chore: Refactor Enterprise code to layer on top of AGPL

This is an experiment to invert the import order of the Enterprise
code to layer on top of AGPL.

* Fix Garrett's comments

* Add pointer.Handle to atomically obtain references

This uses a context to ensure the same value persists through
multiple executions to `Load()`.

* Remove entitlements API from AGPL coderd

* Remove AGPL Coder entitlements endpoint test

* Fix warnings output

* Add command-line flag to toggle audit logging

* Fix hasLicense being set

* Remove features interface

* Fix audit logging default

* Add bash as a dependency

* Add comment

* Add tests for resync and pubsub, and add back previous exp backoff retry

* Separate authz code again

* Add pointer loading example from comment

* Fix duplicate test, remove pointer.Handle

* Fix expired license

* Add entitlements struct

* Fix context passing
This commit is contained in:
Kyle Carberry
2022-09-19 23:11:01 -05:00
committed by GitHub
parent 714c366d16
commit db0ba8588e
39 changed files with 1402 additions and 2069 deletions

View File

@ -12,14 +12,13 @@ import (
"cdr.dev/slog"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/features"
"github.com/coder/coder/coderd/httpmw"
"github.com/coder/coder/coderd/tracing"
)
type RequestParams struct {
Features features.Service
Log slog.Logger
Audit Auditor
Log slog.Logger
Request *http.Request
Action database.AuditAction
@ -102,15 +101,6 @@ func InitRequest[T Auditable](w http.ResponseWriter, p *RequestParams) (*Request
params: p,
}
feats := struct {
Audit Auditor
}{}
err := p.Features.Get(&feats)
if err != nil {
p.Log.Error(p.Request.Context(), "unable to get auditor interface", slog.Error(err))
return req, func() {}
}
return req, func() {
ctx := context.Background()
logCtx := p.Request.Context()
@ -120,7 +110,7 @@ func InitRequest[T Auditable](w http.ResponseWriter, p *RequestParams) (*Request
return
}
diff := Diff(feats.Audit, req.Old, req.New)
diff := Diff(p.Audit, req.Old, req.New)
diffRaw, _ := json.Marshal(diff)
ip, err := parseIP(p.Request.RemoteAddr)
@ -128,7 +118,7 @@ func InitRequest[T Auditable](w http.ResponseWriter, p *RequestParams) (*Request
p.Log.Warn(logCtx, "parse ip", slog.Error(err))
}
err = feats.Audit.Export(ctx, database.AuditLog{
err = p.Audit.Export(ctx, database.AuditLog{
ID: uuid.New(),
Time: database.Now(),
UserID: httpmw.APIKey(p.Request).UserID,

View File

@ -43,7 +43,7 @@ type HTTPAuthorizer struct {
// return
// }
func (api *API) Authorize(r *http.Request, action rbac.Action, object rbac.Objecter) bool {
return api.httpAuth.Authorize(r, action, object)
return api.HTTPAuth.Authorize(r, action, object)
}
// Authorize will return false if the user is not authorized to do the action.

View File

@ -7,6 +7,7 @@ import (
"net/url"
"path/filepath"
"sync"
"sync/atomic"
"time"
"github.com/andybalholm/brotli"
@ -24,9 +25,9 @@ import (
"cdr.dev/slog"
"github.com/coder/coder/buildinfo"
"github.com/coder/coder/coderd/audit"
"github.com/coder/coder/coderd/awsidentity"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/features"
"github.com/coder/coder/coderd/gitsshkey"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/coderd/httpmw"
@ -50,6 +51,7 @@ type Options struct {
// CacheDir is used for caching files served by the API.
CacheDir string
Auditor audit.Auditor
AgentConnectionUpdateFrequency time.Duration
AgentInactiveDisconnectTimeout time.Duration
// APIRateLimit is the minutely throughput rate limit per user or ip.
@ -68,8 +70,6 @@ type Options struct {
Telemetry telemetry.Reporter
TracerProvider trace.TracerProvider
AutoImportTemplates []AutoImportTemplate
LicenseHandler http.Handler
FeaturesService features.Service
TailnetCoordinator *tailnet.Coordinator
DERPMap *tailcfg.DERPMap
@ -80,6 +80,9 @@ type Options struct {
// New constructs a Coder API handler.
func New(options *Options) *API {
if options == nil {
options = &Options{}
}
if options.AgentConnectionUpdateFrequency == 0 {
options.AgentConnectionUpdateFrequency = 3 * time.Second
}
@ -117,11 +120,8 @@ func New(options *Options) *API {
if options.TailnetCoordinator == nil {
options.TailnetCoordinator = tailnet.NewCoordinator()
}
if options.LicenseHandler == nil {
options.LicenseHandler = licenses()
}
if options.FeaturesService == nil {
options.FeaturesService = &featuresService{}
if options.Auditor == nil {
options.Auditor = audit.NewNop()
}
siteCacheDir := options.CacheDir
@ -142,14 +142,16 @@ func New(options *Options) *API {
r := chi.NewRouter()
api := &API{
Options: options,
Handler: r,
RootHandler: r,
siteHandler: site.Handler(site.FS(), binFS),
httpAuth: &HTTPAuthorizer{
HTTPAuth: &HTTPAuthorizer{
Authorizer: options.Authorizer,
Logger: options.Logger,
},
metricsCache: metricsCache,
Auditor: atomic.Pointer[audit.Auditor]{},
}
api.Auditor.Store(&options.Auditor)
api.workspaceAgentCache = wsconncache.New(api.dialWorkspaceAgentTailnet, 0)
api.derpServer = derp.NewServer(key.NewNode(), tailnet.Logger(options.Logger))
oauthConfigs := &httpmw.OAuth2Configs{
@ -218,6 +220,8 @@ func New(options *Options) *API {
})
r.Route("/api/v2", func(r chi.Router) {
api.APIHandler = r
r.NotFound(func(rw http.ResponseWriter, r *http.Request) {
httpapi.Write(rw, http.StatusNotFound, codersdk.Response{
Message: "Route not found.",
@ -473,14 +477,6 @@ func New(options *Options) *API {
r.Get("/resources", api.workspaceBuildResources)
r.Get("/state", api.workspaceBuildState)
})
r.Route("/entitlements", func(r chi.Router) {
r.Use(apiKeyMiddleware)
r.Get("/", api.FeaturesService.EntitlementsAPI)
})
r.Route("/licenses", func(r chi.Router) {
r.Use(apiKeyMiddleware)
r.Mount("/", options.LicenseHandler)
})
})
r.NotFound(compressHandler(http.HandlerFunc(api.siteHandler.ServeHTTP)).ServeHTTP)
@ -489,17 +485,20 @@ func New(options *Options) *API {
type API struct {
*Options
Auditor atomic.Pointer[audit.Auditor]
HTTPAuth *HTTPAuthorizer
derpServer *derp.Server
// APIHandler serves "/api/v2"
APIHandler chi.Router
// RootHandler serves "/"
RootHandler chi.Router
Handler chi.Router
derpServer *derp.Server
metricsCache *metricscache.Cache
siteHandler http.Handler
websocketWaitMutex sync.Mutex
websocketWaitGroup sync.WaitGroup
workspaceAgentCache *wsconncache.Cache
httpAuth *HTTPAuthorizer
metricsCache *metricscache.Cache
}
// Close waits for all WebSocket connections to drain before returning.

View File

@ -38,16 +38,6 @@ func TestBuildInfo(t *testing.T) {
require.Equal(t, buildinfo.Version(), buildInfo.Version, "version")
}
// TestAuthorizeAllEndpoints will check `authorize` is called on every endpoint registered.
func TestAuthorizeAllEndpoints(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
a := coderdtest.NewAuthTester(ctx, t, nil)
skipRoutes, assertRoute := coderdtest.AGPLRoutes(a)
a.Test(ctx, assertRoute, skipRoutes)
}
func TestDERP(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)

View File

@ -22,143 +22,6 @@ import (
"github.com/coder/coder/testutil"
)
type RouteCheck struct {
NoAuthorize bool
AssertAction rbac.Action
AssertObject rbac.Object
StatusCode int
}
type AuthTester struct {
t *testing.T
api *coderd.API
authorizer *recordingAuthorizer
Client *codersdk.Client
Workspace codersdk.Workspace
Organization codersdk.Organization
Admin codersdk.CreateFirstUserResponse
Template codersdk.Template
Version codersdk.TemplateVersion
WorkspaceResource codersdk.WorkspaceResource
File codersdk.UploadResponse
TemplateVersionDryRun codersdk.ProvisionerJob
TemplateParam codersdk.Parameter
URLParams map[string]string
}
func NewAuthTester(ctx context.Context, t *testing.T, options *Options) *AuthTester {
authorizer := &recordingAuthorizer{}
if options == nil {
options = &Options{}
}
if options.Authorizer != nil {
t.Error("NewAuthTester cannot be called with custom Authorizer")
}
options.Authorizer = authorizer
options.IncludeProvisionerDaemon = true
client, _, api := newWithAPI(t, options)
admin := CreateFirstUser(t, client)
// The provisioner will call to coderd and register itself. This is async,
// so we wait for it to occur.
require.Eventually(t, func() bool {
provisionerds, err := client.ProvisionerDaemons(ctx)
return assert.NoError(t, err) && len(provisionerds) > 0
}, testutil.WaitLong, testutil.IntervalSlow)
provisionerds, err := client.ProvisionerDaemons(ctx)
require.NoError(t, err, "fetch provisioners")
require.Len(t, provisionerds, 1)
organization, err := client.Organization(ctx, admin.OrganizationID)
require.NoError(t, err, "fetch org")
// Setup some data in the database.
version := CreateTemplateVersion(t, client, admin.OrganizationID, &echo.Responses{
Parse: echo.ParseComplete,
Provision: []*proto.Provision_Response{{
Type: &proto.Provision_Response_Complete{
Complete: &proto.Provision_Complete{
// Return a workspace resource
Resources: []*proto.Resource{{
Name: "some",
Type: "example",
Agents: []*proto.Agent{{
Name: "agent",
Id: "something",
Auth: &proto.Agent_Token{},
Apps: []*proto.App{{
Name: "testapp",
Url: "http://localhost:3000",
}},
}},
}},
},
},
}},
})
AwaitTemplateVersionJob(t, client, version.ID)
template := CreateTemplate(t, client, admin.OrganizationID, version.ID)
workspace := CreateWorkspace(t, client, admin.OrganizationID, template.ID)
AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
file, err := client.Upload(ctx, codersdk.ContentTypeTar, make([]byte, 1024))
require.NoError(t, err, "upload file")
workspaceResources, err := client.WorkspaceResourcesByBuild(ctx, workspace.LatestBuild.ID)
require.NoError(t, err, "workspace resources")
templateVersionDryRun, err := client.CreateTemplateVersionDryRun(ctx, version.ID, codersdk.CreateTemplateVersionDryRunRequest{
ParameterValues: []codersdk.CreateParameterRequest{},
})
require.NoError(t, err, "template version dry-run")
templateParam, err := client.CreateParameter(ctx, codersdk.ParameterTemplate, template.ID, codersdk.CreateParameterRequest{
Name: "test-param",
SourceValue: "hello world",
SourceScheme: codersdk.ParameterSourceSchemeData,
DestinationScheme: codersdk.ParameterDestinationSchemeProvisionerVariable,
})
require.NoError(t, err, "create template param")
urlParameters := map[string]string{
"{organization}": admin.OrganizationID.String(),
"{user}": admin.UserID.String(),
"{organizationname}": organization.Name,
"{workspace}": workspace.ID.String(),
"{workspacebuild}": workspace.LatestBuild.ID.String(),
"{workspacename}": workspace.Name,
"{workspaceagent}": workspaceResources[0].Agents[0].ID.String(),
"{buildnumber}": strconv.FormatInt(int64(workspace.LatestBuild.BuildNumber), 10),
"{template}": template.ID.String(),
"{hash}": file.Hash,
"{workspaceresource}": workspaceResources[0].ID.String(),
"{workspaceapp}": workspaceResources[0].Agents[0].Apps[0].Name,
"{templateversion}": version.ID.String(),
"{jobID}": templateVersionDryRun.ID.String(),
"{templatename}": template.Name,
"{workspace_and_agent}": workspace.Name + "." + workspaceResources[0].Agents[0].Name,
// Only checking template scoped params here
"parameters/{scope}/{id}": fmt.Sprintf("parameters/%s/%s",
string(templateParam.Scope), templateParam.ScopeID.String()),
}
return &AuthTester{
t: t,
api: api,
authorizer: authorizer,
Client: client,
Workspace: workspace,
Organization: organization,
Admin: admin,
Template: template,
Version: version,
WorkspaceResource: workspaceResources[0],
File: file,
TemplateVersionDryRun: templateVersionDryRun,
TemplateParam: templateParam,
URLParams: urlParameters,
}
}
func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
// Some quick reused objects
workspaceRBACObj := rbac.ResourceWorkspace.InOrg(a.Organization.ID).WithOwner(a.Workspace.OwnerID.String())
@ -181,7 +44,6 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
"POST:/api/v2/users/login": {NoAuthorize: true},
"GET:/api/v2/users/authmethods": {NoAuthorize: true},
"POST:/api/v2/csp/reports": {NoAuthorize: true},
"GET:/api/v2/entitlements": {NoAuthorize: true},
// Has it's own auth
"GET:/api/v2/users/oauth2/github/callback": {NoAuthorize: true},
@ -408,6 +270,134 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
return skipRoutes, assertRoute
}
type RouteCheck struct {
NoAuthorize bool
AssertAction rbac.Action
AssertObject rbac.Object
StatusCode int
}
type AuthTester struct {
t *testing.T
api *coderd.API
authorizer *RecordingAuthorizer
Client *codersdk.Client
Workspace codersdk.Workspace
Organization codersdk.Organization
Admin codersdk.CreateFirstUserResponse
Template codersdk.Template
Version codersdk.TemplateVersion
WorkspaceResource codersdk.WorkspaceResource
File codersdk.UploadResponse
TemplateVersionDryRun codersdk.ProvisionerJob
TemplateParam codersdk.Parameter
URLParams map[string]string
}
func NewAuthTester(ctx context.Context, t *testing.T, client *codersdk.Client, api *coderd.API, admin codersdk.CreateFirstUserResponse) *AuthTester {
authorizer, ok := api.Authorizer.(*RecordingAuthorizer)
if !ok {
t.Fail()
}
// The provisioner will call to coderd and register itself. This is async,
// so we wait for it to occur.
require.Eventually(t, func() bool {
provisionerds, err := client.ProvisionerDaemons(ctx)
return assert.NoError(t, err) && len(provisionerds) > 0
}, testutil.WaitLong, testutil.IntervalSlow)
provisionerds, err := client.ProvisionerDaemons(ctx)
require.NoError(t, err, "fetch provisioners")
require.Len(t, provisionerds, 1)
organization, err := client.Organization(ctx, admin.OrganizationID)
require.NoError(t, err, "fetch org")
// Setup some data in the database.
version := CreateTemplateVersion(t, client, admin.OrganizationID, &echo.Responses{
Parse: echo.ParseComplete,
Provision: []*proto.Provision_Response{{
Type: &proto.Provision_Response_Complete{
Complete: &proto.Provision_Complete{
// Return a workspace resource
Resources: []*proto.Resource{{
Name: "some",
Type: "example",
Agents: []*proto.Agent{{
Name: "agent",
Id: "something",
Auth: &proto.Agent_Token{},
Apps: []*proto.App{{
Name: "testapp",
Url: "http://localhost:3000",
}},
}},
}},
},
},
}},
})
AwaitTemplateVersionJob(t, client, version.ID)
template := CreateTemplate(t, client, admin.OrganizationID, version.ID)
workspace := CreateWorkspace(t, client, admin.OrganizationID, template.ID)
AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
file, err := client.Upload(ctx, codersdk.ContentTypeTar, make([]byte, 1024))
require.NoError(t, err, "upload file")
workspaceResources, err := client.WorkspaceResourcesByBuild(ctx, workspace.LatestBuild.ID)
require.NoError(t, err, "workspace resources")
templateVersionDryRun, err := client.CreateTemplateVersionDryRun(ctx, version.ID, codersdk.CreateTemplateVersionDryRunRequest{
ParameterValues: []codersdk.CreateParameterRequest{},
})
require.NoError(t, err, "template version dry-run")
templateParam, err := client.CreateParameter(ctx, codersdk.ParameterTemplate, template.ID, codersdk.CreateParameterRequest{
Name: "test-param",
SourceValue: "hello world",
SourceScheme: codersdk.ParameterSourceSchemeData,
DestinationScheme: codersdk.ParameterDestinationSchemeProvisionerVariable,
})
require.NoError(t, err, "create template param")
urlParameters := map[string]string{
"{organization}": admin.OrganizationID.String(),
"{user}": admin.UserID.String(),
"{organizationname}": organization.Name,
"{workspace}": workspace.ID.String(),
"{workspacebuild}": workspace.LatestBuild.ID.String(),
"{workspacename}": workspace.Name,
"{workspaceagent}": workspaceResources[0].Agents[0].ID.String(),
"{buildnumber}": strconv.FormatInt(int64(workspace.LatestBuild.BuildNumber), 10),
"{template}": template.ID.String(),
"{hash}": file.Hash,
"{workspaceresource}": workspaceResources[0].ID.String(),
"{workspaceapp}": workspaceResources[0].Agents[0].Apps[0].Name,
"{templateversion}": version.ID.String(),
"{jobID}": templateVersionDryRun.ID.String(),
"{templatename}": template.Name,
"{workspace_and_agent}": workspace.Name + "." + workspaceResources[0].Agents[0].Name,
// Only checking template scoped params here
"parameters/{scope}/{id}": fmt.Sprintf("parameters/%s/%s",
string(templateParam.Scope), templateParam.ScopeID.String()),
}
return &AuthTester{
t: t,
api: api,
authorizer: authorizer,
Client: client,
Workspace: workspace,
Organization: organization,
Admin: admin,
Template: template,
Version: version,
WorkspaceResource: workspaceResources[0],
File: file,
TemplateVersionDryRun: templateVersionDryRun,
TemplateParam: templateParam,
URLParams: urlParameters,
}
}
func (a *AuthTester) Test(ctx context.Context, assertRoute map[string]RouteCheck, skipRoutes map[string]string) {
// Always fail auth from this point forward
a.authorizer.AlwaysReturn = rbac.ForbiddenWithInternal(xerrors.New("fake implementation"), nil, nil)
@ -433,7 +423,7 @@ func (a *AuthTester) Test(ctx context.Context, assertRoute map[string]RouteCheck
}
err := chi.Walk(
a.api.Handler,
a.api.RootHandler,
func(
method string,
route string,
@ -513,14 +503,14 @@ type authCall struct {
Object rbac.Object
}
type recordingAuthorizer struct {
type RecordingAuthorizer struct {
Called *authCall
AlwaysReturn error
}
var _ rbac.Authorizer = (*recordingAuthorizer)(nil)
var _ rbac.Authorizer = (*RecordingAuthorizer)(nil)
func (r *recordingAuthorizer) ByRoleName(_ context.Context, subjectID string, roleNames []string, scope rbac.Scope, action rbac.Action, object rbac.Object) error {
func (r *RecordingAuthorizer) ByRoleName(_ context.Context, subjectID string, roleNames []string, scope rbac.Scope, action rbac.Action, object rbac.Object) error {
r.Called = &authCall{
SubjectID: subjectID,
Roles: roleNames,
@ -531,7 +521,7 @@ func (r *recordingAuthorizer) ByRoleName(_ context.Context, subjectID string, ro
return r.AlwaysReturn
}
func (r *recordingAuthorizer) PrepareByRoleName(_ context.Context, subjectID string, roles []string, scope rbac.Scope, action rbac.Action, _ string) (rbac.PreparedAuthorized, error) {
func (r *RecordingAuthorizer) PrepareByRoleName(_ context.Context, subjectID string, roles []string, scope rbac.Scope, action rbac.Action, _ string) (rbac.PreparedAuthorized, error) {
return &fakePreparedAuthorizer{
Original: r,
SubjectID: subjectID,
@ -541,12 +531,12 @@ func (r *recordingAuthorizer) PrepareByRoleName(_ context.Context, subjectID str
}, nil
}
func (r *recordingAuthorizer) reset() {
func (r *RecordingAuthorizer) reset() {
r.Called = nil
}
type fakePreparedAuthorizer struct {
Original *recordingAuthorizer
Original *RecordingAuthorizer
SubjectID string
Roles []string
Scope rbac.Scope

View File

@ -0,0 +1,20 @@
package coderdtest_test
import (
"context"
"testing"
"github.com/coder/coder/coderd/coderdtest"
)
func TestAuthorizeAllEndpoints(t *testing.T) {
t.Parallel()
client, _, api := coderdtest.NewWithAPI(t, &coderdtest.Options{
Authorizer: &coderdtest.RecordingAuthorizer{},
IncludeProvisionerDaemon: true,
})
admin := coderdtest.CreateFirstUser(t, client)
a := coderdtest.NewAuthTester(context.Background(), t, client, api, admin)
skipRoute, assertRoute := coderdtest.AGPLRoutes(a)
a.Test(context.Background(), assertRoute, skipRoute)
}

View File

@ -80,7 +80,6 @@ type Options struct {
// IncludeProvisionerDaemon when true means to start an in-memory provisionerD
IncludeProvisionerDaemon bool
APIBuilder func(*coderd.Options) *coderd.API
MetricsCacheRefreshInterval time.Duration
AgentStatsRefreshInterval time.Duration
}
@ -112,14 +111,11 @@ func NewWithProvisionerCloser(t *testing.T, options *Options) (*codersdk.Client,
// and is a temporary measure while the API to register provisioners is ironed
// out.
func newWithCloser(t *testing.T, options *Options) (*codersdk.Client, io.Closer) {
client, closer, _ := newWithAPI(t, options)
client, closer, _ := NewWithAPI(t, options)
return client, closer
}
// newWithAPI constructs an in-memory API instance and returns a client to talk to it.
// Most tests never need a reference to the API, but AuthorizationTest in this module uses it.
// Do not expose the API or wrath shall descend upon thee.
func newWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *coderd.API) {
func NewOptions(t *testing.T, options *Options) (*httptest.Server, context.CancelFunc, *coderd.Options) {
if options == nil {
options = &Options{}
}
@ -140,9 +136,6 @@ func newWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c
close(options.AutobuildStats)
})
}
if options.APIBuilder == nil {
options.APIBuilder = coderd.New
}
// This can be hotswapped for a live database instance.
db := databasefake.New()
@ -166,8 +159,6 @@ func newWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c
}
ctx, cancelFunc := context.WithCancel(context.Background())
defer t.Cleanup(cancelFunc) // Defer to ensure cancelFunc is executed first.
lifecycleExecutor := executor.New(
ctx,
db,
@ -201,13 +192,7 @@ func newWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c
options.SSHKeygenAlgorithm = gitsshkey.AlgorithmEd25519
}
features := coderd.DisabledImplementations
if options.Auditor != nil {
features.Auditor = options.Auditor
}
// We set the handler after server creation for the access URL.
coderAPI := options.APIBuilder(&coderd.Options{
return srv, cancelFunc, &coderd.Options{
AgentConnectionUpdateFrequency: 150 * time.Millisecond,
// Force a long disconnection timeout to ensure
// agents are not marked as disconnected during slow tests.
@ -218,6 +203,7 @@ func newWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c
Database: db,
Pubsub: pubsub,
Auditor: options.Auditor,
AWSCertificates: options.AWSCertificates,
AzureCertificates: options.AzureCertificates,
GithubOAuth2Config: options.GithubOAuth2Config,
@ -248,22 +234,30 @@ func newWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c
AutoImportTemplates: options.AutoImportTemplates,
MetricsCacheRefreshInterval: options.MetricsCacheRefreshInterval,
AgentStatsRefreshInterval: options.AgentStatsRefreshInterval,
FeaturesService: coderd.NewMockFeaturesService(features),
})
t.Cleanup(func() {
_ = coderAPI.Close()
})
srv.Config.Handler = coderAPI.Handler
}
}
// NewWithAPI constructs an in-memory API instance and returns a client to talk to it.
// Most tests never need a reference to the API, but AuthorizationTest in this module uses it.
// Do not expose the API or wrath shall descend upon thee.
func NewWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *coderd.API) {
if options == nil {
options = &Options{}
}
srv, cancelFunc, newOptions := NewOptions(t, options)
// We set the handler after server creation for the access URL.
coderAPI := coderd.New(newOptions)
srv.Config.Handler = coderAPI.RootHandler
var provisionerCloser io.Closer = nopcloser{}
if options.IncludeProvisionerDaemon {
provisionerCloser = NewProvisionerDaemon(t, coderAPI)
}
t.Cleanup(func() {
cancelFunc()
_ = provisionerCloser.Close()
_ = coderAPI.Close()
})
return codersdk.New(serverURL), provisionerCloser, coderAPI
return codersdk.New(coderAPI.AccessURL), provisionerCloser, coderAPI
}
// NewProvisionerDaemon launches a provisionerd instance configured to work

View File

@ -1,97 +0,0 @@
package coderd
import (
"net/http"
"reflect"
"golang.org/x/xerrors"
"github.com/coder/coder/coderd/audit"
"github.com/coder/coder/coderd/features"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/codersdk"
)
func NewMockFeaturesService(feats FeatureInterfaces) features.Service {
return &featuresService{
feats: &feats,
}
}
type featuresService struct {
feats *FeatureInterfaces
}
func (*featuresService) EntitlementsAPI(rw http.ResponseWriter, _ *http.Request) {
feats := make(map[string]codersdk.Feature)
for _, f := range codersdk.FeatureNames {
feats[f] = codersdk.Feature{
Entitlement: codersdk.EntitlementNotEntitled,
Enabled: false,
}
}
httpapi.Write(rw, http.StatusOK, codersdk.Entitlements{
Features: feats,
Warnings: []string{},
HasLicense: false,
})
}
// Get returns the implementations for feature interfaces. Parameter `s` must be a pointer to a
// struct type containing feature interfaces as fields. The AGPL featureService always returns the
// "disabled" version of the feature interface because it doesn't include any enterprise features
// by definition.
func (f *featuresService) Get(ps any) error {
if reflect.TypeOf(ps).Kind() != reflect.Pointer {
return xerrors.New("input must be pointer to struct")
}
vs := reflect.ValueOf(ps).Elem()
if vs.Kind() != reflect.Struct {
return xerrors.New("input must be pointer to struct")
}
for i := 0; i < vs.NumField(); i++ {
vf := vs.Field(i)
tf := vf.Type()
if tf.Kind() != reflect.Interface {
return xerrors.Errorf("fields of input struct must be interfaces: %s", tf.String())
}
err := f.setImplementation(vf, tf)
if err != nil {
return err
}
}
return nil
}
// setImplementation finds the correct implementation for the field's type, and sets it on the
// struct. It returns an error if unsuccessful
func (f *featuresService) setImplementation(vf reflect.Value, tf reflect.Type) error {
feats := f.feats
if feats == nil {
feats = &DisabledImplementations
}
// when we get more than a few features it might make sense to have a data structure for finding
// the correct implementation that's faster than just a linear search, but for now just spin
// through the implementations we have.
vd := reflect.ValueOf(*feats)
for j := 0; j < vd.NumField(); j++ {
vdf := vd.Field(j)
if vdf.Type() == tf {
vf.Set(vdf)
return nil
}
}
return xerrors.Errorf("unable to find implementation of interface %s", tf.String())
}
// FeatureInterfaces contains a field for each interface controlled by an enterprise feature.
type FeatureInterfaces struct {
Auditor audit.Auditor
}
// DisabledImplementations includes all the implementations of turned-off features. There are no
// turned-on implementations in AGPL code.
var DisabledImplementations = FeatureInterfaces{
Auditor: audit.NewNop(),
}

View File

@ -1,13 +0,0 @@
package features
import "net/http"
// Service is the interface for interacting with enterprise features.
type Service interface {
EntitlementsAPI(w http.ResponseWriter, r *http.Request)
// Get returns the implementations for feature interfaces. Parameter `s` must be a pointer to a
// struct type containing feature interfaces as fields. The FeatureService sets all fields to
// the correct implementations depending on whether the features are turned on.
Get(s any) error
}

View File

@ -1,100 +0,0 @@
package coderd
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coder/coder/coderd/audit"
"github.com/coder/coder/codersdk"
)
func TestEntitlements(t *testing.T) {
t.Parallel()
t.Run("GET", func(t *testing.T) {
t.Parallel()
r := httptest.NewRequest("GET", "https://example.com/api/v2/entitlements", nil)
rw := httptest.NewRecorder()
(&featuresService{}).EntitlementsAPI(rw, r)
resp := rw.Result()
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
dec := json.NewDecoder(resp.Body)
var result codersdk.Entitlements
err := dec.Decode(&result)
require.NoError(t, err)
assert.False(t, result.HasLicense)
assert.Empty(t, result.Warnings)
for _, f := range codersdk.FeatureNames {
require.Contains(t, result.Features, f)
fe := result.Features[f]
assert.False(t, fe.Enabled)
assert.Equal(t, codersdk.EntitlementNotEntitled, fe.Entitlement)
}
})
}
func TestFeaturesServiceGet(t *testing.T) {
t.Parallel()
t.Run("Auditor", func(t *testing.T) {
t.Parallel()
uut := featuresService{}
target := struct {
Auditor audit.Auditor
}{}
err := uut.Get(&target)
require.NoError(t, err)
assert.NotNil(t, target.Auditor)
})
t.Run("NotPointer", func(t *testing.T) {
t.Parallel()
uut := featuresService{}
target := struct {
Auditor audit.Auditor
}{}
err := uut.Get(target)
require.Error(t, err)
assert.Nil(t, target.Auditor)
})
t.Run("UnknownInterface", func(t *testing.T) {
t.Parallel()
uut := featuresService{}
target := struct {
test testInterface
}{}
err := uut.Get(&target)
require.Error(t, err)
assert.Nil(t, target.test)
})
t.Run("PointerToNonStruct", func(t *testing.T) {
t.Parallel()
uut := featuresService{}
var target audit.Auditor
err := uut.Get(&target)
require.Error(t, err)
assert.Nil(t, target)
})
t.Run("StructWithNonInterfaces", func(t *testing.T) {
t.Parallel()
uut := featuresService{}
target := struct {
N int64
Auditor audit.Auditor
}{}
err := uut.Get(&target)
require.Error(t, err)
assert.Nil(t, target.Auditor)
})
}
type testInterface interface {
Test() error
}

View File

@ -1,24 +0,0 @@
package coderd
import (
"net/http"
"github.com/go-chi/chi/v5"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/codersdk"
)
func licenses() http.Handler {
r := chi.NewRouter()
r.NotFound(unsupported)
return r
}
func unsupported(rw http.ResponseWriter, _ *http.Request) {
httpapi.Write(rw, http.StatusNotFound, codersdk.Response{
Message: "Unsupported",
Detail: "These endpoints are not supported in AGPL-licensed Coder",
Validations: nil,
})
}

View File

@ -48,7 +48,7 @@ func (api *API) provisionerDaemons(rw http.ResponseWriter, r *http.Request) {
if daemons == nil {
daemons = []database.ProvisionerDaemon{}
}
daemons, err = AuthorizeFilter(api.httpAuth, r, rbac.ActionRead, daemons)
daemons, err = AuthorizeFilter(api.HTTPAuth, r, rbac.ActionRead, daemons)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching provisioner daemons.",

View File

@ -41,7 +41,7 @@ func TestProvisionerJobLogs_Unit(t *testing.T) {
api := New(&opts)
defer api.Close()
server := httptest.NewServer(api.Handler)
server := httptest.NewServer(api.RootHandler)
defer server.Close()
userID := uuid.New()
keyID, keySecret, err := generateAPIKeyIDSecret()

View File

@ -85,11 +85,12 @@ func (api *API) template(rw http.ResponseWriter, r *http.Request) {
func (api *API) deleteTemplate(rw http.ResponseWriter, r *http.Request) {
var (
template = httpmw.TemplateParam(r)
auditor = *api.Auditor.Load()
aReq, commitAudit = audit.InitRequest[database.Template](rw, &audit.RequestParams{
Features: api.FeaturesService,
Log: api.Logger,
Request: r,
Action: database.AuditActionDelete,
Audit: auditor,
Log: api.Logger,
Request: r,
Action: database.AuditActionDelete,
})
)
defer commitAudit()
@ -139,17 +140,18 @@ func (api *API) postTemplateByOrganization(rw http.ResponseWriter, r *http.Reque
createTemplate codersdk.CreateTemplateRequest
organization = httpmw.OrganizationParam(r)
apiKey = httpmw.APIKey(r)
auditor = *api.Auditor.Load()
templateAudit, commitTemplateAudit = audit.InitRequest[database.Template](rw, &audit.RequestParams{
Features: api.FeaturesService,
Log: api.Logger,
Request: r,
Action: database.AuditActionCreate,
Audit: auditor,
Log: api.Logger,
Request: r,
Action: database.AuditActionCreate,
})
templateVersionAudit, commitTemplateVersionAudit = audit.InitRequest[database.TemplateVersion](rw, &audit.RequestParams{
Features: api.FeaturesService,
Log: api.Logger,
Request: r,
Action: database.AuditActionWrite,
Audit: auditor,
Log: api.Logger,
Request: r,
Action: database.AuditActionWrite,
})
)
defer commitTemplateAudit()
@ -340,7 +342,7 @@ func (api *API) templatesByOrganization(rw http.ResponseWriter, r *http.Request)
}
// Filter templates based on rbac permissions
templates, err = AuthorizeFilter(api.httpAuth, r, rbac.ActionRead, templates)
templates, err = AuthorizeFilter(api.HTTPAuth, r, rbac.ActionRead, templates)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching templates.",
@ -435,11 +437,12 @@ func (api *API) templateByOrganizationAndName(rw http.ResponseWriter, r *http.Re
func (api *API) patchTemplateMeta(rw http.ResponseWriter, r *http.Request) {
var (
template = httpmw.TemplateParam(r)
auditor = *api.Auditor.Load()
aReq, commitAudit = audit.InitRequest[database.Template](rw, &audit.RequestParams{
Features: api.FeaturesService,
Log: api.Logger,
Request: r,
Action: database.AuditActionWrite,
Audit: auditor,
Log: api.Logger,
Request: r,
Action: database.AuditActionWrite,
})
)
defer commitAudit()

View File

@ -559,11 +559,12 @@ func (api *API) templateVersionByName(rw http.ResponseWriter, r *http.Request) {
func (api *API) patchActiveTemplateVersion(rw http.ResponseWriter, r *http.Request) {
var (
template = httpmw.TemplateParam(r)
auditor = *api.Auditor.Load()
aReq, commitAudit = audit.InitRequest[database.Template](rw, &audit.RequestParams{
Features: api.FeaturesService,
Log: api.Logger,
Request: r,
Action: database.AuditActionWrite,
Audit: auditor,
Log: api.Logger,
Request: r,
Action: database.AuditActionWrite,
})
)
defer commitAudit()
@ -631,11 +632,12 @@ func (api *API) postTemplateVersionsByOrganization(rw http.ResponseWriter, r *ht
var (
apiKey = httpmw.APIKey(r)
organization = httpmw.OrganizationParam(r)
auditor = *api.Auditor.Load()
aReq, commitAudit = audit.InitRequest[database.TemplateVersion](rw, &audit.RequestParams{
Features: api.FeaturesService,
Log: api.Logger,
Request: r,
Action: database.AuditActionCreate,
Audit: auditor,
Log: api.Logger,
Request: r,
Action: database.AuditActionCreate,
})
req codersdk.CreateTemplateVersionRequest

View File

@ -220,7 +220,7 @@ func (api *API) users(rw http.ResponseWriter, r *http.Request) {
return
}
users, err = AuthorizeFilter(api.httpAuth, r, rbac.ActionRead, users)
users, err = AuthorizeFilter(api.HTTPAuth, r, rbac.ActionRead, users)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching users.",
@ -255,11 +255,12 @@ func (api *API) users(rw http.ResponseWriter, r *http.Request) {
// Creates a new user.
func (api *API) postUser(rw http.ResponseWriter, r *http.Request) {
auditor := *api.Auditor.Load()
aReq, commitAudit := audit.InitRequest[database.User](rw, &audit.RequestParams{
Features: api.FeaturesService,
Log: api.Logger,
Request: r,
Action: database.AuditActionCreate,
Audit: auditor,
Log: api.Logger,
Request: r,
Action: database.AuditActionCreate,
})
defer commitAudit()
@ -339,12 +340,13 @@ func (api *API) postUser(rw http.ResponseWriter, r *http.Request) {
}
func (api *API) deleteUser(rw http.ResponseWriter, r *http.Request) {
auditor := *api.Auditor.Load()
user := httpmw.UserParam(r)
aReq, commitAudit := audit.InitRequest[database.User](rw, &audit.RequestParams{
Features: api.FeaturesService,
Log: api.Logger,
Request: r,
Action: database.AuditActionDelete,
Audit: auditor,
Log: api.Logger,
Request: r,
Action: database.AuditActionDelete,
})
aReq.Old = user
defer commitAudit()
@ -414,11 +416,12 @@ func (api *API) userByName(rw http.ResponseWriter, r *http.Request) {
func (api *API) putUserProfile(rw http.ResponseWriter, r *http.Request) {
var (
user = httpmw.UserParam(r)
auditor = *api.Auditor.Load()
aReq, commitAudit = audit.InitRequest[database.User](rw, &audit.RequestParams{
Features: api.FeaturesService,
Log: api.Logger,
Request: r,
Action: database.AuditActionWrite,
Audit: auditor,
Log: api.Logger,
Request: r,
Action: database.AuditActionWrite,
})
)
defer commitAudit()
@ -494,11 +497,12 @@ func (api *API) putUserStatus(status database.UserStatus) func(rw http.ResponseW
var (
user = httpmw.UserParam(r)
apiKey = httpmw.APIKey(r)
auditor = *api.Auditor.Load()
aReq, commitAudit = audit.InitRequest[database.User](rw, &audit.RequestParams{
Features: api.FeaturesService,
Log: api.Logger,
Request: r,
Action: database.AuditActionWrite,
Audit: auditor,
Log: api.Logger,
Request: r,
Action: database.AuditActionWrite,
})
)
defer commitAudit()
@ -560,11 +564,12 @@ func (api *API) putUserPassword(rw http.ResponseWriter, r *http.Request) {
var (
user = httpmw.UserParam(r)
params codersdk.UpdateUserPasswordRequest
auditor = *api.Auditor.Load()
aReq, commitAudit = audit.InitRequest[database.User](rw, &audit.RequestParams{
Features: api.FeaturesService,
Log: api.Logger,
Request: r,
Action: database.AuditActionWrite,
Audit: auditor,
Log: api.Logger,
Request: r,
Action: database.AuditActionWrite,
})
)
defer commitAudit()
@ -673,7 +678,7 @@ func (api *API) userRoles(rw http.ResponseWriter, r *http.Request) {
}
// Only include ones we can read from RBAC.
memberships, err = AuthorizeFilter(api.httpAuth, r, rbac.ActionRead, memberships)
memberships, err = AuthorizeFilter(api.HTTPAuth, r, rbac.ActionRead, memberships)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching memberships.",
@ -698,11 +703,12 @@ func (api *API) putUserRoles(rw http.ResponseWriter, r *http.Request) {
user = httpmw.UserParam(r)
actorRoles = httpmw.UserAuthorization(r)
apiKey = httpmw.APIKey(r)
auditor = *api.Auditor.Load()
aReq, commitAudit = audit.InitRequest[database.User](rw, &audit.RequestParams{
Features: api.FeaturesService,
Log: api.Logger,
Request: r,
Action: database.AuditActionWrite,
Audit: auditor,
Log: api.Logger,
Request: r,
Action: database.AuditActionWrite,
})
)
defer commitAudit()
@ -812,7 +818,7 @@ func (api *API) organizationsByUser(rw http.ResponseWriter, r *http.Request) {
}
// Only return orgs the user can read.
organizations, err = AuthorizeFilter(api.httpAuth, r, rbac.ActionRead, organizations)
organizations, err = AuthorizeFilter(api.HTTPAuth, r, rbac.ActionRead, organizations)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching organizations.",
@ -1176,9 +1182,9 @@ func (api *API) createUser(ctx context.Context, store database.Store, req create
func (api *API) setAuthCookie(rw http.ResponseWriter, cookie *http.Cookie) {
http.SetCookie(rw, cookie)
devurlCookie := api.applicationCookie(cookie)
if devurlCookie != nil {
http.SetCookie(rw, devurlCookie)
appCookie := api.applicationCookie(cookie)
if appCookie != nil {
http.SetCookie(rw, appCookie)
}
}

View File

@ -32,7 +32,11 @@ func TestFirstUser(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
_, err := client.CreateFirstUser(ctx, codersdk.CreateFirstUserRequest{})
has, err := client.HasFirstUser(context.Background())
require.NoError(t, err)
require.False(t, has)
_, err = client.CreateFirstUser(ctx, codersdk.CreateFirstUserRequest{})
require.Error(t, err)
})

View File

@ -119,7 +119,7 @@ func (api *API) workspaces(rw http.ResponseWriter, r *http.Request) {
}
// Only return workspaces the user can read
workspaces, err = AuthorizeFilter(api.httpAuth, r, rbac.ActionRead, workspaces)
workspaces, err = AuthorizeFilter(api.HTTPAuth, r, rbac.ActionRead, workspaces)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching workspaces.",
@ -217,11 +217,12 @@ func (api *API) postWorkspacesByOrganization(rw http.ResponseWriter, r *http.Req
var (
organization = httpmw.OrganizationParam(r)
apiKey = httpmw.APIKey(r)
auditor = api.Auditor.Load()
aReq, commitAudit = audit.InitRequest[database.Workspace](rw, &audit.RequestParams{
Features: api.FeaturesService,
Log: api.Logger,
Request: r,
Action: database.AuditActionCreate,
Audit: *auditor,
Log: api.Logger,
Request: r,
Action: database.AuditActionCreate,
})
)
defer commitAudit()
@ -480,11 +481,12 @@ func (api *API) postWorkspacesByOrganization(rw http.ResponseWriter, r *http.Req
func (api *API) patchWorkspace(rw http.ResponseWriter, r *http.Request) {
var (
workspace = httpmw.WorkspaceParam(r)
auditor = api.Auditor.Load()
aReq, commitAudit = audit.InitRequest[database.Workspace](rw, &audit.RequestParams{
Features: api.FeaturesService,
Log: api.Logger,
Request: r,
Action: database.AuditActionWrite,
Audit: *auditor,
Log: api.Logger,
Request: r,
Action: database.AuditActionWrite,
})
)
defer commitAudit()
@ -556,11 +558,12 @@ func (api *API) patchWorkspace(rw http.ResponseWriter, r *http.Request) {
func (api *API) putWorkspaceAutostart(rw http.ResponseWriter, r *http.Request) {
var (
workspace = httpmw.WorkspaceParam(r)
auditor = api.Auditor.Load()
aReq, commitAudit = audit.InitRequest[database.Workspace](rw, &audit.RequestParams{
Features: api.FeaturesService,
Log: api.Logger,
Request: r,
Action: database.AuditActionWrite,
Audit: *auditor,
Log: api.Logger,
Request: r,
Action: database.AuditActionWrite,
})
)
defer commitAudit()
@ -616,11 +619,12 @@ func (api *API) putWorkspaceAutostart(rw http.ResponseWriter, r *http.Request) {
func (api *API) putWorkspaceTTL(rw http.ResponseWriter, r *http.Request) {
var (
workspace = httpmw.WorkspaceParam(r)
auditor = api.Auditor.Load()
aReq, commitAudit = audit.InitRequest[database.Workspace](rw, &audit.RequestParams{
Features: api.FeaturesService,
Log: api.Logger,
Request: r,
Action: database.AuditActionWrite,
Audit: *auditor,
Log: api.Logger,
Request: r,
Action: database.AuditActionWrite,
})
)
defer commitAudit()