feat: Add initial AuthzQuerier implementation (#5919)

feat: Add initial AuthzQuerier implementation
- Adds package database/dbauthz that adds a database.Store implementation where each method goes through AuthZ checks
- Implements all database.Store methods on AuthzQuerier
- Updates and fixes unit tests where required
- Updates coderd initialization to use AuthzQuerier if codersdk.ExperimentAuthzQuerier is enabled
This commit is contained in:
Steven Masley
2023-02-14 08:27:06 -06:00
committed by GitHub
parent ebdfdc749d
commit 6fb8aff6d0
59 changed files with 5013 additions and 136 deletions

View File

@ -19,6 +19,7 @@ import (
"golang.org/x/xerrors"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/dbauthz"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/coderd/rbac"
"github.com/coder/coder/codersdk"
@ -159,7 +160,8 @@ func ExtractAPIKey(cfg ExtractAPIKeyConfig) func(http.Handler) http.Handler {
return
}
key, err := cfg.DB.GetAPIKeyByID(r.Context(), keyID)
//nolint:gocritic // System needs to fetch API key to check if it's valid.
key, err := cfg.DB.GetAPIKeyByID(dbauthz.AsSystem(ctx), keyID)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
optionalWrite(http.StatusUnauthorized, codersdk.Response{
@ -192,7 +194,8 @@ func ExtractAPIKey(cfg ExtractAPIKeyConfig) func(http.Handler) http.Handler {
changed = false
)
if key.LoginType == database.LoginTypeGithub || key.LoginType == database.LoginTypeOIDC {
link, err = cfg.DB.GetUserLinkByUserIDLoginType(r.Context(), database.GetUserLinkByUserIDLoginTypeParams{
//nolint:gocritic // System needs to fetch UserLink to check if it's valid.
link, err = cfg.DB.GetUserLinkByUserIDLoginType(dbauthz.AsSystem(ctx), database.GetUserLinkByUserIDLoginTypeParams{
UserID: key.UserID,
LoginType: key.LoginType,
})
@ -275,7 +278,8 @@ func ExtractAPIKey(cfg ExtractAPIKeyConfig) func(http.Handler) http.Handler {
}
}
if changed {
err := cfg.DB.UpdateAPIKeyByID(r.Context(), database.UpdateAPIKeyByIDParams{
//nolint:gocritic // System needs to update API Key LastUsed
err := cfg.DB.UpdateAPIKeyByID(dbauthz.AsSystem(ctx), database.UpdateAPIKeyByIDParams{
ID: key.ID,
LastUsed: key.LastUsed,
ExpiresAt: key.ExpiresAt,
@ -291,7 +295,8 @@ func ExtractAPIKey(cfg ExtractAPIKeyConfig) func(http.Handler) http.Handler {
// If the API Key is associated with a user_link (e.g. Github/OIDC)
// then we want to update the relevant oauth fields.
if link.UserID != uuid.Nil {
link, err = cfg.DB.UpdateUserLink(r.Context(), database.UpdateUserLinkParams{
// nolint:gocritic
link, err = cfg.DB.UpdateUserLink(dbauthz.AsSystem(ctx), database.UpdateUserLinkParams{
UserID: link.UserID,
LoginType: link.LoginType,
OAuthAccessToken: link.OAuthAccessToken,
@ -310,7 +315,8 @@ func ExtractAPIKey(cfg ExtractAPIKeyConfig) func(http.Handler) http.Handler {
// We only want to update this occasionally to reduce DB write
// load. We update alongside the UserLink and APIKey since it's
// easier on the DB to colocate writes.
_, err = cfg.DB.UpdateUserLastSeenAt(ctx, database.UpdateUserLastSeenAtParams{
// nolint:gocritic
_, err = cfg.DB.UpdateUserLastSeenAt(dbauthz.AsSystem(ctx), database.UpdateUserLastSeenAtParams{
ID: key.UserID,
LastSeenAt: database.Now(),
UpdatedAt: database.Now(),
@ -327,7 +333,8 @@ func ExtractAPIKey(cfg ExtractAPIKeyConfig) func(http.Handler) http.Handler {
// If the key is valid, we also fetch the user roles and status.
// The roles are used for RBAC authorize checks, and the status
// is to block 'suspended' users from accessing the platform.
roles, err := cfg.DB.GetAuthorizationUserRoles(r.Context(), key.UserID)
// nolint:gocritic
roles, err := cfg.DB.GetAuthorizationUserRoles(dbauthz.AsSystem(ctx), key.UserID)
if err != nil {
write(http.StatusUnauthorized, codersdk.Response{
Message: internalErrorMessage,
@ -343,16 +350,20 @@ func ExtractAPIKey(cfg ExtractAPIKeyConfig) func(http.Handler) http.Handler {
return
}
// Actor is the user's authorization context.
actor := rbac.Subject{
ID: key.UserID.String(),
Roles: rbac.RoleNames(roles.Roles),
Groups: roles.Groups,
Scope: rbac.ScopeName(key.Scope),
}
ctx = context.WithValue(ctx, apiKeyContextKey{}, key)
ctx = context.WithValue(ctx, userAuthKey{}, Authorization{
Username: roles.Username,
Actor: rbac.Subject{
ID: key.UserID.String(),
Roles: rbac.RoleNames(roles.Roles),
Groups: roles.Groups,
Scope: rbac.ScopeName(key.Scope),
},
Actor: actor,
})
// Set the auth context for the authzquerier as well.
ctx = dbauthz.As(ctx, actor)
next.ServeHTTP(rw, r.WithContext(ctx))
})

37
coderd/httpmw/authz.go Normal file
View File

@ -0,0 +1,37 @@
package httpmw
import (
"net/http"
"github.com/coder/coder/coderd/database/dbauthz"
"github.com/go-chi/chi/v5"
)
// AsAuthzSystem is a chained handler that temporarily sets the dbauthz context
// to System for the inner handlers, and resets the context afterwards.
//
// TODO: Refactor the middleware functions to not require this.
// This is a bit of a kludge for now as some middleware functions require
// usage as a system user in some cases, but not all cases. To avoid large
// refactors, we use this middleware to temporarily set the context to a system.
func AsAuthzSystem(mws ...func(http.Handler) http.Handler) func(http.Handler) http.Handler {
chain := chi.Chain(mws...)
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
before, beforeExists := dbauthz.ActorFromContext(r.Context())
if !beforeExists {
// AsRemoveActor will actually remove the actor from the context.
before = dbauthz.AsRemoveActor
}
// nolint:gocritic // AsAuthzSystem needs to do this.
r = r.WithContext(dbauthz.AsSystem(ctx))
chain.Handler(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
r = r.WithContext(dbauthz.As(r.Context(), before))
next.ServeHTTP(rw, r)
})).ServeHTTP(rw, r)
})
}
}

View File

@ -0,0 +1,97 @@
package httpmw_test
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/go-chi/chi/v5"
"github.com/stretchr/testify/assert"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/coderd/database/dbauthz"
"github.com/coder/coder/coderd/httpmw"
)
func TestAsAuthzSystem(t *testing.T) {
t.Parallel()
userActor := coderdtest.RandomRBACSubject()
base := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
actor, ok := dbauthz.ActorFromContext(r.Context())
assert.True(t, ok, "actor should exist")
assert.True(t, userActor.Equal(actor), "actor should be the user actor")
})
mwSetUser := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
r = r.WithContext(dbauthz.As(r.Context(), userActor))
next.ServeHTTP(rw, r)
})
}
mwAssertSystem := mwAssert(func(req *http.Request) {
actor, ok := dbauthz.ActorFromContext(req.Context())
assert.True(t, ok, "actor should exist")
assert.False(t, userActor.Equal(actor), "systemActor should not be the user actor")
assert.Contains(t, actor.Roles.Names(), "system", "should have system role")
})
mwAssertUser := mwAssert(func(req *http.Request) {
actor, ok := dbauthz.ActorFromContext(req.Context())
assert.True(t, ok, "actor should exist")
assert.True(t, userActor.Equal(actor), "should be the useractor")
})
mwAssertNoUser := mwAssert(func(req *http.Request) {
_, ok := dbauthz.ActorFromContext(req.Context())
assert.False(t, ok, "actor should not exist")
})
// Request as the user actor
const pattern = "/"
req := httptest.NewRequest("GET", pattern, nil)
res := httptest.NewRecorder()
handler := chi.NewRouter()
handler.Route(pattern, func(r chi.Router) {
r.Use(
// First assert there is no actor context
mwAssertNoUser,
httpmw.AsAuthzSystem(
// Assert the system actor
mwAssertSystem,
mwAssertSystem,
),
// Assert no user present outside of the AsAuthzSystem chain
mwAssertNoUser,
// ----
// Set to the user actor
mwSetUser,
// Assert the user actor
mwAssertUser,
httpmw.AsAuthzSystem(
// Assert the system actor
mwAssertSystem,
mwAssertSystem,
),
// Check the user actor was returned to the context
mwAssertUser,
)
r.Handle("/", base)
r.NotFound(func(writer http.ResponseWriter, request *http.Request) {
assert.Fail(t, "should not hit not found, the route should be correct")
})
})
handler.ServeHTTP(res, req)
}
func mwAssert(assertF func(req *http.Request)) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
assertF(r)
next.ServeHTTP(rw, r)
})
}
}

View File

@ -11,6 +11,7 @@ import (
"github.com/google/uuid"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/dbauthz"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/codersdk"
)
@ -68,7 +69,8 @@ func ExtractUserParam(db database.Store, redirectToLoginOnMe bool) func(http.Han
})
return
}
user, err = db.GetUserByID(ctx, apiKey.UserID)
//nolint:gocritic // System needs to be able to get user from param.
user, err = db.GetUserByID(dbauthz.AsSystem(ctx), apiKey.UserID)
if xerrors.Is(err, sql.ErrNoRows) {
httpapi.ResourceNotFound(rw)
return
@ -81,8 +83,8 @@ func ExtractUserParam(db database.Store, redirectToLoginOnMe bool) func(http.Han
return
}
} else if userID, err := uuid.Parse(userQuery); err == nil {
// If the userQuery is a valid uuid
user, err = db.GetUserByID(ctx, userID)
//nolint:gocritic // If the userQuery is a valid uuid
user, err = db.GetUserByID(dbauthz.AsSystem(ctx), userID)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: userErrorMessage,
@ -90,8 +92,8 @@ func ExtractUserParam(db database.Store, redirectToLoginOnMe bool) func(http.Han
return
}
} else {
// Try as a username last
user, err = db.GetUserByEmailOrUsername(ctx, database.GetUserByEmailOrUsernameParams{
// nolint:gocritic // Try as a username last
user, err = db.GetUserByEmailOrUsername(dbauthz.AsSystem(ctx), database.GetUserByEmailOrUsernameParams{
Username: userQuery,
})
if err != nil {

View File

@ -10,7 +10,9 @@ import (
"github.com/google/uuid"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/dbauthz"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/coderd/rbac"
"github.com/coder/coder/codersdk"
)
@ -45,7 +47,8 @@ func ExtractWorkspaceAgent(db database.Store) func(http.Handler) http.Handler {
})
return
}
agent, err := db.GetWorkspaceAgentByAuthToken(ctx, token)
//nolint:gocritic // System needs to be able to get workspace agents.
agent, err := db.GetWorkspaceAgentByAuthToken(dbauthz.AsSystem(ctx), token)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
httpapi.Write(ctx, rw, http.StatusUnauthorized, codersdk.Response{
@ -62,8 +65,50 @@ func ExtractWorkspaceAgent(db database.Store) func(http.Handler) http.Handler {
return
}
//nolint:gocritic // System needs to be able to get workspace agents.
subject, err := getAgentSubject(dbauthz.AsSystem(ctx), db, agent)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching workspace agent.",
Detail: err.Error(),
})
return
}
ctx = context.WithValue(ctx, workspaceAgentContextKey{}, agent)
// Also set the dbauthz actor for the request.
ctx = dbauthz.As(ctx, subject)
next.ServeHTTP(rw, r.WithContext(ctx))
})
}
}
func getAgentSubject(ctx context.Context, db database.Store, agent database.WorkspaceAgent) (rbac.Subject, error) {
// TODO: make a different query that gets the workspace owner and roles along with the agent.
workspace, err := db.GetWorkspaceByAgentID(ctx, agent.ID)
if err != nil {
return rbac.Subject{}, err
}
user, err := db.GetUserByID(ctx, workspace.OwnerID)
if err != nil {
return rbac.Subject{}, err
}
roles, err := db.GetAuthorizationUserRoles(ctx, user.ID)
if err != nil {
return rbac.Subject{}, err
}
// A user that creates a workspace can use this agent auth token and
// impersonate the workspace. So to prevent privilege escalation, the
// subject inherits the roles of the user that owns the workspace.
// We then add a workspace-agent scope to limit the permissions
// to only what the workspace agent needs.
return rbac.Subject{
ID: user.ID.String(),
Roles: rbac.RoleNames(roles.Roles),
Groups: roles.Groups,
Scope: rbac.WorkspaceAgentScope(workspace.ID, user.ID),
}, nil
}

View File

@ -19,11 +19,10 @@ import (
func TestWorkspaceAgent(t *testing.T) {
t.Parallel()
setup := func(db database.Store) (*http.Request, uuid.UUID) {
token := uuid.New()
setup := func(db database.Store, token uuid.UUID) *http.Request {
r := httptest.NewRequest("GET", "/", nil)
r.Header.Set(codersdk.SessionTokenHeader, token.String())
return r, token
return r
}
t.Run("None", func(t *testing.T) {
@ -34,7 +33,7 @@ func TestWorkspaceAgent(t *testing.T) {
httpmw.ExtractWorkspaceAgent(db),
)
rtr.Get("/", nil)
r, _ := setup(db)
r := setup(db, uuid.New())
rw := httptest.NewRecorder()
rtr.ServeHTTP(rw, r)
@ -46,6 +45,24 @@ func TestWorkspaceAgent(t *testing.T) {
t.Run("Found", func(t *testing.T) {
t.Parallel()
db := dbfake.New()
var (
user = dbgen.User(t, db, database.User{})
workspace = dbgen.Workspace(t, db, database.Workspace{
OwnerID: user.ID,
})
job = dbgen.ProvisionerJob(t, db, database.ProvisionerJob{})
resource = dbgen.WorkspaceResource(t, db, database.WorkspaceResource{
JobID: job.ID,
})
_ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{
WorkspaceID: workspace.ID,
JobID: job.ID,
})
agent = dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
ResourceID: resource.ID,
})
)
rtr := chi.NewRouter()
rtr.Use(
httpmw.ExtractWorkspaceAgent(db),
@ -54,10 +71,7 @@ func TestWorkspaceAgent(t *testing.T) {
_ = httpmw.WorkspaceAgent(r)
rw.WriteHeader(http.StatusOK)
})
r, token := setup(db)
_ = dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{
AuthToken: token,
})
r := setup(db, agent.AuthToken)
rw := httptest.NewRecorder()
rtr.ServeHTTP(rw, r)