mirror of
https://github.com/coder/coder.git
synced 2025-07-13 21:36:50 +00:00
feat: Implement unified pagination and add template versions support (#1308)
* feat: Implement pagination for template versions * feat: Use unified pagination between users and template versions * Sync codepaths between users and template versions * Create requestOption type in codersdk and add test * Fix created_at edge case for pagination cursor in queries * feat: Add support for json omitempty and embedded structs in apitypings (#1318) * Add scripts/apitypings/main.go to Makefile
This commit is contained in:
committed by
GitHub
parent
dc115b8ca0
commit
2d3dc436a8
@ -172,25 +172,25 @@ func (q *fakeQuerier) GetUsers(_ context.Context, params database.GetUsersParams
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
users := q.users
|
||||
// Avoid side-effect of sorting.
|
||||
users := make([]database.User, len(q.users))
|
||||
copy(users, q.users)
|
||||
|
||||
// Database orders by created_at
|
||||
sort.Slice(users, func(i, j int) bool {
|
||||
if users[i].CreatedAt.Equal(users[j].CreatedAt) {
|
||||
slices.SortFunc(users, func(a, b database.User) bool {
|
||||
if a.CreatedAt.Equal(b.CreatedAt) {
|
||||
// Technically the postgres database also orders by uuid. So match
|
||||
// that behavior
|
||||
return users[i].ID.String() < users[j].ID.String()
|
||||
return a.ID.String() < b.ID.String()
|
||||
}
|
||||
return users[i].CreatedAt.Before(users[j].CreatedAt)
|
||||
return a.CreatedAt.Before(b.CreatedAt)
|
||||
})
|
||||
|
||||
if params.AfterUser != uuid.Nil {
|
||||
if params.AfterID != uuid.Nil {
|
||||
found := false
|
||||
for i := range users {
|
||||
if users[i].ID == params.AfterUser {
|
||||
for i, v := range users {
|
||||
if v.ID == params.AfterID {
|
||||
// We want to return all users after index i.
|
||||
if i+1 >= len(users) {
|
||||
return []database.User{}, nil
|
||||
}
|
||||
users = users[i+1:]
|
||||
found = true
|
||||
break
|
||||
@ -199,7 +199,7 @@ func (q *fakeQuerier) GetUsers(_ context.Context, params database.GetUsersParams
|
||||
|
||||
// If no users after the time, then we return an empty list.
|
||||
if !found {
|
||||
return []database.User{}, nil
|
||||
return nil, sql.ErrNoRows
|
||||
}
|
||||
}
|
||||
|
||||
@ -227,7 +227,7 @@ func (q *fakeQuerier) GetUsers(_ context.Context, params database.GetUsersParams
|
||||
|
||||
if params.OffsetOpt > 0 {
|
||||
if int(params.OffsetOpt) > len(users)-1 {
|
||||
return []database.User{}, nil
|
||||
return nil, sql.ErrNoRows
|
||||
}
|
||||
users = users[params.OffsetOpt:]
|
||||
}
|
||||
@ -239,10 +239,7 @@ func (q *fakeQuerier) GetUsers(_ context.Context, params database.GetUsersParams
|
||||
users = users[:params.LimitOpt]
|
||||
}
|
||||
|
||||
tmp := make([]database.User, len(users))
|
||||
copy(tmp, users)
|
||||
|
||||
return tmp, nil
|
||||
return users, nil
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) GetAllUserRoles(_ context.Context, userID uuid.UUID) (database.GetAllUserRolesRow, error) {
|
||||
@ -621,20 +618,62 @@ func (q *fakeQuerier) GetTemplateByOrganizationAndName(_ context.Context, arg da
|
||||
return database.Template{}, sql.ErrNoRows
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) GetTemplateVersionsByTemplateID(_ context.Context, templateID uuid.UUID) ([]database.TemplateVersion, error) {
|
||||
func (q *fakeQuerier) GetTemplateVersionsByTemplateID(_ context.Context, arg database.GetTemplateVersionsByTemplateIDParams) (version []database.TemplateVersion, err error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
version := make([]database.TemplateVersion, 0)
|
||||
for _, templateVersion := range q.templateVersions {
|
||||
if templateVersion.TemplateID.UUID.String() != templateID.String() {
|
||||
if templateVersion.TemplateID.UUID.String() != arg.TemplateID.String() {
|
||||
continue
|
||||
}
|
||||
version = append(version, templateVersion)
|
||||
}
|
||||
|
||||
// Database orders by created_at
|
||||
slices.SortFunc(version, func(a, b database.TemplateVersion) bool {
|
||||
if a.CreatedAt.Equal(b.CreatedAt) {
|
||||
// Technically the postgres database also orders by uuid. So match
|
||||
// that behavior
|
||||
return a.ID.String() < b.ID.String()
|
||||
}
|
||||
return a.CreatedAt.Before(b.CreatedAt)
|
||||
})
|
||||
|
||||
if arg.AfterID != uuid.Nil {
|
||||
found := false
|
||||
for i, v := range version {
|
||||
if v.ID == arg.AfterID {
|
||||
// We want to return all users after index i.
|
||||
version = version[i+1:]
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// If no users after the time, then we return an empty list.
|
||||
if !found {
|
||||
return nil, sql.ErrNoRows
|
||||
}
|
||||
}
|
||||
|
||||
if arg.OffsetOpt > 0 {
|
||||
if int(arg.OffsetOpt) > len(version)-1 {
|
||||
return nil, sql.ErrNoRows
|
||||
}
|
||||
version = version[arg.OffsetOpt:]
|
||||
}
|
||||
|
||||
if arg.LimitOpt > 0 {
|
||||
if int(arg.LimitOpt) > len(version) {
|
||||
arg.LimitOpt = int32(len(version))
|
||||
}
|
||||
version = version[:arg.LimitOpt]
|
||||
}
|
||||
|
||||
if len(version) == 0 {
|
||||
return nil, sql.ErrNoRows
|
||||
}
|
||||
|
||||
return version, nil
|
||||
}
|
||||
|
||||
|
@ -47,7 +47,7 @@ type querier interface {
|
||||
GetTemplateVersionByID(ctx context.Context, id uuid.UUID) (TemplateVersion, error)
|
||||
GetTemplateVersionByJobID(ctx context.Context, jobID uuid.UUID) (TemplateVersion, error)
|
||||
GetTemplateVersionByTemplateIDAndName(ctx context.Context, arg GetTemplateVersionByTemplateIDAndNameParams) (TemplateVersion, error)
|
||||
GetTemplateVersionsByTemplateID(ctx context.Context, dollar_1 uuid.UUID) ([]TemplateVersion, error)
|
||||
GetTemplateVersionsByTemplateID(ctx context.Context, arg GetTemplateVersionsByTemplateIDParams) ([]TemplateVersion, error)
|
||||
GetTemplatesByIDs(ctx context.Context, ids []uuid.UUID) ([]Template, error)
|
||||
GetTemplatesByOrganization(ctx context.Context, arg GetTemplatesByOrganizationParams) ([]Template, error)
|
||||
GetUserByEmailOrUsername(ctx context.Context, arg GetUserByEmailOrUsernameParams) (User, error)
|
||||
|
@ -1908,10 +1908,48 @@ FROM
|
||||
template_versions
|
||||
WHERE
|
||||
template_id = $1 :: uuid
|
||||
AND CASE
|
||||
-- This allows using the last element on a page as effectively a cursor.
|
||||
-- This is an important option for scripts that need to paginate without
|
||||
-- duplicating or missing data.
|
||||
WHEN $2 :: uuid != '00000000-00000000-00000000-00000000' THEN (
|
||||
-- The pagination cursor is the last ID of the previous page.
|
||||
-- The query is ordered by the created_at field, so select all
|
||||
-- rows after the cursor.
|
||||
(created_at, id) > (
|
||||
SELECT
|
||||
created_at, id
|
||||
FROM
|
||||
template_versions
|
||||
WHERE
|
||||
id = $2
|
||||
)
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
ORDER BY
|
||||
-- Deterministic and consistent ordering of all rows, even if they share
|
||||
-- a timestamp. This is to ensure consistent pagination.
|
||||
(created_at, id) ASC OFFSET $3
|
||||
LIMIT
|
||||
-- A null limit means "no limit", so -1 means return all
|
||||
NULLIF($4 :: int, -1)
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetTemplateVersionsByTemplateID(ctx context.Context, dollar_1 uuid.UUID) ([]TemplateVersion, error) {
|
||||
rows, err := q.db.QueryContext(ctx, getTemplateVersionsByTemplateID, dollar_1)
|
||||
type GetTemplateVersionsByTemplateIDParams struct {
|
||||
TemplateID uuid.UUID `db:"template_id" json:"template_id"`
|
||||
AfterID uuid.UUID `db:"after_id" json:"after_id"`
|
||||
OffsetOpt int32 `db:"offset_opt" json:"offset_opt"`
|
||||
LimitOpt int32 `db:"limit_opt" json:"limit_opt"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) GetTemplateVersionsByTemplateID(ctx context.Context, arg GetTemplateVersionsByTemplateIDParams) ([]TemplateVersion, error) {
|
||||
rows, err := q.db.QueryContext(ctx, getTemplateVersionsByTemplateID,
|
||||
arg.TemplateID,
|
||||
arg.AfterID,
|
||||
arg.OffsetOpt,
|
||||
arg.LimitOpt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -2125,22 +2163,19 @@ WHERE
|
||||
-- This is an important option for scripts that need to paginate without
|
||||
-- duplicating or missing data.
|
||||
WHEN $1 :: uuid != '00000000-00000000-00000000-00000000' THEN (
|
||||
-- The pagination cursor is the last user of the previous page.
|
||||
-- The query is ordered by the created_at field, so select all
|
||||
-- users after the cursor. We also want to include any users
|
||||
-- that share the created_at (super rare).
|
||||
created_at >= (
|
||||
SELECT
|
||||
created_at
|
||||
FROM
|
||||
users
|
||||
WHERE
|
||||
id = $1
|
||||
)
|
||||
-- Omit the cursor from the final.
|
||||
AND id != $1
|
||||
-- The pagination cursor is the last ID of the previous page.
|
||||
-- The query is ordered by the created_at field, so select all
|
||||
-- rows after the cursor.
|
||||
(created_at, id) > (
|
||||
SELECT
|
||||
created_at, id
|
||||
FROM
|
||||
users
|
||||
WHERE
|
||||
id = $1
|
||||
)
|
||||
ELSE true
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Start filters
|
||||
-- Filter by name, email or username
|
||||
@ -2171,7 +2206,7 @@ LIMIT
|
||||
`
|
||||
|
||||
type GetUsersParams struct {
|
||||
AfterUser uuid.UUID `db:"after_user" json:"after_user"`
|
||||
AfterID uuid.UUID `db:"after_id" json:"after_id"`
|
||||
Search string `db:"search" json:"search"`
|
||||
Status string `db:"status" json:"status"`
|
||||
OffsetOpt int32 `db:"offset_opt" json:"offset_opt"`
|
||||
@ -2180,7 +2215,7 @@ type GetUsersParams struct {
|
||||
|
||||
func (q *sqlQuerier) GetUsers(ctx context.Context, arg GetUsersParams) ([]User, error) {
|
||||
rows, err := q.db.QueryContext(ctx, getUsers,
|
||||
arg.AfterUser,
|
||||
arg.AfterID,
|
||||
arg.Search,
|
||||
arg.Status,
|
||||
arg.OffsetOpt,
|
||||
|
@ -4,7 +4,33 @@ SELECT
|
||||
FROM
|
||||
template_versions
|
||||
WHERE
|
||||
template_id = $1 :: uuid;
|
||||
template_id = @template_id :: uuid
|
||||
AND CASE
|
||||
-- This allows using the last element on a page as effectively a cursor.
|
||||
-- This is an important option for scripts that need to paginate without
|
||||
-- duplicating or missing data.
|
||||
WHEN @after_id :: uuid != '00000000-00000000-00000000-00000000' THEN (
|
||||
-- The pagination cursor is the last ID of the previous page.
|
||||
-- The query is ordered by the created_at field, so select all
|
||||
-- rows after the cursor.
|
||||
(created_at, id) > (
|
||||
SELECT
|
||||
created_at, id
|
||||
FROM
|
||||
template_versions
|
||||
WHERE
|
||||
id = @after_id
|
||||
)
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
ORDER BY
|
||||
-- Deterministic and consistent ordering of all rows, even if they share
|
||||
-- a timestamp. This is to ensure consistent pagination.
|
||||
(created_at, id) ASC OFFSET @offset_opt
|
||||
LIMIT
|
||||
-- A null limit means "no limit", so -1 means return all
|
||||
NULLIF(@limit_opt :: int, -1);
|
||||
|
||||
-- name: GetTemplateVersionByJobID :one
|
||||
SELECT
|
||||
|
@ -77,23 +77,20 @@ WHERE
|
||||
-- This allows using the last element on a page as effectively a cursor.
|
||||
-- This is an important option for scripts that need to paginate without
|
||||
-- duplicating or missing data.
|
||||
WHEN @after_user :: uuid != '00000000-00000000-00000000-00000000' THEN (
|
||||
-- The pagination cursor is the last user of the previous page.
|
||||
-- The query is ordered by the created_at field, so select all
|
||||
-- users after the cursor. We also want to include any users
|
||||
-- that share the created_at (super rare).
|
||||
created_at >= (
|
||||
SELECT
|
||||
created_at
|
||||
FROM
|
||||
users
|
||||
WHERE
|
||||
id = @after_user
|
||||
)
|
||||
-- Omit the cursor from the final.
|
||||
AND id != @after_user
|
||||
WHEN @after_id :: uuid != '00000000-00000000-00000000-00000000' THEN (
|
||||
-- The pagination cursor is the last ID of the previous page.
|
||||
-- The query is ordered by the created_at field, so select all
|
||||
-- rows after the cursor.
|
||||
(created_at, id) > (
|
||||
SELECT
|
||||
created_at, id
|
||||
FROM
|
||||
users
|
||||
WHERE
|
||||
id = @after_id
|
||||
)
|
||||
ELSE true
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
-- Start filters
|
||||
-- Filter by name, email or username
|
||||
|
57
coderd/pagination.go
Normal file
57
coderd/pagination.go
Normal file
@ -0,0 +1,57 @@
|
||||
package coderd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/codersdk"
|
||||
)
|
||||
|
||||
// parsePagination extracts pagination query params from the http request.
|
||||
// If an error is encountered, the error is written to w and ok is set to false.
|
||||
func parsePagination(w http.ResponseWriter, r *http.Request) (p codersdk.Pagination, ok bool) {
|
||||
var (
|
||||
afterID = uuid.Nil
|
||||
limit = -1 // Default to no limit and return all results.
|
||||
offset = 0
|
||||
)
|
||||
|
||||
var err error
|
||||
if s := r.URL.Query().Get("after_id"); s != "" {
|
||||
afterID, err = uuid.Parse(r.URL.Query().Get("after_id"))
|
||||
if err != nil {
|
||||
httpapi.Write(w, http.StatusBadRequest, httpapi.Response{
|
||||
Message: fmt.Sprintf("after_id must be a valid uuid: %s", err.Error()),
|
||||
})
|
||||
return p, false
|
||||
}
|
||||
}
|
||||
if s := r.URL.Query().Get("limit"); s != "" {
|
||||
limit, err = strconv.Atoi(s)
|
||||
if err != nil {
|
||||
httpapi.Write(w, http.StatusBadRequest, httpapi.Response{
|
||||
Message: fmt.Sprintf("limit must be an integer: %s", err.Error()),
|
||||
})
|
||||
return p, false
|
||||
}
|
||||
}
|
||||
if s := r.URL.Query().Get("offset"); s != "" {
|
||||
offset, err = strconv.Atoi(s)
|
||||
if err != nil {
|
||||
httpapi.Write(w, http.StatusBadRequest, httpapi.Response{
|
||||
Message: fmt.Sprintf("offset must be an integer: %s", err.Error()),
|
||||
})
|
||||
return p, false
|
||||
}
|
||||
}
|
||||
|
||||
return codersdk.Pagination{
|
||||
AfterID: afterID,
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
}, true
|
||||
}
|
@ -75,9 +75,21 @@ func (api *api) deleteTemplate(rw http.ResponseWriter, r *http.Request) {
|
||||
func (api *api) templateVersionsByTemplate(rw http.ResponseWriter, r *http.Request) {
|
||||
template := httpmw.TemplateParam(r)
|
||||
|
||||
versions, err := api.Database.GetTemplateVersionsByTemplateID(r.Context(), template.ID)
|
||||
paginationParams, ok := parsePagination(rw, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
apiVersion := []codersdk.TemplateVersion{}
|
||||
versions, err := api.Database.GetTemplateVersionsByTemplateID(r.Context(), database.GetTemplateVersionsByTemplateIDParams{
|
||||
TemplateID: template.ID,
|
||||
AfterID: paginationParams.AfterID,
|
||||
LimitOpt: int32(paginationParams.Limit),
|
||||
OffsetOpt: int32(paginationParams.Offset),
|
||||
})
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
err = nil
|
||||
httpapi.Write(rw, http.StatusOK, apiVersion)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
@ -101,7 +113,6 @@ func (api *api) templateVersionsByTemplate(rw http.ResponseWriter, r *http.Reque
|
||||
jobByID[job.ID.String()] = job
|
||||
}
|
||||
|
||||
apiVersion := make([]codersdk.TemplateVersion, 0)
|
||||
for _, version := range versions {
|
||||
job, exists := jobByID[version.JobID.String()]
|
||||
if !exists {
|
||||
|
@ -6,10 +6,13 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/coderd/coderdtest"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/provisioner/echo"
|
||||
)
|
||||
|
||||
func TestTemplate(t *testing.T) {
|
||||
@ -63,7 +66,9 @@ func TestTemplateVersionsByTemplate(t *testing.T) {
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
|
||||
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
|
||||
versions, err := client.TemplateVersionsByTemplate(context.Background(), template.ID)
|
||||
versions, err := client.TemplateVersionsByTemplate(context.Background(), codersdk.TemplateVersionsByTemplateRequest{
|
||||
TemplateID: template.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, versions, 1)
|
||||
})
|
||||
@ -137,3 +142,96 @@ func TestPatchActiveTemplateVersion(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// TestPaginatedTemplateVersions creates a list of template versions and paginate.
|
||||
func TestPaginatedTemplateVersions(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
client := coderdtest.New(t, &coderdtest.Options{APIRateLimit: -1})
|
||||
// Prepare database.
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
coderdtest.NewProvisionerDaemon(t, client)
|
||||
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
|
||||
_ = coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
|
||||
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
|
||||
|
||||
// Populate database with template versions.
|
||||
total := 9
|
||||
for i := 0; i < total; i++ {
|
||||
data, err := echo.Tar(nil)
|
||||
require.NoError(t, err)
|
||||
file, err := client.Upload(context.Background(), codersdk.ContentTypeTar, data)
|
||||
require.NoError(t, err)
|
||||
templateVersion, err := client.CreateTemplateVersion(ctx, user.OrganizationID, codersdk.CreateTemplateVersionRequest{
|
||||
TemplateID: template.ID,
|
||||
StorageSource: file.Hash,
|
||||
StorageMethod: database.ProvisionerStorageMethodFile,
|
||||
Provisioner: database.ProvisionerTypeEcho,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_ = coderdtest.AwaitTemplateVersionJob(t, client, templateVersion.ID)
|
||||
}
|
||||
|
||||
templateVersions, err := client.TemplateVersionsByTemplate(ctx,
|
||||
codersdk.TemplateVersionsByTemplateRequest{
|
||||
TemplateID: template.ID,
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, templateVersions, 10, "wrong number of template versions created")
|
||||
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
pagination codersdk.Pagination
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want []codersdk.TemplateVersion
|
||||
}{
|
||||
{
|
||||
name: "Single result",
|
||||
args: args{ctx: ctx, pagination: codersdk.Pagination{Limit: 1}},
|
||||
want: templateVersions[:1],
|
||||
},
|
||||
{
|
||||
name: "Single result, second page",
|
||||
args: args{ctx: ctx, pagination: codersdk.Pagination{Limit: 1, Offset: 1}},
|
||||
want: templateVersions[1:2],
|
||||
},
|
||||
{
|
||||
name: "Last two results",
|
||||
args: args{ctx: ctx, pagination: codersdk.Pagination{Limit: 2, Offset: 8}},
|
||||
want: templateVersions[8:10],
|
||||
},
|
||||
{
|
||||
name: "AfterID returns next two results",
|
||||
args: args{ctx: ctx, pagination: codersdk.Pagination{Limit: 2, AfterID: templateVersions[1].ID}},
|
||||
want: templateVersions[2:4],
|
||||
},
|
||||
{
|
||||
name: "No result after last AfterID",
|
||||
args: args{ctx: ctx, pagination: codersdk.Pagination{Limit: 2, AfterID: templateVersions[9].ID}},
|
||||
want: []codersdk.TemplateVersion{},
|
||||
},
|
||||
{
|
||||
name: "No result after last Offset",
|
||||
args: args{ctx: ctx, pagination: codersdk.Pagination{Limit: 2, Offset: 10}},
|
||||
want: []codersdk.TemplateVersion{},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got, err := client.TemplateVersionsByTemplate(tt.args.ctx, codersdk.TemplateVersionsByTemplateRequest{
|
||||
TemplateID: template.ID,
|
||||
Pagination: tt.args.pagination,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -7,7 +7,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
@ -106,55 +105,26 @@ func (api *api) postFirstUser(rw http.ResponseWriter, r *http.Request) {
|
||||
|
||||
func (api *api) users(rw http.ResponseWriter, r *http.Request) {
|
||||
var (
|
||||
afterArg = r.URL.Query().Get("after_user")
|
||||
limitArg = r.URL.Query().Get("limit")
|
||||
offsetArg = r.URL.Query().Get("offset")
|
||||
searchName = r.URL.Query().Get("search")
|
||||
statusFilter = r.URL.Query().Get("status")
|
||||
)
|
||||
|
||||
// createdAfter is a user uuid.
|
||||
createdAfter := uuid.Nil
|
||||
if afterArg != "" {
|
||||
after, err := uuid.Parse(afterArg)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
|
||||
Message: fmt.Sprintf("after_user must be a valid uuid: %s", err.Error()),
|
||||
})
|
||||
return
|
||||
}
|
||||
createdAfter = after
|
||||
}
|
||||
|
||||
// Default to no limit and return all users.
|
||||
pageLimit := -1
|
||||
if limitArg != "" {
|
||||
limit, err := strconv.Atoi(limitArg)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
|
||||
Message: fmt.Sprintf("limit must be an integer: %s", err.Error()),
|
||||
})
|
||||
return
|
||||
}
|
||||
pageLimit = limit
|
||||
}
|
||||
|
||||
// The default for empty string is 0.
|
||||
offset, err := strconv.ParseInt(offsetArg, 10, 64)
|
||||
if offsetArg != "" && err != nil {
|
||||
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
|
||||
Message: fmt.Sprintf("offset must be an integer: %s", err.Error()),
|
||||
})
|
||||
paginationParams, ok := parsePagination(rw, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
users, err := api.Database.GetUsers(r.Context(), database.GetUsersParams{
|
||||
AfterUser: createdAfter,
|
||||
OffsetOpt: int32(offset),
|
||||
LimitOpt: int32(pageLimit),
|
||||
AfterID: paginationParams.AfterID,
|
||||
OffsetOpt: int32(paginationParams.Offset),
|
||||
LimitOpt: int32(paginationParams.Limit),
|
||||
Search: searchName,
|
||||
Status: statusFilter,
|
||||
})
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
httpapi.Write(rw, http.StatusOK, []codersdk.User{})
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: err.Error(),
|
||||
|
@ -722,8 +722,6 @@ func TestPaginatedUsers(t *testing.T) {
|
||||
allUsers = append(allUsers, me)
|
||||
specialUsers := make([]codersdk.User, 0)
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
// When 100 users exist
|
||||
total := 100
|
||||
// Create users
|
||||
@ -795,7 +793,9 @@ func assertPagination(ctx context.Context, t *testing.T, client *codersdk.Client
|
||||
|
||||
// Check the first page
|
||||
page, err := client.Users(ctx, opt(codersdk.UsersRequest{
|
||||
Limit: limit,
|
||||
Pagination: codersdk.Pagination{
|
||||
Limit: limit,
|
||||
},
|
||||
}))
|
||||
require.NoError(t, err, "first page")
|
||||
require.Equalf(t, page, allUsers[:limit], "first page, limit=%d", limit)
|
||||
@ -811,15 +811,19 @@ func assertPagination(ctx context.Context, t *testing.T, client *codersdk.Client
|
||||
// This is using a cursor, and only works if all users created_at
|
||||
// is unique.
|
||||
page, err = client.Users(ctx, opt(codersdk.UsersRequest{
|
||||
Limit: limit,
|
||||
AfterUser: afterCursor,
|
||||
Pagination: codersdk.Pagination{
|
||||
Limit: limit,
|
||||
AfterID: afterCursor,
|
||||
},
|
||||
}))
|
||||
require.NoError(t, err, "next cursor page")
|
||||
|
||||
// Also check page by offset
|
||||
offsetPage, err := client.Users(ctx, opt(codersdk.UsersRequest{
|
||||
Limit: limit,
|
||||
Offset: count,
|
||||
Pagination: codersdk.Pagination{
|
||||
Limit: limit,
|
||||
Offset: count,
|
||||
},
|
||||
}))
|
||||
require.NoError(t, err, "next offset page")
|
||||
|
||||
@ -834,8 +838,10 @@ func assertPagination(ctx context.Context, t *testing.T, client *codersdk.Client
|
||||
|
||||
// Also check the before
|
||||
prevPage, err := client.Users(ctx, opt(codersdk.UsersRequest{
|
||||
Offset: count - limit,
|
||||
Limit: limit,
|
||||
Pagination: codersdk.Pagination{
|
||||
Offset: count - limit,
|
||||
Limit: limit,
|
||||
},
|
||||
}))
|
||||
require.NoError(t, err, "prev page")
|
||||
require.Equal(t, allUsers[count-limit:count], prevPage, "prev users")
|
||||
|
Reference in New Issue
Block a user