feat: Add support for executing processes with Windows ConPty (#311)

* Initial agent

* fix: Use buffered reader in peer to fix ShortBuffer

This prevents a io.ErrShortBuffer from occurring when the byte
slice being read is smaller than the chunks sent from the opposite
pipe.

This makes sense for unordered connections, where transmission is
not guarunteed, but does not make sense for TCP-like connections.

We use a bufio.Reader when ordered to ensure data isn't lost.

* SSH server works!

* Start Windows support

* Something works

* Refactor pty package to support Windows spawn

* SSH server now works on Windows

* Fix non-Windows

* Fix Linux PTY render

* FIx linux build tests

* Remove agent and wintest

* Add test for Windows resize

* Fix linting errors

* Add Windows environment variables

* Add strings import

* Add comment for attrs

* Add goleak

* Add require import
This commit is contained in:
Kyle Carberry
2022-02-17 10:44:49 -06:00
committed by GitHub
parent c2ad91bb74
commit 503d09c149
32 changed files with 582 additions and 1140 deletions

View File

@ -33,6 +33,7 @@
"drpcserver",
"fatih",
"goleak",
"gossh",
"hashicorp",
"httpmw",
"isatty",
@ -51,9 +52,12 @@
"protobuf",
"provisionerd",
"provisionersdk",
"ptty",
"ptytest",
"retrier",
"sdkproto",
"stretchr",
"tcpip",
"tfexec",
"tfstate",
"unconvert",

View File

@ -8,7 +8,7 @@ import (
"github.com/coder/coder/cli/clitest"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/console"
"github.com/coder/coder/pty/ptytest"
)
func TestMain(m *testing.M) {
@ -21,11 +21,12 @@ func TestCli(t *testing.T) {
client := coderdtest.New(t)
cmd, config := clitest.New(t)
clitest.SetupConfig(t, client, config)
cons := console.New(t, cmd)
pty := ptytest.New(t)
cmd.SetIn(pty.Input())
cmd.SetOut(pty.Output())
go func() {
err := cmd.Execute()
require.NoError(t, err)
}()
_, err := cons.ExpectString("coder")
require.NoError(t, err)
pty.ExpectMatch("coder")
}

View File

@ -3,10 +3,11 @@ package cli_test
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/coder/coder/cli/clitest"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/console"
"github.com/stretchr/testify/require"
"github.com/coder/coder/pty/ptytest"
)
func TestLogin(t *testing.T) {
@ -26,7 +27,9 @@ func TestLogin(t *testing.T) {
// accurately detect Windows ptys when they are not attached to a process:
// https://github.com/mattn/go-isatty/issues/59
root, _ := clitest.New(t, "login", client.URL.String(), "--force-tty")
cons := console.New(t, root)
pty := ptytest.New(t)
root.SetIn(pty.Input())
root.SetOut(pty.Output())
go func() {
err := root.Execute()
require.NoError(t, err)
@ -42,12 +45,9 @@ func TestLogin(t *testing.T) {
for i := 0; i < len(matches); i += 2 {
match := matches[i]
value := matches[i+1]
_, err := cons.ExpectString(match)
require.NoError(t, err)
_, err = cons.SendLine(value)
require.NoError(t, err)
pty.ExpectMatch(match)
pty.WriteLine(value)
}
_, err := cons.ExpectString("Welcome to Coder")
require.NoError(t, err)
pty.ExpectMatch("Welcome to Coder")
})
}

View File

@ -7,10 +7,10 @@ import (
"github.com/coder/coder/cli/clitest"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/console"
"github.com/coder/coder/database"
"github.com/coder/coder/provisioner/echo"
"github.com/coder/coder/provisionersdk/proto"
"github.com/coder/coder/pty/ptytest"
)
func TestProjectCreate(t *testing.T) {
@ -26,7 +26,9 @@ func TestProjectCreate(t *testing.T) {
cmd, root := clitest.New(t, "projects", "create", "--directory", source, "--provisioner", string(database.ProvisionerTypeEcho))
clitest.SetupConfig(t, client, root)
_ = coderdtest.NewProvisionerDaemon(t, client)
console := console.New(t, cmd)
pty := ptytest.New(t)
cmd.SetIn(pty.Input())
cmd.SetOut(pty.Output())
closeChan := make(chan struct{})
go func() {
err := cmd.Execute()
@ -43,10 +45,8 @@ func TestProjectCreate(t *testing.T) {
for i := 0; i < len(matches); i += 2 {
match := matches[i]
value := matches[i+1]
_, err := console.ExpectString(match)
require.NoError(t, err)
_, err = console.SendLine(value)
require.NoError(t, err)
pty.ExpectMatch(match)
pty.WriteLine(value)
}
<-closeChan
})
@ -73,7 +73,9 @@ func TestProjectCreate(t *testing.T) {
cmd, root := clitest.New(t, "projects", "create", "--directory", source, "--provisioner", string(database.ProvisionerTypeEcho))
clitest.SetupConfig(t, client, root)
coderdtest.NewProvisionerDaemon(t, client)
cons := console.New(t, cmd)
pty := ptytest.New(t)
cmd.SetIn(pty.Input())
cmd.SetOut(pty.Output())
closeChan := make(chan struct{})
go func() {
err := cmd.Execute()
@ -91,10 +93,8 @@ func TestProjectCreate(t *testing.T) {
for i := 0; i < len(matches); i += 2 {
match := matches[i]
value := matches[i+1]
_, err := cons.ExpectString(match)
require.NoError(t, err)
_, err = cons.SendLine(value)
require.NoError(t, err)
pty.ExpectMatch(match)
pty.WriteLine(value)
}
<-closeChan
})

View File

@ -12,7 +12,6 @@ import (
"github.com/manifoldco/promptui"
"github.com/mattn/go-isatty"
"github.com/spf13/cobra"
"golang.org/x/xerrors"
"github.com/coder/coder/cli/config"
"github.com/coder/coder/coderd"
@ -138,14 +137,9 @@ func isTTY(cmd *cobra.Command) bool {
}
func prompt(cmd *cobra.Command, prompt *promptui.Prompt) (string, error) {
var ok bool
prompt.Stdin, ok = cmd.InOrStdin().(io.ReadCloser)
if !ok {
return "", xerrors.New("stdin must be a readcloser")
}
prompt.Stdout, ok = cmd.OutOrStdout().(io.WriteCloser)
if !ok {
return "", xerrors.New("stdout must be a readcloser")
prompt.Stdin = io.NopCloser(cmd.InOrStdin())
prompt.Stdout = readWriteCloser{
Writer: cmd.OutOrStdout(),
}
// The prompt library displays defaults in a jarring way for the user
@ -199,3 +193,10 @@ func prompt(cmd *cobra.Command, prompt *promptui.Prompt) (string, error) {
return value, err
}
// readWriteCloser fakes reads, writes, and closing!
type readWriteCloser struct {
io.Reader
io.Writer
io.Closer
}

View File

@ -3,12 +3,13 @@ package cli_test
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/coder/coder/cli/clitest"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/console"
"github.com/coder/coder/provisioner/echo"
"github.com/coder/coder/provisionersdk/proto"
"github.com/stretchr/testify/require"
"github.com/coder/coder/pty/ptytest"
)
func TestWorkspaceCreate(t *testing.T) {
@ -36,7 +37,9 @@ func TestWorkspaceCreate(t *testing.T) {
cmd, root := clitest.New(t, "workspaces", "create", project.Name)
clitest.SetupConfig(t, client, root)
cons := console.New(t, cmd)
pty := ptytest.New(t)
cmd.SetIn(pty.Input())
cmd.SetOut(pty.Output())
closeChan := make(chan struct{})
go func() {
err := cmd.Execute()
@ -51,13 +54,10 @@ func TestWorkspaceCreate(t *testing.T) {
for i := 0; i < len(matches); i += 2 {
match := matches[i]
value := matches[i+1]
_, err := cons.ExpectString(match)
require.NoError(t, err)
_, err = cons.SendLine(value)
require.NoError(t, err)
pty.ExpectMatch(match)
pty.WriteLine(value)
}
_, err := cons.ExpectString("Create")
require.NoError(t, err)
pty.ExpectMatch("Create")
<-closeChan
})
}

View File

@ -5,13 +5,14 @@ import (
"net/http"
"testing"
"github.com/stretchr/testify/require"
"github.com/coder/coder/coderd"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/database"
"github.com/coder/coder/provisioner/echo"
"github.com/coder/coder/provisionersdk/proto"
"github.com/stretchr/testify/require"
)
func TestPostProjectImportByOrganization(t *testing.T) {

View File

@ -5,12 +5,13 @@ import (
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/coder/coder/coderd"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/provisioner/echo"
"github.com/coder/coder/provisionersdk/proto"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
)
func TestCreateProjectImportJob(t *testing.T) {

View File

@ -1,107 +0,0 @@
//go:build windows
// +build windows
// Original copyright 2020 ActiveState Software. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file
package conpty
import (
"fmt"
"io"
"os"
"golang.org/x/sys/windows"
)
// ConPty represents a windows pseudo console.
type ConPty struct {
hpCon windows.Handle
outPipePseudoConsoleSide windows.Handle
outPipeOurSide windows.Handle
inPipeOurSide windows.Handle
inPipePseudoConsoleSide windows.Handle
consoleSize uintptr
outFilePseudoConsoleSide *os.File
outFileOurSide *os.File
inFilePseudoConsoleSide *os.File
inFileOurSide *os.File
closed bool
}
// New returns a new ConPty pseudo terminal device
func New(columns int16, rows int16) (*ConPty, error) {
c := &ConPty{
consoleSize: uintptr(columns) + (uintptr(rows) << 16),
}
return c, c.createPseudoConsoleAndPipes()
}
// Close closes the pseudo-terminal and cleans up all attached resources
func (c *ConPty) Close() error {
// Trying to close these pipes multiple times will result in an
// access violation
if c.closed {
return nil
}
err := closePseudoConsole(c.hpCon)
c.outFilePseudoConsoleSide.Close()
c.outFileOurSide.Close()
c.inFilePseudoConsoleSide.Close()
c.inFileOurSide.Close()
c.closed = true
return err
}
// OutPipe returns the output pipe of the pseudo terminal
func (c *ConPty) OutPipe() *os.File {
return c.outFilePseudoConsoleSide
}
func (c *ConPty) Reader() io.Reader {
return c.outFileOurSide
}
// InPipe returns input pipe of the pseudo terminal
// Note: It is safer to use the Write method to prevent partially-written VT sequences
// from corrupting the terminal
func (c *ConPty) InPipe() *os.File {
return c.inFilePseudoConsoleSide
}
func (c *ConPty) WriteString(str string) (int, error) {
return c.inFileOurSide.WriteString(str)
}
func (c *ConPty) createPseudoConsoleAndPipes() error {
// Create the stdin pipe
if err := windows.CreatePipe(&c.inPipePseudoConsoleSide, &c.inPipeOurSide, nil, 0); err != nil {
return err
}
// Create the stdout pipe
if err := windows.CreatePipe(&c.outPipeOurSide, &c.outPipePseudoConsoleSide, nil, 0); err != nil {
return err
}
// Create the pty with our stdin/stdout
if err := createPseudoConsole(c.consoleSize, c.inPipePseudoConsoleSide, c.outPipePseudoConsoleSide, &c.hpCon); err != nil {
return fmt.Errorf("failed to create pseudo console: %d, %v", uintptr(c.hpCon), err)
}
c.outFilePseudoConsoleSide = os.NewFile(uintptr(c.outPipePseudoConsoleSide), "|0")
c.outFileOurSide = os.NewFile(uintptr(c.outPipeOurSide), "|1")
c.inFilePseudoConsoleSide = os.NewFile(uintptr(c.inPipePseudoConsoleSide), "|2")
c.inFileOurSide = os.NewFile(uintptr(c.inPipeOurSide), "|3")
c.closed = false
return nil
}
func (c *ConPty) Resize(cols uint16, rows uint16) error {
return resizePseudoConsole(c.hpCon, uintptr(cols)+(uintptr(rows)<<16))
}

View File

@ -1,53 +0,0 @@
//go:build windows
// +build windows
// Copyright 2020 ActiveState Software. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file
package conpty
import (
"unsafe"
"golang.org/x/sys/windows"
)
var (
kernel32 = windows.NewLazySystemDLL("kernel32.dll")
procResizePseudoConsole = kernel32.NewProc("ResizePseudoConsole")
procCreatePseudoConsole = kernel32.NewProc("CreatePseudoConsole")
procClosePseudoConsole = kernel32.NewProc("ClosePseudoConsole")
)
func createPseudoConsole(consoleSize uintptr, ptyIn windows.Handle, ptyOut windows.Handle, hpCon *windows.Handle) (err error) {
r1, _, e1 := procCreatePseudoConsole.Call(
consoleSize,
uintptr(ptyIn),
uintptr(ptyOut),
0,
uintptr(unsafe.Pointer(hpCon)),
)
if r1 != 0 { // !S_OK
err = e1
}
return
}
func resizePseudoConsole(handle windows.Handle, consoleSize uintptr) (err error) {
r1, _, e1 := procResizePseudoConsole.Call(uintptr(handle), consoleSize)
if r1 != 0 { // !S_OK
err = e1
}
return
}
func closePseudoConsole(handle windows.Handle) (err error) {
r1, _, e1 := procClosePseudoConsole.Call(uintptr(handle))
if r1 == 0 {
err = e1
}
return
}

View File

@ -1,163 +0,0 @@
// Copyright 2018 Netflix, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package console
import (
"bufio"
"fmt"
"io"
"io/ioutil"
"log"
"os"
"unicode/utf8"
"github.com/coder/coder/console/pty"
)
// Console is an interface to automate input and output for interactive
// applications. Console can block until a specified output is received and send
// input back on it's tty. Console can also multiplex other sources of input
// and multiplex its output to other writers.
type Console struct {
opts Opts
pty pty.Pty
runeReader *bufio.Reader
closers []io.Closer
}
// Opt allows setting Console options.
type Opt func(*Opts) error
// Opts provides additional options on creating a Console.
type Opts struct {
Logger *log.Logger
Stdouts []io.Writer
ExpectObservers []Observer
}
// Observer provides an interface for a function callback that will
// be called after each Expect operation.
// matchers will be the list of active matchers when an error occurred,
// or a list of matchers that matched `buf` when err is nil.
// buf is the captured output that was matched against.
// err is error that might have occurred. May be nil.
type Observer func(matchers []Matcher, buf string, err error)
// WithStdout adds writers that Console duplicates writes to, similar to the
// Unix tee(1) command.
//
// Each write is written to each listed writer, one at a time. Console is the
// last writer, writing to it's internal buffer for matching expects.
// If a listed writer returns an error, that overall write operation stops and
// returns the error; it does not continue down the list.
func WithStdout(writers ...io.Writer) Opt {
return func(opts *Opts) error {
opts.Stdouts = append(opts.Stdouts, writers...)
return nil
}
}
// WithLogger adds a logger for Console to log debugging information to. By
// default Console will discard logs.
func WithLogger(logger *log.Logger) Opt {
return func(opts *Opts) error {
opts.Logger = logger
return nil
}
}
// WithExpectObserver adds an ExpectObserver to allow monitoring Expect operations.
func WithExpectObserver(observers ...Observer) Opt {
return func(opts *Opts) error {
opts.ExpectObservers = append(opts.ExpectObservers, observers...)
return nil
}
}
// NewConsole returns a new Console with the given options.
func NewConsole(opts ...Opt) (*Console, error) {
options := Opts{
Logger: log.New(ioutil.Discard, "", 0),
}
for _, opt := range opts {
if err := opt(&options); err != nil {
return nil, err
}
}
consolePty, err := pty.New()
if err != nil {
return nil, err
}
closers := []io.Closer{consolePty}
reader := consolePty.Reader()
cons := &Console{
opts: options,
pty: consolePty,
runeReader: bufio.NewReaderSize(reader, utf8.UTFMax),
closers: closers,
}
return cons, nil
}
// Tty returns an input Tty for accepting input
func (c *Console) InTty() *os.File {
return c.pty.InPipe()
}
// OutTty returns an output tty for writing
func (c *Console) OutTty() *os.File {
return c.pty.OutPipe()
}
// Close closes Console's tty. Calling Close will unblock Expect and ExpectEOF.
func (c *Console) Close() error {
for _, fd := range c.closers {
err := fd.Close()
if err != nil {
c.Logf("failed to close: %s", err)
}
}
return nil
}
// Send writes string s to Console's tty.
func (c *Console) Send(s string) (int, error) {
c.Logf("console send: %q", s)
n, err := c.pty.WriteString(s)
return n, err
}
// SendLine writes string s to Console's tty with a trailing newline.
func (c *Console) SendLine(s string) (int, error) {
bytes, err := c.Send(fmt.Sprintf("%s\n", s))
return bytes, err
}
// Log prints to Console's logger.
// Arguments are handled in the manner of fmt.Print.
func (c *Console) Log(v ...interface{}) {
c.opts.Logger.Print(v...)
}
// Logf prints to Console's logger.
// Arguments are handled in the manner of fmt.Printf.
func (c *Console) Logf(format string, v ...interface{}) {
c.opts.Logger.Printf(format, v...)
}

View File

@ -1,19 +0,0 @@
// Copyright 2018 Netflix, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package expect provides an expect-like interface to automate control of
// applications. It is unlike expect in that it does not spawn or manage
// process lifecycle. This package only focuses on expecting output and sending
// input through it's psuedoterminal.
package console

View File

@ -1,109 +0,0 @@
// Copyright 2018 Netflix, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package console
import (
"bufio"
"bytes"
"fmt"
"io"
"unicode/utf8"
)
// Expectf reads from the Console's tty until the provided formatted string
// is read or an error occurs, and returns the buffer read by Console.
func (c *Console) Expectf(format string, args ...interface{}) (string, error) {
return c.Expect(String(fmt.Sprintf(format, args...)))
}
// ExpectString reads from Console's tty until the provided string is read or
// an error occurs, and returns the buffer read by Console.
func (c *Console) ExpectString(s string) (string, error) {
return c.Expect(String(s))
}
// Expect reads from Console's tty until a condition specified from opts is
// encountered or an error occurs, and returns the buffer read by console.
// No extra bytes are read once a condition is met, so if a program isn't
// expecting input yet, it will be blocked. Sends are queued up in tty's
// internal buffer so that the next Expect will read the remaining bytes (i.e.
// rest of prompt) as well as its conditions.
func (c *Console) Expect(opts ...ExpectOpt) (string, error) {
var options ExpectOpts
for _, opt := range opts {
if err := opt(&options); err != nil {
return "", err
}
}
buf := new(bytes.Buffer)
writer := io.MultiWriter(append(c.opts.Stdouts, buf)...)
runeWriter := bufio.NewWriterSize(writer, utf8.UTFMax)
var matcher Matcher
var err error
defer func() {
for _, observer := range c.opts.ExpectObservers {
if matcher != nil {
observer([]Matcher{matcher}, buf.String(), err)
return
}
observer(options.Matchers, buf.String(), err)
}
}()
for {
var r rune
r, _, err = c.runeReader.ReadRune()
if err != nil {
matcher = options.Match(err)
if matcher != nil {
err = nil
break
}
return buf.String(), err
}
c.Logf("expect read: %q", string(r))
_, err = runeWriter.WriteRune(r)
if err != nil {
return buf.String(), err
}
// Immediately flush rune to the underlying writers.
err = runeWriter.Flush()
if err != nil {
return buf.String(), err
}
matcher = options.Match(buf)
if matcher != nil {
break
}
}
if matcher != nil {
cb, ok := matcher.(CallbackMatcher)
if ok {
err = cb.Callback(buf)
if err != nil {
return buf.String(), err
}
}
}
return buf.String(), err
}

View File

@ -1,139 +0,0 @@
// Copyright 2018 Netflix, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package console
import (
"bytes"
"strings"
"time"
)
// ExpectOpt allows settings Expect options.
type ExpectOpt func(*ExpectOpts) error
// Callback is a callback function to execute if a match is found for
// the chained matcher.
type Callback func(buf *bytes.Buffer) error
// ExpectOpts provides additional options on Expect.
type ExpectOpts struct {
Matchers []Matcher
ReadTimeout *time.Duration
}
// Match sequentially calls Match on all matchers in ExpectOpts and returns the
// first matcher if a match exists, otherwise nil.
func (eo ExpectOpts) Match(v interface{}) Matcher {
for _, matcher := range eo.Matchers {
if matcher.Match(v) {
return matcher
}
}
return nil
}
// CallbackMatcher is a matcher that provides a Callback function.
type CallbackMatcher interface {
// Callback executes the matcher's callback with the content buffer at the
// time of match.
Callback(buf *bytes.Buffer) error
}
// Matcher provides an interface for finding a match in content read from
// Console's tty.
type Matcher interface {
// Match returns true iff a match is found.
Match(v interface{}) bool
Criteria() interface{}
}
// stringMatcher fulfills the Matcher interface to match strings against a given
// bytes.Buffer.
type stringMatcher struct {
str string
}
func (sm *stringMatcher) Match(v interface{}) bool {
buf, ok := v.(*bytes.Buffer)
if !ok {
return false
}
if strings.Contains(buf.String(), sm.str) {
return true
}
return false
}
func (sm *stringMatcher) Criteria() interface{} {
return sm.str
}
// allMatcher fulfills the Matcher interface to match a group of ExpectOpt
// against any value.
type allMatcher struct {
options ExpectOpts
}
func (am *allMatcher) Match(v interface{}) bool {
var matchers []Matcher
for _, matcher := range am.options.Matchers {
if matcher.Match(v) {
continue
}
matchers = append(matchers, matcher)
}
am.options.Matchers = matchers
return len(matchers) == 0
}
func (am *allMatcher) Criteria() interface{} {
var criteria []interface{}
for _, matcher := range am.options.Matchers {
criteria = append(criteria, matcher.Criteria())
}
return criteria
}
// All adds an Expect condition to exit if the content read from Console's tty
// matches all of the provided ExpectOpt, in any order.
func All(expectOpts ...ExpectOpt) ExpectOpt {
return func(opts *ExpectOpts) error {
var options ExpectOpts
for _, opt := range expectOpts {
if err := opt(&options); err != nil {
return err
}
}
opts.Matchers = append(opts.Matchers, &allMatcher{
options: options,
})
return nil
}
}
// String adds an Expect condition to exit if the content read from Console's
// tty contains any of the given strings.
func String(strs ...string) ExpectOpt {
return func(opts *ExpectOpts) error {
for _, str := range strs {
opts.Matchers = append(opts.Matchers, &stringMatcher{
str: str,
})
}
return nil
}
}

View File

@ -1,163 +0,0 @@
// Copyright 2018 Netflix, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package console_test
import (
"bytes"
"testing"
"github.com/stretchr/testify/require"
. "github.com/coder/coder/console"
)
func TestExpectOptString(t *testing.T) {
t.Parallel()
tests := []struct {
title string
opt ExpectOpt
data string
expected bool
}{
{
"No args",
String(),
"Hello world",
false,
},
{
"Single arg",
String("Hello"),
"Hello world",
true,
},
{
"Multiple arg",
String("other", "world"),
"Hello world",
true,
},
{
"No matches",
String("hello"),
"Hello world",
false,
},
}
for _, test := range tests {
test := test
t.Run(test.title, func(t *testing.T) {
t.Parallel()
var options ExpectOpts
err := test.opt(&options)
require.Nil(t, err)
buf := new(bytes.Buffer)
_, err = buf.WriteString(test.data)
require.Nil(t, err)
matcher := options.Match(buf)
if test.expected {
require.NotNil(t, matcher)
} else {
require.Nil(t, matcher)
}
})
}
}
func TestExpectOptAll(t *testing.T) {
t.Parallel()
tests := []struct {
title string
opt ExpectOpt
data string
expected bool
}{
{
"No opts",
All(),
"Hello world",
true,
},
{
"Single string match",
All(String("Hello")),
"Hello world",
true,
},
{
"Single string no match",
All(String("Hello")),
"No match",
false,
},
{
"Ordered strings match",
All(String("Hello"), String("world")),
"Hello world",
true,
},
{
"Ordered strings not all match",
All(String("Hello"), String("world")),
"Hello",
false,
},
{
"Unordered strings",
All(String("world"), String("Hello")),
"Hello world",
true,
},
{
"Unordered strings not all match",
All(String("world"), String("Hello")),
"Hello",
false,
},
{
"Repeated strings match",
All(String("Hello"), String("Hello")),
"Hello world",
true,
},
}
for _, test := range tests {
test := test
t.Run(test.title, func(t *testing.T) {
t.Parallel()
var options ExpectOpts
err := test.opt(&options)
require.Nil(t, err)
buf := new(bytes.Buffer)
_, err = buf.WriteString(test.data)
require.Nil(t, err)
matcher := options.Match(buf)
if test.expected {
require.NotNil(t, matcher)
} else {
require.Nil(t, matcher)
}
})
}
}

View File

@ -1,181 +0,0 @@
// Copyright 2018 Netflix, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package console_test
import (
"bufio"
"errors"
"fmt"
"io"
"runtime/debug"
"strings"
"sync"
"testing"
"golang.org/x/xerrors"
. "github.com/coder/coder/console"
)
var (
ErrWrongAnswer = xerrors.New("wrong answer")
)
type Survey struct {
Prompt string
Answer string
}
func Prompt(in io.Reader, out io.Writer) error {
reader := bufio.NewReader(in)
for _, survey := range []Survey{
{
"What is 1+1?", "2",
},
{
"What is Netflix backwards?", "xilfteN",
},
} {
_, err := fmt.Fprintf(out, "%s: ", survey.Prompt)
if err != nil {
return err
}
text, err := reader.ReadString('\n')
if err != nil {
return err
}
_, err = fmt.Fprint(out, text)
if err != nil {
return err
}
text = strings.TrimSpace(text)
if text != survey.Answer {
return ErrWrongAnswer
}
}
return nil
}
func newTestConsole(t *testing.T, opts ...Opt) (*Console, error) {
opts = append([]Opt{
expectNoError(t),
}, opts...)
return NewConsole(opts...)
}
func expectNoError(t *testing.T) Opt {
return WithExpectObserver(
func(matchers []Matcher, buf string, err error) {
if err == nil {
return
}
if len(matchers) == 0 {
t.Fatalf("Error occurred while matching %q: %s\n%s", buf, err, string(debug.Stack()))
} else {
var criteria []string
for _, matcher := range matchers {
criteria = append(criteria, fmt.Sprintf("%q", matcher.Criteria()))
}
t.Fatalf("Failed to find [%s] in %q: %s\n%s", strings.Join(criteria, ", "), buf, err, string(debug.Stack()))
}
},
)
}
func testCloser(t *testing.T, closer io.Closer) {
if err := closer.Close(); err != nil {
t.Errorf("Close failed: %s", err)
debug.PrintStack()
}
}
func TestExpectf(t *testing.T) {
t.Parallel()
console, err := newTestConsole(t)
if err != nil {
t.Errorf("Expected no error but got'%s'", err)
}
defer testCloser(t, console)
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
console.Expectf("What is 1+%d?", 1)
console.SendLine("2")
console.Expectf("What is %s backwards?", "Netflix")
console.SendLine("xilfteN")
}()
err = Prompt(console.InTty(), console.OutTty())
if err != nil {
t.Errorf("Expected no error but got '%s'", err)
}
wg.Wait()
}
func TestExpect(t *testing.T) {
t.Parallel()
console, err := newTestConsole(t)
if err != nil {
t.Errorf("Expected no error but got'%s'", err)
}
defer testCloser(t, console)
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
console.ExpectString("What is 1+1?")
console.SendLine("2")
console.ExpectString("What is Netflix backwards?")
console.SendLine("xilfteN")
}()
err = Prompt(console.InTty(), console.OutTty())
if err != nil {
t.Errorf("Expected no error but got '%s'", err)
}
wg.Wait()
}
func TestExpectOutput(t *testing.T) {
t.Parallel()
console, err := newTestConsole(t)
if err != nil {
t.Errorf("Expected no error but got'%s'", err)
}
defer testCloser(t, console)
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
console.ExpectString("What is 1+1?")
console.SendLine("3")
}()
err = Prompt(console.InTty(), console.OutTty())
if err == nil || !errors.Is(err, ErrWrongAnswer) {
t.Errorf("Expected error '%s' but got '%s' instead", ErrWrongAnswer, err)
}
wg.Wait()
}

View File

@ -1,21 +0,0 @@
package pty
import (
"io"
"os"
)
// Pty is the minimal pseudo-tty interface we require.
type Pty interface {
InPipe() *os.File
OutPipe() *os.File
Resize(cols uint16, rows uint16) error
WriteString(str string) (int, error)
Reader() io.Reader
Close() error
}
// New creates a new Pty.
func New() (Pty, error) {
return newPty()
}

View File

@ -1,78 +0,0 @@
//go:build windows
// +build windows
package pty
import (
"io"
"os"
"golang.org/x/sys/windows"
"github.com/coder/coder/console/conpty"
)
func newPty() (Pty, error) {
// We use the CreatePseudoConsole API which was introduced in build 17763
vsn := windows.RtlGetVersion()
if vsn.MajorVersion < 10 ||
vsn.BuildNumber < 17763 {
// If the CreatePseudoConsole API is not available, we fall back to a simpler
// implementation that doesn't create an actual PTY - just uses os.Pipe
return pipePty()
}
return conpty.New(80, 80)
}
func pipePty() (Pty, error) {
inFilePipeSide, inFileOurSide, err := os.Pipe()
if err != nil {
return nil, err
}
outFileOurSide, outFilePipeSide, err := os.Pipe()
if err != nil {
return nil, err
}
return &pipePtyVal{
inFilePipeSide,
inFileOurSide,
outFileOurSide,
outFilePipeSide,
}, nil
}
type pipePtyVal struct {
inFilePipeSide, inFileOurSide *os.File
outFileOurSide, outFilePipeSide *os.File
}
func (p *pipePtyVal) InPipe() *os.File {
return p.inFilePipeSide
}
func (p *pipePtyVal) OutPipe() *os.File {
return p.outFilePipeSide
}
func (p *pipePtyVal) Reader() io.Reader {
return p.outFileOurSide
}
func (p *pipePtyVal) WriteString(str string) (int, error) {
return p.inFileOurSide.WriteString(str)
}
func (p *pipePtyVal) Resize(uint16, uint16) error {
return nil
}
func (p *pipePtyVal) Close() error {
p.inFileOurSide.Close()
p.inFilePipeSide.Close()
p.outFilePipeSide.Close()
p.outFileOurSide.Close()
return nil
}

View File

@ -1,45 +0,0 @@
package console
import (
"bufio"
"io"
"regexp"
"testing"
"github.com/spf13/cobra"
"github.com/stretchr/testify/require"
)
var (
// Used to ensure terminal output doesn't have anything crazy!
// See: https://stackoverflow.com/a/29497680
stripAnsi = regexp.MustCompile("[\u001B\u009B][[\\]()#;?]*(?:(?:(?:[a-zA-Z\\d]*(?:;[a-zA-Z\\d]*)*)?\u0007)|(?:(?:\\d{1,4}(?:;\\d{0,4})*)?[\\dA-PRZcf-ntqry=><~]))")
)
// New creates a new TTY bound to the command provided.
// All ANSI escape codes are stripped to provide clean output.
func New(t *testing.T, cmd *cobra.Command) *Console {
reader, writer := io.Pipe()
scanner := bufio.NewScanner(reader)
t.Cleanup(func() {
_ = reader.Close()
_ = writer.Close()
})
go func() {
for scanner.Scan() {
if scanner.Err() != nil {
return
}
t.Log(stripAnsi.ReplaceAllString(scanner.Text(), ""))
}
}()
console, err := NewConsole(WithStdout(writer))
require.NoError(t, err)
t.Cleanup(func() {
console.Close()
})
cmd.SetIn(console.InTty())
cmd.SetOut(console.OutTty())
return console
}

2
go.mod
View File

@ -20,6 +20,7 @@ require (
github.com/coder/retry v1.3.0
github.com/creack/pty v1.1.17
github.com/fatih/color v1.13.0
github.com/gliderlabs/ssh v0.3.3
github.com/go-chi/chi/v5 v5.0.7
github.com/go-chi/render v1.0.1
github.com/go-playground/validator/v10 v10.10.0
@ -64,6 +65,7 @@ require (
github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 // indirect
github.com/agext/levenshtein v1.2.3 // indirect
github.com/alecthomas/chroma v0.10.0 // indirect
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect
github.com/apparentlymart/go-textseg/v13 v13.0.0 // indirect
github.com/cenkalti/backoff/v4 v4.1.2 // indirect
github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e // indirect

4
go.sum
View File

@ -132,6 +132,8 @@ github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRF
github.com/alexflint/go-filemutex v0.0.0-20171022225611-72bdc8eae2ae/go.mod h1:CgnQgUtFrFz9mxFNtED3jI5tLDjKlOM+oUF/sTk6ps0=
github.com/andybalholm/crlf v0.0.0-20171020200849-670099aa064f/go.mod h1:k8feO4+kXDxro6ErPXBRTJ/ro2mf0SsFG8s7doP9kJE=
github.com/anmitsu/go-shlex v0.0.0-20161002113705-648efa622239/go.mod h1:2FmKhYUyUczH0OGQWaF5ceTx0UBShxjsH6f8oGKYe2c=
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8=
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4=
github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY=
github.com/apache/arrow/go/arrow v0.0.0-20210818145353-234c94e4ce64/go.mod h1:2qMFB56yOP3KzkB3PbYZ4AlUFg3a88F67TIx5lB/WwY=
github.com/apache/arrow/go/arrow v0.0.0-20211013220434-5962184e7a30/go.mod h1:Q7yQnSMnLvcXlZ8RV+jwz/6y1rQTqbX6C82SndT52Zs=
@ -441,6 +443,8 @@ github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm
github.com/gin-gonic/gin v1.6.3 h1:ahKqKTFpO5KTPHxWZjEdPScmYaGtLo8Y4DMHoEsnp14=
github.com/gin-gonic/gin v1.6.3/go.mod h1:75u5sXoLsGZoRN5Sgbi1eraJ4GU3++wFwWzhwvtwp4M=
github.com/gliderlabs/ssh v0.2.2/go.mod h1:U7qILu1NlMHj9FlMhZLlkCdDnU1DBEAqr0aevW3Awn0=
github.com/gliderlabs/ssh v0.3.3 h1:mBQ8NiOgDkINJrZtoizkC3nDNYgSaWtxyem6S2XHBtA=
github.com/gliderlabs/ssh v0.3.3/go.mod h1:ZSS+CUoKHDrqVakTfTWUlKSr9MtMFkC4UvtQKD7O914=
github.com/go-chi/chi/v5 v5.0.7 h1:rDTPXLDHGATaeHvVlLcR4Qe0zftYethFucbjVQ1PxU8=
github.com/go-chi/chi/v5 v5.0.7/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8=
github.com/go-chi/render v1.0.1 h1:4/5tis2cKaNdnv9zFLfXzcquC9HbeZgCnxGnKrltBS8=

View File

@ -263,6 +263,11 @@ func (c *Channel) Label() string {
return c.dc.Label()
}
// Protocol returns the protocol of the underlying DataChannel.
func (c *Channel) Protocol() string {
return c.dc.Protocol()
}
// NetConn wraps the DataChannel in a struct fulfilling net.Conn.
// Read, Write, and Close operations can still be used on the *Channel struct.
func (c *Channel) NetConn() net.Conn {

39
pty/pty.go Normal file
View File

@ -0,0 +1,39 @@
package pty
import (
"io"
)
// PTY is a minimal interface for interacting with a TTY.
type PTY interface {
io.Closer
// Output handles TTY output.
//
// cmd.SetOutput(pty.Output()) would be used to specify a command
// uses the output stream for writing.
//
// The same stream could be read to validate output.
Output() io.ReadWriter
// Input handles TTY input.
//
// cmd.SetInput(pty.Input()) would be used to specify a command
// uses the PTY input for reading.
//
// The same stream would be used to provide user input: pty.Input().Write(...)
Input() io.ReadWriter
// Resize sets the size of the PTY.
Resize(cols uint16, rows uint16) error
}
// New constructs a new Pty.
func New() (PTY, error) {
return newPty()
}
type readWriter struct {
io.Reader
io.Writer
}

View File

@ -10,46 +10,44 @@ import (
"github.com/creack/pty"
)
func newPty() (Pty, error) {
func newPty() (PTY, error) {
ptyFile, ttyFile, err := pty.Open()
if err != nil {
return nil, err
}
return &unixPty{
return &otherPty{
pty: ptyFile,
tty: ttyFile,
}, nil
}
type unixPty struct {
type otherPty struct {
pty, tty *os.File
}
func (p *unixPty) InPipe() *os.File {
return p.tty
func (p *otherPty) Input() io.ReadWriter {
return readWriter{
Reader: p.tty,
Writer: p.pty,
}
}
func (p *unixPty) OutPipe() *os.File {
return p.tty
func (p *otherPty) Output() io.ReadWriter {
return readWriter{
Reader: p.pty,
Writer: p.tty,
}
}
func (p *unixPty) Reader() io.Reader {
return p.pty
}
func (p *unixPty) WriteString(str string) (int, error) {
return p.pty.WriteString(str)
}
func (p *unixPty) Resize(cols uint16, rows uint16) error {
func (p *otherPty) Resize(cols uint16, rows uint16) error {
return pty.Setsize(p.tty, &pty.Winsize{
Rows: rows,
Cols: cols,
})
}
func (p *unixPty) Close() error {
func (p *otherPty) Close() error {
err := p.pty.Close()
if err != nil {
return err

107
pty/pty_windows.go Normal file
View File

@ -0,0 +1,107 @@
//go:build windows
// +build windows
package pty
import (
"io"
"os"
"sync"
"unsafe"
"golang.org/x/sys/windows"
"golang.org/x/xerrors"
)
var (
kernel32 = windows.NewLazySystemDLL("kernel32.dll")
procResizePseudoConsole = kernel32.NewProc("ResizePseudoConsole")
procCreatePseudoConsole = kernel32.NewProc("CreatePseudoConsole")
procClosePseudoConsole = kernel32.NewProc("ClosePseudoConsole")
)
// See: https://docs.microsoft.com/en-us/windows/console/creating-a-pseudoconsole-session
func newPty() (PTY, error) {
// We use the CreatePseudoConsole API which was introduced in build 17763
vsn := windows.RtlGetVersion()
if vsn.MajorVersion < 10 ||
vsn.BuildNumber < 17763 {
// If the CreatePseudoConsole API is not available, we fall back to a simpler
// implementation that doesn't create an actual PTY - just uses os.Pipe
return nil, xerrors.Errorf("pty not supported")
}
ptyWindows := &ptyWindows{}
var err error
ptyWindows.inputRead, ptyWindows.inputWrite, err = os.Pipe()
if err != nil {
return nil, err
}
ptyWindows.outputRead, ptyWindows.outputWrite, err = os.Pipe()
consoleSize := uintptr(80) + (uintptr(80) << 16)
ret, _, err := procCreatePseudoConsole.Call(
consoleSize,
uintptr(ptyWindows.inputRead.Fd()),
uintptr(ptyWindows.outputWrite.Fd()),
0,
uintptr(unsafe.Pointer(&ptyWindows.console)),
)
if int32(ret) < 0 {
return nil, xerrors.Errorf("create pseudo console (%d): %w", int32(ret), err)
}
return ptyWindows, nil
}
type ptyWindows struct {
console windows.Handle
outputWrite *os.File
outputRead *os.File
inputWrite *os.File
inputRead *os.File
closeMutex sync.Mutex
closed bool
}
func (p *ptyWindows) Output() io.ReadWriter {
return readWriter{
Reader: p.outputRead,
Writer: p.outputWrite,
}
}
func (p *ptyWindows) Input() io.ReadWriter {
return readWriter{
Reader: p.inputRead,
Writer: p.inputWrite,
}
}
func (p *ptyWindows) Resize(cols uint16, rows uint16) error {
ret, _, err := procResizePseudoConsole.Call(uintptr(p.console), uintptr(cols)+(uintptr(rows)<<16))
if ret != 0 {
return err
}
return nil
}
func (p *ptyWindows) Close() error {
p.closeMutex.Lock()
defer p.closeMutex.Unlock()
if p.closed {
return nil
}
p.closed = true
ret, _, err := procClosePseudoConsole.Call(uintptr(p.console))
if ret != 0 {
return xerrors.Errorf("close pseudo console: %w", err)
}
_ = p.outputRead.Close()
_ = p.inputWrite.Close()
return nil
}

95
pty/ptytest/ptytest.go Normal file
View File

@ -0,0 +1,95 @@
package ptytest
import (
"bufio"
"bytes"
"fmt"
"io"
"os/exec"
"regexp"
"strings"
"testing"
"unicode/utf8"
"github.com/stretchr/testify/require"
"github.com/coder/coder/pty"
)
var (
// Used to ensure terminal output doesn't have anything crazy!
// See: https://stackoverflow.com/a/29497680
stripAnsi = regexp.MustCompile("[\u001B\u009B][[\\]()#;?]*(?:(?:(?:[a-zA-Z\\d]*(?:;[a-zA-Z\\d]*)*)?\u0007)|(?:(?:\\d{1,4}(?:;\\d{0,4})*)?[\\dA-PRZcf-ntqry=><~]))")
)
func New(t *testing.T) *PTY {
ptty, err := pty.New()
require.NoError(t, err)
return create(t, ptty)
}
func Start(t *testing.T, cmd *exec.Cmd) *PTY {
ptty, err := pty.Start(cmd)
require.NoError(t, err)
return create(t, ptty)
}
func create(t *testing.T, ptty pty.PTY) *PTY {
reader, writer := io.Pipe()
scanner := bufio.NewScanner(reader)
t.Cleanup(func() {
_ = reader.Close()
_ = writer.Close()
})
go func() {
for scanner.Scan() {
if scanner.Err() != nil {
return
}
t.Log(stripAnsi.ReplaceAllString(scanner.Text(), ""))
}
}()
t.Cleanup(func() {
_ = ptty.Close()
})
return &PTY{
t: t,
PTY: ptty,
outputWriter: writer,
runeReader: bufio.NewReaderSize(ptty.Output(), utf8.UTFMax),
}
}
type PTY struct {
t *testing.T
pty.PTY
outputWriter io.Writer
runeReader *bufio.Reader
}
func (p *PTY) ExpectMatch(str string) string {
var buffer bytes.Buffer
multiWriter := io.MultiWriter(&buffer, p.outputWriter)
runeWriter := bufio.NewWriterSize(multiWriter, utf8.UTFMax)
for {
var r rune
r, _, err := p.runeReader.ReadRune()
require.NoError(p.t, err)
_, err = runeWriter.WriteRune(r)
require.NoError(p.t, err)
err = runeWriter.Flush()
require.NoError(p.t, err)
if strings.Contains(buffer.String(), str) {
break
}
}
return buffer.String()
}
func (p *PTY) WriteLine(str string) {
_, err := fmt.Fprintf(p.PTY.Input(), "%s\n", str)
require.NoError(p.t, err)
}

View File

@ -0,0 +1,15 @@
package ptytest_test
import (
"testing"
"github.com/coder/coder/pty/ptytest"
)
func TestPtytest(t *testing.T) {
t.Parallel()
pty := ptytest.New(t)
pty.Output().Write([]byte("write"))
pty.ExpectMatch("write")
pty.WriteLine("read")
}

7
pty/start.go Normal file
View File

@ -0,0 +1,7 @@
package pty
import "os/exec"
func Start(cmd *exec.Cmd) (PTY, error) {
return startPty(cmd)
}

34
pty/start_other.go Normal file
View File

@ -0,0 +1,34 @@
//go:build !windows
// +build !windows
package pty
import (
"os/exec"
"syscall"
"github.com/creack/pty"
)
func startPty(cmd *exec.Cmd) (PTY, error) {
ptty, tty, err := pty.Open()
if err != nil {
return nil, err
}
cmd.SysProcAttr = &syscall.SysProcAttr{
Setsid: true,
Setctty: true,
}
cmd.Stdout = tty
cmd.Stderr = tty
cmd.Stdin = tty
err = cmd.Start()
if err != nil {
_ = ptty.Close()
return nil, err
}
return &otherPty{
pty: ptty,
tty: tty,
}, nil
}

25
pty/start_other_test.go Normal file
View File

@ -0,0 +1,25 @@
//go:build !windows
// +build !windows
package pty_test
import (
"os/exec"
"testing"
"github.com/coder/coder/pty/ptytest"
"go.uber.org/goleak"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
func TestStart(t *testing.T) {
t.Parallel()
t.Run("Echo", func(t *testing.T) {
t.Parallel()
pty := ptytest.Start(t, exec.Command("echo", "test"))
pty.ExpectMatch("test")
})
}

149
pty/start_windows.go Normal file
View File

@ -0,0 +1,149 @@
//go:build windows
// +build windows
package pty
import (
"os"
"os/exec"
"strings"
"unicode/utf16"
"unsafe"
"golang.org/x/sys/windows"
)
// Allocates a PTY and starts the specified command attached to it.
// See: https://docs.microsoft.com/en-us/windows/console/creating-a-pseudoconsole-session#creating-the-hosted-process
func startPty(cmd *exec.Cmd) (PTY, error) {
fullPath, err := exec.LookPath(cmd.Path)
if err != nil {
return nil, err
}
pathPtr, err := windows.UTF16PtrFromString(fullPath)
if err != nil {
return nil, err
}
argsPtr, err := windows.UTF16PtrFromString(windows.ComposeCommandLine(cmd.Args))
if err != nil {
return nil, err
}
if cmd.Dir == "" {
cmd.Dir, err = os.Getwd()
if err != nil {
return nil, err
}
}
dirPtr, err := windows.UTF16PtrFromString(cmd.Dir)
if err != nil {
return nil, err
}
pty, err := newPty()
if err != nil {
return nil, err
}
winPty := pty.(*ptyWindows)
attrs, err := windows.NewProcThreadAttributeList(1)
if err != nil {
return nil, err
}
// Taken from: https://github.com/microsoft/hcsshim/blob/2314362e977aa03b3ed245a4beb12d00422af0e2/internal/winapi/process.go#L6
err = attrs.Update(0x20016, unsafe.Pointer(winPty.console), unsafe.Sizeof(winPty.console))
if err != nil {
return nil, err
}
startupInfo := &windows.StartupInfoEx{}
startupInfo.ProcThreadAttributeList = attrs.List()
startupInfo.StartupInfo.Flags = windows.STARTF_USESTDHANDLES
startupInfo.StartupInfo.Cb = uint32(unsafe.Sizeof(*startupInfo))
var processInfo windows.ProcessInformation
err = windows.CreateProcess(
pathPtr,
argsPtr,
nil,
nil,
false,
// https://docs.microsoft.com/en-us/windows/win32/procthread/process-creation-flags#create_unicode_environment
windows.CREATE_UNICODE_ENVIRONMENT|windows.EXTENDED_STARTUPINFO_PRESENT,
createEnvBlock(addCriticalEnv(dedupEnvCase(true, cmd.Env))),
dirPtr,
&startupInfo.StartupInfo,
&processInfo,
)
if err != nil {
return nil, err
}
defer windows.CloseHandle(processInfo.Thread)
defer windows.CloseHandle(processInfo.Process)
return pty, nil
}
// Taken from: https://github.com/microsoft/hcsshim/blob/7fbdca16f91de8792371ba22b7305bf4ca84170a/internal/exec/exec.go#L476
func createEnvBlock(envv []string) *uint16 {
if len(envv) == 0 {
return &utf16.Encode([]rune("\x00\x00"))[0]
}
length := 0
for _, s := range envv {
length += len(s) + 1
}
length += 1
b := make([]byte, length)
i := 0
for _, s := range envv {
l := len(s)
copy(b[i:i+l], []byte(s))
copy(b[i+l:i+l+1], []byte{0})
i = i + l + 1
}
copy(b[i:i+1], []byte{0})
return &utf16.Encode([]rune(string(b)))[0]
}
// dedupEnvCase is dedupEnv with a case option for testing.
// If caseInsensitive is true, the case of keys is ignored.
func dedupEnvCase(caseInsensitive bool, env []string) []string {
out := make([]string, 0, len(env))
saw := make(map[string]int, len(env)) // key => index into out
for _, kv := range env {
eq := strings.Index(kv, "=")
if eq < 0 {
out = append(out, kv)
continue
}
k := kv[:eq]
if caseInsensitive {
k = strings.ToLower(k)
}
if dupIdx, isDup := saw[k]; isDup {
out[dupIdx] = kv
continue
}
saw[k] = len(out)
out = append(out, kv)
}
return out
}
// addCriticalEnv adds any critical environment variables that are required
// (or at least almost always required) on the operating system.
// Currently this is only used for Windows.
func addCriticalEnv(env []string) []string {
for _, kv := range env {
eq := strings.Index(kv, "=")
if eq < 0 {
continue
}
k := kv[:eq]
if strings.EqualFold(k, "SYSTEMROOT") {
// We already have it.
return env
}
}
return append(env, "SYSTEMROOT="+os.Getenv("SYSTEMROOT"))
}

32
pty/start_windows_test.go Normal file
View File

@ -0,0 +1,32 @@
//go:build windows
// +build windows
package pty_test
import (
"os/exec"
"testing"
"github.com/coder/coder/pty/ptytest"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
func TestStart(t *testing.T) {
t.Parallel()
t.Run("Echo", func(t *testing.T) {
t.Parallel()
pty := ptytest.Start(t, exec.Command("cmd.exe", "/c", "echo", "test"))
pty.ExpectMatch("test")
})
t.Run("Resize", func(t *testing.T) {
t.Parallel()
pty := ptytest.Start(t, exec.Command("cmd.exe"))
err := pty.Resize(100, 50)
require.NoError(t, err)
})
}