chore: make scim auth header case insensitive for 'bearer' (#15538)

Fixes status codes to return more than 500. The way we were using the
package, it always returned a status code 500
This commit is contained in:
Steven Masley
2024-11-15 12:30:11 -06:00
committed by GitHub
parent 4cb807670d
commit 16ade985ae
3 changed files with 91 additions and 22 deletions

View File

@ -1,6 +1,7 @@
package coderd package coderd
import ( import (
"bytes"
"crypto/subtle" "crypto/subtle"
"database/sql" "database/sql"
"encoding/json" "encoding/json"
@ -26,16 +27,21 @@ import (
) )
func (api *API) scimVerifyAuthHeader(r *http.Request) bool { func (api *API) scimVerifyAuthHeader(r *http.Request) bool {
bearer := []byte("Bearer ") bearer := []byte("bearer ")
hdr := []byte(r.Header.Get("Authorization")) hdr := []byte(r.Header.Get("Authorization"))
if len(hdr) >= len(bearer) && subtle.ConstantTimeCompare(hdr[:len(bearer)], bearer) == 1 { // Use toLower to make the comparison case-insensitive.
if len(hdr) >= len(bearer) && subtle.ConstantTimeCompare(bytes.ToLower(hdr[:len(bearer)]), bearer) == 1 {
hdr = hdr[len(bearer):] hdr = hdr[len(bearer):]
} }
return len(api.SCIMAPIKey) != 0 && subtle.ConstantTimeCompare(hdr, api.SCIMAPIKey) == 1 return len(api.SCIMAPIKey) != 0 && subtle.ConstantTimeCompare(hdr, api.SCIMAPIKey) == 1
} }
func scimUnauthorized(rw http.ResponseWriter) {
_ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusUnauthorized, "invalidAuthorization", xerrors.New("invalid authorization")))
}
// scimServiceProviderConfig returns a static SCIM service provider configuration. // scimServiceProviderConfig returns a static SCIM service provider configuration.
// //
// @Summary SCIM 2.0: Service Provider Config // @Summary SCIM 2.0: Service Provider Config
@ -114,7 +120,7 @@ func (api *API) scimServiceProviderConfig(rw http.ResponseWriter, _ *http.Reques
//nolint:revive //nolint:revive
func (api *API) scimGetUsers(rw http.ResponseWriter, r *http.Request) { func (api *API) scimGetUsers(rw http.ResponseWriter, r *http.Request) {
if !api.scimVerifyAuthHeader(r) { if !api.scimVerifyAuthHeader(r) {
_ = handlerutil.WriteError(rw, spec.Error{Status: http.StatusUnauthorized, Type: "invalidAuthorization"}) scimUnauthorized(rw)
return return
} }
@ -142,11 +148,11 @@ func (api *API) scimGetUsers(rw http.ResponseWriter, r *http.Request) {
//nolint:revive //nolint:revive
func (api *API) scimGetUser(rw http.ResponseWriter, r *http.Request) { func (api *API) scimGetUser(rw http.ResponseWriter, r *http.Request) {
if !api.scimVerifyAuthHeader(r) { if !api.scimVerifyAuthHeader(r) {
_ = handlerutil.WriteError(rw, spec.Error{Status: http.StatusUnauthorized, Type: "invalidAuthorization"}) scimUnauthorized(rw)
return return
} }
_ = handlerutil.WriteError(rw, spec.ErrNotFound) _ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusNotFound, spec.ErrNotFound.Type, xerrors.New("endpoint will always return 404")))
} }
// We currently use our own struct instead of using the SCIM package. This was // We currently use our own struct instead of using the SCIM package. This was
@ -192,7 +198,7 @@ var SCIMAuditAdditionalFields = map[string]string{
func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) { func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
if !api.scimVerifyAuthHeader(r) { if !api.scimVerifyAuthHeader(r) {
_ = handlerutil.WriteError(rw, spec.Error{Status: http.StatusUnauthorized, Type: "invalidAuthorization"}) scimUnauthorized(rw)
return return
} }
@ -209,7 +215,7 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) {
var sUser SCIMUser var sUser SCIMUser
err := json.NewDecoder(r.Body).Decode(&sUser) err := json.NewDecoder(r.Body).Decode(&sUser)
if err != nil { if err != nil {
_ = handlerutil.WriteError(rw, err) _ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusBadRequest, "invalidRequest", err))
return return
} }
@ -222,7 +228,7 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) {
} }
if email == "" { if email == "" {
_ = handlerutil.WriteError(rw, spec.Error{Status: http.StatusBadRequest, Type: "invalidEmail"}) _ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusBadRequest, "invalidEmail", xerrors.New("no primary email provided")))
return return
} }
@ -232,7 +238,7 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) {
Username: sUser.UserName, Username: sUser.UserName,
}) })
if err != nil && !xerrors.Is(err, sql.ErrNoRows) { if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
_ = handlerutil.WriteError(rw, err) _ = handlerutil.WriteError(rw, err) // internal error
return return
} }
if err == nil { if err == nil {
@ -248,7 +254,7 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) {
UpdatedAt: dbtime.Now(), UpdatedAt: dbtime.Now(),
}) })
if err != nil { if err != nil {
_ = handlerutil.WriteError(rw, err) _ = handlerutil.WriteError(rw, err) // internal error
return return
} }
aReq.New = newUser aReq.New = newUser
@ -284,14 +290,14 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) {
//nolint:gocritic // SCIM operations are a system user //nolint:gocritic // SCIM operations are a system user
orgSync, err := api.IDPSync.OrganizationSyncSettings(dbauthz.AsSystemRestricted(ctx), api.Database) orgSync, err := api.IDPSync.OrganizationSyncSettings(dbauthz.AsSystemRestricted(ctx), api.Database)
if err != nil { if err != nil {
_ = handlerutil.WriteError(rw, xerrors.Errorf("failed to get organization sync settings: %w", err)) _ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusInternalServerError, "internalError", xerrors.Errorf("failed to get organization sync settings: %w", err)))
return return
} }
if orgSync.AssignDefault { if orgSync.AssignDefault {
//nolint:gocritic // SCIM operations are a system user //nolint:gocritic // SCIM operations are a system user
defaultOrganization, err := api.Database.GetDefaultOrganization(dbauthz.AsSystemRestricted(ctx)) defaultOrganization, err := api.Database.GetDefaultOrganization(dbauthz.AsSystemRestricted(ctx))
if err != nil { if err != nil {
_ = handlerutil.WriteError(rw, err) _ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusInternalServerError, "internalError", xerrors.Errorf("failed to get default organization: %w", err)))
return return
} }
organizations = append(organizations, defaultOrganization.ID) organizations = append(organizations, defaultOrganization.ID)
@ -309,7 +315,7 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) {
SkipNotifications: true, SkipNotifications: true,
}) })
if err != nil { if err != nil {
_ = handlerutil.WriteError(rw, err) _ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusInternalServerError, "internalError", xerrors.Errorf("failed to create user: %w", err)))
return return
} }
aReq.New = dbUser aReq.New = dbUser
@ -335,7 +341,7 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) {
func (api *API) scimPatchUser(rw http.ResponseWriter, r *http.Request) { func (api *API) scimPatchUser(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
if !api.scimVerifyAuthHeader(r) { if !api.scimVerifyAuthHeader(r) {
_ = handlerutil.WriteError(rw, spec.Error{Status: http.StatusUnauthorized, Type: "invalidAuthorization"}) scimUnauthorized(rw)
return return
} }
@ -354,21 +360,21 @@ func (api *API) scimPatchUser(rw http.ResponseWriter, r *http.Request) {
var sUser SCIMUser var sUser SCIMUser
err := json.NewDecoder(r.Body).Decode(&sUser) err := json.NewDecoder(r.Body).Decode(&sUser)
if err != nil { if err != nil {
_ = handlerutil.WriteError(rw, err) _ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusBadRequest, "invalidRequest", err))
return return
} }
sUser.ID = id sUser.ID = id
uid, err := uuid.Parse(id) uid, err := uuid.Parse(id)
if err != nil { if err != nil {
_ = handlerutil.WriteError(rw, spec.Error{Status: http.StatusBadRequest, Type: "invalidId"}) _ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusBadRequest, "invalidId", xerrors.Errorf("id must be a uuid: %w", err)))
return return
} }
//nolint:gocritic // needed for SCIM //nolint:gocritic // needed for SCIM
dbUser, err := api.Database.GetUserByID(dbauthz.AsSystemRestricted(ctx), uid) dbUser, err := api.Database.GetUserByID(dbauthz.AsSystemRestricted(ctx), uid)
if err != nil { if err != nil {
_ = handlerutil.WriteError(rw, err) _ = handlerutil.WriteError(rw, err) // internal error
return return
} }
aReq.Old = dbUser aReq.Old = dbUser
@ -400,7 +406,7 @@ func (api *API) scimPatchUser(rw http.ResponseWriter, r *http.Request) {
UpdatedAt: dbtime.Now(), UpdatedAt: dbtime.Now(),
}) })
if err != nil { if err != nil {
_ = handlerutil.WriteError(rw, err) _ = handlerutil.WriteError(rw, err) // internal error
return return
} }
dbUser = userNew dbUser = userNew

