feat: Handle pagination cases where after_id does not exist (#1947)

* feat: Handle pagination cases where after_id does not exist

Throw an error to the user in these cases
- Templateversions
- Workspacebuilds

User pagination does not need it as suspended users still
have rows in the database
This commit is contained in:
Steven Masley
2022-06-02 09:01:45 -05:00
committed by GitHub
parent 419dc6b036
commit b9983e417f
8 changed files with 212 additions and 73 deletions

View File

@ -231,17 +231,19 @@ func (q *fakeQuerier) GetUsers(_ context.Context, params database.GetUsersParams
users = tmp
}
if len(params.Status) > 0 {
usersFilteredByStatus := make([]database.User, 0, len(users))
for i, user := range users {
for _, status := range params.Status {
if user.Status == status {
usersFilteredByStatus = append(usersFilteredByStatus, users[i])
}
if len(params.Status) == 0 {
params.Status = []database.UserStatus{database.UserStatusActive}
}
usersFilteredByStatus := make([]database.User, 0, len(users))
for i, user := range users {
for _, status := range params.Status {
if user.Status == status {
usersFilteredByStatus = append(usersFilteredByStatus, users[i])
}
}
users = usersFilteredByStatus
}
users = usersFilteredByStatus
if params.OffsetOpt > 0 {
if int(params.OffsetOpt) > len(users)-1 {

View File

@ -57,8 +57,8 @@ func (api *API) postOrganizations(rw http.ResponseWriter, r *http.Request) {
}
var organization database.Organization
err = api.Database.InTx(func(db database.Store) error {
organization, err = api.Database.InsertOrganization(r.Context(), database.InsertOrganizationParams{
err = api.Database.InTx(func(store database.Store) error {
organization, err = store.InsertOrganization(r.Context(), database.InsertOrganizationParams{
ID: uuid.New(),
Name: req.Name,
CreatedAt: database.Now(),
@ -67,7 +67,7 @@ func (api *API) postOrganizations(rw http.ResponseWriter, r *http.Request) {
if err != nil {
return xerrors.Errorf("create organization: %w", err)
}
_, err = api.Database.InsertOrganizationMember(r.Context(), database.InsertOrganizationMemberParams{
_, err = store.InsertOrganizationMember(r.Context(), database.InsertOrganizationMemberParams{
OrganizationID: organization.ID,
UserID: apiKey.UserID,
CreatedAt: database.Now(),

View File

@ -385,51 +385,77 @@ func (api *API) templateVersionsByTemplate(rw http.ResponseWriter, r *http.Reque
return
}
apiVersion := []codersdk.TemplateVersion{}
versions, err := api.Database.GetTemplateVersionsByTemplateID(r.Context(), database.GetTemplateVersionsByTemplateIDParams{
TemplateID: template.ID,
AfterID: paginationParams.AfterID,
LimitOpt: int32(paginationParams.Limit),
OffsetOpt: int32(paginationParams.Offset),
})
if errors.Is(err, sql.ErrNoRows) {
httpapi.Write(rw, http.StatusOK, apiVersion)
return
}
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get template version: %s", err),
})
return
}
jobIDs := make([]uuid.UUID, 0, len(versions))
for _, version := range versions {
jobIDs = append(jobIDs, version.JobID)
}
jobs, err := api.Database.GetProvisionerJobsByIDs(r.Context(), jobIDs)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get jobs: %s", err),
})
return
}
jobByID := map[string]database.ProvisionerJob{}
for _, job := range jobs {
jobByID[job.ID.String()] = job
}
for _, version := range versions {
job, exists := jobByID[version.JobID.String()]
if !exists {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("job %q doesn't exist for version %q", version.JobID, version.ID),
})
return
var err error
apiVersions := []codersdk.TemplateVersion{}
err = api.Database.InTx(func(store database.Store) error {
if paginationParams.AfterID != uuid.Nil {
// See if the record exists first. If the record does not exist, the pagination
// query will not work.
_, err := store.GetTemplateVersionByID(r.Context(), paginationParams.AfterID)
if err != nil && xerrors.Is(err, sql.ErrNoRows) {
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
Message: fmt.Sprintf("record at \"after_id\" (%q) does not exists", paginationParams.AfterID.String()),
})
return err
} else if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get template version at after_id: %s", err),
})
return err
}
}
apiVersion = append(apiVersion, convertTemplateVersion(version, convertProvisionerJob(job)))
versions, err := store.GetTemplateVersionsByTemplateID(r.Context(), database.GetTemplateVersionsByTemplateIDParams{
TemplateID: template.ID,
AfterID: paginationParams.AfterID,
LimitOpt: int32(paginationParams.Limit),
OffsetOpt: int32(paginationParams.Offset),
})
if errors.Is(err, sql.ErrNoRows) {
httpapi.Write(rw, http.StatusOK, apiVersions)
return err
}
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get template version: %s", err),
})
return err
}
jobIDs := make([]uuid.UUID, 0, len(versions))
for _, version := range versions {
jobIDs = append(jobIDs, version.JobID)
}
jobs, err := store.GetProvisionerJobsByIDs(r.Context(), jobIDs)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get jobs: %s", err),
})
return err
}
jobByID := map[string]database.ProvisionerJob{}
for _, job := range jobs {
jobByID[job.ID.String()] = job
}
for _, version := range versions {
job, exists := jobByID[version.JobID.String()]
if !exists {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("job %q doesn't exist for version %q", version.JobID, version.ID),
})
return err
}
apiVersions = append(apiVersions, convertTemplateVersion(version, convertProvisionerJob(job)))
}
return nil
})
if err != nil {
return
}
httpapi.Write(rw, http.StatusOK, apiVersion)
httpapi.Write(rw, http.StatusOK, apiVersions)
}
func (api *API) templateVersionByName(rw http.ResponseWriter, r *http.Request) {
@ -582,7 +608,7 @@ func (api *API) postTemplateVersionsByOrganization(rw http.ResponseWriter, r *ht
}
}
provisionerJob, err = api.Database.InsertProvisionerJob(r.Context(), database.InsertProvisionerJobParams{
provisionerJob, err = db.InsertProvisionerJob(r.Context(), database.InsertProvisionerJobParams{
ID: jobID,
CreatedAt: database.Now(),
UpdatedAt: database.Now(),
@ -606,7 +632,7 @@ func (api *API) postTemplateVersionsByOrganization(rw http.ResponseWriter, r *ht
}
}
templateVersion, err = api.Database.InsertTemplateVersion(r.Context(), database.InsertTemplateVersionParams{
templateVersion, err = db.InsertTemplateVersion(r.Context(), database.InsertTemplateVersionParams{
ID: uuid.New(),
TemplateID: templateID,
OrganizationID: organization.ID,

View File

@ -694,9 +694,10 @@ func TestPaginatedTemplateVersions(t *testing.T) {
pagination codersdk.Pagination
}
tests := []struct {
name string
args args
want []codersdk.TemplateVersion
name string
args args
want []codersdk.TemplateVersion
expectedError string
}{
{
name: "Single result",
@ -728,6 +729,11 @@ func TestPaginatedTemplateVersions(t *testing.T) {
args: args{ctx: ctx, pagination: codersdk.Pagination{Limit: 2, Offset: 10}},
want: []codersdk.TemplateVersion{},
},
{
name: "After_id does not exist",
args: args{ctx: ctx, pagination: codersdk.Pagination{AfterID: uuid.New()}},
expectedError: "does not exist",
},
}
for _, tt := range tests {
tt := tt
@ -737,8 +743,13 @@ func TestPaginatedTemplateVersions(t *testing.T) {
TemplateID: template.ID,
Pagination: tt.args.pagination,
})
assert.NoError(t, err)
assert.Equal(t, tt.want, got)
if tt.expectedError != "" {
require.Error(t, err)
require.ErrorContains(t, err, tt.expectedError)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.want, got)
}
})
}
}

View File

@ -830,6 +830,50 @@ func TestWorkspacesByUser(t *testing.T) {
})
}
// TestSuspendedPagination is when the after_id is a suspended record.
// The database query should still return the correct page, as the after_id
// is in a subquery that finds the record regardless of its status.
// This is mainly to confirm the db fake has the same behavior.
func TestSuspendedPagination(t *testing.T) {
t.Parallel()
ctx := context.Background()
client := coderdtest.New(t, &coderdtest.Options{APIRateLimit: -1})
coderdtest.CreateFirstUser(t, client)
me, err := client.User(context.Background(), codersdk.Me)
require.NoError(t, err)
orgID := me.OrganizationIDs[0]
total := 10
users := make([]codersdk.User, 0, total)
// Create users
for i := 0; i < total; i++ {
email := fmt.Sprintf("%d@coder.com", i)
username := fmt.Sprintf("user%d", i)
user, err := client.CreateUser(context.Background(), codersdk.CreateUserRequest{
Email: email,
Username: username,
Password: "password",
OrganizationID: orgID,
})
require.NoError(t, err)
users = append(users, user)
}
sortUsers(users)
deletedUser := users[2]
expected := users[3:8]
_, err = client.UpdateUserStatus(ctx, deletedUser.ID.String(), codersdk.UserStatusSuspended)
require.NoError(t, err, "suspend user")
page, err := client.Users(ctx, codersdk.UsersRequest{
Pagination: codersdk.Pagination{
Limit: len(expected),
AfterID: deletedUser.ID,
},
})
require.NoError(t, err)
require.Equal(t, expected, page, "expected page")
}
// TestPaginatedUsers creates a list of users, then tries to paginate through
// them using different page sizes.
func TestPaginatedUsers(t *testing.T) {

View File

@ -51,22 +51,51 @@ func (api *API) workspaceBuilds(rw http.ResponseWriter, r *http.Request) {
if !ok {
return
}
req := database.GetWorkspaceBuildByWorkspaceIDParams{
WorkspaceID: workspace.ID,
AfterID: paginationParams.AfterID,
OffsetOpt: int32(paginationParams.Offset),
LimitOpt: int32(paginationParams.Limit),
}
builds, err := api.Database.GetWorkspaceBuildByWorkspaceID(r.Context(), req)
if xerrors.Is(err, sql.ErrNoRows) {
err = nil
}
var builds []database.WorkspaceBuild
// Ensure all db calls happen in the same tx
err := api.Database.InTx(func(store database.Store) error {
var err error
if paginationParams.AfterID != uuid.Nil {
// See if the record exists first. If the record does not exist, the pagination
// query will not work.
_, err := store.GetWorkspaceBuildByID(r.Context(), paginationParams.AfterID)
if err != nil && xerrors.Is(err, sql.ErrNoRows) {
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
Message: fmt.Sprintf("record at \"after_id\" (%q) does not exist", paginationParams.AfterID.String()),
})
return err
} else if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get workspace build at after_id: %s", err),
})
return err
}
}
req := database.GetWorkspaceBuildByWorkspaceIDParams{
WorkspaceID: workspace.ID,
AfterID: paginationParams.AfterID,
OffsetOpt: int32(paginationParams.Offset),
LimitOpt: int32(paginationParams.Limit),
}
builds, err = store.GetWorkspaceBuildByWorkspaceID(r.Context(), req)
if xerrors.Is(err, sql.ErrNoRows) {
err = nil
}
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get workspace builds: %s", err),
})
return err
}
return nil
})
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("get workspace builds: %s", err),
})
return
}
jobIDs := make([]uuid.UUID, 0, len(builds))
for _, version := range builds {
jobIDs = append(jobIDs, version.JobID)

View File

@ -6,6 +6,7 @@ import (
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/coder/coder/coderd/coderdtest"
@ -44,6 +45,30 @@ func TestWorkspaceBuilds(t *testing.T) {
require.NoError(t, err)
})
t.Run("PaginateNonExistentRow", func(t *testing.T) {
t.Parallel()
ctx := context.Background()
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerD: true})
user := coderdtest.CreateFirstUser(t, client)
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
_, err := client.WorkspaceBuilds(ctx, codersdk.WorkspaceBuildsRequest{
WorkspaceID: workspace.ID,
Pagination: codersdk.Pagination{
AfterID: uuid.New(),
},
})
var apiError *codersdk.Error
require.ErrorAs(t, err, &apiError)
require.Equal(t, http.StatusBadRequest, apiError.StatusCode())
require.Contains(t, apiError.Message, "does not exist")
})
t.Run("PaginateLimitOffset", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerD: true})

View File

@ -18,6 +18,8 @@ services:
condition: service_healthy
database:
image: "postgres:14.2"
ports:
- "5432:5432"
environment:
POSTGRES_USER: ${POSTGRES_USER:-username} # The PostgreSQL user (useful to connect to the database)
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-password} # The PostgreSQL password (useful to connect to the database)