From 503d09c149b05a9a8d6c52b723f783afd4015adf Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Thu, 17 Feb 2022 10:44:49 -0600 Subject: [PATCH] feat: Add support for executing processes with Windows ConPty (#311) * Initial agent * fix: Use buffered reader in peer to fix ShortBuffer This prevents a io.ErrShortBuffer from occurring when the byte slice being read is smaller than the chunks sent from the opposite pipe. This makes sense for unordered connections, where transmission is not guarunteed, but does not make sense for TCP-like connections. We use a bufio.Reader when ordered to ensure data isn't lost. * SSH server works! * Start Windows support * Something works * Refactor pty package to support Windows spawn * SSH server now works on Windows * Fix non-Windows * Fix Linux PTY render * FIx linux build tests * Remove agent and wintest * Add test for Windows resize * Fix linting errors * Add Windows environment variables * Add strings import * Add comment for attrs * Add goleak * Add require import --- .vscode/settings.json | 4 + cli/clitest/clitest_test.go | 9 +- cli/login_test.go | 18 +-- cli/projectcreate_test.go | 22 ++-- cli/root.go | 19 ++-- cli/workspacecreate_test.go | 18 +-- coderd/projectimport_test.go | 3 +- codersdk/projectimport_test.go | 5 +- console/conpty/conpty.go | 107 ------------------ console/conpty/syscall.go | 53 --------- console/console.go | 163 --------------------------- console/doc.go | 19 ---- console/expect.go | 109 ------------------ console/expect_opt.go | 139 ----------------------- console/expect_opt_test.go | 163 --------------------------- console/expect_test.go | 181 ------------------------------ console/pty/pty.go | 21 ---- console/pty/pty_windows.go | 78 ------------- console/test_console.go | 45 -------- go.mod | 2 + go.sum | 4 + peer/channel.go | 5 + pty/pty.go | 39 +++++++ {console/pty => pty}/pty_other.go | 32 +++--- pty/pty_windows.go | 107 ++++++++++++++++++ pty/ptytest/ptytest.go | 95 ++++++++++++++++ pty/ptytest/ptytest_test.go | 15 +++ pty/start.go | 7 ++ pty/start_other.go | 34 ++++++ pty/start_other_test.go | 25 +++++ pty/start_windows.go | 149 ++++++++++++++++++++++++ pty/start_windows_test.go | 32 ++++++ 32 files changed, 582 insertions(+), 1140 deletions(-) delete mode 100644 console/conpty/conpty.go delete mode 100644 console/conpty/syscall.go delete mode 100644 console/console.go delete mode 100644 console/doc.go delete mode 100644 console/expect.go delete mode 100644 console/expect_opt.go delete mode 100644 console/expect_opt_test.go delete mode 100644 console/expect_test.go delete mode 100644 console/pty/pty.go delete mode 100644 console/pty/pty_windows.go delete mode 100644 console/test_console.go create mode 100644 pty/pty.go rename {console/pty => pty}/pty_other.go (52%) create mode 100644 pty/pty_windows.go create mode 100644 pty/ptytest/ptytest.go create mode 100644 pty/ptytest/ptytest_test.go create mode 100644 pty/start.go create mode 100644 pty/start_other.go create mode 100644 pty/start_other_test.go create mode 100644 pty/start_windows.go create mode 100644 pty/start_windows_test.go diff --git a/.vscode/settings.json b/.vscode/settings.json index d9b2b88f17..02c3b05cc4 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -33,6 +33,7 @@ "drpcserver", "fatih", "goleak", + "gossh", "hashicorp", "httpmw", "isatty", @@ -51,9 +52,12 @@ "protobuf", "provisionerd", "provisionersdk", + "ptty", + "ptytest", "retrier", "sdkproto", "stretchr", + "tcpip", "tfexec", "tfstate", "unconvert", diff --git a/cli/clitest/clitest_test.go b/cli/clitest/clitest_test.go index fa11db7c04..b1bd908bf6 100644 --- a/cli/clitest/clitest_test.go +++ b/cli/clitest/clitest_test.go @@ -8,7 +8,7 @@ import ( "github.com/coder/coder/cli/clitest" "github.com/coder/coder/coderd/coderdtest" - "github.com/coder/coder/console" + "github.com/coder/coder/pty/ptytest" ) func TestMain(m *testing.M) { @@ -21,11 +21,12 @@ func TestCli(t *testing.T) { client := coderdtest.New(t) cmd, config := clitest.New(t) clitest.SetupConfig(t, client, config) - cons := console.New(t, cmd) + pty := ptytest.New(t) + cmd.SetIn(pty.Input()) + cmd.SetOut(pty.Output()) go func() { err := cmd.Execute() require.NoError(t, err) }() - _, err := cons.ExpectString("coder") - require.NoError(t, err) + pty.ExpectMatch("coder") } diff --git a/cli/login_test.go b/cli/login_test.go index b6c581cc41..24caf18e1a 100644 --- a/cli/login_test.go +++ b/cli/login_test.go @@ -3,10 +3,11 @@ package cli_test import ( "testing" + "github.com/stretchr/testify/require" + "github.com/coder/coder/cli/clitest" "github.com/coder/coder/coderd/coderdtest" - "github.com/coder/coder/console" - "github.com/stretchr/testify/require" + "github.com/coder/coder/pty/ptytest" ) func TestLogin(t *testing.T) { @@ -26,7 +27,9 @@ func TestLogin(t *testing.T) { // accurately detect Windows ptys when they are not attached to a process: // https://github.com/mattn/go-isatty/issues/59 root, _ := clitest.New(t, "login", client.URL.String(), "--force-tty") - cons := console.New(t, root) + pty := ptytest.New(t) + root.SetIn(pty.Input()) + root.SetOut(pty.Output()) go func() { err := root.Execute() require.NoError(t, err) @@ -42,12 +45,9 @@ func TestLogin(t *testing.T) { for i := 0; i < len(matches); i += 2 { match := matches[i] value := matches[i+1] - _, err := cons.ExpectString(match) - require.NoError(t, err) - _, err = cons.SendLine(value) - require.NoError(t, err) + pty.ExpectMatch(match) + pty.WriteLine(value) } - _, err := cons.ExpectString("Welcome to Coder") - require.NoError(t, err) + pty.ExpectMatch("Welcome to Coder") }) } diff --git a/cli/projectcreate_test.go b/cli/projectcreate_test.go index 6311aaf141..873a276263 100644 --- a/cli/projectcreate_test.go +++ b/cli/projectcreate_test.go @@ -7,10 +7,10 @@ import ( "github.com/coder/coder/cli/clitest" "github.com/coder/coder/coderd/coderdtest" - "github.com/coder/coder/console" "github.com/coder/coder/database" "github.com/coder/coder/provisioner/echo" "github.com/coder/coder/provisionersdk/proto" + "github.com/coder/coder/pty/ptytest" ) func TestProjectCreate(t *testing.T) { @@ -26,7 +26,9 @@ func TestProjectCreate(t *testing.T) { cmd, root := clitest.New(t, "projects", "create", "--directory", source, "--provisioner", string(database.ProvisionerTypeEcho)) clitest.SetupConfig(t, client, root) _ = coderdtest.NewProvisionerDaemon(t, client) - console := console.New(t, cmd) + pty := ptytest.New(t) + cmd.SetIn(pty.Input()) + cmd.SetOut(pty.Output()) closeChan := make(chan struct{}) go func() { err := cmd.Execute() @@ -43,10 +45,8 @@ func TestProjectCreate(t *testing.T) { for i := 0; i < len(matches); i += 2 { match := matches[i] value := matches[i+1] - _, err := console.ExpectString(match) - require.NoError(t, err) - _, err = console.SendLine(value) - require.NoError(t, err) + pty.ExpectMatch(match) + pty.WriteLine(value) } <-closeChan }) @@ -73,7 +73,9 @@ func TestProjectCreate(t *testing.T) { cmd, root := clitest.New(t, "projects", "create", "--directory", source, "--provisioner", string(database.ProvisionerTypeEcho)) clitest.SetupConfig(t, client, root) coderdtest.NewProvisionerDaemon(t, client) - cons := console.New(t, cmd) + pty := ptytest.New(t) + cmd.SetIn(pty.Input()) + cmd.SetOut(pty.Output()) closeChan := make(chan struct{}) go func() { err := cmd.Execute() @@ -91,10 +93,8 @@ func TestProjectCreate(t *testing.T) { for i := 0; i < len(matches); i += 2 { match := matches[i] value := matches[i+1] - _, err := cons.ExpectString(match) - require.NoError(t, err) - _, err = cons.SendLine(value) - require.NoError(t, err) + pty.ExpectMatch(match) + pty.WriteLine(value) } <-closeChan }) diff --git a/cli/root.go b/cli/root.go index f4e27a49d9..55e2b4c1d6 100644 --- a/cli/root.go +++ b/cli/root.go @@ -12,7 +12,6 @@ import ( "github.com/manifoldco/promptui" "github.com/mattn/go-isatty" "github.com/spf13/cobra" - "golang.org/x/xerrors" "github.com/coder/coder/cli/config" "github.com/coder/coder/coderd" @@ -138,14 +137,9 @@ func isTTY(cmd *cobra.Command) bool { } func prompt(cmd *cobra.Command, prompt *promptui.Prompt) (string, error) { - var ok bool - prompt.Stdin, ok = cmd.InOrStdin().(io.ReadCloser) - if !ok { - return "", xerrors.New("stdin must be a readcloser") - } - prompt.Stdout, ok = cmd.OutOrStdout().(io.WriteCloser) - if !ok { - return "", xerrors.New("stdout must be a readcloser") + prompt.Stdin = io.NopCloser(cmd.InOrStdin()) + prompt.Stdout = readWriteCloser{ + Writer: cmd.OutOrStdout(), } // The prompt library displays defaults in a jarring way for the user @@ -199,3 +193,10 @@ func prompt(cmd *cobra.Command, prompt *promptui.Prompt) (string, error) { return value, err } + +// readWriteCloser fakes reads, writes, and closing! +type readWriteCloser struct { + io.Reader + io.Writer + io.Closer +} diff --git a/cli/workspacecreate_test.go b/cli/workspacecreate_test.go index 306caa65c4..b3b1ca2691 100644 --- a/cli/workspacecreate_test.go +++ b/cli/workspacecreate_test.go @@ -3,12 +3,13 @@ package cli_test import ( "testing" + "github.com/stretchr/testify/require" + "github.com/coder/coder/cli/clitest" "github.com/coder/coder/coderd/coderdtest" - "github.com/coder/coder/console" "github.com/coder/coder/provisioner/echo" "github.com/coder/coder/provisionersdk/proto" - "github.com/stretchr/testify/require" + "github.com/coder/coder/pty/ptytest" ) func TestWorkspaceCreate(t *testing.T) { @@ -36,7 +37,9 @@ func TestWorkspaceCreate(t *testing.T) { cmd, root := clitest.New(t, "workspaces", "create", project.Name) clitest.SetupConfig(t, client, root) - cons := console.New(t, cmd) + pty := ptytest.New(t) + cmd.SetIn(pty.Input()) + cmd.SetOut(pty.Output()) closeChan := make(chan struct{}) go func() { err := cmd.Execute() @@ -51,13 +54,10 @@ func TestWorkspaceCreate(t *testing.T) { for i := 0; i < len(matches); i += 2 { match := matches[i] value := matches[i+1] - _, err := cons.ExpectString(match) - require.NoError(t, err) - _, err = cons.SendLine(value) - require.NoError(t, err) + pty.ExpectMatch(match) + pty.WriteLine(value) } - _, err := cons.ExpectString("Create") - require.NoError(t, err) + pty.ExpectMatch("Create") <-closeChan }) } diff --git a/coderd/projectimport_test.go b/coderd/projectimport_test.go index 06140190f5..b9df691233 100644 --- a/coderd/projectimport_test.go +++ b/coderd/projectimport_test.go @@ -5,13 +5,14 @@ import ( "net/http" "testing" + "github.com/stretchr/testify/require" + "github.com/coder/coder/coderd" "github.com/coder/coder/coderd/coderdtest" "github.com/coder/coder/codersdk" "github.com/coder/coder/database" "github.com/coder/coder/provisioner/echo" "github.com/coder/coder/provisionersdk/proto" - "github.com/stretchr/testify/require" ) func TestPostProjectImportByOrganization(t *testing.T) { diff --git a/codersdk/projectimport_test.go b/codersdk/projectimport_test.go index 8cc6b28a23..ccbe013458 100644 --- a/codersdk/projectimport_test.go +++ b/codersdk/projectimport_test.go @@ -5,12 +5,13 @@ import ( "testing" "time" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "github.com/coder/coder/coderd" "github.com/coder/coder/coderd/coderdtest" "github.com/coder/coder/provisioner/echo" "github.com/coder/coder/provisionersdk/proto" - "github.com/google/uuid" - "github.com/stretchr/testify/require" ) func TestCreateProjectImportJob(t *testing.T) { diff --git a/console/conpty/conpty.go b/console/conpty/conpty.go deleted file mode 100644 index a57264b8ff..0000000000 --- a/console/conpty/conpty.go +++ /dev/null @@ -1,107 +0,0 @@ -//go:build windows -// +build windows - -// Original copyright 2020 ActiveState Software. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file - -package conpty - -import ( - "fmt" - "io" - "os" - - "golang.org/x/sys/windows" -) - -// ConPty represents a windows pseudo console. -type ConPty struct { - hpCon windows.Handle - outPipePseudoConsoleSide windows.Handle - outPipeOurSide windows.Handle - inPipeOurSide windows.Handle - inPipePseudoConsoleSide windows.Handle - consoleSize uintptr - outFilePseudoConsoleSide *os.File - outFileOurSide *os.File - inFilePseudoConsoleSide *os.File - inFileOurSide *os.File - closed bool -} - -// New returns a new ConPty pseudo terminal device -func New(columns int16, rows int16) (*ConPty, error) { - c := &ConPty{ - consoleSize: uintptr(columns) + (uintptr(rows) << 16), - } - - return c, c.createPseudoConsoleAndPipes() -} - -// Close closes the pseudo-terminal and cleans up all attached resources -func (c *ConPty) Close() error { - // Trying to close these pipes multiple times will result in an - // access violation - if c.closed { - return nil - } - - err := closePseudoConsole(c.hpCon) - c.outFilePseudoConsoleSide.Close() - c.outFileOurSide.Close() - c.inFilePseudoConsoleSide.Close() - c.inFileOurSide.Close() - c.closed = true - return err -} - -// OutPipe returns the output pipe of the pseudo terminal -func (c *ConPty) OutPipe() *os.File { - return c.outFilePseudoConsoleSide -} - -func (c *ConPty) Reader() io.Reader { - return c.outFileOurSide -} - -// InPipe returns input pipe of the pseudo terminal -// Note: It is safer to use the Write method to prevent partially-written VT sequences -// from corrupting the terminal -func (c *ConPty) InPipe() *os.File { - return c.inFilePseudoConsoleSide -} - -func (c *ConPty) WriteString(str string) (int, error) { - return c.inFileOurSide.WriteString(str) -} - -func (c *ConPty) createPseudoConsoleAndPipes() error { - // Create the stdin pipe - if err := windows.CreatePipe(&c.inPipePseudoConsoleSide, &c.inPipeOurSide, nil, 0); err != nil { - return err - } - - // Create the stdout pipe - if err := windows.CreatePipe(&c.outPipeOurSide, &c.outPipePseudoConsoleSide, nil, 0); err != nil { - return err - } - - // Create the pty with our stdin/stdout - if err := createPseudoConsole(c.consoleSize, c.inPipePseudoConsoleSide, c.outPipePseudoConsoleSide, &c.hpCon); err != nil { - return fmt.Errorf("failed to create pseudo console: %d, %v", uintptr(c.hpCon), err) - } - - c.outFilePseudoConsoleSide = os.NewFile(uintptr(c.outPipePseudoConsoleSide), "|0") - c.outFileOurSide = os.NewFile(uintptr(c.outPipeOurSide), "|1") - - c.inFilePseudoConsoleSide = os.NewFile(uintptr(c.inPipePseudoConsoleSide), "|2") - c.inFileOurSide = os.NewFile(uintptr(c.inPipeOurSide), "|3") - c.closed = false - - return nil -} - -func (c *ConPty) Resize(cols uint16, rows uint16) error { - return resizePseudoConsole(c.hpCon, uintptr(cols)+(uintptr(rows)<<16)) -} diff --git a/console/conpty/syscall.go b/console/conpty/syscall.go deleted file mode 100644 index 284603aa8f..0000000000 --- a/console/conpty/syscall.go +++ /dev/null @@ -1,53 +0,0 @@ -//go:build windows -// +build windows - -// Copyright 2020 ActiveState Software. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file - -package conpty - -import ( - "unsafe" - - "golang.org/x/sys/windows" -) - -var ( - kernel32 = windows.NewLazySystemDLL("kernel32.dll") - procResizePseudoConsole = kernel32.NewProc("ResizePseudoConsole") - procCreatePseudoConsole = kernel32.NewProc("CreatePseudoConsole") - procClosePseudoConsole = kernel32.NewProc("ClosePseudoConsole") -) - -func createPseudoConsole(consoleSize uintptr, ptyIn windows.Handle, ptyOut windows.Handle, hpCon *windows.Handle) (err error) { - r1, _, e1 := procCreatePseudoConsole.Call( - consoleSize, - uintptr(ptyIn), - uintptr(ptyOut), - 0, - uintptr(unsafe.Pointer(hpCon)), - ) - - if r1 != 0 { // !S_OK - err = e1 - } - return -} - -func resizePseudoConsole(handle windows.Handle, consoleSize uintptr) (err error) { - r1, _, e1 := procResizePseudoConsole.Call(uintptr(handle), consoleSize) - if r1 != 0 { // !S_OK - err = e1 - } - return -} - -func closePseudoConsole(handle windows.Handle) (err error) { - r1, _, e1 := procClosePseudoConsole.Call(uintptr(handle)) - if r1 == 0 { - err = e1 - } - - return -} diff --git a/console/console.go b/console/console.go deleted file mode 100644 index e5af7fa209..0000000000 --- a/console/console.go +++ /dev/null @@ -1,163 +0,0 @@ -// Copyright 2018 Netflix, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package console - -import ( - "bufio" - "fmt" - "io" - "io/ioutil" - "log" - "os" - "unicode/utf8" - - "github.com/coder/coder/console/pty" -) - -// Console is an interface to automate input and output for interactive -// applications. Console can block until a specified output is received and send -// input back on it's tty. Console can also multiplex other sources of input -// and multiplex its output to other writers. -type Console struct { - opts Opts - pty pty.Pty - runeReader *bufio.Reader - closers []io.Closer -} - -// Opt allows setting Console options. -type Opt func(*Opts) error - -// Opts provides additional options on creating a Console. -type Opts struct { - Logger *log.Logger - Stdouts []io.Writer - ExpectObservers []Observer -} - -// Observer provides an interface for a function callback that will -// be called after each Expect operation. -// matchers will be the list of active matchers when an error occurred, -// or a list of matchers that matched `buf` when err is nil. -// buf is the captured output that was matched against. -// err is error that might have occurred. May be nil. -type Observer func(matchers []Matcher, buf string, err error) - -// WithStdout adds writers that Console duplicates writes to, similar to the -// Unix tee(1) command. -// -// Each write is written to each listed writer, one at a time. Console is the -// last writer, writing to it's internal buffer for matching expects. -// If a listed writer returns an error, that overall write operation stops and -// returns the error; it does not continue down the list. -func WithStdout(writers ...io.Writer) Opt { - return func(opts *Opts) error { - opts.Stdouts = append(opts.Stdouts, writers...) - return nil - } -} - -// WithLogger adds a logger for Console to log debugging information to. By -// default Console will discard logs. -func WithLogger(logger *log.Logger) Opt { - return func(opts *Opts) error { - opts.Logger = logger - return nil - } -} - -// WithExpectObserver adds an ExpectObserver to allow monitoring Expect operations. -func WithExpectObserver(observers ...Observer) Opt { - return func(opts *Opts) error { - opts.ExpectObservers = append(opts.ExpectObservers, observers...) - return nil - } -} - -// NewConsole returns a new Console with the given options. -func NewConsole(opts ...Opt) (*Console, error) { - options := Opts{ - Logger: log.New(ioutil.Discard, "", 0), - } - - for _, opt := range opts { - if err := opt(&options); err != nil { - return nil, err - } - } - - consolePty, err := pty.New() - if err != nil { - return nil, err - } - closers := []io.Closer{consolePty} - reader := consolePty.Reader() - - cons := &Console{ - opts: options, - pty: consolePty, - runeReader: bufio.NewReaderSize(reader, utf8.UTFMax), - closers: closers, - } - - return cons, nil -} - -// Tty returns an input Tty for accepting input -func (c *Console) InTty() *os.File { - return c.pty.InPipe() -} - -// OutTty returns an output tty for writing -func (c *Console) OutTty() *os.File { - return c.pty.OutPipe() -} - -// Close closes Console's tty. Calling Close will unblock Expect and ExpectEOF. -func (c *Console) Close() error { - for _, fd := range c.closers { - err := fd.Close() - if err != nil { - c.Logf("failed to close: %s", err) - } - } - return nil -} - -// Send writes string s to Console's tty. -func (c *Console) Send(s string) (int, error) { - c.Logf("console send: %q", s) - n, err := c.pty.WriteString(s) - return n, err -} - -// SendLine writes string s to Console's tty with a trailing newline. -func (c *Console) SendLine(s string) (int, error) { - bytes, err := c.Send(fmt.Sprintf("%s\n", s)) - - return bytes, err -} - -// Log prints to Console's logger. -// Arguments are handled in the manner of fmt.Print. -func (c *Console) Log(v ...interface{}) { - c.opts.Logger.Print(v...) -} - -// Logf prints to Console's logger. -// Arguments are handled in the manner of fmt.Printf. -func (c *Console) Logf(format string, v ...interface{}) { - c.opts.Logger.Printf(format, v...) -} diff --git a/console/doc.go b/console/doc.go deleted file mode 100644 index 7a5fc545cd..0000000000 --- a/console/doc.go +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright 2018 Netflix, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package expect provides an expect-like interface to automate control of -// applications. It is unlike expect in that it does not spawn or manage -// process lifecycle. This package only focuses on expecting output and sending -// input through it's psuedoterminal. -package console diff --git a/console/expect.go b/console/expect.go deleted file mode 100644 index c2e3f583b0..0000000000 --- a/console/expect.go +++ /dev/null @@ -1,109 +0,0 @@ -// Copyright 2018 Netflix, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package console - -import ( - "bufio" - "bytes" - "fmt" - "io" - "unicode/utf8" -) - -// Expectf reads from the Console's tty until the provided formatted string -// is read or an error occurs, and returns the buffer read by Console. -func (c *Console) Expectf(format string, args ...interface{}) (string, error) { - return c.Expect(String(fmt.Sprintf(format, args...))) -} - -// ExpectString reads from Console's tty until the provided string is read or -// an error occurs, and returns the buffer read by Console. -func (c *Console) ExpectString(s string) (string, error) { - return c.Expect(String(s)) -} - -// Expect reads from Console's tty until a condition specified from opts is -// encountered or an error occurs, and returns the buffer read by console. -// No extra bytes are read once a condition is met, so if a program isn't -// expecting input yet, it will be blocked. Sends are queued up in tty's -// internal buffer so that the next Expect will read the remaining bytes (i.e. -// rest of prompt) as well as its conditions. -func (c *Console) Expect(opts ...ExpectOpt) (string, error) { - var options ExpectOpts - for _, opt := range opts { - if err := opt(&options); err != nil { - return "", err - } - } - - buf := new(bytes.Buffer) - writer := io.MultiWriter(append(c.opts.Stdouts, buf)...) - runeWriter := bufio.NewWriterSize(writer, utf8.UTFMax) - - var matcher Matcher - var err error - - defer func() { - for _, observer := range c.opts.ExpectObservers { - if matcher != nil { - observer([]Matcher{matcher}, buf.String(), err) - return - } - observer(options.Matchers, buf.String(), err) - } - }() - - for { - var r rune - r, _, err = c.runeReader.ReadRune() - if err != nil { - matcher = options.Match(err) - if matcher != nil { - err = nil - break - } - return buf.String(), err - } - - c.Logf("expect read: %q", string(r)) - _, err = runeWriter.WriteRune(r) - if err != nil { - return buf.String(), err - } - - // Immediately flush rune to the underlying writers. - err = runeWriter.Flush() - if err != nil { - return buf.String(), err - } - - matcher = options.Match(buf) - if matcher != nil { - break - } - } - - if matcher != nil { - cb, ok := matcher.(CallbackMatcher) - if ok { - err = cb.Callback(buf) - if err != nil { - return buf.String(), err - } - } - } - - return buf.String(), err -} diff --git a/console/expect_opt.go b/console/expect_opt.go deleted file mode 100644 index fec0d9b8f3..0000000000 --- a/console/expect_opt.go +++ /dev/null @@ -1,139 +0,0 @@ -// Copyright 2018 Netflix, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package console - -import ( - "bytes" - "strings" - "time" -) - -// ExpectOpt allows settings Expect options. -type ExpectOpt func(*ExpectOpts) error - -// Callback is a callback function to execute if a match is found for -// the chained matcher. -type Callback func(buf *bytes.Buffer) error - -// ExpectOpts provides additional options on Expect. -type ExpectOpts struct { - Matchers []Matcher - ReadTimeout *time.Duration -} - -// Match sequentially calls Match on all matchers in ExpectOpts and returns the -// first matcher if a match exists, otherwise nil. -func (eo ExpectOpts) Match(v interface{}) Matcher { - for _, matcher := range eo.Matchers { - if matcher.Match(v) { - return matcher - } - } - return nil -} - -// CallbackMatcher is a matcher that provides a Callback function. -type CallbackMatcher interface { - // Callback executes the matcher's callback with the content buffer at the - // time of match. - Callback(buf *bytes.Buffer) error -} - -// Matcher provides an interface for finding a match in content read from -// Console's tty. -type Matcher interface { - // Match returns true iff a match is found. - Match(v interface{}) bool - Criteria() interface{} -} - -// stringMatcher fulfills the Matcher interface to match strings against a given -// bytes.Buffer. -type stringMatcher struct { - str string -} - -func (sm *stringMatcher) Match(v interface{}) bool { - buf, ok := v.(*bytes.Buffer) - if !ok { - return false - } - if strings.Contains(buf.String(), sm.str) { - return true - } - return false -} - -func (sm *stringMatcher) Criteria() interface{} { - return sm.str -} - -// allMatcher fulfills the Matcher interface to match a group of ExpectOpt -// against any value. -type allMatcher struct { - options ExpectOpts -} - -func (am *allMatcher) Match(v interface{}) bool { - var matchers []Matcher - for _, matcher := range am.options.Matchers { - if matcher.Match(v) { - continue - } - matchers = append(matchers, matcher) - } - - am.options.Matchers = matchers - return len(matchers) == 0 -} - -func (am *allMatcher) Criteria() interface{} { - var criteria []interface{} - for _, matcher := range am.options.Matchers { - criteria = append(criteria, matcher.Criteria()) - } - return criteria -} - -// All adds an Expect condition to exit if the content read from Console's tty -// matches all of the provided ExpectOpt, in any order. -func All(expectOpts ...ExpectOpt) ExpectOpt { - return func(opts *ExpectOpts) error { - var options ExpectOpts - for _, opt := range expectOpts { - if err := opt(&options); err != nil { - return err - } - } - - opts.Matchers = append(opts.Matchers, &allMatcher{ - options: options, - }) - return nil - } -} - -// String adds an Expect condition to exit if the content read from Console's -// tty contains any of the given strings. -func String(strs ...string) ExpectOpt { - return func(opts *ExpectOpts) error { - for _, str := range strs { - opts.Matchers = append(opts.Matchers, &stringMatcher{ - str: str, - }) - } - return nil - } -} diff --git a/console/expect_opt_test.go b/console/expect_opt_test.go deleted file mode 100644 index 91efc935fc..0000000000 --- a/console/expect_opt_test.go +++ /dev/null @@ -1,163 +0,0 @@ -// Copyright 2018 Netflix, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package console_test - -import ( - "bytes" - "testing" - - "github.com/stretchr/testify/require" - - . "github.com/coder/coder/console" -) - -func TestExpectOptString(t *testing.T) { - t.Parallel() - - tests := []struct { - title string - opt ExpectOpt - data string - expected bool - }{ - { - "No args", - String(), - "Hello world", - false, - }, - { - "Single arg", - String("Hello"), - "Hello world", - true, - }, - { - "Multiple arg", - String("other", "world"), - "Hello world", - true, - }, - { - "No matches", - String("hello"), - "Hello world", - false, - }, - } - - for _, test := range tests { - test := test - t.Run(test.title, func(t *testing.T) { - t.Parallel() - - var options ExpectOpts - err := test.opt(&options) - require.Nil(t, err) - - buf := new(bytes.Buffer) - _, err = buf.WriteString(test.data) - require.Nil(t, err) - - matcher := options.Match(buf) - if test.expected { - require.NotNil(t, matcher) - } else { - require.Nil(t, matcher) - } - }) - } -} - -func TestExpectOptAll(t *testing.T) { - t.Parallel() - - tests := []struct { - title string - opt ExpectOpt - data string - expected bool - }{ - { - "No opts", - All(), - "Hello world", - true, - }, - { - "Single string match", - All(String("Hello")), - "Hello world", - true, - }, - { - "Single string no match", - All(String("Hello")), - "No match", - false, - }, - { - "Ordered strings match", - All(String("Hello"), String("world")), - "Hello world", - true, - }, - { - "Ordered strings not all match", - All(String("Hello"), String("world")), - "Hello", - false, - }, - { - "Unordered strings", - All(String("world"), String("Hello")), - "Hello world", - true, - }, - { - "Unordered strings not all match", - All(String("world"), String("Hello")), - "Hello", - false, - }, - { - "Repeated strings match", - All(String("Hello"), String("Hello")), - "Hello world", - true, - }, - } - - for _, test := range tests { - test := test - t.Run(test.title, func(t *testing.T) { - t.Parallel() - var options ExpectOpts - err := test.opt(&options) - require.Nil(t, err) - - buf := new(bytes.Buffer) - _, err = buf.WriteString(test.data) - require.Nil(t, err) - - matcher := options.Match(buf) - if test.expected { - require.NotNil(t, matcher) - } else { - require.Nil(t, matcher) - } - }) - } -} diff --git a/console/expect_test.go b/console/expect_test.go deleted file mode 100644 index c80f981717..0000000000 --- a/console/expect_test.go +++ /dev/null @@ -1,181 +0,0 @@ -// Copyright 2018 Netflix, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package console_test - -import ( - "bufio" - "errors" - "fmt" - "io" - "runtime/debug" - "strings" - "sync" - "testing" - - "golang.org/x/xerrors" - - . "github.com/coder/coder/console" -) - -var ( - ErrWrongAnswer = xerrors.New("wrong answer") -) - -type Survey struct { - Prompt string - Answer string -} - -func Prompt(in io.Reader, out io.Writer) error { - reader := bufio.NewReader(in) - - for _, survey := range []Survey{ - { - "What is 1+1?", "2", - }, - { - "What is Netflix backwards?", "xilfteN", - }, - } { - _, err := fmt.Fprintf(out, "%s: ", survey.Prompt) - if err != nil { - return err - } - text, err := reader.ReadString('\n') - if err != nil { - return err - } - - _, err = fmt.Fprint(out, text) - if err != nil { - return err - } - text = strings.TrimSpace(text) - if text != survey.Answer { - return ErrWrongAnswer - } - } - - return nil -} - -func newTestConsole(t *testing.T, opts ...Opt) (*Console, error) { - opts = append([]Opt{ - expectNoError(t), - }, opts...) - return NewConsole(opts...) -} - -func expectNoError(t *testing.T) Opt { - return WithExpectObserver( - func(matchers []Matcher, buf string, err error) { - if err == nil { - return - } - if len(matchers) == 0 { - t.Fatalf("Error occurred while matching %q: %s\n%s", buf, err, string(debug.Stack())) - } else { - var criteria []string - for _, matcher := range matchers { - criteria = append(criteria, fmt.Sprintf("%q", matcher.Criteria())) - } - t.Fatalf("Failed to find [%s] in %q: %s\n%s", strings.Join(criteria, ", "), buf, err, string(debug.Stack())) - } - }, - ) -} - -func testCloser(t *testing.T, closer io.Closer) { - if err := closer.Close(); err != nil { - t.Errorf("Close failed: %s", err) - debug.PrintStack() - } -} - -func TestExpectf(t *testing.T) { - t.Parallel() - - console, err := newTestConsole(t) - if err != nil { - t.Errorf("Expected no error but got'%s'", err) - } - defer testCloser(t, console) - - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - console.Expectf("What is 1+%d?", 1) - console.SendLine("2") - console.Expectf("What is %s backwards?", "Netflix") - console.SendLine("xilfteN") - }() - - err = Prompt(console.InTty(), console.OutTty()) - if err != nil { - t.Errorf("Expected no error but got '%s'", err) - } - wg.Wait() -} - -func TestExpect(t *testing.T) { - t.Parallel() - - console, err := newTestConsole(t) - if err != nil { - t.Errorf("Expected no error but got'%s'", err) - } - defer testCloser(t, console) - - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - console.ExpectString("What is 1+1?") - console.SendLine("2") - console.ExpectString("What is Netflix backwards?") - console.SendLine("xilfteN") - }() - - err = Prompt(console.InTty(), console.OutTty()) - if err != nil { - t.Errorf("Expected no error but got '%s'", err) - } - wg.Wait() -} - -func TestExpectOutput(t *testing.T) { - t.Parallel() - - console, err := newTestConsole(t) - if err != nil { - t.Errorf("Expected no error but got'%s'", err) - } - defer testCloser(t, console) - - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - console.ExpectString("What is 1+1?") - console.SendLine("3") - }() - - err = Prompt(console.InTty(), console.OutTty()) - if err == nil || !errors.Is(err, ErrWrongAnswer) { - t.Errorf("Expected error '%s' but got '%s' instead", ErrWrongAnswer, err) - } - wg.Wait() -} diff --git a/console/pty/pty.go b/console/pty/pty.go deleted file mode 100644 index 86b56e68f9..0000000000 --- a/console/pty/pty.go +++ /dev/null @@ -1,21 +0,0 @@ -package pty - -import ( - "io" - "os" -) - -// Pty is the minimal pseudo-tty interface we require. -type Pty interface { - InPipe() *os.File - OutPipe() *os.File - Resize(cols uint16, rows uint16) error - WriteString(str string) (int, error) - Reader() io.Reader - Close() error -} - -// New creates a new Pty. -func New() (Pty, error) { - return newPty() -} diff --git a/console/pty/pty_windows.go b/console/pty/pty_windows.go deleted file mode 100644 index 01fbe39169..0000000000 --- a/console/pty/pty_windows.go +++ /dev/null @@ -1,78 +0,0 @@ -//go:build windows -// +build windows - -package pty - -import ( - "io" - "os" - - "golang.org/x/sys/windows" - - "github.com/coder/coder/console/conpty" -) - -func newPty() (Pty, error) { - // We use the CreatePseudoConsole API which was introduced in build 17763 - vsn := windows.RtlGetVersion() - if vsn.MajorVersion < 10 || - vsn.BuildNumber < 17763 { - // If the CreatePseudoConsole API is not available, we fall back to a simpler - // implementation that doesn't create an actual PTY - just uses os.Pipe - return pipePty() - } - - return conpty.New(80, 80) -} - -func pipePty() (Pty, error) { - inFilePipeSide, inFileOurSide, err := os.Pipe() - if err != nil { - return nil, err - } - - outFileOurSide, outFilePipeSide, err := os.Pipe() - if err != nil { - return nil, err - } - - return &pipePtyVal{ - inFilePipeSide, - inFileOurSide, - outFileOurSide, - outFilePipeSide, - }, nil -} - -type pipePtyVal struct { - inFilePipeSide, inFileOurSide *os.File - outFileOurSide, outFilePipeSide *os.File -} - -func (p *pipePtyVal) InPipe() *os.File { - return p.inFilePipeSide -} - -func (p *pipePtyVal) OutPipe() *os.File { - return p.outFilePipeSide -} - -func (p *pipePtyVal) Reader() io.Reader { - return p.outFileOurSide -} - -func (p *pipePtyVal) WriteString(str string) (int, error) { - return p.inFileOurSide.WriteString(str) -} - -func (p *pipePtyVal) Resize(uint16, uint16) error { - return nil -} - -func (p *pipePtyVal) Close() error { - p.inFileOurSide.Close() - p.inFilePipeSide.Close() - p.outFilePipeSide.Close() - p.outFileOurSide.Close() - return nil -} diff --git a/console/test_console.go b/console/test_console.go deleted file mode 100644 index d1d845d6cb..0000000000 --- a/console/test_console.go +++ /dev/null @@ -1,45 +0,0 @@ -package console - -import ( - "bufio" - "io" - "regexp" - "testing" - - "github.com/spf13/cobra" - "github.com/stretchr/testify/require" -) - -var ( - // Used to ensure terminal output doesn't have anything crazy! - // See: https://stackoverflow.com/a/29497680 - stripAnsi = regexp.MustCompile("[\u001B\u009B][[\\]()#;?]*(?:(?:(?:[a-zA-Z\\d]*(?:;[a-zA-Z\\d]*)*)?\u0007)|(?:(?:\\d{1,4}(?:;\\d{0,4})*)?[\\dA-PRZcf-ntqry=><~]))") -) - -// New creates a new TTY bound to the command provided. -// All ANSI escape codes are stripped to provide clean output. -func New(t *testing.T, cmd *cobra.Command) *Console { - reader, writer := io.Pipe() - scanner := bufio.NewScanner(reader) - t.Cleanup(func() { - _ = reader.Close() - _ = writer.Close() - }) - go func() { - for scanner.Scan() { - if scanner.Err() != nil { - return - } - t.Log(stripAnsi.ReplaceAllString(scanner.Text(), "")) - } - }() - - console, err := NewConsole(WithStdout(writer)) - require.NoError(t, err) - t.Cleanup(func() { - console.Close() - }) - cmd.SetIn(console.InTty()) - cmd.SetOut(console.OutTty()) - return console -} diff --git a/go.mod b/go.mod index d002eeadbf..e224d221c5 100644 --- a/go.mod +++ b/go.mod @@ -20,6 +20,7 @@ require ( github.com/coder/retry v1.3.0 github.com/creack/pty v1.1.17 github.com/fatih/color v1.13.0 + github.com/gliderlabs/ssh v0.3.3 github.com/go-chi/chi/v5 v5.0.7 github.com/go-chi/render v1.0.1 github.com/go-playground/validator/v10 v10.10.0 @@ -64,6 +65,7 @@ require ( github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 // indirect github.com/agext/levenshtein v1.2.3 // indirect github.com/alecthomas/chroma v0.10.0 // indirect + github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect github.com/apparentlymart/go-textseg/v13 v13.0.0 // indirect github.com/cenkalti/backoff/v4 v4.1.2 // indirect github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e // indirect diff --git a/go.sum b/go.sum index 370ff4aeb0..416fe1da0f 100644 --- a/go.sum +++ b/go.sum @@ -132,6 +132,8 @@ github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRF github.com/alexflint/go-filemutex v0.0.0-20171022225611-72bdc8eae2ae/go.mod h1:CgnQgUtFrFz9mxFNtED3jI5tLDjKlOM+oUF/sTk6ps0= github.com/andybalholm/crlf v0.0.0-20171020200849-670099aa064f/go.mod h1:k8feO4+kXDxro6ErPXBRTJ/ro2mf0SsFG8s7doP9kJE= github.com/anmitsu/go-shlex v0.0.0-20161002113705-648efa622239/go.mod h1:2FmKhYUyUczH0OGQWaF5ceTx0UBShxjsH6f8oGKYe2c= +github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8= +github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4= github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/apache/arrow/go/arrow v0.0.0-20210818145353-234c94e4ce64/go.mod h1:2qMFB56yOP3KzkB3PbYZ4AlUFg3a88F67TIx5lB/WwY= github.com/apache/arrow/go/arrow v0.0.0-20211013220434-5962184e7a30/go.mod h1:Q7yQnSMnLvcXlZ8RV+jwz/6y1rQTqbX6C82SndT52Zs= @@ -441,6 +443,8 @@ github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm github.com/gin-gonic/gin v1.6.3 h1:ahKqKTFpO5KTPHxWZjEdPScmYaGtLo8Y4DMHoEsnp14= github.com/gin-gonic/gin v1.6.3/go.mod h1:75u5sXoLsGZoRN5Sgbi1eraJ4GU3++wFwWzhwvtwp4M= github.com/gliderlabs/ssh v0.2.2/go.mod h1:U7qILu1NlMHj9FlMhZLlkCdDnU1DBEAqr0aevW3Awn0= +github.com/gliderlabs/ssh v0.3.3 h1:mBQ8NiOgDkINJrZtoizkC3nDNYgSaWtxyem6S2XHBtA= +github.com/gliderlabs/ssh v0.3.3/go.mod h1:ZSS+CUoKHDrqVakTfTWUlKSr9MtMFkC4UvtQKD7O914= github.com/go-chi/chi/v5 v5.0.7 h1:rDTPXLDHGATaeHvVlLcR4Qe0zftYethFucbjVQ1PxU8= github.com/go-chi/chi/v5 v5.0.7/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= github.com/go-chi/render v1.0.1 h1:4/5tis2cKaNdnv9zFLfXzcquC9HbeZgCnxGnKrltBS8= diff --git a/peer/channel.go b/peer/channel.go index d1f4930fe3..732a6a1c1d 100644 --- a/peer/channel.go +++ b/peer/channel.go @@ -263,6 +263,11 @@ func (c *Channel) Label() string { return c.dc.Label() } +// Protocol returns the protocol of the underlying DataChannel. +func (c *Channel) Protocol() string { + return c.dc.Protocol() +} + // NetConn wraps the DataChannel in a struct fulfilling net.Conn. // Read, Write, and Close operations can still be used on the *Channel struct. func (c *Channel) NetConn() net.Conn { diff --git a/pty/pty.go b/pty/pty.go new file mode 100644 index 0000000000..0086bfba56 --- /dev/null +++ b/pty/pty.go @@ -0,0 +1,39 @@ +package pty + +import ( + "io" +) + +// PTY is a minimal interface for interacting with a TTY. +type PTY interface { + io.Closer + + // Output handles TTY output. + // + // cmd.SetOutput(pty.Output()) would be used to specify a command + // uses the output stream for writing. + // + // The same stream could be read to validate output. + Output() io.ReadWriter + + // Input handles TTY input. + // + // cmd.SetInput(pty.Input()) would be used to specify a command + // uses the PTY input for reading. + // + // The same stream would be used to provide user input: pty.Input().Write(...) + Input() io.ReadWriter + + // Resize sets the size of the PTY. + Resize(cols uint16, rows uint16) error +} + +// New constructs a new Pty. +func New() (PTY, error) { + return newPty() +} + +type readWriter struct { + io.Reader + io.Writer +} diff --git a/console/pty/pty_other.go b/pty/pty_other.go similarity index 52% rename from console/pty/pty_other.go rename to pty/pty_other.go index 723a6dbfd7..dbdda408b1 100644 --- a/console/pty/pty_other.go +++ b/pty/pty_other.go @@ -10,46 +10,44 @@ import ( "github.com/creack/pty" ) -func newPty() (Pty, error) { +func newPty() (PTY, error) { ptyFile, ttyFile, err := pty.Open() if err != nil { return nil, err } - return &unixPty{ + return &otherPty{ pty: ptyFile, tty: ttyFile, }, nil } -type unixPty struct { +type otherPty struct { pty, tty *os.File } -func (p *unixPty) InPipe() *os.File { - return p.tty +func (p *otherPty) Input() io.ReadWriter { + return readWriter{ + Reader: p.tty, + Writer: p.pty, + } } -func (p *unixPty) OutPipe() *os.File { - return p.tty +func (p *otherPty) Output() io.ReadWriter { + return readWriter{ + Reader: p.pty, + Writer: p.tty, + } } -func (p *unixPty) Reader() io.Reader { - return p.pty -} - -func (p *unixPty) WriteString(str string) (int, error) { - return p.pty.WriteString(str) -} - -func (p *unixPty) Resize(cols uint16, rows uint16) error { +func (p *otherPty) Resize(cols uint16, rows uint16) error { return pty.Setsize(p.tty, &pty.Winsize{ Rows: rows, Cols: cols, }) } -func (p *unixPty) Close() error { +func (p *otherPty) Close() error { err := p.pty.Close() if err != nil { return err diff --git a/pty/pty_windows.go b/pty/pty_windows.go new file mode 100644 index 0000000000..b6a9f8ae2e --- /dev/null +++ b/pty/pty_windows.go @@ -0,0 +1,107 @@ +//go:build windows +// +build windows + +package pty + +import ( + "io" + "os" + "sync" + "unsafe" + + "golang.org/x/sys/windows" + + "golang.org/x/xerrors" +) + +var ( + kernel32 = windows.NewLazySystemDLL("kernel32.dll") + procResizePseudoConsole = kernel32.NewProc("ResizePseudoConsole") + procCreatePseudoConsole = kernel32.NewProc("CreatePseudoConsole") + procClosePseudoConsole = kernel32.NewProc("ClosePseudoConsole") +) + +// See: https://docs.microsoft.com/en-us/windows/console/creating-a-pseudoconsole-session +func newPty() (PTY, error) { + // We use the CreatePseudoConsole API which was introduced in build 17763 + vsn := windows.RtlGetVersion() + if vsn.MajorVersion < 10 || + vsn.BuildNumber < 17763 { + // If the CreatePseudoConsole API is not available, we fall back to a simpler + // implementation that doesn't create an actual PTY - just uses os.Pipe + return nil, xerrors.Errorf("pty not supported") + } + + ptyWindows := &ptyWindows{} + + var err error + ptyWindows.inputRead, ptyWindows.inputWrite, err = os.Pipe() + if err != nil { + return nil, err + } + ptyWindows.outputRead, ptyWindows.outputWrite, err = os.Pipe() + + consoleSize := uintptr(80) + (uintptr(80) << 16) + ret, _, err := procCreatePseudoConsole.Call( + consoleSize, + uintptr(ptyWindows.inputRead.Fd()), + uintptr(ptyWindows.outputWrite.Fd()), + 0, + uintptr(unsafe.Pointer(&ptyWindows.console)), + ) + if int32(ret) < 0 { + return nil, xerrors.Errorf("create pseudo console (%d): %w", int32(ret), err) + } + return ptyWindows, nil +} + +type ptyWindows struct { + console windows.Handle + + outputWrite *os.File + outputRead *os.File + inputWrite *os.File + inputRead *os.File + + closeMutex sync.Mutex + closed bool +} + +func (p *ptyWindows) Output() io.ReadWriter { + return readWriter{ + Reader: p.outputRead, + Writer: p.outputWrite, + } +} + +func (p *ptyWindows) Input() io.ReadWriter { + return readWriter{ + Reader: p.inputRead, + Writer: p.inputWrite, + } +} + +func (p *ptyWindows) Resize(cols uint16, rows uint16) error { + ret, _, err := procResizePseudoConsole.Call(uintptr(p.console), uintptr(cols)+(uintptr(rows)<<16)) + if ret != 0 { + return err + } + return nil +} + +func (p *ptyWindows) Close() error { + p.closeMutex.Lock() + defer p.closeMutex.Unlock() + if p.closed { + return nil + } + p.closed = true + + ret, _, err := procClosePseudoConsole.Call(uintptr(p.console)) + if ret != 0 { + return xerrors.Errorf("close pseudo console: %w", err) + } + _ = p.outputRead.Close() + _ = p.inputWrite.Close() + return nil +} diff --git a/pty/ptytest/ptytest.go b/pty/ptytest/ptytest.go new file mode 100644 index 0000000000..7ea5b7a119 --- /dev/null +++ b/pty/ptytest/ptytest.go @@ -0,0 +1,95 @@ +package ptytest + +import ( + "bufio" + "bytes" + "fmt" + "io" + "os/exec" + "regexp" + "strings" + "testing" + "unicode/utf8" + + "github.com/stretchr/testify/require" + + "github.com/coder/coder/pty" +) + +var ( + // Used to ensure terminal output doesn't have anything crazy! + // See: https://stackoverflow.com/a/29497680 + stripAnsi = regexp.MustCompile("[\u001B\u009B][[\\]()#;?]*(?:(?:(?:[a-zA-Z\\d]*(?:;[a-zA-Z\\d]*)*)?\u0007)|(?:(?:\\d{1,4}(?:;\\d{0,4})*)?[\\dA-PRZcf-ntqry=><~]))") +) + +func New(t *testing.T) *PTY { + ptty, err := pty.New() + require.NoError(t, err) + return create(t, ptty) +} + +func Start(t *testing.T, cmd *exec.Cmd) *PTY { + ptty, err := pty.Start(cmd) + require.NoError(t, err) + return create(t, ptty) +} + +func create(t *testing.T, ptty pty.PTY) *PTY { + reader, writer := io.Pipe() + scanner := bufio.NewScanner(reader) + t.Cleanup(func() { + _ = reader.Close() + _ = writer.Close() + }) + go func() { + for scanner.Scan() { + if scanner.Err() != nil { + return + } + t.Log(stripAnsi.ReplaceAllString(scanner.Text(), "")) + } + }() + + t.Cleanup(func() { + _ = ptty.Close() + }) + return &PTY{ + t: t, + PTY: ptty, + + outputWriter: writer, + runeReader: bufio.NewReaderSize(ptty.Output(), utf8.UTFMax), + } +} + +type PTY struct { + t *testing.T + pty.PTY + + outputWriter io.Writer + runeReader *bufio.Reader +} + +func (p *PTY) ExpectMatch(str string) string { + var buffer bytes.Buffer + multiWriter := io.MultiWriter(&buffer, p.outputWriter) + runeWriter := bufio.NewWriterSize(multiWriter, utf8.UTFMax) + for { + var r rune + r, _, err := p.runeReader.ReadRune() + require.NoError(p.t, err) + _, err = runeWriter.WriteRune(r) + require.NoError(p.t, err) + err = runeWriter.Flush() + require.NoError(p.t, err) + if strings.Contains(buffer.String(), str) { + break + } + } + return buffer.String() +} + +func (p *PTY) WriteLine(str string) { + _, err := fmt.Fprintf(p.PTY.Input(), "%s\n", str) + require.NoError(p.t, err) +} diff --git a/pty/ptytest/ptytest_test.go b/pty/ptytest/ptytest_test.go new file mode 100644 index 0000000000..6603b35ad5 --- /dev/null +++ b/pty/ptytest/ptytest_test.go @@ -0,0 +1,15 @@ +package ptytest_test + +import ( + "testing" + + "github.com/coder/coder/pty/ptytest" +) + +func TestPtytest(t *testing.T) { + t.Parallel() + pty := ptytest.New(t) + pty.Output().Write([]byte("write")) + pty.ExpectMatch("write") + pty.WriteLine("read") +} diff --git a/pty/start.go b/pty/start.go new file mode 100644 index 0000000000..2b75843ee1 --- /dev/null +++ b/pty/start.go @@ -0,0 +1,7 @@ +package pty + +import "os/exec" + +func Start(cmd *exec.Cmd) (PTY, error) { + return startPty(cmd) +} diff --git a/pty/start_other.go b/pty/start_other.go new file mode 100644 index 0000000000..103f55202e --- /dev/null +++ b/pty/start_other.go @@ -0,0 +1,34 @@ +//go:build !windows +// +build !windows + +package pty + +import ( + "os/exec" + "syscall" + + "github.com/creack/pty" +) + +func startPty(cmd *exec.Cmd) (PTY, error) { + ptty, tty, err := pty.Open() + if err != nil { + return nil, err + } + cmd.SysProcAttr = &syscall.SysProcAttr{ + Setsid: true, + Setctty: true, + } + cmd.Stdout = tty + cmd.Stderr = tty + cmd.Stdin = tty + err = cmd.Start() + if err != nil { + _ = ptty.Close() + return nil, err + } + return &otherPty{ + pty: ptty, + tty: tty, + }, nil +} diff --git a/pty/start_other_test.go b/pty/start_other_test.go new file mode 100644 index 0000000000..a5e7d94b36 --- /dev/null +++ b/pty/start_other_test.go @@ -0,0 +1,25 @@ +//go:build !windows +// +build !windows + +package pty_test + +import ( + "os/exec" + "testing" + + "github.com/coder/coder/pty/ptytest" + "go.uber.org/goleak" +) + +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} + +func TestStart(t *testing.T) { + t.Parallel() + t.Run("Echo", func(t *testing.T) { + t.Parallel() + pty := ptytest.Start(t, exec.Command("echo", "test")) + pty.ExpectMatch("test") + }) +} diff --git a/pty/start_windows.go b/pty/start_windows.go new file mode 100644 index 0000000000..136ba24573 --- /dev/null +++ b/pty/start_windows.go @@ -0,0 +1,149 @@ +//go:build windows +// +build windows + +package pty + +import ( + "os" + "os/exec" + "strings" + "unicode/utf16" + "unsafe" + + "golang.org/x/sys/windows" +) + +// Allocates a PTY and starts the specified command attached to it. +// See: https://docs.microsoft.com/en-us/windows/console/creating-a-pseudoconsole-session#creating-the-hosted-process +func startPty(cmd *exec.Cmd) (PTY, error) { + fullPath, err := exec.LookPath(cmd.Path) + if err != nil { + return nil, err + } + pathPtr, err := windows.UTF16PtrFromString(fullPath) + if err != nil { + return nil, err + } + argsPtr, err := windows.UTF16PtrFromString(windows.ComposeCommandLine(cmd.Args)) + if err != nil { + return nil, err + } + if cmd.Dir == "" { + cmd.Dir, err = os.Getwd() + if err != nil { + return nil, err + } + } + dirPtr, err := windows.UTF16PtrFromString(cmd.Dir) + if err != nil { + return nil, err + } + pty, err := newPty() + if err != nil { + return nil, err + } + winPty := pty.(*ptyWindows) + + attrs, err := windows.NewProcThreadAttributeList(1) + if err != nil { + return nil, err + } + // Taken from: https://github.com/microsoft/hcsshim/blob/2314362e977aa03b3ed245a4beb12d00422af0e2/internal/winapi/process.go#L6 + err = attrs.Update(0x20016, unsafe.Pointer(winPty.console), unsafe.Sizeof(winPty.console)) + if err != nil { + return nil, err + } + + startupInfo := &windows.StartupInfoEx{} + startupInfo.ProcThreadAttributeList = attrs.List() + startupInfo.StartupInfo.Flags = windows.STARTF_USESTDHANDLES + startupInfo.StartupInfo.Cb = uint32(unsafe.Sizeof(*startupInfo)) + var processInfo windows.ProcessInformation + err = windows.CreateProcess( + pathPtr, + argsPtr, + nil, + nil, + false, + // https://docs.microsoft.com/en-us/windows/win32/procthread/process-creation-flags#create_unicode_environment + windows.CREATE_UNICODE_ENVIRONMENT|windows.EXTENDED_STARTUPINFO_PRESENT, + createEnvBlock(addCriticalEnv(dedupEnvCase(true, cmd.Env))), + dirPtr, + &startupInfo.StartupInfo, + &processInfo, + ) + if err != nil { + return nil, err + } + defer windows.CloseHandle(processInfo.Thread) + defer windows.CloseHandle(processInfo.Process) + + return pty, nil +} + +// Taken from: https://github.com/microsoft/hcsshim/blob/7fbdca16f91de8792371ba22b7305bf4ca84170a/internal/exec/exec.go#L476 +func createEnvBlock(envv []string) *uint16 { + if len(envv) == 0 { + return &utf16.Encode([]rune("\x00\x00"))[0] + } + length := 0 + for _, s := range envv { + length += len(s) + 1 + } + length += 1 + + b := make([]byte, length) + i := 0 + for _, s := range envv { + l := len(s) + copy(b[i:i+l], []byte(s)) + copy(b[i+l:i+l+1], []byte{0}) + i = i + l + 1 + } + copy(b[i:i+1], []byte{0}) + + return &utf16.Encode([]rune(string(b)))[0] +} + +// dedupEnvCase is dedupEnv with a case option for testing. +// If caseInsensitive is true, the case of keys is ignored. +func dedupEnvCase(caseInsensitive bool, env []string) []string { + out := make([]string, 0, len(env)) + saw := make(map[string]int, len(env)) // key => index into out + for _, kv := range env { + eq := strings.Index(kv, "=") + if eq < 0 { + out = append(out, kv) + continue + } + k := kv[:eq] + if caseInsensitive { + k = strings.ToLower(k) + } + if dupIdx, isDup := saw[k]; isDup { + out[dupIdx] = kv + continue + } + saw[k] = len(out) + out = append(out, kv) + } + return out +} + +// addCriticalEnv adds any critical environment variables that are required +// (or at least almost always required) on the operating system. +// Currently this is only used for Windows. +func addCriticalEnv(env []string) []string { + for _, kv := range env { + eq := strings.Index(kv, "=") + if eq < 0 { + continue + } + k := kv[:eq] + if strings.EqualFold(k, "SYSTEMROOT") { + // We already have it. + return env + } + } + return append(env, "SYSTEMROOT="+os.Getenv("SYSTEMROOT")) +} diff --git a/pty/start_windows_test.go b/pty/start_windows_test.go new file mode 100644 index 0000000000..faee269776 --- /dev/null +++ b/pty/start_windows_test.go @@ -0,0 +1,32 @@ +//go:build windows +// +build windows + +package pty_test + +import ( + "os/exec" + "testing" + + "github.com/coder/coder/pty/ptytest" + "github.com/stretchr/testify/require" + "go.uber.org/goleak" +) + +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} + +func TestStart(t *testing.T) { + t.Parallel() + t.Run("Echo", func(t *testing.T) { + t.Parallel() + pty := ptytest.Start(t, exec.Command("cmd.exe", "/c", "echo", "test")) + pty.ExpectMatch("test") + }) + t.Run("Resize", func(t *testing.T) { + t.Parallel() + pty := ptytest.Start(t, exec.Command("cmd.exe")) + err := pty.Resize(100, 50) + require.NoError(t, err) + }) +}