feat: Add Tailscale networking (#3505)

* fix: Add coder user to docker group on installation

This makes for a simpler setup, and reduces the likelihood
a user runs into a strange issue.

* Add wgnet

* Add ping

* Add listening

* Finish refactor to make this work

* Add interface for swapping

* Fix conncache with interface

* chore: update gvisor

* fix tailscale types

* linting

* more linting

* Add coordinator

* Add coordinator tests

* Fix coordination

* It compiles!

* Move all connection negotiation in-memory

* Migrate coordinator to use net.conn

* Add closed func

* Fix close listener func

* Make reconnecting PTY work

* Fix reconnecting PTY

* Update CI to Go 1.19

* Add CLI flags for DERP mapping

* Fix Tailnet test

* Rename ConnCoordinator to TailnetCoordinator

* Remove print statement from workspace agent test

* Refactor wsconncache to use tailnet

* Remove STUN from unit tests

* Add migrate back to dump

* chore: Upgrade to Go 1.19

This is required as part of #3505.

* Fix reconnecting PTY tests

* fix: update wireguard-go to fix devtunnel

* fix migration numbers

* linting

* Return early for status if endpoints are empty

* Update cli/server.go

Co-authored-by: Colin Adler <colin1adler@gmail.com>

* Update cli/server.go

Co-authored-by: Colin Adler <colin1adler@gmail.com>

* Fix frontend entites

* Fix agent bicopy

* Fix race condition for the last node

* Fix down migration

* Fix connection RBAC

* Fix migration numbers

* Fix forwarding TCP to a local port

* Implement ping for tailnet

* Rename to ForceHTTP

* Add external derpmapping

* Expose DERP region names to the API

* Add global option to enable Tailscale networking for web

* Mark DERP flags hidden while testing

* Update DERP map on reconnect

* Add close func to workspace agents

* Fix race condition in upstream dependency

* Fix feature columns race condition

Co-authored-by: Colin Adler <colin1adler@gmail.com>
This commit is contained in:
Kyle Carberry
2022-08-31 20:09:44 -05:00
committed by GitHub
parent 00da01fdf7
commit 9bd83e5ec7
56 changed files with 2498 additions and 1817 deletions

View File

@ -4,11 +4,13 @@ import (
"context"
"crypto/rand"
"crypto/rsa"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/netip"
"net/url"
"os"
"os/exec"
@ -27,15 +29,14 @@ import (
"go.uber.org/atomic"
gossh "golang.org/x/crypto/ssh"
"golang.org/x/xerrors"
"inet.af/netaddr"
"tailscale.com/types/key"
"tailscale.com/tailcfg"
"cdr.dev/slog"
"github.com/coder/coder/agent/usershell"
"github.com/coder/coder/peer"
"github.com/coder/coder/peer/peerwg"
"github.com/coder/coder/peerbroker"
"github.com/coder/coder/pty"
"github.com/coder/coder/tailnet"
"github.com/coder/retry"
)
@ -50,57 +51,63 @@ const (
MagicSessionErrorCode = 229
)
var (
// tailnetIP is a static IPv6 address with the Tailscale prefix that is used to route
// connections from clients to this node. A dynamic address is not required because a Tailnet
// client only dials a single agent at a time.
tailnetIP = netip.MustParseAddr("fd7a:115c:a1e0:49d6:b259:b7ac:b1b2:48f4")
tailnetSSHPort = 1
tailnetReconnectingPTYPort = 2
)
type Options struct {
EnableWireguard bool
UploadWireguardKeys UploadWireguardKeys
ListenWireguardPeers ListenWireguardPeers
CoordinatorDialer CoordinatorDialer
WebRTCDialer WebRTCDialer
FetchMetadata FetchMetadata
ReconnectingPTYTimeout time.Duration
EnvironmentVariables map[string]string
Logger slog.Logger
}
type Metadata struct {
WireguardAddresses []netaddr.IPPrefix `json:"addresses"`
EnvironmentVariables map[string]string `json:"environment_variables"`
StartupScript string `json:"startup_script"`
Directory string `json:"directory"`
DERPMap *tailcfg.DERPMap `json:"derpmap"`
EnvironmentVariables map[string]string `json:"environment_variables"`
StartupScript string `json:"startup_script"`
Directory string `json:"directory"`
}
type WireguardPublicKeys struct {
Public key.NodePublic `json:"public"`
Disco key.DiscoPublic `json:"disco"`
}
type WebRTCDialer func(ctx context.Context, logger slog.Logger) (*peerbroker.Listener, error)
type Dialer func(ctx context.Context, logger slog.Logger) (Metadata, *peerbroker.Listener, error)
type UploadWireguardKeys func(ctx context.Context, keys WireguardPublicKeys) error
type ListenWireguardPeers func(ctx context.Context, logger slog.Logger) (<-chan peerwg.Handshake, func(), error)
// CoordinatorDialer is a function that constructs a new broker.
// A dialer must be passed in to allow for reconnects.
type CoordinatorDialer func(ctx context.Context) (net.Conn, error)
func New(dialer Dialer, options *Options) io.Closer {
if options == nil {
options = &Options{}
}
// FetchMetadata is a function to obtain metadata for the agent.
type FetchMetadata func(ctx context.Context) (Metadata, error)
func New(options Options) io.Closer {
if options.ReconnectingPTYTimeout == 0 {
options.ReconnectingPTYTimeout = 5 * time.Minute
}
ctx, cancelFunc := context.WithCancel(context.Background())
server := &agent{
dialer: dialer,
webrtcDialer: options.WebRTCDialer,
reconnectingPTYTimeout: options.ReconnectingPTYTimeout,
logger: options.Logger,
closeCancel: cancelFunc,
closed: make(chan struct{}),
envVars: options.EnvironmentVariables,
enableWireguard: options.EnableWireguard,
postKeys: options.UploadWireguardKeys,
listenWireguardPeers: options.ListenWireguardPeers,
coordinatorDialer: options.CoordinatorDialer,
fetchMetadata: options.FetchMetadata,
}
server.init(ctx)
return server
}
type agent struct {
dialer Dialer
logger slog.Logger
webrtcDialer WebRTCDialer
logger slog.Logger
reconnectingPTYs sync.Map
reconnectingPTYTimeout time.Duration
@ -113,24 +120,21 @@ type agent struct {
envVars map[string]string
// metadata is atomic because values can change after reconnection.
metadata atomic.Value
startupScript atomic.Bool
fetchMetadata FetchMetadata
sshServer *ssh.Server
enableWireguard bool
network *peerwg.Network
postKeys UploadWireguardKeys
listenWireguardPeers ListenWireguardPeers
network *tailnet.Conn
coordinatorDialer CoordinatorDialer
}
func (a *agent) run(ctx context.Context) {
var metadata Metadata
var peerListener *peerbroker.Listener
var err error
// An exponential back-off occurs when the connection is failing to dial.
// This is to prevent server spam in case of a coderd outage.
for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); {
a.logger.Info(ctx, "connecting")
metadata, peerListener, err = a.dialer(ctx, a.logger)
metadata, err = a.fetchMetadata(ctx)
if err != nil {
if errors.Is(err, context.Canceled) {
return
@ -141,7 +145,7 @@ func (a *agent) run(ctx context.Context) {
a.logger.Warn(context.Background(), "failed to dial", slog.Error(err))
continue
}
a.logger.Info(context.Background(), "connected")
a.logger.Info(context.Background(), "fetched metadata")
break
}
select {
@ -151,24 +155,164 @@ func (a *agent) run(ctx context.Context) {
}
a.metadata.Store(metadata)
if a.startupScript.CAS(false, true) {
// The startup script has not ran yet!
go func() {
err := a.runStartupScript(ctx, metadata.StartupScript)
// The startup script has not ran yet!
go func() {
err := a.runStartupScript(ctx, metadata.StartupScript)
if errors.Is(err, context.Canceled) {
return
}
if err != nil {
a.logger.Warn(ctx, "agent script failed", slog.Error(err))
}
}()
if a.webrtcDialer != nil {
go a.runWebRTCNetworking(ctx)
}
if metadata.DERPMap != nil {
go a.runTailnet(ctx, metadata.DERPMap)
}
}
func (a *agent) runTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) {
a.closeMutex.Lock()
defer a.closeMutex.Unlock()
if a.isClosed() {
return
}
if a.network != nil {
a.network.SetDERPMap(derpMap)
return
}
var err error
a.network, err = tailnet.NewConn(&tailnet.Options{
Addresses: []netip.Prefix{netip.PrefixFrom(tailnetIP, 128)},
DERPMap: derpMap,
Logger: a.logger.Named("tailnet"),
})
if err != nil {
a.logger.Critical(ctx, "create tailnet", slog.Error(err))
return
}
go a.runCoordinator(ctx)
sshListener, err := a.network.Listen("tcp", ":"+strconv.Itoa(tailnetSSHPort))
if err != nil {
a.logger.Critical(ctx, "listen for ssh", slog.Error(err))
return
}
go func() {
for {
conn, err := sshListener.Accept()
if err != nil {
return
}
go a.sshServer.HandleConn(conn)
}
}()
reconnectingPTYListener, err := a.network.Listen("tcp", ":"+strconv.Itoa(tailnetReconnectingPTYPort))
if err != nil {
a.logger.Critical(ctx, "listen for reconnecting pty", slog.Error(err))
return
}
go func() {
for {
conn, err := reconnectingPTYListener.Accept()
if err != nil {
return
}
// This cannot use a JSON decoder, since that can
// buffer additional data that is required for the PTY.
rawLen := make([]byte, 2)
_, err = conn.Read(rawLen)
if err != nil {
continue
}
length := binary.LittleEndian.Uint16(rawLen)
data := make([]byte, length)
_, err = conn.Read(data)
if err != nil {
continue
}
var msg reconnectingPTYInit
err = json.Unmarshal(data, &msg)
if err != nil {
continue
}
go a.handleReconnectingPTY(ctx, msg, conn)
}
}()
}
// runCoordinator listens for nodes and updates the self-node as it changes.
func (a *agent) runCoordinator(ctx context.Context) {
var coordinator net.Conn
var err error
// An exponential back-off occurs when the connection is failing to dial.
// This is to prevent server spam in case of a coderd outage.
for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); {
coordinator, err = a.coordinatorDialer(ctx)
if err != nil {
if errors.Is(err, context.Canceled) {
return
}
if err != nil {
a.logger.Warn(ctx, "agent script failed", slog.Error(err))
if a.isClosed() {
return
}
}()
}
if a.enableWireguard {
err = a.startWireguard(ctx, metadata.WireguardAddresses)
if err != nil {
a.logger.Error(ctx, "start wireguard", slog.Error(err))
a.logger.Warn(context.Background(), "failed to dial", slog.Error(err))
continue
}
a.logger.Info(context.Background(), "connected to coordination server")
break
}
select {
case <-ctx.Done():
return
default:
}
defer coordinator.Close()
sendNodes, errChan := tailnet.ServeCoordinator(coordinator, a.network.UpdateNodes)
a.network.SetNodeCallback(sendNodes)
select {
case <-ctx.Done():
return
case err := <-errChan:
if a.isClosed() {
return
}
if errors.Is(err, context.Canceled) {
return
}
a.logger.Debug(ctx, "node broker accept exited; restarting connection", slog.Error(err))
a.runCoordinator(ctx)
return
}
}
func (a *agent) runWebRTCNetworking(ctx context.Context) {
var peerListener *peerbroker.Listener
var err error
// An exponential back-off occurs when the connection is failing to dial.
// This is to prevent server spam in case of a coderd outage.
for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); {
peerListener, err = a.webrtcDialer(ctx, a.logger)
if err != nil {
if errors.Is(err, context.Canceled) {
return
}
if a.isClosed() {
return
}
a.logger.Warn(context.Background(), "failed to dial", slog.Error(err))
continue
}
a.logger.Info(context.Background(), "connected to webrtc broker")
break
}
select {
case <-ctx.Done():
return
default:
}
for {
@ -178,7 +322,7 @@ func (a *agent) run(ctx context.Context) {
return
}
a.logger.Debug(ctx, "peer listener accept exited; restarting connection", slog.Error(err))
a.run(ctx)
a.runWebRTCNetworking(ctx)
return
}
a.closeMutex.Lock()
@ -243,7 +387,38 @@ func (a *agent) handlePeerConn(ctx context.Context, conn *peer.Conn) {
case ProtocolSSH:
go a.sshServer.HandleConn(channel.NetConn())
case ProtocolReconnectingPTY:
go a.handleReconnectingPTY(ctx, channel.Label(), channel.NetConn())
rawID := channel.Label()
// The ID format is referenced in conn.go.
// <uuid>:<height>:<width>
idParts := strings.SplitN(rawID, ":", 4)
if len(idParts) != 4 {
a.logger.Warn(ctx, "client sent invalid id format", slog.F("raw-id", rawID))
continue
}
id := idParts[0]
// Enforce a consistent format for IDs.
_, err := uuid.Parse(id)
if err != nil {
a.logger.Warn(ctx, "client sent reconnection token that isn't a uuid", slog.F("id", id), slog.Error(err))
continue
}
// Parse the initial terminal dimensions.
height, err := strconv.Atoi(idParts[1])
if err != nil {
a.logger.Warn(ctx, "client sent invalid height", slog.F("id", id), slog.F("height", idParts[1]))
continue
}
width, err := strconv.Atoi(idParts[2])
if err != nil {
a.logger.Warn(ctx, "client sent invalid width", slog.F("id", id), slog.F("width", idParts[2]))
continue
}
go a.handleReconnectingPTY(ctx, reconnectingPTYInit{
ID: id,
Height: uint16(height),
Width: uint16(width),
Command: idParts[3],
}, channel.NetConn())
case ProtocolDial:
go a.handleDial(ctx, channel.Label(), channel.NetConn())
default:
@ -514,45 +689,19 @@ func (a *agent) handleSSHSession(session ssh.Session) (retErr error) {
return cmd.Wait()
}
func (a *agent) handleReconnectingPTY(ctx context.Context, rawID string, conn net.Conn) {
func (a *agent) handleReconnectingPTY(ctx context.Context, msg reconnectingPTYInit, conn net.Conn) {
defer conn.Close()
// The ID format is referenced in conn.go.
// <uuid>:<height>:<width>
idParts := strings.SplitN(rawID, ":", 4)
if len(idParts) != 4 {
a.logger.Warn(ctx, "client sent invalid id format", slog.F("raw-id", rawID))
return
}
id := idParts[0]
// Enforce a consistent format for IDs.
_, err := uuid.Parse(id)
if err != nil {
a.logger.Warn(ctx, "client sent reconnection token that isn't a uuid", slog.F("id", id), slog.Error(err))
return
}
// Parse the initial terminal dimensions.
height, err := strconv.Atoi(idParts[1])
if err != nil {
a.logger.Warn(ctx, "client sent invalid height", slog.F("id", id), slog.F("height", idParts[1]))
return
}
width, err := strconv.Atoi(idParts[2])
if err != nil {
a.logger.Warn(ctx, "client sent invalid width", slog.F("id", id), slog.F("width", idParts[2]))
return
}
var rpty *reconnectingPTY
rawRPTY, ok := a.reconnectingPTYs.Load(id)
rawRPTY, ok := a.reconnectingPTYs.Load(msg.ID)
if ok {
rpty, ok = rawRPTY.(*reconnectingPTY)
if !ok {
a.logger.Warn(ctx, "found invalid type in reconnecting pty map", slog.F("id", id))
a.logger.Warn(ctx, "found invalid type in reconnecting pty map", slog.F("id", msg.ID))
}
} else {
// Empty command will default to the users shell!
cmd, err := a.createCommand(ctx, idParts[3], nil)
cmd, err := a.createCommand(ctx, msg.Command, nil)
if err != nil {
a.logger.Warn(ctx, "create reconnecting pty command", slog.Error(err))
return
@ -561,7 +710,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, rawID string, conn ne
ptty, process, err := pty.Start(cmd)
if err != nil {
a.logger.Warn(ctx, "start reconnecting pty command", slog.F("id", id))
a.logger.Warn(ctx, "start reconnecting pty command", slog.F("id", msg.ID))
}
// Default to buffer 64KiB.
@ -582,7 +731,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, rawID string, conn ne
timeout: time.AfterFunc(a.reconnectingPTYTimeout, cancelFunc),
circularBuffer: circularBuffer,
}
a.reconnectingPTYs.Store(id, rpty)
a.reconnectingPTYs.Store(msg.ID, rpty)
go func() {
// CommandContext isn't respected for Windows PTYs right now,
// so we need to manually track the lifecycle.
@ -611,7 +760,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, rawID string, conn ne
_, err = rpty.circularBuffer.Write(part)
rpty.circularBufferMutex.Unlock()
if err != nil {
a.logger.Error(ctx, "reconnecting pty write buffer", slog.Error(err), slog.F("id", id))
a.logger.Error(ctx, "reconnecting pty write buffer", slog.Error(err), slog.F("id", msg.ID))
break
}
rpty.activeConnsMutex.Lock()
@ -625,22 +774,22 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, rawID string, conn ne
// ID from memory.
_ = process.Kill()
rpty.Close()
a.reconnectingPTYs.Delete(id)
a.reconnectingPTYs.Delete(msg.ID)
a.connCloseWait.Done()
}()
}
// Resize the PTY to initial height + width.
err = rpty.ptty.Resize(uint16(height), uint16(width))
err := rpty.ptty.Resize(msg.Height, msg.Width)
if err != nil {
// We can continue after this, it's not fatal!
a.logger.Error(ctx, "resize reconnecting pty", slog.F("id", id), slog.Error(err))
a.logger.Error(ctx, "resize reconnecting pty", slog.F("id", msg.ID), slog.Error(err))
}
// Write any previously stored data for the TTY.
rpty.circularBufferMutex.RLock()
_, err = conn.Write(rpty.circularBuffer.Bytes())
rpty.circularBufferMutex.RUnlock()
if err != nil {
a.logger.Warn(ctx, "write reconnecting pty buffer", slog.F("id", id), slog.Error(err))
a.logger.Warn(ctx, "write reconnecting pty buffer", slog.F("id", msg.ID), slog.Error(err))
return
}
connectionID := uuid.NewString()
@ -686,12 +835,12 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, rawID string, conn ne
return
}
if err != nil {
a.logger.Warn(ctx, "reconnecting pty buffer read error", slog.F("id", id), slog.Error(err))
a.logger.Warn(ctx, "reconnecting pty buffer read error", slog.F("id", msg.ID), slog.Error(err))
return
}
_, err = rpty.ptty.Input().Write([]byte(req.Data))
if err != nil {
a.logger.Warn(ctx, "write to reconnecting pty", slog.F("id", id), slog.Error(err))
a.logger.Warn(ctx, "write to reconnecting pty", slog.F("id", msg.ID), slog.Error(err))
return
}
// Check if a resize needs to happen!
@ -701,7 +850,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, rawID string, conn ne
err = rpty.ptty.Resize(req.Height, req.Width)
if err != nil {
// We can continue after this, it's not fatal!
a.logger.Error(ctx, "resize reconnecting pty", slog.F("id", id), slog.Error(err))
a.logger.Error(ctx, "resize reconnecting pty", slog.F("id", msg.ID), slog.Error(err))
}
}
}
@ -788,6 +937,9 @@ func (a *agent) Close() error {
}
close(a.closed)
a.closeCancel()
if a.network != nil {
_ = a.network.Close()
}
_ = a.sshServer.Close()
a.connCloseWait.Wait()
return nil