chore: add templates search query to a filter (#13772)

* chore: add templates search query to a filter
This commit is contained in:
Steven Masley
2024-07-03 08:42:23 -10:00
committed by GitHub
parent 8778aa0f71
commit ccf34901bc
8 changed files with 246 additions and 23 deletions

View File

@ -1,6 +1,7 @@
package httpapi
import (
"database/sql"
"errors"
"fmt"
"net/url"
@ -104,6 +105,27 @@ func (p *QueryParamParser) PositiveInt32(vals url.Values, def int32, queryParam
return v
}
// NullableBoolean will return a null sql value if no input is provided.
// SQLc still uses sql.NullBool rather than the generic type. So converting from
// the generic type is required.
func (p *QueryParamParser) NullableBoolean(vals url.Values, def sql.NullBool, queryParam string) sql.NullBool {
v, err := parseNullableQueryParam[bool](p, vals, strconv.ParseBool, sql.Null[bool]{
V: def.Bool,
Valid: def.Valid,
}, queryParam)
if err != nil {
p.Errors = append(p.Errors, codersdk.ValidationError{
Field: queryParam,
Detail: fmt.Sprintf("Query param %q must be a valid boolean: %s", queryParam, err.Error()),
})
}
return sql.NullBool{
Bool: v.V,
Valid: v.Valid,
}
}
func (p *QueryParamParser) Boolean(vals url.Values, def bool, queryParam string) bool {
v, err := parseQueryParam(p, vals, strconv.ParseBool, def, queryParam)
if err != nil {
@ -294,9 +316,34 @@ func ParseCustomList[T any](parser *QueryParamParser, vals url.Values, def []T,
return v
}
func parseNullableQueryParam[T any](parser *QueryParamParser, vals url.Values, parse func(v string) (T, error), def sql.Null[T], queryParam string) (sql.Null[T], error) {
setParse := parseSingle(parser, parse, def.V, queryParam)
return parseQueryParamSet[sql.Null[T]](parser, vals, func(set []string) (sql.Null[T], error) {
if len(set) == 0 {
return sql.Null[T]{
Valid: false,
}, nil
}
value, err := setParse(set)
if err != nil {
return sql.Null[T]{}, err
}
return sql.Null[T]{
V: value,
Valid: true,
}, nil
}, def, queryParam)
}
// parseQueryParam expects just 1 value set for the given query param.
func parseQueryParam[T any](parser *QueryParamParser, vals url.Values, parse func(v string) (T, error), def T, queryParam string) (T, error) {
setParse := func(set []string) (T, error) {
setParse := parseSingle(parser, parse, def, queryParam)
return parseQueryParamSet(parser, vals, setParse, def, queryParam)
}
func parseSingle[T any](parser *QueryParamParser, parse func(v string) (T, error), def T, queryParam string) func(set []string) (T, error) {
return func(set []string) (T, error) {
if len(set) > 1 {
// Set as a parser.Error rather than return an error.
// Returned errors are errors from the passed in `parse` function, and
@ -311,7 +358,6 @@ func parseQueryParam[T any](parser *QueryParamParser, vals url.Values, parse fun
}
return parse(set[0])
}
return parseQueryParamSet(parser, vals, setParse, def, queryParam)
}
func parseQueryParamSet[T any](parser *QueryParamParser, vals url.Values, parse func(set []string) (T, error), def T, queryParam string) (T, error) {

View File

@ -1,6 +1,7 @@
package httpapi_test
import (
"database/sql"
"fmt"
"net/http"
"net/url"
@ -220,6 +221,65 @@ func TestParseQueryParams(t *testing.T) {
testQueryParams(t, expParams, parser, parser.Boolean)
})
t.Run("NullableBoolean", func(t *testing.T) {
t.Parallel()
expParams := []queryParamTestCase[sql.NullBool]{
{
QueryParam: "valid_true",
Value: "true",
Expected: sql.NullBool{
Bool: true,
Valid: true,
},
},
{
QueryParam: "no_value_true_def",
NoSet: true,
Default: sql.NullBool{
Bool: true,
Valid: true,
},
Expected: sql.NullBool{
Bool: true,
Valid: true,
},
},
{
QueryParam: "no_value",
NoSet: true,
Expected: sql.NullBool{
Bool: false,
Valid: false,
},
},
{
QueryParam: "invalid_boolean",
Value: "yes",
Expected: sql.NullBool{
Bool: false,
Valid: false,
},
ExpectedErrorContains: "must be a valid boolean",
},
{
QueryParam: "unexpected_list",
Values: []string{"true", "false"},
ExpectedErrorContains: multipleValuesError,
// Expected value is a bit strange, but the error is raised
// in the parser, not as a parse failure. Maybe this should be
// fixed, but is how it is done atm.
Expected: sql.NullBool{
Bool: false,
Valid: true,
},
},
}
parser := httpapi.NewQueryParamParser()
testQueryParams(t, expParams, parser, parser.NullableBoolean)
})
t.Run("Int", func(t *testing.T) {
t.Parallel()
expParams := []queryParamTestCase[int]{

View File

@ -184,6 +184,52 @@ func Workspaces(query string, page codersdk.Pagination, agentInactiveDisconnectT
return filter, parser.Errors
}
func Templates(ctx context.Context, db database.Store, query string) (database.GetTemplatesWithFilterParams, []codersdk.ValidationError) {
// Always lowercase for all searches.
query = strings.ToLower(query)
values, errors := searchTerms(query, func(term string, values url.Values) error {
// Default to the template name
values.Add("name", term)
return nil
})
if len(errors) > 0 {
return database.GetTemplatesWithFilterParams{}, errors
}
parser := httpapi.NewQueryParamParser()
filter := database.GetTemplatesWithFilterParams{
Deleted: parser.Boolean(values, false, "deleted"),
// TODO: Should name be a fuzzy search?
ExactName: parser.String(values, "", "name"),
IDs: parser.UUIDs(values, []uuid.UUID{}, "ids"),
Deprecated: parser.NullableBoolean(values, sql.NullBool{}, "deprecated"),
}
// Convert the "organization" parameter to an organization uuid. This can require
// a database lookup.
organizationArg := parser.String(values, "", "organization")
if organizationArg != "" {
organizationID, err := uuid.Parse(organizationArg)
if err == nil {
filter.OrganizationID = organizationID
} else {
// Organization could be a name
organization, err := db.GetOrganizationByName(ctx, organizationArg)
if err != nil {
parser.Errors = append(parser.Errors, codersdk.ValidationError{
Field: "organization",
Detail: fmt.Sprintf("Organization %q either does not exist, or you are unauthorized to view it", organizationArg),
})
} else {
filter.OrganizationID = organization.ID
}
}
}
parser.ErrorExcessParams(values)
return filter, parser.Errors
}
func searchTerms(query string, defaultKey func(term string, values url.Values) error) (url.Values, []codersdk.ValidationError) {
searchValues := make(url.Values)

View File

@ -8,6 +8,7 @@ import (
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -454,3 +455,45 @@ func TestSearchUsers(t *testing.T) {
})
}
}
func TestSearchTemplates(t *testing.T) {
t.Parallel()
testCases := []struct {
Name string
Query string
Expected database.GetTemplatesWithFilterParams
ExpectedErrorContains string
}{
{
Name: "Empty",
Query: "",
Expected: database.GetTemplatesWithFilterParams{},
},
}
for _, c := range testCases {
c := c
t.Run(c.Name, func(t *testing.T) {
t.Parallel()
// Do not use a real database, this is only used for an
// organization lookup.
db := dbmem.New()
values, errs := searchquery.Templates(context.Background(), db, c.Query)
if c.ExpectedErrorContains != "" {
require.True(t, len(errs) > 0, "expect some errors")
var s strings.Builder
for _, err := range errs {
_, _ = s.WriteString(fmt.Sprintf("%s: %s\n", err.Field, err.Detail))
}
require.Contains(t, s.String(), c.ExpectedErrorContains)
} else {
require.Len(t, errs, 0, "expected no error")
if c.Expected.IDs == nil {
// Nil and length 0 are the same
c.Expected.IDs = []uuid.UUID{}
}
require.Equal(t, c.Expected, values, "expected values")
}
})
}
}

View File

@ -21,6 +21,7 @@ import (
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/coderd/rbac/policy"
"github.com/coder/coder/v2/coderd/schedule"
"github.com/coder/coder/v2/coderd/searchquery"
"github.com/coder/coder/v2/coderd/telemetry"
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/coderd/workspacestats"
@ -457,20 +458,12 @@ func (api *API) fetchTemplates(mutate func(r *http.Request, arg *database.GetTem
return func(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
p := httpapi.NewQueryParamParser()
values := r.URL.Query()
deprecated := sql.NullBool{}
if values.Has("deprecated") {
deprecated = sql.NullBool{
Bool: p.Boolean(values, false, "deprecated"),
Valid: true,
}
}
if len(p.Errors) > 0 {
queryStr := r.URL.Query().Get("q")
filter, errs := searchquery.Templates(ctx, api.Database, queryStr)
if len(errs) > 0 {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid query params.",
Validations: p.Errors,
Message: "Invalid template search query.",
Validations: errs,
})
return
}
@ -484,9 +477,7 @@ func (api *API) fetchTemplates(mutate func(r *http.Request, arg *database.GetTem
return
}
args := database.GetTemplatesWithFilterParams{
Deprecated: deprecated,
}
args := filter
if mutate != nil {
mutate(r, &args)
}

View File

@ -420,7 +420,9 @@ func TestTemplatesByOrganization(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitLong)
templates, err := client.TemplatesByOrganization(ctx, user.OrganizationID)
templates, err := client.Templates(ctx, codersdk.TemplateFilter{
OrganizationID: user.OrganizationID,
})
require.NoError(t, err)
require.Len(t, templates, 1)
})
@ -440,7 +442,7 @@ func TestTemplatesByOrganization(t *testing.T) {
require.Len(t, templates, 2)
// Listing all should match
templates, err = client.Templates(ctx)
templates, err = client.Templates(ctx, codersdk.TemplateFilter{})
require.NoError(t, err)
require.Len(t, templates, 2)
@ -473,12 +475,19 @@ func TestTemplatesByOrganization(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitLong)
// All 4 are viewable by the owner
templates, err := client.Templates(ctx)
templates, err := client.Templates(ctx, codersdk.TemplateFilter{})
require.NoError(t, err)
require.Len(t, templates, 4)
// View a single organization from the owner
templates, err = client.Templates(ctx, codersdk.TemplateFilter{
OrganizationID: owner.OrganizationID,
})
require.NoError(t, err)
require.Len(t, templates, 2)
// Only 2 are viewable by the org user
templates, err = user.Templates(ctx)
templates, err = user.Templates(ctx, codersdk.TemplateFilter{})
require.NoError(t, err)
require.Len(t, templates, 2)
for _, tmpl := range templates {