chore: refactor audit page to use window function for count (#5133)

* Move count query to window function

* Unpack count and update types

* Remove count endpoint

* Update tests, wip

* Fix tests

* Update frontend, wip

* Remove space

* Fix frontend test

* Don't hang on error

* Handle no results

* Don't omit count

* Fix frontend tests
This commit is contained in:
Presley Pizzo
2022-11-21 11:30:41 -05:00
committed by GitHub
parent 7a369e0a30
commit 67941b4f80
15 changed files with 54 additions and 347 deletions

View File

@ -57,45 +57,19 @@ func (api *API) auditLogs(rw http.ResponseWriter, r *http.Request) {
httpapi.InternalServerError(rw, err)
return
}
httpapi.Write(ctx, rw, http.StatusOK, codersdk.AuditLogResponse{
AuditLogs: convertAuditLogs(dblogs),
})
}
func (api *API) auditLogCount(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if !api.Authorize(r, rbac.ActionRead, rbac.ResourceAuditLog) {
httpapi.Forbidden(rw)
return
}
queryStr := r.URL.Query().Get("q")
filter, errs := auditSearchQuery(queryStr)
if len(errs) > 0 {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid audit search query.",
Validations: errs,
// GetAuditLogsOffset does not return ErrNoRows because it uses a window function to get the count.
// So we need to check if the dblogs is empty and return an empty array if so.
if len(dblogs) == 0 {
httpapi.Write(ctx, rw, http.StatusOK, codersdk.AuditLogResponse{
AuditLogs: []codersdk.AuditLog{},
Count: 0,
})
return
}
count, err := api.Database.GetAuditLogCount(ctx, database.GetAuditLogCountParams{
ResourceType: filter.ResourceType,
ResourceID: filter.ResourceID,
Action: filter.Action,
Username: filter.Username,
Email: filter.Email,
DateFrom: filter.DateFrom,
DateTo: filter.DateTo,
})
if err != nil {
httpapi.InternalServerError(rw, err)
return
}
httpapi.Write(ctx, rw, http.StatusOK, codersdk.AuditLogCountResponse{
Count: count,
httpapi.Write(ctx, rw, http.StatusOK, codersdk.AuditLogResponse{
AuditLogs: convertAuditLogs(dblogs),
Count: dblogs[0].Count,
})
}

View File

@ -25,9 +25,6 @@ func TestAuditLogs(t *testing.T) {
err := client.CreateTestAuditLog(ctx, codersdk.CreateTestAuditLogRequest{})
require.NoError(t, err)
count, err := client.AuditLogCount(ctx, codersdk.AuditLogCountRequest{})
require.NoError(t, err)
alogs, err := client.AuditLogs(ctx, codersdk.AuditLogsRequest{
Pagination: codersdk.Pagination{
Limit: 1,
@ -35,7 +32,7 @@ func TestAuditLogs(t *testing.T) {
})
require.NoError(t, err)
require.Equal(t, int64(1), count.Count)
require.Equal(t, int64(1), alogs.Count)
require.Len(t, alogs.AuditLogs, 1)
})
}
@ -161,16 +158,7 @@ func TestAuditLogsFilter(t *testing.T) {
})
require.NoError(t, err, "fetch audit logs")
require.Len(t, auditLogs.AuditLogs, testCase.ExpectedResult, "expected audit logs returned")
})
// Test count filtering
t.Run("GetCount"+testCase.Name, func(t *testing.T) {
t.Parallel()
response, err := client.AuditLogCount(ctx, codersdk.AuditLogCountRequest{
SearchQuery: testCase.SearchQuery,
})
require.NoError(t, err, "fetch audit logs count")
require.Equal(t, int(response.Count), testCase.ExpectedResult, "expected audit logs count returned")
require.Equal(t, testCase.ExpectedResult, int(auditLogs.Count), "expected audit log count returned")
})
}
})

View File

