feat: Allow inheriting parameters from previous template_versions when updating a template (#2397)

* WIP: feat: Update templates also updates parameters
* Insert params for template version update
* Working implementation of inherited params
* Add "--always-prompt" flag and logging info
This commit is contained in:
Steven Masley
2022-06-17 12:22:28 -05:00
committed by GitHub
parent 18973a65c1
commit 64b92eea67
16 changed files with 548 additions and 168 deletions

View File

@ -12,6 +12,7 @@ import (
"golang.org/x/exp/slices"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/util/slice"
)
// New returns an in-memory fake of the database.
@ -103,6 +104,19 @@ func (q *fakeQuerier) AcquireProvisionerJob(_ context.Context, arg database.Acqu
return database.ProvisionerJob{}, sql.ErrNoRows
}
func (q *fakeQuerier) ParameterValue(_ context.Context, id uuid.UUID) (database.ParameterValue, error) {
q.mutex.Lock()
defer q.mutex.Unlock()
for _, parameterValue := range q.parameterValues {
if parameterValue.ID.String() != id.String() {
continue
}
return parameterValue, nil
}
return database.ParameterValue{}, sql.ErrNoRows
}
func (q *fakeQuerier) DeleteParameterValueByID(_ context.Context, id uuid.UUID) error {
q.mutex.Lock()
defer q.mutex.Unlock()
@ -744,17 +758,27 @@ func (q *fakeQuerier) GetOrganizationsByUserID(_ context.Context, userID uuid.UU
return organizations, nil
}
func (q *fakeQuerier) GetParameterValuesByScope(_ context.Context, arg database.GetParameterValuesByScopeParams) ([]database.ParameterValue, error) {
func (q *fakeQuerier) ParameterValues(_ context.Context, arg database.ParameterValuesParams) ([]database.ParameterValue, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
parameterValues := make([]database.ParameterValue, 0)
for _, parameterValue := range q.parameterValues {
if parameterValue.Scope != arg.Scope {
continue
if len(arg.Scopes) > 0 {
if !slice.Contains(arg.Scopes, parameterValue.Scope) {
continue
}
}
if parameterValue.ScopeID != arg.ScopeID {
continue
if len(arg.ScopeIds) > 0 {
if !slice.Contains(arg.ScopeIds, parameterValue.ScopeID) {
continue
}
}
if len(arg.Ids) > 0 {
if !slice.Contains(arg.Ids, parameterValue.ID) {
continue
}
}
parameterValues = append(parameterValues, parameterValue)
}

View File

@ -45,7 +45,6 @@ type querier interface {
GetParameterSchemasByJobID(ctx context.Context, jobID uuid.UUID) ([]ParameterSchema, error)
GetParameterSchemasCreatedAfter(ctx context.Context, createdAt time.Time) ([]ParameterSchema, error)
GetParameterValueByScopeAndName(ctx context.Context, arg GetParameterValueByScopeAndNameParams) (ParameterValue, error)
GetParameterValuesByScope(ctx context.Context, arg GetParameterValuesByScopeParams) ([]ParameterValue, error)
GetProvisionerDaemonByID(ctx context.Context, id uuid.UUID) (ProvisionerDaemon, error)
GetProvisionerDaemons(ctx context.Context) ([]ProvisionerDaemon, error)
GetProvisionerJobByID(ctx context.Context, id uuid.UUID) (ProvisionerJob, error)
@ -109,6 +108,8 @@ type querier interface {
InsertWorkspaceApp(ctx context.Context, arg InsertWorkspaceAppParams) (WorkspaceApp, error)
InsertWorkspaceBuild(ctx context.Context, arg InsertWorkspaceBuildParams) (WorkspaceBuild, error)
InsertWorkspaceResource(ctx context.Context, arg InsertWorkspaceResourceParams) (WorkspaceResource, error)
ParameterValue(ctx context.Context, id uuid.UUID) (ParameterValue, error)
ParameterValues(ctx context.Context, arg ParameterValuesParams) ([]ParameterValue, error)
UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDParams) error
UpdateGitSSHKey(ctx context.Context, arg UpdateGitSSHKeyParams) error
UpdateMemberRoles(ctx context.Context, arg UpdateMemberRolesParams) (OrganizationMember, error)

View File

@ -1079,54 +1079,6 @@ func (q *sqlQuerier) GetParameterValueByScopeAndName(ctx context.Context, arg Ge
return i, err
}
const getParameterValuesByScope = `-- name: GetParameterValuesByScope :many
SELECT
id, created_at, updated_at, scope, scope_id, name, source_scheme, source_value, destination_scheme
FROM
parameter_values
WHERE
scope = $1
AND scope_id = $2
`
type GetParameterValuesByScopeParams struct {
Scope ParameterScope `db:"scope" json:"scope"`
ScopeID uuid.UUID `db:"scope_id" json:"scope_id"`
}
func (q *sqlQuerier) GetParameterValuesByScope(ctx context.Context, arg GetParameterValuesByScopeParams) ([]ParameterValue, error) {
rows, err := q.db.QueryContext(ctx, getParameterValuesByScope, arg.Scope, arg.ScopeID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []ParameterValue
for rows.Next() {
var i ParameterValue
if err := rows.Scan(
&i.ID,
&i.CreatedAt,
&i.UpdatedAt,
&i.Scope,
&i.ScopeID,
&i.Name,
&i.SourceScheme,
&i.SourceValue,
&i.DestinationScheme,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const insertParameterValue = `-- name: InsertParameterValue :one
INSERT INTO
parameter_values (
@ -1183,6 +1135,103 @@ func (q *sqlQuerier) InsertParameterValue(ctx context.Context, arg InsertParamet
return i, err
}
const parameterValue = `-- name: ParameterValue :one
SELECT id, created_at, updated_at, scope, scope_id, name, source_scheme, source_value, destination_scheme FROM
parameter_values
WHERE
id = $1
`
func (q *sqlQuerier) ParameterValue(ctx context.Context, id uuid.UUID) (ParameterValue, error) {
row := q.db.QueryRowContext(ctx, parameterValue, id)
var i ParameterValue
err := row.Scan(
&i.ID,
&i.CreatedAt,
&i.UpdatedAt,
&i.Scope,
&i.ScopeID,
&i.Name,
&i.SourceScheme,
&i.SourceValue,
&i.DestinationScheme,
)
return i, err
}
const parameterValues = `-- name: ParameterValues :many
SELECT
id, created_at, updated_at, scope, scope_id, name, source_scheme, source_value, destination_scheme
FROM
parameter_values
WHERE
CASE
WHEN cardinality($1 :: parameter_scope[]) > 0 THEN
scope = ANY($1 :: parameter_scope[])
ELSE true
END
AND CASE
WHEN cardinality($2 :: uuid[]) > 0 THEN
scope_id = ANY($2 :: uuid[])
ELSE true
END
AND CASE
WHEN cardinality($3 :: uuid[]) > 0 THEN
id = ANY($3 :: uuid[])
ELSE true
END
AND CASE
WHEN cardinality($4 :: text[]) > 0 THEN
"name" = ANY($4 :: text[])
ELSE true
END
`
type ParameterValuesParams struct {
Scopes []ParameterScope `db:"scopes" json:"scopes"`
ScopeIds []uuid.UUID `db:"scope_ids" json:"scope_ids"`
Ids []uuid.UUID `db:"ids" json:"ids"`
Names []string `db:"names" json:"names"`
}
func (q *sqlQuerier) ParameterValues(ctx context.Context, arg ParameterValuesParams) ([]ParameterValue, error) {
rows, err := q.db.QueryContext(ctx, parameterValues,
pq.Array(arg.Scopes),
pq.Array(arg.ScopeIds),
pq.Array(arg.Ids),
pq.Array(arg.Names),
)
if err != nil {
return nil, err
}
defer rows.Close()
var items []ParameterValue
for rows.Next() {
var i ParameterValue
if err := rows.Scan(
&i.ID,
&i.CreatedAt,
&i.UpdatedAt,
&i.Scope,
&i.ScopeID,
&i.Name,
&i.SourceScheme,
&i.SourceValue,
&i.DestinationScheme,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getProvisionerDaemonByID = `-- name: GetProvisionerDaemonByID :one
SELECT
id, created_at, updated_at, name, provisioners

View File

@ -1,17 +1,43 @@
-- name: ParameterValue :one
SELECT * FROM
parameter_values
WHERE
id = $1;
-- name: DeleteParameterValueByID :exec
DELETE FROM
parameter_values
WHERE
id = $1;
-- name: GetParameterValuesByScope :many
-- name: ParameterValues :many
SELECT
*
FROM
parameter_values
WHERE
scope = $1
AND scope_id = $2;
CASE
WHEN cardinality(@scopes :: parameter_scope[]) > 0 THEN
scope = ANY(@scopes :: parameter_scope[])
ELSE true
END
AND CASE
WHEN cardinality(@scope_ids :: uuid[]) > 0 THEN
scope_id = ANY(@scope_ids :: uuid[])
ELSE true
END
AND CASE
WHEN cardinality(@ids :: uuid[]) > 0 THEN
id = ANY(@ids :: uuid[])
ELSE true
END
AND CASE
WHEN cardinality(@names :: text[]) > 0 THEN
"name" = ANY(@names :: text[])
ELSE true
END
;
-- name: GetParameterValueByScopeAndName :one
SELECT

View File

@ -61,9 +61,9 @@ func Compute(ctx context.Context, db database.Store, scope ComputeScope, options
}
// Job parameters come second!
err = compute.injectScope(ctx, database.GetParameterValuesByScopeParams{
Scope: database.ParameterScopeImportJob,
ScopeID: scope.TemplateImportJobID,
err = compute.injectScope(ctx, database.ParameterValuesParams{
Scopes: []database.ParameterScope{database.ParameterScopeImportJob},
ScopeIds: []uuid.UUID{scope.TemplateImportJobID},
})
if err != nil {
return nil, err
@ -105,9 +105,9 @@ func Compute(ctx context.Context, db database.Store, scope ComputeScope, options
if scope.TemplateID.Valid {
// Template parameters come third!
err = compute.injectScope(ctx, database.GetParameterValuesByScopeParams{
Scope: database.ParameterScopeTemplate,
ScopeID: scope.TemplateID.UUID,
err = compute.injectScope(ctx, database.ParameterValuesParams{
Scopes: []database.ParameterScope{database.ParameterScopeTemplate},
ScopeIds: []uuid.UUID{scope.TemplateID.UUID},
})
if err != nil {
return nil, err
@ -116,9 +116,9 @@ func Compute(ctx context.Context, db database.Store, scope ComputeScope, options
if scope.WorkspaceID.Valid {
// Workspace parameters come last!
err = compute.injectScope(ctx, database.GetParameterValuesByScopeParams{
Scope: database.ParameterScopeWorkspace,
ScopeID: scope.WorkspaceID.UUID,
err = compute.injectScope(ctx, database.ParameterValuesParams{
Scopes: []database.ParameterScope{database.ParameterScopeWorkspace},
ScopeIds: []uuid.UUID{scope.WorkspaceID.UUID},
})
if err != nil {
return nil, err
@ -148,13 +148,13 @@ type compute struct {
}
// Validates and computes the value for parameters; setting the value on "parameterByName".
func (c *compute) injectScope(ctx context.Context, scopeParams database.GetParameterValuesByScopeParams) error {
scopedParameters, err := c.db.GetParameterValuesByScope(ctx, scopeParams)
func (c *compute) injectScope(ctx context.Context, scopeParams database.ParameterValuesParams) error {
scopedParameters, err := c.db.ParameterValues(ctx, scopeParams)
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
if err != nil {
return xerrors.Errorf("get %s parameters: %w", scopeParams.Scope, err)
return xerrors.Errorf("get %s parameters: %w", scopeParams.Scopes, err)
}
for _, scopedParameter := range scopedParameters {

View File

@ -91,9 +91,9 @@ func (api *API) parameters(rw http.ResponseWriter, r *http.Request) {
return
}
parameterValues, err := api.Database.GetParameterValuesByScope(r.Context(), database.GetParameterValuesByScopeParams{
Scope: scope,
ScopeID: scopeID,
parameterValues, err := api.Database.ParameterValues(r.Context(), database.ParameterValuesParams{
Scopes: []database.ParameterScope{scope},
ScopeIds: []uuid.UUID{scopeID},
})
if errors.Is(err, sql.ErrNoRows) {
err = nil
@ -214,6 +214,8 @@ func (api *API) parameterRBACResource(rw http.ResponseWriter, r *http.Request, s
switch scope {
case database.ParameterScopeWorkspace:
resource, err = api.Database.GetWorkspaceByID(ctx, scopeID)
case database.ParameterScopeImportJob:
resource, err = api.Database.GetTemplateVersionByJobID(ctx, scopeID)
case database.ParameterScopeTemplate:
resource, err = api.Database.GetTemplateByID(ctx, scopeID)
default:
@ -237,12 +239,9 @@ func (api *API) parameterRBACResource(rw http.ResponseWriter, r *http.Request, s
}
func readScopeAndID(rw http.ResponseWriter, r *http.Request) (database.ParameterScope, uuid.UUID, bool) {
var scope database.ParameterScope
switch chi.URLParam(r, "scope") {
case string(codersdk.ParameterTemplate):
scope = database.ParameterScopeTemplate
case string(codersdk.ParameterWorkspace):
scope = database.ParameterScopeWorkspace
scope := database.ParameterScope(chi.URLParam(r, "scope"))
switch scope {
case database.ParameterScopeTemplate, database.ParameterScopeImportJob, database.ParameterScopeWorkspace:
default:
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
Message: fmt.Sprintf("Invalid scope %q.", scope),

View File

@ -220,7 +220,7 @@ func (api *API) postTemplateByOrganization(rw http.ResponseWriter, r *http.Reque
CreatedAt: database.Now(),
UpdatedAt: database.Now(),
Scope: database.ParameterScopeTemplate,
ScopeID: dbTemplate.ID,
ScopeID: template.ID,
SourceScheme: database.ParameterSourceScheme(parameterValue.SourceScheme),
SourceValue: parameterValue.SourceValue,
DestinationScheme: database.ParameterDestinationScheme(parameterValue.DestinationScheme),

View File

@ -559,9 +559,16 @@ func (api *API) patchActiveTemplateVersion(rw http.ResponseWriter, r *http.Reque
})
return
}
err = api.Database.UpdateTemplateActiveVersionByID(r.Context(), database.UpdateTemplateActiveVersionByIDParams{
ID: template.ID,
ActiveVersionID: req.ID,
err = api.Database.InTx(func(store database.Store) error {
err = store.UpdateTemplateActiveVersionByID(r.Context(), database.UpdateTemplateActiveVersionByIDParams{
ID: template.ID,
ActiveVersionID: req.ID,
})
if err != nil {
return xerrors.Errorf("update active version: %w", err)
}
return nil
})
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
@ -631,7 +638,53 @@ func (api *API) postTemplateVersionsByOrganization(rw http.ResponseWriter, r *ht
var provisionerJob database.ProvisionerJob
err = api.Database.InTx(func(db database.Store) error {
jobID := uuid.New()
inherits := make([]uuid.UUID, 0)
for _, parameterValue := range req.ParameterValues {
if parameterValue.CloneID != uuid.Nil {
inherits = append(inherits, parameterValue.CloneID)
}
}
// Expand inherited params
if len(inherits) > 0 {
if req.TemplateID == uuid.Nil {
return xerrors.Errorf("cannot inherit parameters if template_id is not set")
}
inheritedParams, err := db.ParameterValues(r.Context(), database.ParameterValuesParams{
Ids: inherits,
})
if err != nil {
return xerrors.Errorf("fetch inherited params: %w", err)
}
for _, copy := range inheritedParams {
// This is a bit inefficient, as we make a new db call for each
// param.
version, err := db.GetTemplateVersionByJobID(r.Context(), copy.ScopeID)
if err != nil {
return xerrors.Errorf("fetch template version for param %q: %w", copy.Name, err)
}
if !version.TemplateID.Valid || version.TemplateID.UUID != req.TemplateID {
return xerrors.Errorf("cannot inherit parameters from other templates")
}
if copy.Scope != database.ParameterScopeImportJob {
return xerrors.Errorf("copy parameter scope is %q, must be %q", copy.Scope, database.ParameterScopeImportJob)
}
// Add the copied param to the list to process
req.ParameterValues = append(req.ParameterValues, codersdk.CreateParameterRequest{
Name: copy.Name,
SourceValue: copy.SourceValue,
SourceScheme: codersdk.ParameterSourceScheme(copy.SourceScheme),
DestinationScheme: codersdk.ParameterDestinationScheme(copy.DestinationScheme),
})
}
}
for _, parameterValue := range req.ParameterValues {
if parameterValue.CloneID != uuid.Nil {
continue
}
_, err = db.InsertParameterValue(r.Context(), database.InsertParameterValueParams{
ID: uuid.New(),
Name: parameterValue.Name,

View File

@ -0,0 +1,10 @@
package slice
func Contains[T comparable](haystack []T, needle T) bool {
for _, hay := range haystack {
if needle == hay {
return true
}
}
return false
}

View File

@ -0,0 +1,35 @@
package slice_test
import (
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/coder/coder/coderd/util/slice"
)
func TestContains(t *testing.T) {
t.Parallel()
assertSetContains(t, []int{1, 2, 3, 4, 5}, []int{1, 2, 3, 4, 5}, []int{0, 6, -1, -2, 100})
assertSetContains(t, []string{"hello", "world", "foo", "bar", "baz"}, []string{"hello", "world", "baz"}, []string{"not", "words", "in", "set"})
assertSetContains(t,
[]uuid.UUID{uuid.New(), uuid.MustParse("c7c6686d-a93c-4df2-bef9-5f837e9a33d5"), uuid.MustParse("8f3b3e0b-2c3f-46a5-a365-fd5b62bd8818")},
[]uuid.UUID{uuid.MustParse("c7c6686d-a93c-4df2-bef9-5f837e9a33d5")},
[]uuid.UUID{uuid.MustParse("1d00e27d-8de6-46f8-80d5-1da0ca83874a")},
)
}
func assertSetContains[T comparable](t *testing.T, set []T, in []T, out []T) {
t.Helper()
for _, e := range set {
require.True(t, slice.Contains(set, e), "elements in set should be in the set")
}
for _, e := range in {
require.True(t, slice.Contains(set, e), "expect element in set")
}
for _, e := range out {
require.False(t, slice.Contains(set, e), "expect element in set")
}
}