mirror of
https://github.com/coder/coder.git
synced 2025-07-03 16:13:58 +00:00
chore: Allow RecordingAuthorizer to record multiple rbac authz calls (#6024)
* chore: Allow RecordingAuthorizer to record multiple rbac authz calls Prior iteration only recorded the last call. This is required for more comprehensive testing
This commit is contained in:
@ -7,17 +7,17 @@ import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coder/coder/coderd/database/dbfake"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/coderd"
|
||||
"github.com/coder/coder/coderd/database/dbfake"
|
||||
"github.com/coder/coder/coderd/rbac"
|
||||
"github.com/coder/coder/coderd/rbac/regosql"
|
||||
"github.com/coder/coder/codersdk"
|
||||
@ -443,7 +443,9 @@ func NewAuthTester(ctx context.Context, t *testing.T, client *codersdk.Client, a
|
||||
|
||||
func (a *AuthTester) Test(ctx context.Context, assertRoute map[string]RouteCheck, skipRoutes map[string]string) {
|
||||
// Always fail auth from this point forward
|
||||
a.authorizer.AlwaysReturn = rbac.ForbiddenWithInternal(xerrors.New("fake implementation"), nil, nil)
|
||||
a.authorizer.Wrapped = &FakeAuthorizer{
|
||||
AlwaysReturn: rbac.ForbiddenWithInternal(xerrors.New("fake implementation"), nil, nil),
|
||||
}
|
||||
|
||||
routeMissing := make(map[string]bool)
|
||||
for k, v := range assertRoute {
|
||||
@ -483,7 +485,7 @@ func (a *AuthTester) Test(ctx context.Context, assertRoute map[string]RouteCheck
|
||||
return nil
|
||||
}
|
||||
a.t.Run(name, func(t *testing.T) {
|
||||
a.authorizer.reset()
|
||||
a.authorizer.Reset()
|
||||
routeKey := strings.TrimRight(name, "/")
|
||||
|
||||
routeAssertions, ok := assertRoute[routeKey]
|
||||
@ -514,18 +516,19 @@ func (a *AuthTester) Test(ctx context.Context, assertRoute map[string]RouteCheck
|
||||
assert.Equal(t, http.StatusForbidden, resp.StatusCode, "expect unauthorized")
|
||||
}
|
||||
}
|
||||
if a.authorizer.Called != nil {
|
||||
if a.authorizer.lastCall() != nil {
|
||||
last := a.authorizer.lastCall()
|
||||
if routeAssertions.AssertAction != "" {
|
||||
assert.Equal(t, routeAssertions.AssertAction, a.authorizer.Called.Action, "resource action")
|
||||
assert.Equal(t, routeAssertions.AssertAction, last.Action, "resource action")
|
||||
}
|
||||
if routeAssertions.AssertObject.Type != "" {
|
||||
assert.Equal(t, routeAssertions.AssertObject.Type, a.authorizer.Called.Object.Type, "resource type")
|
||||
assert.Equal(t, routeAssertions.AssertObject.Type, last.Object.Type, "resource type")
|
||||
}
|
||||
if routeAssertions.AssertObject.Owner != "" {
|
||||
assert.Equal(t, routeAssertions.AssertObject.Owner, a.authorizer.Called.Object.Owner, "resource owner")
|
||||
assert.Equal(t, routeAssertions.AssertObject.Owner, last.Object.Owner, "resource owner")
|
||||
}
|
||||
if routeAssertions.AssertObject.OrgID != "" {
|
||||
assert.Equal(t, routeAssertions.AssertObject.OrgID, a.authorizer.Called.Object.OrgID, "resource org")
|
||||
assert.Equal(t, routeAssertions.AssertObject.OrgID, last.Object.OrgID, "resource org")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@ -539,52 +542,195 @@ func (a *AuthTester) Test(ctx context.Context, assertRoute map[string]RouteCheck
|
||||
}
|
||||
|
||||
type authCall struct {
|
||||
Subject rbac.Subject
|
||||
Action rbac.Action
|
||||
Object rbac.Object
|
||||
}
|
||||
Actor rbac.Subject
|
||||
Action rbac.Action
|
||||
Object rbac.Object
|
||||
|
||||
type RecordingAuthorizer struct {
|
||||
Called *authCall
|
||||
AlwaysReturn error
|
||||
asserted bool
|
||||
}
|
||||
|
||||
var _ rbac.Authorizer = (*RecordingAuthorizer)(nil)
|
||||
|
||||
// AuthorizeSQL does not record the call. This matches the postgres behavior
|
||||
// of not calling Authorize()
|
||||
func (r *RecordingAuthorizer) AuthorizeSQL(_ context.Context, _ rbac.Subject, _ rbac.Action, _ rbac.Object) error {
|
||||
return r.AlwaysReturn
|
||||
// RecordingAuthorizer wraps any rbac.Authorizer and records all Authorize()
|
||||
// calls made. This is useful for testing as these calls can later be asserted.
|
||||
type RecordingAuthorizer struct {
|
||||
sync.RWMutex
|
||||
Called []authCall
|
||||
Wrapped rbac.Authorizer
|
||||
}
|
||||
|
||||
func (r *RecordingAuthorizer) Authorize(_ context.Context, subject rbac.Subject, action rbac.Action, object rbac.Object) error {
|
||||
r.Called = &authCall{
|
||||
Subject: subject,
|
||||
Action: action,
|
||||
Object: object,
|
||||
type ActionObjectPair struct {
|
||||
Action rbac.Action
|
||||
Object rbac.Object
|
||||
}
|
||||
|
||||
// Pair is on the RecordingAuthorizer to be easy to find and keep the pkg
|
||||
// interface smaller.
|
||||
func (*RecordingAuthorizer) Pair(action rbac.Action, object rbac.Objecter) ActionObjectPair {
|
||||
return ActionObjectPair{
|
||||
Action: action,
|
||||
Object: object.RBACObject(),
|
||||
}
|
||||
return r.AlwaysReturn
|
||||
}
|
||||
|
||||
func (r *RecordingAuthorizer) Prepare(_ context.Context, subject rbac.Subject, action rbac.Action, _ string) (rbac.PreparedAuthorized, error) {
|
||||
return &fakePreparedAuthorizer{
|
||||
Original: r,
|
||||
Subject: subject,
|
||||
Action: action,
|
||||
HardCodedSQLString: "true",
|
||||
// AllAsserted returns an error if all calls to Authorize() have not been
|
||||
// asserted and checked. This is useful for testing to ensure that all
|
||||
// Authorize() calls are checked in the unit test.
|
||||
func (r *RecordingAuthorizer) AllAsserted() error {
|
||||
r.RLock()
|
||||
defer r.RUnlock()
|
||||
missed := []authCall{}
|
||||
for _, c := range r.Called {
|
||||
if !c.asserted {
|
||||
missed = append(missed, c)
|
||||
}
|
||||
}
|
||||
|
||||
if len(missed) > 0 {
|
||||
return xerrors.Errorf("missed calls: %+v", missed)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AssertActor asserts in order. If the order of authz calls does not match,
|
||||
// this will fail.
|
||||
func (r *RecordingAuthorizer) AssertActor(t *testing.T, actor rbac.Subject, did ...ActionObjectPair) {
|
||||
r.RLock()
|
||||
defer r.RUnlock()
|
||||
ptr := 0
|
||||
for i, call := range r.Called {
|
||||
if ptr == len(did) {
|
||||
// Finished all assertions
|
||||
return
|
||||
}
|
||||
if call.Actor.ID == actor.ID {
|
||||
action, object := did[ptr].Action, did[ptr].Object
|
||||
assert.Equalf(t, action, call.Action, "assert action %d", ptr)
|
||||
assert.Equalf(t, object, call.Object, "assert object %d", ptr)
|
||||
r.Called[i].asserted = true
|
||||
ptr++
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equalf(t, len(did), ptr, "assert actor: didn't find all actions, %d missing actions", len(did)-ptr)
|
||||
}
|
||||
|
||||
// recordAuthorize is the internal method that records the Authorize() call.
|
||||
func (r *RecordingAuthorizer) recordAuthorize(subject rbac.Subject, action rbac.Action, object rbac.Object) {
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
r.Called = append(r.Called, authCall{
|
||||
Actor: subject,
|
||||
Action: action,
|
||||
Object: object,
|
||||
})
|
||||
}
|
||||
|
||||
func (r *RecordingAuthorizer) Authorize(ctx context.Context, subject rbac.Subject, action rbac.Action, object rbac.Object) error {
|
||||
r.recordAuthorize(subject, action, object)
|
||||
if r.Wrapped == nil {
|
||||
panic("Developer error: RecordingAuthorizer.Wrapped is nil")
|
||||
}
|
||||
return r.Wrapped.Authorize(ctx, subject, action, object)
|
||||
}
|
||||
|
||||
func (r *RecordingAuthorizer) Prepare(ctx context.Context, subject rbac.Subject, action rbac.Action, objectType string) (rbac.PreparedAuthorized, error) {
|
||||
r.RLock()
|
||||
defer r.RUnlock()
|
||||
if r.Wrapped == nil {
|
||||
panic("Developer error: RecordingAuthorizer.Wrapped is nil")
|
||||
}
|
||||
|
||||
prep, err := r.Wrapped.Prepare(ctx, subject, action, objectType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &PreparedRecorder{
|
||||
rec: r,
|
||||
prepped: prep,
|
||||
subject: subject,
|
||||
action: action,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *RecordingAuthorizer) reset() {
|
||||
// Reset clears the recorded Authorize() calls.
|
||||
func (r *RecordingAuthorizer) Reset() {
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
r.Called = nil
|
||||
}
|
||||
|
||||
// lastCall is implemented to support legacy tests.
|
||||
// Deprecated
|
||||
func (r *RecordingAuthorizer) lastCall() *authCall {
|
||||
r.RLock()
|
||||
defer r.RUnlock()
|
||||
if len(r.Called) == 0 {
|
||||
return nil
|
||||
}
|
||||
return &r.Called[len(r.Called)-1]
|
||||
}
|
||||
|
||||
// PreparedRecorder is the prepared version of the RecordingAuthorizer.
|
||||
// It records the Authorize() calls to the original recorder. If the caller
|
||||
// uses CompileToSQL, all recording stops. This is to support parity between
|
||||
// memory and SQL backed dbs.
|
||||
type PreparedRecorder struct {
|
||||
rec *RecordingAuthorizer
|
||||
prepped rbac.PreparedAuthorized
|
||||
subject rbac.Subject
|
||||
action rbac.Action
|
||||
|
||||
rw sync.Mutex
|
||||
usingSQL bool
|
||||
}
|
||||
|
||||
func (s *PreparedRecorder) Authorize(ctx context.Context, object rbac.Object) error {
|
||||
s.rw.Lock()
|
||||
defer s.rw.Unlock()
|
||||
|
||||
if !s.usingSQL {
|
||||
s.rec.recordAuthorize(s.subject, s.action, object)
|
||||
}
|
||||
return s.prepped.Authorize(ctx, object)
|
||||
}
|
||||
func (s *PreparedRecorder) CompileToSQL(ctx context.Context, cfg regosql.ConvertConfig) (string, error) {
|
||||
s.rw.Lock()
|
||||
defer s.rw.Unlock()
|
||||
|
||||
s.usingSQL = true
|
||||
return s.prepped.CompileToSQL(ctx, cfg)
|
||||
}
|
||||
|
||||
// FakeAuthorizer is an Authorizer that always returns the same error.
|
||||
type FakeAuthorizer struct {
|
||||
// AlwaysReturn is the error that will be returned by Authorize.
|
||||
AlwaysReturn error
|
||||
}
|
||||
|
||||
var _ rbac.Authorizer = (*FakeAuthorizer)(nil)
|
||||
|
||||
func (d *FakeAuthorizer) Authorize(_ context.Context, _ rbac.Subject, _ rbac.Action, _ rbac.Object) error {
|
||||
return d.AlwaysReturn
|
||||
}
|
||||
|
||||
func (d *FakeAuthorizer) Prepare(_ context.Context, subject rbac.Subject, action rbac.Action, _ string) (rbac.PreparedAuthorized, error) {
|
||||
return &fakePreparedAuthorizer{
|
||||
Original: d,
|
||||
Subject: subject,
|
||||
Action: action,
|
||||
}, nil
|
||||
}
|
||||
|
||||
var _ rbac.PreparedAuthorized = (*fakePreparedAuthorizer)(nil)
|
||||
|
||||
// fakePreparedAuthorizer is the prepared version of a FakeAuthorizer. It will
|
||||
// return the same error as the original FakeAuthorizer.
|
||||
type fakePreparedAuthorizer struct {
|
||||
Original *RecordingAuthorizer
|
||||
Subject rbac.Subject
|
||||
Action rbac.Action
|
||||
HardCodedSQLString string
|
||||
HardCodedRegoString string
|
||||
sync.RWMutex
|
||||
Original *FakeAuthorizer
|
||||
Subject rbac.Subject
|
||||
Action rbac.Action
|
||||
}
|
||||
|
||||
func (f *fakePreparedAuthorizer) Authorize(ctx context.Context, object rbac.Object) error {
|
||||
@ -593,17 +739,6 @@ func (f *fakePreparedAuthorizer) Authorize(ctx context.Context, object rbac.Obje
|
||||
|
||||
// CompileToSQL returns a compiled version of the authorizer that will work for
|
||||
// in memory databases. This fake version will not work against a SQL database.
|
||||
func (fakePreparedAuthorizer) CompileToSQL(_ context.Context, _ regosql.ConvertConfig) (string, error) {
|
||||
return "", xerrors.New("not implemented")
|
||||
}
|
||||
|
||||
func (f *fakePreparedAuthorizer) Eval(object rbac.Object) bool {
|
||||
return f.Original.AuthorizeSQL(context.Background(), f.Subject, f.Action, object) == nil
|
||||
}
|
||||
|
||||
func (f fakePreparedAuthorizer) RegoString() string {
|
||||
if f.HardCodedRegoString != "" {
|
||||
return f.HardCodedRegoString
|
||||
}
|
||||
panic("not implemented")
|
||||
func (*fakePreparedAuthorizer) CompileToSQL(_ context.Context, _ regosql.ConvertConfig) (string, error) {
|
||||
return "not a valid sql string", nil
|
||||
}
|
||||
|
Reference in New Issue
Block a user