@ -318,7 +318,6 @@ func New(options *Options) *API {
)
r.Get("/", api.auditLogs)
r.Get("/count", api.auditLogCount)
r.Post("/testgenerate", api.generateFakeAuditLog)
})
r.Route("/files", func(r chi.Router) {

View File

@ -3100,6 +3100,7 @@ func (q *fakeQuerier) GetAuditLogsOffset(ctx context.Context, arg database.GetAu
UserCreatedAt: sql.NullTime{Time: user.CreatedAt, Valid: userValid},
UserStatus: user.Status,
UserRoles: user.RBACRoles,
Count: 0,
})
if len(logs) >= int(arg.Limit) {
@ -3107,52 +3108,12 @@ func (q *fakeQuerier) GetAuditLogsOffset(ctx context.Context, arg database.GetAu
}
}
return logs, nil
}
func (q *fakeQuerier) GetAuditLogCount(_ context.Context, arg database.GetAuditLogCountParams) (int64, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
logs := make([]database.AuditLog, 0)
for _, alog := range q.auditLogs {
if arg.Action != "" && !strings.Contains(string(alog.Action), arg.Action) {
continue
}
if arg.ResourceType != "" && !strings.Contains(string(alog.ResourceType), arg.ResourceType) {
continue
}
if arg.ResourceID != uuid.Nil && alog.ResourceID != arg.ResourceID {
continue
}
if arg.Username != "" {
user, err := q.GetUserByID(context.Background(), alog.UserID)
if err == nil && !strings.EqualFold(arg.Username, user.Username) {
continue
}
}
if arg.Email != "" {
user, err := q.GetUserByID(context.Background(), alog.UserID)
if err == nil && !strings.EqualFold(arg.Email, user.Email) {
continue
}
}
if !arg.DateFrom.IsZero() {
if alog.Time.Before(arg.DateFrom) {
continue
}
}
if !arg.DateTo.IsZero() {
if alog.Time.After(arg.DateTo) {
continue
}
}
logs = append(logs, alog)
count := int64(len(logs))
for i := range logs {
logs[i].Count = count
}
return int64(len(logs)), nil
return logs, nil
}
func (q *fakeQuerier) InsertAuditLog(_ context.Context, arg database.InsertAuditLogParams) (database.AuditLog, error) {

View File

@ -33,7 +33,6 @@ type sqlcQuerier interface {
GetAPIKeysLastUsedAfter(ctx context.Context, lastUsed time.Time) ([]APIKey, error)
GetActiveUserCount(ctx context.Context) (int64, error)
GetAllOrganizationMembers(ctx context.Context, organizationID uuid.UUID) ([]User, error)
GetAuditLogCount(ctx context.Context, arg GetAuditLogCountParams) (int64, error)
// GetAuditLogsBefore retrieves `row_limit` number of audit logs before the provided
// ID.
GetAuditLogsOffset(ctx context.Context, arg GetAuditLogsOffsetParams) ([]GetAuditLogsOffsetRow, error)

View File

@ -365,89 +365,6 @@ func (q *sqlQuerier) UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDP
return err
}
const getAuditLogCount = `-- name: GetAuditLogCount :one
SELECT
COUNT(*) as count
FROM
audit_logs
WHERE
-- Filter resource_type
CASE
WHEN $1 :: text != '' THEN
resource_type = $1 :: resource_type
ELSE true
END
-- Filter resource_id
AND CASE
WHEN $2 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
resource_id = $2
ELSE true
END
-- Filter by resource_target
AND CASE
WHEN $3 :: text != '' THEN
resource_target = $3
ELSE true
END
-- Filter action
AND CASE
WHEN $4 :: text != '' THEN
action = $4 :: audit_action
ELSE true
END
-- Filter by username
AND CASE
WHEN $5 :: text != '' THEN
user_id = (SELECT id from users WHERE users.username = $5 )
ELSE true
END
-- Filter by user_email
AND CASE
WHEN $6 :: text != '' THEN
user_id = (SELECT id from users WHERE users.email = $6 )
ELSE true
END
-- Filter by date_from
AND CASE
WHEN $7 :: timestamp with time zone != '0001-01-01 00:00:00' THEN
"time" >= $7
ELSE true
END
-- Filter by date_to
AND CASE
WHEN $8 :: timestamp with time zone != '0001-01-01 00:00:00' THEN
"time" <= $8
ELSE true
END
`
type GetAuditLogCountParams struct {
ResourceType string `db:"resource_type" json:"resource_type"`
ResourceID uuid.UUID `db:"resource_id" json:"resource_id"`
ResourceTarget string `db:"resource_target" json:"resource_target"`
Action string `db:"action" json:"action"`
Username string `db:"username" json:"username"`
Email string `db:"email" json:"email"`
DateFrom time.Time `db:"date_from" json:"date_from"`
DateTo time.Time `db:"date_to" json:"date_to"`
}
func (q *sqlQuerier) GetAuditLogCount(ctx context.Context, arg GetAuditLogCountParams) (int64, error) {
row := q.db.QueryRowContext(ctx, getAuditLogCount,
arg.ResourceType,
arg.ResourceID,
arg.ResourceTarget,
arg.Action,
arg.Username,
arg.Email,
arg.DateFrom,
arg.DateTo,
)
var count int64
err := row.Scan(&count)
return count, err
}
const getAuditLogsOffset = `-- name: GetAuditLogsOffset :many
SELECT
audit_logs.id, audit_logs.time, audit_logs.user_id, audit_logs.organization_id, audit_logs.ip, audit_logs.user_agent, audit_logs.resource_type, audit_logs.resource_id, audit_logs.resource_target, audit_logs.action, audit_logs.diff, audit_logs.status_code, audit_logs.additional_fields, audit_logs.request_id, audit_logs.resource_icon,
@ -456,7 +373,8 @@ SELECT
users.created_at AS user_created_at,
users.status AS user_status,
users.rbac_roles AS user_roles,
users.avatar_url AS user_avatar_url
users.avatar_url AS user_avatar_url,
COUNT(audit_logs.*) OVER() AS count
FROM
audit_logs
LEFT JOIN
@ -553,6 +471,7 @@ type GetAuditLogsOffsetRow struct {
UserStatus UserStatus `db:"user_status" json:"user_status"`
UserRoles []string `db:"user_roles" json:"user_roles"`
UserAvatarUrl sql.NullString `db:"user_avatar_url" json:"user_avatar_url"`
Count int64 `db:"count" json:"count"`
}
// GetAuditLogsBefore retrieves `row_limit` number of audit logs before the provided
@ -599,6 +518,7 @@ func (q *sqlQuerier) GetAuditLogsOffset(ctx context.Context, arg GetAuditLogsOff
&i.UserStatus,
pq.Array(&i.UserRoles),
&i.UserAvatarUrl,
&i.Count,
); err != nil {
return nil, err
}

View File

@ -8,7 +8,8 @@ SELECT
users.created_at AS user_created_at,
users.status AS user_status,
users.rbac_roles AS user_roles,
users.avatar_url AS user_avatar_url
users.avatar_url AS user_avatar_url,
COUNT(audit_logs.*) OVER() AS count
FROM
audit_logs
LEFT JOIN
@ -69,61 +70,6 @@ LIMIT
OFFSET
$2;
-- name: GetAuditLogCount :one
SELECT
COUNT(*) as count
FROM
audit_logs
WHERE
-- Filter resource_type
CASE
WHEN @resource_type :: text != '' THEN
resource_type = @resource_type :: resource_type
ELSE true
END
-- Filter resource_id
AND CASE
WHEN @resource_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
resource_id = @resource_id
ELSE true
END
-- Filter by resource_target
AND CASE
WHEN @resource_target :: text != '' THEN
resource_target = @resource_target
ELSE true
END
-- Filter action
AND CASE
WHEN @action :: text != '' THEN
action = @action :: audit_action
ELSE true
END
-- Filter by username
AND CASE
WHEN @username :: text != '' THEN
user_id = (SELECT id from users WHERE users.username = @username )
ELSE true
END
-- Filter by user_email
AND CASE
WHEN @email :: text != '' THEN
user_id = (SELECT id from users WHERE users.email = @email )
ELSE true
END
-- Filter by date_from
AND CASE
WHEN @date_from :: timestamp with time zone != '0001-01-01 00:00:00' THEN
"time" >= @date_from
ELSE true
END
-- Filter by date_to
AND CASE
WHEN @date_to :: timestamp with time zone != '0001-01-01 00:00:00' THEN
"time" <= @date_to
ELSE true
END;
-- name: InsertAuditLog :one
INSERT INTO
audit_logs (