feat: add count endpoint for users, enabling better pagination (#4848)

* Start on backend

* Hook up frontend

* Add to frontend test

* Add go test, wip

* Fix some test bugs

* Fix test

* Format

* Add to authorize.go

* copy user array into local variable

* Authorize route

* Log count error

* Authorize better

* Tweaks to authorization

* More authorization tweaks

* Make gen

* Fix test

Co-authored-by: Garrett <garrett@coder.com>
This commit is contained in:
Presley Pizzo
2022-11-08 10:58:44 -05:00
committed by GitHub
parent a4fbc74751
commit f496b149df
19 changed files with 620 additions and 216 deletions

View File

@ -437,6 +437,7 @@ func New(options *Options) *API {
)
r.Post("/", api.postUser)
r.Get("/", api.users)
r.Get("/count", api.userCount)
r.Post("/logout", api.postLogout)
// These routes query information about site wide roles.
r.Route("/roles", func(r chi.Router) {

View File

@ -246,6 +246,7 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
// Endpoints that use the SQLQuery filter.
"GET:/api/v2/workspaces/": {StatusCode: http.StatusOK, NoAuthorize: true},
"GET:/api/v2/workspaces/count": {StatusCode: http.StatusOK, NoAuthorize: true},
"GET:/api/v2/users/count": {StatusCode: http.StatusOK, NoAuthorize: true},
}
// Routes like proxy routes support all HTTP methods. A helper func to expand

View File

@ -457,6 +457,72 @@ func (q *fakeQuerier) GetActiveUserCount(_ context.Context) (int64, error) {
return active, nil
}
func (q *fakeQuerier) GetFilteredUserCount(ctx context.Context, arg database.GetFilteredUserCountParams) (int64, error) {
count, err := q.GetAuthorizedUserCount(ctx, arg, nil)
return count, err
}
func (q *fakeQuerier) GetAuthorizedUserCount(_ context.Context, params database.GetFilteredUserCountParams, authorizedFilter rbac.AuthorizeFilter) (int64, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
users := append([]database.User{}, q.users...)
if params.Deleted {
tmp := make([]database.User, 0, len(users))
for _, user := range users {
if user.Deleted {
tmp = append(tmp, user)
}
}
users = tmp
}
if params.Search != "" {
tmp := make([]database.User, 0, len(users))
for i, user := range users {
if strings.Contains(strings.ToLower(user.Email), strings.ToLower(params.Search)) {
tmp = append(tmp, users[i])
} else if strings.Contains(strings.ToLower(user.Username), strings.ToLower(params.Search)) {
tmp = append(tmp, users[i])
}
}
users = tmp
}
if len(params.Status) > 0 {
usersFilteredByStatus := make([]database.User, 0, len(users))
for i, user := range users {
if slice.ContainsCompare(params.Status, user.Status, func(a, b database.UserStatus) bool {
return strings.EqualFold(string(a), string(b))
}) {
usersFilteredByStatus = append(usersFilteredByStatus, users[i])
}
}
users = usersFilteredByStatus
}
if len(params.RbacRole) > 0 && !slice.Contains(params.RbacRole, rbac.RoleMember()) {
usersFilteredByRole := make([]database.User, 0, len(users))
for i, user := range users {
if slice.OverlapCompare(params.RbacRole, user.RBACRoles, strings.EqualFold) {
usersFilteredByRole = append(usersFilteredByRole, users[i])
}
}
users = usersFilteredByRole
}
for _, user := range q.workspaces {
// If the filter exists, ensure the object is authorized.
if authorizedFilter != nil && !authorizedFilter.Eval(user.RBACObject()) {
continue
}
}
return int64(len(users)), nil
}
func (q *fakeQuerier) UpdateUserDeletedByID(_ context.Context, params database.UpdateUserDeletedByIDParams) error {
q.mutex.Lock()
defer q.mutex.Unlock()

View File

@ -19,6 +19,7 @@ import (
type customQuerier interface {
templateQuerier
workspaceQuerier
userQuerier
}
type templateQuerier interface {
@ -169,8 +170,6 @@ func (q *sqlQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspa
}
func (q *sqlQuerier) GetAuthorizedWorkspaceCount(ctx context.Context, arg GetWorkspaceCountParams, authorizedFilter rbac.AuthorizeFilter) (int64, error) {
// In order to properly use ORDER BY, OFFSET, and LIMIT, we need to inject the
// authorizedFilter between the end of the where clause and those statements.
filter := strings.Replace(getWorkspaceCount, "-- @authorize_filter", fmt.Sprintf(" AND %s", authorizedFilter.SQLString(rbac.NoACLConfig())), 1)
// The name comment is for metric tracking
query := fmt.Sprintf("-- name: GetAuthorizedWorkspaceCount :one\n%s", filter)
@ -187,3 +186,21 @@ func (q *sqlQuerier) GetAuthorizedWorkspaceCount(ctx context.Context, arg GetWor
err := row.Scan(&count)
return count, err
}
type userQuerier interface {
GetAuthorizedUserCount(ctx context.Context, arg GetFilteredUserCountParams, authorizedFilter rbac.AuthorizeFilter) (int64, error)
}
func (q *sqlQuerier) GetAuthorizedUserCount(ctx context.Context, arg GetFilteredUserCountParams, authorizedFilter rbac.AuthorizeFilter) (int64, error) {
filter := strings.Replace(getFilteredUserCount, "-- @authorize_filter", fmt.Sprintf(" AND %s", authorizedFilter.SQLString(rbac.NoACLConfig())), 1)
query := fmt.Sprintf("-- name: GetAuthorizedUserCount :one\n%s", filter)
row := q.db.QueryRowContext(ctx, query,
arg.Deleted,
arg.Search,
pq.Array(arg.Status),
pq.Array(arg.RbacRole),
)
var count int64
err := row.Scan(&count)
return count, err
}

View File

@ -44,6 +44,7 @@ type sqlcQuerier interface {
GetDeploymentID(ctx context.Context) (string, error)
GetFileByHashAndCreator(ctx context.Context, arg GetFileByHashAndCreatorParams) (File, error)
GetFileByID(ctx context.Context, id uuid.UUID) (File, error)
GetFilteredUserCount(ctx context.Context, arg GetFilteredUserCountParams) (int64, error)
GetGitAuthLink(ctx context.Context, arg GetGitAuthLinkParams) (GitAuthLink, error)
GetGitSSHKey(ctx context.Context, userID uuid.UUID) (GitSSHKey, error)
GetGroupByID(ctx context.Context, id uuid.UUID) (Group, error)

View File

@ -3975,6 +3975,60 @@ func (q *sqlQuerier) GetAuthorizationUserRoles(ctx context.Context, userID uuid.
return i, err
}
const getFilteredUserCount = `-- name: GetFilteredUserCount :one
SELECT
COUNT(*)
FROM
users
WHERE
users.deleted = $1
-- Start filters
-- Filter by name, email or username
AND CASE
WHEN $2 :: text != '' THEN (
email ILIKE concat('%', $2, '%')
OR username ILIKE concat('%', $2, '%')
)
ELSE true
END
-- Filter by status
AND CASE
-- @status needs to be a text because it can be empty, If it was
-- user_status enum, it would not.
WHEN cardinality($3 :: user_status[]) > 0 THEN
status = ANY($3 :: user_status[])
ELSE true
END
-- Filter by rbac_roles
AND CASE
-- @rbac_role allows filtering by rbac roles. If 'member' is included, show everyone, as everyone is a member.
WHEN cardinality($4 :: text[]) > 0 AND 'member' != ANY($4 :: text[])
THEN rbac_roles && $4 :: text[]
ELSE true
END
-- Authorize Filter clause will be injected below in GetAuthorizedUserCount
-- @authorize_filter
`
type GetFilteredUserCountParams struct {
Deleted bool `db:"deleted" json:"deleted"`
Search string `db:"search" json:"search"`
Status []UserStatus `db:"status" json:"status"`
RbacRole []string `db:"rbac_role" json:"rbac_role"`
}
func (q *sqlQuerier) GetFilteredUserCount(ctx context.Context, arg GetFilteredUserCountParams) (int64, error) {
row := q.db.QueryRowContext(ctx, getFilteredUserCount,
arg.Deleted,
arg.Search,
pq.Array(arg.Status),
pq.Array(arg.RbacRole),
)
var count int64
err := row.Scan(&count)
return count, err
}
const getUserByEmailOrUsername = `-- name: GetUserByEmailOrUsername :one
SELECT
id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type, avatar_url, deleted, last_seen_at

View File

@ -39,6 +39,41 @@ FROM
WHERE
status = 'active'::user_status AND deleted = false;
-- name: GetFilteredUserCount :one
SELECT
COUNT(*)
FROM
users
WHERE
users.deleted = @deleted
-- Start filters
-- Filter by name, email or username
AND CASE
WHEN @search :: text != '' THEN (
email ILIKE concat('%', @search, '%')
OR username ILIKE concat('%', @search, '%')
)
ELSE true
END
-- Filter by status
AND CASE
-- @status needs to be a text because it can be empty, If it was
-- user_status enum, it would not.
WHEN cardinality(@status :: user_status[]) > 0 THEN
status = ANY(@status :: user_status[])
ELSE true
END
-- Filter by rbac_roles
AND CASE
-- @rbac_role allows filtering by rbac roles. If 'member' is included, show everyone, as everyone is a member.
WHEN cardinality(@rbac_role :: text[]) > 0 AND 'member' != ANY(@rbac_role :: text[])
THEN rbac_roles && @rbac_role :: text[]
ELSE true
END
-- Authorize Filter clause will be injected below in GetAuthorizedUserCount
-- @authorize_filter
;
-- name: InsertUser :one
INSERT INTO
users (

View File

@ -251,6 +251,42 @@ func (api *API) users(rw http.ResponseWriter, r *http.Request) {
render.JSON(rw, r, convertUsers(users, organizationIDsByUserID))
}
func (api *API) userCount(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
query := r.URL.Query().Get("q")
params, errs := userSearchQuery(query)
if len(errs) > 0 {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid user search query.",
Validations: errs,
})
return
}
sqlFilter, err := api.HTTPAuth.AuthorizeSQLFilter(r, rbac.ActionRead, rbac.ResourceUser.Type)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error preparing sql filter.",
Detail: err.Error(),
})
return
}
count, err := api.Database.GetAuthorizedUserCount(ctx, database.GetFilteredUserCountParams{
Search: params.Search,
Status: params.Status,
RbacRole: params.RbacRole,
}, sqlFilter)
if err != nil {
httpapi.InternalServerError(rw, err)
return
}
httpapi.Write(ctx, rw, http.StatusOK, codersdk.UserCountResponse{
Count: count,
})
}
// Creates a new user.
func (api *API) postUser(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()

View File

@ -1255,6 +1255,58 @@ func TestGetUsers(t *testing.T) {
})
}
func TestGetFilteredUserCount(t *testing.T) {
t.Parallel()
t.Run("AllUsers", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
user := coderdtest.CreateFirstUser(t, client)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
client.CreateUser(ctx, codersdk.CreateUserRequest{
Email: "alice@email.com",
Username: "alice",
Password: "password",
OrganizationID: user.OrganizationID,
})
// No params is all users
response, err := client.UserCount(ctx, codersdk.UserCountRequest{})
require.NoError(t, err)
require.Equal(t, 2, int(response.Count))
})
t.Run("ActiveUsers", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
first := coderdtest.CreateFirstUser(t, client)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
_, err := client.User(ctx, first.UserID.String())
require.NoError(t, err, "")
// Alice will be suspended
alice, err := client.CreateUser(ctx, codersdk.CreateUserRequest{
Email: "alice@email.com",
Username: "alice",
Password: "password",
OrganizationID: first.OrganizationID,
})
require.NoError(t, err)
_, err = client.UpdateUserStatus(ctx, alice.Username, codersdk.UserStatusSuspended)
require.NoError(t, err)
response, err := client.UserCount(ctx, codersdk.UserCountRequest{
Status: codersdk.UserStatusActive,
})
require.NoError(t, err)
require.Equal(t, 1, int(response.Count))
})
}
func TestPostTokens(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)

View File

@ -166,7 +166,7 @@ func (api *API) workspaceCount(rw http.ResponseWriter, r *http.Request) {
filter, errs := workspaceSearchQuery(queryStr, codersdk.Pagination{})
if len(errs) > 0 {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid audit search query.",
Message: "Invalid workspace search query.",
Validations: errs,
})
return