feat: add reconnectingpty loadtest (#5083)

This commit is contained in:
Dean Sheather
2022-11-18 02:57:15 +10:00
committed by GitHub
parent acf34d4295
commit 69e8c9e7b4
11 changed files with 607 additions and 20 deletions

View File

@ -725,6 +725,7 @@ func (a *agent) handleSSHSession(session ssh.Session) (retErr error) {
func (a *agent) handleReconnectingPTY(ctx context.Context, msg codersdk.ReconnectingPTYInit, conn net.Conn) {
defer conn.Close()
connectionID := uuid.NewString()
var rpty *reconnectingPTY
rawRPTY, ok := a.reconnectingPTYs.Load(msg.ID)
if ok {
@ -760,7 +761,11 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, msg codersdk.Reconnec
a.closeMutex.Unlock()
ctx, cancelFunc := context.WithCancel(ctx)
rpty = &reconnectingPTY{
activeConns: make(map[string]net.Conn),
activeConns: map[string]net.Conn{
// We have to put the connection in the map instantly otherwise
// the connection won't be closed if the process instantly dies.
connectionID: conn,
},
ptty: ptty,
// Timeouts created with an after func can be reset!
timeout: time.AfterFunc(a.reconnectingPTYTimeout, cancelFunc),
@ -827,7 +832,6 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, msg codersdk.Reconnec
a.logger.Warn(ctx, "write reconnecting pty buffer", slog.F("id", msg.ID), slog.Error(err))
return
}
connectionID := uuid.NewString()
// Multiple connections to the same TTY are permitted.
// This could easily be used for terminal sharing, but
// we do it because it's a nice user experience to

View File

@ -83,7 +83,7 @@ func TestAgent(t *testing.T) {
conn, stats, _ := setupAgent(t, codersdk.WorkspaceAgentMetadata{}, 0)
ptyConn, err := conn.ReconnectingPTY(ctx, uuid.NewString(), 128, 128, "/bin/bash")
ptyConn, err := conn.ReconnectingPTY(ctx, uuid.New(), 128, 128, "/bin/bash")
require.NoError(t, err)
defer ptyConn.Close()
@ -405,7 +405,7 @@ func TestAgent(t *testing.T) {
defer cancel()
conn, _, _ := setupAgent(t, codersdk.WorkspaceAgentMetadata{}, 0)
id := uuid.NewString()
id := uuid.New()
netConn, err := conn.ReconnectingPTY(ctx, id, 100, 100, "/bin/bash")
require.NoError(t, err)
bufRead := bufio.NewReader(netConn)

View File

@ -10,6 +10,7 @@ import (
"github.com/coder/coder/loadtest/agentconn"
"github.com/coder/coder/loadtest/harness"
"github.com/coder/coder/loadtest/placebo"
"github.com/coder/coder/loadtest/reconnectingpty"
"github.com/coder/coder/loadtest/workspacebuild"
)
@ -90,6 +91,7 @@ type LoadTestType string
const (
LoadTestTypeAgentConn LoadTestType = "agentconn"
LoadTestTypePlacebo LoadTestType = "placebo"
LoadTestTypeReconnectingPTY LoadTestType = "reconnectingpty"
LoadTestTypeWorkspaceBuild LoadTestType = "workspacebuild"
)
@ -104,6 +106,8 @@ type LoadTest struct {
AgentConn *agentconn.Config `json:"agentconn,omitempty"`
// Placebo must be set if type == "placebo".
Placebo *placebo.Config `json:"placebo,omitempty"`
// ReconnectingPTY must be set if type == "reconnectingpty".
ReconnectingPTY *reconnectingpty.Config `json:"reconnectingpty,omitempty"`
// WorkspaceBuild must be set if type == "workspacebuild".
WorkspaceBuild *workspacebuild.Config `json:"workspacebuild,omitempty"`
}
@ -120,6 +124,11 @@ func (t LoadTest) NewRunner(client *codersdk.Client) (harness.Runnable, error) {
return nil, xerrors.New("placebo config must be set")
}
return placebo.NewRunner(*t.Placebo), nil
case LoadTestTypeReconnectingPTY:
if t.ReconnectingPTY == nil {
return nil, xerrors.New("reconnectingpty config must be set")
}
return reconnectingpty.NewRunner(client, *t.ReconnectingPTY), nil
case LoadTestTypeWorkspaceBuild:
if t.WorkspaceBuild == nil {
return nil, xerrors.Errorf("workspacebuild config must be set")
@ -185,6 +194,15 @@ func (t *LoadTest) Validate() error {
if err != nil {
return xerrors.Errorf("validate placebo: %w", err)
}
case LoadTestTypeReconnectingPTY:
if t.ReconnectingPTY == nil {
return xerrors.Errorf("reconnectingpty test type must specify reconnectingpty")
}
err := t.ReconnectingPTY.Validate()
if err != nil {
return xerrors.Errorf("validate reconnectingpty: %w", err)
}
case LoadTestTypeWorkspaceBuild:
if t.WorkspaceBuild == nil {
return xerrors.New("workspacebuild test type must specify workspacebuild")

View File

@ -6,7 +6,6 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/netip"
@ -26,6 +25,7 @@ import (
"tailscale.com/tailcfg"
"cdr.dev/slog"
"github.com/coder/coder/agent"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/gitauth"
"github.com/coder/coder/coderd/httpapi"
@ -247,17 +247,13 @@ func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
return
}
defer release()
ptNetConn, err := agentConn.ReconnectingPTY(ctx, reconnect.String(), uint16(height), uint16(width), r.URL.Query().Get("command"))
ptNetConn, err := agentConn.ReconnectingPTY(ctx, reconnect, uint16(height), uint16(width), r.URL.Query().Get("command"))
if err != nil {
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("dial: %s", err))
return
}
defer ptNetConn.Close()
// Pipe the ends together!
go func() {
_, _ = io.Copy(wsNetConn, ptNetConn)
}()
_, _ = io.Copy(ptNetConn, wsNetConn)
agent.Bicopy(ctx, wsNetConn, ptNetConn)
}
func (api *API) workspaceAgentListeningPorts(rw http.ResponseWriter, r *http.Request) {

View File

@ -14,6 +14,7 @@ import (
"strings"
"time"
"github.com/google/uuid"
"golang.org/x/crypto/ssh"
"golang.org/x/xerrors"
"tailscale.com/net/speedtest"
@ -158,13 +159,13 @@ func (c *AgentConn) Close() error {
// @typescript-ignore ReconnectingPTYInit
type ReconnectingPTYInit struct {
ID string
ID uuid.UUID
Height uint16
Width uint16
Command string
}
func (c *AgentConn) ReconnectingPTY(ctx context.Context, id string, height, width uint16, command string) (net.Conn, error) {
func (c *AgentConn) ReconnectingPTY(ctx context.Context, id uuid.UUID, height, width uint16, command string) (net.Conn, error) {
ctx, span := tracing.StartSpan(ctx)
defer span.End()

View File

@ -268,7 +268,7 @@ func readBodyAsError(res *http.Response) error {
return &Error{
statusCode: res.StatusCode,
Response: Response{
Message: "unexpected non-JSON response",
Message: fmt.Sprintf("unexpected non-JSON response %q", contentType),
Detail: string(resp),
},
Helper: helper,

View File

@ -501,11 +501,18 @@ func (c *Client) PostWorkspaceAgentVersion(ctx context.Context, version string)
// WorkspaceAgentReconnectingPTY spawns a PTY that reconnects using the token provided.
// It communicates using `agent.ReconnectingPTYRequest` marshaled as JSON.
// Responses are PTY output that can be rendered.
func (c *Client) WorkspaceAgentReconnectingPTY(ctx context.Context, agentID, reconnect uuid.UUID, height, width int, command string) (net.Conn, error) {
serverURL, err := c.URL.Parse(fmt.Sprintf("/api/v2/workspaceagents/%s/pty?reconnect=%s&height=%d&width=%d&command=%s", agentID, reconnect, height, width, command))
func (c *Client) WorkspaceAgentReconnectingPTY(ctx context.Context, agentID, reconnect uuid.UUID, height, width uint16, command string) (net.Conn, error) {
serverURL, err := c.URL.Parse(fmt.Sprintf("/api/v2/workspaceagents/%s/pty", agentID))
if err != nil {
return nil, xerrors.Errorf("parse url: %w", err)
}
q := serverURL.Query()
q.Set("reconnect", reconnect.String())
q.Set("height", strconv.Itoa(int(height)))
q.Set("width", strconv.Itoa(int(width)))
q.Set("command", command)
serverURL.RawQuery = q.Encode()
jar, err := cookiejar.New(nil)
if err != nil {
return nil, xerrors.Errorf("create cookie jar: %w", err)

View File

@ -0,0 +1,52 @@
package reconnectingpty
import (
"time"
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/codersdk"
)
const (
DefaultWidth = 80
DefaultHeight = 24
DefaultTimeout = httpapi.Duration(5 * time.Minute)
)
type Config struct {
// AgentID is the ID of the agent to run the command in.
AgentID uuid.UUID `json:"agent_id"`
// Init is the initial packet to send to the agent when launching the TTY.
// If the ID is not set, defaults to a random UUID. If the width or height
// is not set, defaults to 80x24. If the command is not set, defaults to
// opening a login shell. Command runs in the default shell.
Init codersdk.ReconnectingPTYInit `json:"init"`
// Timeout is the duration to wait for the command to exit. Defaults to
// 5 minutes.
Timeout httpapi.Duration `json:"timeout"`
// ExpectTimeout means we expect the timeout to be reached (i.e. the command
// doesn't exit within the given timeout).
ExpectTimeout bool `json:"expect_timeout"`
// ExpectOutput checks that the given string is present in the output. The
// string must be present on a single line.
ExpectOutput string `json:"expect_output"`
// LogOutput determines whether the output of the command should be logged.
// For commands that produce a lot of output this should be disabled to
// avoid loadtest OOMs. All log output is still read and discarded if this
// is false.
LogOutput bool `json:"log_output"`
}
func (c Config) Validate() error {
if c.AgentID == uuid.Nil {
return xerrors.New("agent_id must be set")
}
if c.Timeout < 0 {
return xerrors.New("timeout must be a positive value")
}
return nil
}

View File

@ -0,0 +1,78 @@
package reconnectingpty_test
import (
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/loadtest/reconnectingpty"
)
func Test_Config(t *testing.T) {
t.Parallel()
id := uuid.New()
cases := []struct {
name string
config reconnectingpty.Config
errContains string
}{
{
name: "OKBasic",
config: reconnectingpty.Config{
AgentID: id,
},
},
{
name: "OKFull",
config: reconnectingpty.Config{
AgentID: id,
Init: codersdk.ReconnectingPTYInit{
ID: id,
Width: 80,
Height: 24,
Command: "echo 'hello world'",
},
Timeout: httpapi.Duration(time.Minute),
ExpectTimeout: false,
ExpectOutput: "hello world",
LogOutput: true,
},
},
{
name: "NoAgentID",
config: reconnectingpty.Config{
AgentID: uuid.Nil,
},
errContains: "agent_id must be set",
},
{
name: "NegativeTimeout",
config: reconnectingpty.Config{
AgentID: id,
Timeout: httpapi.Duration(-time.Minute),
},
errContains: "timeout must be a positive value",
},
}
for _, c := range cases {
c := c
t.Run(c.name, func(t *testing.T) {
t.Parallel()
err := c.config.Validate()
if c.errContains != "" {
require.Error(t, err)
require.Contains(t, err.Error(), c.errContains)
} else {
require.NoError(t, err)
}
})
}
}

View File

@ -0,0 +1,137 @@
package reconnectingpty
import (
"bufio"
"context"
"fmt"
"io"
"strings"
"time"
"github.com/google/uuid"
"golang.org/x/xerrors"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/sloghuman"
"github.com/coder/coder/coderd/tracing"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/loadtest/harness"
"github.com/coder/coder/loadtest/loadtestutil"
)
type Runner struct {
client *codersdk.Client
cfg Config
}
var _ harness.Runnable = &Runner{}
func NewRunner(client *codersdk.Client, cfg Config) *Runner {
return &Runner{
client: client,
cfg: cfg,
}
}
// Run implements Runnable.
func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) error {
ctx, span := tracing.StartSpan(ctx)
defer span.End()
logs = loadtestutil.NewSyncWriter(logs)
logger := slog.Make(sloghuman.Sink(logs)).Leveled(slog.LevelDebug)
r.client.Logger = logger
r.client.LogBodies = true
var (
id = r.cfg.Init.ID
width = r.cfg.Init.Width
height = r.cfg.Init.Height
)
if id == uuid.Nil {
id = uuid.New()
}
if width == 0 {
width = DefaultWidth
}
if height == 0 {
height = DefaultHeight
}
_, _ = fmt.Fprintln(logs, "Opening reconnecting PTY connection to agent via coderd...")
_, _ = fmt.Fprintf(logs, "\tID: %s\n", id.String())
_, _ = fmt.Fprintf(logs, "\tWidth: %d\n", width)
_, _ = fmt.Fprintf(logs, "\tHeight: %d\n", height)
_, _ = fmt.Fprintf(logs, "\tCommand: %q\n\n", r.cfg.Init.Command)
conn, err := r.client.WorkspaceAgentReconnectingPTY(ctx, r.cfg.AgentID, id, width, height, r.cfg.Init.Command)
if err != nil {
return xerrors.Errorf("open reconnecting PTY: %w", err)
}
defer conn.Close()
var (
copyTimeout = r.cfg.Timeout
copyOutput = io.Discard
)
if copyTimeout == 0 {
copyTimeout = DefaultTimeout
}
if r.cfg.LogOutput {
_, _ = fmt.Fprintln(logs, "Output:")
copyOutput = logs
}
copyCtx, copyCancel := context.WithTimeout(ctx, time.Duration(copyTimeout))
matched, err := copyContext(copyCtx, copyOutput, conn, r.cfg.ExpectOutput)
copyCancel()
if r.cfg.ExpectTimeout {
if err == nil {
return xerrors.Errorf("expected timeout, but the command exited successfully")
}
if !xerrors.Is(err, context.DeadlineExceeded) {
return xerrors.Errorf("expected timeout, but got a different error: %w", err)
}
} else if err != nil {
return xerrors.Errorf("copy context: %w", err)
}
if !matched {
return xerrors.Errorf("expected string %q not found in output", r.cfg.ExpectOutput)
}
return nil
}
func copyContext(ctx context.Context, dst io.Writer, src io.Reader, expectOutput string) (bool, error) {
var (
copyErr = make(chan error)
matched = expectOutput == ""
)
go func() {
defer close(copyErr)
scanner := bufio.NewScanner(src)
for scanner.Scan() {
if expectOutput != "" && strings.Contains(scanner.Text(), expectOutput) {
matched = true
}
_, err := dst.Write([]byte("\t" + scanner.Text() + "\n"))
if err != nil {
copyErr <- xerrors.Errorf("write to logs: %w", err)
return
}
}
if scanner.Err() != nil {
copyErr <- xerrors.Errorf("read from reconnecting PTY: %w", scanner.Err())
return
}
}()
select {
case <-ctx.Done():
return matched, ctx.Err()
case err := <-copyErr:
return matched, err
}
}

View File

@ -0,0 +1,294 @@
package reconnectingpty_test
import (
"bytes"
"context"
"runtime"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/agent"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/loadtest/reconnectingpty"
"github.com/coder/coder/provisioner/echo"
"github.com/coder/coder/provisionersdk/proto"
"github.com/coder/coder/testutil"
)
func Test_Runner(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
t.Skip("PTY is flakey on Windows")
}
t.Run("OK", func(t *testing.T) {
t.Parallel()
client, agentID := setupRunnerTest(t)
runner := reconnectingpty.NewRunner(client, reconnectingpty.Config{
AgentID: agentID,
Init: codersdk.ReconnectingPTYInit{
Command: "echo 'hello world' && sleep 1",
},
LogOutput: true,
})
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
logs := bytes.NewBuffer(nil)
err := runner.Run(ctx, "1", logs)
logStr := logs.String()
t.Log("Runner logs:\n\n" + logStr)
require.NoError(t, err)
require.Contains(t, logStr, "Output:")
require.Contains(t, logStr, "\thello world")
})
t.Run("NoLogOutput", func(t *testing.T) {
t.Parallel()
client, agentID := setupRunnerTest(t)
runner := reconnectingpty.NewRunner(client, reconnectingpty.Config{
AgentID: agentID,
Init: codersdk.ReconnectingPTYInit{
Command: "echo 'hello world'",
},
LogOutput: false,
})
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
logs := bytes.NewBuffer(nil)
err := runner.Run(ctx, "1", logs)
logStr := logs.String()
t.Log("Runner logs:\n\n" + logStr)
require.NoError(t, err)
require.NotContains(t, logStr, "Output:")
require.NotContains(t, logStr, "\thello world")
})
t.Run("Timeout", func(t *testing.T) {
t.Parallel()
t.Run("NoTimeout", func(t *testing.T) {
t.Parallel()
client, agentID := setupRunnerTest(t)
runner := reconnectingpty.NewRunner(client, reconnectingpty.Config{
AgentID: agentID,
Init: codersdk.ReconnectingPTYInit{
Command: "echo 'hello world'",
},
Timeout: httpapi.Duration(5 * time.Second),
LogOutput: true,
})
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
logs := bytes.NewBuffer(nil)
err := runner.Run(ctx, "1", logs)
logStr := logs.String()
t.Log("Runner logs:\n\n" + logStr)
require.NoError(t, err)
})
t.Run("Timeout", func(t *testing.T) {
t.Parallel()
client, agentID := setupRunnerTest(t)
runner := reconnectingpty.NewRunner(client, reconnectingpty.Config{
AgentID: agentID,
Init: codersdk.ReconnectingPTYInit{
Command: "sleep 5",
},
Timeout: httpapi.Duration(2 * time.Second),
LogOutput: true,
})
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
logs := bytes.NewBuffer(nil)
err := runner.Run(ctx, "1", logs)
logStr := logs.String()
t.Log("Runner logs:\n\n" + logStr)
require.Error(t, err)
require.ErrorIs(t, err, context.DeadlineExceeded)
})
})
t.Run("ExpectTimeout", func(t *testing.T) {
t.Parallel()
t.Run("Timeout", func(t *testing.T) {
t.Parallel()
client, agentID := setupRunnerTest(t)
runner := reconnectingpty.NewRunner(client, reconnectingpty.Config{
AgentID: agentID,
Init: codersdk.ReconnectingPTYInit{
Command: "sleep 5",
},
Timeout: httpapi.Duration(2 * time.Second),
ExpectTimeout: true,
LogOutput: true,
})
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
logs := bytes.NewBuffer(nil)
err := runner.Run(ctx, "1", logs)
logStr := logs.String()
t.Log("Runner logs:\n\n" + logStr)
require.NoError(t, err)
})
t.Run("NoTimeout", func(t *testing.T) {
t.Parallel()
client, agentID := setupRunnerTest(t)
runner := reconnectingpty.NewRunner(client, reconnectingpty.Config{
AgentID: agentID,
Init: codersdk.ReconnectingPTYInit{
Command: "echo 'hello world'",
},
Timeout: httpapi.Duration(5 * time.Second),
ExpectTimeout: true,
LogOutput: true,
})
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
logs := bytes.NewBuffer(nil)
err := runner.Run(ctx, "1", logs)
logStr := logs.String()
t.Log("Runner logs:\n\n" + logStr)
require.Error(t, err)
require.ErrorContains(t, err, "expected timeout")
})
})
t.Run("ExpectOutput", func(t *testing.T) {
t.Parallel()
t.Run("Matches", func(t *testing.T) {
t.Parallel()
client, agentID := setupRunnerTest(t)
runner := reconnectingpty.NewRunner(client, reconnectingpty.Config{
AgentID: agentID,
Init: codersdk.ReconnectingPTYInit{
Command: "echo 'hello world' && sleep 1",
},
ExpectOutput: "hello world",
LogOutput: false,
})
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
logs := bytes.NewBuffer(nil)
err := runner.Run(ctx, "1", logs)
logStr := logs.String()
t.Log("Runner logs:\n\n" + logStr)
require.NoError(t, err)
})
t.Run("NotMatches", func(t *testing.T) {
t.Parallel()
client, agentID := setupRunnerTest(t)
runner := reconnectingpty.NewRunner(client, reconnectingpty.Config{
AgentID: agentID,
Init: codersdk.ReconnectingPTYInit{
Command: "echo 'hello world' && sleep 1",
},
ExpectOutput: "bello borld",
LogOutput: false,
})
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
logs := bytes.NewBuffer(nil)
err := runner.Run(ctx, "1", logs)
logStr := logs.String()
t.Log("Runner logs:\n\n" + logStr)
require.Error(t, err)
require.ErrorContains(t, err, `expected string "bello borld" not found`)
})
})
}
func setupRunnerTest(t *testing.T) (client *codersdk.Client, agentID uuid.UUID) {
t.Helper()
client = coderdtest.New(t, &coderdtest.Options{
IncludeProvisionerDaemon: true,
})
user := coderdtest.CreateFirstUser(t, client)
authToken := uuid.NewString()
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
Parse: echo.ParseComplete,
ProvisionPlan: echo.ProvisionComplete,
ProvisionApply: []*proto.Provision_Response{{
Type: &proto.Provision_Response_Complete{
Complete: &proto.Provision_Complete{
Resources: []*proto.Resource{{
Name: "example",
Type: "aws_instance",
Agents: []*proto.Agent{{
Id: uuid.NewString(),
Name: "agent",
Auth: &proto.Agent_Token{
Token: authToken,
},
Apps: []*proto.App{},
}},
}},
},
},
}},
})
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
agentClient := codersdk.New(client.URL)
agentClient.SetSessionToken(authToken)
agentCloser := agent.New(agent.Options{
Client: agentClient,
Logger: slogtest.Make(t, nil).Named("agent"),
})
t.Cleanup(func() {
_ = agentCloser.Close()
})
resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID)
return client, resources[0].Agents[0].ID
}