Files
coder/cli/ssh_internal_test.go
Mathias Fredriksson bacdc28881 fix(cli)!: enforce selection for multiple agents rather than use randomness (#18427)
In the past we randomly selected workspace agent if there were multiple.
Unless both are running on the same machine with the same configuration,
this would be very confusing behavior for a user.

With the introduction of sub agents (devcontainer agents), we have now
made this an error state and require the specifying of agent when there
is more than one (either normal agent or sub agent).

This aligns with the behavior of e.g. Coder Desktop.

Fixes coder/internal#696
2025-06-19 22:13:35 +03:00

444 lines
11 KiB
Go

package cli
import (
"context"
"fmt"
"io"
"net"
"net/url"
"sync"
"testing"
"time"
gliderssh "github.com/gliderlabs/ssh"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"
"golang.org/x/xerrors"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/quartz"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
)
const (
fakeOwnerName = "fake-owner-name"
fakeServerURL = "https://fake-foo-url"
fakeWorkspaceName = "fake-workspace-name"
)
func TestVerifyWorkspaceOutdated(t *testing.T) {
t.Parallel()
serverURL, err := url.Parse(fakeServerURL)
require.NoError(t, err)
client := codersdk.Client{URL: serverURL}
t.Run("Up-to-date", func(t *testing.T) {
t.Parallel()
workspace := codersdk.Workspace{Name: fakeWorkspaceName, OwnerName: fakeOwnerName}
_, outdated := verifyWorkspaceOutdated(&client, workspace)
assert.False(t, outdated, "workspace should be up-to-date")
})
t.Run("Outdated", func(t *testing.T) {
t.Parallel()
workspace := codersdk.Workspace{Name: fakeWorkspaceName, OwnerName: fakeOwnerName, Outdated: true}
updateWorkspaceBanner, outdated := verifyWorkspaceOutdated(&client, workspace)
assert.True(t, outdated, "workspace should be outdated")
assert.NotEmpty(t, updateWorkspaceBanner, "workspace banner should be present")
})
}
func TestBuildWorkspaceLink(t *testing.T) {
t.Parallel()
serverURL, err := url.Parse(fakeServerURL)
require.NoError(t, err)
workspace := codersdk.Workspace{Name: fakeWorkspaceName, OwnerName: fakeOwnerName}
workspaceLink := buildWorkspaceLink(serverURL, workspace)
assert.Equal(t, workspaceLink.String(), fakeServerURL+"/@"+fakeOwnerName+"/"+fakeWorkspaceName)
}
func TestCloserStack_Mainline(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := testutil.Logger(t)
uut := newCloserStack(ctx, logger, quartz.NewMock(t))
closes := new([]*fakeCloser)
fc0 := &fakeCloser{closes: closes}
fc1 := &fakeCloser{closes: closes}
func() {
defer uut.close(nil)
err := uut.push("fc0", fc0)
require.NoError(t, err)
err = uut.push("fc1", fc1)
require.NoError(t, err)
}()
// order reversed
require.Equal(t, []*fakeCloser{fc1, fc0}, *closes)
}
func TestCloserStack_Empty(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := testutil.Logger(t)
uut := newCloserStack(ctx, logger, quartz.NewMock(t))
closed := make(chan struct{})
go func() {
defer close(closed)
uut.close(nil)
}()
testutil.TryReceive(ctx, t, closed)
}
func TestCloserStack_Context(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
ctx, cancel := context.WithCancel(ctx)
defer cancel()
logger := testutil.Logger(t)
uut := newCloserStack(ctx, logger, quartz.NewMock(t))
closes := new([]*fakeCloser)
fc0 := &fakeCloser{closes: closes}
fc1 := &fakeCloser{closes: closes}
err := uut.push("fc0", fc0)
require.NoError(t, err)
err = uut.push("fc1", fc1)
require.NoError(t, err)
cancel()
require.Eventually(t, func() bool {
uut.Lock()
defer uut.Unlock()
return uut.closed
}, testutil.WaitShort, testutil.IntervalFast)
}
func TestCloserStack_PushAfterClose(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
uut := newCloserStack(ctx, logger, quartz.NewMock(t))
closes := new([]*fakeCloser)
fc0 := &fakeCloser{closes: closes}
fc1 := &fakeCloser{closes: closes}
err := uut.push("fc0", fc0)
require.NoError(t, err)
exErr := xerrors.New("test")
uut.close(exErr)
require.Equal(t, []*fakeCloser{fc0}, *closes)
err = uut.push("fc1", fc1)
require.ErrorIs(t, err, exErr)
require.Equal(t, []*fakeCloser{fc1, fc0}, *closes, "should close fc1")
}
func TestCloserStack_CloseAfterContext(t *testing.T) {
t.Parallel()
testCtx := testutil.Context(t, testutil.WaitShort)
ctx, cancel := context.WithCancel(testCtx)
defer cancel()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
uut := newCloserStack(ctx, logger, quartz.NewMock(t))
ac := newAsyncCloser(testCtx, t)
defer ac.complete()
err := uut.push("async", ac)
require.NoError(t, err)
cancel()
testutil.TryReceive(testCtx, t, ac.started)
closed := make(chan struct{})
go func() {
defer close(closed)
uut.close(nil)
}()
// since the asyncCloser is still waiting, we shouldn't complete uut.close()
select {
case <-time.After(testutil.IntervalFast):
// OK!
case <-closed:
t.Fatal("closed before stack was finished")
}
ac.complete()
testutil.TryReceive(testCtx, t, closed)
}
func TestCloserStack_Timeout(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
mClock := quartz.NewMock(t)
trap := mClock.Trap().TickerFunc("closerStack")
defer trap.Close()
uut := newCloserStack(ctx, logger, mClock)
var ac [3]*asyncCloser
for i := range ac {
ac[i] = newAsyncCloser(ctx, t)
err := uut.push(fmt.Sprintf("async %d", i), ac[i])
require.NoError(t, err)
}
defer func() {
for _, a := range ac {
a.complete()
}
}()
closed := make(chan struct{})
go func() {
defer close(closed)
uut.close(nil)
}()
trap.MustWait(ctx).MustRelease(ctx)
// top starts right away, but it hangs
testutil.TryReceive(ctx, t, ac[2].started)
// timer pops and we start the middle one
mClock.Advance(gracefulShutdownTimeout).MustWait(ctx)
testutil.TryReceive(ctx, t, ac[1].started)
// middle one finishes
ac[1].complete()
// bottom starts, but also hangs
testutil.TryReceive(ctx, t, ac[0].started)
// timer has to pop twice to time out.
mClock.Advance(gracefulShutdownTimeout).MustWait(ctx)
mClock.Advance(gracefulShutdownTimeout).MustWait(ctx)
testutil.TryReceive(ctx, t, closed)
}
func TestCoderConnectStdio(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
stack := newCloserStack(ctx, logger, quartz.NewMock(t))
clientOutput, clientInput := io.Pipe()
serverOutput, serverInput := io.Pipe()
defer func() {
for _, c := range []io.Closer{clientOutput, clientInput, serverOutput, serverInput} {
_ = c.Close()
}
}()
server := newSSHServer("127.0.0.1:0")
ln, err := net.Listen("tcp", server.server.Addr)
require.NoError(t, err)
go func() {
_ = server.Serve(ln)
}()
t.Cleanup(func() {
_ = server.Close()
})
stdioDone := make(chan struct{})
go func() {
err = runCoderConnectStdio(ctx, ln.Addr().String(), clientOutput, serverInput, stack)
assert.NoError(t, err)
close(stdioDone)
}()
conn, channels, requests, err := ssh.NewClientConn(&testutil.ReaderWriterConn{
Reader: serverOutput,
Writer: clientInput,
}, "", &ssh.ClientConfig{
// #nosec
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
})
require.NoError(t, err)
defer conn.Close()
sshClient := ssh.NewClient(conn, channels, requests)
session, err := sshClient.NewSession()
require.NoError(t, err)
defer session.Close()
// We're not connected to a real shell
err = session.Run("")
require.NoError(t, err)
err = sshClient.Close()
require.NoError(t, err)
_ = clientOutput.Close()
<-stdioDone
}
type sshServer struct {
server *gliderssh.Server
}
func newSSHServer(addr string) *sshServer {
return &sshServer{
server: &gliderssh.Server{
Addr: addr,
Handler: func(s gliderssh.Session) {
_, _ = io.WriteString(s.Stderr(), "Connected!")
},
},
}
}
func (s *sshServer) Serve(ln net.Listener) error {
return s.server.Serve(ln)
}
func (s *sshServer) Close() error {
return s.server.Close()
}
type fakeCloser struct {
closes *[]*fakeCloser
err error
}
func (c *fakeCloser) Close() error {
*c.closes = append(*c.closes, c)
return c.err
}
type asyncCloser struct {
t *testing.T
ctx context.Context
started chan struct{}
isComplete chan struct{}
comepleteOnce sync.Once
}
func (c *asyncCloser) Close() error {
close(c.started)
select {
case <-c.ctx.Done():
c.t.Error("timed out")
return c.ctx.Err()
case <-c.isComplete:
return nil
}
}
func (c *asyncCloser) complete() {
c.comepleteOnce.Do(func() { close(c.isComplete) })
}
func newAsyncCloser(ctx context.Context, t *testing.T) *asyncCloser {
return &asyncCloser{
t: t,
ctx: ctx,
isComplete: make(chan struct{}),
started: make(chan struct{}),
}
}
func Test_getWorkspaceAgent(t *testing.T) {
t.Parallel()
createWorkspaceWithAgents := func(agents []codersdk.WorkspaceAgent) codersdk.Workspace {
return codersdk.Workspace{
Name: "test-workspace",
LatestBuild: codersdk.WorkspaceBuild{
Resources: []codersdk.WorkspaceResource{
{
Agents: agents,
},
},
},
}
}
createAgent := func(name string) codersdk.WorkspaceAgent {
return codersdk.WorkspaceAgent{
ID: uuid.New(),
Name: name,
}
}
t.Run("SingleAgent_NoNameSpecified", func(t *testing.T) {
t.Parallel()
agent := createAgent("main")
workspace := createWorkspaceWithAgents([]codersdk.WorkspaceAgent{agent})
result, err := getWorkspaceAgent(workspace, "")
require.NoError(t, err)
assert.Equal(t, agent.ID, result.ID)
assert.Equal(t, "main", result.Name)
})
t.Run("MultipleAgents_NoNameSpecified", func(t *testing.T) {
t.Parallel()
agent1 := createAgent("main1")
agent2 := createAgent("main2")
workspace := createWorkspaceWithAgents([]codersdk.WorkspaceAgent{agent1, agent2})
_, err := getWorkspaceAgent(workspace, "")
require.Error(t, err)
assert.Contains(t, err.Error(), "multiple agents found")
assert.Contains(t, err.Error(), "available agents: [main1 main2]")
})
t.Run("AgentNameSpecified_Found", func(t *testing.T) {
t.Parallel()
agent1 := createAgent("main1")
agent2 := createAgent("main2")
workspace := createWorkspaceWithAgents([]codersdk.WorkspaceAgent{agent1, agent2})
result, err := getWorkspaceAgent(workspace, "main1")
require.NoError(t, err)
assert.Equal(t, agent1.ID, result.ID)
assert.Equal(t, "main1", result.Name)
})
t.Run("AgentNameSpecified_NotFound", func(t *testing.T) {
t.Parallel()
agent1 := createAgent("main1")
agent2 := createAgent("main2")
workspace := createWorkspaceWithAgents([]codersdk.WorkspaceAgent{agent1, agent2})
_, err := getWorkspaceAgent(workspace, "nonexistent")
require.Error(t, err)
assert.Contains(t, err.Error(), `agent not found by name "nonexistent"`)
assert.Contains(t, err.Error(), "available agents: [main1 main2]")
})
t.Run("NoAgents", func(t *testing.T) {
t.Parallel()
workspace := createWorkspaceWithAgents([]codersdk.WorkspaceAgent{})
_, err := getWorkspaceAgent(workspace, "")
require.Error(t, err)
assert.Contains(t, err.Error(), `workspace "test-workspace" has no agents`)
})
t.Run("AvailableAgentNames_SortedCorrectly", func(t *testing.T) {
t.Parallel()
// Define agents in non-alphabetical order.
agent2 := createAgent("zod")
agent1 := createAgent("clark")
agent3 := createAgent("krypton")
workspace := createWorkspaceWithAgents([]codersdk.WorkspaceAgent{agent2, agent1, agent3})
_, err := getWorkspaceAgent(workspace, "nonexistent")
require.Error(t, err)
// Available agents should be sorted alphabetically.
assert.Contains(t, err.Error(), "available agents: [clark krypton zod]")
})
}