feat: add groups and group members to telemetry snapshot (#13655)

* feat: Added in groups and groups members to telemetry snapshot
* feat: adding in test to dbauthz for getting group members and groups
This commit is contained in:
austinrhode
2024-06-25 11:01:40 -07:00
committed by GitHub
parent 58325dfd14
commit 87ad560aff
14 changed files with 239 additions and 25 deletions

View File

@ -1321,11 +1321,25 @@ func (q *querier) GetGroupByOrgAndName(ctx context.Context, arg database.GetGrou
return fetch(q.log, q.auth, q.db.GetGroupByOrgAndName)(ctx, arg)
}
func (q *querier) GetGroupMembers(ctx context.Context, id uuid.UUID) ([]database.User, error) {
func (q *querier) GetGroupMembers(ctx context.Context) ([]database.GroupMember, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
return nil, err
}
return q.db.GetGroupMembers(ctx)
}
func (q *querier) GetGroupMembersByGroupID(ctx context.Context, id uuid.UUID) ([]database.User, error) {
if _, err := q.GetGroupByID(ctx, id); err != nil { // AuthZ check
return nil, err
}
return q.db.GetGroupMembers(ctx, id)
return q.db.GetGroupMembersByGroupID(ctx, id)
}
func (q *querier) GetGroups(ctx context.Context) ([]database.Group, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
return nil, err
}
return q.db.GetGroups(ctx)
}
func (q *querier) GetGroupsByOrganizationAndUserID(ctx context.Context, arg database.GetGroupsByOrganizationAndUserIDParams) ([]database.Group, error) {

View File

@ -314,11 +314,19 @@ func (s *MethodTestSuite) TestGroup() {
Name: g.Name,
}).Asserts(g, policy.ActionRead).Returns(g)
}))
s.Run("GetGroupMembers", s.Subtest(func(db database.Store, check *expects) {
s.Run("GetGroupMembersByGroupID", s.Subtest(func(db database.Store, check *expects) {
g := dbgen.Group(s.T(), db, database.Group{})
_ = dbgen.GroupMember(s.T(), db, database.GroupMember{})
check.Args(g.ID).Asserts(g, policy.ActionRead)
}))
s.Run("GetGroupMembers", s.Subtest(func(db database.Store, check *expects) {
_ = dbgen.GroupMember(s.T(), db, database.GroupMember{})
check.Asserts(rbac.ResourceSystem, policy.ActionRead)
}))
s.Run("GetGroups", s.Subtest(func(db database.Store, check *expects) {
_ = dbgen.Group(s.T(), db, database.Group{})
check.Asserts(rbac.ResourceSystem, policy.ActionRead)
}))
s.Run("GetGroupsByOrganizationAndUserID", s.Subtest(func(db database.Store, check *expects) {
g := dbgen.Group(s.T(), db, database.Group{})
gm := dbgen.GroupMember(s.T(), db, database.GroupMember{GroupID: g.ID})

View File

@ -105,7 +105,7 @@ func TestGenerator(t *testing.T) {
exp := []database.User{u}
dbgen.GroupMember(t, db, database.GroupMember{GroupID: g.ID, UserID: u.ID})
require.Equal(t, exp, must(db.GetGroupMembers(context.Background(), g.ID)))
require.Equal(t, exp, must(db.GetGroupMembersByGroupID(context.Background(), g.ID)))
})
t.Run("Organization", func(t *testing.T) {

View File

@ -2370,7 +2370,16 @@ func (q *FakeQuerier) GetGroupByOrgAndName(_ context.Context, arg database.GetGr
return database.Group{}, sql.ErrNoRows
}
func (q *FakeQuerier) GetGroupMembers(_ context.Context, id uuid.UUID) ([]database.User, error) {
func (q *FakeQuerier) GetGroupMembers(_ context.Context) ([]database.GroupMember, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
out := make([]database.GroupMember, len(q.groupMembers))
copy(out, q.groupMembers)
return out, nil
}
func (q *FakeQuerier) GetGroupMembersByGroupID(_ context.Context, id uuid.UUID) ([]database.User, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
@ -2399,6 +2408,15 @@ func (q *FakeQuerier) GetGroupMembers(_ context.Context, id uuid.UUID) ([]databa
return users, nil
}
func (q *FakeQuerier) GetGroups(_ context.Context) ([]database.Group, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
out := make([]database.Group, len(q.groups))
copy(out, q.groups)
return out, nil
}
func (q *FakeQuerier) GetGroupsByOrganizationAndUserID(_ context.Context, arg database.GetGroupsByOrganizationAndUserIDParams) ([]database.Group, error) {
err := validateDatabaseType(arg)
if err != nil {

View File

@ -585,13 +585,27 @@ func (m metricsStore) GetGroupByOrgAndName(ctx context.Context, arg database.Get
return group, err
}
func (m metricsStore) GetGroupMembers(ctx context.Context, groupID uuid.UUID) ([]database.User, error) {
func (m metricsStore) GetGroupMembers(ctx context.Context) ([]database.GroupMember, error) {
start := time.Now()
users, err := m.s.GetGroupMembers(ctx, groupID)
r0, r1 := m.s.GetGroupMembers(ctx)
m.queryLatencies.WithLabelValues("GetGroupMembers").Observe(time.Since(start).Seconds())
return r0, r1
}
func (m metricsStore) GetGroupMembersByGroupID(ctx context.Context, groupID uuid.UUID) ([]database.User, error) {
start := time.Now()
users, err := m.s.GetGroupMembersByGroupID(ctx, groupID)
m.queryLatencies.WithLabelValues("GetGroupMembersByGroupID").Observe(time.Since(start).Seconds())
return users, err
}
func (m metricsStore) GetGroups(ctx context.Context) ([]database.Group, error) {
start := time.Now()
r0, r1 := m.s.GetGroups(ctx)
m.queryLatencies.WithLabelValues("GetGroups").Observe(time.Since(start).Seconds())
return r0, r1
}
func (m metricsStore) GetGroupsByOrganizationAndUserID(ctx context.Context, arg database.GetGroupsByOrganizationAndUserIDParams) ([]database.Group, error) {
start := time.Now()
r0, r1 := m.s.GetGroupsByOrganizationAndUserID(ctx, arg)

View File

@ -1139,18 +1139,48 @@ func (mr *MockStoreMockRecorder) GetGroupByOrgAndName(arg0, arg1 any) *gomock.Ca
}
// GetGroupMembers mocks base method.
func (m *MockStore) GetGroupMembers(arg0 context.Context, arg1 uuid.UUID) ([]database.User, error) {
func (m *MockStore) GetGroupMembers(arg0 context.Context) ([]database.GroupMember, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetGroupMembers", arg0, arg1)
ret0, _ := ret[0].([]database.User)
ret := m.ctrl.Call(m, "GetGroupMembers", arg0)
ret0, _ := ret[0].([]database.GroupMember)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetGroupMembers indicates an expected call of GetGroupMembers.
func (mr *MockStoreMockRecorder) GetGroupMembers(arg0, arg1 any) *gomock.Call {
func (mr *MockStoreMockRecorder) GetGroupMembers(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupMembers", reflect.TypeOf((*MockStore)(nil).GetGroupMembers), arg0, arg1)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupMembers", reflect.TypeOf((*MockStore)(nil).GetGroupMembers), arg0)
}
// GetGroupMembersByGroupID mocks base method.
func (m *MockStore) GetGroupMembersByGroupID(arg0 context.Context, arg1 uuid.UUID) ([]database.User, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetGroupMembersByGroupID", arg0, arg1)
ret0, _ := ret[0].([]database.User)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetGroupMembersByGroupID indicates an expected call of GetGroupMembersByGroupID.
func (mr *MockStoreMockRecorder) GetGroupMembersByGroupID(arg0, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupMembersByGroupID", reflect.TypeOf((*MockStore)(nil).GetGroupMembersByGroupID), arg0, arg1)
}
// GetGroups mocks base method.
func (m *MockStore) GetGroups(arg0 context.Context) ([]database.Group, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetGroups", arg0)
ret0, _ := ret[0].([]database.Group)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetGroups indicates an expected call of GetGroups.
func (mr *MockStoreMockRecorder) GetGroups(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroups", reflect.TypeOf((*MockStore)(nil).GetGroups), arg0)
}
// GetGroupsByOrganizationAndUserID mocks base method.

View File

@ -124,9 +124,11 @@ type sqlcQuerier interface {
GetGitSSHKey(ctx context.Context, userID uuid.UUID) (GitSSHKey, error)
GetGroupByID(ctx context.Context, id uuid.UUID) (Group, error)
GetGroupByOrgAndName(ctx context.Context, arg GetGroupByOrgAndNameParams) (Group, error)
GetGroupMembers(ctx context.Context) ([]GroupMember, error)
// If the group is a user made group, then we need to check the group_members table.
// If it is the "Everyone" group, then we need to check the organization_members table.
GetGroupMembers(ctx context.Context, groupID uuid.UUID) ([]User, error)
GetGroupMembersByGroupID(ctx context.Context, groupID uuid.UUID) ([]User, error)
GetGroups(ctx context.Context) ([]Group, error)
GetGroupsByOrganizationAndUserID(ctx context.Context, arg GetGroupsByOrganizationAndUserIDParams) ([]Group, error)
GetGroupsByOrganizationID(ctx context.Context, organizationID uuid.UUID) ([]Group, error)
GetHealthSettings(ctx context.Context) (string, error)

View File

@ -1312,6 +1312,33 @@ func (q *sqlQuerier) DeleteGroupMemberFromGroup(ctx context.Context, arg DeleteG
}
const getGroupMembers = `-- name: GetGroupMembers :many
SELECT user_id, group_id FROM group_members
`
func (q *sqlQuerier) GetGroupMembers(ctx context.Context) ([]GroupMember, error) {
rows, err := q.db.QueryContext(ctx, getGroupMembers)
if err != nil {
return nil, err
}
defer rows.Close()
var items []GroupMember
for rows.Next() {
var i GroupMember
if err := rows.Scan(&i.UserID, &i.GroupID); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getGroupMembersByGroupID = `-- name: GetGroupMembersByGroupID :many
SELECT
users.id, users.email, users.username, users.hashed_password, users.created_at, users.updated_at, users.status, users.rbac_roles, users.login_type, users.avatar_url, users.deleted, users.last_seen_at, users.quiet_hours_schedule, users.theme_preference, users.name
FROM
@ -1337,8 +1364,8 @@ AND
// If the group is a user made group, then we need to check the group_members table.
// If it is the "Everyone" group, then we need to check the organization_members table.
func (q *sqlQuerier) GetGroupMembers(ctx context.Context, groupID uuid.UUID) ([]User, error) {
rows, err := q.db.QueryContext(ctx, getGroupMembers, groupID)
func (q *sqlQuerier) GetGroupMembersByGroupID(ctx context.Context, groupID uuid.UUID) ([]User, error) {
rows, err := q.db.QueryContext(ctx, getGroupMembersByGroupID, groupID)
if err != nil {
return nil, err
}
@ -1507,6 +1534,41 @@ func (q *sqlQuerier) GetGroupByOrgAndName(ctx context.Context, arg GetGroupByOrg
return i, err
}
const getGroups = `-- name: GetGroups :many
SELECT id, name, organization_id, avatar_url, quota_allowance, display_name, source FROM groups
`
func (q *sqlQuerier) GetGroups(ctx context.Context) ([]Group, error) {
rows, err := q.db.QueryContext(ctx, getGroups)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Group
for rows.Next() {
var i Group
if err := rows.Scan(
&i.ID,
&i.Name,
&i.OrganizationID,
&i.AvatarURL,
&i.QuotaAllowance,
&i.DisplayName,
&i.Source,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getGroupsByOrganizationAndUserID = `-- name: GetGroupsByOrganizationAndUserID :many
SELECT
groups.id, groups.name, groups.organization_id, groups.avatar_url, groups.quota_allowance, groups.display_name, groups.source

View File

@ -1,4 +1,7 @@
-- name: GetGroupMembers :many
SELECT * FROM group_members;
-- name: GetGroupMembersByGroupID :many
SELECT
users.*
FROM

View File

@ -1,3 +1,6 @@
-- name: GetGroups :many
SELECT * FROM groups;
-- name: GetGroupByID :one
SELECT
*

View File

@ -344,9 +344,6 @@ func (r *remoteReporter) createSnapshot() (*Snapshot, error) {
users := database.ConvertUserRows(userRows)
var firstUser database.User
for _, dbUser := range users {
if dbUser.Status != database.UserStatusActive {
continue
}
if firstUser.CreatedAt.IsZero() {
firstUser = dbUser
}
@ -366,6 +363,28 @@ func (r *remoteReporter) createSnapshot() (*Snapshot, error) {
}
return nil
})
eg.Go(func() error {
groups, err := r.options.Database.GetGroups(ctx)
if err != nil {
return xerrors.Errorf("get groups: %w", err)
}
snapshot.Groups = make([]Group, 0, len(groups))
for _, group := range groups {
snapshot.Groups = append(snapshot.Groups, ConvertGroup(group))
}
return nil
})
eg.Go(func() error {
groupMembers, err := r.options.Database.GetGroupMembers(ctx)
if err != nil {
return xerrors.Errorf("get groups: %w", err)
}
snapshot.GroupMembers = make([]GroupMember, 0, len(groupMembers))
for _, member := range groupMembers {
snapshot.GroupMembers = append(snapshot.GroupMembers, ConvertGroupMember(member))
}
return nil
})
eg.Go(func() error {
workspaceRows, err := r.options.Database.GetWorkspaces(ctx, database.GetWorkspacesParams{})
if err != nil {
@ -642,6 +661,26 @@ func ConvertUser(dbUser database.User) User {
EmailHashed: emailHashed,
RBACRoles: dbUser.RBACRoles,
CreatedAt: dbUser.CreatedAt,
Status: dbUser.Status,
}
}
func ConvertGroup(group database.Group) Group {
return Group{
ID: group.ID,
Name: group.Name,
OrganizationID: group.OrganizationID,
AvatarURL: group.AvatarURL,
QuotaAllowance: group.QuotaAllowance,
DisplayName: group.DisplayName,
Source: group.Source,
}
}
func ConvertGroupMember(member database.GroupMember) GroupMember {
return GroupMember{
GroupID: member.GroupID,
UserID: member.UserID,
}
}
@ -746,6 +785,8 @@ type Snapshot struct {
TemplateVersions []TemplateVersion `json:"template_versions"`
Templates []Template `json:"templates"`
Users []User `json:"users"`
Groups []Group `json:"groups"`
GroupMembers []GroupMember `json:"group_members"`
WorkspaceAgentStats []WorkspaceAgentStat `json:"workspace_agent_stats"`
WorkspaceAgents []WorkspaceAgent `json:"workspace_agents"`
WorkspaceApps []WorkspaceApp `json:"workspace_apps"`
@ -797,6 +838,21 @@ type User struct {
Status database.UserStatus `json:"status"`
}
type Group struct {
ID uuid.UUID `json:"id"`
Name string `json:"name"`
OrganizationID uuid.UUID `json:"organization_id"`
AvatarURL string `json:"avatar_url"`
QuotaAllowance int32 `json:"quota_allowance"`
DisplayName string `json:"display_name"`
Source database.GroupSource `json:"source"`
}
type GroupMember struct {
UserID uuid.UUID `json:"user_id"`
GroupID uuid.UUID `json:"group_id"`
}
type WorkspaceResource struct {
ID uuid.UUID `json:"id"`
CreatedAt time.Time `json:"created_at"`

View File

@ -55,6 +55,8 @@ func TestTelemetry(t *testing.T) {
SharingLevel: database.AppSharingLevelOwner,
Health: database.WorkspaceAppHealthDisabled,
})
_ = dbgen.Group(t, db, database.Group{})
_ = dbgen.GroupMember(t, db, database.GroupMember{})
wsagent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{})
// Update the workspace agent to have a valid subsystem.
err = db.UpdateWorkspaceAgentStartupByID(ctx, database.UpdateWorkspaceAgentStartupByIDParams{
@ -91,6 +93,8 @@ func TestTelemetry(t *testing.T) {
require.Len(t, snapshot.Templates, 1)
require.Len(t, snapshot.TemplateVersions, 1)
require.Len(t, snapshot.Users, 1)
require.Len(t, snapshot.Groups, 2)
require.Len(t, snapshot.GroupMembers, 1)
require.Len(t, snapshot.Workspaces, 1)
require.Len(t, snapshot.WorkspaceApps, 1)
require.Len(t, snapshot.WorkspaceAgents, 1)