feat: add count to get users endpoint (#5016)

This commit is contained in:
Garrett Delfosse
2022-11-14 17:22:57 -05:00
committed by GitHub
parent 49b340e039
commit 88f3691dcc
25 changed files with 425 additions and 483 deletions

View File

@ -431,7 +431,6 @@ func New(options *Options) *API {
)
r.Post("/", api.postUser)
r.Get("/", api.users)
r.Get("/count", api.userCount)
r.Post("/logout", api.postLogout)
// These routes query information about site wide roles.
r.Route("/roles", func(r chi.Router) {

View File

@ -245,7 +245,6 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
// Endpoints that use the SQLQuery filter.
"GET:/api/v2/workspaces/": {StatusCode: http.StatusOK, NoAuthorize: true},
"GET:/api/v2/users/count": {StatusCode: http.StatusOK, NoAuthorize: true},
}
// Routes like proxy routes support all HTTP methods. A helper func to expand

View File

@ -538,7 +538,7 @@ func (q *fakeQuerier) UpdateUserDeletedByID(_ context.Context, params database.U
return sql.ErrNoRows
}
func (q *fakeQuerier) GetUsers(_ context.Context, params database.GetUsersParams) ([]database.User, error) {
func (q *fakeQuerier) GetUsers(_ context.Context, params database.GetUsersParams) ([]database.GetUsersRow, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
@ -579,7 +579,7 @@ func (q *fakeQuerier) GetUsers(_ context.Context, params database.GetUsersParams
// If no users after the time, then we return an empty list.
if !found {
return nil, sql.ErrNoRows
return []database.GetUsersRow{}, nil
}
}
@ -617,9 +617,11 @@ func (q *fakeQuerier) GetUsers(_ context.Context, params database.GetUsersParams
users = usersFilteredByRole
}
beforePageCount := len(users)
if params.OffsetOpt > 0 {
if int(params.OffsetOpt) > len(users)-1 {
return nil, sql.ErrNoRows
return []database.GetUsersRow{}, nil
}
users = users[params.OffsetOpt:]
}
@ -631,7 +633,30 @@ func (q *fakeQuerier) GetUsers(_ context.Context, params database.GetUsersParams
users = users[:params.LimitOpt]
}
return users, nil
return convertUsers(users, int64(beforePageCount)), nil
}
func convertUsers(users []database.User, count int64) []database.GetUsersRow {
rows := make([]database.GetUsersRow, len(users))
for i, u := range users {
rows[i] = database.GetUsersRow{
ID: u.ID,
Email: u.Email,
Username: u.Username,
HashedPassword: u.HashedPassword,
CreatedAt: u.CreatedAt,
UpdatedAt: u.UpdatedAt,
Status: u.Status,
RBACRoles: u.RBACRoles,
LoginType: u.LoginType,
AvatarURL: u.AvatarURL,
Deleted: u.Deleted,
LastSeenAt: u.LastSeenAt,
Count: count,
}
}
return rows
}
func (q *fakeQuerier) GetUsersByIDs(_ context.Context, ids []uuid.UUID) ([]database.User, error) {

View File

@ -71,3 +71,25 @@ func (User) RBACObject() rbac.Object {
func (License) RBACObject() rbac.Object {
return rbac.ResourceLicense
}
func ConvertUserRows(rows []GetUsersRow) []User {
users := make([]User, len(rows))
for i, r := range rows {
users[i] = User{
ID: r.ID,
Email: r.Email,
Username: r.Username,
HashedPassword: r.HashedPassword,
CreatedAt: r.CreatedAt,
UpdatedAt: r.UpdatedAt,
Status: r.Status,
RBACRoles: r.RBACRoles,
LoginType: r.LoginType,
AvatarURL: r.AvatarURL,
Deleted: r.Deleted,
LastSeenAt: r.LastSeenAt,
}
}
return users
}

View File

@ -93,7 +93,7 @@ type sqlcQuerier interface {
GetUserGroups(ctx context.Context, userID uuid.UUID) ([]Group, error)
GetUserLinkByLinkedID(ctx context.Context, linkedID string) (UserLink, error)
GetUserLinkByUserIDLoginType(ctx context.Context, arg GetUserLinkByUserIDLoginTypeParams) (UserLink, error)
GetUsers(ctx context.Context, arg GetUsersParams) ([]User, error)
GetUsers(ctx context.Context, arg GetUsersParams) ([]GetUsersRow, error)
// This shouldn't check for deleted, because it's frequently used
// to look up references to actions. eg. a user could build a workspace
// for another user, then be deleted... we still want them to appear!

View File

@ -4178,7 +4178,7 @@ func (q *sqlQuerier) GetUserCount(ctx context.Context) (int64, error) {
const getUsers = `-- name: GetUsers :many
SELECT
id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type, avatar_url, deleted, last_seen_at
id, email, username, hashed_password, created_at, updated_at, status, rbac_roles, login_type, avatar_url, deleted, last_seen_at, COUNT(*) OVER() AS count
FROM
users
WHERE
@ -4247,7 +4247,23 @@ type GetUsersParams struct {
LimitOpt int32 `db:"limit_opt" json:"limit_opt"`
}
func (q *sqlQuerier) GetUsers(ctx context.Context, arg GetUsersParams) ([]User, error) {
type GetUsersRow struct {
ID uuid.UUID `db:"id" json:"id"`
Email string `db:"email" json:"email"`
Username string `db:"username" json:"username"`
HashedPassword []byte `db:"hashed_password" json:"hashed_password"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
Status UserStatus `db:"status" json:"status"`
RBACRoles pq.StringArray `db:"rbac_roles" json:"rbac_roles"`
LoginType LoginType `db:"login_type" json:"login_type"`
AvatarURL sql.NullString `db:"avatar_url" json:"avatar_url"`
Deleted bool `db:"deleted" json:"deleted"`
LastSeenAt time.Time `db:"last_seen_at" json:"last_seen_at"`
Count int64 `db:"count" json:"count"`
}
func (q *sqlQuerier) GetUsers(ctx context.Context, arg GetUsersParams) ([]GetUsersRow, error) {
rows, err := q.db.QueryContext(ctx, getUsers,
arg.Deleted,
arg.AfterID,
@ -4261,9 +4277,9 @@ func (q *sqlQuerier) GetUsers(ctx context.Context, arg GetUsersParams) ([]User,
return nil, err
}
defer rows.Close()
var items []User
var items []GetUsersRow
for rows.Next() {
var i User
var i GetUsersRow
if err := rows.Scan(
&i.ID,
&i.Email,
@ -4277,6 +4293,7 @@ func (q *sqlQuerier) GetUsers(ctx context.Context, arg GetUsersParams) ([]User,
&i.AvatarURL,
&i.Deleted,
&i.LastSeenAt,
&i.Count,
); err != nil {
return nil, err
}

View File

@ -128,7 +128,7 @@ WHERE
-- name: GetUsers :many
SELECT
*
*, COUNT(*) OVER() AS count
FROM
users
WHERE

View File

@ -350,10 +350,11 @@ func (r *remoteReporter) createSnapshot() (*Snapshot, error) {
return nil
})
eg.Go(func() error {
users, err := r.options.Database.GetUsers(ctx, database.GetUsersParams{})
userRows, err := r.options.Database.GetUsers(ctx, database.GetUsersParams{})
if err != nil {
return xerrors.Errorf("get users: %w", err)
}
users := database.ConvertUserRows(userRows)
var firstUser database.User
for _, dbUser := range users {
if dbUser.Status != database.UserStatusActive {

View File

@ -198,7 +198,7 @@ func (api *API) users(rw http.ResponseWriter, r *http.Request) {
return
}
users, err := api.Database.GetUsers(ctx, database.GetUsersParams{
userRows, err := api.Database.GetUsers(ctx, database.GetUsersParams{
AfterID: paginationParams.AfterID,
OffsetOpt: int32(paginationParams.Offset),
LimitOpt: int32(paginationParams.Limit),
@ -206,10 +206,6 @@ func (api *API) users(rw http.ResponseWriter, r *http.Request) {
Status: params.Status,
RbacRole: params.RbacRole,
})
if errors.Is(err, sql.ErrNoRows) {
httpapi.Write(ctx, rw, http.StatusOK, []codersdk.User{})
return
}
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching users.",
@ -217,8 +213,17 @@ func (api *API) users(rw http.ResponseWriter, r *http.Request) {
})
return
}
// GetUsers does not return ErrNoRows because it uses a window function to get the count.
// So we need to check if the userRows is empty and return an empty array if so.
if len(userRows) == 0 {
httpapi.Write(ctx, rw, http.StatusOK, codersdk.GetUsersResponse{
Users: []codersdk.User{},
Count: 0,
})
return
}
users, err = AuthorizeFilter(api.HTTPAuth, r, rbac.ActionRead, users)
users, err := AuthorizeFilter(api.HTTPAuth, r, rbac.ActionRead, database.ConvertUserRows(userRows))
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching users.",
@ -248,42 +253,9 @@ func (api *API) users(rw http.ResponseWriter, r *http.Request) {
}
render.Status(r, http.StatusOK)
render.JSON(rw, r, convertUsers(users, organizationIDsByUserID))
}
func (api *API) userCount(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
query := r.URL.Query().Get("q")
params, errs := userSearchQuery(query)
if len(errs) > 0 {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid user search query.",
Validations: errs,
})
return
}
sqlFilter, err := api.HTTPAuth.AuthorizeSQLFilter(r, rbac.ActionRead, rbac.ResourceUser.Type)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error preparing sql filter.",
Detail: err.Error(),
})
return
}
count, err := api.Database.GetAuthorizedUserCount(ctx, database.GetFilteredUserCountParams{
Search: params.Search,
Status: params.Status,
RbacRole: params.RbacRole,
}, sqlFilter)
if err != nil {
httpapi.InternalServerError(rw, err)
return
}
httpapi.Write(ctx, rw, http.StatusOK, codersdk.UserCountResponse{
Count: count,
render.JSON(rw, r, codersdk.GetUsersResponse{
Users: convertUsers(users, organizationIDsByUserID),
Count: int(userRows[0].Count),
})
}

View File

@ -78,14 +78,14 @@ func TestFirstUser(t *testing.T) {
_ = coderdtest.CreateAnotherUser(t, client, firstUserResp.OrganizationID)
allUsers, err := client.Users(ctx, codersdk.UsersRequest{})
allUsersRes, err := client.Users(ctx, codersdk.UsersRequest{})
require.NoError(t, err)
require.Len(t, allUsers, 2)
require.Len(t, allUsersRes.Users, 2)
// We sent the "GET Users" request with the first user, but the second user
// should be Never since they haven't performed a request.
for _, user := range allUsers {
for _, user := range allUsersRes.Users {
if user.ID == firstUser.ID {
require.WithinDuration(t, firstUser.LastSeenAt, database.Now(), testutil.WaitShort)
} else {
@ -1186,7 +1186,7 @@ func TestUsersFilter(t *testing.T) {
exp = append(exp, made)
}
}
require.ElementsMatch(t, exp, matched, "expected workspaces returned")
require.ElementsMatch(t, exp, matched.Users, "expected workspaces returned")
})
}
}
@ -1208,10 +1208,10 @@ func TestGetUsers(t *testing.T) {
OrganizationID: user.OrganizationID,
})
// No params is all users
users, err := client.Users(ctx, codersdk.UsersRequest{})
res, err := client.Users(ctx, codersdk.UsersRequest{})
require.NoError(t, err)
require.Len(t, users, 2)
require.Len(t, users[0].OrganizationIDs, 1)
require.Len(t, res.Users, 2)
require.Len(t, res.Users[0].OrganizationIDs, 1)
})
t.Run("ActiveUsers", func(t *testing.T) {
t.Parallel()
@ -1247,64 +1247,66 @@ func TestGetUsers(t *testing.T) {
_, err = client.UpdateUserStatus(ctx, alice.Username, codersdk.UserStatusSuspended)
require.NoError(t, err)
users, err := client.Users(ctx, codersdk.UsersRequest{
res, err := client.Users(ctx, codersdk.UsersRequest{
Status: codersdk.UserStatusActive,
})
require.NoError(t, err)
require.ElementsMatch(t, active, users)
require.ElementsMatch(t, active, res.Users)
})
}
func TestGetFilteredUserCount(t *testing.T) {
func TestGetUsersPagination(t *testing.T) {
t.Parallel()
t.Run("AllUsers", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
user := coderdtest.CreateFirstUser(t, client)
client := coderdtest.New(t, nil)
first := coderdtest.CreateFirstUser(t, client)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
client.CreateUser(ctx, codersdk.CreateUserRequest{
Email: "alice@email.com",
Username: "alice",
Password: "password",
OrganizationID: user.OrganizationID,
})
// No params is all users
response, err := client.UserCount(ctx, codersdk.UserCountRequest{})
require.NoError(t, err)
require.Equal(t, 2, int(response.Count))
_, err := client.User(ctx, first.UserID.String())
require.NoError(t, err, "")
_, err = client.CreateUser(ctx, codersdk.CreateUserRequest{
Email: "alice@email.com",
Username: "alice",
Password: "password",
OrganizationID: first.OrganizationID,
})
t.Run("ActiveUsers", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
first := coderdtest.CreateFirstUser(t, client)
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
res, err := client.Users(ctx, codersdk.UsersRequest{})
require.NoError(t, err)
require.Len(t, res.Users, 2)
require.Equal(t, res.Count, 2)
_, err := client.User(ctx, first.UserID.String())
require.NoError(t, err, "")
// Alice will be suspended
alice, err := client.CreateUser(ctx, codersdk.CreateUserRequest{
Email: "alice@email.com",
Username: "alice",
Password: "password",
OrganizationID: first.OrganizationID,
})
require.NoError(t, err)
_, err = client.UpdateUserStatus(ctx, alice.Username, codersdk.UserStatusSuspended)
require.NoError(t, err)
response, err := client.UserCount(ctx, codersdk.UserCountRequest{
Status: codersdk.UserStatusActive,
})
require.NoError(t, err)
require.Equal(t, 1, int(response.Count))
res, err = client.Users(ctx, codersdk.UsersRequest{
Pagination: codersdk.Pagination{
Limit: 1,
},
})
require.NoError(t, err)
require.Len(t, res.Users, 1)
require.Equal(t, res.Count, 2)
res, err = client.Users(ctx, codersdk.UsersRequest{
Pagination: codersdk.Pagination{
Offset: 1,
},
})
require.NoError(t, err)
require.Len(t, res.Users, 1)
require.Equal(t, res.Count, 2)
// if offset is higher than the count postgres returns an empty array
// and not an ErrNoRows error. This also means the count must be 0.
res, err = client.Users(ctx, codersdk.UsersRequest{
Pagination: codersdk.Pagination{
Offset: 3,
},
})
require.NoError(t, err)
require.Len(t, res.Users, 0)
require.Equal(t, res.Count, 0)
}
func TestPostTokens(t *testing.T) {
@ -1420,7 +1422,7 @@ func TestSuspendedPagination(t *testing.T) {
},
})
require.NoError(t, err)
require.Equal(t, expected, page, "expected page")
require.Equal(t, expected, page.Users, "expected page")
}
// TestPaginatedUsers creates a list of users, then tries to paginate through
@ -1546,15 +1548,15 @@ func assertPagination(ctx context.Context, t *testing.T, client *codersdk.Client
},
}))
require.NoError(t, err, "first page")
require.Equalf(t, page, allUsers[:limit], "first page, limit=%d", limit)
count += len(page)
require.Equalf(t, page.Users, allUsers[:limit], "first page, limit=%d", limit)
count += len(page.Users)
for {
if len(page) == 0 {
if len(page.Users) == 0 {
break
}
afterCursor := page[len(page)-1].ID
afterCursor := page.Users[len(page.Users)-1].ID
// Assert each page is the next expected page
// This is using a cursor, and only works if all users created_at
// is unique.
@ -1581,8 +1583,8 @@ func assertPagination(ctx context.Context, t *testing.T, client *codersdk.Client
} else {
expected = allUsers[count : count+limit]
}
require.Equalf(t, page, expected, "next users, after=%s, limit=%d", afterCursor, limit)
require.Equalf(t, offsetPage, expected, "offset users, offset=%d, limit=%d", count, limit)
require.Equalf(t, page.Users, expected, "next users, after=%s, limit=%d", afterCursor, limit)
require.Equalf(t, offsetPage.Users, expected, "offset users, offset=%d, limit=%d", count, limit)
// Also check the before
prevPage, err := client.Users(ctx, opt(codersdk.UsersRequest{
@ -1592,8 +1594,8 @@ func assertPagination(ctx context.Context, t *testing.T, client *codersdk.Client
},
}))
require.NoError(t, err, "prev page")
require.Equal(t, allUsers[count-limit:count], prevPage, "prev users")
count += len(page)
require.Equal(t, allUsers[count-limit:count], prevPage.Users, "prev users")
count += len(page.Users)
}
}