feat: Backend api for filtering users using filter query string (#2553)

* User search query string
This commit is contained in:
Steven Masley
2022-06-24 10:02:23 -05:00
committed by GitHub
parent 981fb2764f
commit d21ab2115d
13 changed files with 344 additions and 85 deletions

View File

@ -338,10 +338,11 @@ func TestAuthorizeAllEndpoints(t *testing.T) {
AssertAction: rbac.ActionRead,
AssertObject: workspaceRBACObj,
},
"POST:/api/v2/users/{user}/organizations/": {
"POST:/api/v2/users/{user}/organizations": {
AssertAction: rbac.ActionCreate,
AssertObject: rbac.ResourceOrganization,
},
"GET:/api/v2/users": {StatusCode: http.StatusOK, AssertObject: rbac.ResourceUser},
// These endpoints need payloads to get to the auth part. Payloads will be required
"PUT:/api/v2/users/{user}/roles": {StatusCode: http.StatusBadRequest, NoAuthorize: true},

View File

@ -285,19 +285,25 @@ func (q *fakeQuerier) GetUsers(_ context.Context, params database.GetUsersParams
users = tmp
}
if len(params.Status) == 0 {
params.Status = []database.UserStatus{database.UserStatusActive}
}
usersFilteredByStatus := make([]database.User, 0, len(users))
for i, user := range users {
for _, status := range params.Status {
if user.Status == status {
if len(params.Status) > 0 {
usersFilteredByStatus := make([]database.User, 0, len(users))
for i, user := range users {
if slice.Contains(params.Status, user.Status) {
usersFilteredByStatus = append(usersFilteredByStatus, users[i])
}
}
users = usersFilteredByStatus
}
if len(params.RbacRole) > 0 {
usersFilteredByRole := make([]database.User, 0, len(users))
for i, user := range users {
if slice.Overlap(params.RbacRole, user.RBACRoles) {
usersFilteredByRole = append(usersFilteredByRole, users[i])
}
}
users = usersFilteredByRole
}
users = usersFilteredByStatus
if params.OffsetOpt > 0 {
if int(params.OffsetOpt) > len(users)-1 {

View File

@ -30,3 +30,10 @@ func (d ProvisionerDaemon) RBACObject() rbac.Object {
func (f File) RBACObject() rbac.Object {
return rbac.ResourceFile.WithID(f.Hash).WithOwner(f.CreatedBy.String())
}
// RBACObject returns the RBAC object for the site wide user resource.
// If you are trying to get the RBAC object for the UserData, use
// rbac.ResourceUserData
func (u User) RBACObject() rbac.Object {
return rbac.ResourceUser.WithID(u.ID.String())
}

View File

@ -2571,27 +2571,33 @@ WHERE
AND CASE
-- @status needs to be a text because it can be empty, If it was
-- user_status enum, it would not.
WHEN cardinality($3 :: user_status[]) > 0 THEN (
WHEN cardinality($3 :: user_status[]) > 0 THEN
status = ANY($3 :: user_status[])
)
ELSE
-- Only show active by default
status = 'active'
ELSE true
END
-- Filter by rbac_roles
AND CASE
-- @rbac_role allows filtering by rbac roles. If 'member' is included, show everyone, as
-- everyone is a member.
WHEN cardinality($4 :: text[]) > 0 AND 'member' != ANY($4 :: text[]) THEN
rbac_roles && $4 :: text[]
ELSE true
END
-- End of filters
ORDER BY
-- Deterministic and consistent ordering of all users, even if they share
-- a timestamp. This is to ensure consistent pagination.
(created_at, id) ASC OFFSET $4
(created_at, id) ASC OFFSET $5
LIMIT
-- A null limit means "no limit", so -1 means return all
NULLIF($5 :: int, -1)
NULLIF($6 :: int, -1)
`
type GetUsersParams struct {
AfterID uuid.UUID `db:"after_id" json:"after_id"`
Search string `db:"search" json:"search"`
Status []UserStatus `db:"status" json:"status"`
RbacRole []string `db:"rbac_role" json:"rbac_role"`
OffsetOpt int32 `db:"offset_opt" json:"offset_opt"`
LimitOpt int32 `db:"limit_opt" json:"limit_opt"`
}
@ -2601,6 +2607,7 @@ func (q *sqlQuerier) GetUsers(ctx context.Context, arg GetUsersParams) ([]User,
arg.AfterID,
arg.Search,
pq.Array(arg.Status),
pq.Array(arg.RbacRole),
arg.OffsetOpt,
arg.LimitOpt,
)

View File

@ -108,12 +108,17 @@ WHERE
AND CASE
-- @status needs to be a text because it can be empty, If it was
-- user_status enum, it would not.
WHEN cardinality(@status :: user_status[]) > 0 THEN (
WHEN cardinality(@status :: user_status[]) > 0 THEN
status = ANY(@status :: user_status[])
)
ELSE
-- Only show active by default
status = 'active'
ELSE true
END
-- Filter by rbac_roles
AND CASE
-- @rbac_role allows filtering by rbac roles. If 'member' is included, show everyone, as
-- everyone is a member.
WHEN cardinality(@rbac_role :: text[]) > 0 AND 'member' != ANY(@rbac_role :: text[]) THEN
rbac_roles && @rbac_role :: text[]
ELSE true
END
-- End of filters
ORDER BY

View File

@ -83,14 +83,31 @@ func (p *QueryParamParser) UUIDs(vals url.Values, def []uuid.UUID, queryParam st
return v
}
func (p *QueryParamParser) String(vals url.Values, def string, queryParam string) string {
v, err := parseQueryParam(vals, func(v string) (string, error) {
func (*QueryParamParser) String(vals url.Values, def string, queryParam string) string {
v, _ := parseQueryParam(vals, func(v string) (string, error) {
return v, nil
}, def, queryParam)
return v
}
func (*QueryParamParser) Strings(vals url.Values, def []string, queryParam string) []string {
v, _ := parseQueryParam(vals, func(v string) ([]string, error) {
if v == "" {
return []string{}, nil
}
return strings.Split(v, ","), nil
}, def, queryParam)
return v
}
// ParseCustom has to be a function, not a method on QueryParamParser because generics
// cannot be used on struct methods.
func ParseCustom[T any](parser *QueryParamParser, vals url.Values, def T, queryParam string, parseFunc func(v string) (T, error)) T {
v, err := parseQueryParam(vals, parseFunc, def, queryParam)
if err != nil {
p.Errors = append(p.Errors, Error{
parser.Errors = append(parser.Errors, Error{
Field: queryParam,
Detail: fmt.Sprintf("Query param %q must be a valid string", queryParam),
Detail: fmt.Sprintf("Query param %q has invalid uuids: %q", queryParam, err.Error()),
})
}
return v

View File

@ -8,6 +8,7 @@ import (
"fmt"
"net"
"net/http"
"net/url"
"strings"
"time"
@ -119,35 +120,13 @@ func (api *API) postFirstUser(rw http.ResponseWriter, r *http.Request) {
}
func (api *API) users(rw http.ResponseWriter, r *http.Request) {
var (
searchName = r.URL.Query().Get("search")
statusFilters = r.URL.Query().Get("status")
)
statuses := make([]database.UserStatus, 0)
if statusFilters != "" {
// Split on commas if present to account for it being a list
for _, filter := range strings.Split(statusFilters, ",") {
switch database.UserStatus(filter) {
case database.UserStatusSuspended, database.UserStatusActive:
statuses = append(statuses, database.UserStatus(filter))
default:
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
Message: fmt.Sprintf("%q is not a valid user status.", filter),
Validations: []httpapi.Error{
{Field: "status", Detail: "invalid status"},
},
})
return
}
}
}
// Reading all users across the site.
if !api.Authorize(r, rbac.ActionRead, rbac.ResourceUser) {
httpapi.Forbidden(rw)
return
query := r.URL.Query().Get("q")
params, errs := userSearchQuery(query)
if len(errs) > 0 {
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
Message: "Invalid user search query.",
Validations: errs,
})
}
paginationParams, ok := parsePagination(rw, r)
@ -159,8 +138,9 @@ func (api *API) users(rw http.ResponseWriter, r *http.Request) {
AfterID: paginationParams.AfterID,
OffsetOpt: int32(paginationParams.Offset),
LimitOpt: int32(paginationParams.Limit),
Search: searchName,
Status: statuses,
Search: params.Search,
Status: params.Status,
RbacRole: params.RbacRole,
})
if errors.Is(err, sql.ErrNoRows) {
httpapi.Write(rw, http.StatusOK, []codersdk.User{})
@ -174,6 +154,7 @@ func (api *API) users(rw http.ResponseWriter, r *http.Request) {
return
}
users = AuthorizeFilter(api, r, rbac.ActionRead, users)
userIDs := make([]uuid.UUID, 0, len(users))
for _, user := range users {
userIDs = append(userIDs, user.ID)
@ -971,3 +952,56 @@ func findUser(id uuid.UUID, users []database.User) *database.User {
}
return nil
}
func userSearchQuery(query string) (database.GetUsersParams, []httpapi.Error) {
searchParams := make(url.Values)
if query == "" {
// No filter
return database.GetUsersParams{}, nil
}
// Because we do this in 2 passes, we want to maintain quotes on the first
// pass.Further splitting occurs on the second pass and quotes will be
// dropped.
elements := splitQueryParameterByDelimiter(query, ' ', true)
for _, element := range elements {
parts := splitQueryParameterByDelimiter(element, ':', false)
switch len(parts) {
case 1:
// No key:value pair.
searchParams.Set("search", parts[0])
case 2:
searchParams.Set(parts[0], parts[1])
default:
return database.GetUsersParams{}, []httpapi.Error{
{Field: "q", Detail: fmt.Sprintf("Query element %q can only contain 1 ':'", element)},
}
}
}
parser := httpapi.NewQueryParamParser()
filter := database.GetUsersParams{
Search: parser.String(searchParams, "", "search"),
Status: httpapi.ParseCustom(parser, searchParams, []database.UserStatus{}, "status", parseUserStatus),
RbacRole: parser.Strings(searchParams, []string{}, "role"),
}
return filter, parser.Errors
}
// parseUserStatus ensures proper enums are used for user statuses
func parseUserStatus(v string) ([]database.UserStatus, error) {
var statuses []database.UserStatus
if v == "" {
return statuses, nil
}
parts := strings.Split(v, ",")
for _, part := range parts {
switch database.UserStatus(part) {
case database.UserStatusActive, database.UserStatusSuspended:
statuses = append(statuses, database.UserStatus(part))
default:
return []database.UserStatus{}, xerrors.Errorf("%q is not a valid user status", part)
}
}
return statuses, nil
}

View File

@ -704,6 +704,131 @@ func TestGetUser(t *testing.T) {
})
}
// TestUsersFilter creates a set of users to run various filters against for testing.
func TestUsersFilter(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerD: true})
first := coderdtest.CreateFirstUser(t, client)
firstUser, err := client.User(context.Background(), codersdk.Me)
require.NoError(t, err, "fetch me")
users := make([]codersdk.User, 0)
users = append(users, firstUser)
for i := 0; i < 15; i++ {
roles := []string{}
if i%2 == 0 {
roles = append(roles, rbac.RoleAdmin())
}
if i%3 == 0 {
roles = append(roles, "auditor")
}
userClient := coderdtest.CreateAnotherUser(t, client, first.OrganizationID, roles...)
user, err := userClient.User(context.Background(), codersdk.Me)
require.NoError(t, err, "fetch me")
if i%4 == 0 {
user, err = client.UpdateUserStatus(context.Background(), user.ID.String(), codersdk.UserStatusSuspended)
require.NoError(t, err, "suspend user")
}
users = append(users, user)
}
// --- Setup done ---
testCases := []struct {
Name string
Filter codersdk.UsersRequest
// If FilterF is true, we include it in the expected results
FilterF func(f codersdk.UsersRequest, user codersdk.User) bool
}{
{
Name: "All",
Filter: codersdk.UsersRequest{
Status: codersdk.UserStatusSuspended + "," + codersdk.UserStatusActive,
},
FilterF: func(_ codersdk.UsersRequest, u codersdk.User) bool {
return true
},
},
{
Name: "Active",
Filter: codersdk.UsersRequest{
Status: codersdk.UserStatusActive,
},
FilterF: func(_ codersdk.UsersRequest, u codersdk.User) bool {
return u.Status == codersdk.UserStatusActive
},
},
{
Name: "Suspended",
Filter: codersdk.UsersRequest{
Status: codersdk.UserStatusSuspended,
},
FilterF: func(_ codersdk.UsersRequest, u codersdk.User) bool {
return u.Status == codersdk.UserStatusSuspended
},
},
{
Name: "NameContains",
Filter: codersdk.UsersRequest{
Search: "a",
},
FilterF: func(_ codersdk.UsersRequest, u codersdk.User) bool {
return (strings.Contains(u.Username, "a") || strings.Contains(u.Email, "a"))
},
},
{
Name: "Admins",
Filter: codersdk.UsersRequest{
Role: rbac.RoleAdmin(),
Status: codersdk.UserStatusSuspended + "," + codersdk.UserStatusActive,
},
FilterF: func(_ codersdk.UsersRequest, u codersdk.User) bool {
for _, r := range u.Roles {
if r.Name == rbac.RoleAdmin() {
return true
}
}
return false
},
},
{
Name: "SearchQuery",
Filter: codersdk.UsersRequest{
SearchQuery: "i role:admin status:active",
},
FilterF: func(_ codersdk.UsersRequest, u codersdk.User) bool {
for _, r := range u.Roles {
if r.Name == rbac.RoleAdmin() {
return (strings.Contains(u.Username, "i") || strings.Contains(u.Email, "i")) &&
u.Status == codersdk.UserStatusActive
}
}
return false
},
},
}
for _, c := range testCases {
c := c
t.Run(c.Name, func(t *testing.T) {
t.Parallel()
matched, err := client.Users(context.Background(), c.Filter)
require.NoError(t, err, "fetch workspaces")
exp := make([]codersdk.User, 0)
for _, made := range users {
match := c.FilterF(c.Filter, made)
if match {
exp = append(exp, made)
}
}
require.ElementsMatch(t, exp, matched, "expected workspaces returned")
})
}
}
func TestGetUsers(t *testing.T) {
t.Parallel()
t.Run("AllUsers", func(t *testing.T) {
@ -754,7 +879,7 @@ func TestGetUsers(t *testing.T) {
require.NoError(t, err)
users, err := client.Users(context.Background(), codersdk.UsersRequest{
Status: string(codersdk.UserStatusActive),
Status: codersdk.UserStatusActive,
})
require.NoError(t, err)
require.ElementsMatch(t, active, users)

View File

@ -8,3 +8,15 @@ func Contains[T comparable](haystack []T, needle T) bool {
}
return false
}
// Overlap returns if the 2 sets have any overlap (element(s) in common)
func Overlap[T comparable](a []T, b []T) bool {
// For each element in b, if at least 1 is contained in 'a',
// return true.
for _, element := range b {
if Contains(a, element) {
return true
}
}
return false
}

View File

@ -21,6 +21,35 @@ func TestContains(t *testing.T) {
)
}
func TestOverlap(t *testing.T) {
t.Parallel()
assertSetOverlaps(t, true, []int{1, 2, 3, 4, 5}, []int{1, 2, 3, 4, 5})
assertSetOverlaps(t, true, []int{10}, []int{10})
assertSetOverlaps(t, false, []int{1, 2, 3, 4, 5}, []int{6, 7, 8, 9})
assertSetOverlaps(t, false, []int{1, 2, 3, 4, 5}, []int{})
assertSetOverlaps(t, false, []int{}, []int{})
assertSetOverlaps(t, true, []string{"hello", "world", "foo", "bar", "baz"}, []string{"hello", "world", "baz"})
assertSetOverlaps(t, true,
[]uuid.UUID{uuid.New(), uuid.MustParse("c7c6686d-a93c-4df2-bef9-5f837e9a33d5"), uuid.MustParse("8f3b3e0b-2c3f-46a5-a365-fd5b62bd8818")},
[]uuid.UUID{uuid.MustParse("c7c6686d-a93c-4df2-bef9-5f837e9a33d5")},
)
}
func assertSetOverlaps[T comparable](t *testing.T, overlap bool, a []T, b []T) {
t.Helper()
for _, e := range a {
require.True(t, slice.Overlap(a, []T{e}), "elements in set should overlap with itself")
}
for _, e := range b {
require.True(t, slice.Overlap(b, []T{e}), "elements in set should overlap with itself")
}
require.Equal(t, overlap, slice.Overlap(a, b))
}
func assertSetContains[T comparable](t *testing.T, set []T, in []T, out []T) {
t.Helper()
for _, e := range set {