Files
coder/enterprise/coderd/httpmw/provisionerdaemon_test.go
Vincent Vielle 60ddcf5de2 chore: improve testing coverage on ExtractProvisionerDaemonAuthenticated middleware (#15622)
This one aims to resolve #15604 

Created some table tests for the main cases - 
also preferred to create two isolated cases for the most complicated
cases in order to keep table tests simple enough.

Give us full coverage on the middleware logic, for both optional and non
optional cases - PSK and ProvisionerKey.
2024-11-26 04:02:20 +01:00

291 lines
8.6 KiB
Go

package httpmw_test
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbmock"
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/enterprise/coderd/coderdenttest"
"github.com/coder/coder/v2/enterprise/coderd/license"
"github.com/coder/coder/v2/testutil"
)
func TestExtractProvisionerDaemonAuthenticated(t *testing.T) {
const (
//nolint:gosec // test key generated by test
functionalKey = "5Hl2Qw9kX3nM7vB4jR8pY6tA1cF0eD5uI2oL9gN3mZ4"
)
t.Parallel()
tests := []struct {
name string
opts httpmw.ExtractProvisionerAuthConfig
expectedStatusCode int
expectedResponseMessage string
provisionerKey string
provisionerPSK string
}{
{
name: "NoKeyProvided_Optional",
opts: httpmw.ExtractProvisionerAuthConfig{
DB: nil,
Optional: true,
},
expectedStatusCode: http.StatusOK,
},
{
name: "NoKeyProvided_NotOptional",
opts: httpmw.ExtractProvisionerAuthConfig{
DB: nil,
Optional: false,
},
expectedStatusCode: http.StatusUnauthorized,
expectedResponseMessage: "provisioner daemon key required",
},
{
name: "ProvisionerKeyAndPSKProvided_NotOptional",
opts: httpmw.ExtractProvisionerAuthConfig{
DB: nil,
Optional: false,
},
provisionerKey: "key",
provisionerPSK: "psk",
expectedStatusCode: http.StatusBadRequest,
expectedResponseMessage: "provisioner daemon key and psk provided, but only one is allowed",
},
{
name: "ProvisionerKeyAndPSKProvided_Optional",
opts: httpmw.ExtractProvisionerAuthConfig{
DB: nil,
Optional: true,
},
provisionerKey: "key",
expectedStatusCode: http.StatusOK,
},
{
name: "InvalidProvisionerKey_NotOptional",
opts: httpmw.ExtractProvisionerAuthConfig{
DB: nil,
Optional: false,
},
provisionerKey: "invalid",
expectedStatusCode: http.StatusBadRequest,
expectedResponseMessage: "provisioner daemon key invalid",
},
{
name: "InvalidProvisionerKey_Optional",
opts: httpmw.ExtractProvisionerAuthConfig{
DB: nil,
Optional: true,
},
provisionerKey: "invalid",
expectedStatusCode: http.StatusOK,
},
{
name: "InvalidProvisionerPSK_NotOptional",
opts: httpmw.ExtractProvisionerAuthConfig{
DB: nil,
Optional: false,
PSK: "psk",
},
provisionerPSK: "invalid",
expectedStatusCode: http.StatusUnauthorized,
expectedResponseMessage: "provisioner daemon psk invalid",
},
{
name: "InvalidProvisionerPSK_Optional",
opts: httpmw.ExtractProvisionerAuthConfig{
DB: nil,
Optional: true,
PSK: "psk",
},
provisionerPSK: "invalid",
expectedStatusCode: http.StatusOK,
},
{
name: "ValidProvisionerPSK_NotOptional",
opts: httpmw.ExtractProvisionerAuthConfig{
DB: nil,
Optional: false,
PSK: "ThisIsAValidPSK",
},
provisionerPSK: "ThisIsAValidPSK",
expectedStatusCode: http.StatusOK,
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
routeCtx := chi.NewRouteContext()
r := httptest.NewRequest(http.MethodGet, "/", nil)
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, routeCtx))
res := httptest.NewRecorder()
if test.provisionerKey != "" {
r.Header.Set(codersdk.ProvisionerDaemonKey, test.provisionerKey)
}
if test.provisionerPSK != "" {
r.Header.Set(codersdk.ProvisionerDaemonPSK, test.provisionerPSK)
}
httpmw.ExtractProvisionerDaemonAuthenticated(test.opts)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})).ServeHTTP(res, r)
//nolint:bodyclose
require.Equal(t, test.expectedStatusCode, res.Result().StatusCode)
if test.expectedResponseMessage != "" {
require.Contains(t, res.Body.String(), test.expectedResponseMessage)
}
})
}
t.Run("ProvisionerKey", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
client, db, user := coderdenttest.NewWithDatabase(t, &coderdenttest.Options{
LicenseOptions: &coderdenttest.LicenseOptions{
Features: license.Features{
codersdk.FeatureExternalProvisionerDaemons: 1,
},
},
})
// nolint:gocritic // test
key, err := client.CreateProvisionerKey(ctx, user.OrganizationID, codersdk.CreateProvisionerKeyRequest{
Name: "dont-TEST-me",
})
require.NoError(t, err)
routeCtx := chi.NewRouteContext()
r := httptest.NewRequest(http.MethodGet, "/", nil)
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, routeCtx))
res := httptest.NewRecorder()
r.Header.Set(codersdk.ProvisionerDaemonKey, key.Key)
httpmw.ExtractProvisionerDaemonAuthenticated(httpmw.ExtractProvisionerAuthConfig{
DB: db,
Optional: false,
})(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})).ServeHTTP(res, r)
//nolint:bodyclose
require.Equal(t, http.StatusOK, res.Result().StatusCode)
})
t.Run("ProvisionerKey_NotFound", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
client, db, user := coderdenttest.NewWithDatabase(t, &coderdenttest.Options{
LicenseOptions: &coderdenttest.LicenseOptions{
Features: license.Features{
codersdk.FeatureExternalProvisionerDaemons: 1,
},
},
})
// nolint:gocritic // test
_, err := client.CreateProvisionerKey(ctx, user.OrganizationID, codersdk.CreateProvisionerKeyRequest{
Name: "dont-TEST-me",
})
require.NoError(t, err)
routeCtx := chi.NewRouteContext()
r := httptest.NewRequest(http.MethodGet, "/", nil)
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, routeCtx))
res := httptest.NewRecorder()
//nolint:gosec // test key generated by test
pkey := "5Hl2Qw9kX3nM7vB4jR8pY6tA1cF0eD5uI2oL9gN3mZ4"
r.Header.Set(codersdk.ProvisionerDaemonKey, pkey)
httpmw.ExtractProvisionerDaemonAuthenticated(httpmw.ExtractProvisionerAuthConfig{
DB: db,
Optional: false,
})(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})).ServeHTTP(res, r)
//nolint:bodyclose
require.Equal(t, http.StatusUnauthorized, res.Result().StatusCode)
require.Contains(t, res.Body.String(), "provisioner daemon key invalid")
})
t.Run("ProvisionerKey_CompareFail", func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
mockDB := dbmock.NewMockStore(ctrl)
gomock.InOrder(
mockDB.EXPECT().GetProvisionerKeyByHashedSecret(gomock.Any(), gomock.Any()).Times(1).Return(database.ProvisionerKey{
ID: uuid.New(),
HashedSecret: []byte("hashedSecret"),
}, nil),
)
routeCtx := chi.NewRouteContext()
r := httptest.NewRequest(http.MethodGet, "/", nil)
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, routeCtx))
res := httptest.NewRecorder()
r.Header.Set(codersdk.ProvisionerDaemonKey, functionalKey)
httpmw.ExtractProvisionerDaemonAuthenticated(httpmw.ExtractProvisionerAuthConfig{
DB: mockDB,
Optional: false,
})(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})).ServeHTTP(res, r)
//nolint:bodyclose
require.Equal(t, http.StatusUnauthorized, res.Result().StatusCode)
require.Contains(t, res.Body.String(), "provisioner daemon key invalid")
})
t.Run("ProvisionerKey_DBError", func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
mockDB := dbmock.NewMockStore(ctrl)
gomock.InOrder(
mockDB.EXPECT().GetProvisionerKeyByHashedSecret(gomock.Any(), gomock.Any()).Times(1).Return(database.ProvisionerKey{}, xerrors.New("error")),
)
routeCtx := chi.NewRouteContext()
r := httptest.NewRequest(http.MethodGet, "/", nil)
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, routeCtx))
res := httptest.NewRecorder()
//nolint:gosec // test key generated by test
r.Header.Set(codersdk.ProvisionerDaemonKey, functionalKey)
httpmw.ExtractProvisionerDaemonAuthenticated(httpmw.ExtractProvisionerAuthConfig{
DB: mockDB,
Optional: false,
})(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})).ServeHTTP(res, r)
//nolint:bodyclose
require.Equal(t, http.StatusInternalServerError, res.Result().StatusCode)
require.Contains(t, res.Body.String(), "get provisioner daemon key")
})
}