mirror of
https://github.com/coder/coder.git
synced 2025-07-13 21:36:50 +00:00
feat: Add initial AuthzQuerier implementation (#5919)
feat: Add initial AuthzQuerier implementation - Adds package database/dbauthz that adds a database.Store implementation where each method goes through AuthZ checks - Implements all database.Store methods on AuthzQuerier - Updates and fixes unit tests where required - Updates coderd initialization to use AuthzQuerier if codersdk.ExperimentAuthzQuerier is enabled
This commit is contained in:
@ -19,6 +19,7 @@ import (
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/coderd/rbac"
|
||||
"github.com/coder/coder/codersdk"
|
||||
@ -159,7 +160,8 @@ func ExtractAPIKey(cfg ExtractAPIKeyConfig) func(http.Handler) http.Handler {
|
||||
return
|
||||
}
|
||||
|
||||
key, err := cfg.DB.GetAPIKeyByID(r.Context(), keyID)
|
||||
//nolint:gocritic // System needs to fetch API key to check if it's valid.
|
||||
key, err := cfg.DB.GetAPIKeyByID(dbauthz.AsSystem(ctx), keyID)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
optionalWrite(http.StatusUnauthorized, codersdk.Response{
|
||||
@ -192,7 +194,8 @@ func ExtractAPIKey(cfg ExtractAPIKeyConfig) func(http.Handler) http.Handler {
|
||||
changed = false
|
||||
)
|
||||
if key.LoginType == database.LoginTypeGithub || key.LoginType == database.LoginTypeOIDC {
|
||||
link, err = cfg.DB.GetUserLinkByUserIDLoginType(r.Context(), database.GetUserLinkByUserIDLoginTypeParams{
|
||||
//nolint:gocritic // System needs to fetch UserLink to check if it's valid.
|
||||
link, err = cfg.DB.GetUserLinkByUserIDLoginType(dbauthz.AsSystem(ctx), database.GetUserLinkByUserIDLoginTypeParams{
|
||||
UserID: key.UserID,
|
||||
LoginType: key.LoginType,
|
||||
})
|
||||
@ -275,7 +278,8 @@ func ExtractAPIKey(cfg ExtractAPIKeyConfig) func(http.Handler) http.Handler {
|
||||
}
|
||||
}
|
||||
if changed {
|
||||
err := cfg.DB.UpdateAPIKeyByID(r.Context(), database.UpdateAPIKeyByIDParams{
|
||||
//nolint:gocritic // System needs to update API Key LastUsed
|
||||
err := cfg.DB.UpdateAPIKeyByID(dbauthz.AsSystem(ctx), database.UpdateAPIKeyByIDParams{
|
||||
ID: key.ID,
|
||||
LastUsed: key.LastUsed,
|
||||
ExpiresAt: key.ExpiresAt,
|
||||
@ -291,7 +295,8 @@ func ExtractAPIKey(cfg ExtractAPIKeyConfig) func(http.Handler) http.Handler {
|
||||
// If the API Key is associated with a user_link (e.g. Github/OIDC)
|
||||
// then we want to update the relevant oauth fields.
|
||||
if link.UserID != uuid.Nil {
|
||||
link, err = cfg.DB.UpdateUserLink(r.Context(), database.UpdateUserLinkParams{
|
||||
// nolint:gocritic
|
||||
link, err = cfg.DB.UpdateUserLink(dbauthz.AsSystem(ctx), database.UpdateUserLinkParams{
|
||||
UserID: link.UserID,
|
||||
LoginType: link.LoginType,
|
||||
OAuthAccessToken: link.OAuthAccessToken,
|
||||
@ -310,7 +315,8 @@ func ExtractAPIKey(cfg ExtractAPIKeyConfig) func(http.Handler) http.Handler {
|
||||
// We only want to update this occasionally to reduce DB write
|
||||
// load. We update alongside the UserLink and APIKey since it's
|
||||
// easier on the DB to colocate writes.
|
||||
_, err = cfg.DB.UpdateUserLastSeenAt(ctx, database.UpdateUserLastSeenAtParams{
|
||||
// nolint:gocritic
|
||||
_, err = cfg.DB.UpdateUserLastSeenAt(dbauthz.AsSystem(ctx), database.UpdateUserLastSeenAtParams{
|
||||
ID: key.UserID,
|
||||
LastSeenAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
@ -327,7 +333,8 @@ func ExtractAPIKey(cfg ExtractAPIKeyConfig) func(http.Handler) http.Handler {
|
||||
// If the key is valid, we also fetch the user roles and status.
|
||||
// The roles are used for RBAC authorize checks, and the status
|
||||
// is to block 'suspended' users from accessing the platform.
|
||||
roles, err := cfg.DB.GetAuthorizationUserRoles(r.Context(), key.UserID)
|
||||
// nolint:gocritic
|
||||
roles, err := cfg.DB.GetAuthorizationUserRoles(dbauthz.AsSystem(ctx), key.UserID)
|
||||
if err != nil {
|
||||
write(http.StatusUnauthorized, codersdk.Response{
|
||||
Message: internalErrorMessage,
|
||||
@ -343,16 +350,20 @@ func ExtractAPIKey(cfg ExtractAPIKeyConfig) func(http.Handler) http.Handler {
|
||||
return
|
||||
}
|
||||
|
||||
// Actor is the user's authorization context.
|
||||
actor := rbac.Subject{
|
||||
ID: key.UserID.String(),
|
||||
Roles: rbac.RoleNames(roles.Roles),
|
||||
Groups: roles.Groups,
|
||||
Scope: rbac.ScopeName(key.Scope),
|
||||
}
|
||||
ctx = context.WithValue(ctx, apiKeyContextKey{}, key)
|
||||
ctx = context.WithValue(ctx, userAuthKey{}, Authorization{
|
||||
Username: roles.Username,
|
||||
Actor: rbac.Subject{
|
||||
ID: key.UserID.String(),
|
||||
Roles: rbac.RoleNames(roles.Roles),
|
||||
Groups: roles.Groups,
|
||||
Scope: rbac.ScopeName(key.Scope),
|
||||
},
|
||||
Actor: actor,
|
||||
})
|
||||
// Set the auth context for the authzquerier as well.
|
||||
ctx = dbauthz.As(ctx, actor)
|
||||
|
||||
next.ServeHTTP(rw, r.WithContext(ctx))
|
||||
})
|
||||
|
37
coderd/httpmw/authz.go
Normal file
37
coderd/httpmw/authz.go
Normal file
@ -0,0 +1,37 @@
|
||||
package httpmw
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/coder/coder/coderd/database/dbauthz"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
// AsAuthzSystem is a chained handler that temporarily sets the dbauthz context
|
||||
// to System for the inner handlers, and resets the context afterwards.
|
||||
//
|
||||
// TODO: Refactor the middleware functions to not require this.
|
||||
// This is a bit of a kludge for now as some middleware functions require
|
||||
// usage as a system user in some cases, but not all cases. To avoid large
|
||||
// refactors, we use this middleware to temporarily set the context to a system.
|
||||
func AsAuthzSystem(mws ...func(http.Handler) http.Handler) func(http.Handler) http.Handler {
|
||||
chain := chi.Chain(mws...)
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
before, beforeExists := dbauthz.ActorFromContext(r.Context())
|
||||
if !beforeExists {
|
||||
// AsRemoveActor will actually remove the actor from the context.
|
||||
before = dbauthz.AsRemoveActor
|
||||
}
|
||||
|
||||
// nolint:gocritic // AsAuthzSystem needs to do this.
|
||||
r = r.WithContext(dbauthz.AsSystem(ctx))
|
||||
chain.Handler(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
r = r.WithContext(dbauthz.As(r.Context(), before))
|
||||
next.ServeHTTP(rw, r)
|
||||
})).ServeHTTP(rw, r)
|
||||
})
|
||||
}
|
||||
}
|
97
coderd/httpmw/authz_test.go
Normal file
97
coderd/httpmw/authz_test.go
Normal file
@ -0,0 +1,97 @@
|
||||
package httpmw_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/coder/coder/coderd/coderdtest"
|
||||
"github.com/coder/coder/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/coderd/httpmw"
|
||||
)
|
||||
|
||||
func TestAsAuthzSystem(t *testing.T) {
|
||||
t.Parallel()
|
||||
userActor := coderdtest.RandomRBACSubject()
|
||||
|
||||
base := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
actor, ok := dbauthz.ActorFromContext(r.Context())
|
||||
assert.True(t, ok, "actor should exist")
|
||||
assert.True(t, userActor.Equal(actor), "actor should be the user actor")
|
||||
})
|
||||
|
||||
mwSetUser := func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
r = r.WithContext(dbauthz.As(r.Context(), userActor))
|
||||
next.ServeHTTP(rw, r)
|
||||
})
|
||||
}
|
||||
|
||||
mwAssertSystem := mwAssert(func(req *http.Request) {
|
||||
actor, ok := dbauthz.ActorFromContext(req.Context())
|
||||
assert.True(t, ok, "actor should exist")
|
||||
assert.False(t, userActor.Equal(actor), "systemActor should not be the user actor")
|
||||
assert.Contains(t, actor.Roles.Names(), "system", "should have system role")
|
||||
})
|
||||
|
||||
mwAssertUser := mwAssert(func(req *http.Request) {
|
||||
actor, ok := dbauthz.ActorFromContext(req.Context())
|
||||
assert.True(t, ok, "actor should exist")
|
||||
assert.True(t, userActor.Equal(actor), "should be the useractor")
|
||||
})
|
||||
|
||||
mwAssertNoUser := mwAssert(func(req *http.Request) {
|
||||
_, ok := dbauthz.ActorFromContext(req.Context())
|
||||
assert.False(t, ok, "actor should not exist")
|
||||
})
|
||||
|
||||
// Request as the user actor
|
||||
const pattern = "/"
|
||||
req := httptest.NewRequest("GET", pattern, nil)
|
||||
res := httptest.NewRecorder()
|
||||
|
||||
handler := chi.NewRouter()
|
||||
handler.Route(pattern, func(r chi.Router) {
|
||||
r.Use(
|
||||
// First assert there is no actor context
|
||||
mwAssertNoUser,
|
||||
httpmw.AsAuthzSystem(
|
||||
// Assert the system actor
|
||||
mwAssertSystem,
|
||||
mwAssertSystem,
|
||||
),
|
||||
// Assert no user present outside of the AsAuthzSystem chain
|
||||
mwAssertNoUser,
|
||||
// ----
|
||||
// Set to the user actor
|
||||
mwSetUser,
|
||||
// Assert the user actor
|
||||
mwAssertUser,
|
||||
httpmw.AsAuthzSystem(
|
||||
// Assert the system actor
|
||||
mwAssertSystem,
|
||||
mwAssertSystem,
|
||||
),
|
||||
// Check the user actor was returned to the context
|
||||
mwAssertUser,
|
||||
)
|
||||
r.Handle("/", base)
|
||||
r.NotFound(func(writer http.ResponseWriter, request *http.Request) {
|
||||
assert.Fail(t, "should not hit not found, the route should be correct")
|
||||
})
|
||||
})
|
||||
|
||||
handler.ServeHTTP(res, req)
|
||||
}
|
||||
|
||||
func mwAssert(assertF func(req *http.Request)) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
assertF(r)
|
||||
next.ServeHTTP(rw, r)
|
||||
})
|
||||
}
|
||||
}
|
@ -11,6 +11,7 @@ import (
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/codersdk"
|
||||
)
|
||||
@ -68,7 +69,8 @@ func ExtractUserParam(db database.Store, redirectToLoginOnMe bool) func(http.Han
|
||||
})
|
||||
return
|
||||
}
|
||||
user, err = db.GetUserByID(ctx, apiKey.UserID)
|
||||
//nolint:gocritic // System needs to be able to get user from param.
|
||||
user, err = db.GetUserByID(dbauthz.AsSystem(ctx), apiKey.UserID)
|
||||
if xerrors.Is(err, sql.ErrNoRows) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
@ -81,8 +83,8 @@ func ExtractUserParam(db database.Store, redirectToLoginOnMe bool) func(http.Han
|
||||
return
|
||||
}
|
||||
} else if userID, err := uuid.Parse(userQuery); err == nil {
|
||||
// If the userQuery is a valid uuid
|
||||
user, err = db.GetUserByID(ctx, userID)
|
||||
//nolint:gocritic // If the userQuery is a valid uuid
|
||||
user, err = db.GetUserByID(dbauthz.AsSystem(ctx), userID)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: userErrorMessage,
|
||||
@ -90,8 +92,8 @@ func ExtractUserParam(db database.Store, redirectToLoginOnMe bool) func(http.Han
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// Try as a username last
|
||||
user, err = db.GetUserByEmailOrUsername(ctx, database.GetUserByEmailOrUsernameParams{
|
||||
// nolint:gocritic // Try as a username last
|
||||
user, err = db.GetUserByEmailOrUsername(dbauthz.AsSystem(ctx), database.GetUserByEmailOrUsernameParams{
|
||||
Username: userQuery,
|
||||
})
|
||||
if err != nil {
|
||||
|
@ -10,7 +10,9 @@ import (
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/coderd/rbac"
|
||||
"github.com/coder/coder/codersdk"
|
||||
)
|
||||
|
||||
@ -45,7 +47,8 @@ func ExtractWorkspaceAgent(db database.Store) func(http.Handler) http.Handler {
|
||||
})
|
||||
return
|
||||
}
|
||||
agent, err := db.GetWorkspaceAgentByAuthToken(ctx, token)
|
||||
//nolint:gocritic // System needs to be able to get workspace agents.
|
||||
agent, err := db.GetWorkspaceAgentByAuthToken(dbauthz.AsSystem(ctx), token)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
httpapi.Write(ctx, rw, http.StatusUnauthorized, codersdk.Response{
|
||||
@ -62,8 +65,50 @@ func ExtractWorkspaceAgent(db database.Store) func(http.Handler) http.Handler {
|
||||
return
|
||||
}
|
||||
|
||||
//nolint:gocritic // System needs to be able to get workspace agents.
|
||||
subject, err := getAgentSubject(dbauthz.AsSystem(ctx), db, agent)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching workspace agent.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
ctx = context.WithValue(ctx, workspaceAgentContextKey{}, agent)
|
||||
// Also set the dbauthz actor for the request.
|
||||
ctx = dbauthz.As(ctx, subject)
|
||||
next.ServeHTTP(rw, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func getAgentSubject(ctx context.Context, db database.Store, agent database.WorkspaceAgent) (rbac.Subject, error) {
|
||||
// TODO: make a different query that gets the workspace owner and roles along with the agent.
|
||||
workspace, err := db.GetWorkspaceByAgentID(ctx, agent.ID)
|
||||
if err != nil {
|
||||
return rbac.Subject{}, err
|
||||
}
|
||||
|
||||
user, err := db.GetUserByID(ctx, workspace.OwnerID)
|
||||
if err != nil {
|
||||
return rbac.Subject{}, err
|
||||
}
|
||||
|
||||
roles, err := db.GetAuthorizationUserRoles(ctx, user.ID)
|
||||
if err != nil {
|
||||
return rbac.Subject{}, err
|
||||
}
|
||||
|
||||
// A user that creates a workspace can use this agent auth token and
|
||||
// impersonate the workspace. So to prevent privilege escalation, the
|
||||
// subject inherits the roles of the user that owns the workspace.
|
||||
// We then add a workspace-agent scope to limit the permissions
|
||||
// to only what the workspace agent needs.
|
||||
return rbac.Subject{
|
||||
ID: user.ID.String(),
|
||||
Roles: rbac.RoleNames(roles.Roles),
|
||||
Groups: roles.Groups,
|
||||
Scope: rbac.WorkspaceAgentScope(workspace.ID, user.ID),
|
||||
}, nil
|
||||
}
|
||||
|
@ -19,11 +19,10 @@ import (
|
||||
func TestWorkspaceAgent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setup := func(db database.Store) (*http.Request, uuid.UUID) {
|
||||
token := uuid.New()
|
||||
setup := func(db database.Store, token uuid.UUID) *http.Request {
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
r.Header.Set(codersdk.SessionTokenHeader, token.String())
|
||||
return r, token
|
||||
return r
|
||||
}
|
||||
|
||||
t.Run("None", func(t *testing.T) {
|
||||
@ -34,7 +33,7 @@ func TestWorkspaceAgent(t *testing.T) {
|
||||
httpmw.ExtractWorkspaceAgent(db),
|
||||
)
|
||||
rtr.Get("/", nil)
|
||||
r, _ := setup(db)
|
||||
r := setup(db, uuid.New())
|
||||
rw := httptest.NewRecorder()
|
||||
rtr.ServeHTTP(rw, r)
|
||||
|
||||
@ -46,6 +45,24 @@ func TestWorkspaceAgent(t *testing.T) {
|
||||
t.Run("Found", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := dbfake.New()
|
||||
var (
|
||||
user = dbgen.User(t, db, database.User{})
|
||||
workspace = dbgen.Workspace(t, db, database.Workspace{
|
||||
OwnerID: user.ID,
|
||||
})
|
||||
job = dbgen.ProvisionerJob(t, db, database.ProvisionerJob{})
|
||||
resource = dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
|
||||
JobID: job.ID,
|
||||
})
|
||||
_ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
|
||||
WorkspaceID: workspace.ID,
|
||||
JobID: job.ID,
|
||||
})
|
||||
agent = dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
||||
ResourceID: resource.ID,
|
||||
})
|
||||
)
|
||||
|
||||
rtr := chi.NewRouter()
|
||||
rtr.Use(
|
||||
httpmw.ExtractWorkspaceAgent(db),
|
||||
@ -54,10 +71,7 @@ func TestWorkspaceAgent(t *testing.T) {
|
||||
_ = httpmw.WorkspaceAgent(r)
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
})
|
||||
r, token := setup(db)
|
||||
_ = dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
|
||||
AuthToken: token,
|
||||
})
|
||||
r := setup(db, agent.AuthToken)
|
||||
rw := httptest.NewRecorder()
|
||||
rtr.ServeHTTP(rw, r)
|
||||
|
||||
|
Reference in New Issue
Block a user