feat(coderd): add mark-all-as-read endpoint for inbox notifications (#16976)

[Resolve this issue](https://github.com/coder/internal/issues/506)

Add a mark-all-as-read endpoint which is marking as read all
notifications that are not read for the authenticated user.
Also adds the DB logic.
This commit is contained in:
Vincent Vielle
2025-03-20 13:41:54 +01:00
committed by GitHub
parent d8d4b9b86e
commit 4960a1e85a
15 changed files with 262 additions and 0 deletions

19
coderd/apidoc/docs.go generated
View File

@ -1705,6 +1705,25 @@ const docTemplate = `{
}
}
},
"/notifications/inbox/mark-all-as-read": {
"put": {
"security": [
{
"CoderSessionToken": []
}
],
"tags": [
"Notifications"
],
"summary": "Mark all unread notifications as read",
"operationId": "mark-all-unread-notifications-as-read",
"responses": {
"204": {
"description": "No Content"
}
}
}
},
"/notifications/inbox/watch": {
"get": {
"security": [

View File

@ -1486,6 +1486,23 @@
}
}
},
"/notifications/inbox/mark-all-as-read": {
"put": {
"security": [
{
"CoderSessionToken": []
}
],
"tags": ["Notifications"],
"summary": "Mark all unread notifications as read",
"operationId": "mark-all-unread-notifications-as-read",
"responses": {
"204": {
"description": "No Content"
}
}
}
},
"/notifications/inbox/watch": {
"get": {
"security": [

View File

@ -1395,6 +1395,7 @@ func New(options *Options) *API {
r.Use(apiKeyMiddleware)
r.Route("/inbox", func(r chi.Router) {
r.Get("/", api.listInboxNotifications)
r.Put("/mark-all-as-read", api.markAllInboxNotificationsAsRead)
r.Get("/watch", api.watchInboxNotifications)
r.Put("/{id}/read-status", api.updateInboxNotificationReadStatus)
})

View File

@ -3554,6 +3554,16 @@ func (q *querier) ListWorkspaceAgentPortShares(ctx context.Context, workspaceID
return q.db.ListWorkspaceAgentPortShares(ctx, workspaceID)
}
func (q *querier) MarkAllInboxNotificationsAsRead(ctx context.Context, arg database.MarkAllInboxNotificationsAsReadParams) error {
resource := rbac.ResourceInboxNotification.WithOwner(arg.UserID.String())
if err := q.authorizeContext(ctx, policy.ActionUpdate, resource); err != nil {
return err
}
return q.db.MarkAllInboxNotificationsAsRead(ctx, arg)
}
func (q *querier) OIDCClaimFieldValues(ctx context.Context, args database.OIDCClaimFieldValuesParams) ([]string, error) {
resource := rbac.ResourceIdpsyncSettings
if args.OrganizationID != uuid.Nil {

View File

@ -4653,6 +4653,15 @@ func (s *MethodTestSuite) TestNotifications() {
ReadAt: sql.NullTime{Time: readAt, Valid: true},
}).Asserts(rbac.ResourceInboxNotification.WithID(notifID).WithOwner(u.ID.String()), policy.ActionUpdate)
}))
s.Run("MarkAllInboxNotificationsAsRead", s.Subtest(func(db database.Store, check *expects) {
u := dbgen.User(s.T(), db, database.User{})
check.Args(database.MarkAllInboxNotificationsAsReadParams{
UserID: u.ID,
ReadAt: sql.NullTime{Time: dbtestutil.NowInDefaultTimezone(), Valid: true},
}).Asserts(rbac.ResourceInboxNotification.WithOwner(u.ID.String()), policy.ActionUpdate)
}))
}
func (s *MethodTestSuite) TestOAuth2ProviderApps() {

View File

@ -9500,6 +9500,21 @@ func (q *FakeQuerier) ListWorkspaceAgentPortShares(_ context.Context, workspaceI
return shares, nil
}
func (q *FakeQuerier) MarkAllInboxNotificationsAsRead(_ context.Context, arg database.MarkAllInboxNotificationsAsReadParams) error {
err := validateDatabaseType(arg)
if err != nil {
return err
}
for idx, notif := range q.inboxNotifications {
if notif.UserID == arg.UserID && !notif.ReadAt.Valid {
q.inboxNotifications[idx].ReadAt = arg.ReadAt
}
}
return nil
}
// nolint:forcetypeassert
func (q *FakeQuerier) OIDCClaimFieldValues(_ context.Context, args database.OIDCClaimFieldValuesParams) ([]string, error) {
orgMembers := q.getOrganizationMemberNoLock(args.OrganizationID)

View File

@ -2257,6 +2257,13 @@ func (m queryMetricsStore) ListWorkspaceAgentPortShares(ctx context.Context, wor
return r0, r1
}
func (m queryMetricsStore) MarkAllInboxNotificationsAsRead(ctx context.Context, arg database.MarkAllInboxNotificationsAsReadParams) error {
start := time.Now()
r0 := m.s.MarkAllInboxNotificationsAsRead(ctx, arg)
m.queryLatencies.WithLabelValues("MarkAllInboxNotificationsAsRead").Observe(time.Since(start).Seconds())
return r0
}
func (m queryMetricsStore) OIDCClaimFieldValues(ctx context.Context, organizationID database.OIDCClaimFieldValuesParams) ([]string, error) {
start := time.Now()
r0, r1 := m.s.OIDCClaimFieldValues(ctx, organizationID)

View File

@ -4763,6 +4763,20 @@ func (mr *MockStoreMockRecorder) ListWorkspaceAgentPortShares(ctx, workspaceID a
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListWorkspaceAgentPortShares", reflect.TypeOf((*MockStore)(nil).ListWorkspaceAgentPortShares), ctx, workspaceID)
}
// MarkAllInboxNotificationsAsRead mocks base method.
func (m *MockStore) MarkAllInboxNotificationsAsRead(ctx context.Context, arg database.MarkAllInboxNotificationsAsReadParams) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "MarkAllInboxNotificationsAsRead", ctx, arg)
ret0, _ := ret[0].(error)
return ret0
}
// MarkAllInboxNotificationsAsRead indicates an expected call of MarkAllInboxNotificationsAsRead.
func (mr *MockStoreMockRecorder) MarkAllInboxNotificationsAsRead(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkAllInboxNotificationsAsRead", reflect.TypeOf((*MockStore)(nil).MarkAllInboxNotificationsAsRead), ctx, arg)
}
// OIDCClaimFieldValues mocks base method.
func (m *MockStore) OIDCClaimFieldValues(ctx context.Context, arg database.OIDCClaimFieldValuesParams) ([]string, error) {
m.ctrl.T.Helper()

View File

@ -469,6 +469,7 @@ type sqlcQuerier interface {
ListProvisionerKeysByOrganization(ctx context.Context, organizationID uuid.UUID) ([]ProvisionerKey, error)
ListProvisionerKeysByOrganizationExcludeReserved(ctx context.Context, organizationID uuid.UUID) ([]ProvisionerKey, error)
ListWorkspaceAgentPortShares(ctx context.Context, workspaceID uuid.UUID) ([]WorkspaceAgentPortShare, error)
MarkAllInboxNotificationsAsRead(ctx context.Context, arg MarkAllInboxNotificationsAsReadParams) error
OIDCClaimFieldValues(ctx context.Context, arg OIDCClaimFieldValuesParams) ([]string, error)
// OIDCClaimFields returns a list of distinct keys in the the merged_claims fields.
// This query is used to generate the list of available sync fields for idp sync settings.

View File

@ -4511,6 +4511,25 @@ func (q *sqlQuerier) InsertInboxNotification(ctx context.Context, arg InsertInbo
return i, err
}
const markAllInboxNotificationsAsRead = `-- name: MarkAllInboxNotificationsAsRead :exec
UPDATE
inbox_notifications
SET
read_at = $1
WHERE
user_id = $2 and read_at IS NULL
`
type MarkAllInboxNotificationsAsReadParams struct {
ReadAt sql.NullTime `db:"read_at" json:"read_at"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
}
func (q *sqlQuerier) MarkAllInboxNotificationsAsRead(ctx context.Context, arg MarkAllInboxNotificationsAsReadParams) error {
_, err := q.db.ExecContext(ctx, markAllInboxNotificationsAsRead, arg.ReadAt, arg.UserID)
return err
}
const updateInboxNotificationReadStatus = `-- name: UpdateInboxNotificationReadStatus :exec
UPDATE
inbox_notifications

View File

@ -57,3 +57,11 @@ SET
read_at = $1
WHERE
id = $2;
-- name: MarkAllInboxNotificationsAsRead :exec
UPDATE
inbox_notifications
SET
read_at = $1
WHERE
user_id = $2 and read_at IS NULL;

View File

@ -344,3 +344,31 @@ func (api *API) updateInboxNotificationReadStatus(rw http.ResponseWriter, r *htt
UnreadCount: int(unreadCount),
})
}
// markAllInboxNotificationsAsRead marks as read all unread notifications for authenticated user.
// @Summary Mark all unread notifications as read
// @ID mark-all-unread-notifications-as-read
// @Security CoderSessionToken
// @Tags Notifications
// @Success 204
// @Router /notifications/inbox/mark-all-as-read [put]
func (api *API) markAllInboxNotificationsAsRead(rw http.ResponseWriter, r *http.Request) {
var (
ctx = r.Context()
apikey = httpmw.APIKey(r)
)
err := api.Database.MarkAllInboxNotificationsAsRead(ctx, database.MarkAllInboxNotificationsAsReadParams{
UserID: apikey.UserID,
ReadAt: sql.NullTime{Time: dbtime.Now(), Valid: true},
})
if err != nil {
api.Logger.Error(ctx, "failed to mark all unread notifications as read", slog.Error(err))
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to mark all unread notifications as read.",
})
return
}
rw.WriteHeader(http.StatusNoContent)
}

View File

@ -37,6 +37,7 @@ func TestInboxNotification_Watch(t *testing.T) {
// I skip these tests specifically on windows as for now they are flaky - only on Windows.
// For now the idea is that the runner takes too long to insert the entries, could be worth
// investigating a manual Tx.
// see: https://github.com/coder/internal/issues/503
if runtime.GOOS == "windows" {
t.Skip("our runners are randomly taking too long to insert entries")
}
@ -312,6 +313,7 @@ func TestInboxNotifications_List(t *testing.T) {
// I skip these tests specifically on windows as for now they are flaky - only on Windows.
// For now the idea is that the runner takes too long to insert the entries, could be worth
// investigating a manual Tx.
// see: https://github.com/coder/internal/issues/503
if runtime.GOOS == "windows" {
t.Skip("our runners are randomly taking too long to insert entries")
}
@ -595,6 +597,7 @@ func TestInboxNotifications_ReadStatus(t *testing.T) {
// I skip these tests specifically on windows as for now they are flaky - only on Windows.
// For now the idea is that the runner takes too long to insert the entries, could be worth
// investigating a manual Tx.
// see: https://github.com/coder/internal/issues/503
if runtime.GOOS == "windows" {
t.Skip("our runners are randomly taking too long to insert entries")
}
@ -730,3 +733,76 @@ func TestInboxNotifications_ReadStatus(t *testing.T) {
require.Empty(t, updatedNotif.Notification)
})
}
func TestInboxNotifications_MarkAllAsRead(t *testing.T) {
t.Parallel()
// I skip these tests specifically on windows as for now they are flaky - only on Windows.
// For now the idea is that the runner takes too long to insert the entries, could be worth
// investigating a manual Tx.
// see: https://github.com/coder/internal/issues/503
if runtime.GOOS == "windows" {
t.Skip("our runners are randomly taking too long to insert entries")
}
t.Run("ok", func(t *testing.T) {
t.Parallel()
client, _, api := coderdtest.NewWithAPI(t, &coderdtest.Options{})
firstUser := coderdtest.CreateFirstUser(t, client)
client, member := coderdtest.CreateAnotherUser(t, client, firstUser.OrganizationID)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
notifs, err := client.ListInboxNotifications(ctx, codersdk.ListInboxNotificationsRequest{})
require.NoError(t, err)
require.NotNil(t, notifs)
require.Equal(t, 0, notifs.UnreadCount)
require.Empty(t, notifs.Notifications)
for i := range 20 {
dbgen.NotificationInbox(t, api.Database, database.InsertInboxNotificationParams{
ID: uuid.New(),
UserID: member.ID,
TemplateID: notifications.TemplateWorkspaceOutOfMemory,
Title: fmt.Sprintf("Notification %d", i),
Actions: json.RawMessage("[]"),
Content: fmt.Sprintf("Content of the notif %d", i),
CreatedAt: dbtime.Now(),
})
}
notifs, err = client.ListInboxNotifications(ctx, codersdk.ListInboxNotificationsRequest{})
require.NoError(t, err)
require.NotNil(t, notifs)
require.Equal(t, 20, notifs.UnreadCount)
require.Len(t, notifs.Notifications, 20)
err = client.MarkAllInboxNotificationsAsRead(ctx)
require.NoError(t, err)
notifs, err = client.ListInboxNotifications(ctx, codersdk.ListInboxNotificationsRequest{})
require.NoError(t, err)
require.NotNil(t, notifs)
require.Equal(t, 0, notifs.UnreadCount)
require.Len(t, notifs.Notifications, 20)
for i := range 10 {
dbgen.NotificationInbox(t, api.Database, database.InsertInboxNotificationParams{
ID: uuid.New(),
UserID: member.ID,
TemplateID: notifications.TemplateWorkspaceOutOfMemory,
Title: fmt.Sprintf("Notification %d", i),
Actions: json.RawMessage("[]"),
Content: fmt.Sprintf("Content of the notif %d", i),
CreatedAt: dbtime.Now(),
})
}
notifs, err = client.ListInboxNotifications(ctx, codersdk.ListInboxNotificationsRequest{})
require.NoError(t, err)
require.NotNil(t, notifs)
require.Equal(t, 10, notifs.UnreadCount)
require.Len(t, notifs.Notifications, 25)
})
}