chore: Add more dbgen functions (#6005)

This commit is contained in:
Steven Masley
2023-02-02 17:21:29 -06:00
committed by GitHub
parent 5fe4819669
commit 41e52310bf
3 changed files with 119 additions and 1 deletions

View File

@ -6,6 +6,7 @@ import (
"database/sql"
"encoding/hex"
"fmt"
"net"
"testing"
"time"
@ -21,6 +22,34 @@ import (
// All methods take in a 'seed' object. Any provided fields in the seed will be
// maintained. Any fields omitted will have sensible defaults generated.
func AuditLog(t *testing.T, db database.Store, seed database.AuditLog) database.AuditLog {
log, err := db.InsertAuditLog(context.Background(), database.InsertAuditLogParams{
ID: takeFirst(seed.ID, uuid.New()),
Time: takeFirst(seed.Time, time.Now()),
UserID: takeFirst(seed.UserID, uuid.New()),
OrganizationID: takeFirst(seed.OrganizationID, uuid.New()),
Ip: pqtype.Inet{
IPNet: takeFirstIP(seed.Ip.IPNet, net.IPNet{}),
Valid: takeFirst(seed.Ip.Valid, false),
},
UserAgent: sql.NullString{
String: takeFirst(seed.UserAgent.String, ""),
Valid: takeFirst(seed.UserAgent.Valid, false),
},
ResourceType: takeFirst(seed.ResourceType, database.ResourceTypeOrganization),
ResourceID: takeFirst(seed.ResourceID, uuid.New()),
ResourceTarget: takeFirst(seed.ResourceTarget, uuid.NewString()),
Action: takeFirst(seed.Action, database.AuditActionCreate),
Diff: takeFirstBytes(seed.Diff, []byte("{}")),
StatusCode: takeFirst(seed.StatusCode, 200),
AdditionalFields: takeFirstBytes(seed.Diff, []byte("{}")),
RequestID: takeFirst(seed.RequestID, uuid.New()),
ResourceIcon: takeFirst(seed.ResourceIcon, ""),
})
require.NoError(t, err, "insert audit log")
return log
}
func Template(t *testing.T, db database.Store, seed database.Template) database.Template {
template, err := db.InsertTemplate(context.Background(), database.InsertTemplateParams{
ID: takeFirst(seed.ID, uuid.New()),
@ -66,6 +95,47 @@ func APIKey(t *testing.T, db database.Store, seed database.APIKey) (key database
return key, fmt.Sprintf("%s-%s", key.ID, secret)
}
func WorkspaceAgent(t *testing.T, db database.Store, orig database.WorkspaceAgent) database.WorkspaceAgent {
workspace, err := db.InsertWorkspaceAgent(context.Background(), database.InsertWorkspaceAgentParams{
ID: takeFirst(orig.ID, uuid.New()),
CreatedAt: takeFirst(orig.CreatedAt, time.Now()),
UpdatedAt: takeFirst(orig.UpdatedAt, time.Now()),
Name: takeFirst(orig.Name, namesgenerator.GetRandomName(1)),
ResourceID: takeFirst(orig.ResourceID, uuid.New()),
AuthToken: takeFirst(orig.AuthToken, uuid.New()),
AuthInstanceID: sql.NullString{
String: takeFirst(orig.AuthInstanceID.String, namesgenerator.GetRandomName(1)),
Valid: takeFirst(orig.AuthInstanceID.Valid, true),
},
Architecture: takeFirst(orig.Architecture, "amd64"),
EnvironmentVariables: pqtype.NullRawMessage{
RawMessage: takeFirstBytes(orig.EnvironmentVariables.RawMessage, []byte("{}")),
Valid: takeFirst(orig.EnvironmentVariables.Valid, false),
},
OperatingSystem: takeFirst(orig.OperatingSystem, "linux"),
StartupScript: sql.NullString{
String: takeFirst(orig.StartupScript.String, ""),
Valid: takeFirst(orig.StartupScript.Valid, false),
},
Directory: takeFirst(orig.Directory, ""),
InstanceMetadata: pqtype.NullRawMessage{
RawMessage: takeFirstBytes(orig.ResourceMetadata.RawMessage, []byte("{}")),
Valid: takeFirst(orig.ResourceMetadata.Valid, false),
},
ResourceMetadata: pqtype.NullRawMessage{
RawMessage: takeFirstBytes(orig.ResourceMetadata.RawMessage, []byte("{}")),
Valid: takeFirst(orig.ResourceMetadata.Valid, false),
},
ConnectionTimeoutSeconds: takeFirst(orig.ConnectionTimeoutSeconds, 3600),
TroubleshootingURL: takeFirst(orig.TroubleshootingURL, "https://example.com"),
MOTDFile: takeFirst(orig.TroubleshootingURL, ""),
LoginBeforeReady: takeFirst(orig.LoginBeforeReady, false),
StartupScriptTimeoutSeconds: takeFirst(orig.StartupScriptTimeoutSeconds, 3600),
})
require.NoError(t, err, "insert workspace agent")
return workspace
}
func Workspace(t *testing.T, db database.Store, orig database.Workspace) database.Workspace {
workspace, err := db.InsertWorkspace(context.Background(), database.InsertWorkspaceParams{
ID: takeFirst(orig.ID, uuid.New()),
@ -89,7 +159,7 @@ func WorkspaceBuild(t *testing.T, db database.Store, orig database.WorkspaceBuil
UpdatedAt: takeFirst(orig.UpdatedAt, time.Now()),
WorkspaceID: takeFirst(orig.WorkspaceID, uuid.New()),
TemplateVersionID: takeFirst(orig.TemplateVersionID, uuid.New()),
BuildNumber: takeFirst(orig.BuildNumber, 0),
BuildNumber: takeFirst(orig.BuildNumber, 1),
Transition: takeFirst(orig.Transition, database.WorkspaceTransitionStart),
InitiatorID: takeFirst(orig.InitiatorID, uuid.New()),
JobID: takeFirst(orig.JobID, uuid.New()),
@ -140,6 +210,20 @@ func Group(t *testing.T, db database.Store, orig database.Group) database.Group
return group
}
func GroupMember(t *testing.T, db database.Store, orig database.GroupMember) database.GroupMember {
member := database.GroupMember{
UserID: takeFirst(orig.UserID, uuid.New()),
GroupID: takeFirst(orig.GroupID, uuid.New()),
}
//nolint:gosimple
err := db.InsertGroupMember(context.Background(), database.InsertGroupMemberParams{
UserID: member.UserID,
GroupID: member.GroupID,
})
require.NoError(t, err, "insert group member")
return member
}
func ProvisionerJob(t *testing.T, db database.Store, orig database.ProvisionerJob) database.ProvisionerJob {
job, err := db.InsertProvisionerJob(context.Background(), database.InsertProvisionerJobParams{
ID: takeFirst(orig.ID, uuid.New()),

View File

@ -14,6 +14,14 @@ import (
func TestGenerator(t *testing.T) {
t.Parallel()
t.Run("AuditLog", func(t *testing.T) {
t.Parallel()
db := databasefake.New()
_ = dbgen.AuditLog(t, db, database.AuditLog{})
logs := must(db.GetAuditLogsOffset(context.Background(), database.GetAuditLogsOffsetParams{Limit: 1}))
require.Len(t, logs, 1)
})
t.Run("APIKey", func(t *testing.T) {
t.Parallel()
db := databasefake.New()
@ -56,6 +64,17 @@ func TestGenerator(t *testing.T) {
require.Equal(t, exp, must(db.GetGroupByID(context.Background(), exp.ID)))
})
t.Run("GroupMember", func(t *testing.T) {
t.Parallel()
db := databasefake.New()
g := dbgen.Group(t, db, database.Group{})
u := dbgen.User(t, db, database.User{})
exp := []database.User{u}
dbgen.GroupMember(t, db, database.GroupMember{GroupID: g.ID, UserID: u.ID})
require.Equal(t, exp, must(db.GetGroupMembers(context.Background(), g.ID)))
})
t.Run("Organization", func(t *testing.T) {
t.Parallel()
db := databasefake.New()
@ -70,6 +89,13 @@ func TestGenerator(t *testing.T) {
require.Equal(t, exp, must(db.GetWorkspaceByID(context.Background(), exp.ID)))
})
t.Run("WorkspaceAgent", func(t *testing.T) {
t.Parallel()
db := databasefake.New()
exp := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{})
require.Equal(t, exp, must(db.GetWorkspaceAgentByID(context.Background(), exp.ID)))
})
t.Run("Template", func(t *testing.T) {
t.Parallel()
db := databasefake.New()

View File

@ -1,5 +1,13 @@
package dbgen
import "net"
func takeFirstIP(values ...net.IPNet) net.IPNet {
return takeFirstF(values, func(v net.IPNet) bool {
return len(v.IP) != 0 && len(v.Mask) != 0
})
}
// takeFirstBytes implements takeFirst for []byte.
// []byte is not a comparable type.
func takeFirstBytes(values ...[]byte) []byte {