mirror of
https://github.com/coder/coder.git
synced 2025-07-09 11:45:56 +00:00
chore: refactor tailnetAPIConnector to tailnet.Controller (#15361)
Refactors `workspacesdk.tailnetAPIConnector` as a `tailnet.Controller` to reuse all the reconnection and graceful disconnect logic. chore re: #14729
This commit is contained in:
@ -1,363 +0,0 @@
|
|||||||
package workspacesdk
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"slices"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/google/uuid"
|
|
||||||
"golang.org/x/xerrors"
|
|
||||||
"nhooyr.io/websocket"
|
|
||||||
|
|
||||||
"cdr.dev/slog"
|
|
||||||
"github.com/coder/coder/v2/buildinfo"
|
|
||||||
"github.com/coder/coder/v2/codersdk"
|
|
||||||
"github.com/coder/coder/v2/tailnet"
|
|
||||||
"github.com/coder/coder/v2/tailnet/proto"
|
|
||||||
"github.com/coder/quartz"
|
|
||||||
"github.com/coder/retry"
|
|
||||||
)
|
|
||||||
|
|
||||||
var tailnetConnectorGracefulTimeout = time.Second
|
|
||||||
|
|
||||||
// tailnetConn is the subset of the tailnet.Conn methods that tailnetAPIConnector uses. It is
|
|
||||||
// included so that we can fake it in testing.
|
|
||||||
//
|
|
||||||
// @typescript-ignore tailnetConn
|
|
||||||
type tailnetConn interface {
|
|
||||||
tailnet.Coordinatee
|
|
||||||
tailnet.DERPMapSetter
|
|
||||||
}
|
|
||||||
|
|
||||||
// tailnetAPIConnector dials the tailnet API (v2+) and then uses the API with a tailnet.Conn to
|
|
||||||
//
|
|
||||||
// 1) run the Coordinate API and pass node information back and forth
|
|
||||||
// 2) stream DERPMap updates and program the Conn
|
|
||||||
// 3) Send network telemetry events
|
|
||||||
//
|
|
||||||
// These functions share the same websocket, and so are combined here so that if we hit a problem
|
|
||||||
// we tear the whole thing down and start over with a new websocket.
|
|
||||||
//
|
|
||||||
// @typescript-ignore tailnetAPIConnector
|
|
||||||
type tailnetAPIConnector struct {
|
|
||||||
// We keep track of two contexts: the main context from the caller, and a "graceful" context
|
|
||||||
// that we keep open slightly longer than the main context to give a chance to send the
|
|
||||||
// Disconnect message to the coordinator. That tells the coordinator that we really meant to
|
|
||||||
// disconnect instead of just losing network connectivity.
|
|
||||||
ctx context.Context
|
|
||||||
gracefulCtx context.Context
|
|
||||||
cancelGracefulCtx context.CancelFunc
|
|
||||||
|
|
||||||
logger slog.Logger
|
|
||||||
|
|
||||||
agentID uuid.UUID
|
|
||||||
clock quartz.Clock
|
|
||||||
dialer tailnet.ControlProtocolDialer
|
|
||||||
derpCtrl tailnet.DERPController
|
|
||||||
coordCtrl tailnet.CoordinationController
|
|
||||||
telCtrl *tailnet.BasicTelemetryController
|
|
||||||
tokenCtrl tailnet.ResumeTokenController
|
|
||||||
|
|
||||||
closed chan struct{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a new tailnetAPIConnector without running it
|
|
||||||
func newTailnetAPIConnector(ctx context.Context, logger slog.Logger, agentID uuid.UUID, dialer tailnet.ControlProtocolDialer, clock quartz.Clock) *tailnetAPIConnector {
|
|
||||||
return &tailnetAPIConnector{
|
|
||||||
ctx: ctx,
|
|
||||||
logger: logger,
|
|
||||||
agentID: agentID,
|
|
||||||
clock: clock,
|
|
||||||
dialer: dialer,
|
|
||||||
closed: make(chan struct{}),
|
|
||||||
telCtrl: tailnet.NewBasicTelemetryController(logger),
|
|
||||||
tokenCtrl: tailnet.NewBasicResumeTokenController(logger, clock),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// manageGracefulTimeout allows the gracefulContext to last 1 second longer than the main context
|
|
||||||
// to allow a graceful disconnect.
|
|
||||||
func (tac *tailnetAPIConnector) manageGracefulTimeout() {
|
|
||||||
defer tac.cancelGracefulCtx()
|
|
||||||
<-tac.ctx.Done()
|
|
||||||
timer := tac.clock.NewTimer(tailnetConnectorGracefulTimeout, "tailnetAPIClient", "gracefulTimeout")
|
|
||||||
defer timer.Stop()
|
|
||||||
select {
|
|
||||||
case <-tac.closed:
|
|
||||||
case <-timer.C:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Runs a tailnetAPIConnector using the provided connection
|
|
||||||
func (tac *tailnetAPIConnector) runConnector(conn tailnetConn) {
|
|
||||||
tac.derpCtrl = tailnet.NewBasicDERPController(tac.logger, conn)
|
|
||||||
tac.coordCtrl = tailnet.NewSingleDestController(tac.logger, conn, tac.agentID)
|
|
||||||
tac.gracefulCtx, tac.cancelGracefulCtx = context.WithCancel(context.Background())
|
|
||||||
go tac.manageGracefulTimeout()
|
|
||||||
go func() {
|
|
||||||
defer close(tac.closed)
|
|
||||||
// Sadly retry doesn't support quartz.Clock yet so this is not
|
|
||||||
// influenced by the configured clock.
|
|
||||||
for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(tac.ctx); {
|
|
||||||
tailnetClients, err := tac.dialer.Dial(tac.ctx, tac.tokenCtrl)
|
|
||||||
if err != nil {
|
|
||||||
if xerrors.Is(err, context.Canceled) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
errF := slog.Error(err)
|
|
||||||
var sdkErr *codersdk.Error
|
|
||||||
if xerrors.As(err, &sdkErr) {
|
|
||||||
errF = slog.Error(sdkErr)
|
|
||||||
}
|
|
||||||
tac.logger.Error(tac.ctx, "failed to dial tailnet v2+ API", errF)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
tac.logger.Debug(tac.ctx, "obtained tailnet API v2+ client")
|
|
||||||
tac.runConnectorOnce(tailnetClients)
|
|
||||||
tac.logger.Debug(tac.ctx, "tailnet API v2+ connection lost")
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
var permanentErrorStatuses = []int{
|
|
||||||
http.StatusConflict, // returned if client/agent connections disabled (browser only)
|
|
||||||
http.StatusBadRequest, // returned if API mismatch
|
|
||||||
http.StatusNotFound, // returned if user doesn't have permission or agent doesn't exist
|
|
||||||
}
|
|
||||||
|
|
||||||
// runConnectorOnce uses the provided client to coordinate and stream DERP Maps. It is combined
|
|
||||||
// into one function so that a problem with one tears down the other and triggers a retry (if
|
|
||||||
// appropriate). We multiplex both RPCs over the same websocket, so we want them to share the same
|
|
||||||
// fate.
|
|
||||||
func (tac *tailnetAPIConnector) runConnectorOnce(clients tailnet.ControlProtocolClients) {
|
|
||||||
defer func() {
|
|
||||||
closeErr := clients.Closer.Close()
|
|
||||||
if closeErr != nil &&
|
|
||||||
!xerrors.Is(closeErr, io.EOF) &&
|
|
||||||
!xerrors.Is(closeErr, context.Canceled) &&
|
|
||||||
!xerrors.Is(closeErr, context.DeadlineExceeded) {
|
|
||||||
tac.logger.Error(tac.ctx, "error closing DRPC connection", slog.Error(closeErr))
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
tac.telCtrl.New(clients.Telemetry) // synchronous, doesn't need a goroutine
|
|
||||||
|
|
||||||
refreshTokenCtx, refreshTokenCancel := context.WithCancel(tac.ctx)
|
|
||||||
wg := sync.WaitGroup{}
|
|
||||||
wg.Add(3)
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
tac.coordinate(clients.Coordinator)
|
|
||||||
}()
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
defer refreshTokenCancel()
|
|
||||||
dErr := tac.derpMap(clients.DERP)
|
|
||||||
if dErr != nil && tac.ctx.Err() == nil {
|
|
||||||
// The main context is still active, meaning that we want the tailnet data plane to stay
|
|
||||||
// up, even though we hit some error getting DERP maps on the control plane. That means
|
|
||||||
// we do NOT want to gracefully disconnect on the coordinate() routine. So, we'll just
|
|
||||||
// close the underlying connection. This will trigger a retry of the control plane in
|
|
||||||
// run().
|
|
||||||
clients.Closer.Close()
|
|
||||||
// Note that derpMap() logs it own errors, we don't bother here.
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
tac.refreshToken(refreshTokenCtx, clients.ResumeToken)
|
|
||||||
}()
|
|
||||||
wg.Wait()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tac *tailnetAPIConnector) coordinate(client tailnet.CoordinatorClient) {
|
|
||||||
defer func() {
|
|
||||||
cErr := client.Close()
|
|
||||||
if cErr != nil {
|
|
||||||
tac.logger.Debug(tac.ctx, "error closing Coordinate RPC", slog.Error(cErr))
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
coordination := tac.coordCtrl.New(client)
|
|
||||||
tac.logger.Debug(tac.ctx, "serving coordinator")
|
|
||||||
select {
|
|
||||||
case <-tac.ctx.Done():
|
|
||||||
tac.logger.Debug(tac.ctx, "main context canceled; do graceful disconnect")
|
|
||||||
crdErr := coordination.Close(tac.gracefulCtx)
|
|
||||||
if crdErr != nil {
|
|
||||||
tac.logger.Warn(tac.ctx, "failed to close remote coordination", slog.Error(crdErr))
|
|
||||||
}
|
|
||||||
case err := <-coordination.Wait():
|
|
||||||
if err != nil &&
|
|
||||||
!xerrors.Is(err, io.EOF) &&
|
|
||||||
!xerrors.Is(err, context.Canceled) &&
|
|
||||||
!xerrors.Is(err, context.DeadlineExceeded) {
|
|
||||||
tac.logger.Error(tac.ctx, "remote coordination error", slog.Error(err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tac *tailnetAPIConnector) derpMap(client tailnet.DERPClient) error {
|
|
||||||
defer func() {
|
|
||||||
cErr := client.Close()
|
|
||||||
if cErr != nil {
|
|
||||||
tac.logger.Debug(tac.ctx, "error closing StreamDERPMaps RPC", slog.Error(cErr))
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
cw := tac.derpCtrl.New(client)
|
|
||||||
select {
|
|
||||||
case <-tac.ctx.Done():
|
|
||||||
cErr := client.Close()
|
|
||||||
if cErr != nil {
|
|
||||||
tac.logger.Warn(tac.ctx, "failed to close StreamDERPMaps RPC", slog.Error(cErr))
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
case err := <-cw.Wait():
|
|
||||||
if xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if err != nil && !xerrors.Is(err, io.EOF) {
|
|
||||||
tac.logger.Error(tac.ctx, "error receiving DERP Map", slog.Error(err))
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tac *tailnetAPIConnector) refreshToken(ctx context.Context, client tailnet.ResumeTokenClient) {
|
|
||||||
cw := tac.tokenCtrl.New(client)
|
|
||||||
go func() {
|
|
||||||
<-ctx.Done()
|
|
||||||
cErr := cw.Close(tac.ctx)
|
|
||||||
if cErr != nil {
|
|
||||||
tac.logger.Error(tac.ctx, "error closing token refresher", slog.Error(cErr))
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
err := <-cw.Wait()
|
|
||||||
if err != nil && !xerrors.Is(err, context.Canceled) && !xerrors.Is(err, context.DeadlineExceeded) {
|
|
||||||
tac.logger.Error(tac.ctx, "error receiving refresh token", slog.Error(err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tac *tailnetAPIConnector) SendTelemetryEvent(event *proto.TelemetryEvent) {
|
|
||||||
tac.telCtrl.SendTelemetryEvent(event)
|
|
||||||
}
|
|
||||||
|
|
||||||
type WebsocketDialer struct {
|
|
||||||
logger slog.Logger
|
|
||||||
dialOptions *websocket.DialOptions
|
|
||||||
url *url.URL
|
|
||||||
resumeTokenFailed bool
|
|
||||||
connected chan error
|
|
||||||
isFirst bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *WebsocketDialer) Dial(ctx context.Context, r tailnet.ResumeTokenController,
|
|
||||||
) (
|
|
||||||
tailnet.ControlProtocolClients, error,
|
|
||||||
) {
|
|
||||||
w.logger.Debug(ctx, "dialing Coder tailnet v2+ API")
|
|
||||||
|
|
||||||
u := new(url.URL)
|
|
||||||
*u = *w.url
|
|
||||||
if r != nil && !w.resumeTokenFailed {
|
|
||||||
if token, ok := r.Token(); ok {
|
|
||||||
q := u.Query()
|
|
||||||
q.Set("resume_token", token)
|
|
||||||
u.RawQuery = q.Encode()
|
|
||||||
w.logger.Debug(ctx, "using resume token on dial")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// nolint:bodyclose
|
|
||||||
ws, res, err := websocket.Dial(ctx, u.String(), w.dialOptions)
|
|
||||||
if w.isFirst {
|
|
||||||
if res != nil && slices.Contains(permanentErrorStatuses, res.StatusCode) {
|
|
||||||
err = codersdk.ReadBodyAsError(res)
|
|
||||||
// A bit more human-readable help in the case the API version was rejected
|
|
||||||
var sdkErr *codersdk.Error
|
|
||||||
if xerrors.As(err, &sdkErr) {
|
|
||||||
if sdkErr.Message == AgentAPIMismatchMessage &&
|
|
||||||
sdkErr.StatusCode() == http.StatusBadRequest {
|
|
||||||
sdkErr.Helper = fmt.Sprintf(
|
|
||||||
"Ensure your client release version (%s, different than the API version) matches the server release version",
|
|
||||||
buildinfo.Version())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
w.connected <- err
|
|
||||||
return tailnet.ControlProtocolClients{}, err
|
|
||||||
}
|
|
||||||
w.isFirst = false
|
|
||||||
close(w.connected)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
bodyErr := codersdk.ReadBodyAsError(res)
|
|
||||||
var sdkErr *codersdk.Error
|
|
||||||
if xerrors.As(bodyErr, &sdkErr) {
|
|
||||||
for _, v := range sdkErr.Validations {
|
|
||||||
if v.Field == "resume_token" {
|
|
||||||
// Unset the resume token for the next attempt
|
|
||||||
w.logger.Warn(ctx, "failed to dial tailnet v2+ API: server replied invalid resume token; unsetting for next connection attempt")
|
|
||||||
w.resumeTokenFailed = true
|
|
||||||
return tailnet.ControlProtocolClients{}, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !errors.Is(err, context.Canceled) {
|
|
||||||
w.logger.Error(ctx, "failed to dial tailnet v2+ API", slog.Error(err), slog.F("sdk_err", sdkErr))
|
|
||||||
}
|
|
||||||
return tailnet.ControlProtocolClients{}, err
|
|
||||||
}
|
|
||||||
w.resumeTokenFailed = false
|
|
||||||
|
|
||||||
client, err := tailnet.NewDRPCClient(
|
|
||||||
websocket.NetConn(context.Background(), ws, websocket.MessageBinary),
|
|
||||||
w.logger,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
w.logger.Debug(ctx, "failed to create DRPCClient", slog.Error(err))
|
|
||||||
_ = ws.Close(websocket.StatusInternalError, "")
|
|
||||||
return tailnet.ControlProtocolClients{}, err
|
|
||||||
}
|
|
||||||
coord, err := client.Coordinate(context.Background())
|
|
||||||
if err != nil {
|
|
||||||
w.logger.Debug(ctx, "failed to create Coordinate RPC", slog.Error(err))
|
|
||||||
_ = ws.Close(websocket.StatusInternalError, "")
|
|
||||||
return tailnet.ControlProtocolClients{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
derps := &tailnet.DERPFromDRPCWrapper{}
|
|
||||||
derps.Client, err = client.StreamDERPMaps(context.Background(), &proto.StreamDERPMapsRequest{})
|
|
||||||
if err != nil {
|
|
||||||
w.logger.Debug(ctx, "failed to create DERPMap stream", slog.Error(err))
|
|
||||||
_ = ws.Close(websocket.StatusInternalError, "")
|
|
||||||
return tailnet.ControlProtocolClients{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return tailnet.ControlProtocolClients{
|
|
||||||
Closer: client.DRPCConn(),
|
|
||||||
Coordinator: coord,
|
|
||||||
DERP: derps,
|
|
||||||
ResumeToken: client,
|
|
||||||
Telemetry: client,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *WebsocketDialer) Connected() <-chan error {
|
|
||||||
return w.connected
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewWebsocketDialer(logger slog.Logger, u *url.URL, opts *websocket.DialOptions) *WebsocketDialer {
|
|
||||||
return &WebsocketDialer{
|
|
||||||
logger: logger,
|
|
||||||
dialOptions: opts,
|
|
||||||
url: u,
|
|
||||||
connected: make(chan error, 1),
|
|
||||||
isFirst: true,
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,218 +0,0 @@
|
|||||||
package workspacesdk
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"sync/atomic"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/google/uuid"
|
|
||||||
"github.com/hashicorp/yamux"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"tailscale.com/tailcfg"
|
|
||||||
|
|
||||||
"cdr.dev/slog"
|
|
||||||
"cdr.dev/slog/sloggers/slogtest"
|
|
||||||
"github.com/coder/coder/v2/tailnet"
|
|
||||||
"github.com/coder/coder/v2/tailnet/proto"
|
|
||||||
"github.com/coder/coder/v2/tailnet/tailnettest"
|
|
||||||
"github.com/coder/coder/v2/testutil"
|
|
||||||
"github.com/coder/quartz"
|
|
||||||
)
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
// Give tests a bit more time to timeout. Darwin is particularly slow.
|
|
||||||
tailnetConnectorGracefulTimeout = 5 * time.Second
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTailnetAPIConnector_Disconnects(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
testCtx := testutil.Context(t, testutil.WaitShort)
|
|
||||||
ctx, cancel := context.WithCancel(testCtx)
|
|
||||||
logger := slogtest.Make(t, &slogtest.Options{
|
|
||||||
IgnoredErrorIs: append(slogtest.DefaultIgnoredErrorIs,
|
|
||||||
io.EOF, // we get EOF when we simulate a DERPMap error
|
|
||||||
yamux.ErrSessionShutdown, // coordination can throw these when DERP error tears down session
|
|
||||||
),
|
|
||||||
}).Leveled(slog.LevelDebug)
|
|
||||||
agentID := uuid.UUID{0x55}
|
|
||||||
clientID := uuid.UUID{0x66}
|
|
||||||
fCoord := tailnettest.NewFakeCoordinator()
|
|
||||||
var coord tailnet.Coordinator = fCoord
|
|
||||||
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
|
|
||||||
coordPtr.Store(&coord)
|
|
||||||
derpMapCh := make(chan *tailcfg.DERPMap)
|
|
||||||
defer close(derpMapCh)
|
|
||||||
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
|
|
||||||
Logger: logger.Named("svc"),
|
|
||||||
CoordPtr: &coordPtr,
|
|
||||||
DERPMapUpdateFrequency: time.Millisecond,
|
|
||||||
DERPMapFn: func() *tailcfg.DERPMap { return <-derpMapCh },
|
|
||||||
NetworkTelemetryHandler: func([]*proto.TelemetryEvent) {},
|
|
||||||
ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(),
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
dialer := &pipeDialer{
|
|
||||||
ctx: testCtx,
|
|
||||||
logger: logger,
|
|
||||||
t: t,
|
|
||||||
svc: svc,
|
|
||||||
streamID: tailnet.StreamID{
|
|
||||||
Name: "client",
|
|
||||||
ID: clientID,
|
|
||||||
Auth: tailnet.ClientCoordinateeAuth{AgentID: agentID},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
fConn := newFakeTailnetConn()
|
|
||||||
|
|
||||||
uut := newTailnetAPIConnector(ctx, logger.Named("tac"), agentID, dialer, quartz.NewReal())
|
|
||||||
uut.runConnector(fConn)
|
|
||||||
|
|
||||||
call := testutil.RequireRecvCtx(ctx, t, fCoord.CoordinateCalls)
|
|
||||||
reqTun := testutil.RequireRecvCtx(ctx, t, call.Reqs)
|
|
||||||
require.NotNil(t, reqTun.AddTunnel)
|
|
||||||
|
|
||||||
// simulate a problem with DERPMaps by sending nil
|
|
||||||
testutil.RequireSendCtx(ctx, t, derpMapCh, nil)
|
|
||||||
|
|
||||||
// this should cause the coordinate call to hang up WITHOUT disconnecting
|
|
||||||
reqNil := testutil.RequireRecvCtx(ctx, t, call.Reqs)
|
|
||||||
require.Nil(t, reqNil)
|
|
||||||
|
|
||||||
// ...and then reconnect
|
|
||||||
call = testutil.RequireRecvCtx(ctx, t, fCoord.CoordinateCalls)
|
|
||||||
reqTun = testutil.RequireRecvCtx(ctx, t, call.Reqs)
|
|
||||||
require.NotNil(t, reqTun.AddTunnel)
|
|
||||||
|
|
||||||
// canceling the context should trigger the disconnect message
|
|
||||||
cancel()
|
|
||||||
reqDisc := testutil.RequireRecvCtx(testCtx, t, call.Reqs)
|
|
||||||
require.NotNil(t, reqDisc)
|
|
||||||
require.NotNil(t, reqDisc.Disconnect)
|
|
||||||
close(call.Resps)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTailnetAPIConnector_TelemetrySuccess(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
ctx := testutil.Context(t, testutil.WaitShort)
|
|
||||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
|
||||||
agentID := uuid.UUID{0x55}
|
|
||||||
clientID := uuid.UUID{0x66}
|
|
||||||
fCoord := tailnettest.NewFakeCoordinator()
|
|
||||||
var coord tailnet.Coordinator = fCoord
|
|
||||||
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
|
|
||||||
coordPtr.Store(&coord)
|
|
||||||
derpMapCh := make(chan *tailcfg.DERPMap)
|
|
||||||
defer close(derpMapCh)
|
|
||||||
eventCh := make(chan []*proto.TelemetryEvent, 1)
|
|
||||||
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
|
|
||||||
Logger: logger,
|
|
||||||
CoordPtr: &coordPtr,
|
|
||||||
DERPMapUpdateFrequency: time.Millisecond,
|
|
||||||
DERPMapFn: func() *tailcfg.DERPMap { return <-derpMapCh },
|
|
||||||
NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
t.Error("timeout sending telemetry event")
|
|
||||||
case eventCh <- batch:
|
|
||||||
t.Log("sent telemetry batch")
|
|
||||||
}
|
|
||||||
},
|
|
||||||
ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(),
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
dialer := &pipeDialer{
|
|
||||||
ctx: ctx,
|
|
||||||
logger: logger,
|
|
||||||
t: t,
|
|
||||||
svc: svc,
|
|
||||||
streamID: tailnet.StreamID{
|
|
||||||
Name: "client",
|
|
||||||
ID: clientID,
|
|
||||||
Auth: tailnet.ClientCoordinateeAuth{AgentID: agentID},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
fConn := newFakeTailnetConn()
|
|
||||||
|
|
||||||
uut := newTailnetAPIConnector(ctx, logger, agentID, dialer, quartz.NewReal())
|
|
||||||
uut.runConnector(fConn)
|
|
||||||
// Coordinate calls happen _after_ telemetry is connected up, so we use this
|
|
||||||
// to ensure telemetry is connected before sending our event
|
|
||||||
cc := testutil.RequireRecvCtx(ctx, t, fCoord.CoordinateCalls)
|
|
||||||
defer close(cc.Resps)
|
|
||||||
|
|
||||||
uut.SendTelemetryEvent(&proto.TelemetryEvent{
|
|
||||||
Id: []byte("test event"),
|
|
||||||
})
|
|
||||||
|
|
||||||
testEvents := testutil.RequireRecvCtx(ctx, t, eventCh)
|
|
||||||
|
|
||||||
require.Len(t, testEvents, 1)
|
|
||||||
require.Equal(t, []byte("test event"), testEvents[0].Id)
|
|
||||||
}
|
|
||||||
|
|
||||||
type fakeTailnetConn struct{}
|
|
||||||
|
|
||||||
func (*fakeTailnetConn) UpdatePeers([]*proto.CoordinateResponse_PeerUpdate) error {
|
|
||||||
// TODO implement me
|
|
||||||
panic("implement me")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (*fakeTailnetConn) SetAllPeersLost() {}
|
|
||||||
|
|
||||||
func (*fakeTailnetConn) SetNodeCallback(func(*tailnet.Node)) {}
|
|
||||||
|
|
||||||
func (*fakeTailnetConn) SetDERPMap(*tailcfg.DERPMap) {}
|
|
||||||
|
|
||||||
func (*fakeTailnetConn) SetTunnelDestination(uuid.UUID) {}
|
|
||||||
|
|
||||||
func newFakeTailnetConn() *fakeTailnetConn {
|
|
||||||
return &fakeTailnetConn{}
|
|
||||||
}
|
|
||||||
|
|
||||||
type pipeDialer struct {
|
|
||||||
ctx context.Context
|
|
||||||
logger slog.Logger
|
|
||||||
t testing.TB
|
|
||||||
svc *tailnet.ClientService
|
|
||||||
streamID tailnet.StreamID
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *pipeDialer) Dial(_ context.Context, _ tailnet.ResumeTokenController) (tailnet.ControlProtocolClients, error) {
|
|
||||||
s, c := net.Pipe()
|
|
||||||
go func() {
|
|
||||||
err := p.svc.ServeConnV2(p.ctx, s, p.streamID)
|
|
||||||
p.logger.Debug(p.ctx, "piped tailnet service complete", slog.Error(err))
|
|
||||||
}()
|
|
||||||
client, err := tailnet.NewDRPCClient(c, p.logger)
|
|
||||||
if !assert.NoError(p.t, err) {
|
|
||||||
_ = c.Close()
|
|
||||||
return tailnet.ControlProtocolClients{}, err
|
|
||||||
}
|
|
||||||
coord, err := client.Coordinate(context.Background())
|
|
||||||
if !assert.NoError(p.t, err) {
|
|
||||||
_ = c.Close()
|
|
||||||
return tailnet.ControlProtocolClients{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
derps := &tailnet.DERPFromDRPCWrapper{}
|
|
||||||
derps.Client, err = client.StreamDERPMaps(context.Background(), &proto.StreamDERPMapsRequest{})
|
|
||||||
if !assert.NoError(p.t, err) {
|
|
||||||
_ = c.Close()
|
|
||||||
return tailnet.ControlProtocolClients{}, err
|
|
||||||
}
|
|
||||||
return tailnet.ControlProtocolClients{
|
|
||||||
Closer: client.DRPCConn(),
|
|
||||||
Coordinator: coord,
|
|
||||||
DERP: derps,
|
|
||||||
ResumeToken: client,
|
|
||||||
Telemetry: client,
|
|
||||||
}, nil
|
|
||||||
}
|
|
139
codersdk/workspacesdk/dialer.go
Normal file
139
codersdk/workspacesdk/dialer.go
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
package workspacesdk
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"slices"
|
||||||
|
|
||||||
|
"golang.org/x/xerrors"
|
||||||
|
"nhooyr.io/websocket"
|
||||||
|
|
||||||
|
"cdr.dev/slog"
|
||||||
|
"github.com/coder/coder/v2/buildinfo"
|
||||||
|
"github.com/coder/coder/v2/codersdk"
|
||||||
|
"github.com/coder/coder/v2/tailnet"
|
||||||
|
"github.com/coder/coder/v2/tailnet/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
var permanentErrorStatuses = []int{
|
||||||
|
http.StatusConflict, // returned if client/agent connections disabled (browser only)
|
||||||
|
http.StatusBadRequest, // returned if API mismatch
|
||||||
|
http.StatusNotFound, // returned if user doesn't have permission or agent doesn't exist
|
||||||
|
}
|
||||||
|
|
||||||
|
type WebsocketDialer struct {
|
||||||
|
logger slog.Logger
|
||||||
|
dialOptions *websocket.DialOptions
|
||||||
|
url *url.URL
|
||||||
|
resumeTokenFailed bool
|
||||||
|
connected chan error
|
||||||
|
isFirst bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *WebsocketDialer) Dial(ctx context.Context, r tailnet.ResumeTokenController,
|
||||||
|
) (
|
||||||
|
tailnet.ControlProtocolClients, error,
|
||||||
|
) {
|
||||||
|
w.logger.Debug(ctx, "dialing Coder tailnet v2+ API")
|
||||||
|
|
||||||
|
u := new(url.URL)
|
||||||
|
*u = *w.url
|
||||||
|
if r != nil && !w.resumeTokenFailed {
|
||||||
|
if token, ok := r.Token(); ok {
|
||||||
|
q := u.Query()
|
||||||
|
q.Set("resume_token", token)
|
||||||
|
u.RawQuery = q.Encode()
|
||||||
|
w.logger.Debug(ctx, "using resume token on dial")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// nolint:bodyclose
|
||||||
|
ws, res, err := websocket.Dial(ctx, u.String(), w.dialOptions)
|
||||||
|
if w.isFirst {
|
||||||
|
if res != nil && slices.Contains(permanentErrorStatuses, res.StatusCode) {
|
||||||
|
err = codersdk.ReadBodyAsError(res)
|
||||||
|
// A bit more human-readable help in the case the API version was rejected
|
||||||
|
var sdkErr *codersdk.Error
|
||||||
|
if xerrors.As(err, &sdkErr) {
|
||||||
|
if sdkErr.Message == AgentAPIMismatchMessage &&
|
||||||
|
sdkErr.StatusCode() == http.StatusBadRequest {
|
||||||
|
sdkErr.Helper = fmt.Sprintf(
|
||||||
|
"Ensure your client release version (%s, different than the API version) matches the server release version",
|
||||||
|
buildinfo.Version())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
w.connected <- err
|
||||||
|
return tailnet.ControlProtocolClients{}, err
|
||||||
|
}
|
||||||
|
w.isFirst = false
|
||||||
|
close(w.connected)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
bodyErr := codersdk.ReadBodyAsError(res)
|
||||||
|
var sdkErr *codersdk.Error
|
||||||
|
if xerrors.As(bodyErr, &sdkErr) {
|
||||||
|
for _, v := range sdkErr.Validations {
|
||||||
|
if v.Field == "resume_token" {
|
||||||
|
// Unset the resume token for the next attempt
|
||||||
|
w.logger.Warn(ctx, "failed to dial tailnet v2+ API: server replied invalid resume token; unsetting for next connection attempt")
|
||||||
|
w.resumeTokenFailed = true
|
||||||
|
return tailnet.ControlProtocolClients{}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !errors.Is(err, context.Canceled) {
|
||||||
|
w.logger.Error(ctx, "failed to dial tailnet v2+ API", slog.Error(err), slog.F("sdk_err", sdkErr))
|
||||||
|
}
|
||||||
|
return tailnet.ControlProtocolClients{}, err
|
||||||
|
}
|
||||||
|
w.resumeTokenFailed = false
|
||||||
|
|
||||||
|
client, err := tailnet.NewDRPCClient(
|
||||||
|
websocket.NetConn(context.Background(), ws, websocket.MessageBinary),
|
||||||
|
w.logger,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
w.logger.Debug(ctx, "failed to create DRPCClient", slog.Error(err))
|
||||||
|
_ = ws.Close(websocket.StatusInternalError, "")
|
||||||
|
return tailnet.ControlProtocolClients{}, err
|
||||||
|
}
|
||||||
|
coord, err := client.Coordinate(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
w.logger.Debug(ctx, "failed to create Coordinate RPC", slog.Error(err))
|
||||||
|
_ = ws.Close(websocket.StatusInternalError, "")
|
||||||
|
return tailnet.ControlProtocolClients{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
derps := &tailnet.DERPFromDRPCWrapper{}
|
||||||
|
derps.Client, err = client.StreamDERPMaps(context.Background(), &proto.StreamDERPMapsRequest{})
|
||||||
|
if err != nil {
|
||||||
|
w.logger.Debug(ctx, "failed to create DERPMap stream", slog.Error(err))
|
||||||
|
_ = ws.Close(websocket.StatusInternalError, "")
|
||||||
|
return tailnet.ControlProtocolClients{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return tailnet.ControlProtocolClients{
|
||||||
|
Closer: client.DRPCConn(),
|
||||||
|
Coordinator: coord,
|
||||||
|
DERP: derps,
|
||||||
|
ResumeToken: client,
|
||||||
|
Telemetry: client,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *WebsocketDialer) Connected() <-chan error {
|
||||||
|
return w.connected
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewWebsocketDialer(logger slog.Logger, u *url.URL, opts *websocket.DialOptions) *WebsocketDialer {
|
||||||
|
return &WebsocketDialer{
|
||||||
|
logger: logger,
|
||||||
|
dialOptions: opts,
|
||||||
|
url: u,
|
||||||
|
connected: make(chan error, 1),
|
||||||
|
isFirst: true,
|
||||||
|
}
|
||||||
|
}
|
@ -234,7 +234,9 @@ func (c *Client) DialAgent(dialCtx context.Context, agentID uuid.UUID, options *
|
|||||||
// Need to disable compression to avoid a data-race.
|
// Need to disable compression to avoid a data-race.
|
||||||
CompressionMode: websocket.CompressionDisabled,
|
CompressionMode: websocket.CompressionDisabled,
|
||||||
})
|
})
|
||||||
connector := newTailnetAPIConnector(ctx, options.Logger, agentID, dialer, quartz.NewReal())
|
clk := quartz.NewReal()
|
||||||
|
controller := tailnet.NewController(options.Logger, dialer)
|
||||||
|
controller.ResumeTokenCtrl = tailnet.NewBasicResumeTokenController(options.Logger, clk)
|
||||||
|
|
||||||
ip := tailnet.TailscaleServicePrefix.RandomAddr()
|
ip := tailnet.TailscaleServicePrefix.RandomAddr()
|
||||||
var header http.Header
|
var header http.Header
|
||||||
@ -243,7 +245,9 @@ func (c *Client) DialAgent(dialCtx context.Context, agentID uuid.UUID, options *
|
|||||||
}
|
}
|
||||||
var telemetrySink tailnet.TelemetrySink
|
var telemetrySink tailnet.TelemetrySink
|
||||||
if options.EnableTelemetry {
|
if options.EnableTelemetry {
|
||||||
telemetrySink = connector
|
basicTel := tailnet.NewBasicTelemetryController(options.Logger)
|
||||||
|
telemetrySink = basicTel
|
||||||
|
controller.TelemetryCtrl = basicTel
|
||||||
}
|
}
|
||||||
conn, err := tailnet.NewConn(&tailnet.Options{
|
conn, err := tailnet.NewConn(&tailnet.Options{
|
||||||
Addresses: []netip.Prefix{netip.PrefixFrom(ip, 128)},
|
Addresses: []netip.Prefix{netip.PrefixFrom(ip, 128)},
|
||||||
@ -264,7 +268,9 @@ func (c *Client) DialAgent(dialCtx context.Context, agentID uuid.UUID, options *
|
|||||||
_ = conn.Close()
|
_ = conn.Close()
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
connector.runConnector(conn)
|
controller.CoordCtrl = tailnet.NewSingleDestController(options.Logger, conn, agentID)
|
||||||
|
controller.DERPCtrl = tailnet.NewBasicDERPController(options.Logger, conn)
|
||||||
|
controller.Run(ctx)
|
||||||
|
|
||||||
options.Logger.Debug(ctx, "running tailnet API v2+ connector")
|
options.Logger.Debug(ctx, "running tailnet API v2+ connector")
|
||||||
|
|
||||||
@ -283,7 +289,7 @@ func (c *Client) DialAgent(dialCtx context.Context, agentID uuid.UUID, options *
|
|||||||
AgentID: agentID,
|
AgentID: agentID,
|
||||||
CloseFunc: func() error {
|
CloseFunc: func() error {
|
||||||
cancel()
|
cancel()
|
||||||
<-connector.closed
|
<-controller.Closed()
|
||||||
return conn.Close()
|
return conn.Close()
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
@ -16,8 +16,10 @@ import (
|
|||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
|
|
||||||
"cdr.dev/slog"
|
"cdr.dev/slog"
|
||||||
|
"github.com/coder/coder/v2/codersdk"
|
||||||
"github.com/coder/coder/v2/tailnet/proto"
|
"github.com/coder/coder/v2/tailnet/proto"
|
||||||
"github.com/coder/quartz"
|
"github.com/coder/quartz"
|
||||||
|
"github.com/coder/retry"
|
||||||
)
|
)
|
||||||
|
|
||||||
// A Controller connects to the tailnet control plane, and then uses the control protocols to
|
// A Controller connects to the tailnet control plane, and then uses the control protocols to
|
||||||
@ -30,6 +32,16 @@ type Controller struct {
|
|||||||
DERPCtrl DERPController
|
DERPCtrl DERPController
|
||||||
ResumeTokenCtrl ResumeTokenController
|
ResumeTokenCtrl ResumeTokenController
|
||||||
TelemetryCtrl TelemetryController
|
TelemetryCtrl TelemetryController
|
||||||
|
|
||||||
|
ctx context.Context
|
||||||
|
gracefulCtx context.Context
|
||||||
|
cancelGracefulCtx context.CancelFunc
|
||||||
|
logger slog.Logger
|
||||||
|
closedCh chan struct{}
|
||||||
|
|
||||||
|
// Testing only
|
||||||
|
clock quartz.Clock
|
||||||
|
gracefulTimeout time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
type CloserWaiter interface {
|
type CloserWaiter interface {
|
||||||
@ -664,3 +676,211 @@ func (r *basicResumeTokenRefresher) refresh() {
|
|||||||
}
|
}
|
||||||
r.timer.Reset(dur, "basicResumeTokenRefresher", "refresh")
|
r.timer.Reset(dur, "basicResumeTokenRefresher", "refresh")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewController creates a new Controller without running it
|
||||||
|
func NewController(logger slog.Logger, dialer ControlProtocolDialer, opts ...ControllerOpt) *Controller {
|
||||||
|
c := &Controller{
|
||||||
|
logger: logger,
|
||||||
|
clock: quartz.NewReal(),
|
||||||
|
gracefulTimeout: time.Second,
|
||||||
|
Dialer: dialer,
|
||||||
|
closedCh: make(chan struct{}),
|
||||||
|
}
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(c)
|
||||||
|
}
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
type ControllerOpt func(*Controller)
|
||||||
|
|
||||||
|
func WithTestClock(clock quartz.Clock) ControllerOpt {
|
||||||
|
return func(c *Controller) {
|
||||||
|
c.clock = clock
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithGracefulTimeout(timeout time.Duration) ControllerOpt {
|
||||||
|
return func(c *Controller) {
|
||||||
|
c.gracefulTimeout = timeout
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// manageGracefulTimeout allows the gracefulContext to last longer than the main context
|
||||||
|
// to allow a graceful disconnect.
|
||||||
|
func (c *Controller) manageGracefulTimeout() {
|
||||||
|
defer c.cancelGracefulCtx()
|
||||||
|
<-c.ctx.Done()
|
||||||
|
timer := c.clock.NewTimer(c.gracefulTimeout, "tailnetAPIClient", "gracefulTimeout")
|
||||||
|
defer timer.Stop()
|
||||||
|
select {
|
||||||
|
case <-c.closedCh:
|
||||||
|
case <-timer.C:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run dials the API and uses it with the provided controllers.
|
||||||
|
func (c *Controller) Run(ctx context.Context) {
|
||||||
|
c.ctx = ctx
|
||||||
|
c.gracefulCtx, c.cancelGracefulCtx = context.WithCancel(context.Background())
|
||||||
|
go c.manageGracefulTimeout()
|
||||||
|
go func() {
|
||||||
|
defer close(c.closedCh)
|
||||||
|
// Sadly retry doesn't support quartz.Clock yet so this is not
|
||||||
|
// influenced by the configured clock.
|
||||||
|
for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(c.ctx); {
|
||||||
|
tailnetClients, err := c.Dialer.Dial(c.ctx, c.ResumeTokenCtrl)
|
||||||
|
if err != nil {
|
||||||
|
if xerrors.Is(err, context.Canceled) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
errF := slog.Error(err)
|
||||||
|
var sdkErr *codersdk.Error
|
||||||
|
if xerrors.As(err, &sdkErr) {
|
||||||
|
errF = slog.Error(sdkErr)
|
||||||
|
}
|
||||||
|
c.logger.Error(c.ctx, "failed to dial tailnet v2+ API", errF)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
c.logger.Debug(c.ctx, "obtained tailnet API v2+ client")
|
||||||
|
c.runControllersOnce(tailnetClients)
|
||||||
|
c.logger.Debug(c.ctx, "tailnet API v2+ connection lost")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// runControllersOnce uses the provided clients to call into the controllers once. It is combined
|
||||||
|
// into one function so that a problem with one tears down the other and triggers a retry (if
|
||||||
|
// appropriate). We typically multiplex all RPCs over the same websocket, so we want them to share
|
||||||
|
// the same fate.
|
||||||
|
func (c *Controller) runControllersOnce(clients ControlProtocolClients) {
|
||||||
|
defer func() {
|
||||||
|
closeErr := clients.Closer.Close()
|
||||||
|
if closeErr != nil &&
|
||||||
|
!xerrors.Is(closeErr, io.EOF) &&
|
||||||
|
!xerrors.Is(closeErr, context.Canceled) &&
|
||||||
|
!xerrors.Is(closeErr, context.DeadlineExceeded) {
|
||||||
|
c.logger.Error(c.ctx, "error closing DRPC connection", slog.Error(closeErr))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if c.TelemetryCtrl != nil {
|
||||||
|
c.TelemetryCtrl.New(clients.Telemetry) // synchronous, doesn't need a goroutine
|
||||||
|
}
|
||||||
|
|
||||||
|
wg := sync.WaitGroup{}
|
||||||
|
|
||||||
|
if c.CoordCtrl != nil {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
c.coordinate(clients.Coordinator)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
if c.DERPCtrl != nil {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
dErr := c.derpMap(clients.DERP)
|
||||||
|
if dErr != nil && c.ctx.Err() == nil {
|
||||||
|
// The main context is still active, meaning that we want the tailnet data plane to stay
|
||||||
|
// up, even though we hit some error getting DERP maps on the control plane. That means
|
||||||
|
// we do NOT want to gracefully disconnect on the coordinate() routine. So, we'll just
|
||||||
|
// close the underlying connection. This will trigger a retry of the control plane in
|
||||||
|
// run().
|
||||||
|
_ = clients.Closer.Close()
|
||||||
|
// Note that derpMap() logs it own errors, we don't bother here.
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Refresh token is a little different, in that we don't want its controller to hold open the
|
||||||
|
// connection on its own. So we keep it separate from the other wait group, and cancel its
|
||||||
|
// context as soon as the other routines exit.
|
||||||
|
refreshTokenCtx, refreshTokenCancel := context.WithCancel(c.ctx)
|
||||||
|
refreshTokenDone := make(chan struct{})
|
||||||
|
defer func() {
|
||||||
|
<-refreshTokenDone
|
||||||
|
}()
|
||||||
|
defer refreshTokenCancel()
|
||||||
|
go func() {
|
||||||
|
defer close(refreshTokenDone)
|
||||||
|
if c.ResumeTokenCtrl != nil {
|
||||||
|
c.refreshToken(refreshTokenCtx, clients.ResumeToken)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Controller) coordinate(client CoordinatorClient) {
|
||||||
|
defer func() {
|
||||||
|
cErr := client.Close()
|
||||||
|
if cErr != nil {
|
||||||
|
c.logger.Debug(c.ctx, "error closing Coordinate RPC", slog.Error(cErr))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
coordination := c.CoordCtrl.New(client)
|
||||||
|
c.logger.Debug(c.ctx, "serving coordinator")
|
||||||
|
select {
|
||||||
|
case <-c.ctx.Done():
|
||||||
|
c.logger.Debug(c.ctx, "main context canceled; do graceful disconnect")
|
||||||
|
crdErr := coordination.Close(c.gracefulCtx)
|
||||||
|
if crdErr != nil {
|
||||||
|
c.logger.Warn(c.ctx, "failed to close remote coordination", slog.Error(crdErr))
|
||||||
|
}
|
||||||
|
case err := <-coordination.Wait():
|
||||||
|
if err != nil &&
|
||||||
|
!xerrors.Is(err, io.EOF) &&
|
||||||
|
!xerrors.Is(err, context.Canceled) &&
|
||||||
|
!xerrors.Is(err, context.DeadlineExceeded) {
|
||||||
|
c.logger.Error(c.ctx, "remote coordination error", slog.Error(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Controller) derpMap(client DERPClient) error {
|
||||||
|
defer func() {
|
||||||
|
cErr := client.Close()
|
||||||
|
if cErr != nil {
|
||||||
|
c.logger.Debug(c.ctx, "error closing StreamDERPMaps RPC", slog.Error(cErr))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
cw := c.DERPCtrl.New(client)
|
||||||
|
select {
|
||||||
|
case <-c.ctx.Done():
|
||||||
|
cErr := client.Close()
|
||||||
|
if cErr != nil {
|
||||||
|
c.logger.Warn(c.ctx, "failed to close StreamDERPMaps RPC", slog.Error(cErr))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
case err := <-cw.Wait():
|
||||||
|
if xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err != nil && !xerrors.Is(err, io.EOF) {
|
||||||
|
c.logger.Error(c.ctx, "error receiving DERP Map", slog.Error(err))
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Controller) refreshToken(ctx context.Context, client ResumeTokenClient) {
|
||||||
|
cw := c.ResumeTokenCtrl.New(client)
|
||||||
|
go func() {
|
||||||
|
<-ctx.Done()
|
||||||
|
cErr := cw.Close(c.ctx)
|
||||||
|
if cErr != nil {
|
||||||
|
c.logger.Error(c.ctx, "error closing token refresher", slog.Error(cErr))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
err := <-cw.Wait()
|
||||||
|
if err != nil && !xerrors.Is(err, context.Canceled) && !xerrors.Is(err, context.DeadlineExceeded) {
|
||||||
|
c.logger.Error(c.ctx, "error receiving refresh token", slog.Error(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Controller) Closed() <-chan struct{} {
|
||||||
|
return c.closedCh
|
||||||
|
}
|
||||||
|
@ -10,6 +10,8 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
"github.com/hashicorp/yamux"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"go.uber.org/mock/gomock"
|
"go.uber.org/mock/gomock"
|
||||||
"golang.org/x/xerrors"
|
"golang.org/x/xerrors"
|
||||||
@ -678,3 +680,202 @@ type fakeResumeTokenCall struct {
|
|||||||
resp chan *proto.RefreshResumeTokenResponse
|
resp chan *proto.RefreshResumeTokenResponse
|
||||||
errCh chan error
|
errCh chan error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestController_Disconnects(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
testCtx := testutil.Context(t, testutil.WaitShort)
|
||||||
|
ctx, cancel := context.WithCancel(testCtx)
|
||||||
|
logger := slogtest.Make(t, &slogtest.Options{
|
||||||
|
IgnoredErrorIs: append(slogtest.DefaultIgnoredErrorIs,
|
||||||
|
io.EOF, // we get EOF when we simulate a DERPMap error
|
||||||
|
yamux.ErrSessionShutdown, // coordination can throw these when DERP error tears down session
|
||||||
|
),
|
||||||
|
}).Leveled(slog.LevelDebug)
|
||||||
|
agentID := uuid.UUID{0x55}
|
||||||
|
clientID := uuid.UUID{0x66}
|
||||||
|
fCoord := tailnettest.NewFakeCoordinator()
|
||||||
|
var coord tailnet.Coordinator = fCoord
|
||||||
|
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
|
||||||
|
coordPtr.Store(&coord)
|
||||||
|
derpMapCh := make(chan *tailcfg.DERPMap)
|
||||||
|
defer close(derpMapCh)
|
||||||
|
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
|
||||||
|
Logger: logger.Named("svc"),
|
||||||
|
CoordPtr: &coordPtr,
|
||||||
|
DERPMapUpdateFrequency: time.Millisecond,
|
||||||
|
DERPMapFn: func() *tailcfg.DERPMap { return <-derpMapCh },
|
||||||
|
NetworkTelemetryHandler: func([]*proto.TelemetryEvent) {},
|
||||||
|
ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(),
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
dialer := &pipeDialer{
|
||||||
|
ctx: testCtx,
|
||||||
|
logger: logger,
|
||||||
|
t: t,
|
||||||
|
svc: svc,
|
||||||
|
streamID: tailnet.StreamID{
|
||||||
|
Name: "client",
|
||||||
|
ID: clientID,
|
||||||
|
Auth: tailnet.ClientCoordinateeAuth{AgentID: agentID},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
peersLost := make(chan struct{})
|
||||||
|
fConn := &fakeTailnetConn{peersLostCh: peersLost}
|
||||||
|
|
||||||
|
uut := tailnet.NewController(logger.Named("tac"), dialer,
|
||||||
|
// darwin can be slow sometimes.
|
||||||
|
tailnet.WithGracefulTimeout(5*time.Second))
|
||||||
|
uut.CoordCtrl = tailnet.NewAgentCoordinationController(logger.Named("coord_ctrl"), fConn)
|
||||||
|
uut.DERPCtrl = tailnet.NewBasicDERPController(logger.Named("derp_ctrl"), fConn)
|
||||||
|
uut.Run(ctx)
|
||||||
|
|
||||||
|
call := testutil.RequireRecvCtx(testCtx, t, fCoord.CoordinateCalls)
|
||||||
|
|
||||||
|
// simulate a problem with DERPMaps by sending nil
|
||||||
|
testutil.RequireSendCtx(testCtx, t, derpMapCh, nil)
|
||||||
|
|
||||||
|
// this should cause the coordinate call to hang up WITHOUT disconnecting
|
||||||
|
reqNil := testutil.RequireRecvCtx(testCtx, t, call.Reqs)
|
||||||
|
require.Nil(t, reqNil)
|
||||||
|
|
||||||
|
// and mark all peers lost
|
||||||
|
_ = testutil.RequireRecvCtx(testCtx, t, peersLost)
|
||||||
|
|
||||||
|
// ...and then reconnect
|
||||||
|
call = testutil.RequireRecvCtx(testCtx, t, fCoord.CoordinateCalls)
|
||||||
|
|
||||||
|
// canceling the context should trigger the disconnect message
|
||||||
|
cancel()
|
||||||
|
reqDisc := testutil.RequireRecvCtx(testCtx, t, call.Reqs)
|
||||||
|
require.NotNil(t, reqDisc)
|
||||||
|
require.NotNil(t, reqDisc.Disconnect)
|
||||||
|
close(call.Resps)
|
||||||
|
|
||||||
|
_ = testutil.RequireRecvCtx(testCtx, t, peersLost)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestController_TelemetrySuccess(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
ctx := testutil.Context(t, testutil.WaitShort)
|
||||||
|
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||||
|
agentID := uuid.UUID{0x55}
|
||||||
|
clientID := uuid.UUID{0x66}
|
||||||
|
fCoord := tailnettest.NewFakeCoordinator()
|
||||||
|
var coord tailnet.Coordinator = fCoord
|
||||||
|
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
|
||||||
|
coordPtr.Store(&coord)
|
||||||
|
derpMapCh := make(chan *tailcfg.DERPMap)
|
||||||
|
defer close(derpMapCh)
|
||||||
|
eventCh := make(chan []*proto.TelemetryEvent, 1)
|
||||||
|
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
|
||||||
|
Logger: logger,
|
||||||
|
CoordPtr: &coordPtr,
|
||||||
|
DERPMapUpdateFrequency: time.Millisecond,
|
||||||
|
DERPMapFn: func() *tailcfg.DERPMap { return <-derpMapCh },
|
||||||
|
NetworkTelemetryHandler: func(batch []*proto.TelemetryEvent) {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
t.Error("timeout sending telemetry event")
|
||||||
|
case eventCh <- batch:
|
||||||
|
t.Log("sent telemetry batch")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(),
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
dialer := &pipeDialer{
|
||||||
|
ctx: ctx,
|
||||||
|
logger: logger,
|
||||||
|
t: t,
|
||||||
|
svc: svc,
|
||||||
|
streamID: tailnet.StreamID{
|
||||||
|
Name: "client",
|
||||||
|
ID: clientID,
|
||||||
|
Auth: tailnet.ClientCoordinateeAuth{AgentID: agentID},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
uut := tailnet.NewController(logger, dialer)
|
||||||
|
uut.CoordCtrl = tailnet.NewAgentCoordinationController(logger, &fakeTailnetConn{})
|
||||||
|
tel := tailnet.NewBasicTelemetryController(logger)
|
||||||
|
uut.TelemetryCtrl = tel
|
||||||
|
uut.Run(ctx)
|
||||||
|
// Coordinate calls happen _after_ telemetry is connected up, so we use this
|
||||||
|
// to ensure telemetry is connected before sending our event
|
||||||
|
cc := testutil.RequireRecvCtx(ctx, t, fCoord.CoordinateCalls)
|
||||||
|
defer close(cc.Resps)
|
||||||
|
|
||||||
|
tel.SendTelemetryEvent(&proto.TelemetryEvent{
|
||||||
|
Id: []byte("test event"),
|
||||||
|
})
|
||||||
|
|
||||||
|
testEvents := testutil.RequireRecvCtx(ctx, t, eventCh)
|
||||||
|
|
||||||
|
require.Len(t, testEvents, 1)
|
||||||
|
require.Equal(t, []byte("test event"), testEvents[0].Id)
|
||||||
|
}
|
||||||
|
|
||||||
|
type fakeTailnetConn struct {
|
||||||
|
peersLostCh chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*fakeTailnetConn) UpdatePeers([]*proto.CoordinateResponse_PeerUpdate) error {
|
||||||
|
// TODO implement me
|
||||||
|
panic("implement me")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeTailnetConn) SetAllPeersLost() {
|
||||||
|
if f.peersLostCh == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
f.peersLostCh <- struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*fakeTailnetConn) SetNodeCallback(func(*tailnet.Node)) {}
|
||||||
|
|
||||||
|
func (*fakeTailnetConn) SetDERPMap(*tailcfg.DERPMap) {}
|
||||||
|
|
||||||
|
func (*fakeTailnetConn) SetTunnelDestination(uuid.UUID) {}
|
||||||
|
|
||||||
|
type pipeDialer struct {
|
||||||
|
ctx context.Context
|
||||||
|
logger slog.Logger
|
||||||
|
t testing.TB
|
||||||
|
svc *tailnet.ClientService
|
||||||
|
streamID tailnet.StreamID
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *pipeDialer) Dial(_ context.Context, _ tailnet.ResumeTokenController) (tailnet.ControlProtocolClients, error) {
|
||||||
|
s, c := net.Pipe()
|
||||||
|
go func() {
|
||||||
|
err := p.svc.ServeConnV2(p.ctx, s, p.streamID)
|
||||||
|
p.logger.Debug(p.ctx, "piped tailnet service complete", slog.Error(err))
|
||||||
|
}()
|
||||||
|
client, err := tailnet.NewDRPCClient(c, p.logger)
|
||||||
|
if !assert.NoError(p.t, err) {
|
||||||
|
_ = c.Close()
|
||||||
|
return tailnet.ControlProtocolClients{}, err
|
||||||
|
}
|
||||||
|
coord, err := client.Coordinate(context.Background())
|
||||||
|
if !assert.NoError(p.t, err) {
|
||||||
|
_ = c.Close()
|
||||||
|
return tailnet.ControlProtocolClients{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
derps := &tailnet.DERPFromDRPCWrapper{}
|
||||||
|
derps.Client, err = client.StreamDERPMaps(context.Background(), &proto.StreamDERPMapsRequest{})
|
||||||
|
if !assert.NoError(p.t, err) {
|
||||||
|
_ = c.Close()
|
||||||
|
return tailnet.ControlProtocolClients{}, err
|
||||||
|
}
|
||||||
|
return tailnet.ControlProtocolClients{
|
||||||
|
Closer: client.DRPCConn(),
|
||||||
|
Coordinator: coord,
|
||||||
|
DERP: derps,
|
||||||
|
ResumeToken: client,
|
||||||
|
Telemetry: client,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user