improve parameters validation

This commit is contained in:
defelmnq
2025-03-13 22:41:57 +00:00
parent 796bcd07d5
commit 75c310d4d7
4 changed files with 65 additions and 111 deletions

View File

@ -6,12 +6,9 @@ import (
"encoding/json"
"net/http"
"slices"
"strings"
"time"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"golang.org/x/xerrors"
"cdr.dev/slog"
@ -25,42 +22,6 @@ import (
"github.com/coder/websocket"
)
// convertInboxNotificationParameters parses and validates the common parameters used in get and list endpoints for inbox notifications
func convertInboxNotificationParameters(ctx context.Context, logger slog.Logger, targetsParam string, templatesParam string, readStatusParam string) ([]uuid.UUID, []uuid.UUID, string, error) {
var targets []uuid.UUID
if targetsParam != "" {
splitTargets := strings.Split(targetsParam, ",")
for _, target := range splitTargets {
id, err := uuid.Parse(target)
if err != nil {
logger.Error(ctx, "unable to parse target id", slog.Error(err))
return nil, nil, "", xerrors.New("unable to parse target id")
}
targets = append(targets, id)
}
}
var templates []uuid.UUID
if templatesParam != "" {
splitTemplates := strings.Split(templatesParam, ",")
for _, template := range splitTemplates {
id, err := uuid.Parse(template)
if err != nil {
logger.Error(ctx, "unable to parse template id", slog.Error(err))
return nil, nil, "", xerrors.New("unable to parse template id")
}
templates = append(templates, id)
}
}
readStatus := string(database.InboxNotificationReadStatusAll)
if readStatusParam != "" {
readStatus = readStatusParam
}
return targets, templates, readStatus, nil
}
// convertInboxNotificationResponse works as a util function to transform a database.InboxNotification to codersdk.InboxNotification
func convertInboxNotificationResponse(ctx context.Context, logger slog.Logger, notif database.InboxNotification) codersdk.InboxNotification {
return codersdk.InboxNotification{
@ -102,22 +63,33 @@ func convertInboxNotificationResponse(ctx context.Context, logger slog.Logger, n
// @Success 200 {object} codersdk.GetInboxNotificationResponse
// @Router /notifications/inbox/watch [get]
func (api *API) watchInboxNotifications(rw http.ResponseWriter, r *http.Request) {
p := httpapi.NewQueryParamParser()
vals := r.URL.Query()
var (
ctx = r.Context()
apikey = httpmw.APIKey(r)
)
var req codersdk.WatchInboxNotificationsRequest
if !httpapi.Read(ctx, rw, r, &req) {
targets = p.UUIDs(vals, []uuid.UUID{}, "targets")
templates = p.UUIDs(vals, []uuid.UUID{}, "templates")
readStatus = p.String(vals, "all", "read_status")
)
p.ErrorExcessParams(vals)
if len(p.Errors) > 0 {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Query parameters have invalid values.",
Validations: p.Errors,
})
return
}
targets, templates, readStatusParam, err := convertInboxNotificationParameters(ctx, api.Logger, req.Targets, req.Templates, req.Targets)
if err != nil {
if !slices.Contains([]string{
string(database.InboxNotificationReadStatusAll),
string(database.InboxNotificationReadStatusRead),
string(database.InboxNotificationReadStatusUnread),
}, readStatus) {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid query parameter.",
Detail: err.Error(),
Message: "starting_before query parameter should be any of 'all', 'read', 'unread'.",
})
return
}
@ -165,12 +137,12 @@ func (api *API) watchInboxNotifications(rw http.ResponseWriter, r *http.Request)
}
// filter out notifications that don't match the read status
if readStatusParam != "" {
if readStatusParam == string(database.InboxNotificationReadStatusRead) {
if readStatus != "" {
if readStatus == string(database.InboxNotificationReadStatusRead) {
if payload.InboxNotification.ReadAt == nil {
return
}
} else if readStatusParam == string(database.InboxNotificationReadStatusUnread) {
} else if readStatus == string(database.InboxNotificationReadStatusUnread) {
if payload.InboxNotification.ReadAt != nil {
return
}
@ -227,35 +199,48 @@ func (api *API) watchInboxNotifications(rw http.ResponseWriter, r *http.Request)
// @Success 200 {object} codersdk.ListInboxNotificationsResponse
// @Router /notifications/inbox [get]
func (api *API) listInboxNotifications(rw http.ResponseWriter, r *http.Request) {
p := httpapi.NewQueryParamParser()
vals := r.URL.Query()
var (
ctx = r.Context()
apikey = httpmw.APIKey(r)
targets = p.UUIDs(vals, []uuid.UUID{}, "targets")
templates = p.UUIDs(vals, []uuid.UUID{}, "templates")
readStatus = p.String(vals, "all", "read_status")
startingBefore = p.UUID(vals, uuid.Nil, "starting_before")
)
var req codersdk.ListInboxNotificationsRequest
if !httpapi.Read(ctx, rw, r, &req) {
return
}
targets, templates, readStatus, err := convertInboxNotificationParameters(ctx, api.Logger, req.Targets, req.Templates, req.ReadStatus)
if err != nil {
p.ErrorExcessParams(vals)
if len(p.Errors) > 0 {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid query parameter.",
Detail: err.Error(),
Message: "Query parameters have invalid values.",
Validations: p.Errors,
})
return
}
startingBefore := dbtime.Now()
if req.StartingBefore != uuid.Nil {
lastNotif, err := api.Database.GetInboxNotificationByID(ctx, req.StartingBefore)
if !slices.Contains([]string{
string(database.InboxNotificationReadStatusAll),
string(database.InboxNotificationReadStatusRead),
string(database.InboxNotificationReadStatusUnread),
}, readStatus) {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "starting_before query parameter should be any of 'all', 'read', 'unread'.",
})
return
}
createdBefore := dbtime.Now()
if startingBefore != uuid.Nil {
lastNotif, err := api.Database.GetInboxNotificationByID(ctx, startingBefore)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Failed to get notification by id.",
})
return
}
startingBefore = lastNotif.CreatedAt
createdBefore = lastNotif.CreatedAt
}
notifs, err := api.Database.GetFilteredInboxNotificationsByUserID(ctx, database.GetFilteredInboxNotificationsByUserIDParams{
@ -263,7 +248,7 @@ func (api *API) listInboxNotifications(rw http.ResponseWriter, r *http.Request)
Templates: templates,
Targets: targets,
ReadStatus: database.InboxNotificationReadStatus(readStatus),
CreatedAtOpt: startingBefore,
CreatedAtOpt: createdBefore,
})
if err != nil {
api.Logger.Error(ctx, "failed to get filtered inbox notifications", slog.Error(err))
@ -304,29 +289,23 @@ func (api *API) listInboxNotifications(rw http.ResponseWriter, r *http.Request)
// @Success 200 {object} codersdk.Response
// @Router /notifications/inbox/{id}/read-status [put]
func (api *API) updateInboxNotificationReadStatus(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
var (
apikey = httpmw.APIKey(r)
notifID = chi.URLParam(r, "id")
ctx = r.Context()
apikey = httpmw.APIKey(r)
)
notificationID, ok := httpmw.ParseUUIDParam(rw, r, "id")
if !ok {
}
var body codersdk.UpdateInboxNotificationReadStatusRequest
if !httpapi.Read(ctx, rw, r, &body) {
return
}
parsedNotifID, err := uuid.Parse(notifID)
if err != nil {
api.Logger.Error(ctx, "failed to parse notification uuid", slog.Error(err))
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Failed to parse notification uuid.",
})
return
}
err = api.Database.UpdateInboxNotificationReadStatus(ctx, database.UpdateInboxNotificationReadStatusParams{
ID: parsedNotifID,
err := api.Database.UpdateInboxNotificationReadStatus(ctx, database.UpdateInboxNotificationReadStatusParams{
ID: notificationID,
ReadAt: func() sql.NullTime {
if body.IsRead {
return sql.NullTime{
@ -355,7 +334,7 @@ func (api *API) updateInboxNotificationReadStatus(rw http.ResponseWriter, r *htt
return
}
updatedNotification, err := api.Database.GetInboxNotificationByID(ctx, parsedNotifID)
updatedNotification, err := api.Database.GetInboxNotificationByID(ctx, notificationID)
if err != nil {
api.Logger.Error(ctx, "failed to get notification by id", slog.Error(err))
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{

View File

@ -133,7 +133,7 @@ func TestInboxNotifications_List(t *testing.T) {
}
notifs, err = client.ListInboxNotifications(ctx, codersdk.ListInboxNotificationsRequest{
Templates: []uuid.UUID{notifications.TemplateWorkspaceOutOfMemory},
Templates: notifications.TemplateWorkspaceOutOfMemory.String(),
})
require.NoError(t, err)
require.NotNil(t, notifs)
@ -181,7 +181,7 @@ func TestInboxNotifications_List(t *testing.T) {
}
notifs, err = client.ListInboxNotifications(ctx, codersdk.ListInboxNotificationsRequest{
Targets: []uuid.UUID{filteredTarget},
Targets: filteredTarget.String(),
})
require.NoError(t, err)
require.NotNil(t, notifs)
@ -235,8 +235,8 @@ func TestInboxNotifications_List(t *testing.T) {
}
notifs, err = client.ListInboxNotifications(ctx, codersdk.ListInboxNotificationsRequest{
Targets: []uuid.UUID{filteredTarget},
Templates: []uuid.UUID{notifications.TemplateWorkspaceOutOfDisk},
Targets: filteredTarget.String(),
Templates: notifications.TemplateWorkspaceOutOfDisk.String(),
})
require.NoError(t, err)
require.NotNil(t, notifs)

View File

@ -1,17 +0,0 @@
package uuid
import (
"strings"
"github.com/google/uuid"
)
func FromSliceToString(uuids []uuid.UUID, separator string) string {
uuidStrings := make([]string, 0, len(uuids))
for _, u := range uuids {
uuidStrings = append(uuidStrings, u.String())
}
return strings.Join(uuidStrings, separator)
}

View File

@ -8,8 +8,6 @@ import (
"time"
"github.com/google/uuid"
utiluuid "github.com/coder/coder/v2/coderd/util/uuid"
)
type InboxNotification struct {
@ -30,12 +28,6 @@ type InboxNotificationAction struct {
URL string `json:"url"`
}
type WatchInboxNotificationsRequest struct {
Targets string `json:"targets,omitempty"`
Templates string `json:"templates,omitempty"`
ReadStatus string `json:"read_status,omitempty" validate:"omitempty,oneof=read unread all"`
}
type GetInboxNotificationResponse struct {
Notification InboxNotification `json:"notification"`
UnreadCount int `json:"unread_count"`
@ -56,10 +48,10 @@ type ListInboxNotificationsResponse struct {
func ListInboxNotificationsRequestToQueryParams(req ListInboxNotificationsRequest) []RequestOption {
var opts []RequestOption
if len(req.Targets) > 0 {
opts = append(opts, WithQueryParam("targets", utiluuid.FromSliceToString(req.Targets, ",")))
opts = append(opts, WithQueryParam("targets", req.Targets))
}
if len(req.Templates) > 0 {
opts = append(opts, WithQueryParam("templates", utiluuid.FromSliceToString(req.Templates, ",")))
opts = append(opts, WithQueryParam("templates", req.Templates))
}
if req.ReadStatus != "" {
opts = append(opts, WithQueryParam("read_status", req.ReadStatus))