diff --git a/coderd/workspaceapps/apptest/apptest.go b/coderd/workspaceapps/apptest/apptest.go index e20ba046ba..ab90b0a4b4 100644 --- a/coderd/workspaceapps/apptest/apptest.go +++ b/coderd/workspaceapps/apptest/apptest.go @@ -22,7 +22,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/xerrors" - "nhooyr.io/websocket" "github.com/coder/coder/coderd/coderdtest" "github.com/coder/coder/coderd/rbac" @@ -72,7 +71,13 @@ func Run(t *testing.T, factory DeploymentFactory) { // Run the test against the path app hostname since that's where the // reconnecting-pty proxy server we want to test is mounted. client := appDetails.AppClient(t) - conn, err := client.WorkspaceAgentReconnectingPTY(ctx, appDetails.Agent.ID, uuid.New(), 80, 80, "/bin/bash") + conn, err := client.WorkspaceAgentReconnectingPTY(ctx, codersdk.WorkspaceAgentReconnectingPTYOpts{ + AgentID: appDetails.Agent.ID, + Reconnect: uuid.New(), + Height: 80, + Width: 80, + Command: "/bin/bash", + }) require.NoError(t, err) defer conn.Close() @@ -125,29 +130,42 @@ func Run(t *testing.T, factory DeploymentFactory) { }) require.NoError(t, err) - // Try to connect to the endpoint with the signed token and no other - // authentication. - q := u.Query() - q.Set("reconnect", uuid.NewString()) - q.Set("height", strconv.Itoa(24)) - q.Set("width", strconv.Itoa(80)) - q.Set("command", `/bin/sh -c "echo test"`) - q.Set(codersdk.SignedAppTokenQueryParameter, issueRes.SignedToken) - u.RawQuery = q.Encode() + // Make an unauthenticated client. + unauthedAppClient := codersdk.New(appDetails.AppClient(t).URL) + conn, err := unauthedAppClient.WorkspaceAgentReconnectingPTY(ctx, codersdk.WorkspaceAgentReconnectingPTYOpts{ + AgentID: appDetails.Agent.ID, + Reconnect: uuid.New(), + Height: 80, + Width: 80, + Command: "/bin/bash", + SignedToken: issueRes.SignedToken, + }) + require.NoError(t, err) + defer conn.Close() - //nolint:bodyclose - wsConn, res, err := websocket.Dial(ctx, u.String(), nil) - if !assert.NoError(t, err) { - dump, err := httputil.DumpResponse(res, true) - if err == nil { - t.Log(string(dump)) - } - return - } - defer wsConn.Close(websocket.StatusNormalClosure, "") - conn := websocket.NetConn(ctx, wsConn, websocket.MessageBinary) + // First attempt to resize the TTY. + // The websocket will close if it fails! + data, err := json.Marshal(codersdk.ReconnectingPTYRequest{ + Height: 250, + Width: 250, + }) + require.NoError(t, err) + _, err = conn.Write(data) + require.NoError(t, err) bufRead := bufio.NewReader(conn) + // Brief pause to reduce the likelihood that we send keystrokes while + // the shell is simultaneously sending a prompt. + time.Sleep(100 * time.Millisecond) + + data, err = json.Marshal(codersdk.ReconnectingPTYRequest{ + Data: "echo test\r\n", + }) + require.NoError(t, err) + _, err = conn.Write(data) + require.NoError(t, err) + + expectLine(t, bufRead, matchEchoCommand) expectLine(t, bufRead, matchEchoOutput) }) }) diff --git a/codersdk/workspaceagents.go b/codersdk/workspaceagents.go index 87a13d45de..8f418eebf2 100644 --- a/codersdk/workspaceagents.go +++ b/codersdk/workspaceagents.go @@ -385,32 +385,55 @@ func (c *Client) IssueReconnectingPTYSignedToken(ctx context.Context, req IssueR return resp, json.NewDecoder(res.Body).Decode(&resp) } +// @typescript-ignore:WorkspaceAgentReconnectingPTYOpts +type WorkspaceAgentReconnectingPTYOpts struct { + AgentID uuid.UUID + Reconnect uuid.UUID + Width uint16 + Height uint16 + Command string + + // SignedToken is an optional signed token from the + // issue-reconnecting-pty-signed-token endpoint. If set, the session token + // on the client will not be sent. + SignedToken string +} + // WorkspaceAgentReconnectingPTY spawns a PTY that reconnects using the token provided. // It communicates using `agent.ReconnectingPTYRequest` marshaled as JSON. // Responses are PTY output that can be rendered. -func (c *Client) WorkspaceAgentReconnectingPTY(ctx context.Context, agentID, reconnect uuid.UUID, height, width uint16, command string) (net.Conn, error) { - serverURL, err := c.URL.Parse(fmt.Sprintf("/api/v2/workspaceagents/%s/pty", agentID)) +func (c *Client) WorkspaceAgentReconnectingPTY(ctx context.Context, opts WorkspaceAgentReconnectingPTYOpts) (net.Conn, error) { + serverURL, err := c.URL.Parse(fmt.Sprintf("/api/v2/workspaceagents/%s/pty", opts.AgentID)) if err != nil { return nil, xerrors.Errorf("parse url: %w", err) } q := serverURL.Query() - q.Set("reconnect", reconnect.String()) - q.Set("height", strconv.Itoa(int(height))) - q.Set("width", strconv.Itoa(int(width))) - q.Set("command", command) + q.Set("reconnect", opts.Reconnect.String()) + q.Set("width", strconv.Itoa(int(opts.Width))) + q.Set("height", strconv.Itoa(int(opts.Height))) + q.Set("command", opts.Command) + // If we're using a signed token, set the query parameter. + if opts.SignedToken != "" { + q.Set(SignedAppTokenQueryParameter, opts.SignedToken) + } serverURL.RawQuery = q.Encode() - jar, err := cookiejar.New(nil) - if err != nil { - return nil, xerrors.Errorf("create cookie jar: %w", err) - } - jar.SetCookies(serverURL, []*http.Cookie{{ - Name: SessionTokenCookie, - Value: c.SessionToken(), - }}) - httpClient := &http.Client{ - Jar: jar, - Transport: c.HTTPClient.Transport, + // If we're not using a signed token, we need to set the session token as a + // cookie. + httpClient := c.HTTPClient + if opts.SignedToken == "" { + jar, err := cookiejar.New(nil) + if err != nil { + return nil, xerrors.Errorf("create cookie jar: %w", err) + } + jar.SetCookies(serverURL, []*http.Cookie{{ + Name: SessionTokenCookie, + Value: c.SessionToken(), + }}) + httpClient = &http.Client{ + Jar: jar, + Transport: c.HTTPClient.Transport, + } } conn, res, err := websocket.Dial(ctx, serverURL.String(), &websocket.DialOptions{ HTTPClient: httpClient, diff --git a/scaletest/reconnectingpty/run.go b/scaletest/reconnectingpty/run.go index 5c7a042812..4069220c5b 100644 --- a/scaletest/reconnectingpty/run.go +++ b/scaletest/reconnectingpty/run.go @@ -64,7 +64,13 @@ func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) error { _, _ = fmt.Fprintf(logs, "\tHeight: %d\n", height) _, _ = fmt.Fprintf(logs, "\tCommand: %q\n\n", r.cfg.Init.Command) - conn, err := r.client.WorkspaceAgentReconnectingPTY(ctx, r.cfg.AgentID, id, width, height, r.cfg.Init.Command) + conn, err := r.client.WorkspaceAgentReconnectingPTY(ctx, codersdk.WorkspaceAgentReconnectingPTYOpts{ + AgentID: r.cfg.AgentID, + Reconnect: id, + Width: width, + Height: height, + Command: r.cfg.Init.Command, + }) if err != nil { return xerrors.Errorf("open reconnecting PTY: %w", err) }