mirror of
https://github.com/coder/coder.git
synced 2025-07-08 11:39:50 +00:00
This one aims to resolve #15604 Created some table tests for the main cases - also preferred to create two isolated cases for the most complicated cases in order to keep table tests simple enough. Give us full coverage on the middleware logic, for both optional and non optional cases - PSK and ProvisionerKey.
291 lines
8.6 KiB
Go
291 lines
8.6 KiB
Go
package httpmw_test
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
|
|
"github.com/go-chi/chi/v5"
|
|
"github.com/google/uuid"
|
|
"github.com/stretchr/testify/require"
|
|
"go.uber.org/mock/gomock"
|
|
"golang.org/x/xerrors"
|
|
|
|
"github.com/coder/coder/v2/coderd/database"
|
|
"github.com/coder/coder/v2/coderd/database/dbmock"
|
|
"github.com/coder/coder/v2/coderd/httpmw"
|
|
"github.com/coder/coder/v2/codersdk"
|
|
"github.com/coder/coder/v2/enterprise/coderd/coderdenttest"
|
|
"github.com/coder/coder/v2/enterprise/coderd/license"
|
|
"github.com/coder/coder/v2/testutil"
|
|
)
|
|
|
|
func TestExtractProvisionerDaemonAuthenticated(t *testing.T) {
|
|
const (
|
|
//nolint:gosec // test key generated by test
|
|
functionalKey = "5Hl2Qw9kX3nM7vB4jR8pY6tA1cF0eD5uI2oL9gN3mZ4"
|
|
)
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
opts httpmw.ExtractProvisionerAuthConfig
|
|
expectedStatusCode int
|
|
expectedResponseMessage string
|
|
provisionerKey string
|
|
provisionerPSK string
|
|
}{
|
|
{
|
|
name: "NoKeyProvided_Optional",
|
|
opts: httpmw.ExtractProvisionerAuthConfig{
|
|
DB: nil,
|
|
Optional: true,
|
|
},
|
|
expectedStatusCode: http.StatusOK,
|
|
},
|
|
{
|
|
name: "NoKeyProvided_NotOptional",
|
|
opts: httpmw.ExtractProvisionerAuthConfig{
|
|
DB: nil,
|
|
Optional: false,
|
|
},
|
|
expectedStatusCode: http.StatusUnauthorized,
|
|
expectedResponseMessage: "provisioner daemon key required",
|
|
},
|
|
{
|
|
name: "ProvisionerKeyAndPSKProvided_NotOptional",
|
|
opts: httpmw.ExtractProvisionerAuthConfig{
|
|
DB: nil,
|
|
Optional: false,
|
|
},
|
|
provisionerKey: "key",
|
|
provisionerPSK: "psk",
|
|
expectedStatusCode: http.StatusBadRequest,
|
|
expectedResponseMessage: "provisioner daemon key and psk provided, but only one is allowed",
|
|
},
|
|
{
|
|
name: "ProvisionerKeyAndPSKProvided_Optional",
|
|
opts: httpmw.ExtractProvisionerAuthConfig{
|
|
DB: nil,
|
|
Optional: true,
|
|
},
|
|
provisionerKey: "key",
|
|
expectedStatusCode: http.StatusOK,
|
|
},
|
|
{
|
|
name: "InvalidProvisionerKey_NotOptional",
|
|
opts: httpmw.ExtractProvisionerAuthConfig{
|
|
DB: nil,
|
|
Optional: false,
|
|
},
|
|
provisionerKey: "invalid",
|
|
expectedStatusCode: http.StatusBadRequest,
|
|
expectedResponseMessage: "provisioner daemon key invalid",
|
|
},
|
|
{
|
|
name: "InvalidProvisionerKey_Optional",
|
|
opts: httpmw.ExtractProvisionerAuthConfig{
|
|
DB: nil,
|
|
Optional: true,
|
|
},
|
|
provisionerKey: "invalid",
|
|
expectedStatusCode: http.StatusOK,
|
|
},
|
|
{
|
|
name: "InvalidProvisionerPSK_NotOptional",
|
|
opts: httpmw.ExtractProvisionerAuthConfig{
|
|
DB: nil,
|
|
Optional: false,
|
|
PSK: "psk",
|
|
},
|
|
provisionerPSK: "invalid",
|
|
expectedStatusCode: http.StatusUnauthorized,
|
|
expectedResponseMessage: "provisioner daemon psk invalid",
|
|
},
|
|
{
|
|
name: "InvalidProvisionerPSK_Optional",
|
|
opts: httpmw.ExtractProvisionerAuthConfig{
|
|
DB: nil,
|
|
Optional: true,
|
|
PSK: "psk",
|
|
},
|
|
provisionerPSK: "invalid",
|
|
expectedStatusCode: http.StatusOK,
|
|
},
|
|
{
|
|
name: "ValidProvisionerPSK_NotOptional",
|
|
opts: httpmw.ExtractProvisionerAuthConfig{
|
|
DB: nil,
|
|
Optional: false,
|
|
PSK: "ThisIsAValidPSK",
|
|
},
|
|
provisionerPSK: "ThisIsAValidPSK",
|
|
expectedStatusCode: http.StatusOK,
|
|
},
|
|
}
|
|
|
|
for _, test := range tests {
|
|
test := test
|
|
t.Run(test.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
routeCtx := chi.NewRouteContext()
|
|
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, routeCtx))
|
|
res := httptest.NewRecorder()
|
|
|
|
if test.provisionerKey != "" {
|
|
r.Header.Set(codersdk.ProvisionerDaemonKey, test.provisionerKey)
|
|
}
|
|
if test.provisionerPSK != "" {
|
|
r.Header.Set(codersdk.ProvisionerDaemonPSK, test.provisionerPSK)
|
|
}
|
|
|
|
httpmw.ExtractProvisionerDaemonAuthenticated(test.opts)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})).ServeHTTP(res, r)
|
|
|
|
//nolint:bodyclose
|
|
require.Equal(t, test.expectedStatusCode, res.Result().StatusCode)
|
|
if test.expectedResponseMessage != "" {
|
|
require.Contains(t, res.Body.String(), test.expectedResponseMessage)
|
|
}
|
|
})
|
|
}
|
|
|
|
t.Run("ProvisionerKey", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
client, db, user := coderdenttest.NewWithDatabase(t, &coderdenttest.Options{
|
|
LicenseOptions: &coderdenttest.LicenseOptions{
|
|
Features: license.Features{
|
|
codersdk.FeatureExternalProvisionerDaemons: 1,
|
|
},
|
|
},
|
|
})
|
|
// nolint:gocritic // test
|
|
key, err := client.CreateProvisionerKey(ctx, user.OrganizationID, codersdk.CreateProvisionerKeyRequest{
|
|
Name: "dont-TEST-me",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
routeCtx := chi.NewRouteContext()
|
|
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, routeCtx))
|
|
res := httptest.NewRecorder()
|
|
|
|
r.Header.Set(codersdk.ProvisionerDaemonKey, key.Key)
|
|
|
|
httpmw.ExtractProvisionerDaemonAuthenticated(httpmw.ExtractProvisionerAuthConfig{
|
|
DB: db,
|
|
Optional: false,
|
|
})(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})).ServeHTTP(res, r)
|
|
|
|
//nolint:bodyclose
|
|
require.Equal(t, http.StatusOK, res.Result().StatusCode)
|
|
})
|
|
|
|
t.Run("ProvisionerKey_NotFound", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
client, db, user := coderdenttest.NewWithDatabase(t, &coderdenttest.Options{
|
|
LicenseOptions: &coderdenttest.LicenseOptions{
|
|
Features: license.Features{
|
|
codersdk.FeatureExternalProvisionerDaemons: 1,
|
|
},
|
|
},
|
|
})
|
|
// nolint:gocritic // test
|
|
_, err := client.CreateProvisionerKey(ctx, user.OrganizationID, codersdk.CreateProvisionerKeyRequest{
|
|
Name: "dont-TEST-me",
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
routeCtx := chi.NewRouteContext()
|
|
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, routeCtx))
|
|
res := httptest.NewRecorder()
|
|
|
|
//nolint:gosec // test key generated by test
|
|
pkey := "5Hl2Qw9kX3nM7vB4jR8pY6tA1cF0eD5uI2oL9gN3mZ4"
|
|
r.Header.Set(codersdk.ProvisionerDaemonKey, pkey)
|
|
|
|
httpmw.ExtractProvisionerDaemonAuthenticated(httpmw.ExtractProvisionerAuthConfig{
|
|
DB: db,
|
|
Optional: false,
|
|
})(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})).ServeHTTP(res, r)
|
|
|
|
//nolint:bodyclose
|
|
require.Equal(t, http.StatusUnauthorized, res.Result().StatusCode)
|
|
require.Contains(t, res.Body.String(), "provisioner daemon key invalid")
|
|
})
|
|
|
|
t.Run("ProvisionerKey_CompareFail", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
mockDB := dbmock.NewMockStore(ctrl)
|
|
|
|
gomock.InOrder(
|
|
mockDB.EXPECT().GetProvisionerKeyByHashedSecret(gomock.Any(), gomock.Any()).Times(1).Return(database.ProvisionerKey{
|
|
ID: uuid.New(),
|
|
HashedSecret: []byte("hashedSecret"),
|
|
}, nil),
|
|
)
|
|
|
|
routeCtx := chi.NewRouteContext()
|
|
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, routeCtx))
|
|
res := httptest.NewRecorder()
|
|
|
|
r.Header.Set(codersdk.ProvisionerDaemonKey, functionalKey)
|
|
|
|
httpmw.ExtractProvisionerDaemonAuthenticated(httpmw.ExtractProvisionerAuthConfig{
|
|
DB: mockDB,
|
|
Optional: false,
|
|
})(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})).ServeHTTP(res, r)
|
|
|
|
//nolint:bodyclose
|
|
require.Equal(t, http.StatusUnauthorized, res.Result().StatusCode)
|
|
require.Contains(t, res.Body.String(), "provisioner daemon key invalid")
|
|
})
|
|
|
|
t.Run("ProvisionerKey_DBError", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
mockDB := dbmock.NewMockStore(ctrl)
|
|
|
|
gomock.InOrder(
|
|
mockDB.EXPECT().GetProvisionerKeyByHashedSecret(gomock.Any(), gomock.Any()).Times(1).Return(database.ProvisionerKey{}, xerrors.New("error")),
|
|
)
|
|
|
|
routeCtx := chi.NewRouteContext()
|
|
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, routeCtx))
|
|
res := httptest.NewRecorder()
|
|
|
|
//nolint:gosec // test key generated by test
|
|
r.Header.Set(codersdk.ProvisionerDaemonKey, functionalKey)
|
|
|
|
httpmw.ExtractProvisionerDaemonAuthenticated(httpmw.ExtractProvisionerAuthConfig{
|
|
DB: mockDB,
|
|
Optional: false,
|
|
})(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})).ServeHTTP(res, r)
|
|
|
|
//nolint:bodyclose
|
|
require.Equal(t, http.StatusInternalServerError, res.Result().StatusCode)
|
|
require.Contains(t, res.Body.String(), "get provisioner daemon key")
|
|
})
|
|
}
|