mirror of
https://github.com/coder/coder.git
synced 2025-07-09 11:45:56 +00:00
chore: make scim auth header case insensitive for 'bearer' (#15538)
Fixes status codes to return more than 500. The way we were using the package, it always returned a status code 500
This commit is contained in:
@ -1,6 +1,7 @@
|
|||||||
package coderd
|
package coderd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"crypto/subtle"
|
"crypto/subtle"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@ -26,16 +27,21 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func (api *API) scimVerifyAuthHeader(r *http.Request) bool {
|
func (api *API) scimVerifyAuthHeader(r *http.Request) bool {
|
||||||
bearer := []byte("Bearer ")
|
bearer := []byte("bearer ")
|
||||||
hdr := []byte(r.Header.Get("Authorization"))
|
hdr := []byte(r.Header.Get("Authorization"))
|
||||||
|
|
||||||
if len(hdr) >= len(bearer) && subtle.ConstantTimeCompare(hdr[:len(bearer)], bearer) == 1 {
|
// Use toLower to make the comparison case-insensitive.
|
||||||
|
if len(hdr) >= len(bearer) && subtle.ConstantTimeCompare(bytes.ToLower(hdr[:len(bearer)]), bearer) == 1 {
|
||||||
hdr = hdr[len(bearer):]
|
hdr = hdr[len(bearer):]
|
||||||
}
|
}
|
||||||
|
|
||||||
return len(api.SCIMAPIKey) != 0 && subtle.ConstantTimeCompare(hdr, api.SCIMAPIKey) == 1
|
return len(api.SCIMAPIKey) != 0 && subtle.ConstantTimeCompare(hdr, api.SCIMAPIKey) == 1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func scimUnauthorized(rw http.ResponseWriter) {
|
||||||
|
_ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusUnauthorized, "invalidAuthorization", xerrors.New("invalid authorization")))
|
||||||
|
}
|
||||||
|
|
||||||
// scimServiceProviderConfig returns a static SCIM service provider configuration.
|
// scimServiceProviderConfig returns a static SCIM service provider configuration.
|
||||||
//
|
//
|
||||||
// @Summary SCIM 2.0: Service Provider Config
|
// @Summary SCIM 2.0: Service Provider Config
|
||||||
@ -114,7 +120,7 @@ func (api *API) scimServiceProviderConfig(rw http.ResponseWriter, _ *http.Reques
|
|||||||
//nolint:revive
|
//nolint:revive
|
||||||
func (api *API) scimGetUsers(rw http.ResponseWriter, r *http.Request) {
|
func (api *API) scimGetUsers(rw http.ResponseWriter, r *http.Request) {
|
||||||
if !api.scimVerifyAuthHeader(r) {
|
if !api.scimVerifyAuthHeader(r) {
|
||||||
_ = handlerutil.WriteError(rw, spec.Error{Status: http.StatusUnauthorized, Type: "invalidAuthorization"})
|
scimUnauthorized(rw)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -142,11 +148,11 @@ func (api *API) scimGetUsers(rw http.ResponseWriter, r *http.Request) {
|
|||||||
//nolint:revive
|
//nolint:revive
|
||||||
func (api *API) scimGetUser(rw http.ResponseWriter, r *http.Request) {
|
func (api *API) scimGetUser(rw http.ResponseWriter, r *http.Request) {
|
||||||
if !api.scimVerifyAuthHeader(r) {
|
if !api.scimVerifyAuthHeader(r) {
|
||||||
_ = handlerutil.WriteError(rw, spec.Error{Status: http.StatusUnauthorized, Type: "invalidAuthorization"})
|
scimUnauthorized(rw)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
_ = handlerutil.WriteError(rw, spec.ErrNotFound)
|
_ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusNotFound, spec.ErrNotFound.Type, xerrors.New("endpoint will always return 404")))
|
||||||
}
|
}
|
||||||
|
|
||||||
// We currently use our own struct instead of using the SCIM package. This was
|
// We currently use our own struct instead of using the SCIM package. This was
|
||||||
@ -192,7 +198,7 @@ var SCIMAuditAdditionalFields = map[string]string{
|
|||||||
func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) {
|
func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
if !api.scimVerifyAuthHeader(r) {
|
if !api.scimVerifyAuthHeader(r) {
|
||||||
_ = handlerutil.WriteError(rw, spec.Error{Status: http.StatusUnauthorized, Type: "invalidAuthorization"})
|
scimUnauthorized(rw)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -209,7 +215,7 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) {
|
|||||||
var sUser SCIMUser
|
var sUser SCIMUser
|
||||||
err := json.NewDecoder(r.Body).Decode(&sUser)
|
err := json.NewDecoder(r.Body).Decode(&sUser)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = handlerutil.WriteError(rw, err)
|
_ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusBadRequest, "invalidRequest", err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -222,7 +228,7 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if email == "" {
|
if email == "" {
|
||||||
_ = handlerutil.WriteError(rw, spec.Error{Status: http.StatusBadRequest, Type: "invalidEmail"})
|
_ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusBadRequest, "invalidEmail", xerrors.New("no primary email provided")))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -232,7 +238,7 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) {
|
|||||||
Username: sUser.UserName,
|
Username: sUser.UserName,
|
||||||
})
|
})
|
||||||
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
|
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
|
||||||
_ = handlerutil.WriteError(rw, err)
|
_ = handlerutil.WriteError(rw, err) // internal error
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@ -248,7 +254,7 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) {
|
|||||||
UpdatedAt: dbtime.Now(),
|
UpdatedAt: dbtime.Now(),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = handlerutil.WriteError(rw, err)
|
_ = handlerutil.WriteError(rw, err) // internal error
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
aReq.New = newUser
|
aReq.New = newUser
|
||||||
@ -284,14 +290,14 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) {
|
|||||||
//nolint:gocritic // SCIM operations are a system user
|
//nolint:gocritic // SCIM operations are a system user
|
||||||
orgSync, err := api.IDPSync.OrganizationSyncSettings(dbauthz.AsSystemRestricted(ctx), api.Database)
|
orgSync, err := api.IDPSync.OrganizationSyncSettings(dbauthz.AsSystemRestricted(ctx), api.Database)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = handlerutil.WriteError(rw, xerrors.Errorf("failed to get organization sync settings: %w", err))
|
_ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusInternalServerError, "internalError", xerrors.Errorf("failed to get organization sync settings: %w", err)))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if orgSync.AssignDefault {
|
if orgSync.AssignDefault {
|
||||||
//nolint:gocritic // SCIM operations are a system user
|
//nolint:gocritic // SCIM operations are a system user
|
||||||
defaultOrganization, err := api.Database.GetDefaultOrganization(dbauthz.AsSystemRestricted(ctx))
|
defaultOrganization, err := api.Database.GetDefaultOrganization(dbauthz.AsSystemRestricted(ctx))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = handlerutil.WriteError(rw, err)
|
_ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusInternalServerError, "internalError", xerrors.Errorf("failed to get default organization: %w", err)))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
organizations = append(organizations, defaultOrganization.ID)
|
organizations = append(organizations, defaultOrganization.ID)
|
||||||
@ -309,7 +315,7 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) {
|
|||||||
SkipNotifications: true,
|
SkipNotifications: true,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = handlerutil.WriteError(rw, err)
|
_ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusInternalServerError, "internalError", xerrors.Errorf("failed to create user: %w", err)))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
aReq.New = dbUser
|
aReq.New = dbUser
|
||||||
@ -335,7 +341,7 @@ func (api *API) scimPostUser(rw http.ResponseWriter, r *http.Request) {
|
|||||||
func (api *API) scimPatchUser(rw http.ResponseWriter, r *http.Request) {
|
func (api *API) scimPatchUser(rw http.ResponseWriter, r *http.Request) {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
if !api.scimVerifyAuthHeader(r) {
|
if !api.scimVerifyAuthHeader(r) {
|
||||||
_ = handlerutil.WriteError(rw, spec.Error{Status: http.StatusUnauthorized, Type: "invalidAuthorization"})
|
scimUnauthorized(rw)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -354,21 +360,21 @@ func (api *API) scimPatchUser(rw http.ResponseWriter, r *http.Request) {
|
|||||||
var sUser SCIMUser
|
var sUser SCIMUser
|
||||||
err := json.NewDecoder(r.Body).Decode(&sUser)
|
err := json.NewDecoder(r.Body).Decode(&sUser)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = handlerutil.WriteError(rw, err)
|
_ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusBadRequest, "invalidRequest", err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
sUser.ID = id
|
sUser.ID = id
|
||||||
|
|
||||||
uid, err := uuid.Parse(id)
|
uid, err := uuid.Parse(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = handlerutil.WriteError(rw, spec.Error{Status: http.StatusBadRequest, Type: "invalidId"})
|
_ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusBadRequest, "invalidId", xerrors.Errorf("id must be a uuid: %w", err)))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
//nolint:gocritic // needed for SCIM
|
//nolint:gocritic // needed for SCIM
|
||||||
dbUser, err := api.Database.GetUserByID(dbauthz.AsSystemRestricted(ctx), uid)
|
dbUser, err := api.Database.GetUserByID(dbauthz.AsSystemRestricted(ctx), uid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = handlerutil.WriteError(rw, err)
|
_ = handlerutil.WriteError(rw, err) // internal error
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
aReq.Old = dbUser
|
aReq.Old = dbUser
|
||||||
@ -400,7 +406,7 @@ func (api *API) scimPatchUser(rw http.ResponseWriter, r *http.Request) {
|
|||||||
UpdatedAt: dbtime.Now(),
|
UpdatedAt: dbtime.Now(),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = handlerutil.WriteError(rw, err)
|
_ = handlerutil.WriteError(rw, err) // internal error
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
dbUser = userNew
|
dbUser = userNew
|
||||||
|
@ -1,6 +1,11 @@
|
|||||||
package scim
|
package scim
|
||||||
|
|
||||||
import "time"
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/imulab/go-scim/pkg/v2/spec"
|
||||||
|
)
|
||||||
|
|
||||||
type ServiceProviderConfig struct {
|
type ServiceProviderConfig struct {
|
||||||
Schemas []string `json:"schemas"`
|
Schemas []string `json:"schemas"`
|
||||||
@ -44,3 +49,37 @@ type AuthenticationScheme struct {
|
|||||||
SpecURI string `json:"specUri"`
|
SpecURI string `json:"specUri"`
|
||||||
DocURI string `json:"documentationUri"`
|
DocURI string `json:"documentationUri"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HTTPError wraps a *spec.Error for correct usage with
|
||||||
|
// 'handlerutil.WriteError'. This error type is cursed to be
|
||||||
|
// absolutely strange and specific to the SCIM library we use.
|
||||||
|
//
|
||||||
|
// The library expects *spec.Error to be returned on unwrap, and the
|
||||||
|
// internal error description to be returned by a json.Marshal of the
|
||||||
|
// top level error.
|
||||||
|
type HTTPError struct {
|
||||||
|
scim *spec.Error
|
||||||
|
internal error
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewHTTPError(status int, eType string, err error) *HTTPError {
|
||||||
|
return &HTTPError{
|
||||||
|
scim: &spec.Error{
|
||||||
|
Status: status,
|
||||||
|
Type: eType,
|
||||||
|
},
|
||||||
|
internal: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e HTTPError) Error() string {
|
||||||
|
return e.internal.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e HTTPError) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(e.internal)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e HTTPError) Unwrap() error {
|
||||||
|
return e.scim
|
||||||
|
}
|
||||||
|
@ -6,11 +6,15 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/golang-jwt/jwt/v4"
|
"github.com/golang-jwt/jwt/v4"
|
||||||
|
"github.com/imulab/go-scim/pkg/v2/handlerutil"
|
||||||
|
"github.com/imulab/go-scim/pkg/v2/spec"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/xerrors"
|
||||||
|
|
||||||
"github.com/coder/coder/v2/coderd/audit"
|
"github.com/coder/coder/v2/coderd/audit"
|
||||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||||
@ -22,6 +26,7 @@ import (
|
|||||||
"github.com/coder/coder/v2/enterprise/coderd"
|
"github.com/coder/coder/v2/enterprise/coderd"
|
||||||
"github.com/coder/coder/v2/enterprise/coderd/coderdenttest"
|
"github.com/coder/coder/v2/enterprise/coderd/coderdenttest"
|
||||||
"github.com/coder/coder/v2/enterprise/coderd/license"
|
"github.com/coder/coder/v2/enterprise/coderd/license"
|
||||||
|
"github.com/coder/coder/v2/enterprise/coderd/scim"
|
||||||
"github.com/coder/coder/v2/testutil"
|
"github.com/coder/coder/v2/testutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -59,7 +64,8 @@ func setScimAuth(key []byte) func(*http.Request) {
|
|||||||
|
|
||||||
func setScimAuthBearer(key []byte) func(*http.Request) {
|
func setScimAuthBearer(key []byte) func(*http.Request) {
|
||||||
return func(r *http.Request) {
|
return func(r *http.Request) {
|
||||||
r.Header.Set("Authorization", "Bearer "+string(key))
|
// Do strange casing to ensure it's case-insensitive
|
||||||
|
r.Header.Set("Authorization", "beAreR "+string(key))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -111,7 +117,7 @@ func TestScim(t *testing.T) {
|
|||||||
res, err := client.Request(ctx, "POST", "/scim/v2/Users", struct{}{})
|
res, err := client.Request(ctx, "POST", "/scim/v2/Users", struct{}{})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer res.Body.Close()
|
defer res.Body.Close()
|
||||||
assert.Equal(t, http.StatusInternalServerError, res.StatusCode)
|
assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("OK", func(t *testing.T) {
|
t.Run("OK", func(t *testing.T) {
|
||||||
@ -454,7 +460,7 @@ func TestScim(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
_, _ = io.Copy(io.Discard, res.Body)
|
_, _ = io.Copy(io.Discard, res.Body)
|
||||||
_ = res.Body.Close()
|
_ = res.Body.Close()
|
||||||
assert.Equal(t, http.StatusInternalServerError, res.StatusCode)
|
assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("OK", func(t *testing.T) {
|
t.Run("OK", func(t *testing.T) {
|
||||||
@ -585,3 +591,21 @@ func TestScim(t *testing.T) {
|
|||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestScimError(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// Demonstrates that we cannot use the standard errors
|
||||||
|
rw := httptest.NewRecorder()
|
||||||
|
_ = handlerutil.WriteError(rw, spec.ErrNotFound)
|
||||||
|
resp := rw.Result()
|
||||||
|
defer resp.Body.Close()
|
||||||
|
require.Equal(t, http.StatusInternalServerError, resp.StatusCode)
|
||||||
|
|
||||||
|
// Our error wrapper works
|
||||||
|
rw = httptest.NewRecorder()
|
||||||
|
_ = handlerutil.WriteError(rw, scim.NewHTTPError(http.StatusNotFound, spec.ErrNotFound.Type, xerrors.New("not found")))
|
||||||
|
resp = rw.Result()
|
||||||
|
defer resp.Body.Close()
|
||||||
|
require.Equal(t, http.StatusNotFound, resp.StatusCode)
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user