mirror of
https://github.com/coder/coder.git
synced 2025-07-03 16:13:58 +00:00
feat: Guard search queries against common mistakes (#6404)
* feat: Error on excessive invalid search keys * feat: Guard search queries against common mistakes * Raise errors in FE on workspaces table * All errors should be on newlines
This commit is contained in:
@ -5,12 +5,13 @@ import (
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/coder/coder/codersdk"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/codersdk"
|
||||
)
|
||||
|
||||
// QueryParamParser is a helper for parsing all query params and gathering all
|
||||
@ -20,16 +21,38 @@ type QueryParamParser struct {
|
||||
// Errors is the set of errors to return via the API. If the length
|
||||
// of this set is 0, there are no errors!.
|
||||
Errors []codersdk.ValidationError
|
||||
// Parsed is a map of all query params that were parsed. This is useful
|
||||
// for checking if extra query params were passed in.
|
||||
Parsed map[string]bool
|
||||
}
|
||||
|
||||
func NewQueryParamParser() *QueryParamParser {
|
||||
return &QueryParamParser{
|
||||
Errors: []codersdk.ValidationError{},
|
||||
Parsed: map[string]bool{},
|
||||
}
|
||||
}
|
||||
|
||||
// ErrorExcessParams checks if any query params were passed in that were not
|
||||
// parsed. If so, it adds an error to the parser as these values are not valid
|
||||
// query parameters.
|
||||
func (p *QueryParamParser) ErrorExcessParams(values url.Values) {
|
||||
for k := range values {
|
||||
if _, ok := p.Parsed[k]; !ok {
|
||||
p.Errors = append(p.Errors, codersdk.ValidationError{
|
||||
Field: k,
|
||||
Detail: fmt.Sprintf("Query param %q is not a valid query param", k),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *QueryParamParser) addParsed(key string) {
|
||||
p.Parsed[key] = true
|
||||
}
|
||||
|
||||
func (p *QueryParamParser) Int(vals url.Values, def int, queryParam string) int {
|
||||
v, err := parseQueryParam(vals, strconv.Atoi, def, queryParam)
|
||||
v, err := parseQueryParam(p, vals, strconv.Atoi, def, queryParam)
|
||||
if err != nil {
|
||||
p.Errors = append(p.Errors, codersdk.ValidationError{
|
||||
Field: queryParam,
|
||||
@ -40,14 +63,16 @@ func (p *QueryParamParser) Int(vals url.Values, def int, queryParam string) int
|
||||
}
|
||||
|
||||
func (p *QueryParamParser) UUIDorMe(vals url.Values, def uuid.UUID, me uuid.UUID, queryParam string) uuid.UUID {
|
||||
if vals.Get(queryParam) == "me" {
|
||||
return me
|
||||
}
|
||||
return p.UUID(vals, def, queryParam)
|
||||
return ParseCustom(p, vals, def, queryParam, func(v string) (uuid.UUID, error) {
|
||||
if v == "me" {
|
||||
return me, nil
|
||||
}
|
||||
return uuid.Parse(v)
|
||||
})
|
||||
}
|
||||
|
||||
func (p *QueryParamParser) UUID(vals url.Values, def uuid.UUID, queryParam string) uuid.UUID {
|
||||
v, err := parseQueryParam(vals, uuid.Parse, def, queryParam)
|
||||
v, err := parseQueryParam(p, vals, uuid.Parse, def, queryParam)
|
||||
if err != nil {
|
||||
p.Errors = append(p.Errors, codersdk.ValidationError{
|
||||
Field: queryParam,
|
||||
@ -58,54 +83,61 @@ func (p *QueryParamParser) UUID(vals url.Values, def uuid.UUID, queryParam strin
|
||||
}
|
||||
|
||||
func (p *QueryParamParser) UUIDs(vals url.Values, def []uuid.UUID, queryParam string) []uuid.UUID {
|
||||
v, err := parseQueryParam(vals, func(v string) ([]uuid.UUID, error) {
|
||||
var badValues []string
|
||||
strs := strings.Split(v, ",")
|
||||
ids := make([]uuid.UUID, 0, len(strs))
|
||||
for _, s := range strs {
|
||||
id, err := uuid.Parse(strings.TrimSpace(s))
|
||||
if err != nil {
|
||||
badValues = append(badValues, v)
|
||||
continue
|
||||
}
|
||||
ids = append(ids, id)
|
||||
}
|
||||
return ParseCustomList(p, vals, def, queryParam, func(v string) (uuid.UUID, error) {
|
||||
return uuid.Parse(strings.TrimSpace(v))
|
||||
})
|
||||
}
|
||||
|
||||
if len(badValues) > 0 {
|
||||
return []uuid.UUID{}, xerrors.Errorf("%s", strings.Join(badValues, ","))
|
||||
}
|
||||
return ids, nil
|
||||
func (p *QueryParamParser) Time(vals url.Values, def time.Time, queryParam string, format string) time.Time {
|
||||
v, err := parseQueryParam(p, vals, func(term string) (time.Time, error) {
|
||||
return time.Parse(format, term)
|
||||
}, def, queryParam)
|
||||
if err != nil {
|
||||
p.Errors = append(p.Errors, codersdk.ValidationError{
|
||||
Field: queryParam,
|
||||
Detail: fmt.Sprintf("Query param %q has invalid uuids: %q", queryParam, err.Error()),
|
||||
Detail: fmt.Sprintf("Query param %q must be a valid date format (%s): %s", queryParam, format, err.Error()),
|
||||
})
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func (*QueryParamParser) String(vals url.Values, def string, queryParam string) string {
|
||||
v, _ := parseQueryParam(vals, func(v string) (string, error) {
|
||||
func (p *QueryParamParser) String(vals url.Values, def string, queryParam string) string {
|
||||
v, _ := parseQueryParam(p, vals, func(v string) (string, error) {
|
||||
return v, nil
|
||||
}, def, queryParam)
|
||||
return v
|
||||
}
|
||||
|
||||
func (*QueryParamParser) Strings(vals url.Values, def []string, queryParam string) []string {
|
||||
v, _ := parseQueryParam(vals, func(v string) ([]string, error) {
|
||||
if v == "" {
|
||||
return []string{}, nil
|
||||
}
|
||||
return strings.Split(v, ","), nil
|
||||
}, def, queryParam)
|
||||
return v
|
||||
func (p *QueryParamParser) Strings(vals url.Values, def []string, queryParam string) []string {
|
||||
return ParseCustomList(p, vals, def, queryParam, func(v string) (string, error) {
|
||||
return v, nil
|
||||
})
|
||||
}
|
||||
|
||||
// ValidEnum parses enum query params. Add more to the list as needed.
|
||||
type ValidEnum interface {
|
||||
database.ResourceType | database.AuditAction | database.BuildReason | database.UserStatus |
|
||||
database.WorkspaceStatus
|
||||
|
||||
// Valid is required on the enum type to be used with ParseEnum.
|
||||
Valid() bool
|
||||
}
|
||||
|
||||
// ParseEnum is a function that can be passed into ParseCustom that handles enum
|
||||
// validation.
|
||||
func ParseEnum[T ValidEnum](term string) (T, error) {
|
||||
enum := T(term)
|
||||
if enum.Valid() {
|
||||
return enum, nil
|
||||
}
|
||||
var empty T
|
||||
return empty, xerrors.Errorf("%q is not a valid value", term)
|
||||
}
|
||||
|
||||
// ParseCustom has to be a function, not a method on QueryParamParser because generics
|
||||
// cannot be used on struct methods.
|
||||
func ParseCustom[T any](parser *QueryParamParser, vals url.Values, def T, queryParam string, parseFunc func(v string) (T, error)) T {
|
||||
v, err := parseQueryParam(vals, parseFunc, def, queryParam)
|
||||
v, err := parseQueryParam(parser, vals, parseFunc, def, queryParam)
|
||||
if err != nil {
|
||||
parser.Errors = append(parser.Errors, codersdk.ValidationError{
|
||||
Field: queryParam,
|
||||
@ -115,10 +147,41 @@ func ParseCustom[T any](parser *QueryParamParser, vals url.Values, def T, queryP
|
||||
return v
|
||||
}
|
||||
|
||||
func parseQueryParam[T any](vals url.Values, parse func(v string) (T, error), def T, queryParam string) (T, error) {
|
||||
// ParseCustomList is a function that handles csv query params.
|
||||
func ParseCustomList[T any](parser *QueryParamParser, vals url.Values, def []T, queryParam string, parseFunc func(v string) (T, error)) []T {
|
||||
v, err := parseQueryParam(parser, vals, func(v string) ([]T, error) {
|
||||
terms := strings.Split(v, ",")
|
||||
var badValues []string
|
||||
var output []T
|
||||
for _, s := range terms {
|
||||
good, err := parseFunc(s)
|
||||
if err != nil {
|
||||
badValues = append(badValues, s)
|
||||
continue
|
||||
}
|
||||
output = append(output, good)
|
||||
}
|
||||
if len(badValues) > 0 {
|
||||
return []T{}, xerrors.Errorf("%s", strings.Join(badValues, ","))
|
||||
}
|
||||
|
||||
return output, nil
|
||||
}, def, queryParam)
|
||||
if err != nil {
|
||||
parser.Errors = append(parser.Errors, codersdk.ValidationError{
|
||||
Field: queryParam,
|
||||
Detail: fmt.Sprintf("Query param %q has invalid values: %s", queryParam, err.Error()),
|
||||
})
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func parseQueryParam[T any](parser *QueryParamParser, vals url.Values, parse func(v string) (T, error), def T, queryParam string) (T, error) {
|
||||
parser.addParsed(queryParam)
|
||||
if !vals.Has(queryParam) || vals.Get(queryParam) == "" {
|
||||
return def, nil
|
||||
}
|
||||
|
||||
str := vals.Get(queryParam)
|
||||
return parse(str)
|
||||
}
|
||||
|
@ -5,10 +5,12 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
)
|
||||
|
||||
@ -26,6 +28,68 @@ type queryParamTestCase[T any] struct {
|
||||
func TestParseQueryParams(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("Enum", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
expParams := []queryParamTestCase[database.ResourceType]{
|
||||
{
|
||||
QueryParam: "resource_type",
|
||||
Value: string(database.ResourceTypeWorkspace),
|
||||
Expected: database.ResourceTypeWorkspace,
|
||||
},
|
||||
{
|
||||
QueryParam: "bad_type",
|
||||
Value: "foo",
|
||||
ExpectedErrorContains: "not a valid value",
|
||||
},
|
||||
}
|
||||
|
||||
parser := httpapi.NewQueryParamParser()
|
||||
testQueryParams(t, expParams, parser, func(vals url.Values, def database.ResourceType, queryParam string) database.ResourceType {
|
||||
return httpapi.ParseCustom(parser, vals, def, queryParam, httpapi.ParseEnum[database.ResourceType])
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("EnumList", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
expParams := []queryParamTestCase[[]database.ResourceType]{
|
||||
{
|
||||
QueryParam: "resource_type",
|
||||
Value: fmt.Sprintf("%s,%s", database.ResourceTypeWorkspace, database.ResourceTypeApiKey),
|
||||
Expected: []database.ResourceType{database.ResourceTypeWorkspace, database.ResourceTypeApiKey},
|
||||
},
|
||||
}
|
||||
|
||||
parser := httpapi.NewQueryParamParser()
|
||||
testQueryParams(t, expParams, parser, func(vals url.Values, def []database.ResourceType, queryParam string) []database.ResourceType {
|
||||
return httpapi.ParseCustomList(parser, vals, def, queryParam, httpapi.ParseEnum[database.ResourceType])
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("Time", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
const layout = "2006-01-02"
|
||||
|
||||
expParams := []queryParamTestCase[time.Time]{
|
||||
{
|
||||
QueryParam: "date",
|
||||
Value: "2010-01-01",
|
||||
Expected: must(time.Parse(layout, "2010-01-01")),
|
||||
},
|
||||
{
|
||||
QueryParam: "bad_date",
|
||||
Value: "2010",
|
||||
ExpectedErrorContains: "must be a valid date format",
|
||||
},
|
||||
}
|
||||
|
||||
parser := httpapi.NewQueryParamParser()
|
||||
testQueryParams(t, expParams, parser, func(vals url.Values, def time.Time, queryParam string) time.Time {
|
||||
return parser.Time(vals, time.Time{}, queryParam, layout)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("UUID", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
me := uuid.New()
|
||||
@ -43,12 +107,12 @@ func TestParseQueryParams(t *testing.T) {
|
||||
{
|
||||
QueryParam: "invalid_id",
|
||||
Value: "bogus",
|
||||
ExpectedErrorContains: "must be a valid uuid",
|
||||
ExpectedErrorContains: "invalid UUID length",
|
||||
},
|
||||
{
|
||||
QueryParam: "long_id",
|
||||
Value: "afe39fbf-0f52-4a62-b0cc-58670145d773-123",
|
||||
ExpectedErrorContains: "must be a valid uuid",
|
||||
ExpectedErrorContains: "invalid UUID length",
|
||||
},
|
||||
{
|
||||
QueryParam: "no_value",
|
||||
@ -187,8 +251,8 @@ func testQueryParams[T any](t *testing.T, testCases []queryParamTestCase[T], par
|
||||
for _, c := range testCases {
|
||||
// !! Do not run these in parallel !!
|
||||
t.Run(c.QueryParam, func(t *testing.T) {
|
||||
v := parse(v, c.Default, c.QueryParam)
|
||||
require.Equal(t, c.Expected, v, fmt.Sprintf("param=%q value=%q", c.QueryParam, c.Value))
|
||||
value := parse(v, c.Default, c.QueryParam)
|
||||
require.Equal(t, c.Expected, value, fmt.Sprintf("param=%q value=%q", c.QueryParam, c.Value))
|
||||
if c.ExpectedErrorContains != "" {
|
||||
errors := parser.Errors
|
||||
require.True(t, len(errors) > 0, "error exist")
|
||||
@ -199,3 +263,10 @@ func testQueryParams[T any](t *testing.T, testCases []queryParamTestCase[T], par
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func must[T any](value T, err error) T {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
Reference in New Issue
Block a user