diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 41fa20392f..8bfede13e8 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -575,11 +575,6 @@ func (q *querier) canAssignRoles(ctx context.Context, orgID *uuid.UUID, added, r return nil } -func (q *querier) GetAuthorizedTemplates(ctx context.Context, arg database.GetTemplatesWithFilterParams, _ rbac.PreparedAuthorized) ([]database.Template, error) { - // TODO Delete this function, all GetTemplates should be authorized. For now just call getTemplates on the authz querier. - return q.GetTemplatesWithFilter(ctx, arg) -} - func (q *querier) SoftDeleteTemplateByID(ctx context.Context, id uuid.UUID) error { deleteF := func(ctx context.Context, id uuid.UUID) error { return q.db.UpdateTemplateDeletedByID(ctx, database.UpdateTemplateDeletedByIDParams{ @@ -591,34 +586,6 @@ func (q *querier) SoftDeleteTemplateByID(ctx context.Context, id uuid.UUID) erro return deleteQ(q.log, q.auth, q.db.GetTemplateByID, deleteF)(ctx, id) } -func (q *querier) GetTemplateGroupRoles(ctx context.Context, id uuid.UUID) ([]database.TemplateGroup, error) { - // An actor is authorized to read template group roles if they are authorized to read the template. - template, err := q.db.GetTemplateByID(ctx, id) - if err != nil { - return nil, err - } - if err := q.authorizeContext(ctx, rbac.ActionRead, template); err != nil { - return nil, err - } - return q.db.GetTemplateGroupRoles(ctx, id) -} - -func (q *querier) GetTemplateUserRoles(ctx context.Context, id uuid.UUID) ([]database.TemplateUser, error) { - // An actor is authorized to query template user roles if they are authorized to read the template. - template, err := q.db.GetTemplateByID(ctx, id) - if err != nil { - return nil, err - } - if err := q.authorizeContext(ctx, rbac.ActionRead, template); err != nil { - return nil, err - } - return q.db.GetTemplateUserRoles(ctx, id) -} - -func (q *querier) GetAuthorizedUserCount(ctx context.Context, arg database.GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) { - return q.db.GetAuthorizedUserCount(ctx, arg, prepared) -} - func (q *querier) GetUsersWithCount(ctx context.Context, arg database.GetUsersParams) ([]database.User, int64, error) { // TODO Implement this with a SQL filter. The count is incorrect without it. rowUsers, err := q.db.GetUsers(ctx, arg) @@ -655,11 +622,6 @@ func (q *querier) SoftDeleteUserByID(ctx context.Context, id uuid.UUID) error { return deleteQ(q.log, q.auth, q.db.GetUserByID, deleteF)(ctx, id) } -func (q *querier) GetAuthorizedWorkspaces(ctx context.Context, arg database.GetWorkspacesParams, _ rbac.PreparedAuthorized) ([]database.GetWorkspacesRow, error) { - // TODO Delete this function, all GetWorkspaces should be authorized. For now just call GetWorkspaces on the authz querier. - return q.GetWorkspaces(ctx, arg) -} - func (q *querier) SoftDeleteWorkspaceByID(ctx context.Context, id uuid.UUID) error { return deleteQ(q.log, q.auth, q.db.GetWorkspaceByID, func(ctx context.Context, id uuid.UUID) error { return q.db.UpdateWorkspaceDeletedByID(ctx, database.UpdateWorkspaceDeletedByIDParams{ @@ -2642,3 +2604,41 @@ func (q *querier) UpsertTailnetCoordinator(ctx context.Context, id uuid.UUID) (d } return q.db.UpsertTailnetCoordinator(ctx, id) } + +func (q *querier) GetAuthorizedTemplates(ctx context.Context, arg database.GetTemplatesWithFilterParams, _ rbac.PreparedAuthorized) ([]database.Template, error) { + // TODO Delete this function, all GetTemplates should be authorized. For now just call getTemplates on the authz querier. + return q.GetTemplatesWithFilter(ctx, arg) +} + +func (q *querier) GetTemplateGroupRoles(ctx context.Context, id uuid.UUID) ([]database.TemplateGroup, error) { + // An actor is authorized to read template group roles if they are authorized to read the template. + template, err := q.db.GetTemplateByID(ctx, id) + if err != nil { + return nil, err + } + if err := q.authorizeContext(ctx, rbac.ActionRead, template); err != nil { + return nil, err + } + return q.db.GetTemplateGroupRoles(ctx, id) +} + +func (q *querier) GetTemplateUserRoles(ctx context.Context, id uuid.UUID) ([]database.TemplateUser, error) { + // An actor is authorized to query template user roles if they are authorized to read the template. + template, err := q.db.GetTemplateByID(ctx, id) + if err != nil { + return nil, err + } + if err := q.authorizeContext(ctx, rbac.ActionRead, template); err != nil { + return nil, err + } + return q.db.GetTemplateUserRoles(ctx, id) +} + +func (q *querier) GetAuthorizedWorkspaces(ctx context.Context, arg database.GetWorkspacesParams, _ rbac.PreparedAuthorized) ([]database.GetWorkspacesRow, error) { + // TODO Delete this function, all GetWorkspaces should be authorized. For now just call GetWorkspaces on the authz querier. + return q.GetWorkspaces(ctx, arg) +} + +func (q *querier) GetAuthorizedUserCount(ctx context.Context, arg database.GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) { + return q.db.GetAuthorizedUserCount(ctx, arg, prepared) +} diff --git a/coderd/database/dbfake/dbfake.go b/coderd/database/dbfake/dbfake.go index 4aa8334e5f..02e70639cd 100644 --- a/coderd/database/dbfake/dbfake.go +++ b/coderd/database/dbfake/dbfake.go @@ -266,80 +266,6 @@ func (q *FakeQuerier) getUserByIDNoLock(id uuid.UUID) (database.User, error) { return database.User{}, sql.ErrNoRows } -func (q *FakeQuerier) GetAuthorizedUserCount(ctx context.Context, params database.GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) { - if err := validateDatabaseType(params); err != nil { - return 0, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - // Call this to match the same function calls as the SQL implementation. - if prepared != nil { - _, err := prepared.CompileToSQL(ctx, rbac.ConfigWithoutACL()) - if err != nil { - return -1, err - } - } - - users := make([]database.User, 0, len(q.users)) - - for _, user := range q.users { - // If the filter exists, ensure the object is authorized. - if prepared != nil && prepared.Authorize(ctx, user.RBACObject()) != nil { - continue - } - - users = append(users, user) - } - - // Filter out deleted since they should never be returned.. - tmp := make([]database.User, 0, len(users)) - for _, user := range users { - if !user.Deleted { - tmp = append(tmp, user) - } - } - users = tmp - - if params.Search != "" { - tmp := make([]database.User, 0, len(users)) - for i, user := range users { - if strings.Contains(strings.ToLower(user.Email), strings.ToLower(params.Search)) { - tmp = append(tmp, users[i]) - } else if strings.Contains(strings.ToLower(user.Username), strings.ToLower(params.Search)) { - tmp = append(tmp, users[i]) - } - } - users = tmp - } - - if len(params.Status) > 0 { - usersFilteredByStatus := make([]database.User, 0, len(users)) - for i, user := range users { - if slice.ContainsCompare(params.Status, user.Status, func(a, b database.UserStatus) bool { - return strings.EqualFold(string(a), string(b)) - }) { - usersFilteredByStatus = append(usersFilteredByStatus, users[i]) - } - } - users = usersFilteredByStatus - } - - if len(params.RbacRole) > 0 && !slice.Contains(params.RbacRole, rbac.RoleMember()) { - usersFilteredByRole := make([]database.User, 0, len(users)) - for i, user := range users { - if slice.OverlapCompare(params.RbacRole, user.RBACRoles, strings.EqualFold) { - usersFilteredByRole = append(usersFilteredByRole, users[i]) - } - } - - users = usersFilteredByRole - } - - return int64(len(users)), nil -} - func convertUsers(users []database.User, count int64) []database.GetUsersRow { rows := make([]database.GetUsersRow, len(users)) for i, u := range users { @@ -363,259 +289,6 @@ func convertUsers(users []database.User, count int64) []database.GetUsersRow { return rows } -//nolint:gocyclo -func (q *FakeQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg database.GetWorkspacesParams, prepared rbac.PreparedAuthorized) ([]database.GetWorkspacesRow, error) { - if err := validateDatabaseType(arg); err != nil { - return nil, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - if prepared != nil { - // Call this to match the same function calls as the SQL implementation. - _, err := prepared.CompileToSQL(ctx, rbac.ConfigWithoutACL()) - if err != nil { - return nil, err - } - } - - workspaces := make([]database.Workspace, 0) - for _, workspace := range q.workspaces { - if arg.OwnerID != uuid.Nil && workspace.OwnerID != arg.OwnerID { - continue - } - - if arg.OwnerUsername != "" { - owner, err := q.getUserByIDNoLock(workspace.OwnerID) - if err == nil && !strings.EqualFold(arg.OwnerUsername, owner.Username) { - continue - } - } - - if arg.TemplateName != "" { - template, err := q.getTemplateByIDNoLock(ctx, workspace.TemplateID) - if err == nil && !strings.EqualFold(arg.TemplateName, template.Name) { - continue - } - } - - if !arg.Deleted && workspace.Deleted { - continue - } - - if arg.Name != "" && !strings.Contains(strings.ToLower(workspace.Name), strings.ToLower(arg.Name)) { - continue - } - - if arg.Status != "" { - build, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, workspace.ID) - if err != nil { - return nil, xerrors.Errorf("get latest build: %w", err) - } - - job, err := q.getProvisionerJobByIDNoLock(ctx, build.JobID) - if err != nil { - return nil, xerrors.Errorf("get provisioner job: %w", err) - } - - // This logic should match the logic in the workspace.sql file. - var statusMatch bool - switch database.WorkspaceStatus(arg.Status) { - case database.WorkspaceStatusPending: - statusMatch = isNull(job.StartedAt) - case database.WorkspaceStatusStarting: - statusMatch = isNotNull(job.StartedAt) && - isNull(job.CanceledAt) && - isNull(job.CompletedAt) && - time.Since(job.UpdatedAt) < 30*time.Second && - build.Transition == database.WorkspaceTransitionStart - - case database.WorkspaceStatusRunning: - statusMatch = isNotNull(job.CompletedAt) && - isNull(job.CanceledAt) && - isNull(job.Error) && - build.Transition == database.WorkspaceTransitionStart - - case database.WorkspaceStatusStopping: - statusMatch = isNotNull(job.StartedAt) && - isNull(job.CanceledAt) && - isNull(job.CompletedAt) && - time.Since(job.UpdatedAt) < 30*time.Second && - build.Transition == database.WorkspaceTransitionStop - - case database.WorkspaceStatusStopped: - statusMatch = isNotNull(job.CompletedAt) && - isNull(job.CanceledAt) && - isNull(job.Error) && - build.Transition == database.WorkspaceTransitionStop - case database.WorkspaceStatusFailed: - statusMatch = (isNotNull(job.CanceledAt) && isNotNull(job.Error)) || - (isNotNull(job.CompletedAt) && isNotNull(job.Error)) - - case database.WorkspaceStatusCanceling: - statusMatch = isNotNull(job.CanceledAt) && - isNull(job.CompletedAt) - - case database.WorkspaceStatusCanceled: - statusMatch = isNotNull(job.CanceledAt) && - isNotNull(job.CompletedAt) - - case database.WorkspaceStatusDeleted: - statusMatch = isNotNull(job.StartedAt) && - isNull(job.CanceledAt) && - isNotNull(job.CompletedAt) && - time.Since(job.UpdatedAt) < 30*time.Second && - build.Transition == database.WorkspaceTransitionDelete && - isNull(job.Error) - - case database.WorkspaceStatusDeleting: - statusMatch = isNull(job.CompletedAt) && - isNull(job.CanceledAt) && - isNull(job.Error) && - build.Transition == database.WorkspaceTransitionDelete - - default: - return nil, xerrors.Errorf("unknown workspace status in filter: %q", arg.Status) - } - if !statusMatch { - continue - } - } - - if arg.HasAgent != "" { - build, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, workspace.ID) - if err != nil { - return nil, xerrors.Errorf("get latest build: %w", err) - } - - job, err := q.getProvisionerJobByIDNoLock(ctx, build.JobID) - if err != nil { - return nil, xerrors.Errorf("get provisioner job: %w", err) - } - - workspaceResources, err := q.getWorkspaceResourcesByJobIDNoLock(ctx, job.ID) - if err != nil { - return nil, xerrors.Errorf("get workspace resources: %w", err) - } - - var workspaceResourceIDs []uuid.UUID - for _, wr := range workspaceResources { - workspaceResourceIDs = append(workspaceResourceIDs, wr.ID) - } - - workspaceAgents, err := q.getWorkspaceAgentsByResourceIDsNoLock(ctx, workspaceResourceIDs) - if err != nil { - return nil, xerrors.Errorf("get workspace agents: %w", err) - } - - var hasAgentMatched bool - for _, wa := range workspaceAgents { - if mapAgentStatus(wa, arg.AgentInactiveDisconnectTimeoutSeconds) == arg.HasAgent { - hasAgentMatched = true - } - } - - if !hasAgentMatched { - continue - } - } - - if len(arg.TemplateIds) > 0 { - match := false - for _, id := range arg.TemplateIds { - if workspace.TemplateID == id { - match = true - break - } - } - if !match { - continue - } - } - - // If the filter exists, ensure the object is authorized. - if prepared != nil && prepared.Authorize(ctx, workspace.RBACObject()) != nil { - continue - } - workspaces = append(workspaces, workspace) - } - - // Sort workspaces (ORDER BY) - isRunning := func(build database.WorkspaceBuild, job database.ProvisionerJob) bool { - return job.CompletedAt.Valid && !job.CanceledAt.Valid && !job.Error.Valid && build.Transition == database.WorkspaceTransitionStart - } - - preloadedWorkspaceBuilds := map[uuid.UUID]database.WorkspaceBuild{} - preloadedProvisionerJobs := map[uuid.UUID]database.ProvisionerJob{} - preloadedUsers := map[uuid.UUID]database.User{} - - for _, w := range workspaces { - build, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, w.ID) - if err == nil { - preloadedWorkspaceBuilds[w.ID] = build - } else if !errors.Is(err, sql.ErrNoRows) { - return nil, xerrors.Errorf("get latest build: %w", err) - } - - job, err := q.getProvisionerJobByIDNoLock(ctx, build.JobID) - if err == nil { - preloadedProvisionerJobs[w.ID] = job - } else if !errors.Is(err, sql.ErrNoRows) { - return nil, xerrors.Errorf("get provisioner job: %w", err) - } - - user, err := q.getUserByIDNoLock(w.OwnerID) - if err == nil { - preloadedUsers[w.ID] = user - } else if !errors.Is(err, sql.ErrNoRows) { - return nil, xerrors.Errorf("get user: %w", err) - } - } - - sort.Slice(workspaces, func(i, j int) bool { - w1 := workspaces[i] - w2 := workspaces[j] - - // Order by: running first - w1IsRunning := isRunning(preloadedWorkspaceBuilds[w1.ID], preloadedProvisionerJobs[w1.ID]) - w2IsRunning := isRunning(preloadedWorkspaceBuilds[w2.ID], preloadedProvisionerJobs[w2.ID]) - - if w1IsRunning && !w2IsRunning { - return true - } - - if !w1IsRunning && w2IsRunning { - return false - } - - // Order by: usernames - if w1.ID != w2.ID { - return sort.StringsAreSorted([]string{preloadedUsers[w1.ID].Username, preloadedUsers[w2.ID].Username}) - } - - // Order by: workspace names - return sort.StringsAreSorted([]string{w1.Name, w2.Name}) - }) - - beforePageCount := len(workspaces) - - if arg.Offset > 0 { - if int(arg.Offset) > len(workspaces) { - return []database.GetWorkspacesRow{}, nil - } - workspaces = workspaces[arg.Offset:] - } - if arg.Limit > 0 { - if int(arg.Limit) > len(workspaces) { - return q.convertToWorkspaceRowsNoLock(ctx, workspaces, int64(beforePageCount)), nil - } - workspaces = workspaces[:arg.Limit] - } - - return q.convertToWorkspaceRowsNoLock(ctx, workspaces, int64(beforePageCount)), nil -} - // mapAgentStatus determines the agent status based on different timestamps like created_at, last_connected_at, disconnected_at, etc. // The function must be in sync with: coderd/workspaceagents.go:convertWorkspaceAgent. func mapAgentStatus(dbAgent database.WorkspaceAgent, agentInactiveDisconnectTimeoutSeconds int64) string { @@ -778,66 +451,6 @@ func (q *FakeQuerier) getTemplateByIDNoLock(_ context.Context, id uuid.UUID) (da return database.Template{}, sql.ErrNoRows } -func (q *FakeQuerier) GetAuthorizedTemplates(ctx context.Context, arg database.GetTemplatesWithFilterParams, prepared rbac.PreparedAuthorized) ([]database.Template, error) { - if err := validateDatabaseType(arg); err != nil { - return nil, err - } - - q.mutex.RLock() - defer q.mutex.RUnlock() - - // Call this to match the same function calls as the SQL implementation. - if prepared != nil { - _, err := prepared.CompileToSQL(ctx, rbac.ConfigWithACL()) - if err != nil { - return nil, err - } - } - - var templates []database.Template - for _, template := range q.templates { - if prepared != nil && prepared.Authorize(ctx, template.RBACObject()) != nil { - continue - } - - if template.Deleted != arg.Deleted { - continue - } - if arg.OrganizationID != uuid.Nil && template.OrganizationID != arg.OrganizationID { - continue - } - - if arg.ExactName != "" && !strings.EqualFold(template.Name, arg.ExactName) { - continue - } - - if len(arg.IDs) > 0 { - match := false - for _, id := range arg.IDs { - if template.ID == id { - match = true - break - } - } - if !match { - continue - } - } - templates = append(templates, template.DeepCopy()) - } - if len(templates) > 0 { - slices.SortFunc(templates, func(i, j database.Template) bool { - if i.Name != j.Name { - return i.Name < j.Name - } - return i.ID.String() < j.ID.String() - }) - return templates, nil - } - - return nil, sql.ErrNoRows -} - func (q *FakeQuerier) getTemplateVersionByIDNoLock(_ context.Context, templateVersionID uuid.UUID) (database.TemplateVersion, error) { for _, templateVersion := range q.templateVersions { if templateVersion.ID != templateVersionID { @@ -848,84 +461,6 @@ func (q *FakeQuerier) getTemplateVersionByIDNoLock(_ context.Context, templateVe return database.TemplateVersion{}, sql.ErrNoRows } -func (q *FakeQuerier) GetTemplateUserRoles(_ context.Context, id uuid.UUID) ([]database.TemplateUser, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - var template database.Template - for _, t := range q.templates { - if t.ID == id { - template = t - break - } - } - - if template.ID == uuid.Nil { - return nil, sql.ErrNoRows - } - - users := make([]database.TemplateUser, 0, len(template.UserACL)) - for k, v := range template.UserACL { - user, err := q.getUserByIDNoLock(uuid.MustParse(k)) - if err != nil && xerrors.Is(err, sql.ErrNoRows) { - return nil, xerrors.Errorf("get user by ID: %w", err) - } - // We don't delete users from the map if they - // get deleted so just skip. - if xerrors.Is(err, sql.ErrNoRows) { - continue - } - - if user.Deleted || user.Status == database.UserStatusSuspended { - continue - } - - users = append(users, database.TemplateUser{ - User: user, - Actions: v, - }) - } - - return users, nil -} - -func (q *FakeQuerier) GetTemplateGroupRoles(_ context.Context, id uuid.UUID) ([]database.TemplateGroup, error) { - q.mutex.RLock() - defer q.mutex.RUnlock() - - var template database.Template - for _, t := range q.templates { - if t.ID == id { - template = t - break - } - } - - if template.ID == uuid.Nil { - return nil, sql.ErrNoRows - } - - groups := make([]database.TemplateGroup, 0, len(template.GroupACL)) - for k, v := range template.GroupACL { - group, err := q.getGroupByIDNoLock(context.Background(), uuid.MustParse(k)) - if err != nil && !xerrors.Is(err, sql.ErrNoRows) { - return nil, xerrors.Errorf("get group by ID: %w", err) - } - // We don't delete groups from the map if they - // get deleted so just skip. - if xerrors.Is(err, sql.ErrNoRows) { - continue - } - - groups = append(groups, database.TemplateGroup{ - Group: group, - Actions: v, - }) - } - - return groups, nil -} - func (q *FakeQuerier) getWorkspaceAgentByIDNoLock(_ context.Context, id uuid.UUID) (database.WorkspaceAgent, error) { // The schema sorts this by created at, so we iterate the array backwards. for i := len(q.workspaceAgents) - 1; i >= 0; i-- { @@ -5438,3 +4973,468 @@ func (*FakeQuerier) UpsertTailnetClient(context.Context, database.UpsertTailnetC func (*FakeQuerier) UpsertTailnetCoordinator(context.Context, uuid.UUID) (database.TailnetCoordinator, error) { return database.TailnetCoordinator{}, ErrUnimplemented } + +func (q *FakeQuerier) GetAuthorizedTemplates(ctx context.Context, arg database.GetTemplatesWithFilterParams, prepared rbac.PreparedAuthorized) ([]database.Template, error) { + if err := validateDatabaseType(arg); err != nil { + return nil, err + } + + q.mutex.RLock() + defer q.mutex.RUnlock() + + // Call this to match the same function calls as the SQL implementation. + if prepared != nil { + _, err := prepared.CompileToSQL(ctx, rbac.ConfigWithACL()) + if err != nil { + return nil, err + } + } + + var templates []database.Template + for _, template := range q.templates { + if prepared != nil && prepared.Authorize(ctx, template.RBACObject()) != nil { + continue + } + + if template.Deleted != arg.Deleted { + continue + } + if arg.OrganizationID != uuid.Nil && template.OrganizationID != arg.OrganizationID { + continue + } + + if arg.ExactName != "" && !strings.EqualFold(template.Name, arg.ExactName) { + continue + } + + if len(arg.IDs) > 0 { + match := false + for _, id := range arg.IDs { + if template.ID == id { + match = true + break + } + } + if !match { + continue + } + } + templates = append(templates, template.DeepCopy()) + } + if len(templates) > 0 { + slices.SortFunc(templates, func(i, j database.Template) bool { + if i.Name != j.Name { + return i.Name < j.Name + } + return i.ID.String() < j.ID.String() + }) + return templates, nil + } + + return nil, sql.ErrNoRows +} + +func (q *FakeQuerier) GetTemplateGroupRoles(_ context.Context, id uuid.UUID) ([]database.TemplateGroup, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + + var template database.Template + for _, t := range q.templates { + if t.ID == id { + template = t + break + } + } + + if template.ID == uuid.Nil { + return nil, sql.ErrNoRows + } + + groups := make([]database.TemplateGroup, 0, len(template.GroupACL)) + for k, v := range template.GroupACL { + group, err := q.getGroupByIDNoLock(context.Background(), uuid.MustParse(k)) + if err != nil && !xerrors.Is(err, sql.ErrNoRows) { + return nil, xerrors.Errorf("get group by ID: %w", err) + } + // We don't delete groups from the map if they + // get deleted so just skip. + if xerrors.Is(err, sql.ErrNoRows) { + continue + } + + groups = append(groups, database.TemplateGroup{ + Group: group, + Actions: v, + }) + } + + return groups, nil +} + +func (q *FakeQuerier) GetTemplateUserRoles(_ context.Context, id uuid.UUID) ([]database.TemplateUser, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + + var template database.Template + for _, t := range q.templates { + if t.ID == id { + template = t + break + } + } + + if template.ID == uuid.Nil { + return nil, sql.ErrNoRows + } + + users := make([]database.TemplateUser, 0, len(template.UserACL)) + for k, v := range template.UserACL { + user, err := q.getUserByIDNoLock(uuid.MustParse(k)) + if err != nil && xerrors.Is(err, sql.ErrNoRows) { + return nil, xerrors.Errorf("get user by ID: %w", err) + } + // We don't delete users from the map if they + // get deleted so just skip. + if xerrors.Is(err, sql.ErrNoRows) { + continue + } + + if user.Deleted || user.Status == database.UserStatusSuspended { + continue + } + + users = append(users, database.TemplateUser{ + User: user, + Actions: v, + }) + } + + return users, nil +} + +//nolint:gocyclo +func (q *FakeQuerier) GetAuthorizedWorkspaces(ctx context.Context, arg database.GetWorkspacesParams, prepared rbac.PreparedAuthorized) ([]database.GetWorkspacesRow, error) { + if err := validateDatabaseType(arg); err != nil { + return nil, err + } + + q.mutex.RLock() + defer q.mutex.RUnlock() + + if prepared != nil { + // Call this to match the same function calls as the SQL implementation. + _, err := prepared.CompileToSQL(ctx, rbac.ConfigWithoutACL()) + if err != nil { + return nil, err + } + } + + workspaces := make([]database.Workspace, 0) + for _, workspace := range q.workspaces { + if arg.OwnerID != uuid.Nil && workspace.OwnerID != arg.OwnerID { + continue + } + + if arg.OwnerUsername != "" { + owner, err := q.getUserByIDNoLock(workspace.OwnerID) + if err == nil && !strings.EqualFold(arg.OwnerUsername, owner.Username) { + continue + } + } + + if arg.TemplateName != "" { + template, err := q.getTemplateByIDNoLock(ctx, workspace.TemplateID) + if err == nil && !strings.EqualFold(arg.TemplateName, template.Name) { + continue + } + } + + if !arg.Deleted && workspace.Deleted { + continue + } + + if arg.Name != "" && !strings.Contains(strings.ToLower(workspace.Name), strings.ToLower(arg.Name)) { + continue + } + + if arg.Status != "" { + build, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, workspace.ID) + if err != nil { + return nil, xerrors.Errorf("get latest build: %w", err) + } + + job, err := q.getProvisionerJobByIDNoLock(ctx, build.JobID) + if err != nil { + return nil, xerrors.Errorf("get provisioner job: %w", err) + } + + // This logic should match the logic in the workspace.sql file. + var statusMatch bool + switch database.WorkspaceStatus(arg.Status) { + case database.WorkspaceStatusPending: + statusMatch = isNull(job.StartedAt) + case database.WorkspaceStatusStarting: + statusMatch = isNotNull(job.StartedAt) && + isNull(job.CanceledAt) && + isNull(job.CompletedAt) && + time.Since(job.UpdatedAt) < 30*time.Second && + build.Transition == database.WorkspaceTransitionStart + + case database.WorkspaceStatusRunning: + statusMatch = isNotNull(job.CompletedAt) && + isNull(job.CanceledAt) && + isNull(job.Error) && + build.Transition == database.WorkspaceTransitionStart + + case database.WorkspaceStatusStopping: + statusMatch = isNotNull(job.StartedAt) && + isNull(job.CanceledAt) && + isNull(job.CompletedAt) && + time.Since(job.UpdatedAt) < 30*time.Second && + build.Transition == database.WorkspaceTransitionStop + + case database.WorkspaceStatusStopped: + statusMatch = isNotNull(job.CompletedAt) && + isNull(job.CanceledAt) && + isNull(job.Error) && + build.Transition == database.WorkspaceTransitionStop + case database.WorkspaceStatusFailed: + statusMatch = (isNotNull(job.CanceledAt) && isNotNull(job.Error)) || + (isNotNull(job.CompletedAt) && isNotNull(job.Error)) + + case database.WorkspaceStatusCanceling: + statusMatch = isNotNull(job.CanceledAt) && + isNull(job.CompletedAt) + + case database.WorkspaceStatusCanceled: + statusMatch = isNotNull(job.CanceledAt) && + isNotNull(job.CompletedAt) + + case database.WorkspaceStatusDeleted: + statusMatch = isNotNull(job.StartedAt) && + isNull(job.CanceledAt) && + isNotNull(job.CompletedAt) && + time.Since(job.UpdatedAt) < 30*time.Second && + build.Transition == database.WorkspaceTransitionDelete && + isNull(job.Error) + + case database.WorkspaceStatusDeleting: + statusMatch = isNull(job.CompletedAt) && + isNull(job.CanceledAt) && + isNull(job.Error) && + build.Transition == database.WorkspaceTransitionDelete + + default: + return nil, xerrors.Errorf("unknown workspace status in filter: %q", arg.Status) + } + if !statusMatch { + continue + } + } + + if arg.HasAgent != "" { + build, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, workspace.ID) + if err != nil { + return nil, xerrors.Errorf("get latest build: %w", err) + } + + job, err := q.getProvisionerJobByIDNoLock(ctx, build.JobID) + if err != nil { + return nil, xerrors.Errorf("get provisioner job: %w", err) + } + + workspaceResources, err := q.getWorkspaceResourcesByJobIDNoLock(ctx, job.ID) + if err != nil { + return nil, xerrors.Errorf("get workspace resources: %w", err) + } + + var workspaceResourceIDs []uuid.UUID + for _, wr := range workspaceResources { + workspaceResourceIDs = append(workspaceResourceIDs, wr.ID) + } + + workspaceAgents, err := q.getWorkspaceAgentsByResourceIDsNoLock(ctx, workspaceResourceIDs) + if err != nil { + return nil, xerrors.Errorf("get workspace agents: %w", err) + } + + var hasAgentMatched bool + for _, wa := range workspaceAgents { + if mapAgentStatus(wa, arg.AgentInactiveDisconnectTimeoutSeconds) == arg.HasAgent { + hasAgentMatched = true + } + } + + if !hasAgentMatched { + continue + } + } + + if len(arg.TemplateIds) > 0 { + match := false + for _, id := range arg.TemplateIds { + if workspace.TemplateID == id { + match = true + break + } + } + if !match { + continue + } + } + + // If the filter exists, ensure the object is authorized. + if prepared != nil && prepared.Authorize(ctx, workspace.RBACObject()) != nil { + continue + } + workspaces = append(workspaces, workspace) + } + + // Sort workspaces (ORDER BY) + isRunning := func(build database.WorkspaceBuild, job database.ProvisionerJob) bool { + return job.CompletedAt.Valid && !job.CanceledAt.Valid && !job.Error.Valid && build.Transition == database.WorkspaceTransitionStart + } + + preloadedWorkspaceBuilds := map[uuid.UUID]database.WorkspaceBuild{} + preloadedProvisionerJobs := map[uuid.UUID]database.ProvisionerJob{} + preloadedUsers := map[uuid.UUID]database.User{} + + for _, w := range workspaces { + build, err := q.getLatestWorkspaceBuildByWorkspaceIDNoLock(ctx, w.ID) + if err == nil { + preloadedWorkspaceBuilds[w.ID] = build + } else if !errors.Is(err, sql.ErrNoRows) { + return nil, xerrors.Errorf("get latest build: %w", err) + } + + job, err := q.getProvisionerJobByIDNoLock(ctx, build.JobID) + if err == nil { + preloadedProvisionerJobs[w.ID] = job + } else if !errors.Is(err, sql.ErrNoRows) { + return nil, xerrors.Errorf("get provisioner job: %w", err) + } + + user, err := q.getUserByIDNoLock(w.OwnerID) + if err == nil { + preloadedUsers[w.ID] = user + } else if !errors.Is(err, sql.ErrNoRows) { + return nil, xerrors.Errorf("get user: %w", err) + } + } + + sort.Slice(workspaces, func(i, j int) bool { + w1 := workspaces[i] + w2 := workspaces[j] + + // Order by: running first + w1IsRunning := isRunning(preloadedWorkspaceBuilds[w1.ID], preloadedProvisionerJobs[w1.ID]) + w2IsRunning := isRunning(preloadedWorkspaceBuilds[w2.ID], preloadedProvisionerJobs[w2.ID]) + + if w1IsRunning && !w2IsRunning { + return true + } + + if !w1IsRunning && w2IsRunning { + return false + } + + // Order by: usernames + if w1.ID != w2.ID { + return sort.StringsAreSorted([]string{preloadedUsers[w1.ID].Username, preloadedUsers[w2.ID].Username}) + } + + // Order by: workspace names + return sort.StringsAreSorted([]string{w1.Name, w2.Name}) + }) + + beforePageCount := len(workspaces) + + if arg.Offset > 0 { + if int(arg.Offset) > len(workspaces) { + return []database.GetWorkspacesRow{}, nil + } + workspaces = workspaces[arg.Offset:] + } + if arg.Limit > 0 { + if int(arg.Limit) > len(workspaces) { + return q.convertToWorkspaceRowsNoLock(ctx, workspaces, int64(beforePageCount)), nil + } + workspaces = workspaces[:arg.Limit] + } + + return q.convertToWorkspaceRowsNoLock(ctx, workspaces, int64(beforePageCount)), nil +} + +func (q *FakeQuerier) GetAuthorizedUserCount(ctx context.Context, params database.GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) { + if err := validateDatabaseType(params); err != nil { + return 0, err + } + + q.mutex.RLock() + defer q.mutex.RUnlock() + + // Call this to match the same function calls as the SQL implementation. + if prepared != nil { + _, err := prepared.CompileToSQL(ctx, rbac.ConfigWithoutACL()) + if err != nil { + return -1, err + } + } + + users := make([]database.User, 0, len(q.users)) + + for _, user := range q.users { + // If the filter exists, ensure the object is authorized. + if prepared != nil && prepared.Authorize(ctx, user.RBACObject()) != nil { + continue + } + + users = append(users, user) + } + + // Filter out deleted since they should never be returned.. + tmp := make([]database.User, 0, len(users)) + for _, user := range users { + if !user.Deleted { + tmp = append(tmp, user) + } + } + users = tmp + + if params.Search != "" { + tmp := make([]database.User, 0, len(users)) + for i, user := range users { + if strings.Contains(strings.ToLower(user.Email), strings.ToLower(params.Search)) { + tmp = append(tmp, users[i]) + } else if strings.Contains(strings.ToLower(user.Username), strings.ToLower(params.Search)) { + tmp = append(tmp, users[i]) + } + } + users = tmp + } + + if len(params.Status) > 0 { + usersFilteredByStatus := make([]database.User, 0, len(users)) + for i, user := range users { + if slice.ContainsCompare(params.Status, user.Status, func(a, b database.UserStatus) bool { + return strings.EqualFold(string(a), string(b)) + }) { + usersFilteredByStatus = append(usersFilteredByStatus, users[i]) + } + } + users = usersFilteredByStatus + } + + if len(params.RbacRole) > 0 && !slice.Contains(params.RbacRole, rbac.RoleMember()) { + usersFilteredByRole := make([]database.User, 0, len(users)) + for i, user := range users { + if slice.OverlapCompare(params.RbacRole, user.RBACRoles, strings.EqualFold) { + usersFilteredByRole = append(usersFilteredByRole, users[i]) + } + } + + users = usersFilteredByRole + } + + return int64(len(users)), nil +} diff --git a/coderd/database/dbmetrics/dbmetrics.go b/coderd/database/dbmetrics/dbmetrics.go index ec28fd428a..11d857a302 100644 --- a/coderd/database/dbmetrics/dbmetrics.go +++ b/coderd/database/dbmetrics/dbmetrics.go @@ -16,6 +16,12 @@ import ( "github.com/coder/coder/coderd/rbac" ) +var ( + // Force these imports, for some reason the autogen does not include them. + _ uuid.UUID + _ rbac.Action +) + const wrapname = "dbmetrics.metricsStore" // New returns a database.Store that registers metrics for all queries to reg. @@ -73,41 +79,6 @@ func (m metricsStore) InTx(f func(database.Store) error, options *sql.TxOptions) return err } -func (m metricsStore) GetAuthorizedTemplates(ctx context.Context, arg database.GetTemplatesWithFilterParams, prepared rbac.PreparedAuthorized) ([]database.Template, error) { - start := time.Now() - templates, err := m.s.GetAuthorizedTemplates(ctx, arg, prepared) - m.queryLatencies.WithLabelValues("GetAuthorizedTemplates").Observe(time.Since(start).Seconds()) - return templates, err -} - -func (m metricsStore) GetTemplateGroupRoles(ctx context.Context, id uuid.UUID) ([]database.TemplateGroup, error) { - start := time.Now() - roles, err := m.s.GetTemplateGroupRoles(ctx, id) - m.queryLatencies.WithLabelValues("GetTemplateGroupRoles").Observe(time.Since(start).Seconds()) - return roles, err -} - -func (m metricsStore) GetTemplateUserRoles(ctx context.Context, id uuid.UUID) ([]database.TemplateUser, error) { - start := time.Now() - roles, err := m.s.GetTemplateUserRoles(ctx, id) - m.queryLatencies.WithLabelValues("GetTemplateUserRoles").Observe(time.Since(start).Seconds()) - return roles, err -} - -func (m metricsStore) GetAuthorizedWorkspaces(ctx context.Context, arg database.GetWorkspacesParams, prepared rbac.PreparedAuthorized) ([]database.GetWorkspacesRow, error) { - start := time.Now() - workspaces, err := m.s.GetAuthorizedWorkspaces(ctx, arg, prepared) - m.queryLatencies.WithLabelValues("GetAuthorizedWorkspaces").Observe(time.Since(start).Seconds()) - return workspaces, err -} - -func (m metricsStore) GetAuthorizedUserCount(ctx context.Context, arg database.GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) { - start := time.Now() - count, err := m.s.GetAuthorizedUserCount(ctx, arg, prepared) - m.queryLatencies.WithLabelValues("GetAuthorizedUserCount").Observe(time.Since(start).Seconds()) - return count, err -} - func (m metricsStore) AcquireLock(ctx context.Context, pgAdvisoryXactLock int64) error { start := time.Now() err := m.s.AcquireLock(ctx, pgAdvisoryXactLock) @@ -1639,3 +1610,38 @@ func (m metricsStore) UpsertTailnetCoordinator(ctx context.Context, id uuid.UUID defer m.queryLatencies.WithLabelValues("UpsertTailnetCoordinator").Observe(time.Since(start).Seconds()) return m.s.UpsertTailnetCoordinator(ctx, id) } + +func (m metricsStore) GetAuthorizedTemplates(ctx context.Context, arg database.GetTemplatesWithFilterParams, prepared rbac.PreparedAuthorized) ([]database.Template, error) { + start := time.Now() + templates, err := m.s.GetAuthorizedTemplates(ctx, arg, prepared) + m.queryLatencies.WithLabelValues("GetAuthorizedTemplates").Observe(time.Since(start).Seconds()) + return templates, err +} + +func (m metricsStore) GetTemplateGroupRoles(ctx context.Context, id uuid.UUID) ([]database.TemplateGroup, error) { + start := time.Now() + roles, err := m.s.GetTemplateGroupRoles(ctx, id) + m.queryLatencies.WithLabelValues("GetTemplateGroupRoles").Observe(time.Since(start).Seconds()) + return roles, err +} + +func (m metricsStore) GetTemplateUserRoles(ctx context.Context, id uuid.UUID) ([]database.TemplateUser, error) { + start := time.Now() + roles, err := m.s.GetTemplateUserRoles(ctx, id) + m.queryLatencies.WithLabelValues("GetTemplateUserRoles").Observe(time.Since(start).Seconds()) + return roles, err +} + +func (m metricsStore) GetAuthorizedWorkspaces(ctx context.Context, arg database.GetWorkspacesParams, prepared rbac.PreparedAuthorized) ([]database.GetWorkspacesRow, error) { + start := time.Now() + workspaces, err := m.s.GetAuthorizedWorkspaces(ctx, arg, prepared) + m.queryLatencies.WithLabelValues("GetAuthorizedWorkspaces").Observe(time.Since(start).Seconds()) + return workspaces, err +} + +func (m metricsStore) GetAuthorizedUserCount(ctx context.Context, arg database.GetFilteredUserCountParams, prepared rbac.PreparedAuthorized) (int64, error) { + start := time.Now() + count, err := m.s.GetAuthorizedUserCount(ctx, arg, prepared) + m.queryLatencies.WithLabelValues("GetAuthorizedUserCount").Observe(time.Since(start).Seconds()) + return count, err +} diff --git a/scripts/dbgen/main.go b/scripts/dbgen/main.go index 6141980428..0eeda09ba9 100644 --- a/scripts/dbgen/main.go +++ b/scripts/dbgen/main.go @@ -418,21 +418,44 @@ type querierFunction struct { // readQuerierFunctions reads the functions from coderd/database/querier.go func readQuerierFunctions() ([]querierFunction, error) { + f, err := parseDBFile("querier.go") + if err != nil { + return nil, xerrors.Errorf("parse querier.go: %w", err) + } + funcs, err := loadInterfaceFuncs(f, "sqlcQuerier") + if err != nil { + return nil, xerrors.Errorf("load interface %s funcs: %w", "sqlcQuerier", err) + } + + customFile, err := parseDBFile("modelqueries.go") + if err != nil { + return nil, xerrors.Errorf("parse modelqueriers.go: %w", err) + } + // Custom funcs should be appended after the regular functions + customFuncs, err := loadInterfaceFuncs(customFile, "customQuerier") + if err != nil { + return nil, xerrors.Errorf("load interface %s funcs: %w", "customQuerier", err) + } + + return append(funcs, customFuncs...), nil +} + +func parseDBFile(filename string) (*dst.File, error) { localPath, err := localFilePath() if err != nil { return nil, err } - querierPath := filepath.Join(localPath, "..", "..", "..", "coderd", "database", "querier.go") + querierPath := filepath.Join(localPath, "..", "..", "..", "coderd", "database", filename) querierData, err := os.ReadFile(querierPath) if err != nil { - return nil, xerrors.Errorf("read querier: %w", err) + return nil, xerrors.Errorf("read %s: %w", filename, err) } f, err := decorator.Parse(querierData) - if err != nil { - return nil, err - } + return f, err +} +func loadInterfaceFuncs(f *dst.File, interfaceName string) ([]querierFunction, error) { var querier *dst.InterfaceType for _, decl := range f.Decls { genDecl, ok := decl.(*dst.GenDecl) @@ -447,7 +470,7 @@ func readQuerierFunctions() ([]querierFunction, error) { } // This is the name of the interface. If that ever changes, // this will need to be updated. - if typeSpec.Name.Name != "sqlcQuerier" { + if typeSpec.Name.Name != interfaceName { continue } querier, ok = typeSpec.Type.(*dst.InterfaceType) @@ -461,7 +484,8 @@ func readQuerierFunctions() ([]querierFunction, error) { return nil, xerrors.Errorf("querier not found") } funcs := []querierFunction{} - for _, method := range querier.Methods.List { + allMethods := interfaceMethods(querier) + for _, method := range allMethods { funcType, ok := method.Type.(*dst.FuncType) if !ok { continue @@ -540,3 +564,30 @@ func nameFromSnakeCase(s string) string { } return ret } + +// interfaceMethods returns all embedded methods of an interface. +func interfaceMethods(i *dst.InterfaceType) []*dst.Field { + var allMethods []*dst.Field + for _, field := range i.Methods.List { + switch fieldType := field.Type.(type) { + case *dst.FuncType: + allMethods = append(allMethods, field) + case *dst.InterfaceType: + allMethods = append(allMethods, interfaceMethods(fieldType)...) + case *dst.Ident: + // Embedded interfaces are Idents -> TypeSpec -> InterfaceType + // If the embedded interface is not in the parsed file, then + // the Obj will be nil. + if fieldType.Obj != nil { + objDecl, ok := fieldType.Obj.Decl.(*dst.TypeSpec) + if ok { + isInterface, ok := objDecl.Type.(*dst.InterfaceType) + if ok { + allMethods = append(allMethods, interfaceMethods(isInterface)...) + } + } + } + } + } + return allMethods +}