mirror of
https://github.com/coder/coder.git
synced 2025-07-06 15:41:45 +00:00
feat: Backend api for filtering users using filter query string (#2553)
* User search query string
This commit is contained in:
@ -338,10 +338,11 @@ func TestAuthorizeAllEndpoints(t *testing.T) {
|
||||
AssertAction: rbac.ActionRead,
|
||||
AssertObject: workspaceRBACObj,
|
||||
},
|
||||
"POST:/api/v2/users/{user}/organizations/": {
|
||||
"POST:/api/v2/users/{user}/organizations": {
|
||||
AssertAction: rbac.ActionCreate,
|
||||
AssertObject: rbac.ResourceOrganization,
|
||||
},
|
||||
"GET:/api/v2/users": {StatusCode: http.StatusOK, AssertObject: rbac.ResourceUser},
|
||||
|
||||
// These endpoints need payloads to get to the auth part. Payloads will be required
|
||||
"PUT:/api/v2/users/{user}/roles": {StatusCode: http.StatusBadRequest, NoAuthorize: true},
|
||||
|
@ -285,19 +285,25 @@ func (q *fakeQuerier) GetUsers(_ context.Context, params database.GetUsersParams
|
||||
users = tmp
|
||||
}
|
||||
|
||||
if len(params.Status) == 0 {
|
||||
params.Status = []database.UserStatus{database.UserStatusActive}
|
||||
}
|
||||
|
||||
usersFilteredByStatus := make([]database.User, 0, len(users))
|
||||
for i, user := range users {
|
||||
for _, status := range params.Status {
|
||||
if user.Status == status {
|
||||
if len(params.Status) > 0 {
|
||||
usersFilteredByStatus := make([]database.User, 0, len(users))
|
||||
for i, user := range users {
|
||||
if slice.Contains(params.Status, user.Status) {
|
||||
usersFilteredByStatus = append(usersFilteredByStatus, users[i])
|
||||
}
|
||||
}
|
||||
users = usersFilteredByStatus
|
||||
}
|
||||
|
||||
if len(params.RbacRole) > 0 {
|
||||
usersFilteredByRole := make([]database.User, 0, len(users))
|
||||
for i, user := range users {
|
||||
if slice.Overlap(params.RbacRole, user.RBACRoles) {
|
||||
usersFilteredByRole = append(usersFilteredByRole, users[i])
|
||||
}
|
||||
}
|
||||
users = usersFilteredByRole
|
||||
}
|
||||
users = usersFilteredByStatus
|
||||
|
||||
if params.OffsetOpt > 0 {
|
||||
if int(params.OffsetOpt) > len(users)-1 {
|
||||
|
@ -30,3 +30,10 @@ func (d ProvisionerDaemon) RBACObject() rbac.Object {
|
||||
func (f File) RBACObject() rbac.Object {
|
||||
return rbac.ResourceFile.WithID(f.Hash).WithOwner(f.CreatedBy.String())
|
||||
}
|
||||
|
||||
// RBACObject returns the RBAC object for the site wide user resource.
|
||||
// If you are trying to get the RBAC object for the UserData, use
|
||||
// rbac.ResourceUserData
|
||||
func (u User) RBACObject() rbac.Object {
|
||||
return rbac.ResourceUser.WithID(u.ID.String())
|
||||
}
|
||||
|
@ -2571,27 +2571,33 @@ WHERE
|
||||
AND CASE
|
||||
-- @status needs to be a text because it can be empty, If it was
|
||||
-- user_status enum, it would not.
|
||||
WHEN cardinality($3 :: user_status[]) > 0 THEN (
|
||||
WHEN cardinality($3 :: user_status[]) > 0 THEN
|
||||
status = ANY($3 :: user_status[])
|
||||
)
|
||||
ELSE
|
||||
-- Only show active by default
|
||||
status = 'active'
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by rbac_roles
|
||||
AND CASE
|
||||
-- @rbac_role allows filtering by rbac roles. If 'member' is included, show everyone, as
|
||||
-- everyone is a member.
|
||||
WHEN cardinality($4 :: text[]) > 0 AND 'member' != ANY($4 :: text[]) THEN
|
||||
rbac_roles && $4 :: text[]
|
||||
ELSE true
|
||||
END
|
||||
-- End of filters
|
||||
ORDER BY
|
||||
-- Deterministic and consistent ordering of all users, even if they share
|
||||
-- a timestamp. This is to ensure consistent pagination.
|
||||
(created_at, id) ASC OFFSET $4
|
||||
(created_at, id) ASC OFFSET $5
|
||||
LIMIT
|
||||
-- A null limit means "no limit", so -1 means return all
|
||||
NULLIF($5 :: int, -1)
|
||||
NULLIF($6 :: int, -1)
|
||||
`
|
||||
|
||||
type GetUsersParams struct {
|
||||
AfterID uuid.UUID `db:"after_id" json:"after_id"`
|
||||
Search string `db:"search" json:"search"`
|
||||
Status []UserStatus `db:"status" json:"status"`
|
||||
RbacRole []string `db:"rbac_role" json:"rbac_role"`
|
||||
OffsetOpt int32 `db:"offset_opt" json:"offset_opt"`
|
||||
LimitOpt int32 `db:"limit_opt" json:"limit_opt"`
|
||||
}
|
||||
@ -2601,6 +2607,7 @@ func (q *sqlQuerier) GetUsers(ctx context.Context, arg GetUsersParams) ([]User,
|
||||
arg.AfterID,
|
||||
arg.Search,
|
||||
pq.Array(arg.Status),
|
||||
pq.Array(arg.RbacRole),
|
||||
arg.OffsetOpt,
|
||||
arg.LimitOpt,
|
||||
)
|
||||
|
@ -108,12 +108,17 @@ WHERE
|
||||
AND CASE
|
||||
-- @status needs to be a text because it can be empty, If it was
|
||||
-- user_status enum, it would not.
|
||||
WHEN cardinality(@status :: user_status[]) > 0 THEN (
|
||||
WHEN cardinality(@status :: user_status[]) > 0 THEN
|
||||
status = ANY(@status :: user_status[])
|
||||
)
|
||||
ELSE
|
||||
-- Only show active by default
|
||||
status = 'active'
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by rbac_roles
|
||||
AND CASE
|
||||
-- @rbac_role allows filtering by rbac roles. If 'member' is included, show everyone, as
|
||||
-- everyone is a member.
|
||||
WHEN cardinality(@rbac_role :: text[]) > 0 AND 'member' != ANY(@rbac_role :: text[]) THEN
|
||||
rbac_roles && @rbac_role :: text[]
|
||||
ELSE true
|
||||
END
|
||||
-- End of filters
|
||||
ORDER BY
|
||||
|
@ -83,14 +83,31 @@ func (p *QueryParamParser) UUIDs(vals url.Values, def []uuid.UUID, queryParam st
|
||||
return v
|
||||
}
|
||||
|
||||
func (p *QueryParamParser) String(vals url.Values, def string, queryParam string) string {
|
||||
v, err := parseQueryParam(vals, func(v string) (string, error) {
|
||||
func (*QueryParamParser) String(vals url.Values, def string, queryParam string) string {
|
||||
v, _ := parseQueryParam(vals, func(v string) (string, error) {
|
||||
return v, nil
|
||||
}, def, queryParam)
|
||||
return v
|
||||
}
|
||||
|
||||
func (*QueryParamParser) Strings(vals url.Values, def []string, queryParam string) []string {
|
||||
v, _ := parseQueryParam(vals, func(v string) ([]string, error) {
|
||||
if v == "" {
|
||||
return []string{}, nil
|
||||
}
|
||||
return strings.Split(v, ","), nil
|
||||
}, def, queryParam)
|
||||
return v
|
||||
}
|
||||
|
||||
// ParseCustom has to be a function, not a method on QueryParamParser because generics
|
||||
// cannot be used on struct methods.
|
||||
func ParseCustom[T any](parser *QueryParamParser, vals url.Values, def T, queryParam string, parseFunc func(v string) (T, error)) T {
|
||||
v, err := parseQueryParam(vals, parseFunc, def, queryParam)
|
||||
if err != nil {
|
||||
p.Errors = append(p.Errors, Error{
|
||||
parser.Errors = append(parser.Errors, Error{
|
||||
Field: queryParam,
|
||||
Detail: fmt.Sprintf("Query param %q must be a valid string", queryParam),
|
||||
Detail: fmt.Sprintf("Query param %q has invalid uuids: %q", queryParam, err.Error()),
|
||||
})
|
||||
}
|
||||
return v
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@ -119,35 +120,13 @@ func (api *API) postFirstUser(rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (api *API) users(rw http.ResponseWriter, r *http.Request) {
|
||||
var (
|
||||
searchName = r.URL.Query().Get("search")
|
||||
statusFilters = r.URL.Query().Get("status")
|
||||
)
|
||||
|
||||
statuses := make([]database.UserStatus, 0)
|
||||
|
||||
if statusFilters != "" {
|
||||
// Split on commas if present to account for it being a list
|
||||
for _, filter := range strings.Split(statusFilters, ",") {
|
||||
switch database.UserStatus(filter) {
|
||||
case database.UserStatusSuspended, database.UserStatusActive:
|
||||
statuses = append(statuses, database.UserStatus(filter))
|
||||
default:
|
||||
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
|
||||
Message: fmt.Sprintf("%q is not a valid user status.", filter),
|
||||
Validations: []httpapi.Error{
|
||||
{Field: "status", Detail: "invalid status"},
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reading all users across the site.
|
||||
if !api.Authorize(r, rbac.ActionRead, rbac.ResourceUser) {
|
||||
httpapi.Forbidden(rw)
|
||||
return
|
||||
query := r.URL.Query().Get("q")
|
||||
params, errs := userSearchQuery(query)
|
||||
if len(errs) > 0 {
|
||||
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
|
||||
Message: "Invalid user search query.",
|
||||
Validations: errs,
|
||||
})
|
||||
}
|
||||
|
||||
paginationParams, ok := parsePagination(rw, r)
|
||||
@ -159,8 +138,9 @@ func (api *API) users(rw http.ResponseWriter, r *http.Request) {
|
||||
AfterID: paginationParams.AfterID,
|
||||
OffsetOpt: int32(paginationParams.Offset),
|
||||
LimitOpt: int32(paginationParams.Limit),
|
||||
Search: searchName,
|
||||
Status: statuses,
|
||||
Search: params.Search,
|
||||
Status: params.Status,
|
||||
RbacRole: params.RbacRole,
|
||||
})
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
httpapi.Write(rw, http.StatusOK, []codersdk.User{})
|
||||
@ -174,6 +154,7 @@ func (api *API) users(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
users = AuthorizeFilter(api, r, rbac.ActionRead, users)
|
||||
userIDs := make([]uuid.UUID, 0, len(users))
|
||||
for _, user := range users {
|
||||
userIDs = append(userIDs, user.ID)
|
||||
@ -971,3 +952,56 @@ func findUser(id uuid.UUID, users []database.User) *database.User {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func userSearchQuery(query string) (database.GetUsersParams, []httpapi.Error) {
|
||||
searchParams := make(url.Values)
|
||||
if query == "" {
|
||||
// No filter
|
||||
return database.GetUsersParams{}, nil
|
||||
}
|
||||
// Because we do this in 2 passes, we want to maintain quotes on the first
|
||||
// pass.Further splitting occurs on the second pass and quotes will be
|
||||
// dropped.
|
||||
elements := splitQueryParameterByDelimiter(query, ' ', true)
|
||||
for _, element := range elements {
|
||||
parts := splitQueryParameterByDelimiter(element, ':', false)
|
||||
switch len(parts) {
|
||||
case 1:
|
||||
// No key:value pair.
|
||||
searchParams.Set("search", parts[0])
|
||||
case 2:
|
||||
searchParams.Set(parts[0], parts[1])
|
||||
default:
|
||||
return database.GetUsersParams{}, []httpapi.Error{
|
||||
{Field: "q", Detail: fmt.Sprintf("Query element %q can only contain 1 ':'", element)},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
parser := httpapi.NewQueryParamParser()
|
||||
filter := database.GetUsersParams{
|
||||
Search: parser.String(searchParams, "", "search"),
|
||||
Status: httpapi.ParseCustom(parser, searchParams, []database.UserStatus{}, "status", parseUserStatus),
|
||||
RbacRole: parser.Strings(searchParams, []string{}, "role"),
|
||||
}
|
||||
|
||||
return filter, parser.Errors
|
||||
}
|
||||
|
||||
// parseUserStatus ensures proper enums are used for user statuses
|
||||
func parseUserStatus(v string) ([]database.UserStatus, error) {
|
||||
var statuses []database.UserStatus
|
||||
if v == "" {
|
||||
return statuses, nil
|
||||
}
|
||||
parts := strings.Split(v, ",")
|
||||
for _, part := range parts {
|
||||
switch database.UserStatus(part) {
|
||||
case database.UserStatusActive, database.UserStatusSuspended:
|
||||
statuses = append(statuses, database.UserStatus(part))
|
||||
default:
|
||||
return []database.UserStatus{}, xerrors.Errorf("%q is not a valid user status", part)
|
||||
}
|
||||
}
|
||||
return statuses, nil
|
||||
}
|
||||
|
@ -704,6 +704,131 @@ func TestGetUser(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
// TestUsersFilter creates a set of users to run various filters against for testing.
|
||||
func TestUsersFilter(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerD: true})
|
||||
first := coderdtest.CreateFirstUser(t, client)
|
||||
firstUser, err := client.User(context.Background(), codersdk.Me)
|
||||
require.NoError(t, err, "fetch me")
|
||||
|
||||
users := make([]codersdk.User, 0)
|
||||
users = append(users, firstUser)
|
||||
for i := 0; i < 15; i++ {
|
||||
roles := []string{}
|
||||
if i%2 == 0 {
|
||||
roles = append(roles, rbac.RoleAdmin())
|
||||
}
|
||||
if i%3 == 0 {
|
||||
roles = append(roles, "auditor")
|
||||
}
|
||||
userClient := coderdtest.CreateAnotherUser(t, client, first.OrganizationID, roles...)
|
||||
user, err := userClient.User(context.Background(), codersdk.Me)
|
||||
require.NoError(t, err, "fetch me")
|
||||
|
||||
if i%4 == 0 {
|
||||
user, err = client.UpdateUserStatus(context.Background(), user.ID.String(), codersdk.UserStatusSuspended)
|
||||
require.NoError(t, err, "suspend user")
|
||||
}
|
||||
|
||||
users = append(users, user)
|
||||
}
|
||||
|
||||
// --- Setup done ---
|
||||
testCases := []struct {
|
||||
Name string
|
||||
Filter codersdk.UsersRequest
|
||||
// If FilterF is true, we include it in the expected results
|
||||
FilterF func(f codersdk.UsersRequest, user codersdk.User) bool
|
||||
}{
|
||||
{
|
||||
Name: "All",
|
||||
Filter: codersdk.UsersRequest{
|
||||
Status: codersdk.UserStatusSuspended + "," + codersdk.UserStatusActive,
|
||||
},
|
||||
FilterF: func(_ codersdk.UsersRequest, u codersdk.User) bool {
|
||||
return true
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "Active",
|
||||
Filter: codersdk.UsersRequest{
|
||||
Status: codersdk.UserStatusActive,
|
||||
},
|
||||
FilterF: func(_ codersdk.UsersRequest, u codersdk.User) bool {
|
||||
return u.Status == codersdk.UserStatusActive
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "Suspended",
|
||||
Filter: codersdk.UsersRequest{
|
||||
Status: codersdk.UserStatusSuspended,
|
||||
},
|
||||
FilterF: func(_ codersdk.UsersRequest, u codersdk.User) bool {
|
||||
return u.Status == codersdk.UserStatusSuspended
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "NameContains",
|
||||
Filter: codersdk.UsersRequest{
|
||||
Search: "a",
|
||||
},
|
||||
FilterF: func(_ codersdk.UsersRequest, u codersdk.User) bool {
|
||||
return (strings.Contains(u.Username, "a") || strings.Contains(u.Email, "a"))
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "Admins",
|
||||
Filter: codersdk.UsersRequest{
|
||||
Role: rbac.RoleAdmin(),
|
||||
Status: codersdk.UserStatusSuspended + "," + codersdk.UserStatusActive,
|
||||
},
|
||||
FilterF: func(_ codersdk.UsersRequest, u codersdk.User) bool {
|
||||
for _, r := range u.Roles {
|
||||
if r.Name == rbac.RoleAdmin() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "SearchQuery",
|
||||
Filter: codersdk.UsersRequest{
|
||||
SearchQuery: "i role:admin status:active",
|
||||
},
|
||||
FilterF: func(_ codersdk.UsersRequest, u codersdk.User) bool {
|
||||
for _, r := range u.Roles {
|
||||
if r.Name == rbac.RoleAdmin() {
|
||||
return (strings.Contains(u.Username, "i") || strings.Contains(u.Email, "i")) &&
|
||||
u.Status == codersdk.UserStatusActive
|
||||
}
|
||||
}
|
||||
return false
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range testCases {
|
||||
c := c
|
||||
t.Run(c.Name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
matched, err := client.Users(context.Background(), c.Filter)
|
||||
require.NoError(t, err, "fetch workspaces")
|
||||
|
||||
exp := make([]codersdk.User, 0)
|
||||
for _, made := range users {
|
||||
match := c.FilterF(c.Filter, made)
|
||||
if match {
|
||||
exp = append(exp, made)
|
||||
}
|
||||
}
|
||||
require.ElementsMatch(t, exp, matched, "expected workspaces returned")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetUsers(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("AllUsers", func(t *testing.T) {
|
||||
@ -754,7 +879,7 @@ func TestGetUsers(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
users, err := client.Users(context.Background(), codersdk.UsersRequest{
|
||||
Status: string(codersdk.UserStatusActive),
|
||||
Status: codersdk.UserStatusActive,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.ElementsMatch(t, active, users)
|
||||
|
@ -8,3 +8,15 @@ func Contains[T comparable](haystack []T, needle T) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Overlap returns if the 2 sets have any overlap (element(s) in common)
|
||||
func Overlap[T comparable](a []T, b []T) bool {
|
||||
// For each element in b, if at least 1 is contained in 'a',
|
||||
// return true.
|
||||
for _, element := range b {
|
||||
if Contains(a, element) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
@ -21,6 +21,35 @@ func TestContains(t *testing.T) {
|
||||
)
|
||||
}
|
||||
|
||||
func TestOverlap(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assertSetOverlaps(t, true, []int{1, 2, 3, 4, 5}, []int{1, 2, 3, 4, 5})
|
||||
assertSetOverlaps(t, true, []int{10}, []int{10})
|
||||
|
||||
assertSetOverlaps(t, false, []int{1, 2, 3, 4, 5}, []int{6, 7, 8, 9})
|
||||
assertSetOverlaps(t, false, []int{1, 2, 3, 4, 5}, []int{})
|
||||
assertSetOverlaps(t, false, []int{}, []int{})
|
||||
|
||||
assertSetOverlaps(t, true, []string{"hello", "world", "foo", "bar", "baz"}, []string{"hello", "world", "baz"})
|
||||
assertSetOverlaps(t, true,
|
||||
[]uuid.UUID{uuid.New(), uuid.MustParse("c7c6686d-a93c-4df2-bef9-5f837e9a33d5"), uuid.MustParse("8f3b3e0b-2c3f-46a5-a365-fd5b62bd8818")},
|
||||
[]uuid.UUID{uuid.MustParse("c7c6686d-a93c-4df2-bef9-5f837e9a33d5")},
|
||||
)
|
||||
}
|
||||
|
||||
func assertSetOverlaps[T comparable](t *testing.T, overlap bool, a []T, b []T) {
|
||||
t.Helper()
|
||||
for _, e := range a {
|
||||
require.True(t, slice.Overlap(a, []T{e}), "elements in set should overlap with itself")
|
||||
}
|
||||
for _, e := range b {
|
||||
require.True(t, slice.Overlap(b, []T{e}), "elements in set should overlap with itself")
|
||||
}
|
||||
|
||||
require.Equal(t, overlap, slice.Overlap(a, b))
|
||||
}
|
||||
|
||||
func assertSetContains[T comparable](t *testing.T, set []T, in []T, out []T) {
|
||||
t.Helper()
|
||||
for _, e := range set {
|
||||
|
Reference in New Issue
Block a user