feat: Use environment variables and startup script in agent (#1147)

These values were ignored. Environment variables are applied to
new sessions, and are refreshed on reconnect. This is cool because
a workspace could be updated with new environment variables without
requiring a complete start/stop.

The startup script is only ran once regardless of changes, which
feels like the expected behavior.
This commit is contained in:
Kyle Carberry
2022-04-25 13:30:39 -05:00
committed by GitHub
parent 09405ddc40
commit a2dd618849
10 changed files with 189 additions and 28 deletions

View File

@ -11,9 +11,13 @@ import (
"os"
"os/exec"
"os/user"
"runtime"
"sync"
"time"
gsyslog "github.com/hashicorp/go-syslog"
"go.uber.org/atomic"
"cdr.dev/slog"
"github.com/coder/coder/agent/usershell"
"github.com/coder/coder/peer"
@ -29,10 +33,11 @@ import (
)
type Options struct {
Logger slog.Logger
EnvironmentVariables map[string]string
StartupScript string
}
type Dialer func(ctx context.Context, logger slog.Logger) (*peerbroker.Listener, error)
type Dialer func(ctx context.Context, logger slog.Logger) (*Options, *peerbroker.Listener, error)
func New(dialer Dialer, logger slog.Logger) io.Closer {
ctx, cancelFunc := context.WithCancel(context.Background())
@ -55,16 +60,21 @@ type agent struct {
closeMutex sync.Mutex
closed chan struct{}
sshServer *ssh.Server
// Environment variables sent by Coder to inject for shell sessions.
// This is atomic because values can change after reconnect.
envVars atomic.Value
startupScript atomic.Bool
sshServer *ssh.Server
}
func (a *agent) run(ctx context.Context) {
var options *Options
var peerListener *peerbroker.Listener
var err error
// An exponential back-off occurs when the connection is failing to dial.
// This is to prevent server spam in case of a coderd outage.
for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); {
peerListener, err = a.dialer(ctx, a.logger)
options, peerListener, err = a.dialer(ctx, a.logger)
if err != nil {
if errors.Is(err, context.Canceled) {
return
@ -83,6 +93,20 @@ func (a *agent) run(ctx context.Context) {
return
default:
}
a.envVars.Store(options.EnvironmentVariables)
if a.startupScript.CAS(false, true) {
// The startup script has not ran yet!
go func() {
err := a.runStartupScript(ctx, options.StartupScript)
if errors.Is(err, context.Canceled) {
return
}
if err != nil {
a.logger.Warn(ctx, "agent script failed", slog.Error(err))
}
}()
}
for {
conn, err := peerListener.Accept()
@ -101,6 +125,48 @@ func (a *agent) run(ctx context.Context) {
}
}
func (*agent) runStartupScript(ctx context.Context, script string) error {
if script == "" {
return nil
}
currentUser, err := user.Current()
if err != nil {
return xerrors.Errorf("get current user: %w", err)
}
username := currentUser.Username
shell, err := usershell.Get(username)
if err != nil {
return xerrors.Errorf("get user shell: %w", err)
}
var writer io.WriteCloser
// Attempt to use the syslog to write startup information.
writer, err = gsyslog.NewLogger(gsyslog.LOG_INFO, "USER", "coder-startup-script")
if err != nil {
// If the syslog isn't supported or cannot be created, use a text file in temp.
writer, err = os.CreateTemp("", "coder-startup-script.txt")
if err != nil {
return xerrors.Errorf("open startup script log file: %w", err)
}
}
defer func() {
_ = writer.Close()
}()
caller := "-c"
if runtime.GOOS == "windows" {
caller = "/c"
}
cmd := exec.CommandContext(ctx, shell, caller, script)
cmd.Stdout = writer
cmd.Stderr = writer
err = cmd.Run()
if err != nil {
return xerrors.Errorf("run: %w", err)
}
return nil
}
func (a *agent) handlePeerConn(ctx context.Context, conn *peer.Conn) {
go func() {
select {
@ -230,8 +296,24 @@ func (a *agent) handleSSHSession(session ssh.Session) error {
// OpenSSH executes all commands with the users current shell.
// We replicate that behavior for IDE support.
cmd := exec.CommandContext(session.Context(), shell, "-c", command)
caller := "-c"
if runtime.GOOS == "windows" {
caller = "/c"
}
cmd := exec.CommandContext(session.Context(), shell, caller, command)
cmd.Env = append(os.Environ(), session.Environ()...)
// Load environment variables passed via the agent.
envVars := a.envVars.Load()
if envVars != nil {
envVarMap, ok := envVars.(map[string]string)
if ok {
for key, value := range envVarMap {
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", key, value))
}
}
}
executablePath, err := os.Executable()
if err != nil {
return xerrors.Errorf("getting os executable: %w", err)