mirror of
https://github.com/coder/coder.git
synced 2025-07-13 21:36:50 +00:00
feat: add support for X11 forwarding (#7205)
* feat: add support for X11 forwarding * Only run X forwarding on Linux * Fix piping * Fix comments
This commit is contained in:
@ -161,7 +161,7 @@ type agent struct {
|
||||
}
|
||||
|
||||
func (a *agent) init(ctx context.Context) {
|
||||
sshSrv, err := agentssh.NewServer(ctx, a.logger.Named("ssh-server"), a.sshMaxTimeout)
|
||||
sshSrv, err := agentssh.NewServer(ctx, a.logger.Named("ssh-server"), a.filesystem, a.sshMaxTimeout, "")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
@ -20,6 +20,7 @@ import (
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
"github.com/pkg/sftp"
|
||||
"github.com/spf13/afero"
|
||||
"go.uber.org/atomic"
|
||||
gossh "golang.org/x/crypto/ssh"
|
||||
"golang.org/x/xerrors"
|
||||
@ -48,6 +49,7 @@ const (
|
||||
|
||||
type Server struct {
|
||||
mu sync.RWMutex // Protects following.
|
||||
fs afero.Fs
|
||||
listeners map[net.Listener]struct{}
|
||||
conns map[net.Conn]struct{}
|
||||
sessions map[ssh.Session]struct{}
|
||||
@ -58,6 +60,7 @@ type Server struct {
|
||||
|
||||
logger slog.Logger
|
||||
srv *ssh.Server
|
||||
x11SocketDir string
|
||||
|
||||
Env map[string]string
|
||||
AgentToken func() string
|
||||
@ -68,7 +71,7 @@ type Server struct {
|
||||
connCountSSHSession atomic.Int64
|
||||
}
|
||||
|
||||
func NewServer(ctx context.Context, logger slog.Logger, maxTimeout time.Duration) (*Server, error) {
|
||||
func NewServer(ctx context.Context, logger slog.Logger, fs afero.Fs, maxTimeout time.Duration, x11SocketDir string) (*Server, error) {
|
||||
// Clients' should ignore the host key when connecting.
|
||||
// The agent needs to authenticate with coderd to SSH,
|
||||
// so SSH authentication doesn't improve security.
|
||||
@ -80,15 +83,20 @@ func NewServer(ctx context.Context, logger slog.Logger, maxTimeout time.Duration
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if x11SocketDir == "" {
|
||||
x11SocketDir = filepath.Join(os.TempDir(), ".X11-unix")
|
||||
}
|
||||
|
||||
forwardHandler := &ssh.ForwardedTCPHandler{}
|
||||
unixForwardHandler := &forwardedUnixHandler{log: logger}
|
||||
|
||||
s := &Server{
|
||||
listeners: make(map[net.Listener]struct{}),
|
||||
fs: fs,
|
||||
conns: make(map[net.Conn]struct{}),
|
||||
sessions: make(map[ssh.Session]struct{}),
|
||||
logger: logger,
|
||||
x11SocketDir: x11SocketDir,
|
||||
}
|
||||
|
||||
s.srv = &ssh.Server{
|
||||
@ -125,6 +133,7 @@ func NewServer(ctx context.Context, logger slog.Logger, maxTimeout time.Duration
|
||||
"streamlocal-forward@openssh.com": unixForwardHandler.HandleSSHRequest,
|
||||
"cancel-streamlocal-forward@openssh.com": unixForwardHandler.HandleSSHRequest,
|
||||
},
|
||||
X11Callback: s.x11Callback,
|
||||
ServerConfigCallback: func(ctx ssh.Context) *gossh.ServerConfig {
|
||||
return &gossh.ServerConfig{
|
||||
NoClientAuth: true,
|
||||
@ -163,6 +172,15 @@ func (s *Server) sessionHandler(session ssh.Session) {
|
||||
|
||||
ctx := session.Context()
|
||||
|
||||
x11, hasX11 := session.X11()
|
||||
if hasX11 {
|
||||
handled := s.x11Handler(session.Context(), x11)
|
||||
if !handled {
|
||||
_ = session.Exit(1)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
switch ss := session.Subsystem(); ss {
|
||||
case "":
|
||||
case "sftp":
|
||||
|
@ -10,6 +10,7 @@ import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/afero"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/atomic"
|
||||
@ -32,7 +33,7 @@ func TestNewServer_ServeClient(t *testing.T) {
|
||||
|
||||
ctx := context.Background()
|
||||
logger := slogtest.Make(t, nil)
|
||||
s, err := agentssh.NewServer(ctx, logger, 0)
|
||||
s, err := agentssh.NewServer(ctx, logger, afero.NewMemMapFs(), 0, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
// The assumption is that these are set before serving SSH connections.
|
||||
@ -50,6 +51,7 @@ func TestNewServer_ServeClient(t *testing.T) {
|
||||
}()
|
||||
|
||||
c := sshClient(t, ln.Addr().String())
|
||||
|
||||
var b bytes.Buffer
|
||||
sess, err := c.NewSession()
|
||||
sess.Stdout = &b
|
||||
@ -72,7 +74,7 @@ func TestNewServer_CloseActiveConnections(t *testing.T) {
|
||||
|
||||
ctx := context.Background()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
s, err := agentssh.NewServer(ctx, logger, 0)
|
||||
s, err := agentssh.NewServer(ctx, logger, afero.NewMemMapFs(), 0, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
// The assumption is that these are set before serving SSH connections.
|
||||
|
190
agent/agentssh/x11.go
Normal file
190
agent/agentssh/x11.go
Normal file
@ -0,0 +1,190 @@
|
||||
package agentssh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
"github.com/gofrs/flock"
|
||||
"github.com/spf13/afero"
|
||||
gossh "golang.org/x/crypto/ssh"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
)
|
||||
|
||||
// x11Callback is called when the client requests X11 forwarding.
|
||||
// It adds an Xauthority entry to the Xauthority file.
|
||||
func (s *Server) x11Callback(ctx ssh.Context, x11 ssh.X11) bool {
|
||||
hostname, err := os.Hostname()
|
||||
if err != nil {
|
||||
s.logger.Warn(ctx, "failed to get hostname", slog.Error(err))
|
||||
return false
|
||||
}
|
||||
|
||||
err = s.fs.MkdirAll(s.x11SocketDir, 0o700)
|
||||
if err != nil {
|
||||
s.logger.Warn(ctx, "failed to make the x11 socket dir", slog.F("dir", s.x11SocketDir), slog.Error(err))
|
||||
return false
|
||||
}
|
||||
|
||||
err = addXauthEntry(ctx, s.fs, hostname, strconv.Itoa(int(x11.ScreenNumber)), x11.AuthProtocol, x11.AuthCookie)
|
||||
if err != nil {
|
||||
s.logger.Warn(ctx, "failed to add Xauthority entry", slog.Error(err))
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// x11Handler is called when a session has requested X11 forwarding.
|
||||
// It listens for X11 connections and forwards them to the client.
|
||||
func (s *Server) x11Handler(ctx ssh.Context, x11 ssh.X11) bool {
|
||||
serverConn, valid := ctx.Value(ssh.ContextKeyConn).(*gossh.ServerConn)
|
||||
if !valid {
|
||||
s.logger.Warn(ctx, "failed to get server connection")
|
||||
return false
|
||||
}
|
||||
listener, err := net.Listen("unix", filepath.Join(s.x11SocketDir, fmt.Sprintf("X%d", x11.ScreenNumber)))
|
||||
if err != nil {
|
||||
s.logger.Warn(ctx, "failed to listen for X11", slog.Error(err))
|
||||
return false
|
||||
}
|
||||
s.trackListener(listener, true)
|
||||
|
||||
go func() {
|
||||
defer listener.Close()
|
||||
defer s.trackListener(listener, false)
|
||||
handledFirstConnection := false
|
||||
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return
|
||||
}
|
||||
s.logger.Warn(ctx, "failed to accept X11 connection", slog.Error(err))
|
||||
return
|
||||
}
|
||||
if x11.SingleConnection && handledFirstConnection {
|
||||
s.logger.Warn(ctx, "X11 connection rejected because single connection is enabled")
|
||||
_ = conn.Close()
|
||||
continue
|
||||
}
|
||||
handledFirstConnection = true
|
||||
|
||||
unixConn, ok := conn.(*net.UnixConn)
|
||||
if !ok {
|
||||
s.logger.Warn(ctx, fmt.Sprintf("failed to cast connection to UnixConn. got: %T", conn))
|
||||
return
|
||||
}
|
||||
unixAddr, ok := unixConn.LocalAddr().(*net.UnixAddr)
|
||||
if !ok {
|
||||
s.logger.Warn(ctx, fmt.Sprintf("failed to cast local address to UnixAddr. got: %T", unixConn.LocalAddr()))
|
||||
return
|
||||
}
|
||||
|
||||
channel, reqs, err := serverConn.OpenChannel("x11", gossh.Marshal(struct {
|
||||
OriginatorAddress string
|
||||
OriginatorPort uint32
|
||||
}{
|
||||
OriginatorAddress: unixAddr.Name,
|
||||
OriginatorPort: 0,
|
||||
}))
|
||||
if err != nil {
|
||||
s.logger.Warn(ctx, "failed to open X11 channel", slog.Error(err))
|
||||
return
|
||||
}
|
||||
go gossh.DiscardRequests(reqs)
|
||||
go Bicopy(ctx, conn, channel)
|
||||
}
|
||||
}()
|
||||
return true
|
||||
}
|
||||
|
||||
// addXauthEntry adds an Xauthority entry to the Xauthority file.
|
||||
// The Xauthority file is located at ~/.Xauthority.
|
||||
func addXauthEntry(ctx context.Context, fs afero.Fs, host string, display string, authProtocol string, authCookie string) error {
|
||||
// Get the Xauthority file path
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to get user home directory: %w", err)
|
||||
}
|
||||
|
||||
xauthPath := filepath.Join(homeDir, ".Xauthority")
|
||||
|
||||
lock := flock.New(xauthPath)
|
||||
defer lock.Close()
|
||||
ok, err := lock.TryLockContext(ctx, 100*time.Millisecond)
|
||||
if !ok {
|
||||
return xerrors.Errorf("failed to lock Xauthority file: %w", err)
|
||||
}
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to lock Xauthority file: %w", err)
|
||||
}
|
||||
|
||||
// Open or create the Xauthority file
|
||||
file, err := fs.OpenFile(xauthPath, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0o600)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to open Xauthority file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Convert the authCookie from hex string to byte slice
|
||||
authCookieBytes, err := hex.DecodeString(authCookie)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to decode auth cookie: %w", err)
|
||||
}
|
||||
|
||||
// Write Xauthority entry
|
||||
family := uint16(0x0100) // FamilyLocal
|
||||
err = binary.Write(file, binary.BigEndian, family)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to write family: %w", err)
|
||||
}
|
||||
|
||||
err = binary.Write(file, binary.BigEndian, uint16(len(host)))
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to write host length: %w", err)
|
||||
}
|
||||
_, err = file.WriteString(host)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to write host: %w", err)
|
||||
}
|
||||
|
||||
err = binary.Write(file, binary.BigEndian, uint16(len(display)))
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to write display length: %w", err)
|
||||
}
|
||||
_, err = file.WriteString(display)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to write display: %w", err)
|
||||
}
|
||||
|
||||
err = binary.Write(file, binary.BigEndian, uint16(len(authProtocol)))
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to write auth protocol length: %w", err)
|
||||
}
|
||||
_, err = file.WriteString(authProtocol)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to write auth protocol: %w", err)
|
||||
}
|
||||
|
||||
err = binary.Write(file, binary.BigEndian, uint16(len(authCookieBytes)))
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to write auth cookie length: %w", err)
|
||||
}
|
||||
_, err = file.Write(authCookieBytes)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to write auth cookie: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
99
agent/agentssh/x11_test.go
Normal file
99
agent/agentssh/x11_test.go
Normal file
@ -0,0 +1,99 @@
|
||||
package agentssh_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/atomic"
|
||||
gossh "golang.org/x/crypto/ssh"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"github.com/coder/coder/agent/agentssh"
|
||||
"github.com/coder/coder/codersdk/agentsdk"
|
||||
"github.com/coder/coder/testutil"
|
||||
)
|
||||
|
||||
func TestServer_X11(t *testing.T) {
|
||||
t.Parallel()
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skip("X11 forwarding is only supported on Linux")
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
fs := afero.NewOsFs()
|
||||
dir := t.TempDir()
|
||||
s, err := agentssh.NewServer(ctx, logger, fs, 0, dir)
|
||||
require.NoError(t, err)
|
||||
defer s.Close()
|
||||
|
||||
// The assumption is that these are set before serving SSH connections.
|
||||
s.AgentToken = func() string { return "" }
|
||||
s.Manifest = atomic.NewPointer(&agentsdk.Manifest{})
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
err := s.Serve(ln)
|
||||
assert.Error(t, err) // Server is closed.
|
||||
}()
|
||||
|
||||
c := sshClient(t, ln.Addr().String())
|
||||
|
||||
sess, err := c.NewSession()
|
||||
require.NoError(t, err)
|
||||
|
||||
reply, err := sess.SendRequest("x11-req", true, gossh.Marshal(ssh.X11{
|
||||
AuthProtocol: "MIT-MAGIC-COOKIE-1",
|
||||
AuthCookie: hex.EncodeToString([]byte("cookie")),
|
||||
ScreenNumber: 0,
|
||||
}))
|
||||
require.NoError(t, err)
|
||||
assert.True(t, reply)
|
||||
|
||||
err = sess.Shell()
|
||||
require.NoError(t, err)
|
||||
|
||||
x11Chans := c.HandleChannelOpen("x11")
|
||||
payload := "hello world"
|
||||
require.Eventually(t, func() bool {
|
||||
conn, err := net.Dial("unix", filepath.Join(dir, "X0"))
|
||||
if err == nil {
|
||||
_, err = conn.Write([]byte(payload))
|
||||
assert.NoError(t, err)
|
||||
_ = conn.Close()
|
||||
}
|
||||
return err == nil
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
|
||||
x11 := <-x11Chans
|
||||
ch, reqs, err := x11.Accept()
|
||||
require.NoError(t, err)
|
||||
go gossh.DiscardRequests(reqs)
|
||||
got := make([]byte, len(payload))
|
||||
_, err = ch.Read(got)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, payload, string(got))
|
||||
_ = ch.Close()
|
||||
_ = s.Close()
|
||||
<-done
|
||||
|
||||
// Ensure the Xauthority file was written!
|
||||
home, err := os.UserHomeDir()
|
||||
require.NoError(t, err)
|
||||
_, err = fs.Stat(filepath.Join(home, ".Xauthority"))
|
||||
require.NoError(t, err)
|
||||
}
|
2
go.mod
2
go.mod
@ -45,7 +45,7 @@ replace tailscale.com => github.com/coder/tailscale v1.1.1-0.20230418202606-ed93
|
||||
// repo as tailscale.com/tempfork/gliderlabs/ssh, however, we can't replace the
|
||||
// subpath and it includes changes to golang.org/x/crypto/ssh as well which
|
||||
// makes importing it directly a bit messy.
|
||||
replace github.com/gliderlabs/ssh => github.com/coder/ssh v0.0.0-20220811105153-fcea99919338
|
||||
replace github.com/gliderlabs/ssh => github.com/coder/ssh v0.0.0-20230419180646-49c741437b53
|
||||
|
||||
// Waiting on https://github.com/imulab/go-scim/pull/95 to merge.
|
||||
replace github.com/imulab/go-scim/pkg/v2 => github.com/coder/go-scim/pkg/v2 v2.0.0-20230221055123-1d63c1222136
|
||||
|
4
go.sum
4
go.sum
@ -343,6 +343,10 @@ github.com/coder/retry v1.3.1-0.20230210155434-e90a2e1e091d h1:09JG37IgTB6n3ouX9
|
||||
github.com/coder/retry v1.3.1-0.20230210155434-e90a2e1e091d/go.mod h1:r+1J5i/989wt6CUeNSuvFKKA9hHuKKPMxdzDbTuvwwk=
|
||||
github.com/coder/ssh v0.0.0-20220811105153-fcea99919338 h1:tN5GKFT68YLVzJoA8AHuiMNJ0qlhoD3pGN3JY9gxSko=
|
||||
github.com/coder/ssh v0.0.0-20220811105153-fcea99919338/go.mod h1:ZSS+CUoKHDrqVakTfTWUlKSr9MtMFkC4UvtQKD7O914=
|
||||
github.com/coder/ssh v0.0.0-20230419175457-0612ba535202 h1:1I/Im5ZUan1Y9ypAr6VuAKQ4NbvEy/frR3cV86pKQk8=
|
||||
github.com/coder/ssh v0.0.0-20230419175457-0612ba535202/go.mod h1:ZSS+CUoKHDrqVakTfTWUlKSr9MtMFkC4UvtQKD7O914=
|
||||
github.com/coder/ssh v0.0.0-20230419180646-49c741437b53 h1:kaLOp3tlVnbOJIjmAvXuBTgeWWoZZlJJJ4QGeSMjOnA=
|
||||
github.com/coder/ssh v0.0.0-20230419180646-49c741437b53/go.mod h1:ZSS+CUoKHDrqVakTfTWUlKSr9MtMFkC4UvtQKD7O914=
|
||||
github.com/coder/tailscale v1.1.1-0.20230418202606-ed9307cf1b22 h1:bvGOqnI0ITbwOZFQ0SZ4MBw/8LLUEjxmNu57XEujrfQ=
|
||||
github.com/coder/tailscale v1.1.1-0.20230418202606-ed9307cf1b22/go.mod h1:jpg+77g19FpXL43U1VoIqoSg1K/Vh5CVxycGldQ8KhA=
|
||||
github.com/coder/terraform-provider-coder v0.6.23 h1:O2Rcj0umez4DfVdGnKZi63z1Xzxd0IQOn9VQDB8YU8g=
|
||||
|
Reference in New Issue
Block a user