mirror of
https://github.com/coder/coder.git
synced 2025-07-15 22:20:27 +00:00
chore: improve testing coverage on ExtractProvisionerDaemonAuthenticated middleware (#15622)
This one aims to resolve #15604 Created some table tests for the main cases - also preferred to create two isolated cases for the most complicated cases in order to keep table tests simple enough. Give us full coverage on the middleware logic, for both optional and non optional cases - PSK and ProvisionerKey.
This commit is contained in:
@ -25,6 +25,9 @@ type ExtractProvisionerAuthConfig struct {
|
||||
PSK string
|
||||
}
|
||||
|
||||
// ExtractProvisionerDaemonAuthenticated authenticates a request as a provisioner daemon.
|
||||
// If the request is not authenticated, the next handler is called unless Optional is true.
|
||||
// This function currently is tested inside the enterprise package.
|
||||
func ExtractProvisionerDaemonAuthenticated(opts ExtractProvisionerAuthConfig) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
5
enterprise/coderd/httpmw/doc.go
Normal file
5
enterprise/coderd/httpmw/doc.go
Normal file
@ -0,0 +1,5 @@
|
||||
// Package httpmw contains middleware for HTTP handlers.
|
||||
// Currently, the tested middleware is inside the AGPL package.
|
||||
// As the middleware also contains enterprise-only logic, tests had to be
|
||||
// moved here.
|
||||
package httpmw
|
290
enterprise/coderd/httpmw/provisionerdaemon_test.go
Normal file
290
enterprise/coderd/httpmw/provisionerdaemon_test.go
Normal file
@ -0,0 +1,290 @@
|
||||
package httpmw_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/enterprise/coderd/coderdenttest"
|
||||
"github.com/coder/coder/v2/enterprise/coderd/license"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestExtractProvisionerDaemonAuthenticated(t *testing.T) {
|
||||
const (
|
||||
//nolint:gosec // test key generated by test
|
||||
functionalKey = "5Hl2Qw9kX3nM7vB4jR8pY6tA1cF0eD5uI2oL9gN3mZ4"
|
||||
)
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
opts httpmw.ExtractProvisionerAuthConfig
|
||||
expectedStatusCode int
|
||||
expectedResponseMessage string
|
||||
provisionerKey string
|
||||
provisionerPSK string
|
||||
}{
|
||||
{
|
||||
name: "NoKeyProvided_Optional",
|
||||
opts: httpmw.ExtractProvisionerAuthConfig{
|
||||
DB: nil,
|
||||
Optional: true,
|
||||
},
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "NoKeyProvided_NotOptional",
|
||||
opts: httpmw.ExtractProvisionerAuthConfig{
|
||||
DB: nil,
|
||||
Optional: false,
|
||||
},
|
||||
expectedStatusCode: http.StatusUnauthorized,
|
||||
expectedResponseMessage: "provisioner daemon key required",
|
||||
},
|
||||
{
|
||||
name: "ProvisionerKeyAndPSKProvided_NotOptional",
|
||||
opts: httpmw.ExtractProvisionerAuthConfig{
|
||||
DB: nil,
|
||||
Optional: false,
|
||||
},
|
||||
provisionerKey: "key",
|
||||
provisionerPSK: "psk",
|
||||
expectedStatusCode: http.StatusBadRequest,
|
||||
expectedResponseMessage: "provisioner daemon key and psk provided, but only one is allowed",
|
||||
},
|
||||
{
|
||||
name: "ProvisionerKeyAndPSKProvided_Optional",
|
||||
opts: httpmw.ExtractProvisionerAuthConfig{
|
||||
DB: nil,
|
||||
Optional: true,
|
||||
},
|
||||
provisionerKey: "key",
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "InvalidProvisionerKey_NotOptional",
|
||||
opts: httpmw.ExtractProvisionerAuthConfig{
|
||||
DB: nil,
|
||||
Optional: false,
|
||||
},
|
||||
provisionerKey: "invalid",
|
||||
expectedStatusCode: http.StatusBadRequest,
|
||||
expectedResponseMessage: "provisioner daemon key invalid",
|
||||
},
|
||||
{
|
||||
name: "InvalidProvisionerKey_Optional",
|
||||
opts: httpmw.ExtractProvisionerAuthConfig{
|
||||
DB: nil,
|
||||
Optional: true,
|
||||
},
|
||||
provisionerKey: "invalid",
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "InvalidProvisionerPSK_NotOptional",
|
||||
opts: httpmw.ExtractProvisionerAuthConfig{
|
||||
DB: nil,
|
||||
Optional: false,
|
||||
PSK: "psk",
|
||||
},
|
||||
provisionerPSK: "invalid",
|
||||
expectedStatusCode: http.StatusUnauthorized,
|
||||
expectedResponseMessage: "provisioner daemon psk invalid",
|
||||
},
|
||||
{
|
||||
name: "InvalidProvisionerPSK_Optional",
|
||||
opts: httpmw.ExtractProvisionerAuthConfig{
|
||||
DB: nil,
|
||||
Optional: true,
|
||||
PSK: "psk",
|
||||
},
|
||||
provisionerPSK: "invalid",
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "ValidProvisionerPSK_NotOptional",
|
||||
opts: httpmw.ExtractProvisionerAuthConfig{
|
||||
DB: nil,
|
||||
Optional: false,
|
||||
PSK: "ThisIsAValidPSK",
|
||||
},
|
||||
provisionerPSK: "ThisIsAValidPSK",
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
routeCtx := chi.NewRouteContext()
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, routeCtx))
|
||||
res := httptest.NewRecorder()
|
||||
|
||||
if test.provisionerKey != "" {
|
||||
r.Header.Set(codersdk.ProvisionerDaemonKey, test.provisionerKey)
|
||||
}
|
||||
if test.provisionerPSK != "" {
|
||||
r.Header.Set(codersdk.ProvisionerDaemonPSK, test.provisionerPSK)
|
||||
}
|
||||
|
||||
httpmw.ExtractProvisionerDaemonAuthenticated(test.opts)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})).ServeHTTP(res, r)
|
||||
|
||||
//nolint:bodyclose
|
||||
require.Equal(t, test.expectedStatusCode, res.Result().StatusCode)
|
||||
if test.expectedResponseMessage != "" {
|
||||
require.Contains(t, res.Body.String(), test.expectedResponseMessage)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("ProvisionerKey", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
client, db, user := coderdenttest.NewWithDatabase(t, &coderdenttest.Options{
|
||||
LicenseOptions: &coderdenttest.LicenseOptions{
|
||||
Features: license.Features{
|
||||
codersdk.FeatureExternalProvisionerDaemons: 1,
|
||||
},
|
||||
},
|
||||
})
|
||||
// nolint:gocritic // test
|
||||
key, err := client.CreateProvisionerKey(ctx, user.OrganizationID, codersdk.CreateProvisionerKeyRequest{
|
||||
Name: "dont-TEST-me",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
routeCtx := chi.NewRouteContext()
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, routeCtx))
|
||||
res := httptest.NewRecorder()
|
||||
|
||||
r.Header.Set(codersdk.ProvisionerDaemonKey, key.Key)
|
||||
|
||||
httpmw.ExtractProvisionerDaemonAuthenticated(httpmw.ExtractProvisionerAuthConfig{
|
||||
DB: db,
|
||||
Optional: false,
|
||||
})(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})).ServeHTTP(res, r)
|
||||
|
||||
//nolint:bodyclose
|
||||
require.Equal(t, http.StatusOK, res.Result().StatusCode)
|
||||
})
|
||||
|
||||
t.Run("ProvisionerKey_NotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
client, db, user := coderdenttest.NewWithDatabase(t, &coderdenttest.Options{
|
||||
LicenseOptions: &coderdenttest.LicenseOptions{
|
||||
Features: license.Features{
|
||||
codersdk.FeatureExternalProvisionerDaemons: 1,
|
||||
},
|
||||
},
|
||||
})
|
||||
// nolint:gocritic // test
|
||||
_, err := client.CreateProvisionerKey(ctx, user.OrganizationID, codersdk.CreateProvisionerKeyRequest{
|
||||
Name: "dont-TEST-me",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
routeCtx := chi.NewRouteContext()
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, routeCtx))
|
||||
res := httptest.NewRecorder()
|
||||
|
||||
//nolint:gosec // test key generated by test
|
||||
pkey := "5Hl2Qw9kX3nM7vB4jR8pY6tA1cF0eD5uI2oL9gN3mZ4"
|
||||
r.Header.Set(codersdk.ProvisionerDaemonKey, pkey)
|
||||
|
||||
httpmw.ExtractProvisionerDaemonAuthenticated(httpmw.ExtractProvisionerAuthConfig{
|
||||
DB: db,
|
||||
Optional: false,
|
||||
})(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})).ServeHTTP(res, r)
|
||||
|
||||
//nolint:bodyclose
|
||||
require.Equal(t, http.StatusUnauthorized, res.Result().StatusCode)
|
||||
require.Contains(t, res.Body.String(), "provisioner daemon key invalid")
|
||||
})
|
||||
|
||||
t.Run("ProvisionerKey_CompareFail", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mockDB := dbmock.NewMockStore(ctrl)
|
||||
|
||||
gomock.InOrder(
|
||||
mockDB.EXPECT().GetProvisionerKeyByHashedSecret(gomock.Any(), gomock.Any()).Times(1).Return(database.ProvisionerKey{
|
||||
ID: uuid.New(),
|
||||
HashedSecret: []byte("hashedSecret"),
|
||||
}, nil),
|
||||
)
|
||||
|
||||
routeCtx := chi.NewRouteContext()
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, routeCtx))
|
||||
res := httptest.NewRecorder()
|
||||
|
||||
r.Header.Set(codersdk.ProvisionerDaemonKey, functionalKey)
|
||||
|
||||
httpmw.ExtractProvisionerDaemonAuthenticated(httpmw.ExtractProvisionerAuthConfig{
|
||||
DB: mockDB,
|
||||
Optional: false,
|
||||
})(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})).ServeHTTP(res, r)
|
||||
|
||||
//nolint:bodyclose
|
||||
require.Equal(t, http.StatusUnauthorized, res.Result().StatusCode)
|
||||
require.Contains(t, res.Body.String(), "provisioner daemon key invalid")
|
||||
})
|
||||
|
||||
t.Run("ProvisionerKey_DBError", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mockDB := dbmock.NewMockStore(ctrl)
|
||||
|
||||
gomock.InOrder(
|
||||
mockDB.EXPECT().GetProvisionerKeyByHashedSecret(gomock.Any(), gomock.Any()).Times(1).Return(database.ProvisionerKey{}, xerrors.New("error")),
|
||||
)
|
||||
|
||||
routeCtx := chi.NewRouteContext()
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, routeCtx))
|
||||
res := httptest.NewRecorder()
|
||||
|
||||
//nolint:gosec // test key generated by test
|
||||
r.Header.Set(codersdk.ProvisionerDaemonKey, functionalKey)
|
||||
|
||||
httpmw.ExtractProvisionerDaemonAuthenticated(httpmw.ExtractProvisionerAuthConfig{
|
||||
DB: mockDB,
|
||||
Optional: false,
|
||||
})(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})).ServeHTTP(res, r)
|
||||
|
||||
//nolint:bodyclose
|
||||
require.Equal(t, http.StatusInternalServerError, res.Result().StatusCode)
|
||||
require.Contains(t, res.Body.String(), "get provisioner daemon key")
|
||||
})
|
||||
}
|
Reference in New Issue
Block a user