mirror of
https://github.com/coder/coder.git
synced 2025-07-08 11:39:50 +00:00
feat: add count endpoint for users, enabling better pagination (#4848)
* Start on backend * Hook up frontend * Add to frontend test * Add go test, wip * Fix some test bugs * Fix test * Format * Add to authorize.go * copy user array into local variable * Authorize route * Log count error * Authorize better * Tweaks to authorization * More authorization tweaks * Make gen * Fix test Co-authored-by: Garrett <garrett@coder.com>
This commit is contained in:
@ -437,6 +437,7 @@ func New(options *Options) *API {
|
||||
)
|
||||
r.Post("/", api.postUser)
|
||||
r.Get("/", api.users)
|
||||
r.Get("/count", api.userCount)
|
||||
r.Post("/logout", api.postLogout)
|
||||
// These routes query information about site wide roles.
|
||||
r.Route("/roles", func(r chi.Router) {
|
||||
|
@ -246,6 +246,7 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
|
||||
// Endpoints that use the SQLQuery filter.
|
||||
"GET:/api/v2/workspaces/": {StatusCode: http.StatusOK, NoAuthorize: true},
|
||||
"GET:/api/v2/workspaces/count": {StatusCode: http.StatusOK, NoAuthorize: true},
|
||||
"GET:/api/v2/users/count": {StatusCode: http.StatusOK, NoAuthorize: true},
|
||||
}
|
||||
|
||||
// Routes like proxy routes support all HTTP methods. A helper func to expand
|
||||
|
@ -457,6 +457,72 @@ func (q *fakeQuerier) GetActiveUserCount(_ context.Context) (int64, error) {
|
||||
return active, nil
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) GetFilteredUserCount(ctx context.Context, arg database.GetFilteredUserCountParams) (int64, error) {
|
||||
count, err := q.GetAuthorizedUserCount(ctx, arg, nil)
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) GetAuthorizedUserCount(_ context.Context, params database.GetFilteredUserCountParams, authorizedFilter rbac.AuthorizeFilter) (int64, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
users := append([]database.User{}, q.users...)
|
||||
|
||||
if params.Deleted {
|
||||
tmp := make([]database.User, 0, len(users))
|
||||
for _, user := range users {
|
||||
if user.Deleted {
|
||||
tmp = append(tmp, user)
|
||||
}
|
||||
}
|
||||
users = tmp
|
||||
}
|
||||
|
||||
if params.Search != "" {
|
||||
tmp := make([]database.User, 0, len(users))
|
||||
for i, user := range users {
|
||||
if strings.Contains(strings.ToLower(user.Email), strings.ToLower(params.Search)) {
|
||||
tmp = append(tmp, users[i])
|
||||
} else if strings.Contains(strings.ToLower(user.Username), strings.ToLower(params.Search)) {
|
||||
tmp = append(tmp, users[i])
|
||||
}
|
||||
}
|
||||
users = tmp
|
||||
}
|
||||
|
||||
if len(params.Status) > 0 {
|
||||
usersFilteredByStatus := make([]database.User, 0, len(users))
|
||||
for i, user := range users {
|
||||
if slice.ContainsCompare(params.Status, user.Status, func(a, b database.UserStatus) bool {
|
||||
return strings.EqualFold(string(a), string(b))
|
||||
}) {
|
||||
usersFilteredByStatus = append(usersFilteredByStatus, users[i])
|
||||
}
|
||||
}
|
||||
users = usersFilteredByStatus
|
||||
}
|
||||
|
||||
if len(params.RbacRole) > 0 && !slice.Contains(params.RbacRole, rbac.RoleMember()) {
|
||||
usersFilteredByRole := make([]database.User, 0, len(users))
|
||||
for i, user := range users {
|
||||
if slice.OverlapCompare(params.RbacRole, user.RBACRoles, strings.EqualFold) {
|
||||
usersFilteredByRole = append(usersFilteredByRole, users[i])
|
||||
}
|
||||
}
|
||||
|
||||
users = usersFilteredByRole
|
||||
}
|
||||
|
||||
for _, user := range q.workspaces {
|
||||
// If the filter exists, ensure the object is authorized.
|
||||
if authorizedFilter != nil && !authorizedFilter.Eval(user.RBACObject()) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
return int64(len(users)), nil
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) UpdateUserDeletedByID(_ context.Context, params database.UpdateUserDeletedByIDParams) error {
|
||||
q.mutex.Lock()
|
||||
defer q.mutex.Unlock()
|
||||
|
@ -19,6 +19,7 @@ import (
|
||||
type customQuerier interface {
|
||||
templateQuerier
|
||||
workspaceQuerier
|
||||
userQuerier
|
||||
}
|
||||
|
||||
type templateQuerier interface {
|
||||
@ -169,8 +170,6 @@ func (q *sqlQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg GetWorkspa
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) GetAuthorizedWorkspaceCount(ctx context.Context, arg GetWorkspaceCountParams, authorizedFilter rbac.AuthorizeFilter) (int64, error) {
|
||||
// In order to properly use ORDER BY, OFFSET, and LIMIT, we need to inject the
|
||||
// authorizedFilter between the end of the where clause and those statements.
|
||||
filter := strings.Replace(getWorkspaceCount, "-- @authorize_filter", fmt.Sprintf(" AND %s", authorizedFilter.SQLString(rbac.NoACLConfig())), 1)
|
||||
// The name comment is for metric tracking
|
||||
query := fmt.Sprintf("-- name: GetAuthorizedWorkspaceCount :one\n%s", filter)
|
||||
@ -187,3 +186,21 @@ func (q *sqlQuerier) GetAuthorizedWorkspaceCount(ctx context.Context, arg GetWor
|
||||
err := row.Scan(&count)
|
||||
return count, err
|
||||
}
|
||||
|
||||
type userQuerier interface {
|
||||
GetAuthorizedUserCount(ctx context.Context, arg GetFilteredUserCountParams, authorizedFilter rbac.AuthorizeFilter) (int64, error)
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) GetAuthorizedUserCount(ctx context.Context, arg GetFilteredUserCountParams, authorizedFilter rbac.AuthorizeFilter) (int64, error) {
|
||||
filter := strings.Replace(getFilteredUserCount, "-- @authorize_filter", fmt.Sprintf(" AND %s", authorizedFilter.SQLString(rbac.NoACLConfig())), 1)
|
||||
query := fmt.Sprintf("-- name: GetAuthorizedUserCount :one\n%s", filter)
|
||||
row := q.db.QueryRowContext(ctx, query,
|
||||
arg.Deleted,
|
||||
arg.Search,
|
||||
pq.Array(arg.Status),
|
||||
pq.Array(arg.RbacRole),
|
||||
)
|
||||
var count int64
|
||||
err := row.Scan(&count)
|
||||
return count, err
|
||||
}
|
||||
|
@ -44,6 +44,7 @@ type sqlcQuerier interface {
|
||||
GetDeploymentID(ctx context.Context) (string, error)
|
||||
GetFileByHashAndCreator(ctx context.Context, arg GetFileByHashAndCreatorParams) (File, error)
|
||||
GetFileByID(ctx context.Context, id uuid.UUID) (File, error)
|
||||
GetFilteredUserCount(ctx context.Context, arg GetFilteredUserCountParams) (int64, error)
|
||||
GetGitAuthLink(ctx context.Context, arg GetGitAuthLinkParams) (GitAuthLink, error)
|
||||
GetGitSSHKey(ctx context.Context, userID uuid.UUID) (GitSSHKey, error)
|
||||
GetGroupByID(ctx context.Context, id uuid.UUID) (Group, error)
|
||||
|
@ -3975,6 +3975,60 @@ func (q *sqlQuerier) GetAuthorizationUserRoles(ctx context.Context, userID uuid.
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getFilteredUserCount = `-- name: GetFilteredUserCount :one
|
||||
SELECT
|
||||
COUNT(*)
|
||||
FROM
|
||||
users
|
||||
WHERE
|
||||
users.deleted = $1
|
||||
-- Start filters
|
||||
-- Filter by name, email or username
|
||||
AND CASE
|
||||
WHEN $2 :: text != '' THEN (
|
||||
email ILIKE concat('%', $2, '%')
|
||||
OR username ILIKE concat('%', $2, '%')
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by status
|
||||
AND CASE
|
||||
-- @status needs to be a text because it can be empty, If it was
|
||||
-- user_status enum, it would not.
|
||||
WHEN cardinality($3 :: user_status[]) > 0 THEN
|
||||
status = ANY($3 :: user_status[])
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by rbac_roles
|
||||
AND CASE
|
||||
-- @rbac_role allows filtering by rbac roles. If 'member' is included, show everyone, as everyone is a member.
|
||||
WHEN cardinality($4 :: text[]) > 0 AND 'member' != ANY($4 :: text[])
|
||||
THEN rbac_roles && $4 :: text[]
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in GetAuthorizedUserCount
|
||||
-- @authorize_filter
|
||||
`
|
||||
|
||||
type GetFilteredUserCountParams struct {
|
||||
Deleted bool `db:"deleted" json:"deleted"`
|
||||
Search string `db:"search" json:"search"`
|
||||
Status []UserStatus `db:"status" json:"status"`
|
||||
RbacRole []string `db:"rbac_role" json:"rbac_role"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) GetFilteredUserCount(ctx context.Context, arg GetFilteredUserCountParams) (int64, error) {
|
||||
row := q.db.QueryRowContext(ctx, getFilteredUserCount,
|
||||
arg.Deleted,
|
||||
arg.Search,
|
||||
pq.Array(arg.Status),
|
||||
pq.Array(arg.RbacRole),
|
||||
)
|
||||
var count int64
|
||||
err := row.Scan(&count)
|
||||
return count, err
|
||||
}
|
||||
|
||||
const getUserByEmailOrUsername = `-- name: GetUserByEmailOrUsername :one
|
||||
SELECT
|
||||
id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type, avatar_url, deleted, last_seen_at
|
||||
|
@ -39,6 +39,41 @@ FROM
|
||||
WHERE
|
||||
status = 'active'::user_status AND deleted = false;
|
||||
|
||||
-- name: GetFilteredUserCount :one
|
||||
SELECT
|
||||
COUNT(*)
|
||||
FROM
|
||||
users
|
||||
WHERE
|
||||
users.deleted = @deleted
|
||||
-- Start filters
|
||||
-- Filter by name, email or username
|
||||
AND CASE
|
||||
WHEN @search :: text != '' THEN (
|
||||
email ILIKE concat('%', @search, '%')
|
||||
OR username ILIKE concat('%', @search, '%')
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by status
|
||||
AND CASE
|
||||
-- @status needs to be a text because it can be empty, If it was
|
||||
-- user_status enum, it would not.
|
||||
WHEN cardinality(@status :: user_status[]) > 0 THEN
|
||||
status = ANY(@status :: user_status[])
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by rbac_roles
|
||||
AND CASE
|
||||
-- @rbac_role allows filtering by rbac roles. If 'member' is included, show everyone, as everyone is a member.
|
||||
WHEN cardinality(@rbac_role :: text[]) > 0 AND 'member' != ANY(@rbac_role :: text[])
|
||||
THEN rbac_roles && @rbac_role :: text[]
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in GetAuthorizedUserCount
|
||||
-- @authorize_filter
|
||||
;
|
||||
|
||||
-- name: InsertUser :one
|
||||
INSERT INTO
|
||||
users (
|
||||
|
@ -251,6 +251,42 @@ func (api *API) users(rw http.ResponseWriter, r *http.Request) {
|
||||
render.JSON(rw, r, convertUsers(users, organizationIDsByUserID))
|
||||
}
|
||||
|
||||
func (api *API) userCount(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
query := r.URL.Query().Get("q")
|
||||
params, errs := userSearchQuery(query)
|
||||
if len(errs) > 0 {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid user search query.",
|
||||
Validations: errs,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
sqlFilter, err := api.HTTPAuth.AuthorizeSQLFilter(r, rbac.ActionRead, rbac.ResourceUser.Type)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error preparing sql filter.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
count, err := api.Database.GetAuthorizedUserCount(ctx, database.GetFilteredUserCountParams{
|
||||
Search: params.Search,
|
||||
Status: params.Status,
|
||||
RbacRole: params.RbacRole,
|
||||
}, sqlFilter)
|
||||
if err != nil {
|
||||
httpapi.InternalServerError(rw, err)
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.UserCountResponse{
|
||||
Count: count,
|
||||
})
|
||||
}
|
||||
|
||||
// Creates a new user.
|
||||
func (api *API) postUser(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
@ -1255,6 +1255,58 @@ func TestGetUsers(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetFilteredUserCount(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("AllUsers", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
client.CreateUser(ctx, codersdk.CreateUserRequest{
|
||||
Email: "alice@email.com",
|
||||
Username: "alice",
|
||||
Password: "password",
|
||||
OrganizationID: user.OrganizationID,
|
||||
})
|
||||
// No params is all users
|
||||
response, err := client.UserCount(ctx, codersdk.UserCountRequest{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 2, int(response.Count))
|
||||
})
|
||||
t.Run("ActiveUsers", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
first := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
_, err := client.User(ctx, first.UserID.String())
|
||||
require.NoError(t, err, "")
|
||||
|
||||
// Alice will be suspended
|
||||
alice, err := client.CreateUser(ctx, codersdk.CreateUserRequest{
|
||||
Email: "alice@email.com",
|
||||
Username: "alice",
|
||||
Password: "password",
|
||||
OrganizationID: first.OrganizationID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.UpdateUserStatus(ctx, alice.Username, codersdk.UserStatusSuspended)
|
||||
require.NoError(t, err)
|
||||
|
||||
response, err := client.UserCount(ctx, codersdk.UserCountRequest{
|
||||
Status: codersdk.UserStatusActive,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, int(response.Count))
|
||||
})
|
||||
}
|
||||
|
||||
func TestPostTokens(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
|
@ -166,7 +166,7 @@ func (api *API) workspaceCount(rw http.ResponseWriter, r *http.Request) {
|
||||
filter, errs := workspaceSearchQuery(queryStr, codersdk.Pagination{})
|
||||
if len(errs) > 0 {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid audit search query.",
|
||||
Message: "Invalid workspace search query.",
|
||||
Validations: errs,
|
||||
})
|
||||
return
|
||||
|
Reference in New Issue
Block a user