mirror of
https://github.com/coder/coder.git
synced 2025-07-09 11:45:56 +00:00
chore: refactor tailnetAPIConnector to tailnet.Controller (#15361)
Refactors `workspacesdk.tailnetAPIConnector` as a `tailnet.Controller` to reuse all the reconnection and graceful disconnect logic. chore re: #14729
This commit is contained in:
@ -1,363 +0,0 @@
|
||||
package workspacesdk
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
"nhooyr.io/websocket"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/v2/buildinfo"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/tailnet"
|
||||
"github.com/coder/coder/v2/tailnet/proto"
|
||||
"github.com/coder/quartz"
|
||||
"github.com/coder/retry"
|
||||
)
|
||||
|
||||
var tailnetConnectorGracefulTimeout = time.Second
|
||||
|
||||
// tailnetConn is the subset of the tailnet.Conn methods that tailnetAPIConnector uses. It is
|
||||
// included so that we can fake it in testing.
|
||||
//
|
||||
// @typescript-ignore tailnetConn
|
||||
type tailnetConn interface {
|
||||
tailnet.Coordinatee
|
||||
tailnet.DERPMapSetter
|
||||
}
|
||||
|
||||
// tailnetAPIConnector dials the tailnet API (v2+) and then uses the API with a tailnet.Conn to
|
||||
//
|
||||
// 1) run the Coordinate API and pass node information back and forth
|
||||
// 2) stream DERPMap updates and program the Conn
|
||||
// 3) Send network telemetry events
|
||||
//
|
||||
// These functions share the same websocket, and so are combined here so that if we hit a problem
|
||||
// we tear the whole thing down and start over with a new websocket.
|
||||
//
|
||||
// @typescript-ignore tailnetAPIConnector
|
||||
type tailnetAPIConnector struct {
|
||||
// We keep track of two contexts: the main context from the caller, and a "graceful" context
|
||||
// that we keep open slightly longer than the main context to give a chance to send the
|
||||
// Disconnect message to the coordinator. That tells the coordinator that we really meant to
|
||||
// disconnect instead of just losing network connectivity.
|
||||
ctx context.Context
|
||||
gracefulCtx context.Context
|
||||
cancelGracefulCtx context.CancelFunc
|
||||
|
||||
logger slog.Logger
|
||||
|
||||
agentID uuid.UUID
|
||||
clock quartz.Clock
|
||||
dialer tailnet.ControlProtocolDialer
|
||||
derpCtrl tailnet.DERPController
|
||||
coordCtrl tailnet.CoordinationController
|
||||
telCtrl *tailnet.BasicTelemetryController
|
||||
tokenCtrl tailnet.ResumeTokenController
|
||||
|
||||
closed chan struct{}
|
||||
}
|
||||
|
||||
// Create a new tailnetAPIConnector without running it
|
||||
func newTailnetAPIConnector(ctx context.Context, logger slog.Logger, agentID uuid.UUID, dialer tailnet.ControlProtocolDialer, clock quartz.Clock) *tailnetAPIConnector {
|
||||
return &tailnetAPIConnector{
|
||||
ctx: ctx,
|
||||
logger: logger,
|
||||
agentID: agentID,
|
||||
clock: clock,
|
||||
dialer: dialer,
|
||||
closed: make(chan struct{}),
|
||||
telCtrl: tailnet.NewBasicTelemetryController(logger),
|
||||
tokenCtrl: tailnet.NewBasicResumeTokenController(logger, clock),
|
||||
}
|
||||
}
|
||||
|
||||
// manageGracefulTimeout allows the gracefulContext to last 1 second longer than the main context
|
||||
// to allow a graceful disconnect.
|
||||
func (tac *tailnetAPIConnector) manageGracefulTimeout() {
|
||||
defer tac.cancelGracefulCtx()
|
||||
<-tac.ctx.Done()
|
||||
timer := tac.clock.NewTimer(tailnetConnectorGracefulTimeout, "tailnetAPIClient", "gracefulTimeout")
|
||||
defer timer.Stop()
|
||||
select {
|
||||
case <-tac.closed:
|
||||
case <-timer.C:
|
||||
}
|
||||
}
|
||||
|
||||
// Runs a tailnetAPIConnector using the provided connection
|
||||
func (tac *tailnetAPIConnector) runConnector(conn tailnetConn) {
|
||||
tac.derpCtrl = tailnet.NewBasicDERPController(tac.logger, conn)
|
||||
tac.coordCtrl = tailnet.NewSingleDestController(tac.logger, conn, tac.agentID)
|
||||
tac.gracefulCtx, tac.cancelGracefulCtx = context.WithCancel(context.Background())
|
||||
go tac.manageGracefulTimeout()
|
||||
go func() {
|
||||
defer close(tac.closed)
|
||||
// Sadly retry doesn't support quartz.Clock yet so this is not
|
||||
// influenced by the configured clock.
|
||||
for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(tac.ctx); {
|
||||
tailnetClients, err := tac.dialer.Dial(tac.ctx, tac.tokenCtrl)
|
||||
if err != nil {
|
||||
if xerrors.Is(err, context.Canceled) {
|
||||
continue
|
||||
}
|
||||
errF := slog.Error(err)
|
||||
var sdkErr *codersdk.Error
|
||||
if xerrors.As(err, &sdkErr) {
|
||||
errF = slog.Error(sdkErr)
|
||||
}
|
||||
tac.logger.Error(tac.ctx, "failed to dial tailnet v2+ API", errF)
|
||||
continue
|
||||
}
|
||||
tac.logger.Debug(tac.ctx, "obtained tailnet API v2+ client")
|
||||
tac.runConnectorOnce(tailnetClients)
|
||||
tac.logger.Debug(tac.ctx, "tailnet API v2+ connection lost")
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
var permanentErrorStatuses = []int{
|
||||
http.StatusConflict, // returned if client/agent connections disabled (browser only)
|
||||
http.StatusBadRequest, // returned if API mismatch
|
||||
http.StatusNotFound, // returned if user doesn't have permission or agent doesn't exist
|
||||
}
|
||||
|
||||
// runConnectorOnce uses the provided client to coordinate and stream DERP Maps. It is combined
|
||||
// into one function so that a problem with one tears down the other and triggers a retry (if
|
||||
// appropriate). We multiplex both RPCs over the same websocket, so we want them to share the same
|
||||
// fate.
|
||||
func (tac *tailnetAPIConnector) runConnectorOnce(clients tailnet.ControlProtocolClients) {
|
||||
defer func() {
|
||||
closeErr := clients.Closer.Close()
|
||||
if closeErr != nil &&
|
||||
!xerrors.Is(closeErr, io.EOF) &&
|
||||
!xerrors.Is(closeErr, context.Canceled) &&
|
||||
!xerrors.Is(closeErr, context.DeadlineExceeded) {
|
||||
tac.logger.Error(tac.ctx, "error closing DRPC connection", slog.Error(closeErr))
|
||||
}
|
||||
}()
|
||||
|
||||
tac.telCtrl.New(clients.Telemetry) // synchronous, doesn't need a goroutine
|
||||
|
||||
refreshTokenCtx, refreshTokenCancel := context.WithCancel(tac.ctx)
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(3)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
tac.coordinate(clients.Coordinator)
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
defer refreshTokenCancel()
|
||||
dErr := tac.derpMap(clients.DERP)
|
||||
if dErr != nil && tac.ctx.Err() == nil {
|
||||
// The main context is still active, meaning that we want the tailnet data plane to stay
|
||||
// up, even though we hit some error getting DERP maps on the control plane. That means
|
||||
// we do NOT want to gracefully disconnect on the coordinate() routine. So, we'll just
|
||||
// close the underlying connection. This will trigger a retry of the control plane in
|
||||
// run().
|
||||
clients.Closer.Close()
|
||||
// Note that derpMap() logs it own errors, we don't bother here.
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
tac.refreshToken(refreshTokenCtx, clients.ResumeToken)
|
||||
}()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (tac *tailnetAPIConnector) coordinate(client tailnet.CoordinatorClient) {
|
||||
defer func() {
|
||||
cErr := client.Close()
|
||||
if cErr != nil {
|
||||
tac.logger.Debug(tac.ctx, "error closing Coordinate RPC", slog.Error(cErr))
|
||||
}
|
||||
}()
|
||||
coordination := tac.coordCtrl.New(client)
|
||||
tac.logger.Debug(tac.ctx, "serving coordinator")
|
||||
select {
|
||||
case <-tac.ctx.Done():
|
||||
tac.logger.Debug(tac.ctx, "main context canceled; do graceful disconnect")
|
||||
crdErr := coordination.Close(tac.gracefulCtx)
|
||||
if crdErr != nil {
|
||||
tac.logger.Warn(tac.ctx, "failed to close remote coordination", slog.Error(crdErr))
|
||||
}
|
||||
case err := <-coordination.Wait():
|
||||
if err != nil &&
|
||||
!xerrors.Is(err, io.EOF) &&
|
||||
!xerrors.Is(err, context.Canceled) &&
|
||||
!xerrors.Is(err, context.DeadlineExceeded) {
|
||||
tac.logger.Error(tac.ctx, "remote coordination error", slog.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (tac *tailnetAPIConnector) derpMap(client tailnet.DERPClient) error {
|
||||
defer func() {
|
||||
cErr := client.Close()
|
||||
if cErr != nil {
|
||||
tac.logger.Debug(tac.ctx, "error closing StreamDERPMaps RPC", slog.Error(cErr))
|
||||
}
|
||||
}()
|
||||
cw := tac.derpCtrl.New(client)
|
||||
select {
|
||||
case <-tac.ctx.Done():
|
||||
cErr := client.Close()
|
||||
if cErr != nil {
|
||||
tac.logger.Warn(tac.ctx, "failed to close StreamDERPMaps RPC", slog.Error(cErr))
|
||||
}
|
||||
return nil
|
||||
case err := <-cw.Wait():
|
||||
if xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) {
|
||||
return nil
|
||||
}
|
||||
if err != nil && !xerrors.Is(err, io.EOF) {
|
||||
tac.logger.Error(tac.ctx, "error receiving DERP Map", slog.Error(err))
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func (tac *tailnetAPIConnector) refreshToken(ctx context.Context, client tailnet.ResumeTokenClient) {
|
||||
cw := tac.tokenCtrl.New(client)
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
cErr := cw.Close(tac.ctx)
|
||||
if cErr != nil {
|
||||
tac.logger.Error(tac.ctx, "error closing token refresher", slog.Error(cErr))
|
||||
}
|
||||
}()
|
||||
|
||||
err := <-cw.Wait()
|
||||
if err != nil && !xerrors.Is(err, context.Canceled) && !xerrors.Is(err, context.DeadlineExceeded) {
|
||||
tac.logger.Error(tac.ctx, "error receiving refresh token", slog.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
func (tac *tailnetAPIConnector) SendTelemetryEvent(event *proto.TelemetryEvent) {
|
||||
tac.telCtrl.SendTelemetryEvent(event)
|
||||
}
|
||||
|
||||
type WebsocketDialer struct {
|
||||
logger slog.Logger
|
||||
dialOptions *websocket.DialOptions
|
||||
url *url.URL
|
||||
resumeTokenFailed bool
|
||||
connected chan error
|
||||
isFirst bool
|
||||
}
|
||||
|
||||
func (w *WebsocketDialer) Dial(ctx context.Context, r tailnet.ResumeTokenController,
|
||||
) (
|
||||
tailnet.ControlProtocolClients, error,
|
||||
) {
|
||||
w.logger.Debug(ctx, "dialing Coder tailnet v2+ API")
|
||||
|
||||
u := new(url.URL)
|
||||
*u = *w.url
|
||||
if r != nil && !w.resumeTokenFailed {
|
||||
if token, ok := r.Token(); ok {
|
||||
q := u.Query()
|
||||
q.Set("resume_token", token)
|
||||
u.RawQuery = q.Encode()
|
||||
w.logger.Debug(ctx, "using resume token on dial")
|
||||
}
|
||||
}
|
||||
|
||||
// nolint:bodyclose
|
||||
ws, res, err := websocket.Dial(ctx, u.String(), w.dialOptions)
|
||||
if w.isFirst {
|
||||
if res != nil && slices.Contains(permanentErrorStatuses, res.StatusCode) {
|
||||
err = codersdk.ReadBodyAsError(res)
|
||||
// A bit more human-readable help in the case the API version was rejected
|
||||
var sdkErr *codersdk.Error
|
||||
if xerrors.As(err, &sdkErr) {
|
||||
if sdkErr.Message == AgentAPIMismatchMessage &&
|
||||
sdkErr.StatusCode() == http.StatusBadRequest {
|
||||
sdkErr.Helper = fmt.Sprintf(
|
||||
"Ensure your client release version (%s, different than the API version) matches the server release version",
|
||||
buildinfo.Version())
|
||||
}
|
||||
}
|
||||
w.connected <- err
|
||||
return tailnet.ControlProtocolClients{}, err
|
||||
}
|
||||
w.isFirst = false
|
||||
close(w.connected)
|
||||
}
|
||||
if err != nil {
|
||||
bodyErr := codersdk.ReadBodyAsError(res)
|
||||
var sdkErr *codersdk.Error
|
||||
if xerrors.As(bodyErr, &sdkErr) {
|
||||
for _, v := range sdkErr.Validations {
|
||||
if v.Field == "resume_token" {
|
||||
// Unset the resume token for the next attempt
|
||||
w.logger.Warn(ctx, "failed to dial tailnet v2+ API: server replied invalid resume token; unsetting for next connection attempt")
|
||||
w.resumeTokenFailed = true
|
||||
return tailnet.ControlProtocolClients{}, err
|
||||
}
|
||||
}
|
||||
}
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
w.logger.Error(ctx, "failed to dial tailnet v2+ API", slog.Error(err), slog.F("sdk_err", sdkErr))
|
||||
}
|
||||
return tailnet.ControlProtocolClients{}, err
|
||||
}
|
||||
w.resumeTokenFailed = false
|
||||
|
||||
client, err := tailnet.NewDRPCClient(
|
||||
websocket.NetConn(context.Background(), ws, websocket.MessageBinary),
|
||||
w.logger,
|
||||
)
|
||||
if err != nil {
|
||||
w.logger.Debug(ctx, "failed to create DRPCClient", slog.Error(err))
|
||||
_ = ws.Close(websocket.StatusInternalError, "")
|
||||
return tailnet.ControlProtocolClients{}, err
|
||||
}
|
||||
coord, err := client.Coordinate(context.Background())
|
||||
if err != nil {
|
||||
w.logger.Debug(ctx, "failed to create Coordinate RPC", slog.Error(err))
|
||||
_ = ws.Close(websocket.StatusInternalError, "")
|
||||
return tailnet.ControlProtocolClients{}, err
|
||||
}
|
||||
|
||||
derps := &tailnet.DERPFromDRPCWrapper{}
|
||||
derps.Client, err = client.StreamDERPMaps(context.Background(), &proto.StreamDERPMapsRequest{})
|
||||
if err != nil {
|
||||
w.logger.Debug(ctx, "failed to create DERPMap stream", slog.Error(err))
|
||||
_ = ws.Close(websocket.StatusInternalError, "")
|
||||
return tailnet.ControlProtocolClients{}, err
|
||||
}
|
||||
|
||||
return tailnet.ControlProtocolClients{
|
||||
Closer: client.DRPCConn(),
|
||||
Coordinator: coord,
|
||||
DERP: derps,
|
||||
ResumeToken: client,
|
||||
Telemetry: client,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (w *WebsocketDialer) Connected() <-chan error {
|
||||
return w.connected
|
||||
}
|
||||
|
||||
func NewWebsocketDialer(logger slog.Logger, u *url.URL, opts *websocket.DialOptions) *WebsocketDialer {
|
||||
return &WebsocketDialer{
|
||||
logger: logger,
|
||||
dialOptions: opts,
|
||||
url: u,
|
||||
connected: make(chan error, 1),
|
||||
isFirst: true,
|
||||
}
|
||||
}
|
@ -1,218 +0,0 @@
|
||||
package workspacesdk
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/hashicorp/yamux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"tailscale.com/tailcfg"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/tailnet"
|
||||
"github.com/coder/coder/v2/tailnet/proto"
|
||||
"github.com/coder/coder/v2/tailnet/tailnettest"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// Give tests a bit more time to timeout. Darwin is particularly slow.
|
||||
tailnetConnectorGracefulTimeout = 5 * time.Second
|
||||
}
|
||||
|
||||
func TestTailnetAPIConnector_Disconnects(t *testing.T) {
|
||||
t.Parallel()
|
||||
testCtx := testutil.Context(t, testutil.WaitShort)
|
||||
ctx, cancel := context.WithCancel(testCtx)
|
||||
logger := slogtest.Make(t, &slogtest.Options{
|
||||
IgnoredErrorIs: append(slogtest.DefaultIgnoredErrorIs,
|
||||
io.EOF, // we get EOF when we simulate a DERPMap error
|
||||
yamux.ErrSessionShutdown, // coordination can throw these when DERP error tears down session
|
||||
),
|
||||
}).Leveled(slog.LevelDebug)
|
||||
agentID := uuid.UUID{0x55}
|
||||
clientID := uuid.UUID{0x66}
|
||||
fCoord := tailnettest.NewFakeCoordinator()
|
||||
var coord tailnet.Coordinator = fCoord
|
||||
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
|
||||
coordPtr.Store(&coord)
|
||||
derpMapCh := make(chan *tailcfg.DERPMap)
|
||||
defer close(derpMapCh)
|
||||
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
|
||||
Logger: logger.Named("svc"),
|
||||
CoordPtr: &coordPtr,
|
||||
DERPMapUpdateFrequency: time.Millisecond,
|
||||
DERPMapFn: func() *tailcfg.DERPMap { return <-derpMapCh },
|
||||
NetworkTelemetryHandler: func([]*proto.TelemetryEvent) {},
|
||||
ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
dialer := &pipeDialer{
|
||||
ctx: testCtx,
|
||||
logger: logger,
|
||||
t: t,
|
||||
svc: svc,
|
||||
streamID: tailnet.StreamID{
|
||||
Name: "client",
|
||||
ID: clientID,
|
||||
Auth: tailnet.ClientCoordinateeAuth{AgentID: agentID},
|
||||
},
|
||||
}
|
||||
|
||||
fConn := newFakeTailnetConn()
|
||||
|
||||
uut := newTailnetAPIConnector(ctx, logger.Named("tac"), agentID, dialer, quartz.NewReal())
|
||||
uut.runConnector(fConn)
|
||||
|
||||
call := testutil.RequireRecvCtx(ctx, t, fCoord.CoordinateCalls)
|
||||
reqTun := testutil.RequireRecvCtx(ctx, t, call.Reqs)
|
||||
require.NotNil(t, reqTun.AddTunnel)
|
||||
|
||||
// simulate a problem with DERPMaps by sending nil
|
||||
testutil.RequireSendCtx(ctx, t, derpMapCh, nil)
|
||||
|
||||
// this should cause the coordinate call to hang up WITHOUT disconnecting
|
||||
reqNil := testutil.RequireRecvCtx(ctx, t, call.Reqs)
|
||||
require.Nil(t, reqNil)
|
||||
|
||||
// ...and then reconnect
|
||||
call = testutil.RequireRecvCtx(ctx, t, fCoord.CoordinateCalls)
|
||||
reqTun = testutil.RequireRecvCtx(ctx, t, call.Reqs)
|
||||
require.NotNil(t, reqTun.AddTunnel)
|
||||
|
||||
// canceling the context should trigger the disconnect message
|
||||
cancel()
|
||||
reqDisc := testutil.RequireRecvCtx(testCtx, t, call.Reqs)
|
||||
require.NotNil(t, reqDisc)
|
||||
require.NotNil(t, reqDisc.Disconnect)
|
||||
close(call.Resps)
|
||||
}
|
||||
|
||||
func TestTailnetAPIConnector_TelemetrySuccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
agentID := uuid.UUID{0x55}
|
||||
clientID := uuid.UUID{0x66}
|
||||
fCoord := tailnettest.NewFakeCoordinator()
|
||||
var coord tailnet.Coordinator = fCoord
|
||||
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
|
||||
coordPtr.Store(&coord)
|
||||
derpMapCh := make(chan *tailcfg.DERPMap)
|
||||
defer close(derpMapCh)
|
||||
eventCh := make(chan []*proto.TelemetryEvent, 1)
|
||||
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
|
||||
Logger: logger,
|
||||
CoordPtr: &coordPtr,
|
||||
DERPMapUpdateFrequency: time.Millisecond,
|
||||
DERPMapFn: func() *tailcfg.DERPMap { return <-derpMapCh },
|
||||
NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Error("timeout sending telemetry event")
|
||||
case eventCh <- batch:
|
||||
t.Log("sent telemetry batch")
|
||||
}
|
||||
},
|
||||
ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
dialer := &pipeDialer{
|
||||
ctx: ctx,
|
||||
logger: logger,
|
||||
t: t,
|
||||
svc: svc,
|
||||
streamID: tailnet.StreamID{
|
||||
Name: "client",
|
||||
ID: clientID,
|
||||
Auth: tailnet.ClientCoordinateeAuth{AgentID: agentID},
|
||||
},
|
||||
}
|
||||
|
||||
fConn := newFakeTailnetConn()
|
||||
|
||||
uut := newTailnetAPIConnector(ctx, logger, agentID, dialer, quartz.NewReal())
|
||||
uut.runConnector(fConn)
|
||||
// Coordinate calls happen _after_ telemetry is connected up, so we use this
|
||||
// to ensure telemetry is connected before sending our event
|
||||
cc := testutil.RequireRecvCtx(ctx, t, fCoord.CoordinateCalls)
|
||||
defer close(cc.Resps)
|
||||
|
||||
uut.SendTelemetryEvent(&proto.TelemetryEvent{
|
||||
Id: []byte("test event"),
|
||||
})
|
||||
|
||||
testEvents := testutil.RequireRecvCtx(ctx, t, eventCh)
|
||||
|
||||
require.Len(t, testEvents, 1)
|
||||
require.Equal(t, []byte("test event"), testEvents[0].Id)
|
||||
}
|
||||
|
||||
type fakeTailnetConn struct{}
|
||||
|
||||
func (*fakeTailnetConn) UpdatePeers([]*proto.CoordinateResponse_PeerUpdate) error {
|
||||
// TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (*fakeTailnetConn) SetAllPeersLost() {}
|
||||
|
||||
func (*fakeTailnetConn) SetNodeCallback(func(*tailnet.Node)) {}
|
||||
|
||||
func (*fakeTailnetConn) SetDERPMap(*tailcfg.DERPMap) {}
|
||||
|
||||
func (*fakeTailnetConn) SetTunnelDestination(uuid.UUID) {}
|
||||
|
||||
func newFakeTailnetConn() *fakeTailnetConn {
|
||||
return &fakeTailnetConn{}
|
||||
}
|
||||
|
||||
type pipeDialer struct {
|
||||
ctx context.Context
|
||||
logger slog.Logger
|
||||
t testing.TB
|
||||
svc *tailnet.ClientService
|
||||
streamID tailnet.StreamID
|
||||
}
|
||||
|
||||
func (p *pipeDialer) Dial(_ context.Context, _ tailnet.ResumeTokenController) (tailnet.ControlProtocolClients, error) {
|
||||
s, c := net.Pipe()
|
||||
go func() {
|
||||
err := p.svc.ServeConnV2(p.ctx, s, p.streamID)
|
||||
p.logger.Debug(p.ctx, "piped tailnet service complete", slog.Error(err))
|
||||
}()
|
||||
client, err := tailnet.NewDRPCClient(c, p.logger)
|
||||
if !assert.NoError(p.t, err) {
|
||||
_ = c.Close()
|
||||
return tailnet.ControlProtocolClients{}, err
|
||||
}
|
||||
coord, err := client.Coordinate(context.Background())
|
||||
if !assert.NoError(p.t, err) {
|
||||
_ = c.Close()
|
||||
return tailnet.ControlProtocolClients{}, err
|
||||
}
|
||||
|
||||
derps := &tailnet.DERPFromDRPCWrapper{}
|
||||
derps.Client, err = client.StreamDERPMaps(context.Background(), &proto.StreamDERPMapsRequest{})
|
||||
if !assert.NoError(p.t, err) {
|
||||
_ = c.Close()
|
||||
return tailnet.ControlProtocolClients{}, err
|
||||
}
|
||||
return tailnet.ControlProtocolClients{
|
||||
Closer: client.DRPCConn(),
|
||||
Coordinator: coord,
|
||||
DERP: derps,
|
||||
ResumeToken: client,
|
||||
Telemetry: client,
|
||||
}, nil
|
||||
}
|
139
codersdk/workspacesdk/dialer.go
Normal file
139
codersdk/workspacesdk/dialer.go
Normal file
@ -0,0 +1,139 @@
|
||||
package workspacesdk
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"slices"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
"nhooyr.io/websocket"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/v2/buildinfo"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/tailnet"
|
||||
"github.com/coder/coder/v2/tailnet/proto"
|
||||
)
|
||||
|
||||
var permanentErrorStatuses = []int{
|
||||
http.StatusConflict, // returned if client/agent connections disabled (browser only)
|
||||
http.StatusBadRequest, // returned if API mismatch
|
||||
http.StatusNotFound, // returned if user doesn't have permission or agent doesn't exist
|
||||
}
|
||||
|
||||
type WebsocketDialer struct {
|
||||
logger slog.Logger
|
||||
dialOptions *websocket.DialOptions
|
||||
url *url.URL
|
||||
resumeTokenFailed bool
|
||||
connected chan error
|
||||
isFirst bool
|
||||
}
|
||||
|
||||
func (w *WebsocketDialer) Dial(ctx context.Context, r tailnet.ResumeTokenController,
|
||||
) (
|
||||
tailnet.ControlProtocolClients, error,
|
||||
) {
|
||||
w.logger.Debug(ctx, "dialing Coder tailnet v2+ API")
|
||||
|
||||
u := new(url.URL)
|
||||
*u = *w.url
|
||||
if r != nil && !w.resumeTokenFailed {
|
||||
if token, ok := r.Token(); ok {
|
||||
q := u.Query()
|
||||
q.Set("resume_token", token)
|
||||
u.RawQuery = q.Encode()
|
||||
w.logger.Debug(ctx, "using resume token on dial")
|
||||
}
|
||||
}
|
||||
|
||||
// nolint:bodyclose
|
||||
ws, res, err := websocket.Dial(ctx, u.String(), w.dialOptions)
|
||||
if w.isFirst {
|
||||
if res != nil && slices.Contains(permanentErrorStatuses, res.StatusCode) {
|
||||
err = codersdk.ReadBodyAsError(res)
|
||||
// A bit more human-readable help in the case the API version was rejected
|
||||
var sdkErr *codersdk.Error
|
||||
if xerrors.As(err, &sdkErr) {
|
||||
if sdkErr.Message == AgentAPIMismatchMessage &&
|
||||
sdkErr.StatusCode() == http.StatusBadRequest {
|
||||
sdkErr.Helper = fmt.Sprintf(
|
||||
"Ensure your client release version (%s, different than the API version) matches the server release version",
|
||||
buildinfo.Version())
|
||||
}
|
||||
}
|
||||
w.connected <- err
|
||||
return tailnet.ControlProtocolClients{}, err
|
||||
}
|
||||
w.isFirst = false
|
||||
close(w.connected)
|
||||
}
|
||||
if err != nil {
|
||||
bodyErr := codersdk.ReadBodyAsError(res)
|
||||
var sdkErr *codersdk.Error
|
||||
if xerrors.As(bodyErr, &sdkErr) {
|
||||
for _, v := range sdkErr.Validations {
|
||||
if v.Field == "resume_token" {
|
||||
// Unset the resume token for the next attempt
|
||||
w.logger.Warn(ctx, "failed to dial tailnet v2+ API: server replied invalid resume token; unsetting for next connection attempt")
|
||||
w.resumeTokenFailed = true
|
||||
return tailnet.ControlProtocolClients{}, err
|
||||
}
|
||||
}
|
||||
}
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
w.logger.Error(ctx, "failed to dial tailnet v2+ API", slog.Error(err), slog.F("sdk_err", sdkErr))
|
||||
}
|
||||
return tailnet.ControlProtocolClients{}, err
|
||||
}
|
||||
w.resumeTokenFailed = false
|
||||
|
||||
client, err := tailnet.NewDRPCClient(
|
||||
websocket.NetConn(context.Background(), ws, websocket.MessageBinary),
|
||||
w.logger,
|
||||
)
|
||||
if err != nil {
|
||||
w.logger.Debug(ctx, "failed to create DRPCClient", slog.Error(err))
|
||||
_ = ws.Close(websocket.StatusInternalError, "")
|
||||
return tailnet.ControlProtocolClients{}, err
|
||||
}
|
||||
coord, err := client.Coordinate(context.Background())
|
||||
if err != nil {
|
||||
w.logger.Debug(ctx, "failed to create Coordinate RPC", slog.Error(err))
|
||||
_ = ws.Close(websocket.StatusInternalError, "")
|
||||
return tailnet.ControlProtocolClients{}, err
|
||||
}
|
||||
|
||||
derps := &tailnet.DERPFromDRPCWrapper{}
|
||||
derps.Client, err = client.StreamDERPMaps(context.Background(), &proto.StreamDERPMapsRequest{})
|
||||
if err != nil {
|
||||
w.logger.Debug(ctx, "failed to create DERPMap stream", slog.Error(err))
|
||||
_ = ws.Close(websocket.StatusInternalError, "")
|
||||
return tailnet.ControlProtocolClients{}, err
|
||||
}
|
||||
|
||||
return tailnet.ControlProtocolClients{
|
||||
Closer: client.DRPCConn(),
|
||||
Coordinator: coord,
|
||||
DERP: derps,
|
||||
ResumeToken: client,
|
||||
Telemetry: client,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (w *WebsocketDialer) Connected() <-chan error {
|
||||
return w.connected
|
||||
}
|
||||
|
||||
func NewWebsocketDialer(logger slog.Logger, u *url.URL, opts *websocket.DialOptions) *WebsocketDialer {
|
||||
return &WebsocketDialer{
|
||||
logger: logger,
|
||||
dialOptions: opts,
|
||||
url: u,
|
||||
connected: make(chan error, 1),
|
||||
isFirst: true,
|
||||
}
|
||||
}
|
@ -234,7 +234,9 @@ func (c *Client) DialAgent(dialCtx context.Context, agentID uuid.UUID, options *
|
||||
// Need to disable compression to avoid a data-race.
|
||||
CompressionMode: websocket.CompressionDisabled,
|
||||
})
|
||||
connector := newTailnetAPIConnector(ctx, options.Logger, agentID, dialer, quartz.NewReal())
|
||||
clk := quartz.NewReal()
|
||||
controller := tailnet.NewController(options.Logger, dialer)
|
||||
controller.ResumeTokenCtrl = tailnet.NewBasicResumeTokenController(options.Logger, clk)
|
||||
|
||||
ip := tailnet.TailscaleServicePrefix.RandomAddr()
|
||||
var header http.Header
|
||||
@ -243,7 +245,9 @@ func (c *Client) DialAgent(dialCtx context.Context, agentID uuid.UUID, options *
|
||||
}
|
||||
var telemetrySink tailnet.TelemetrySink
|
||||
if options.EnableTelemetry {
|
||||
telemetrySink = connector
|
||||
basicTel := tailnet.NewBasicTelemetryController(options.Logger)
|
||||
telemetrySink = basicTel
|
||||
controller.TelemetryCtrl = basicTel
|
||||
}
|
||||
conn, err := tailnet.NewConn(&tailnet.Options{
|
||||
Addresses: []netip.Prefix{netip.PrefixFrom(ip, 128)},
|
||||
@ -264,7 +268,9 @@ func (c *Client) DialAgent(dialCtx context.Context, agentID uuid.UUID, options *
|
||||
_ = conn.Close()
|
||||
}
|
||||
}()
|
||||
connector.runConnector(conn)
|
||||
controller.CoordCtrl = tailnet.NewSingleDestController(options.Logger, conn, agentID)
|
||||
controller.DERPCtrl = tailnet.NewBasicDERPController(options.Logger, conn)
|
||||
controller.Run(ctx)
|
||||
|
||||
options.Logger.Debug(ctx, "running tailnet API v2+ connector")
|
||||
|
||||
@ -283,7 +289,7 @@ func (c *Client) DialAgent(dialCtx context.Context, agentID uuid.UUID, options *
|
||||
AgentID: agentID,
|
||||
CloseFunc: func() error {
|
||||
cancel()
|
||||
<-connector.closed
|
||||
<-controller.Closed()
|
||||
return conn.Close()
|
||||
},
|
||||
})
|
||||
|
Reference in New Issue
Block a user