package agent import ( "bufio" "bytes" "context" "crypto/rand" "crypto/rsa" "encoding/binary" "encoding/json" "errors" "flag" "fmt" "io" "net" "net/http" "net/netip" "os" "os/exec" "os/user" "path/filepath" "reflect" "runtime" "sort" "strconv" "strings" "sync" "time" "github.com/armon/circbuf" "github.com/gliderlabs/ssh" "github.com/google/uuid" "github.com/pkg/sftp" "github.com/spf13/afero" "go.uber.org/atomic" gossh "golang.org/x/crypto/ssh" "golang.org/x/exp/slices" "golang.org/x/xerrors" "tailscale.com/net/speedtest" "tailscale.com/tailcfg" "tailscale.com/types/netlogtype" "cdr.dev/slog" "github.com/coder/coder/agent/usershell" "github.com/coder/coder/buildinfo" "github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/gitauth" "github.com/coder/coder/codersdk" "github.com/coder/coder/codersdk/agentsdk" "github.com/coder/coder/pty" "github.com/coder/coder/tailnet" "github.com/coder/retry" ) const ( ProtocolReconnectingPTY = "reconnecting-pty" ProtocolSSH = "ssh" ProtocolDial = "dial" // MagicSessionErrorCode indicates that something went wrong with the session, rather than the // command just returning a nonzero exit code, and is chosen as an arbitrary, high number // unlikely to shadow other exit codes, which are typically 1, 2, 3, etc. MagicSessionErrorCode = 229 // MagicSSHSessionTypeEnvironmentVariable is used to track the purpose behind an SSH connection. // This is stripped from any commands being executed, and is counted towards connection stats. MagicSSHSessionTypeEnvironmentVariable = "CODER_SSH_SESSION_TYPE" // MagicSSHSessionTypeVSCode is set in the SSH config by the VS Code extension to identify itself. MagicSSHSessionTypeVSCode = "vscode" // MagicSSHSessionTypeJetBrains is set in the SSH config by the JetBrains extension to identify itself. MagicSSHSessionTypeJetBrains = "jetbrains" ) type Options struct { Filesystem afero.Fs LogDir string TempDir string ExchangeToken func(ctx context.Context) (string, error) Client Client ReconnectingPTYTimeout time.Duration EnvironmentVariables map[string]string Logger slog.Logger AgentPorts map[int]string SSHMaxTimeout time.Duration TailnetListenPort uint16 } type Client interface { Manifest(ctx context.Context) (agentsdk.Manifest, error) Listen(ctx context.Context) (net.Conn, error) ReportStats(ctx context.Context, log slog.Logger, statsChan <-chan *agentsdk.Stats, setInterval func(time.Duration)) (io.Closer, error) PostLifecycle(ctx context.Context, state agentsdk.PostLifecycleRequest) error PostAppHealth(ctx context.Context, req agentsdk.PostAppHealthsRequest) error PostStartup(ctx context.Context, req agentsdk.PostStartupRequest) error PostMetadata(ctx context.Context, key string, req agentsdk.PostMetadataRequest) error PatchStartupLogs(ctx context.Context, req agentsdk.PatchStartupLogs) error } func New(options Options) io.Closer { if options.ReconnectingPTYTimeout == 0 { options.ReconnectingPTYTimeout = 5 * time.Minute } if options.Filesystem == nil { options.Filesystem = afero.NewOsFs() } if options.TempDir == "" { options.TempDir = os.TempDir() } if options.LogDir == "" { if options.TempDir != os.TempDir() { options.Logger.Debug(context.Background(), "log dir not set, using temp dir", slog.F("temp_dir", options.TempDir)) } options.LogDir = options.TempDir } if options.ExchangeToken == nil { options.ExchangeToken = func(ctx context.Context) (string, error) { return "", nil } } ctx, cancelFunc := context.WithCancel(context.Background()) a := &agent{ tailnetListenPort: options.TailnetListenPort, reconnectingPTYTimeout: options.ReconnectingPTYTimeout, logger: options.Logger, closeCancel: cancelFunc, closed: make(chan struct{}), envVars: options.EnvironmentVariables, client: options.Client, exchangeToken: options.ExchangeToken, filesystem: options.Filesystem, logDir: options.LogDir, tempDir: options.TempDir, lifecycleUpdate: make(chan struct{}, 1), lifecycleReported: make(chan codersdk.WorkspaceAgentLifecycle, 1), ignorePorts: options.AgentPorts, connStatsChan: make(chan *agentsdk.Stats, 1), sshMaxTimeout: options.SSHMaxTimeout, } a.init(ctx) return a } type agent struct { logger slog.Logger client Client exchangeToken func(ctx context.Context) (string, error) tailnetListenPort uint16 filesystem afero.Fs logDir string tempDir string // ignorePorts tells the api handler which ports to ignore when // listing all listening ports. This is helpful to hide ports that // are used by the agent, that the user does not care about. ignorePorts map[int]string reconnectingPTYs sync.Map reconnectingPTYTimeout time.Duration connCloseWait sync.WaitGroup closeCancel context.CancelFunc closeMutex sync.Mutex closed chan struct{} envVars map[string]string // manifest is atomic because values can change after reconnection. manifest atomic.Pointer[agentsdk.Manifest] sessionToken atomic.Pointer[string] sshServer *ssh.Server sshMaxTimeout time.Duration lifecycleUpdate chan struct{} lifecycleReported chan codersdk.WorkspaceAgentLifecycle lifecycleMu sync.RWMutex // Protects following. lifecycleState codersdk.WorkspaceAgentLifecycle network *tailnet.Conn connStatsChan chan *agentsdk.Stats latestStat atomic.Pointer[agentsdk.Stats] connCountVSCode atomic.Int64 connCountJetBrains atomic.Int64 connCountReconnectingPTY atomic.Int64 connCountSSHSession atomic.Int64 } // runLoop attempts to start the agent in a retry loop. // Coder may be offline temporarily, a connection issue // may be happening, but regardless after the intermittent // failure, you'll want the agent to reconnect. func (a *agent) runLoop(ctx context.Context) { go a.reportLifecycleLoop(ctx) go a.reportMetadataLoop(ctx) for retrier := retry.New(100*time.Millisecond, 10*time.Second); retrier.Wait(ctx); { a.logger.Info(ctx, "connecting to coderd") err := a.run(ctx) // Cancel after the run is complete to clean up any leaked resources! if err == nil { continue } if errors.Is(err, context.Canceled) { return } if a.isClosed() { return } if errors.Is(err, io.EOF) { a.logger.Info(ctx, "disconnected from coderd") continue } a.logger.Warn(ctx, "run exited with error", slog.Error(err)) } } func (a *agent) collectMetadata(ctx context.Context, md codersdk.WorkspaceAgentMetadataDescription) *codersdk.WorkspaceAgentMetadataResult { var out bytes.Buffer result := &codersdk.WorkspaceAgentMetadataResult{ // CollectedAt is set here for testing purposes and overrode by // the server to the time the server received the result to protect // against clock skew. // // In the future, the server may accept the timestamp from the agent // if it is certain the clocks are in sync. CollectedAt: time.Now(), } cmd, err := a.createCommand(ctx, md.Script, nil) if err != nil { result.Error = err.Error() return result } cmd.Stdout = &out cmd.Stderr = &out // The error isn't mutually exclusive with useful output. err = cmd.Run() const bufLimit = 10 << 10 if out.Len() > bufLimit { err = errors.Join( err, xerrors.Errorf("output truncated from %v to %v bytes", out.Len(), bufLimit), ) out.Truncate(bufLimit) } if err != nil { result.Error = err.Error() } result.Value = out.String() return result } func adjustIntervalForTests(i int64) time.Duration { // In tests we want to set shorter intervals because engineers are // impatient. base := time.Second if flag.Lookup("test.v") != nil { base = time.Millisecond * 100 } return time.Duration(i) * base } type metadataResultAndKey struct { result *codersdk.WorkspaceAgentMetadataResult key string } type trySingleflight struct { m sync.Map } func (t *trySingleflight) Do(key string, fn func()) { _, loaded := t.m.LoadOrStore(key, struct{}{}) if !loaded { // There is already a goroutine running for this key. return } defer t.m.Delete(key) fn() } func (a *agent) reportMetadataLoop(ctx context.Context) { baseInterval := adjustIntervalForTests(1) const metadataLimit = 128 var ( baseTicker = time.NewTicker(baseInterval) lastCollectedAts = make(map[string]time.Time) metadataResults = make(chan metadataResultAndKey, metadataLimit) ) defer baseTicker.Stop() // We use a custom singleflight that immediately returns if there is already // a goroutine running for a given key. This is to prevent a build-up of // goroutines waiting on Do when the script takes many multiples of // baseInterval to run. var flight trySingleflight for { select { case <-ctx.Done(): return case mr := <-metadataResults: lastCollectedAts[mr.key] = mr.result.CollectedAt err := a.client.PostMetadata(ctx, mr.key, *mr.result) if err != nil { a.logger.Error(ctx, "report metadata", slog.Error(err)) } case <-baseTicker.C: } if len(metadataResults) > 0 { // The inner collection loop expects the channel is empty before spinning up // all the collection goroutines. a.logger.Debug( ctx, "metadata collection backpressured", slog.F("queue_len", len(metadataResults)), ) continue } manifest := a.manifest.Load() if manifest == nil { continue } if len(manifest.Metadata) > metadataLimit { a.logger.Error( ctx, "metadata limit exceeded", slog.F("limit", metadataLimit), slog.F("got", len(manifest.Metadata)), ) continue } // If the manifest changes (e.g. on agent reconnect) we need to // purge old cache values to prevent lastCollectedAt from growing // boundlessly. for key := range lastCollectedAts { if slices.IndexFunc(manifest.Metadata, func(md codersdk.WorkspaceAgentMetadataDescription) bool { return md.Key == key }) < 0 { delete(lastCollectedAts, key) } } // Spawn a goroutine for each metadata collection, and use a // channel to synchronize the results and avoid both messy // mutex logic and overloading the API. for _, md := range manifest.Metadata { collectedAt, ok := lastCollectedAts[md.Key] if ok { // If the interval is zero, we assume the user just wants // a single collection at startup, not a spinning loop. if md.Interval == 0 { continue } // The last collected value isn't quite stale yet, so we skip it. if collectedAt.Add( adjustIntervalForTests(md.Interval), ).After(time.Now()) { continue } } md := md // We send the result to the channel in the goroutine to avoid // sending the same result multiple times. So, we don't care about // the return values. go flight.Do(md.Key, func() { timeout := md.Timeout if timeout == 0 { timeout = md.Interval } ctx, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Second, ) defer cancel() select { case <-ctx.Done(): case metadataResults <- metadataResultAndKey{ key: md.Key, result: a.collectMetadata(ctx, md), }: } }) } } } // reportLifecycleLoop reports the current lifecycle state once. // Only the latest state is reported, intermediate states may be // lost if the agent can't communicate with the API. func (a *agent) reportLifecycleLoop(ctx context.Context) { var lastReported codersdk.WorkspaceAgentLifecycle for { select { case <-a.lifecycleUpdate: case <-ctx.Done(): return } for r := retry.New(time.Second, 15*time.Second); r.Wait(ctx); { a.lifecycleMu.RLock() state := a.lifecycleState a.lifecycleMu.RUnlock() if state == lastReported { break } a.logger.Debug(ctx, "reporting lifecycle state", slog.F("state", state)) err := a.client.PostLifecycle(ctx, agentsdk.PostLifecycleRequest{ State: state, }) if err == nil { lastReported = state select { case a.lifecycleReported <- state: case <-a.lifecycleReported: a.lifecycleReported <- state } break } if xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) { return } // If we fail to report the state we probably shouldn't exit, log only. a.logger.Error(ctx, "post state", slog.Error(err)) } } } // setLifecycle sets the lifecycle state and notifies the lifecycle loop. // The state is only updated if it's a valid state transition. func (a *agent) setLifecycle(ctx context.Context, state codersdk.WorkspaceAgentLifecycle) { a.lifecycleMu.Lock() lastState := a.lifecycleState if slices.Index(codersdk.WorkspaceAgentLifecycleOrder, lastState) > slices.Index(codersdk.WorkspaceAgentLifecycleOrder, state) { a.logger.Warn(ctx, "attempted to set lifecycle state to a previous state", slog.F("last", lastState), slog.F("state", state)) a.lifecycleMu.Unlock() return } a.lifecycleState = state a.logger.Debug(ctx, "set lifecycle state", slog.F("state", state), slog.F("last", lastState)) a.lifecycleMu.Unlock() select { case a.lifecycleUpdate <- struct{}{}: default: } } func (a *agent) run(ctx context.Context) error { // This allows the agent to refresh it's token if necessary. // For instance identity this is required, since the instance // may not have re-provisioned, but a new agent ID was created. sessionToken, err := a.exchangeToken(ctx) if err != nil { return xerrors.Errorf("exchange token: %w", err) } a.sessionToken.Store(&sessionToken) manifest, err := a.client.Manifest(ctx) if err != nil { return xerrors.Errorf("fetch metadata: %w", err) } a.logger.Info(ctx, "fetched manifest", slog.F("manifest", manifest)) // Expand the directory and send it back to coderd so external // applications that rely on the directory can use it. // // An example is VS Code Remote, which must know the directory // before initializing a connection. manifest.Directory, err = expandDirectory(manifest.Directory) if err != nil { return xerrors.Errorf("expand directory: %w", err) } err = a.client.PostStartup(ctx, agentsdk.PostStartupRequest{ Version: buildinfo.Version(), ExpandedDirectory: manifest.Directory, }) if err != nil { return xerrors.Errorf("update workspace agent version: %w", err) } oldManifest := a.manifest.Swap(&manifest) // The startup script should only execute on the first run! if oldManifest == nil { a.setLifecycle(ctx, codersdk.WorkspaceAgentLifecycleStarting) // Perform overrides early so that Git auth can work even if users // connect to a workspace that is not yet ready. We don't run this // concurrently with the startup script to avoid conflicts between // them. if manifest.GitAuthConfigs > 0 { // If this fails, we should consider surfacing the error in the // startup log and setting the lifecycle state to be "start_error" // (after startup script completion), but for now we'll just log it. err := gitauth.OverrideVSCodeConfigs(a.filesystem) if err != nil { a.logger.Warn(ctx, "failed to override vscode git auth configs", slog.Error(err)) } } lifecycleState := codersdk.WorkspaceAgentLifecycleReady scriptDone := make(chan error, 1) scriptStart := time.Now() err = a.trackConnGoroutine(func() { defer close(scriptDone) scriptDone <- a.runStartupScript(ctx, manifest.StartupScript) }) if err != nil { return xerrors.Errorf("track startup script: %w", err) } go func() { var timeout <-chan time.Time // If timeout is zero, an older version of the coder // provider was used. Otherwise a timeout is always > 0. if manifest.StartupScriptTimeout > 0 { t := time.NewTimer(manifest.StartupScriptTimeout) defer t.Stop() timeout = t.C } var err error select { case err = <-scriptDone: case <-timeout: a.logger.Warn(ctx, "startup script timed out") a.setLifecycle(ctx, codersdk.WorkspaceAgentLifecycleStartTimeout) err = <-scriptDone // The script can still complete after a timeout. } if errors.Is(err, context.Canceled) { return } // Only log if there was a startup script. if manifest.StartupScript != "" { execTime := time.Since(scriptStart) if err != nil { a.logger.Warn(ctx, "startup script failed", slog.F("execution_time", execTime), slog.Error(err)) lifecycleState = codersdk.WorkspaceAgentLifecycleStartError } else { a.logger.Info(ctx, "startup script completed", slog.F("execution_time", execTime)) } } a.setLifecycle(ctx, lifecycleState) }() } // This automatically closes when the context ends! appReporterCtx, appReporterCtxCancel := context.WithCancel(ctx) defer appReporterCtxCancel() go NewWorkspaceAppHealthReporter( a.logger, manifest.Apps, a.client.PostAppHealth)(appReporterCtx) a.closeMutex.Lock() network := a.network a.closeMutex.Unlock() if network == nil { network, err = a.createTailnet(ctx, manifest.DERPMap) if err != nil { return xerrors.Errorf("create tailnet: %w", err) } a.closeMutex.Lock() // Re-check if agent was closed while initializing the network. closed := a.isClosed() if !closed { a.network = network } a.closeMutex.Unlock() if closed { _ = network.Close() return xerrors.New("agent is closed") } a.startReportingConnectionStats(ctx) } else { // Update the DERP map! network.SetDERPMap(manifest.DERPMap) } a.logger.Debug(ctx, "running tailnet connection coordinator") err = a.runCoordinator(ctx, network) if err != nil { return xerrors.Errorf("run coordinator: %w", err) } return nil } func (a *agent) trackConnGoroutine(fn func()) error { a.closeMutex.Lock() defer a.closeMutex.Unlock() if a.isClosed() { return xerrors.New("track conn goroutine: agent is closed") } a.connCloseWait.Add(1) go func() { defer a.connCloseWait.Done() fn() }() return nil } func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (_ *tailnet.Conn, err error) { network, err := tailnet.NewConn(&tailnet.Options{ Addresses: []netip.Prefix{netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128)}, DERPMap: derpMap, Logger: a.logger.Named("tailnet"), ListenPort: a.tailnetListenPort, }) if err != nil { return nil, xerrors.Errorf("create tailnet: %w", err) } defer func() { if err != nil { network.Close() } }() sshListener, err := network.Listen("tcp", ":"+strconv.Itoa(codersdk.WorkspaceAgentSSHPort)) if err != nil { return nil, xerrors.Errorf("listen on the ssh port: %w", err) } defer func() { if err != nil { _ = sshListener.Close() } }() if err = a.trackConnGoroutine(func() { var wg sync.WaitGroup for { conn, err := sshListener.Accept() if err != nil { break } wg.Add(1) closed := make(chan struct{}) go func() { select { case <-closed: case <-a.closed: _ = conn.Close() } wg.Done() }() go func() { defer close(closed) a.sshServer.HandleConn(conn) }() } wg.Wait() }); err != nil { return nil, err } reconnectingPTYListener, err := network.Listen("tcp", ":"+strconv.Itoa(codersdk.WorkspaceAgentReconnectingPTYPort)) if err != nil { return nil, xerrors.Errorf("listen for reconnecting pty: %w", err) } defer func() { if err != nil { _ = reconnectingPTYListener.Close() } }() if err = a.trackConnGoroutine(func() { logger := a.logger.Named("reconnecting-pty") var wg sync.WaitGroup for { conn, err := reconnectingPTYListener.Accept() if err != nil { if !a.isClosed() { logger.Debug(ctx, "accept pty failed", slog.Error(err)) } break } wg.Add(1) closed := make(chan struct{}) go func() { select { case <-closed: case <-a.closed: _ = conn.Close() } wg.Done() }() go func() { defer close(closed) // This cannot use a JSON decoder, since that can // buffer additional data that is required for the PTY. rawLen := make([]byte, 2) _, err = conn.Read(rawLen) if err != nil { return } length := binary.LittleEndian.Uint16(rawLen) data := make([]byte, length) _, err = conn.Read(data) if err != nil { return } var msg codersdk.WorkspaceAgentReconnectingPTYInit err = json.Unmarshal(data, &msg) if err != nil { return } _ = a.handleReconnectingPTY(ctx, logger, msg, conn) }() } wg.Wait() }); err != nil { return nil, err } speedtestListener, err := network.Listen("tcp", ":"+strconv.Itoa(codersdk.WorkspaceAgentSpeedtestPort)) if err != nil { return nil, xerrors.Errorf("listen for speedtest: %w", err) } defer func() { if err != nil { _ = speedtestListener.Close() } }() if err = a.trackConnGoroutine(func() { var wg sync.WaitGroup for { conn, err := speedtestListener.Accept() if err != nil { if !a.isClosed() { a.logger.Debug(ctx, "speedtest listener failed", slog.Error(err)) } break } wg.Add(1) closed := make(chan struct{}) go func() { select { case <-closed: case <-a.closed: _ = conn.Close() } wg.Done() }() go func() { defer close(closed) _ = speedtest.ServeConn(conn) }() } wg.Wait() }); err != nil { return nil, err } apiListener, err := network.Listen("tcp", ":"+strconv.Itoa(codersdk.WorkspaceAgentHTTPAPIServerPort)) if err != nil { return nil, xerrors.Errorf("api listener: %w", err) } defer func() { if err != nil { _ = apiListener.Close() } }() if err = a.trackConnGoroutine(func() { defer apiListener.Close() server := &http.Server{ Handler: a.apiHandler(), ReadTimeout: 20 * time.Second, ReadHeaderTimeout: 20 * time.Second, WriteTimeout: 20 * time.Second, ErrorLog: slog.Stdlib(ctx, a.logger.Named("http_api_server"), slog.LevelInfo), } go func() { select { case <-ctx.Done(): case <-a.closed: } _ = server.Close() }() err := server.Serve(apiListener) if err != nil && !xerrors.Is(err, http.ErrServerClosed) && !strings.Contains(err.Error(), "use of closed network connection") { a.logger.Critical(ctx, "serve HTTP API server", slog.Error(err)) } }); err != nil { return nil, err } return network, nil } // runCoordinator runs a coordinator and returns whether a reconnect // should occur. func (a *agent) runCoordinator(ctx context.Context, network *tailnet.Conn) error { ctx, cancel := context.WithCancel(ctx) defer cancel() coordinator, err := a.client.Listen(ctx) if err != nil { return err } defer coordinator.Close() a.logger.Info(ctx, "connected to coordination endpoint") sendNodes, errChan := tailnet.ServeCoordinator(coordinator, func(nodes []*tailnet.Node) error { return network.UpdateNodes(nodes, false) }) network.SetNodeCallback(sendNodes) select { case <-ctx.Done(): return ctx.Err() case err := <-errChan: return err } } func (a *agent) runStartupScript(ctx context.Context, script string) error { return a.runScript(ctx, "startup", script) } func (a *agent) runShutdownScript(ctx context.Context, script string) error { return a.runScript(ctx, "shutdown", script) } func (a *agent) runScript(ctx context.Context, lifecycle, script string) error { if script == "" { return nil } a.logger.Info(ctx, "running script", slog.F("lifecycle", lifecycle), slog.F("script", script)) fileWriter, err := a.filesystem.OpenFile(filepath.Join(a.logDir, fmt.Sprintf("coder-%s-script.log", lifecycle)), os.O_CREATE|os.O_RDWR, 0o600) if err != nil { return xerrors.Errorf("open %s script log file: %w", lifecycle, err) } defer func() { _ = fileWriter.Close() }() var writer io.Writer = fileWriter if lifecycle == "startup" { // Create pipes for startup logs reader and writer logsReader, logsWriter := io.Pipe() defer func() { _ = logsReader.Close() }() writer = io.MultiWriter(fileWriter, logsWriter) flushedLogs, err := a.trackScriptLogs(ctx, logsReader) if err != nil { return xerrors.Errorf("track script logs: %w", err) } defer func() { _ = logsWriter.Close() <-flushedLogs }() } cmd, err := a.createCommand(ctx, script, nil) if err != nil { return xerrors.Errorf("create command: %w", err) } cmd.Stdout = writer cmd.Stderr = writer err = cmd.Run() if err != nil { // cmd.Run does not return a context canceled error, it returns "signal: killed". if ctx.Err() != nil { return ctx.Err() } return xerrors.Errorf("run: %w", err) } return nil } func (a *agent) trackScriptLogs(ctx context.Context, reader io.Reader) (chan struct{}, error) { // Initialize variables for log management queuedLogs := make([]agentsdk.StartupLog, 0) var flushLogsTimer *time.Timer var logMutex sync.Mutex logsFlushed := sync.NewCond(&sync.Mutex{}) var logsSending bool defer func() { logMutex.Lock() if flushLogsTimer != nil { flushLogsTimer.Stop() } logMutex.Unlock() }() // sendLogs function uploads the queued logs to the server sendLogs := func() { // Lock logMutex and check if logs are already being sent logMutex.Lock() if logsSending { logMutex.Unlock() return } if flushLogsTimer != nil { flushLogsTimer.Stop() } if len(queuedLogs) == 0 { logMutex.Unlock() return } // Move the current queued logs to logsToSend and clear the queue logsToSend := queuedLogs logsSending = true queuedLogs = make([]agentsdk.StartupLog, 0) logMutex.Unlock() // Retry uploading logs until successful or a specific error occurs for r := retry.New(time.Second, 5*time.Second); r.Wait(ctx); { err := a.client.PatchStartupLogs(ctx, agentsdk.PatchStartupLogs{ Logs: logsToSend, }) if err == nil { break } var sdkErr *codersdk.Error if errors.As(err, &sdkErr) { if sdkErr.StatusCode() == http.StatusRequestEntityTooLarge { a.logger.Warn(ctx, "startup logs too large, dropping logs") break } } a.logger.Error(ctx, "upload startup logs", slog.Error(err), slog.F("to_send", logsToSend)) } // Reset logsSending flag logMutex.Lock() logsSending = false flushLogsTimer.Reset(100 * time.Millisecond) logMutex.Unlock() logsFlushed.Broadcast() } // queueLog function appends a log to the queue and triggers sendLogs if necessary queueLog := func(log agentsdk.StartupLog) { logMutex.Lock() defer logMutex.Unlock() // Append log to the queue queuedLogs = append(queuedLogs, log) // If there are more than 100 logs, send them immediately if len(queuedLogs) > 100 { // Don't early return after this, because we still want // to reset the timer just in case logs come in while // we're sending. go sendLogs() } // Reset or set the flushLogsTimer to trigger sendLogs after 100 milliseconds if flushLogsTimer != nil { flushLogsTimer.Reset(100 * time.Millisecond) return } flushLogsTimer = time.AfterFunc(100*time.Millisecond, sendLogs) } // It's important that we either flush or drop all logs before returning // because the startup state is reported after flush. // // It'd be weird for the startup state to be ready, but logs are still // coming in. logsFinished := make(chan struct{}) err := a.trackConnGoroutine(func() { scanner := bufio.NewScanner(reader) for scanner.Scan() { queueLog(agentsdk.StartupLog{ CreatedAt: database.Now(), Output: scanner.Text(), }) } defer close(logsFinished) logsFlushed.L.Lock() for { logMutex.Lock() if len(queuedLogs) == 0 { logMutex.Unlock() break } logMutex.Unlock() logsFlushed.Wait() } }) if err != nil { return nil, xerrors.Errorf("track conn goroutine: %w", err) } return logsFinished, nil } func (a *agent) init(ctx context.Context) { // Clients' should ignore the host key when connecting. // The agent needs to authenticate with coderd to SSH, // so SSH authentication doesn't improve security. randomHostKey, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { panic(err) } randomSigner, err := gossh.NewSignerFromKey(randomHostKey) if err != nil { panic(err) } sshLogger := a.logger.Named("ssh-server") forwardHandler := &ssh.ForwardedTCPHandler{} unixForwardHandler := &forwardedUnixHandler{log: a.logger} a.sshServer = &ssh.Server{ ChannelHandlers: map[string]ssh.ChannelHandler{ "direct-tcpip": ssh.DirectTCPIPHandler, "direct-streamlocal@openssh.com": directStreamLocalHandler, "session": ssh.DefaultSessionHandler, }, ConnectionFailedCallback: func(conn net.Conn, err error) { sshLogger.Info(ctx, "ssh connection ended", slog.Error(err)) }, Handler: func(session ssh.Session) { err := a.handleSSHSession(session) var exitError *exec.ExitError if xerrors.As(err, &exitError) { a.logger.Debug(ctx, "ssh session returned", slog.Error(exitError)) _ = session.Exit(exitError.ExitCode()) return } if err != nil { a.logger.Warn(ctx, "ssh session failed", slog.Error(err)) // This exit code is designed to be unlikely to be confused for a legit exit code // from the process. _ = session.Exit(MagicSessionErrorCode) return } _ = session.Exit(0) }, HostSigners: []ssh.Signer{randomSigner}, LocalPortForwardingCallback: func(ctx ssh.Context, destinationHost string, destinationPort uint32) bool { // Allow local port forwarding all! sshLogger.Debug(ctx, "local port forward", slog.F("destination-host", destinationHost), slog.F("destination-port", destinationPort)) return true }, PtyCallback: func(ctx ssh.Context, pty ssh.Pty) bool { return true }, ReversePortForwardingCallback: func(ctx ssh.Context, bindHost string, bindPort uint32) bool { // Allow reverse port forwarding all! sshLogger.Debug(ctx, "local port forward", slog.F("bind-host", bindHost), slog.F("bind-port", bindPort)) return true }, RequestHandlers: map[string]ssh.RequestHandler{ "tcpip-forward": forwardHandler.HandleSSHRequest, "cancel-tcpip-forward": forwardHandler.HandleSSHRequest, "streamlocal-forward@openssh.com": unixForwardHandler.HandleSSHRequest, "cancel-streamlocal-forward@openssh.com": unixForwardHandler.HandleSSHRequest, }, ServerConfigCallback: func(ctx ssh.Context) *gossh.ServerConfig { return &gossh.ServerConfig{ NoClientAuth: true, } }, SubsystemHandlers: map[string]ssh.SubsystemHandler{ "sftp": func(session ssh.Session) { ctx := session.Context() // Typically sftp sessions don't request a TTY, but if they do, // we must ensure the gliderlabs/ssh CRLF emulation is disabled. // Otherwise sftp will be broken. This can happen if a user sets // `RequestTTY force` in their SSH config. session.DisablePTYEmulation() var opts []sftp.ServerOption // Change current working directory to the users home // directory so that SFTP connections land there. homedir, err := userHomeDir() if err != nil { sshLogger.Warn(ctx, "get sftp working directory failed, unable to get home dir", slog.Error(err)) } else { opts = append(opts, sftp.WithServerWorkingDirectory(homedir)) } server, err := sftp.NewServer(session, opts...) if err != nil { sshLogger.Debug(ctx, "initialize sftp server", slog.Error(err)) return } defer server.Close() err = server.Serve() if errors.Is(err, io.EOF) { // Unless we call `session.Exit(0)` here, the client won't // receive `exit-status` because `(*sftp.Server).Close()` // calls `Close()` on the underlying connection (session), // which actually calls `channel.Close()` because it isn't // wrapped. This causes sftp clients to receive a non-zero // exit code. Typically sftp clients don't echo this exit // code but `scp` on macOS does (when using the default // SFTP backend). _ = session.Exit(0) return } sshLogger.Warn(ctx, "sftp server closed with error", slog.Error(err)) _ = session.Exit(1) }, }, MaxTimeout: a.sshMaxTimeout, } go a.runLoop(ctx) } // createCommand processes raw command input with OpenSSH-like behavior. // If the script provided is empty, it will default to the users shell. // This injects environment variables specified by the user at launch too. func (a *agent) createCommand(ctx context.Context, script string, env []string) (*exec.Cmd, error) { currentUser, err := user.Current() if err != nil { return nil, xerrors.Errorf("get current user: %w", err) } username := currentUser.Username shell, err := usershell.Get(username) if err != nil { return nil, xerrors.Errorf("get user shell: %w", err) } manifest := a.manifest.Load() if manifest == nil { return nil, xerrors.Errorf("no metadata was provided") } // OpenSSH executes all commands with the users current shell. // We replicate that behavior for IDE support. caller := "-c" if runtime.GOOS == "windows" { caller = "/c" } args := []string{caller, script} // gliderlabs/ssh returns a command slice of zero // when a shell is requested. if len(script) == 0 { args = []string{} if runtime.GOOS != "windows" { // On Linux and macOS, we should start a login // shell to consume juicy environment variables! args = append(args, "-l") } } cmd := exec.CommandContext(ctx, shell, args...) cmd.Dir = manifest.Directory // If the metadata directory doesn't exist, we run the command // in the users home directory. _, err = os.Stat(cmd.Dir) if cmd.Dir == "" || err != nil { // Default to user home if a directory is not set. homedir, err := userHomeDir() if err != nil { return nil, xerrors.Errorf("get home dir: %w", err) } cmd.Dir = homedir } cmd.Env = append(os.Environ(), env...) executablePath, err := os.Executable() if err != nil { return nil, xerrors.Errorf("getting os executable: %w", err) } // Set environment variables reliable detection of being inside a // Coder workspace. cmd.Env = append(cmd.Env, "CODER=true") cmd.Env = append(cmd.Env, fmt.Sprintf("USER=%s", username)) // Git on Windows resolves with UNIX-style paths. // If using backslashes, it's unable to find the executable. unixExecutablePath := strings.ReplaceAll(executablePath, "\\", "/") cmd.Env = append(cmd.Env, fmt.Sprintf(`GIT_SSH_COMMAND=%s gitssh --`, unixExecutablePath)) // Specific Coder subcommands require the agent token exposed! cmd.Env = append(cmd.Env, fmt.Sprintf("CODER_AGENT_TOKEN=%s", *a.sessionToken.Load())) // Set SSH connection environment variables (these are also set by OpenSSH // and thus expected to be present by SSH clients). Since the agent does // networking in-memory, trying to provide accurate values here would be // nonsensical. For now, we hard code these values so that they're present. srcAddr, srcPort := "0.0.0.0", "0" dstAddr, dstPort := "0.0.0.0", "0" cmd.Env = append(cmd.Env, fmt.Sprintf("SSH_CLIENT=%s %s %s", srcAddr, srcPort, dstPort)) cmd.Env = append(cmd.Env, fmt.Sprintf("SSH_CONNECTION=%s %s %s %s", srcAddr, srcPort, dstAddr, dstPort)) // This adds the ports dialog to code-server that enables // proxying a port dynamically. cmd.Env = append(cmd.Env, fmt.Sprintf("VSCODE_PROXY_URI=%s", manifest.VSCodePortProxyURI)) // Hide Coder message on code-server's "Getting Started" page cmd.Env = append(cmd.Env, "CS_DISABLE_GETTING_STARTED_OVERRIDE=true") // Load environment variables passed via the agent. // These should override all variables we manually specify. for envKey, value := range manifest.EnvironmentVariables { // Expanding environment variables allows for customization // of the $PATH, among other variables. Customers can prepend // or append to the $PATH, so allowing expand is required! cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", envKey, os.ExpandEnv(value))) } // Agent-level environment variables should take over all! // This is used for setting agent-specific variables like "CODER_AGENT_TOKEN". for envKey, value := range a.envVars { cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", envKey, value)) } return cmd, nil } func (a *agent) handleSSHSession(session ssh.Session) (retErr error) { ctx := session.Context() env := session.Environ() var magicType string for index, kv := range env { if !strings.HasPrefix(kv, MagicSSHSessionTypeEnvironmentVariable) { continue } magicType = strings.TrimPrefix(kv, MagicSSHSessionTypeEnvironmentVariable+"=") env = append(env[:index], env[index+1:]...) } switch magicType { case MagicSSHSessionTypeVSCode: a.connCountVSCode.Add(1) defer a.connCountVSCode.Add(-1) case MagicSSHSessionTypeJetBrains: a.connCountJetBrains.Add(1) defer a.connCountJetBrains.Add(-1) case "": a.connCountSSHSession.Add(1) defer a.connCountSSHSession.Add(-1) default: a.logger.Warn(ctx, "invalid magic ssh session type specified", slog.F("type", magicType)) } cmd, err := a.createCommand(ctx, session.RawCommand(), env) if err != nil { return err } if ssh.AgentRequested(session) { l, err := ssh.NewAgentListener() if err != nil { return xerrors.Errorf("new agent listener: %w", err) } defer l.Close() go ssh.ForwardAgentConnections(l, session) cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", "SSH_AUTH_SOCK", l.Addr().String())) } sshPty, windowSize, isPty := session.Pty() if isPty { // Disable minimal PTY emulation set by gliderlabs/ssh (NL-to-CRNL). // See https://github.com/coder/coder/issues/3371. session.DisablePTYEmulation() if !isQuietLogin(session.RawCommand()) { manifest := a.manifest.Load() if manifest != nil { err = showMOTD(session, manifest.MOTDFile) if err != nil { a.logger.Error(ctx, "show MOTD", slog.Error(err)) } } else { a.logger.Warn(ctx, "metadata lookup failed, unable to show MOTD") } } cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", sshPty.Term)) // The pty package sets `SSH_TTY` on supported platforms. ptty, process, err := pty.Start(cmd, pty.WithPTYOption( pty.WithSSHRequest(sshPty), pty.WithLogger(slog.Stdlib(ctx, a.logger, slog.LevelInfo)), )) if err != nil { return xerrors.Errorf("start command: %w", err) } var wg sync.WaitGroup defer func() { defer wg.Wait() closeErr := ptty.Close() if closeErr != nil { a.logger.Warn(ctx, "failed to close tty", slog.Error(closeErr)) if retErr == nil { retErr = closeErr } } }() go func() { for win := range windowSize { resizeErr := ptty.Resize(uint16(win.Height), uint16(win.Width)) // If the pty is closed, then command has exited, no need to log. if resizeErr != nil && !errors.Is(resizeErr, pty.ErrClosed) { a.logger.Warn(ctx, "failed to resize tty", slog.Error(resizeErr)) } } }() // We don't add input copy to wait group because // it won't return until the session is closed. go func() { _, _ = io.Copy(ptty.Input(), session) }() // In low parallelism scenarios, the command may exit and we may close // the pty before the output copy has started. This can result in the // output being lost. To avoid this, we wait for the output copy to // start before waiting for the command to exit. This ensures that the // output copy goroutine will be scheduled before calling close on the // pty. This shouldn't be needed because of `pty.Dup()` below, but it // may not be supported on all platforms. outputCopyStarted := make(chan struct{}) ptyOutput := func() io.ReadCloser { defer close(outputCopyStarted) // Try to dup so we can separate stdin and stdout closure. // Once the original pty is closed, the dup will return // input/output error once the buffered data has been read. stdout, err := ptty.Dup() if err == nil { return stdout } // If we can't dup, we shouldn't close // the fd since it's tied to stdin. return readNopCloser{ptty.Output()} } wg.Add(1) go func() { // Ensure data is flushed to session on command exit, if we // close the session too soon, we might lose data. defer wg.Done() stdout := ptyOutput() defer stdout.Close() _, _ = io.Copy(session, stdout) }() <-outputCopyStarted err = process.Wait() var exitErr *exec.ExitError // ExitErrors just mean the command we run returned a non-zero exit code, which is normal // and not something to be concerned about. But, if it's something else, we should log it. if err != nil && !xerrors.As(err, &exitErr) { a.logger.Warn(ctx, "wait error", slog.Error(err)) } return err } cmd.Stdout = session cmd.Stderr = session.Stderr() // This blocks forever until stdin is received if we don't // use StdinPipe. It's unknown what causes this. stdinPipe, err := cmd.StdinPipe() if err != nil { return xerrors.Errorf("create stdin pipe: %w", err) } go func() { _, _ = io.Copy(stdinPipe, session) _ = stdinPipe.Close() }() err = cmd.Start() if err != nil { return xerrors.Errorf("start: %w", err) } return cmd.Wait() } type readNopCloser struct{ io.Reader } // Close implements io.Closer. func (readNopCloser) Close() error { return nil } func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, msg codersdk.WorkspaceAgentReconnectingPTYInit, conn net.Conn) (retErr error) { defer conn.Close() a.connCountReconnectingPTY.Add(1) defer a.connCountReconnectingPTY.Add(-1) connectionID := uuid.NewString() logger = logger.With(slog.F("id", msg.ID), slog.F("connection_id", connectionID)) defer func() { if err := retErr; err != nil { a.closeMutex.Lock() closed := a.isClosed() a.closeMutex.Unlock() // If the agent is closed, we don't want to // log this as an error since it's expected. if closed { logger.Debug(ctx, "session error after agent close", slog.Error(err)) } else { logger.Error(ctx, "session error", slog.Error(err)) } } logger.Debug(ctx, "session closed") }() var rpty *reconnectingPTY rawRPTY, ok := a.reconnectingPTYs.Load(msg.ID) if ok { logger.Debug(ctx, "connecting to existing session") rpty, ok = rawRPTY.(*reconnectingPTY) if !ok { return xerrors.Errorf("found invalid type in reconnecting pty map: %T", rawRPTY) } } else { logger.Debug(ctx, "creating new session") // Empty command will default to the users shell! cmd, err := a.createCommand(ctx, msg.Command, nil) if err != nil { return xerrors.Errorf("create command: %w", err) } cmd.Env = append(cmd.Env, "TERM=xterm-256color") // Default to buffer 64KiB. circularBuffer, err := circbuf.NewBuffer(64 << 10) if err != nil { return xerrors.Errorf("create circular buffer: %w", err) } ptty, process, err := pty.Start(cmd) if err != nil { return xerrors.Errorf("start command: %w", err) } ctx, cancelFunc := context.WithCancel(ctx) rpty = &reconnectingPTY{ activeConns: map[string]net.Conn{ // We have to put the connection in the map instantly otherwise // the connection won't be closed if the process instantly dies. connectionID: conn, }, ptty: ptty, // Timeouts created with an after func can be reset! timeout: time.AfterFunc(a.reconnectingPTYTimeout, cancelFunc), circularBuffer: circularBuffer, } a.reconnectingPTYs.Store(msg.ID, rpty) go func() { // CommandContext isn't respected for Windows PTYs right now, // so we need to manually track the lifecycle. // When the context has been completed either: // 1. The timeout completed. // 2. The parent context was canceled. <-ctx.Done() _ = process.Kill() }() go func() { // If the process dies randomly, we should // close the pty. _ = process.Wait() rpty.Close() }() if err = a.trackConnGoroutine(func() { buffer := make([]byte, 1024) for { read, err := rpty.ptty.Output().Read(buffer) if err != nil { // When the PTY is closed, this is triggered. break } part := buffer[:read] rpty.circularBufferMutex.Lock() _, err = rpty.circularBuffer.Write(part) rpty.circularBufferMutex.Unlock() if err != nil { logger.Error(ctx, "write to circular buffer", slog.Error(err)) break } rpty.activeConnsMutex.Lock() for _, conn := range rpty.activeConns { _, _ = conn.Write(part) } rpty.activeConnsMutex.Unlock() } // Cleanup the process, PTY, and delete it's // ID from memory. _ = process.Kill() rpty.Close() a.reconnectingPTYs.Delete(msg.ID) }); err != nil { return xerrors.Errorf("start routine: %w", err) } } // Resize the PTY to initial height + width. err := rpty.ptty.Resize(msg.Height, msg.Width) if err != nil { // We can continue after this, it's not fatal! logger.Error(ctx, "resize", slog.Error(err)) } // Write any previously stored data for the TTY. rpty.circularBufferMutex.RLock() prevBuf := slices.Clone(rpty.circularBuffer.Bytes()) rpty.circularBufferMutex.RUnlock() // Note that there is a small race here between writing buffered // data and storing conn in activeConns. This is likely a very minor // edge case, but we should look into ways to avoid it. Holding // activeConnsMutex would be one option, but holding this mutex // while also holding circularBufferMutex seems dangerous. _, err = conn.Write(prevBuf) if err != nil { return xerrors.Errorf("write buffer to conn: %w", err) } // Multiple connections to the same TTY are permitted. // This could easily be used for terminal sharing, but // we do it because it's a nice user experience to // copy/paste a terminal URL and have it _just work_. rpty.activeConnsMutex.Lock() rpty.activeConns[connectionID] = conn rpty.activeConnsMutex.Unlock() // Resetting this timeout prevents the PTY from exiting. rpty.timeout.Reset(a.reconnectingPTYTimeout) ctx, cancelFunc := context.WithCancel(ctx) defer cancelFunc() heartbeat := time.NewTicker(a.reconnectingPTYTimeout / 2) defer heartbeat.Stop() go func() { // Keep updating the activity while this // connection is alive! for { select { case <-ctx.Done(): return case <-heartbeat.C: } rpty.timeout.Reset(a.reconnectingPTYTimeout) } }() defer func() { // After this connection ends, remove it from // the PTYs active connections. If it isn't // removed, all PTY data will be sent to it. rpty.activeConnsMutex.Lock() delete(rpty.activeConns, connectionID) rpty.activeConnsMutex.Unlock() }() decoder := json.NewDecoder(conn) var req codersdk.ReconnectingPTYRequest for { err = decoder.Decode(&req) if xerrors.Is(err, io.EOF) { return nil } if err != nil { logger.Warn(ctx, "read conn", slog.Error(err)) return nil } _, err = rpty.ptty.Input().Write([]byte(req.Data)) if err != nil { logger.Warn(ctx, "write to pty", slog.Error(err)) return nil } // Check if a resize needs to happen! if req.Height == 0 || req.Width == 0 { continue } err = rpty.ptty.Resize(req.Height, req.Width) if err != nil { // We can continue after this, it's not fatal! logger.Error(ctx, "resize", slog.Error(err)) } } } // startReportingConnectionStats runs the connection stats reporting goroutine. func (a *agent) startReportingConnectionStats(ctx context.Context) { reportStats := func(networkStats map[netlogtype.Connection]netlogtype.Counts) { stats := &agentsdk.Stats{ ConnectionCount: int64(len(networkStats)), ConnectionsByProto: map[string]int64{}, } for conn, counts := range networkStats { stats.ConnectionsByProto[conn.Proto.String()]++ stats.RxBytes += int64(counts.RxBytes) stats.RxPackets += int64(counts.RxPackets) stats.TxBytes += int64(counts.TxBytes) stats.TxPackets += int64(counts.TxPackets) } // The count of active sessions. stats.SessionCountSSH = a.connCountSSHSession.Load() stats.SessionCountVSCode = a.connCountVSCode.Load() stats.SessionCountJetBrains = a.connCountJetBrains.Load() stats.SessionCountReconnectingPTY = a.connCountReconnectingPTY.Load() // Compute the median connection latency! var wg sync.WaitGroup var mu sync.Mutex status := a.network.Status() durations := []float64{} ctx, cancelFunc := context.WithTimeout(ctx, 5*time.Second) defer cancelFunc() for nodeID, peer := range status.Peer { if !peer.Active { continue } addresses, found := a.network.NodeAddresses(nodeID) if !found { continue } if len(addresses) == 0 { continue } wg.Add(1) go func() { defer wg.Done() duration, _, _, err := a.network.Ping(ctx, addresses[0].Addr()) if err != nil { return } mu.Lock() durations = append(durations, float64(duration.Microseconds())) mu.Unlock() }() } wg.Wait() sort.Float64s(durations) durationsLength := len(durations) if durationsLength == 0 { stats.ConnectionMedianLatencyMS = -1 } else if durationsLength%2 == 0 { stats.ConnectionMedianLatencyMS = (durations[durationsLength/2-1] + durations[durationsLength/2]) / 2 } else { stats.ConnectionMedianLatencyMS = durations[durationsLength/2] } // Convert from microseconds to milliseconds. stats.ConnectionMedianLatencyMS /= 1000 lastStat := a.latestStat.Load() if lastStat != nil && reflect.DeepEqual(lastStat, stats) { a.logger.Info(ctx, "skipping stat because nothing changed") return } a.latestStat.Store(stats) select { case a.connStatsChan <- stats: case <-a.closed: } } // Report statistics from the created network. cl, err := a.client.ReportStats(ctx, a.logger, a.connStatsChan, func(d time.Duration) { a.network.SetConnStatsCallback(d, 2048, func(_, _ time.Time, virtual, _ map[netlogtype.Connection]netlogtype.Counts) { reportStats(virtual) }, ) }) if err != nil { a.logger.Error(ctx, "report stats", slog.Error(err)) } else { if err = a.trackConnGoroutine(func() { // This is OK because the agent never re-creates the tailnet // and the only shutdown indicator is agent.Close(). <-a.closed _ = cl.Close() }); err != nil { a.logger.Debug(ctx, "report stats goroutine", slog.Error(err)) _ = cl.Close() } } } // isClosed returns whether the API is closed or not. func (a *agent) isClosed() bool { select { case <-a.closed: return true default: return false } } func (a *agent) Close() error { a.closeMutex.Lock() defer a.closeMutex.Unlock() if a.isClosed() { return nil } ctx := context.Background() a.setLifecycle(ctx, codersdk.WorkspaceAgentLifecycleShuttingDown) lifecycleState := codersdk.WorkspaceAgentLifecycleOff if manifest := a.manifest.Load(); manifest != nil && manifest.ShutdownScript != "" { scriptDone := make(chan error, 1) scriptStart := time.Now() go func() { defer close(scriptDone) scriptDone <- a.runShutdownScript(ctx, manifest.ShutdownScript) }() var timeout <-chan time.Time // If timeout is zero, an older version of the coder // provider was used. Otherwise a timeout is always > 0. if manifest.ShutdownScriptTimeout > 0 { t := time.NewTimer(manifest.ShutdownScriptTimeout) defer t.Stop() timeout = t.C } var err error select { case err = <-scriptDone: case <-timeout: a.logger.Warn(ctx, "shutdown script timed out") a.setLifecycle(ctx, codersdk.WorkspaceAgentLifecycleShutdownTimeout) err = <-scriptDone // The script can still complete after a timeout. } execTime := time.Since(scriptStart) if err != nil { a.logger.Warn(ctx, "shutdown script failed", slog.F("execution_time", execTime), slog.Error(err)) lifecycleState = codersdk.WorkspaceAgentLifecycleShutdownError } else { a.logger.Info(ctx, "shutdown script completed", slog.F("execution_time", execTime)) } } // Set final state and wait for it to be reported because context // cancellation will stop the report loop. a.setLifecycle(ctx, lifecycleState) // Wait for the lifecycle to be reported, but don't wait forever so // that we don't break user expectations. ctx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() lifecycleWaitLoop: for { select { case <-ctx.Done(): break lifecycleWaitLoop case s := <-a.lifecycleReported: if s == lifecycleState { break lifecycleWaitLoop } } } close(a.closed) a.closeCancel() _ = a.sshServer.Close() if a.network != nil { _ = a.network.Close() } a.connCloseWait.Wait() return nil } type reconnectingPTY struct { activeConnsMutex sync.Mutex activeConns map[string]net.Conn circularBuffer *circbuf.Buffer circularBufferMutex sync.RWMutex timeout *time.Timer ptty pty.PTY } // Close ends all connections to the reconnecting // PTY and clear the circular buffer. func (r *reconnectingPTY) Close() { r.activeConnsMutex.Lock() defer r.activeConnsMutex.Unlock() for _, conn := range r.activeConns { _ = conn.Close() } _ = r.ptty.Close() r.circularBufferMutex.Lock() r.circularBuffer.Reset() r.circularBufferMutex.Unlock() r.timeout.Stop() } // Bicopy copies all of the data between the two connections and will close them // after one or both of them are done writing. If the context is canceled, both // of the connections will be closed. func Bicopy(ctx context.Context, c1, c2 io.ReadWriteCloser) { ctx, cancel := context.WithCancel(ctx) defer cancel() defer func() { _ = c1.Close() _ = c2.Close() }() var wg sync.WaitGroup copyFunc := func(dst io.WriteCloser, src io.Reader) { defer func() { wg.Done() // If one side of the copy fails, ensure the other one exits as // well. cancel() }() _, _ = io.Copy(dst, src) } wg.Add(2) go copyFunc(c1, c2) go copyFunc(c2, c1) // Convert waitgroup to a channel so we can also wait on the context. done := make(chan struct{}) go func() { defer close(done) wg.Wait() }() select { case <-ctx.Done(): case <-done: } } // isQuietLogin checks if the SSH server should perform a quiet login or not. // // https://github.com/openssh/openssh-portable/blob/25bd659cc72268f2858c5415740c442ee950049f/session.c#L816 func isQuietLogin(rawCommand string) bool { // We are always quiet unless this is a login shell. if len(rawCommand) != 0 { return true } // Best effort, if we can't get the home directory, // we can't lookup .hushlogin. homedir, err := userHomeDir() if err != nil { return false } _, err = os.Stat(filepath.Join(homedir, ".hushlogin")) return err == nil } // showMOTD will output the message of the day from // the given filename to dest, if the file exists. // // https://github.com/openssh/openssh-portable/blob/25bd659cc72268f2858c5415740c442ee950049f/session.c#L784 func showMOTD(dest io.Writer, filename string) error { if filename == "" { return nil } f, err := os.Open(filename) if err != nil { if xerrors.Is(err, os.ErrNotExist) { // This is not an error, there simply isn't a MOTD to show. return nil } return xerrors.Errorf("open MOTD: %w", err) } defer f.Close() s := bufio.NewScanner(f) for s.Scan() { // Carriage return ensures each line starts // at the beginning of the terminal. _, err = fmt.Fprint(dest, s.Text()+"\r\n") if err != nil { return xerrors.Errorf("write MOTD: %w", err) } } if err := s.Err(); err != nil { return xerrors.Errorf("read MOTD: %w", err) } return nil } // userHomeDir returns the home directory of the current user, giving // priority to the $HOME environment variable. func userHomeDir() (string, error) { // First we check the environment. homedir, err := os.UserHomeDir() if err == nil { return homedir, nil } // As a fallback, we try the user information. u, err := user.Current() if err != nil { return "", xerrors.Errorf("current user: %w", err) } return u.HomeDir, nil } // expandDirectory converts a directory path to an absolute path. // It primarily resolves the home directory and any environment // variables that may be set func expandDirectory(dir string) (string, error) { if dir == "" { return "", nil } if dir[0] == '~' { home, err := userHomeDir() if err != nil { return "", err } dir = filepath.Join(home, dir[1:]) } return os.ExpandEnv(dir), nil }