feat: implement provisioner auth middleware and proper org params (#12330)

* feat: provisioner auth in mw to allow ExtractOrg

Step to enable org scoped provisioner daemons

* chore: handle default org handling for provisioner daemons
This commit is contained in:
Steven Masley
2024-03-04 15:15:41 -06:00
committed by GitHub
parent 926fd7ffa6
commit 5c6974e55f
11 changed files with 201 additions and 30 deletions

View File

@ -292,6 +292,15 @@ func New(ctx context.Context, options *Options) (_ *API, err error) {
r.Route("/organizations/{organization}/provisionerdaemons", func(r chi.Router) {
r.Use(
api.provisionerDaemonsEnabledMW,
apiKeyMiddlewareOptional,
httpmw.ExtractProvisionerDaemonAuthenticated(httpmw.ExtractProvisionerAuthConfig{
DB: api.Database,
Optional: true,
}, api.ProvisionerDaemonPSK),
// Either a user auth or provisioner auth is required
// to move forward.
httpmw.RequireAPIKeyOrProvisionerDaemonAuth(),
httpmw.ExtractOrganizationParam(api.Database),
)
r.With(apiKeyMiddleware).Get("/", api.provisionerDaemons)
r.With(apiKeyMiddlewareOptional).Get("/serve", api.provisionerDaemonServe)

View File

@ -2,7 +2,6 @@ package coderd
import (
"context"
"crypto/subtle"
"database/sql"
"errors"
"fmt"
@ -86,11 +85,8 @@ func (api *API) provisionerDaemons(rw http.ResponseWriter, r *http.Request) {
})
return
}
apiDaemons := make([]codersdk.ProvisionerDaemon, 0)
for _, daemon := range daemons {
apiDaemons = append(apiDaemons, db2sdk.ProvisionerDaemon(daemon))
}
httpapi.Write(ctx, rw, http.StatusOK, apiDaemons)
httpapi.Write(ctx, rw, http.StatusOK, db2sdk.List(daemons, db2sdk.ProvisionerDaemon))
}
type provisionerDaemonAuth struct {
@ -118,13 +114,11 @@ func (p *provisionerDaemonAuth) authorize(r *http.Request, tags map[string]strin
}
// Check for PSK
if p.psk != "" {
psk := r.Header.Get(codersdk.ProvisionerDaemonPSK)
if subtle.ConstantTimeCompare([]byte(p.psk), []byte(psk)) == 1 {
// If using PSK auth, the daemon is, by definition, scoped to the organization.
tags = provisionersdk.MutateTags(uuid.Nil, tags)
return tags, true
}
provAuth := httpmw.ProvisionerDaemonAuthenticated(r)
if provAuth {
// If using PSK auth, the daemon is, by definition, scoped to the organization.
tags = provisionersdk.MutateTags(uuid.Nil, tags)
return tags, true
}
return nil, false
}

View File

@ -350,6 +350,7 @@ func TestProvisionerDaemonServe(t *testing.T) {
t.Run("PSK_daily_cost", func(t *testing.T) {
t.Parallel()
const provPSK = `provisionersftw`
client, user := coderdenttest.New(t, &coderdenttest.Options{
UserWorkspaceQuota: 10,
LicenseOptions: &coderdenttest.LicenseOptions{
@ -358,7 +359,7 @@ func TestProvisionerDaemonServe(t *testing.T) {
codersdk.FeatureTemplateRBAC: 1,
},
},
ProvisionerDaemonPSK: "provisionersftw",
ProvisionerDaemonPSK: provPSK,
})
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
@ -397,7 +398,7 @@ func TestProvisionerDaemonServe(t *testing.T) {
Tags: map[string]string{
provisionersdk.TagScope: provisionersdk.ScopeOrganization,
},
PreSharedKey: "provisionersftw",
PreSharedKey: provPSK,
})
}, &provisionerd.Options{
Logger: logger.Named("provisionerd"),
@ -480,7 +481,7 @@ func TestProvisionerDaemonServe(t *testing.T) {
require.Error(t, err)
var apiError *codersdk.Error
require.ErrorAs(t, err, &apiError)
require.Equal(t, http.StatusForbidden, apiError.StatusCode())
require.Equal(t, http.StatusUnauthorized, apiError.StatusCode())
daemons, err := client.ProvisionerDaemons(ctx) //nolint:gocritic // Test assertion.
require.NoError(t, err)
@ -514,7 +515,7 @@ func TestProvisionerDaemonServe(t *testing.T) {
require.Error(t, err)
var apiError *codersdk.Error
require.ErrorAs(t, err, &apiError)
require.Equal(t, http.StatusForbidden, apiError.StatusCode())
require.Equal(t, http.StatusUnauthorized, apiError.StatusCode())
daemons, err := client.ProvisionerDaemons(ctx) //nolint:gocritic // Test assertion.
require.NoError(t, err)
@ -548,7 +549,7 @@ func TestProvisionerDaemonServe(t *testing.T) {
require.Error(t, err)
var apiError *codersdk.Error
require.ErrorAs(t, err, &apiError)
require.Equal(t, http.StatusForbidden, apiError.StatusCode())
require.Equal(t, http.StatusUnauthorized, apiError.StatusCode())
daemons, err := client.ProvisionerDaemons(ctx) //nolint:gocritic // Test assertion.
require.NoError(t, err)