Feat: Agent improvements

This commit is contained in:
Daniel Hougaard
2024-03-01 06:41:17 +01:00
parent fb8c4bd415
commit 5096ce3bdc
3 changed files with 167 additions and 57 deletions

View File

@ -71,19 +71,14 @@ type Template struct {
SourcePath string `yaml:"source-path"`
Base64TemplateContent string `yaml:"base64-template-content"`
DestinationPath string `yaml:"destination-path"`
}
type SecretsStateManager struct {
// etags should be stored in memory, and the key should be env-secretPath-projectID, and the value should be the actual etag
etags map[string]string
secretMutationChannel chan bool
}
func NewSecretsStateManager(secretMutationChannel chan bool) *SecretsStateManager {
return &SecretsStateManager{
etags: make(map[string]string),
secretMutationChannel: secretMutationChannel,
}
Config struct { // Configurations for the template
PollingInterval string `yaml:"polling-interval"` // How often to poll for changes in the secret
Exec struct {
Command string `yaml:"command"` // Command to execute once the template has been rendered
Timeout int64 `yaml:"timeout"` // Timeout for the command
} `yaml:"exec"` // Command to execute once the template has been rendered
} `yaml:"config"`
}
func ReadFile(filePath string) ([]byte, error) {
@ -183,29 +178,24 @@ func ParseAgentConfig(configFile []byte) (*Config, error) {
return config, nil
}
func secretTemplateFunction(accessToken string, secretStateManager *SecretsStateManager) func(string, string, string) ([]models.SingleEnvironmentVariable, error) {
func secretTemplateFunction(accessToken string, existingEtag string, currentEtag *string) func(string, string, string) ([]models.SingleEnvironmentVariable, error) {
return func(projectID, envSlug, secretPath string) ([]models.SingleEnvironmentVariable, error) {
res, err := util.GetPlainTextSecretsViaMachineIdentity(accessToken, projectID, envSlug, secretPath, false)
if err != nil {
return nil, err
}
if secretStateManager != nil {
key := fmt.Sprintf("%s-%s-%s", envSlug, secretPath, projectID)
oldEtag, ok := secretStateManager.etags[key] // if there's no etag, it means it's the first time we are fetching this secret. we should only notify the secretMutationChannel if the etag has changed, not if it's the first time we are fetching the secret
if ok && oldEtag != res.Hash {
secretStateManager.secretMutationChannel <- true
}
secretStateManager.etags[key] = res.Hash
if existingEtag != res.Hash {
*currentEtag = res.Hash
}
return res.Secrets, nil
}
}
func ProcessTemplate(templatePath string, data interface{}, accessToken string, secretStateManager *SecretsStateManager) (*bytes.Buffer, error) {
func ProcessTemplate(templatePath string, data interface{}, accessToken string, existingEtag string, currentEtag *string) (*bytes.Buffer, error) {
// custom template function to fetch secrets from Infisical
secretFunction := secretTemplateFunction(accessToken, secretStateManager)
secretFunction := secretTemplateFunction(accessToken, existingEtag, currentEtag)
funcs := template.FuncMap{
"secret": secretFunction,
}
@ -225,7 +215,7 @@ func ProcessTemplate(templatePath string, data interface{}, accessToken string,
return &buf, nil
}
func ProcessBase64Template(encodedTemplate string, data interface{}, accessToken string, secretStateManager *SecretsStateManager) (*bytes.Buffer, error) {
func ProcessBase64Template(encodedTemplate string, data interface{}, accessToken string, existingEtag string, currentEtag *string) (*bytes.Buffer, error) {
// custom template function to fetch secrets from Infisical
decoded, err := base64.StdEncoding.DecodeString(encodedTemplate)
if err != nil {
@ -234,7 +224,7 @@ func ProcessBase64Template(encodedTemplate string, data interface{}, accessToken
templateString := string(decoded)
secretFunction := secretTemplateFunction(accessToken, secretStateManager) // TODO: Fix this
secretFunction := secretTemplateFunction(accessToken, existingEtag, currentEtag) // TODO: Fix this
funcs := template.FuncMap{
"secret": secretFunction,
}
@ -266,13 +256,13 @@ type TokenManager struct {
clientIdPath string
clientSecretPath string
newAccessTokenNotificationChan chan bool
removeClientSecretOnRead bool
cachedClientSecret string
exitAfterAuth bool
removeClientSecretOnRead bool
cachedClientSecret string
exitAfterAuth bool
}
func NewTokenManager(fileDeposits []Sink, templates []Template, clientIdPath string, clientSecretPath string, newAccessTokenNotificationChan chan bool, removeClientSecretOnRead bool, exitAfterAuth bool) *TokenManager {
log.Info().Msgf("Token manager done, templates: %+v", templates[0])
return &TokenManager{
filePaths: fileDeposits,
templates: templates,
@ -282,6 +272,7 @@ func NewTokenManager(fileDeposits []Sink, templates []Template, clientIdPath str
removeClientSecretOnRead: removeClientSecretOnRead,
exitAfterAuth: exitAfterAuth,
}
}
func (tm *TokenManager) SetToken(token string, accessTokenTTL time.Duration, accessTokenMaxTTL time.Duration) {
@ -459,38 +450,83 @@ func (tm *TokenManager) WriteTokenToFiles() {
}
}
func (tm *TokenManager) FetchSecrets(secretStateManager *SecretsStateManager) {
func (tm *TokenManager) WriteTemplateToFile(bytes *bytes.Buffer, template *Template) {
log.Info().Msgf("template engine started...")
if err := WriteBytesToFile(bytes, template.DestinationPath); err != nil {
log.Error().Msgf("template engine: unable to write secrets to path because %s. Will try again on next cycle", err)
return
}
log.Info().Msgf("template engine: secret template at path %s has been rendered and saved to path %s", template.SourcePath, template.DestinationPath)
}
func (tm *TokenManager) MonitorSecretChanges(secretTemplate Template, sigChan chan os.Signal) {
pollingInterval := time.Duration(5 * time.Minute)
if secretTemplate.Config.PollingInterval != "" {
interval, err := util.ConvertPollingIntervalToTime(secretTemplate.Config.PollingInterval)
if err != nil {
log.Error().Msgf("unable to convert polling interval to time because %v", err)
sigChan <- syscall.SIGINT
return
} else {
pollingInterval = interval
}
}
var existingEtag string
var currentEtag string
var firstRun = true
execTimeout := secretTemplate.Config.Exec.Timeout
execCommand := secretTemplate.Config.Exec.Command
// Now you can use the `command` variable, which is guaranteed to be a string
for {
log.Info().Msg("polling")
token := tm.GetToken()
if token != "" {
for _, secretTemplate := range tm.templates {
var processedTemplate *bytes.Buffer
var err error
if secretTemplate.SourcePath != "" {
processedTemplate, err = ProcessTemplate(secretTemplate.SourcePath, nil, token, secretStateManager)
} else {
processedTemplate, err = ProcessBase64Template(secretTemplate.Base64TemplateContent, nil, token, secretStateManager)
}
if err != nil {
log.Error().Msgf("template engine: unable to render secrets because %s. Will try again on next cycle", err)
var processedTemplate *bytes.Buffer
var err error
continue
}
if err := WriteBytesToFile(processedTemplate, secretTemplate.DestinationPath); err != nil {
log.Error().Msgf("template engine: unable to write secrets to path because %s. Will try again on next cycle", err)
continue
}
log.Info().Msgf("template engine: secret template at path %s has been rendered and saved to path %s", secretTemplate.SourcePath, secretTemplate.DestinationPath)
if secretTemplate.SourcePath != "" {
processedTemplate, err = ProcessTemplate(secretTemplate.SourcePath, nil, token, existingEtag, &currentEtag)
} else {
processedTemplate, err = ProcessBase64Template(secretTemplate.Base64TemplateContent, nil, token, existingEtag, &currentEtag)
}
// fetch new secrets every 5 minutes (TODO: add PubSub in the future )
time.Sleep(5 * time.Second)
if err != nil {
log.Error().Msgf("unable to process template because %v", err)
} else {
if (existingEtag != currentEtag) || firstRun {
tm.WriteTemplateToFile(processedTemplate, &secretTemplate)
existingEtag = currentEtag
if !firstRun && execCommand != "" {
log.Info().Msgf("executing command: %s", execCommand)
err := ExecuteCommandWithTimeout(execCommand, execTimeout)
if err != nil {
log.Error().Msgf("unable to execute command because %v", err)
}
}
if firstRun {
firstRun = false
}
}
}
}
time.Sleep(pollingInterval)
}
}
@ -568,24 +604,22 @@ var agentCmd = &cobra.Command{
configUniversalAuthType := agentConfig.Auth.Config.(UniversalAuth)
tokenRefreshNotifier := make(chan bool)
secretMutationNotifier := make(chan bool)
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
filePaths := agentConfig.Sinks
tm := NewTokenManager(filePaths, agentConfig.Templates, configUniversalAuthType.ClientIDPath, configUniversalAuthType.ClientSecretPath, tokenRefreshNotifier, configUniversalAuthType.RemoveClientSecretOnRead, agentConfig.Infisical.ExitAfterAuth)
ssm := NewSecretsStateManager(secretMutationNotifier)
go tm.ManageTokenLifecycle()
go tm.FetchSecrets(ssm)
for _, template := range agentConfig.Templates {
go tm.MonitorSecretChanges(template, sigChan)
}
for {
select {
case <-tokenRefreshNotifier:
go tm.WriteTokenToFiles()
case <-secretMutationNotifier:
log.Info().Msgf("Mashallah, a mutation has occurred")
case <-sigChan:
log.Info().Msg("agent is gracefully shutting...")
// TODO: check if we are in the middle of writing files to disk

View File

@ -4,6 +4,7 @@ Copyright (c) 2023 Infisical Inc.
package cmd
import (
"context"
"fmt"
"os"
"os/exec"
@ -11,6 +12,7 @@ import (
"runtime"
"strings"
"syscall"
"time"
"github.com/Infisical/infisical-merge/packages/models"
"github.com/Infisical/infisical-merge/packages/util"
@ -270,3 +272,36 @@ func execCmd(cmd *exec.Cmd) error {
os.Exit(waitStatus.ExitStatus())
return nil
}
func ExecuteCommandWithTimeout(command string, timeout int64) error {
shell := [2]string{"sh", "-c"}
if runtime.GOOS == "windows" {
shell = [2]string{"cmd", "/C"}
} else {
currentShell := os.Getenv("SHELL")
if currentShell != "" {
shell[0] = currentShell
}
}
ctx := context.Background()
if timeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(context.Background(), time.Duration(timeout)*time.Second)
defer cancel()
}
cmd := exec.CommandContext(ctx, shell[0], shell[1], command)
if err := cmd.Run(); err != nil {
if exitError, ok := err.(*exec.ExitError); ok { // type assertion
if exitError.ProcessState.ExitCode() == -1 {
return fmt.Errorf("command timed out")
}
}
return err
} else {
return nil
}
}

View File

@ -0,0 +1,41 @@
package util
import (
"fmt"
"strconv"
"time"
)
// ConvertPollingIntervalToTime converts a string representation of a polling interval to a time.Duration
func ConvertPollingIntervalToTime(pollingInterval string) (time.Duration, error) {
length := len(pollingInterval)
if length < 2 {
return 0, fmt.Errorf("invalid format")
}
unit := pollingInterval[length-1:]
numberPart := pollingInterval[:length-1]
number, err := strconv.Atoi(numberPart)
if err != nil {
return 0, err
}
switch unit {
case "s":
if number < 60 {
return 0, fmt.Errorf("polling interval should be at least 60 seconds")
}
return time.Duration(number) * time.Second, nil
case "m":
return time.Duration(number) * time.Minute, nil
case "h":
return time.Duration(number) * time.Hour, nil
case "d":
return time.Duration(number) * 24 * time.Hour, nil
case "w":
return time.Duration(number) * 7 * 24 * time.Hour, nil
default:
return 0, fmt.Errorf("invalid time unit")
}
}