mirror of
https://github.com/coder/coder.git
synced 2025-07-06 15:41:45 +00:00
This commit refactors where custom environment variables are set in the workspace and decouples agent specific configs from the `agentssh.Server`. To reproduce all functionality, `agentssh.Config` is introduced. The custom environment variables are now configured in `agent/agent.go` and the agent retains control of the final state. This will allow for easier extension in the future and keep other modules decoupled.
319 lines
7.3 KiB
Go
319 lines
7.3 KiB
Go
// Package agentssh_test provides tests for basic functinoality of the agentssh
|
|
// package, more test coverage can be found in the `agent` and `cli` package(s).
|
|
package agentssh_test
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"context"
|
|
"fmt"
|
|
"net"
|
|
"runtime"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
|
|
"github.com/prometheus/client_golang/prometheus"
|
|
"github.com/spf13/afero"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"go.uber.org/goleak"
|
|
"golang.org/x/crypto/ssh"
|
|
|
|
"cdr.dev/slog/sloggers/slogtest"
|
|
|
|
"github.com/coder/coder/v2/agent/agentssh"
|
|
"github.com/coder/coder/v2/pty/ptytest"
|
|
"github.com/coder/coder/v2/testutil"
|
|
)
|
|
|
|
func TestMain(m *testing.M) {
|
|
goleak.VerifyTestMain(m)
|
|
}
|
|
|
|
func TestNewServer_ServeClient(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := context.Background()
|
|
logger := slogtest.Make(t, nil)
|
|
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), nil)
|
|
require.NoError(t, err)
|
|
defer s.Close()
|
|
|
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
|
require.NoError(t, err)
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer close(done)
|
|
err := s.Serve(ln)
|
|
assert.Error(t, err) // Server is closed.
|
|
}()
|
|
|
|
c := sshClient(t, ln.Addr().String())
|
|
|
|
var b bytes.Buffer
|
|
sess, err := c.NewSession()
|
|
require.NoError(t, err)
|
|
sess.Stdout = &b
|
|
err = sess.Start("echo hello")
|
|
require.NoError(t, err)
|
|
|
|
err = sess.Wait()
|
|
require.NoError(t, err)
|
|
|
|
require.Equal(t, "hello", strings.TrimSpace(b.String()))
|
|
|
|
err = s.Close()
|
|
require.NoError(t, err)
|
|
<-done
|
|
}
|
|
|
|
func TestNewServer_ExecuteShebang(t *testing.T) {
|
|
t.Parallel()
|
|
if runtime.GOOS == "windows" {
|
|
t.Skip("bash doesn't exist on Windows")
|
|
}
|
|
|
|
ctx := context.Background()
|
|
logger := slogtest.Make(t, nil)
|
|
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), nil)
|
|
require.NoError(t, err)
|
|
t.Cleanup(func() {
|
|
_ = s.Close()
|
|
})
|
|
|
|
t.Run("Basic", func(t *testing.T) {
|
|
t.Parallel()
|
|
cmd, err := s.CreateCommand(ctx, `#!/bin/bash
|
|
echo test`, nil)
|
|
require.NoError(t, err)
|
|
output, err := cmd.AsExec().CombinedOutput()
|
|
require.NoError(t, err)
|
|
require.Equal(t, "test\n", string(output))
|
|
})
|
|
t.Run("Args", func(t *testing.T) {
|
|
t.Parallel()
|
|
cmd, err := s.CreateCommand(ctx, `#!/usr/bin/env bash
|
|
echo test`, nil)
|
|
require.NoError(t, err)
|
|
output, err := cmd.AsExec().CombinedOutput()
|
|
require.NoError(t, err)
|
|
require.Equal(t, "test\n", string(output))
|
|
})
|
|
}
|
|
|
|
func TestNewServer_CloseActiveConnections(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := context.Background()
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), nil)
|
|
require.NoError(t, err)
|
|
defer s.Close()
|
|
|
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
|
require.NoError(t, err)
|
|
|
|
var wg sync.WaitGroup
|
|
wg.Add(2)
|
|
go func() {
|
|
defer wg.Done()
|
|
err := s.Serve(ln)
|
|
assert.Error(t, err) // Server is closed.
|
|
}()
|
|
|
|
pty := ptytest.New(t)
|
|
|
|
doClose := make(chan struct{})
|
|
go func() {
|
|
defer wg.Done()
|
|
c := sshClient(t, ln.Addr().String())
|
|
sess, err := c.NewSession()
|
|
assert.NoError(t, err)
|
|
sess.Stdin = pty.Input()
|
|
sess.Stdout = pty.Output()
|
|
sess.Stderr = pty.Output()
|
|
|
|
assert.NoError(t, err)
|
|
err = sess.Start("")
|
|
assert.NoError(t, err)
|
|
|
|
close(doClose)
|
|
err = sess.Wait()
|
|
assert.Error(t, err)
|
|
}()
|
|
|
|
<-doClose
|
|
err = s.Close()
|
|
require.NoError(t, err)
|
|
|
|
wg.Wait()
|
|
}
|
|
|
|
func TestNewServer_Signal(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("Stdout", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := context.Background()
|
|
logger := slogtest.Make(t, nil)
|
|
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), nil)
|
|
require.NoError(t, err)
|
|
defer s.Close()
|
|
|
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
|
require.NoError(t, err)
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer close(done)
|
|
err := s.Serve(ln)
|
|
assert.Error(t, err) // Server is closed.
|
|
}()
|
|
defer func() {
|
|
err := s.Close()
|
|
require.NoError(t, err)
|
|
<-done
|
|
}()
|
|
|
|
c := sshClient(t, ln.Addr().String())
|
|
|
|
sess, err := c.NewSession()
|
|
require.NoError(t, err)
|
|
r, err := sess.StdoutPipe()
|
|
require.NoError(t, err)
|
|
|
|
// Perform multiple sleeps since the interrupt signal doesn't propagate to
|
|
// the process group, this lets us exit early.
|
|
sleeps := strings.Repeat("sleep 1 && ", int(testutil.WaitMedium.Seconds()))
|
|
err = sess.Start(fmt.Sprintf("echo hello && %s echo bye", sleeps))
|
|
require.NoError(t, err)
|
|
|
|
sc := bufio.NewScanner(r)
|
|
for sc.Scan() {
|
|
t.Log(sc.Text())
|
|
if strings.Contains(sc.Text(), "hello") {
|
|
break
|
|
}
|
|
}
|
|
require.NoError(t, sc.Err())
|
|
|
|
err = sess.Signal(ssh.SIGKILL)
|
|
require.NoError(t, err)
|
|
|
|
// Assumption, signal propagates and the command exists, closing stdout.
|
|
for sc.Scan() {
|
|
t.Log(sc.Text())
|
|
require.NotContains(t, sc.Text(), "bye")
|
|
}
|
|
require.NoError(t, sc.Err())
|
|
|
|
err = sess.Wait()
|
|
exitErr := &ssh.ExitError{}
|
|
require.ErrorAs(t, err, &exitErr)
|
|
wantCode := 255
|
|
if runtime.GOOS == "windows" {
|
|
wantCode = 1
|
|
}
|
|
require.Equal(t, wantCode, exitErr.ExitStatus())
|
|
})
|
|
t.Run("PTY", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := context.Background()
|
|
logger := slogtest.Make(t, nil)
|
|
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), nil)
|
|
require.NoError(t, err)
|
|
defer s.Close()
|
|
|
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
|
require.NoError(t, err)
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
defer close(done)
|
|
err := s.Serve(ln)
|
|
assert.Error(t, err) // Server is closed.
|
|
}()
|
|
defer func() {
|
|
err := s.Close()
|
|
require.NoError(t, err)
|
|
<-done
|
|
}()
|
|
|
|
c := sshClient(t, ln.Addr().String())
|
|
|
|
pty := ptytest.New(t)
|
|
|
|
sess, err := c.NewSession()
|
|
require.NoError(t, err)
|
|
r, err := sess.StdoutPipe()
|
|
require.NoError(t, err)
|
|
|
|
// Note, we request pty but don't use ptytest here because we can't
|
|
// easily test for no text before EOF.
|
|
sess.Stdin = pty.Input()
|
|
sess.Stderr = pty.Output()
|
|
|
|
err = sess.RequestPty("xterm", 80, 80, nil)
|
|
require.NoError(t, err)
|
|
|
|
// Perform multiple sleeps since the interrupt signal doesn't propagate to
|
|
// the process group, this lets us exit early.
|
|
sleeps := strings.Repeat("sleep 1 && ", int(testutil.WaitMedium.Seconds()))
|
|
err = sess.Start(fmt.Sprintf("echo hello && %s echo bye", sleeps))
|
|
require.NoError(t, err)
|
|
|
|
sc := bufio.NewScanner(r)
|
|
for sc.Scan() {
|
|
t.Log(sc.Text())
|
|
if strings.Contains(sc.Text(), "hello") {
|
|
break
|
|
}
|
|
}
|
|
require.NoError(t, sc.Err())
|
|
|
|
err = sess.Signal(ssh.SIGKILL)
|
|
require.NoError(t, err)
|
|
|
|
// Assumption, signal propagates and the command exists, closing stdout.
|
|
for sc.Scan() {
|
|
t.Log(sc.Text())
|
|
require.NotContains(t, sc.Text(), "bye")
|
|
}
|
|
require.NoError(t, sc.Err())
|
|
|
|
err = sess.Wait()
|
|
exitErr := &ssh.ExitError{}
|
|
require.ErrorAs(t, err, &exitErr)
|
|
wantCode := 255
|
|
if runtime.GOOS == "windows" {
|
|
wantCode = 1
|
|
}
|
|
require.Equal(t, wantCode, exitErr.ExitStatus())
|
|
})
|
|
}
|
|
|
|
func sshClient(t *testing.T, addr string) *ssh.Client {
|
|
conn, err := net.Dial("tcp", addr)
|
|
require.NoError(t, err)
|
|
t.Cleanup(func() {
|
|
_ = conn.Close()
|
|
})
|
|
|
|
sshConn, channels, requests, err := ssh.NewClientConn(conn, "localhost:22", &ssh.ClientConfig{
|
|
HostKeyCallback: ssh.InsecureIgnoreHostKey(), //nolint:gosec // This is a test.
|
|
})
|
|
require.NoError(t, err)
|
|
t.Cleanup(func() {
|
|
_ = sshConn.Close()
|
|
})
|
|
c := ssh.NewClient(sshConn, channels, requests)
|
|
t.Cleanup(func() {
|
|
_ = c.Close()
|
|
})
|
|
return c
|
|
}
|