View File

@ -1,6 +1,11 @@
package scim package scim
import "time" import (
"encoding/json"
"time"
"github.com/imulab/go-scim/pkg/v2/spec"
)
type ServiceProviderConfig struct { type ServiceProviderConfig struct {
Schemas []string `json:"schemas"` Schemas []string `json:"schemas"`
@ -44,3 +49,37 @@ type AuthenticationScheme struct {
SpecURI string `json:"specUri"` SpecURI string `json:"specUri"`
DocURI string `json:"documentationUri"` DocURI string `json:"documentationUri"`
} }
// HTTPError wraps a *spec.Error for correct usage with
// 'handlerutil.WriteError'. This error type is cursed to be
// absolutely strange and specific to the SCIM library we use.
//
// The library expects *spec.Error to be returned on unwrap, and the
// internal error description to be returned by a json.Marshal of the
// top level error.
type HTTPError struct {
scim *spec.Error
internal error
}
func NewHTTPError(status int, eType string, err error) *HTTPError {
return &HTTPError{
scim: &spec.Error{
Status: status,
Type: eType,
},
internal: err,
}
}
func (e HTTPError) Error() string {
return e.internal.Error()
}
func (e HTTPError) MarshalJSON() ([]byte, error) {
return json.Marshal(e.internal)
}
func (e HTTPError) Unwrap() error {
return e.scim
}

View File

@ -6,11 +6,15 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"net/http/httptest"
"testing" "testing"
"github.com/golang-jwt/jwt/v4" "github.com/golang-jwt/jwt/v4"
"github.com/imulab/go-scim/pkg/v2/handlerutil"
"github.com/imulab/go-scim/pkg/v2/spec"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/audit" "github.com/coder/coder/v2/coderd/audit"
"github.com/coder/coder/v2/coderd/coderdtest" "github.com/coder/coder/v2/coderd/coderdtest"
@ -22,6 +26,7 @@ import (
"github.com/coder/coder/v2/enterprise/coderd" "github.com/coder/coder/v2/enterprise/coderd"
"github.com/coder/coder/v2/enterprise/coderd/coderdenttest" "github.com/coder/coder/v2/enterprise/coderd/coderdenttest"
"github.com/coder/coder/v2/enterprise/coderd/license" "github.com/coder/coder/v2/enterprise/coderd/license"
"github.com/coder/coder/v2/enterprise/coderd/scim"
"github.com/coder/coder/v2/testutil" "github.com/coder/coder/v2/testutil"
) )
@ -59,7 +64,8 @@ func setScimAuth(key []byte) func(*http.Request) {
func setScimAuthBearer(key []byte) func(*http.Request) { func setScimAuthBearer(key []byte) func(*http.Request) {
return func(r *http.Request) { return func(r *http.Request) {
r.Header.Set("Authorization", "Bearer "+string(key)) // Do strange casing to ensure it's case-insensitive
r.Header.Set("Authorization", "beAreR "+string(key))
} }
} }
@ -111,7 +117,7 @@ func TestScim(t *testing.T) {
res, err := client.Request(ctx, "POST", "/scim/v2/Users", struct{}{}) res, err := client.Request(ctx, "POST", "/scim/v2/Users", struct{}{})
require.NoError(t, err) require.NoError(t, err)
defer res.Body.Close() defer res.Body.Close()
assert.Equal(t, http.StatusInternalServerError, res.StatusCode) assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
}) })
t.Run("OK", func(t *testing.T) { t.Run("OK", func(t *testing.T) {
@ -454,7 +460,7 @@ func TestScim(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
_, _ = io.Copy(io.Discard, res.Body) _, _ = io.Copy(io.Discard, res.Body)
_ = res.Body.Close() _ = res.Body.Close()
assert.Equal(t, http.StatusInternalServerError, res.StatusCode) assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
}) })
t.Run("OK", func(t *testing.T) { t.Run("OK", func(t *testing.T) {
@ -585,3 +591,21 @@ func TestScim(t *testing.T) {
}) })
}) })
} }
func TestScimError(t *testing.T) {
t.Parallel()
// Demonstrates that we cannot use the standard errors
rw := httptest.NewRecorder()
_ = handlerutil.WriteError(rw, spec.ErrNotFound)
resp := rw.Result()
defer resp.Body.Close()
require.Equal(t, http.StatusInternalServerError, resp.StatusCode)
// Our error wrapper works
rw = httptest.NewRecorder()
_ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusNotFound, spec.ErrNotFound.Type, xerrors.New("not found")))
resp = rw.Result()
defer resp.Body.Close()
require.Equal(t, http.StatusNotFound, resp.StatusCode)
}