mirror of
https://github.com/coder/coder.git
synced 2025-07-08 11:39:50 +00:00
feat: add count to get users endpoint (#5016)
This commit is contained in:
@ -431,7 +431,6 @@ func New(options *Options) *API {
|
||||
)
|
||||
r.Post("/", api.postUser)
|
||||
r.Get("/", api.users)
|
||||
r.Get("/count", api.userCount)
|
||||
r.Post("/logout", api.postLogout)
|
||||
// These routes query information about site wide roles.
|
||||
r.Route("/roles", func(r chi.Router) {
|
||||
|
@ -245,7 +245,6 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
|
||||
|
||||
// Endpoints that use the SQLQuery filter.
|
||||
"GET:/api/v2/workspaces/": {StatusCode: http.StatusOK, NoAuthorize: true},
|
||||
"GET:/api/v2/users/count": {StatusCode: http.StatusOK, NoAuthorize: true},
|
||||
}
|
||||
|
||||
// Routes like proxy routes support all HTTP methods. A helper func to expand
|
||||
|
@ -538,7 +538,7 @@ func (q *fakeQuerier) UpdateUserDeletedByID(_ context.Context, params database.U
|
||||
return sql.ErrNoRows
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) GetUsers(_ context.Context, params database.GetUsersParams) ([]database.User, error) {
|
||||
func (q *fakeQuerier) GetUsers(_ context.Context, params database.GetUsersParams) ([]database.GetUsersRow, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
@ -579,7 +579,7 @@ func (q *fakeQuerier) GetUsers(_ context.Context, params database.GetUsersParams
|
||||
|
||||
// If no users after the time, then we return an empty list.
|
||||
if !found {
|
||||
return nil, sql.ErrNoRows
|
||||
return []database.GetUsersRow{}, nil
|
||||
}
|
||||
}
|
||||
|
||||
@ -617,9 +617,11 @@ func (q *fakeQuerier) GetUsers(_ context.Context, params database.GetUsersParams
|
||||
users = usersFilteredByRole
|
||||
}
|
||||
|
||||
beforePageCount := len(users)
|
||||
|
||||
if params.OffsetOpt > 0 {
|
||||
if int(params.OffsetOpt) > len(users)-1 {
|
||||
return nil, sql.ErrNoRows
|
||||
return []database.GetUsersRow{}, nil
|
||||
}
|
||||
users = users[params.OffsetOpt:]
|
||||
}
|
||||
@ -631,7 +633,30 @@ func (q *fakeQuerier) GetUsers(_ context.Context, params database.GetUsersParams
|
||||
users = users[:params.LimitOpt]
|
||||
}
|
||||
|
||||
return users, nil
|
||||
return convertUsers(users, int64(beforePageCount)), nil
|
||||
}
|
||||
|
||||
func convertUsers(users []database.User, count int64) []database.GetUsersRow {
|
||||
rows := make([]database.GetUsersRow, len(users))
|
||||
for i, u := range users {
|
||||
rows[i] = database.GetUsersRow{
|
||||
ID: u.ID,
|
||||
Email: u.Email,
|
||||
Username: u.Username,
|
||||
HashedPassword: u.HashedPassword,
|
||||
CreatedAt: u.CreatedAt,
|
||||
UpdatedAt: u.UpdatedAt,
|
||||
Status: u.Status,
|
||||
RBACRoles: u.RBACRoles,
|
||||
LoginType: u.LoginType,
|
||||
AvatarURL: u.AvatarURL,
|
||||
Deleted: u.Deleted,
|
||||
LastSeenAt: u.LastSeenAt,
|
||||
Count: count,
|
||||
}
|
||||
}
|
||||
|
||||
return rows
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) GetUsersByIDs(_ context.Context, ids []uuid.UUID) ([]database.User, error) {
|
||||
|
@ -71,3 +71,25 @@ func (User) RBACObject() rbac.Object {
|
||||
func (License) RBACObject() rbac.Object {
|
||||
return rbac.ResourceLicense
|
||||
}
|
||||
|
||||
func ConvertUserRows(rows []GetUsersRow) []User {
|
||||
users := make([]User, len(rows))
|
||||
for i, r := range rows {
|
||||
users[i] = User{
|
||||
ID: r.ID,
|
||||
Email: r.Email,
|
||||
Username: r.Username,
|
||||
HashedPassword: r.HashedPassword,
|
||||
CreatedAt: r.CreatedAt,
|
||||
UpdatedAt: r.UpdatedAt,
|
||||
Status: r.Status,
|
||||
RBACRoles: r.RBACRoles,
|
||||
LoginType: r.LoginType,
|
||||
AvatarURL: r.AvatarURL,
|
||||
Deleted: r.Deleted,
|
||||
LastSeenAt: r.LastSeenAt,
|
||||
}
|
||||
}
|
||||
|
||||
return users
|
||||
}
|
||||
|
@ -93,7 +93,7 @@ type sqlcQuerier interface {
|
||||
GetUserGroups(ctx context.Context, userID uuid.UUID) ([]Group, error)
|
||||
GetUserLinkByLinkedID(ctx context.Context, linkedID string) (UserLink, error)
|
||||
GetUserLinkByUserIDLoginType(ctx context.Context, arg GetUserLinkByUserIDLoginTypeParams) (UserLink, error)
|
||||
GetUsers(ctx context.Context, arg GetUsersParams) ([]User, error)
|
||||
GetUsers(ctx context.Context, arg GetUsersParams) ([]GetUsersRow, error)
|
||||
// This shouldn't check for deleted, because it's frequently used
|
||||
// to look up references to actions. eg. a user could build a workspace
|
||||
// for another user, then be deleted... we still want them to appear!
|
||||
|
@ -4178,7 +4178,7 @@ func (q *sqlQuerier) GetUserCount(ctx context.Context) (int64, error) {
|
||||
|
||||
const getUsers = `-- name: GetUsers :many
|
||||
SELECT
|
||||
id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type, avatar_url, deleted, last_seen_at
|
||||
id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type, avatar_url, deleted, last_seen_at, COUNT(*) OVER() AS count
|
||||
FROM
|
||||
users
|
||||
WHERE
|
||||
@ -4247,7 +4247,23 @@ type GetUsersParams struct {
|
||||
LimitOpt int32 `db:"limit_opt" json:"limit_opt"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) GetUsers(ctx context.Context, arg GetUsersParams) ([]User, error) {
|
||||
type GetUsersRow struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
Email string `db:"email" json:"email"`
|
||||
Username string `db:"username" json:"username"`
|
||||
HashedPassword []byte `db:"hashed_password" json:"hashed_password"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
Status UserStatus `db:"status" json:"status"`
|
||||
RBACRoles pq.StringArray `db:"rbac_roles" json:"rbac_roles"`
|
||||
LoginType LoginType `db:"login_type" json:"login_type"`
|
||||
AvatarURL sql.NullString `db:"avatar_url" json:"avatar_url"`
|
||||
Deleted bool `db:"deleted" json:"deleted"`
|
||||
LastSeenAt time.Time `db:"last_seen_at" json:"last_seen_at"`
|
||||
Count int64 `db:"count" json:"count"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) GetUsers(ctx context.Context, arg GetUsersParams) ([]GetUsersRow, error) {
|
||||
rows, err := q.db.QueryContext(ctx, getUsers,
|
||||
arg.Deleted,
|
||||
arg.AfterID,
|
||||
@ -4261,9 +4277,9 @@ func (q *sqlQuerier) GetUsers(ctx context.Context, arg GetUsersParams) ([]User,
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []User
|
||||
var items []GetUsersRow
|
||||
for rows.Next() {
|
||||
var i User
|
||||
var i GetUsersRow
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.Email,
|
||||
@ -4277,6 +4293,7 @@ func (q *sqlQuerier) GetUsers(ctx context.Context, arg GetUsersParams) ([]User,
|
||||
&i.AvatarURL,
|
||||
&i.Deleted,
|
||||
&i.LastSeenAt,
|
||||
&i.Count,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -128,7 +128,7 @@ WHERE
|
||||
|
||||
-- name: GetUsers :many
|
||||
SELECT
|
||||
*
|
||||
*, COUNT(*) OVER() AS count
|
||||
FROM
|
||||
users
|
||||
WHERE
|
||||
|
@ -350,10 +350,11 @@ func (r *remoteReporter) createSnapshot() (*Snapshot, error) {
|
||||
return nil
|
||||
})
|
||||
eg.Go(func() error {
|
||||
users, err := r.options.Database.GetUsers(ctx, database.GetUsersParams{})
|
||||
userRows, err := r.options.Database.GetUsers(ctx, database.GetUsersParams{})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get users: %w", err)
|
||||
}
|
||||
users := database.ConvertUserRows(userRows)
|
||||
var firstUser database.User
|
||||
for _, dbUser := range users {
|
||||
if dbUser.Status != database.UserStatusActive {
|
||||
|
@ -198,7 +198,7 @@ func (api *API) users(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
users, err := api.Database.GetUsers(ctx, database.GetUsersParams{
|
||||
userRows, err := api.Database.GetUsers(ctx, database.GetUsersParams{
|
||||
AfterID: paginationParams.AfterID,
|
||||
OffsetOpt: int32(paginationParams.Offset),
|
||||
LimitOpt: int32(paginationParams.Limit),
|
||||
@ -206,10 +206,6 @@ func (api *API) users(rw http.ResponseWriter, r *http.Request) {
|
||||
Status: params.Status,
|
||||
RbacRole: params.RbacRole,
|
||||
})
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
httpapi.Write(ctx, rw, http.StatusOK, []codersdk.User{})
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching users.",
|
||||
@ -217,8 +213,17 @@ func (api *API) users(rw http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
return
|
||||
}
|
||||
// GetUsers does not return ErrNoRows because it uses a window function to get the count.
|
||||
// So we need to check if the userRows is empty and return an empty array if so.
|
||||
if len(userRows) == 0 {
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.GetUsersResponse{
|
||||
Users: []codersdk.User{},
|
||||
Count: 0,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
users, err = AuthorizeFilter(api.HTTPAuth, r, rbac.ActionRead, users)
|
||||
users, err := AuthorizeFilter(api.HTTPAuth, r, rbac.ActionRead, database.ConvertUserRows(userRows))
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching users.",
|
||||
@ -248,42 +253,9 @@ func (api *API) users(rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
render.Status(r, http.StatusOK)
|
||||
render.JSON(rw, r, convertUsers(users, organizationIDsByUserID))
|
||||
}
|
||||
|
||||
func (api *API) userCount(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
query := r.URL.Query().Get("q")
|
||||
params, errs := userSearchQuery(query)
|
||||
if len(errs) > 0 {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid user search query.",
|
||||
Validations: errs,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
sqlFilter, err := api.HTTPAuth.AuthorizeSQLFilter(r, rbac.ActionRead, rbac.ResourceUser.Type)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error preparing sql filter.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
count, err := api.Database.GetAuthorizedUserCount(ctx, database.GetFilteredUserCountParams{
|
||||
Search: params.Search,
|
||||
Status: params.Status,
|
||||
RbacRole: params.RbacRole,
|
||||
}, sqlFilter)
|
||||
if err != nil {
|
||||
httpapi.InternalServerError(rw, err)
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.UserCountResponse{
|
||||
Count: count,
|
||||
render.JSON(rw, r, codersdk.GetUsersResponse{
|
||||
Users: convertUsers(users, organizationIDsByUserID),
|
||||
Count: int(userRows[0].Count),
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -78,14 +78,14 @@ func TestFirstUser(t *testing.T) {
|
||||
|
||||
_ = coderdtest.CreateAnotherUser(t, client, firstUserResp.OrganizationID)
|
||||
|
||||
allUsers, err := client.Users(ctx, codersdk.UsersRequest{})
|
||||
allUsersRes, err := client.Users(ctx, codersdk.UsersRequest{})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, allUsers, 2)
|
||||
require.Len(t, allUsersRes.Users, 2)
|
||||
|
||||
// We sent the "GET Users" request with the first user, but the second user
|
||||
// should be Never since they haven't performed a request.
|
||||
for _, user := range allUsers {
|
||||
for _, user := range allUsersRes.Users {
|
||||
if user.ID == firstUser.ID {
|
||||
require.WithinDuration(t, firstUser.LastSeenAt, database.Now(), testutil.WaitShort)
|
||||
} else {
|
||||
@ -1186,7 +1186,7 @@ func TestUsersFilter(t *testing.T) {
|
||||
exp = append(exp, made)
|
||||
}
|
||||
}
|
||||
require.ElementsMatch(t, exp, matched, "expected workspaces returned")
|
||||
require.ElementsMatch(t, exp, matched.Users, "expected workspaces returned")
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -1208,10 +1208,10 @@ func TestGetUsers(t *testing.T) {
|
||||
OrganizationID: user.OrganizationID,
|
||||
})
|
||||
// No params is all users
|
||||
users, err := client.Users(ctx, codersdk.UsersRequest{})
|
||||
res, err := client.Users(ctx, codersdk.UsersRequest{})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, users, 2)
|
||||
require.Len(t, users[0].OrganizationIDs, 1)
|
||||
require.Len(t, res.Users, 2)
|
||||
require.Len(t, res.Users[0].OrganizationIDs, 1)
|
||||
})
|
||||
t.Run("ActiveUsers", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
@ -1247,64 +1247,66 @@ func TestGetUsers(t *testing.T) {
|
||||
_, err = client.UpdateUserStatus(ctx, alice.Username, codersdk.UserStatusSuspended)
|
||||
require.NoError(t, err)
|
||||
|
||||
users, err := client.Users(ctx, codersdk.UsersRequest{
|
||||
res, err := client.Users(ctx, codersdk.UsersRequest{
|
||||
Status: codersdk.UserStatusActive,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.ElementsMatch(t, active, users)
|
||||
require.ElementsMatch(t, active, res.Users)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetFilteredUserCount(t *testing.T) {
|
||||
func TestGetUsersPagination(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("AllUsers", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
client := coderdtest.New(t, nil)
|
||||
first := coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
|
||||
client.CreateUser(ctx, codersdk.CreateUserRequest{
|
||||
Email: "alice@email.com",
|
||||
Username: "alice",
|
||||
Password: "password",
|
||||
OrganizationID: user.OrganizationID,
|
||||
})
|
||||
// No params is all users
|
||||
response, err := client.UserCount(ctx, codersdk.UserCountRequest{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 2, int(response.Count))
|
||||
_, err := client.User(ctx, first.UserID.String())
|
||||
require.NoError(t, err, "")
|
||||
|
||||
_, err = client.CreateUser(ctx, codersdk.CreateUserRequest{
|
||||
Email: "alice@email.com",
|
||||
Username: "alice",
|
||||
Password: "password",
|
||||
OrganizationID: first.OrganizationID,
|
||||
})
|
||||
t.Run("ActiveUsers", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
first := coderdtest.CreateFirstUser(t, client)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
res, err := client.Users(ctx, codersdk.UsersRequest{})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, res.Users, 2)
|
||||
require.Equal(t, res.Count, 2)
|
||||
|
||||
_, err := client.User(ctx, first.UserID.String())
|
||||
require.NoError(t, err, "")
|
||||
|
||||
// Alice will be suspended
|
||||
alice, err := client.CreateUser(ctx, codersdk.CreateUserRequest{
|
||||
Email: "alice@email.com",
|
||||
Username: "alice",
|
||||
Password: "password",
|
||||
OrganizationID: first.OrganizationID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.UpdateUserStatus(ctx, alice.Username, codersdk.UserStatusSuspended)
|
||||
require.NoError(t, err)
|
||||
|
||||
response, err := client.UserCount(ctx, codersdk.UserCountRequest{
|
||||
Status: codersdk.UserStatusActive,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, int(response.Count))
|
||||
res, err = client.Users(ctx, codersdk.UsersRequest{
|
||||
Pagination: codersdk.Pagination{
|
||||
Limit: 1,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, res.Users, 1)
|
||||
require.Equal(t, res.Count, 2)
|
||||
|
||||
res, err = client.Users(ctx, codersdk.UsersRequest{
|
||||
Pagination: codersdk.Pagination{
|
||||
Offset: 1,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, res.Users, 1)
|
||||
require.Equal(t, res.Count, 2)
|
||||
|
||||
// if offset is higher than the count postgres returns an empty array
|
||||
// and not an ErrNoRows error. This also means the count must be 0.
|
||||
res, err = client.Users(ctx, codersdk.UsersRequest{
|
||||
Pagination: codersdk.Pagination{
|
||||
Offset: 3,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, res.Users, 0)
|
||||
require.Equal(t, res.Count, 0)
|
||||
}
|
||||
|
||||
func TestPostTokens(t *testing.T) {
|
||||
@ -1420,7 +1422,7 @@ func TestSuspendedPagination(t *testing.T) {
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expected, page, "expected page")
|
||||
require.Equal(t, expected, page.Users, "expected page")
|
||||
}
|
||||
|
||||
// TestPaginatedUsers creates a list of users, then tries to paginate through
|
||||
@ -1546,15 +1548,15 @@ func assertPagination(ctx context.Context, t *testing.T, client *codersdk.Client
|
||||
},
|
||||
}))
|
||||
require.NoError(t, err, "first page")
|
||||
require.Equalf(t, page, allUsers[:limit], "first page, limit=%d", limit)
|
||||
count += len(page)
|
||||
require.Equalf(t, page.Users, allUsers[:limit], "first page, limit=%d", limit)
|
||||
count += len(page.Users)
|
||||
|
||||
for {
|
||||
if len(page) == 0 {
|
||||
if len(page.Users) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
afterCursor := page[len(page)-1].ID
|
||||
afterCursor := page.Users[len(page.Users)-1].ID
|
||||
// Assert each page is the next expected page
|
||||
// This is using a cursor, and only works if all users created_at
|
||||
// is unique.
|
||||
@ -1581,8 +1583,8 @@ func assertPagination(ctx context.Context, t *testing.T, client *codersdk.Client
|
||||
} else {
|
||||
expected = allUsers[count : count+limit]
|
||||
}
|
||||
require.Equalf(t, page, expected, "next users, after=%s, limit=%d", afterCursor, limit)
|
||||
require.Equalf(t, offsetPage, expected, "offset users, offset=%d, limit=%d", count, limit)
|
||||
require.Equalf(t, page.Users, expected, "next users, after=%s, limit=%d", afterCursor, limit)
|
||||
require.Equalf(t, offsetPage.Users, expected, "offset users, offset=%d, limit=%d", count, limit)
|
||||
|
||||
// Also check the before
|
||||
prevPage, err := client.Users(ctx, opt(codersdk.UsersRequest{
|
||||
@ -1592,8 +1594,8 @@ func assertPagination(ctx context.Context, t *testing.T, client *codersdk.Client
|
||||
},
|
||||
}))
|
||||
require.NoError(t, err, "prev page")
|
||||
require.Equal(t, allUsers[count-limit:count], prevPage, "prev users")
|
||||
count += len(page)
|
||||
require.Equal(t, allUsers[count-limit:count], prevPage.Users, "prev users")
|
||||
count += len(page.Users)
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user