fix(cli/ssh): prevent reads/writes to stdin/stdout in stdio mode (#12045)

Fixes #11530
This commit is contained in:
Mathias Fredriksson
2024-02-08 13:09:42 +02:00
committed by GitHub
parent 151aaadc23
commit e659957b65
2 changed files with 170 additions and 1 deletions

View File

@ -1,6 +1,7 @@
package cli_test
import (
"bufio"
"bytes"
"context"
"crypto/ecdsa"
@ -338,6 +339,157 @@ func TestSSH(t *testing.T) {
<-cmdDone
})
t.Run("Stdio_StartStoppedWorkspace_CleanStdout", func(t *testing.T) {
t.Parallel()
authToken := uuid.NewString()
ownerClient := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
owner := coderdtest.CreateFirstUser(t, ownerClient)
client, _ := coderdtest.CreateAnotherUser(t, ownerClient, owner.OrganizationID, rbac.RoleTemplateAdmin())
version := coderdtest.CreateTemplateVersion(t, client, owner.OrganizationID, &echo.Responses{
Parse: echo.ParseComplete,
ProvisionPlan: echo.PlanComplete,
ProvisionApply: echo.ProvisionApplyWithAgent(authToken),
})
coderdtest.AwaitTemplateVersionJobCompleted(t, client, version.ID)
template := coderdtest.CreateTemplate(t, client, owner.OrganizationID, version.ID)
workspace := coderdtest.CreateWorkspace(t, client, owner.OrganizationID, template.ID)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspace.LatestBuild.ID)
// Stop the workspace
workspaceBuild := coderdtest.CreateWorkspaceBuild(t, client, workspace, database.WorkspaceTransitionStop)
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, workspaceBuild.ID)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
clientStdinR, clientStdinW := io.Pipe()
// Here's a simple flowchart for how these pipes are used:
//
// flowchart LR
// A[ProxyCommand] --> B[captureProxyCommandStdoutW]
// B --> C[captureProxyCommandStdoutR]
// C --> VA[Validate output]
// C --> D[proxyCommandStdoutW]
// D --> E[proxyCommandStdoutR]
// E --> F[SSH Client]
proxyCommandStdoutR, proxyCommandStdoutW := io.Pipe()
captureProxyCommandStdoutR, captureProxyCommandStdoutW := io.Pipe()
closePipes := func() {
for _, c := range []io.Closer{clientStdinR, clientStdinW, proxyCommandStdoutR, proxyCommandStdoutW, captureProxyCommandStdoutR, captureProxyCommandStdoutW} {
_ = c.Close()
}
}
defer closePipes()
tGo(t, func() {
<-ctx.Done()
closePipes()
})
// Here we start a monitor for the output produced by the proxy command,
// which is read by the SSH client. This is done to validate that the
// output is clean.
proxyCommandOutputBuf := make(chan byte, 4096)
tGo(t, func() {
defer close(proxyCommandOutputBuf)
gotHeader := false
buf := bytes.Buffer{}
r := bufio.NewReader(captureProxyCommandStdoutR)
for {
b, err := r.ReadByte()
if err != nil {
if errors.Is(err, io.ErrClosedPipe) {
return
}
assert.NoError(t, err, "read byte failed")
return
}
if b == '\n' || b == '\r' {
out := buf.Bytes()
t.Logf("monitorServerOutput: %q (%#x)", out, out)
buf.Reset()
// Ideally we would do further verification, but that would
// involve parsing the SSH protocol to look for output that
// doesn't belong. This at least ensures that no garbage is
// being sent to the SSH client before trying to connect.
if !gotHeader {
gotHeader = true
assert.Equal(t, "SSH-2.0-Go", string(out), "invalid header")
}
} else {
_ = buf.WriteByte(b)
}
select {
case proxyCommandOutputBuf <- b:
case <-ctx.Done():
return
}
}
})
tGo(t, func() {
defer proxyCommandStdoutW.Close()
// Range closed by above goroutine.
for b := range proxyCommandOutputBuf {
_, err := proxyCommandStdoutW.Write([]byte{b})
if err != nil {
if errors.Is(err, io.ErrClosedPipe) {
return
}
assert.NoError(t, err, "write byte failed")
return
}
}
})
// Start the SSH stdio command.
inv, root := clitest.New(t, "ssh", "--stdio", workspace.Name)
clitest.SetupConfig(t, client, root)
inv.Stdin = clientStdinR
inv.Stdout = captureProxyCommandStdoutW
inv.Stderr = io.Discard
cmdDone := tGo(t, func() {
err := inv.WithContext(ctx).Run()
assert.NoError(t, err)
})
tGo(t, func() {
// When the agent connects, the workspace was started, and we should
// have access to the shell.
_ = agenttest.New(t, client.URL, authToken)
coderdtest.NewWorkspaceAgentWaiter(t, client, workspace.ID).Wait()
})
conn, channels, requests, err := ssh.NewClientConn(&stdioConn{
Reader: proxyCommandStdoutR,
Writer: clientStdinW,
}, "", &ssh.ClientConfig{
// #nosec
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
})
require.NoError(t, err)
defer conn.Close()
sshClient := ssh.NewClient(conn, channels, requests)
session, err := sshClient.NewSession()
require.NoError(t, err)
defer session.Close()
command := "sh -c exit"
if runtime.GOOS == "windows" {
command = "cmd.exe /c exit"
}
err = session.Run(command)
require.NoError(t, err)
err = sshClient.Close()
require.NoError(t, err)
_ = clientStdinR.Close()
<-cmdDone
})
t.Run("Stdio_RemoteForward_Signal", func(t *testing.T) {
t.Parallel()
client, workspace, agentToken := setupWorkspaceForAgent(t)