feat(agent): wire up agentssh server to allow exec into container (#16638)

Builds on top of https://github.com/coder/coder/pull/16623/ and wires up
the ReconnectingPTY server. This does nothing to wire up the web
terminal yet but the added test demonstrates the functionality working.

Other changes:
* Refactors and moves the `SystemEnvInfo` interface to the
`agent/usershell` package to address follow-up from
https://github.com/coder/coder/pull/16623#discussion_r1967580249
* Marks `usershellinfo.Get` as deprecated. Consumers should use the
`EnvInfoer` interface instead.

---------

Co-authored-by: Mathias Fredriksson <mafredri@gmail.com>
Co-authored-by: Danny Kopping <danny@coder.com>
This commit is contained in:
Cian Johnston
2025-02-26 09:03:27 +00:00
committed by GitHub
parent a3223397cb
commit 172e52317c
15 changed files with 260 additions and 82 deletions

View File

@ -88,6 +88,8 @@ type Options struct {
BlockFileTransfer bool
Execer agentexec.Execer
ContainerLister agentcontainers.Lister
ExperimentalContainersEnabled bool
}
type Client interface {
@ -188,6 +190,8 @@ func New(options Options) Agent {
metrics: newAgentMetrics(prometheusRegistry),
execer: options.Execer,
lister: options.ContainerLister,
experimentalDevcontainersEnabled: options.ExperimentalContainersEnabled,
}
// Initially, we have a closed channel, reflecting the fact that we are not initially connected.
// Each time we connect we replace the channel (while holding the closeMutex) with a new one
@ -258,6 +262,8 @@ type agent struct {
metrics *agentMetrics
execer agentexec.Execer
lister agentcontainers.Lister
experimentalDevcontainersEnabled bool
}
func (a *agent) TailnetConn() *tailnet.Conn {
@ -297,6 +303,9 @@ func (a *agent) init() {
a.sshServer,
a.metrics.connectionsTotal, a.metrics.reconnectingPTYErrors,
a.reconnectingPTYTimeout,
func(s *reconnectingpty.Server) {
s.ExperimentalContainersEnabled = a.experimentalDevcontainersEnabled
},
)
go a.runLoop()
}

View File

@ -25,8 +25,14 @@ import (
"testing"
"time"
"go.uber.org/goleak"
"tailscale.com/net/speedtest"
"tailscale.com/tailcfg"
"github.com/bramvdbogaerde/go-scp"
"github.com/google/uuid"
"github.com/ory/dockertest/v3"
"github.com/ory/dockertest/v3/docker"
"github.com/pion/udp"
"github.com/pkg/sftp"
"github.com/prometheus/client_golang/prometheus"
@ -34,15 +40,13 @@ import (
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"golang.org/x/crypto/ssh"
"golang.org/x/exp/slices"
"golang.org/x/xerrors"
"tailscale.com/net/speedtest"
"tailscale.com/tailcfg"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/agent"
"github.com/coder/coder/v2/agent/agentssh"
"github.com/coder/coder/v2/agent/agenttest"
@ -1761,6 +1765,74 @@ func TestAgent_ReconnectingPTY(t *testing.T) {
}
}
// This tests end-to-end functionality of connecting to a running container
// and executing a command. It creates a real Docker container and runs a
// command. As such, it does not run by default in CI.
// You can run it manually as follows:
//
// CODER_TEST_USE_DOCKER=1 go test -count=1 ./agent -run TestAgent_ReconnectingPTYContainer
func TestAgent_ReconnectingPTYContainer(t *testing.T) {
t.Parallel()
if os.Getenv("CODER_TEST_USE_DOCKER") != "1" {
t.Skip("Set CODER_TEST_USE_DOCKER=1 to run this test")
}
ctx := testutil.Context(t, testutil.WaitLong)
pool, err := dockertest.NewPool("")
require.NoError(t, err, "Could not connect to docker")
ct, err := pool.RunWithOptions(&dockertest.RunOptions{
Repository: "busybox",
Tag: "latest",
Cmd: []string{"sleep", "infnity"},
}, func(config *docker.HostConfig) {
config.AutoRemove = true
config.RestartPolicy = docker.RestartPolicy{Name: "no"}
})
require.NoError(t, err, "Could not start container")
t.Cleanup(func() {
err := pool.Purge(ct)
require.NoError(t, err, "Could not stop container")
})
// Wait for container to start
require.Eventually(t, func() bool {
ct, ok := pool.ContainerByName(ct.Container.Name)
return ok && ct.Container.State.Running
}, testutil.WaitShort, testutil.IntervalSlow, "Container did not start in time")
// nolint: dogsled
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
o.ExperimentalContainersEnabled = true
})
ac, err := conn.ReconnectingPTY(ctx, uuid.New(), 80, 80, "/bin/sh", func(arp *workspacesdk.AgentReconnectingPTYInit) {
arp.Container = ct.Container.ID
})
require.NoError(t, err, "failed to create ReconnectingPTY")
defer ac.Close()
tr := testutil.NewTerminalReader(t, ac)
require.NoError(t, tr.ReadUntil(ctx, func(line string) bool {
return strings.Contains(line, "#") || strings.Contains(line, "$")
}), "find prompt")
require.NoError(t, json.NewEncoder(ac).Encode(workspacesdk.ReconnectingPTYRequest{
Data: "hostname\r",
}), "write hostname")
require.NoError(t, tr.ReadUntil(ctx, func(line string) bool {
return strings.Contains(line, "hostname")
}), "find hostname command")
require.NoError(t, tr.ReadUntil(ctx, func(line string) bool {
return strings.Contains(line, ct.Container.Config.Hostname)
}), "find hostname output")
require.NoError(t, json.NewEncoder(ac).Encode(workspacesdk.ReconnectingPTYRequest{
Data: "exit\r",
}), "write exit command")
// Wait for the connection to close.
require.ErrorIs(t, tr.ReadUntil(ctx, nil), io.EOF)
}
func TestAgent_Dial(t *testing.T) {
t.Parallel()

View File

@ -6,7 +6,6 @@ import (
"context"
"encoding/json"
"fmt"
"os"
"os/user"
"slices"
"sort"
@ -15,6 +14,7 @@ import (
"time"
"github.com/coder/coder/v2/agent/agentexec"
"github.com/coder/coder/v2/agent/usershell"
"github.com/coder/coder/v2/codersdk"
"golang.org/x/exp/maps"
@ -37,6 +37,7 @@ func NewDocker(execer agentexec.Execer) Lister {
// DockerEnvInfoer is an implementation of agentssh.EnvInfoer that returns
// information about a container.
type DockerEnvInfoer struct {
usershell.SystemEnvInfo
container string
user *user.User
userShell string
@ -122,26 +123,13 @@ func EnvInfo(ctx context.Context, execer agentexec.Execer, container, containerU
return &dei, nil
}
func (dei *DockerEnvInfoer) CurrentUser() (*user.User, error) {
func (dei *DockerEnvInfoer) User() (*user.User, error) {
// Clone the user so that the caller can't modify it
u := *dei.user
return &u, nil
}
func (*DockerEnvInfoer) Environ() []string {
// Return a clone of the environment so that the caller can't modify it
return os.Environ()
}
func (*DockerEnvInfoer) UserHomeDir() (string, error) {
// We default the working directory of the command to the user's home
// directory. Since this came from inside the container, we cannot guarantee
// that this exists on the host. Return the "real" home directory of the user
// instead.
return os.UserHomeDir()
}
func (dei *DockerEnvInfoer) UserShell(string) (string, error) {
func (dei *DockerEnvInfoer) Shell(string) (string, error) {
return dei.userShell, nil
}

View File

@ -502,15 +502,15 @@ func TestDockerEnvInfoer(t *testing.T) {
dei, err := EnvInfo(ctx, agentexec.DefaultExecer, ct.Container.ID, tt.containerUser)
require.NoError(t, err, "Expected no error from DockerEnvInfo()")
u, err := dei.CurrentUser()
u, err := dei.User()
require.NoError(t, err, "Expected no error from CurrentUser()")
require.Equal(t, tt.expectedUsername, u.Username, "Expected username to match")
hd, err := dei.UserHomeDir()
hd, err := dei.HomeDir()
require.NoError(t, err, "Expected no error from UserHomeDir()")
require.NotEmpty(t, hd, "Expected user homedir to be non-empty")
sh, err := dei.UserShell(tt.containerUser)
sh, err := dei.Shell(tt.containerUser)
require.NoError(t, err, "Expected no error from UserShell()")
require.Equal(t, tt.expectedUserShell, sh, "Expected user shell to match")

View File

@ -698,45 +698,6 @@ func (s *Server) sftpHandler(logger slog.Logger, session ssh.Session) {
_ = session.Exit(1)
}
// EnvInfoer encapsulates external information required by CreateCommand.
type EnvInfoer interface {
// CurrentUser returns the current user.
CurrentUser() (*user.User, error)
// Environ returns the environment variables of the current process.
Environ() []string
// UserHomeDir returns the home directory of the current user.
UserHomeDir() (string, error)
// UserShell returns the shell of the given user.
UserShell(username string) (string, error)
}
type systemEnvInfoer struct{}
var defaultEnvInfoer EnvInfoer = &systemEnvInfoer{}
// DefaultEnvInfoer returns a default implementation of
// EnvInfoer. This reads information using the default Go
// implementations.
func DefaultEnvInfoer() EnvInfoer {
return defaultEnvInfoer
}
func (systemEnvInfoer) CurrentUser() (*user.User, error) {
return user.Current()
}
func (systemEnvInfoer) Environ() []string {
return os.Environ()
}
func (systemEnvInfoer) UserHomeDir() (string, error) {
return userHomeDir()
}
func (systemEnvInfoer) UserShell(username string) (string, error) {
return usershell.Get(username)
}
// CreateCommand processes raw command input with OpenSSH-like behavior.
// If the script provided is empty, it will default to the users shell.
// This injects environment variables specified by the user at launch too.
@ -744,17 +705,17 @@ func (systemEnvInfoer) UserShell(username string) (string, error) {
// alternative implementations for the dependencies of CreateCommand.
// This is useful when creating a command to be run in a separate environment
// (for example, a Docker container). Pass in nil to use the default.
func (s *Server) CreateCommand(ctx context.Context, script string, env []string, deps EnvInfoer) (*pty.Cmd, error) {
if deps == nil {
deps = DefaultEnvInfoer()
func (s *Server) CreateCommand(ctx context.Context, script string, env []string, ei usershell.EnvInfoer) (*pty.Cmd, error) {
if ei == nil {
ei = &usershell.SystemEnvInfo{}
}
currentUser, err := deps.CurrentUser()
currentUser, err := ei.User()
if err != nil {
return nil, xerrors.Errorf("get current user: %w", err)
}
username := currentUser.Username
shell, err := deps.UserShell(username)
shell, err := ei.Shell(username)
if err != nil {
return nil, xerrors.Errorf("get user shell: %w", err)
}
@ -802,7 +763,18 @@ func (s *Server) CreateCommand(ctx context.Context, script string, env []string,
}
}
cmd := s.Execer.PTYCommandContext(ctx, name, args...)
// Modify command prior to execution. This will usually be a no-op, but not
// always. For example, to run a command in a Docker container, we need to
// modify the command to be `docker exec -it <container> <command>`.
modifiedName, modifiedArgs := ei.ModifyCommand(name, args...)
// Log if the command was modified.
if modifiedName != name && slices.Compare(modifiedArgs, args) != 0 {
s.logger.Debug(ctx, "modified command",
slog.F("before", append([]string{name}, args...)),
slog.F("after", append([]string{modifiedName}, modifiedArgs...)),
)
}
cmd := s.Execer.PTYCommandContext(ctx, modifiedName, modifiedArgs...)
cmd.Dir = s.config.WorkingDirectory()
// If the metadata directory doesn't exist, we run the command
@ -810,13 +782,13 @@ func (s *Server) CreateCommand(ctx context.Context, script string, env []string,
_, err = os.Stat(cmd.Dir)
if cmd.Dir == "" || err != nil {
// Default to user home if a directory is not set.
homedir, err := deps.UserHomeDir()
homedir, err := ei.HomeDir()
if err != nil {
return nil, xerrors.Errorf("get home dir: %w", err)
}
cmd.Dir = homedir
}
cmd.Env = append(deps.Environ(), env...)
cmd.Env = append(ei.Environ(), env...)
cmd.Env = append(cmd.Env, fmt.Sprintf("USER=%s", username))
// Set SSH connection environment variables (these are also set by OpenSSH

View File

@ -124,7 +124,7 @@ type fakeEnvInfoer struct {
UserShellFn func(string) (string, error)
}
func (f *fakeEnvInfoer) CurrentUser() (u *user.User, err error) {
func (f *fakeEnvInfoer) User() (u *user.User, err error) {
return f.CurrentUserFn()
}
@ -132,14 +132,18 @@ func (f *fakeEnvInfoer) Environ() []string {
return f.EnvironFn()
}
func (f *fakeEnvInfoer) UserHomeDir() (string, error) {
func (f *fakeEnvInfoer) HomeDir() (string, error) {
return f.UserHomeDirFn()
}
func (f *fakeEnvInfoer) UserShell(u string) (string, error) {
func (f *fakeEnvInfoer) Shell(u string) (string, error) {
return f.UserShellFn(u)
}
func (*fakeEnvInfoer) ModifyCommand(cmd string, args ...string) (string, []string) {
return cmd, args
}
func TestNewServer_CloseActiveConnections(t *testing.T) {
t.Parallel()

View File

@ -14,7 +14,9 @@ import (
"golang.org/x/xerrors"
"cdr.dev/slog"
"github.com/coder/coder/v2/agent/agentcontainers"
"github.com/coder/coder/v2/agent/agentssh"
"github.com/coder/coder/v2/agent/usershell"
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
@ -26,20 +28,26 @@ type Server struct {
connCount atomic.Int64
reconnectingPTYs sync.Map
timeout time.Duration
ExperimentalContainersEnabled bool
}
// NewServer returns a new ReconnectingPTY server
func NewServer(logger slog.Logger, commandCreator *agentssh.Server,
connectionsTotal prometheus.Counter, errorsTotal *prometheus.CounterVec,
timeout time.Duration,
timeout time.Duration, opts ...func(*Server),
) *Server {
return &Server{
s := &Server{
logger: logger,
commandCreator: commandCreator,
connectionsTotal: connectionsTotal,
errorsTotal: errorsTotal,
timeout: timeout,
}
for _, o := range opts {
o(s)
}
return s
}
func (s *Server) Serve(ctx, hardCtx context.Context, l net.Listener) (retErr error) {
@ -116,7 +124,7 @@ func (s *Server) handleConn(ctx context.Context, logger slog.Logger, conn net.Co
}
connectionID := uuid.NewString()
connLogger := logger.With(slog.F("message_id", msg.ID), slog.F("connection_id", connectionID))
connLogger := logger.With(slog.F("message_id", msg.ID), slog.F("connection_id", connectionID), slog.F("container", msg.Container), slog.F("container_user", msg.ContainerUser))
connLogger.Debug(ctx, "starting handler")
defer func() {
@ -158,8 +166,17 @@ func (s *Server) handleConn(ctx context.Context, logger slog.Logger, conn net.Co
}
}()
var ei usershell.EnvInfoer
if s.ExperimentalContainersEnabled && msg.Container != "" {
dei, err := agentcontainers.EnvInfo(ctx, s.commandCreator.Execer, msg.Container, msg.ContainerUser)
if err != nil {
return xerrors.Errorf("get container env info: %w", err)
}
ei = dei
s.logger.Info(ctx, "got container env info", slog.F("container", msg.Container))
}
// Empty command will default to the users shell!
cmd, err := s.commandCreator.CreateCommand(ctx, msg.Command, nil, nil)
cmd, err := s.commandCreator.CreateCommand(ctx, msg.Command, nil, ei)
if err != nil {
s.errorsTotal.WithLabelValues("create_command").Add(1)
return xerrors.Errorf("create command: %w", err)

View File

@ -0,0 +1,66 @@
package usershell
import (
"os"
"os/user"
"golang.org/x/xerrors"
)
// HomeDir returns the home directory of the current user, giving
// priority to the $HOME environment variable.
// Deprecated: use EnvInfoer.HomeDir() instead.
func HomeDir() (string, error) {
// First we check the environment.
homedir, err := os.UserHomeDir()
if err == nil {
return homedir, nil
}
// As a fallback, we try the user information.
u, err := user.Current()
if err != nil {
return "", xerrors.Errorf("current user: %w", err)
}
return u.HomeDir, nil
}
// EnvInfoer encapsulates external information about the environment.
type EnvInfoer interface {
// User returns the current user.
User() (*user.User, error)
// Environ returns the environment variables of the current process.
Environ() []string
// HomeDir returns the home directory of the current user.
HomeDir() (string, error)
// Shell returns the shell of the given user.
Shell(username string) (string, error)
// ModifyCommand modifies the command and arguments before execution based on
// the environment. This is useful for executing a command inside a container.
// In the default case, the command and arguments are returned unchanged.
ModifyCommand(name string, args ...string) (string, []string)
}
// SystemEnvInfo encapsulates the information about the environment
// just using the default Go implementations.
type SystemEnvInfo struct{}
func (SystemEnvInfo) User() (*user.User, error) {
return user.Current()
}
func (SystemEnvInfo) Environ() []string {
return os.Environ()
}
func (SystemEnvInfo) HomeDir() (string, error) {
return HomeDir()
}
func (SystemEnvInfo) Shell(username string) (string, error) {
return Get(username)
}
func (SystemEnvInfo) ModifyCommand(name string, args ...string) (string, []string) {
return name, args
}

View File

@ -10,6 +10,7 @@ import (
)
// Get returns the $SHELL environment variable.
// Deprecated: use SystemEnvInfo.UserShell instead.
func Get(username string) (string, error) {
// This command will output "UserShell: /bin/zsh" if successful, we
// can ignore the error since we have fallback behavior.

View File

@ -11,6 +11,7 @@ import (
)
// Get returns the /etc/passwd entry for the username provided.
// Deprecated: use SystemEnvInfo.UserShell instead.
func Get(username string) (string, error) {
contents, err := os.ReadFile("/etc/passwd")
if err != nil {

View File

@ -3,6 +3,7 @@ package usershell
import "os/exec"
// Get returns the command prompt binary name.
// Deprecated: use SystemEnvInfo.UserShell instead.
func Get(username string) (string, error) {
_, err := exec.LookPath("pwsh.exe")
if err == nil {

View File

@ -351,6 +351,8 @@ func (r *RootCmd) workspaceAgent() *serpent.Command {
BlockFileTransfer: blockFileTransfer,
Execer: execer,
ContainerLister: containerLister,
ExperimentalContainersEnabled: devcontainersEnabled,
})
promHandler := agent.PrometheusMetricsHandler(prometheusRegistry, logger)

View File

@ -653,6 +653,8 @@ func (s *Server) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
reconnect := parser.RequiredNotEmpty("reconnect").UUID(values, uuid.New(), "reconnect")
height := parser.UInt(values, 80, "height")
width := parser.UInt(values, 80, "width")
container := parser.String(values, "", "container")
containerUser := parser.String(values, "", "container_user")
if len(parser.Errors) > 0 {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid query parameters.",
@ -690,7 +692,10 @@ func (s *Server) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
}
defer release()
log.Debug(ctx, "dialed workspace agent")
ptNetConn, err := agentConn.ReconnectingPTY(ctx, reconnect, uint16(height), uint16(width), r.URL.Query().Get("command"))
ptNetConn, err := agentConn.ReconnectingPTY(ctx, reconnect, uint16(height), uint16(width), r.URL.Query().Get("command"), func(arp *workspacesdk.AgentReconnectingPTYInit) {
arp.Container = container
arp.ContainerUser = containerUser
})
if err != nil {
log.Debug(ctx, "dial reconnecting pty server in workspace agent", slog.Error(err))
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("dial: %s", err))

View File

@ -93,6 +93,24 @@ type AgentReconnectingPTYInit struct {
Height uint16
Width uint16
Command string
// Container, if set, will attempt to exec into a running container visible to the agent.
// This should be a unique container ID (implementation-dependent).
Container string
// ContainerUser, if set, will set the target user when execing into a container.
// This can be a username or UID, depending on the underlying implementation.
// This is ignored if Container is not set.
ContainerUser string
}
// AgentReconnectingPTYInitOption is a functional option for AgentReconnectingPTYInit.
type AgentReconnectingPTYInitOption func(*AgentReconnectingPTYInit)
// AgentReconnectingPTYInitWithContainer sets the container and container user for the reconnecting PTY session.
func AgentReconnectingPTYInitWithContainer(container, containerUser string) AgentReconnectingPTYInitOption {
return func(init *AgentReconnectingPTYInit) {
init.Container = container
init.ContainerUser = containerUser
}
}
// ReconnectingPTYRequest is sent from the client to the server
@ -107,7 +125,7 @@ type ReconnectingPTYRequest struct {
// ReconnectingPTY spawns a new reconnecting terminal session.
// `ReconnectingPTYRequest` should be JSON marshaled and written to the returned net.Conn.
// Raw terminal output will be read from the returned net.Conn.
func (c *AgentConn) ReconnectingPTY(ctx context.Context, id uuid.UUID, height, width uint16, command string) (net.Conn, error) {
func (c *AgentConn) ReconnectingPTY(ctx context.Context, id uuid.UUID, height, width uint16, command string, initOpts ...AgentReconnectingPTYInitOption) (net.Conn, error) {
ctx, span := tracing.StartSpan(ctx)
defer span.End()
@ -119,12 +137,16 @@ func (c *AgentConn) ReconnectingPTY(ctx context.Context, id uuid.UUID, height, w
if err != nil {
return nil, err
}
data, err := json.Marshal(AgentReconnectingPTYInit{
rptyInit := AgentReconnectingPTYInit{
ID: id,
Height: height,
Width: width,
Command: command,
})
}
for _, o := range initOpts {
o(&rptyInit)
}
data, err := json.Marshal(rptyInit)
if err != nil {
_ = conn.Close()
return nil, err

View File

@ -12,12 +12,14 @@ import (
"strconv"
"strings"
"github.com/google/uuid"
"golang.org/x/xerrors"
"tailscale.com/tailcfg"
"tailscale.com/wgengine/capture"
"github.com/google/uuid"
"golang.org/x/xerrors"
"cdr.dev/slog"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/proto"
@ -305,6 +307,16 @@ type WorkspaceAgentReconnectingPTYOpts struct {
// issue-reconnecting-pty-signed-token endpoint. If set, the session token
// on the client will not be sent.
SignedToken string
// Experimental: Container, if set, will attempt to exec into a running container
// visible to the agent. This should be a unique container ID
// (implementation-dependent).
// ContainerUser is the user as which to exec into the container.
// NOTE: This feature is currently experimental and is currently "opt-in".
// In order to use this feature, the agent must have the environment variable
// CODER_AGENT_DEVCONTAINERS_ENABLE set to "true".
Container string
ContainerUser string
}
// AgentReconnectingPTY spawns a PTY that reconnects using the token provided.
@ -320,6 +332,12 @@ func (c *Client) AgentReconnectingPTY(ctx context.Context, opts WorkspaceAgentRe
q.Set("width", strconv.Itoa(int(opts.Width)))
q.Set("height", strconv.Itoa(int(opts.Height)))
q.Set("command", opts.Command)
if opts.Container != "" {
q.Set("container", opts.Container)
}
if opts.ContainerUser != "" {
q.Set("container_user", opts.ContainerUser)
}
// If we're using a signed token, set the query parameter.
if opts.SignedToken != "" {
q.Set(codersdk.SignedAppTokenQueryParameter, opts.SignedToken)