From 1c3bfacca3eeaa08257c4f92589d2304a7604efc Mon Sep 17 00:00:00 2001 From: Colin Adler Date: Wed, 12 Jul 2023 10:21:54 -0500 Subject: [PATCH] fix(cli): ensure `cliui.Agent` doesn't fetch infinitely (#8446) --- cli/cliui/agent.go | 9 +++-- cli/cliui/agent_test.go | 81 ++++++++++++++++++++++++++++------------- cli/portforward.go | 8 ++-- cli/speedtest.go | 8 ++-- cli/ssh.go | 6 +-- cmd/cliui/main.go | 4 +- 6 files changed, 72 insertions(+), 44 deletions(-) diff --git a/cli/cliui/agent.go b/cli/cliui/agent.go index 4bc0493ee3..acbbca9bef 100644 --- a/cli/cliui/agent.go +++ b/cli/cliui/agent.go @@ -15,13 +15,16 @@ var errAgentShuttingDown = xerrors.New("agent is shutting down") type AgentOptions struct { FetchInterval time.Duration - Fetch func(context.Context) (codersdk.WorkspaceAgent, error) + Fetch func(ctx context.Context, agentID uuid.UUID) (codersdk.WorkspaceAgent, error) FetchLogs func(ctx context.Context, agentID uuid.UUID, after int64, follow bool) (<-chan []codersdk.WorkspaceAgentStartupLog, io.Closer, error) Wait bool // If true, wait for the agent to be ready (startup script). } // Agent displays a spinning indicator that waits for a workspace agent to connect. -func Agent(ctx context.Context, writer io.Writer, opts AgentOptions) error { +func Agent(ctx context.Context, writer io.Writer, agentID uuid.UUID, opts AgentOptions) error { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + if opts.FetchInterval == 0 { opts.FetchInterval = 500 * time.Millisecond } @@ -47,7 +50,7 @@ func Agent(ctx context.Context, writer io.Writer, opts AgentOptions) error { case <-ctx.Done(): return case <-t.C: - agent, err := opts.Fetch(ctx) + agent, err := opts.Fetch(ctx, agentID) select { case <-fetchedAgent: default: diff --git a/cli/cliui/agent_test.go b/cli/cliui/agent_test.go index 184be6ff85..c08ba163ea 100644 --- a/cli/cliui/agent_test.go +++ b/cli/cliui/agent_test.go @@ -6,6 +6,7 @@ import ( "context" "io" "strings" + "sync/atomic" "testing" "time" @@ -16,6 +17,7 @@ import ( "github.com/coder/coder/cli/clibase" "github.com/coder/coder/cli/clitest" "github.com/coder/coder/cli/cliui" + "github.com/coder/coder/coderd/util/ptr" "github.com/coder/coder/codersdk" "github.com/coder/coder/testutil" ) @@ -23,10 +25,6 @@ import ( func TestAgent(t *testing.T) { t.Parallel() - ptrTime := func(t time.Time) *time.Time { - return &t - } - for _, tc := range []struct { name string iter []func(context.Context, *codersdk.WorkspaceAgent, chan []codersdk.WorkspaceAgentStartupLog) error @@ -47,7 +45,7 @@ func TestAgent(t *testing.T) { }, func(_ context.Context, agent *codersdk.WorkspaceAgent, logs chan []codersdk.WorkspaceAgentStartupLog) error { agent.Status = codersdk.WorkspaceAgentConnected - agent.FirstConnectedAt = ptrTime(time.Now()) + agent.FirstConnectedAt = ptr.Ref(time.Now()) close(logs) return nil }, @@ -69,7 +67,7 @@ func TestAgent(t *testing.T) { func(_ context.Context, agent *codersdk.WorkspaceAgent, _ chan []codersdk.WorkspaceAgentStartupLog) error { agent.Status = codersdk.WorkspaceAgentConnecting agent.LifecycleState = codersdk.WorkspaceAgentLifecycleStarting - agent.StartedAt = ptrTime(time.Now()) + agent.StartedAt = ptr.Ref(time.Now()) return nil }, func(_ context.Context, agent *codersdk.WorkspaceAgent, _ chan []codersdk.WorkspaceAgentStartupLog) error { @@ -78,9 +76,9 @@ func TestAgent(t *testing.T) { }, func(_ context.Context, agent *codersdk.WorkspaceAgent, logs chan []codersdk.WorkspaceAgentStartupLog) error { agent.Status = codersdk.WorkspaceAgentConnected - agent.FirstConnectedAt = ptrTime(time.Now()) + agent.FirstConnectedAt = ptr.Ref(time.Now()) agent.LifecycleState = codersdk.WorkspaceAgentLifecycleReady - agent.ReadyAt = ptrTime(time.Now()) + agent.ReadyAt = ptr.Ref(time.Now()) close(logs) return nil }, @@ -102,17 +100,17 @@ func TestAgent(t *testing.T) { iter: []func(context.Context, *codersdk.WorkspaceAgent, chan []codersdk.WorkspaceAgentStartupLog) error{ func(_ context.Context, agent *codersdk.WorkspaceAgent, _ chan []codersdk.WorkspaceAgentStartupLog) error { agent.Status = codersdk.WorkspaceAgentDisconnected - agent.FirstConnectedAt = ptrTime(time.Now().Add(-1 * time.Minute)) - agent.LastConnectedAt = ptrTime(time.Now().Add(-1 * time.Minute)) - agent.DisconnectedAt = ptrTime(time.Now()) + agent.FirstConnectedAt = ptr.Ref(time.Now().Add(-1 * time.Minute)) + agent.LastConnectedAt = ptr.Ref(time.Now().Add(-1 * time.Minute)) + agent.DisconnectedAt = ptr.Ref(time.Now()) agent.LifecycleState = codersdk.WorkspaceAgentLifecycleReady - agent.StartedAt = ptrTime(time.Now().Add(-1 * time.Minute)) - agent.ReadyAt = ptrTime(time.Now()) + agent.StartedAt = ptr.Ref(time.Now().Add(-1 * time.Minute)) + agent.ReadyAt = ptr.Ref(time.Now()) return nil }, func(_ context.Context, agent *codersdk.WorkspaceAgent, _ chan []codersdk.WorkspaceAgentStartupLog) error { agent.Status = codersdk.WorkspaceAgentConnected - agent.LastConnectedAt = ptrTime(time.Now()) + agent.LastConnectedAt = ptr.Ref(time.Now()) return nil }, func(_ context.Context, _ *codersdk.WorkspaceAgent, logs chan []codersdk.WorkspaceAgentStartupLog) error { @@ -136,9 +134,9 @@ func TestAgent(t *testing.T) { iter: []func(context.Context, *codersdk.WorkspaceAgent, chan []codersdk.WorkspaceAgentStartupLog) error{ func(_ context.Context, agent *codersdk.WorkspaceAgent, logs chan []codersdk.WorkspaceAgentStartupLog) error { agent.Status = codersdk.WorkspaceAgentConnected - agent.FirstConnectedAt = ptrTime(time.Now()) + agent.FirstConnectedAt = ptr.Ref(time.Now()) agent.LifecycleState = codersdk.WorkspaceAgentLifecycleStarting - agent.StartedAt = ptrTime(time.Now()) + agent.StartedAt = ptr.Ref(time.Now()) logs <- []codersdk.WorkspaceAgentStartupLog{ { CreatedAt: time.Now(), @@ -149,7 +147,7 @@ func TestAgent(t *testing.T) { }, func(_ context.Context, agent *codersdk.WorkspaceAgent, logs chan []codersdk.WorkspaceAgentStartupLog) error { agent.LifecycleState = codersdk.WorkspaceAgentLifecycleReady - agent.ReadyAt = ptrTime(time.Now()) + agent.ReadyAt = ptr.Ref(time.Now()) logs <- []codersdk.WorkspaceAgentStartupLog{ { CreatedAt: time.Now(), @@ -176,10 +174,10 @@ func TestAgent(t *testing.T) { iter: []func(context.Context, *codersdk.WorkspaceAgent, chan []codersdk.WorkspaceAgentStartupLog) error{ func(_ context.Context, agent *codersdk.WorkspaceAgent, logs chan []codersdk.WorkspaceAgentStartupLog) error { agent.Status = codersdk.WorkspaceAgentConnected - agent.FirstConnectedAt = ptrTime(time.Now()) - agent.StartedAt = ptrTime(time.Now()) + agent.FirstConnectedAt = ptr.Ref(time.Now()) + agent.StartedAt = ptr.Ref(time.Now()) agent.LifecycleState = codersdk.WorkspaceAgentLifecycleStartError - agent.ReadyAt = ptrTime(time.Now()) + agent.ReadyAt = ptr.Ref(time.Now()) logs <- []codersdk.WorkspaceAgentStartupLog{ { CreatedAt: time.Now(), @@ -222,9 +220,9 @@ func TestAgent(t *testing.T) { iter: []func(context.Context, *codersdk.WorkspaceAgent, chan []codersdk.WorkspaceAgentStartupLog) error{ func(_ context.Context, agent *codersdk.WorkspaceAgent, logs chan []codersdk.WorkspaceAgentStartupLog) error { agent.Status = codersdk.WorkspaceAgentConnected - agent.FirstConnectedAt = ptrTime(time.Now()) + agent.FirstConnectedAt = ptr.Ref(time.Now()) agent.LifecycleState = codersdk.WorkspaceAgentLifecycleStarting - agent.StartedAt = ptrTime(time.Now()) + agent.StartedAt = ptr.Ref(time.Now()) logs <- []codersdk.WorkspaceAgentStartupLog{ { CreatedAt: time.Now(), @@ -234,7 +232,7 @@ func TestAgent(t *testing.T) { return nil }, func(_ context.Context, agent *codersdk.WorkspaceAgent, logs chan []codersdk.WorkspaceAgentStartupLog) error { - agent.ReadyAt = ptrTime(time.Now()) + agent.ReadyAt = ptr.Ref(time.Now()) agent.LifecycleState = codersdk.WorkspaceAgentLifecycleShuttingDown close(logs) return nil @@ -310,7 +308,7 @@ func TestAgent(t *testing.T) { cmd := &clibase.Cmd{ Handler: func(inv *clibase.Invocation) error { - tc.opts.Fetch = func(_ context.Context) (codersdk.WorkspaceAgent, error) { + tc.opts.Fetch = func(_ context.Context, _ uuid.UUID) (codersdk.WorkspaceAgent, error) { var err error if len(tc.iter) > 0 { err = tc.iter[0](ctx, &agent, logs) @@ -321,7 +319,7 @@ func TestAgent(t *testing.T) { tc.opts.FetchLogs = func(_ context.Context, _ uuid.UUID, _ int64, _ bool) (<-chan []codersdk.WorkspaceAgentStartupLog, io.Closer, error) { return logs, closeFunc(func() error { return nil }), nil } - err := cliui.Agent(inv.Context(), &buf, tc.opts) + err := cliui.Agent(inv.Context(), &buf, uuid.Nil, tc.opts) return err }, } @@ -350,4 +348,37 @@ func TestAgent(t *testing.T) { } }) } + + t.Run("NotInfinite", func(t *testing.T) { + t.Parallel() + var fetchCalled uint64 + + cmd := &clibase.Cmd{ + Handler: func(inv *clibase.Invocation) error { + buf := bytes.Buffer{} + err := cliui.Agent(inv.Context(), &buf, uuid.Nil, cliui.AgentOptions{ + FetchInterval: 10 * time.Millisecond, + Fetch: func(ctx context.Context, agentID uuid.UUID) (codersdk.WorkspaceAgent, error) { + atomic.AddUint64(&fetchCalled, 1) + + return codersdk.WorkspaceAgent{ + Status: codersdk.WorkspaceAgentConnected, + LifecycleState: codersdk.WorkspaceAgentLifecycleReady, + }, nil + }, + }) + if err != nil { + return err + } + + require.Never(t, func() bool { + called := atomic.LoadUint64(&fetchCalled) + return called > 5 || called == 0 + }, time.Second, 100*time.Millisecond) + + return nil + }, + } + require.NoError(t, cmd.Invoke().Run()) + }) } diff --git a/cli/portforward.go b/cli/portforward.go index a7f42ed650..01dc4a637e 100644 --- a/cli/portforward.go +++ b/cli/portforward.go @@ -90,11 +90,9 @@ func (r *RootCmd) portForward() *clibase.Cmd { } } - err = cliui.Agent(ctx, inv.Stderr, cliui.AgentOptions{ - Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) { - return client.WorkspaceAgent(ctx, workspaceAgent.ID) - }, - Wait: false, + err = cliui.Agent(ctx, inv.Stderr, workspaceAgent.ID, cliui.AgentOptions{ + Fetch: client.WorkspaceAgent, + Wait: false, }) if err != nil { return xerrors.Errorf("await agent: %w", err) diff --git a/cli/speedtest.go b/cli/speedtest.go index e3ebf65341..150605b333 100644 --- a/cli/speedtest.go +++ b/cli/speedtest.go @@ -40,11 +40,9 @@ func (r *RootCmd) speedtest() *clibase.Cmd { return err } - err = cliui.Agent(ctx, inv.Stderr, cliui.AgentOptions{ - Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) { - return client.WorkspaceAgent(ctx, workspaceAgent.ID) - }, - Wait: false, + err = cliui.Agent(ctx, inv.Stderr, workspaceAgent.ID, cliui.AgentOptions{ + Fetch: client.WorkspaceAgent, + Wait: false, }) if err != nil { return xerrors.Errorf("await agent: %w", err) diff --git a/cli/ssh.go b/cli/ssh.go index e98546cdad..def41c091d 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -175,10 +175,8 @@ func (r *RootCmd) ssh() *clibase.Cmd { // OpenSSH passes stderr directly to the calling TTY. // This is required in "stdio" mode so a connecting indicator can be displayed. - err = cliui.Agent(ctx, inv.Stderr, cliui.AgentOptions{ - Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) { - return client.WorkspaceAgent(ctx, workspaceAgent.ID) - }, + err = cliui.Agent(ctx, inv.Stderr, workspaceAgent.ID, cliui.AgentOptions{ + Fetch: client.WorkspaceAgent, FetchLogs: client.WorkspaceAgentStartupLogsAfter, Wait: wait, }) diff --git a/cmd/cliui/main.go b/cmd/cliui/main.go index f972afdd4e..cebf354cf4 100644 --- a/cmd/cliui/main.go +++ b/cmd/cliui/main.go @@ -214,10 +214,10 @@ func main() { agent.LastConnectedAt = &lastConnectedAt }, } - err := cliui.Agent(inv.Context(), inv.Stdout, cliui.AgentOptions{ + err := cliui.Agent(inv.Context(), inv.Stdout, uuid.Nil, cliui.AgentOptions{ FetchInterval: 100 * time.Millisecond, Wait: true, - Fetch: func(_ context.Context) (codersdk.WorkspaceAgent, error) { + Fetch: func(_ context.Context, _ uuid.UUID) (codersdk.WorkspaceAgent, error) { if len(fetchSteps) == 0 { return agent, nil }