mirror of
https://github.com/coder/coder.git
synced 2025-07-06 15:41:45 +00:00
feat: add template RBAC/groups (#4235)
This commit is contained in:
@ -54,6 +54,7 @@ type Authorization struct {
|
||||
ID uuid.UUID
|
||||
Username string
|
||||
Roles []string
|
||||
Groups []string
|
||||
Scope database.APIKeyScope
|
||||
}
|
||||
|
||||
@ -360,6 +361,7 @@ func ExtractAPIKey(cfg ExtractAPIKeyConfig) func(http.Handler) http.Handler {
|
||||
Username: roles.Username,
|
||||
Roles: roles.Roles,
|
||||
Scope: key.Scope,
|
||||
Groups: roles.Groups,
|
||||
})
|
||||
|
||||
next.ServeHTTP(rw, r.WithContext(ctx))
|
||||
|
@ -4,23 +4,22 @@ import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coder/coder/coderd/rbac"
|
||||
"github.com/coder/coder/codersdk"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tabbed/pqtype"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/coderd/database/databasefake"
|
||||
"github.com/coder/coder/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/coderd/httpmw"
|
||||
"github.com/coder/coder/coderd/rbac"
|
||||
"github.com/coder/coder/codersdk"
|
||||
)
|
||||
|
||||
func TestExtractUserRoles(t *testing.T) {
|
||||
@ -71,14 +70,49 @@ func TestExtractUserRoles(t *testing.T) {
|
||||
return user, append(roles, append(orgRoles, rbac.RoleMember(), rbac.RoleOrgMember(org.ID))...), token
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "MultipleOrgMember",
|
||||
AddUser: func(db database.Store) (database.User, []string, string) {
|
||||
roles := []string{}
|
||||
user, token := addUser(t, db, roles...)
|
||||
roles = append(roles, rbac.RoleMember())
|
||||
for i := 0; i < 3; i++ {
|
||||
organization, err := db.InsertOrganization(context.Background(), database.InsertOrganizationParams{
|
||||
ID: uuid.New(),
|
||||
Name: fmt.Sprintf("testorg%d", i),
|
||||
Description: "test",
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
orgRoles := []string{}
|
||||
if i%2 == 0 {
|
||||
orgRoles = append(orgRoles, rbac.RoleOrgAdmin(organization.ID))
|
||||
}
|
||||
_, err = db.InsertOrganizationMember(context.Background(), database.InsertOrganizationMemberParams{
|
||||
OrganizationID: organization.ID,
|
||||
UserID: user.ID,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
Roles: orgRoles,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
roles = append(roles, orgRoles...)
|
||||
roles = append(roles, rbac.RoleOrgMember(organization.ID))
|
||||
}
|
||||
return user, roles, token
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range testCases {
|
||||
c := c
|
||||
t.Run(c.Name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
db = databasefake.New()
|
||||
db, _ = dbtestutil.NewDB(t)
|
||||
user, expRoles, token = c.AddUser(db)
|
||||
rw = httptest.NewRecorder()
|
||||
rtr = chi.NewRouter()
|
||||
@ -118,6 +152,7 @@ func addUser(t *testing.T, db database.Store, roles ...string) (database.User, s
|
||||
Email: "admin@email.com",
|
||||
Username: "admin",
|
||||
RBACRoles: roles,
|
||||
LoginType: database.LoginTypePassword,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@ -129,6 +164,13 @@ func addUser(t *testing.T, db database.Store, roles ...string) (database.User, s
|
||||
ExpiresAt: database.Now().Add(time.Minute),
|
||||
LoginType: database.LoginTypePassword,
|
||||
Scope: database.APIKeyScopeAll,
|
||||
IPAddress: pqtype.Inet{
|
||||
IPNet: net.IPNet{
|
||||
IP: net.ParseIP("0.0.0.0"),
|
||||
Mask: net.IPMask{0, 0, 0, 0},
|
||||
},
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
|
56
coderd/httpmw/groupparam.go
Normal file
56
coderd/httpmw/groupparam.go
Normal file
@ -0,0 +1,56 @@
|
||||
package httpmw
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/codersdk"
|
||||
)
|
||||
|
||||
type groupParamContextKey struct{}
|
||||
|
||||
// GroupParam returns the group extracted via the ExtraGroupParam middleware.
|
||||
func GroupParam(r *http.Request) database.Group {
|
||||
group, ok := r.Context().Value(groupParamContextKey{}).(database.Group)
|
||||
if !ok {
|
||||
panic("developer error: group param middleware not provided")
|
||||
}
|
||||
return group
|
||||
}
|
||||
|
||||
// ExtraGroupParam grabs a group from the "group" URL parameter.
|
||||
func ExtractGroupParam(db database.Store) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
groupID, parsed := parseUUID(rw, r, "group")
|
||||
if !parsed {
|
||||
return
|
||||
}
|
||||
|
||||
group, err := db.GetGroupByID(r.Context(), groupID)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching group.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
ctx = context.WithValue(ctx, groupParamContextKey{}, group)
|
||||
chi.RouteContext(ctx).URLParams.Add("organization", group.OrganizationID.String())
|
||||
next.ServeHTTP(rw, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
103
coderd/httpmw/groupparam_test.go
Normal file
103
coderd/httpmw/groupparam_test.go
Normal file
@ -0,0 +1,103 @@
|
||||
package httpmw_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/databasefake"
|
||||
"github.com/coder/coder/coderd/httpmw"
|
||||
"github.com/coder/coder/testutil"
|
||||
)
|
||||
|
||||
func TestGroupParam(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setup := func(t *testing.T) (database.Store, database.Group) {
|
||||
t.Helper()
|
||||
|
||||
ctx, _ := testutil.Context(t)
|
||||
db := databasefake.New()
|
||||
|
||||
orgID := uuid.New()
|
||||
organization, err := db.InsertOrganization(ctx, database.InsertOrganizationParams{
|
||||
ID: orgID,
|
||||
Name: "banana",
|
||||
Description: "wowie",
|
||||
CreatedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
group, err := db.InsertGroup(ctx, database.InsertGroupParams{
|
||||
ID: uuid.New(),
|
||||
Name: "yeww",
|
||||
OrganizationID: organization.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
return db, group
|
||||
}
|
||||
|
||||
t.Run("OK", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
db, group = setup(t)
|
||||
r = httptest.NewRequest("GET", "/", nil)
|
||||
w = httptest.NewRecorder()
|
||||
)
|
||||
|
||||
router := chi.NewRouter()
|
||||
router.Use(httpmw.ExtractGroupParam(db))
|
||||
router.Get("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
g := httpmw.GroupParam(r)
|
||||
require.Equal(t, group, g)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
rctx := chi.NewRouteContext()
|
||||
rctx.URLParams.Add("group", group.ID.String())
|
||||
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, rctx))
|
||||
|
||||
router.ServeHTTP(w, r)
|
||||
|
||||
res := w.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusOK, res.StatusCode)
|
||||
})
|
||||
|
||||
t.Run("NotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
db, group = setup(t)
|
||||
r = httptest.NewRequest("GET", "/", nil)
|
||||
w = httptest.NewRecorder()
|
||||
)
|
||||
|
||||
router := chi.NewRouter()
|
||||
router.Use(httpmw.ExtractGroupParam(db))
|
||||
router.Get("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
g := httpmw.GroupParam(r)
|
||||
require.Equal(t, group, g)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
rctx := chi.NewRouteContext()
|
||||
rctx.URLParams.Add("group", uuid.NewString())
|
||||
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, rctx))
|
||||
|
||||
router.ServeHTTP(w, r)
|
||||
|
||||
res := w.Result()
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusNotFound, res.StatusCode)
|
||||
})
|
||||
}
|
@ -7,6 +7,7 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
@ -46,8 +47,20 @@ func ExtractTemplateVersionParam(db database.Store) func(http.Handler) http.Hand
|
||||
return
|
||||
}
|
||||
|
||||
template, err := db.GetTemplateByID(r.Context(), templateVersion.TemplateID.UUID)
|
||||
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching template.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
ctx = context.WithValue(ctx, templateVersionParamContextKey{}, templateVersion)
|
||||
chi.RouteContext(ctx).URLParams.Add("organization", templateVersion.OrganizationID.String())
|
||||
|
||||
ctx = context.WithValue(ctx, templateParamContextKey{}, template)
|
||||
|
||||
next.ServeHTTP(rw, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
Reference in New Issue
Block a user