mirror of
https://github.com/coder/coder.git
synced 2025-07-18 14:17:22 +00:00
chore: refactor tailnetAPIConnector to use dialer (#15347)
refactors `tailnetAPIConnector` to use the `Dialer` interface in `tailnet`, introduced lower in this stack of PRs. This will let us use the same Tailnet API handling code across different things that connect to the Tailnet API (CLI client, coderd, workspace proxies, and soon: Coder VPN). chore re: #14729
This commit is contained in:
@ -57,31 +57,27 @@ type tailnetAPIConnector struct {
|
||||
logger slog.Logger
|
||||
|
||||
agentID uuid.UUID
|
||||
coordinateURL string
|
||||
clock quartz.Clock
|
||||
dialOptions *websocket.DialOptions
|
||||
dialer tailnet.ControlProtocolDialer
|
||||
derpCtrl tailnet.DERPController
|
||||
coordCtrl tailnet.CoordinationController
|
||||
telCtrl *tailnet.BasicTelemetryController
|
||||
tokenCtrl tailnet.ResumeTokenController
|
||||
|
||||
connected chan error
|
||||
resumeToken *proto.RefreshResumeTokenResponse
|
||||
isFirst bool
|
||||
closed chan struct{}
|
||||
}
|
||||
|
||||
// Create a new tailnetAPIConnector without running it
|
||||
func newTailnetAPIConnector(ctx context.Context, logger slog.Logger, agentID uuid.UUID, coordinateURL string, clock quartz.Clock, dialOptions *websocket.DialOptions) *tailnetAPIConnector {
|
||||
func newTailnetAPIConnector(ctx context.Context, logger slog.Logger, agentID uuid.UUID, dialer tailnet.ControlProtocolDialer, clock quartz.Clock) *tailnetAPIConnector {
|
||||
return &tailnetAPIConnector{
|
||||
ctx: ctx,
|
||||
logger: logger,
|
||||
agentID: agentID,
|
||||
coordinateURL: coordinateURL,
|
||||
clock: clock,
|
||||
dialOptions: dialOptions,
|
||||
connected: make(chan error, 1),
|
||||
dialer: dialer,
|
||||
closed: make(chan struct{}),
|
||||
telCtrl: tailnet.NewBasicTelemetryController(logger),
|
||||
tokenCtrl: tailnet.NewBasicResumeTokenController(logger, clock),
|
||||
}
|
||||
}
|
||||
|
||||
@ -105,17 +101,25 @@ func (tac *tailnetAPIConnector) runConnector(conn tailnetConn) {
|
||||
tac.gracefulCtx, tac.cancelGracefulCtx = context.WithCancel(context.Background())
|
||||
go tac.manageGracefulTimeout()
|
||||
go func() {
|
||||
tac.isFirst = true
|
||||
defer close(tac.closed)
|
||||
// Sadly retry doesn't support quartz.Clock yet so this is not
|
||||
// influenced by the configured clock.
|
||||
for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(tac.ctx); {
|
||||
tailnetClient, err := tac.dial()
|
||||
tailnetClients, err := tac.dialer.Dial(tac.ctx, tac.tokenCtrl)
|
||||
if err != nil {
|
||||
if xerrors.Is(err, context.Canceled) {
|
||||
continue
|
||||
}
|
||||
errF := slog.Error(err)
|
||||
var sdkErr *codersdk.Error
|
||||
if xerrors.As(err, &sdkErr) {
|
||||
errF = slog.Error(sdkErr)
|
||||
}
|
||||
tac.logger.Error(tac.ctx, "failed to dial tailnet v2+ API", errF)
|
||||
continue
|
||||
}
|
||||
tac.logger.Debug(tac.ctx, "obtained tailnet API v2+ client")
|
||||
tac.runConnectorOnce(tailnetClient)
|
||||
tac.runConnectorOnce(tailnetClients)
|
||||
tac.logger.Debug(tac.ctx, "tailnet API v2+ connection lost")
|
||||
}
|
||||
}()
|
||||
@ -127,26 +131,152 @@ var permanentErrorStatuses = []int{
|
||||
http.StatusNotFound, // returned if user doesn't have permission or agent doesn't exist
|
||||
}
|
||||
|
||||
func (tac *tailnetAPIConnector) dial() (proto.DRPCTailnetClient, error) {
|
||||
tac.logger.Debug(tac.ctx, "dialing Coder tailnet v2+ API")
|
||||
|
||||
u, err := url.Parse(tac.coordinateURL)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("parse URL %q: %w", tac.coordinateURL, err)
|
||||
// runConnectorOnce uses the provided client to coordinate and stream DERP Maps. It is combined
|
||||
// into one function so that a problem with one tears down the other and triggers a retry (if
|
||||
// appropriate). We multiplex both RPCs over the same websocket, so we want them to share the same
|
||||
// fate.
|
||||
func (tac *tailnetAPIConnector) runConnectorOnce(clients tailnet.ControlProtocolClients) {
|
||||
defer func() {
|
||||
closeErr := clients.Closer.Close()
|
||||
if closeErr != nil &&
|
||||
!xerrors.Is(closeErr, io.EOF) &&
|
||||
!xerrors.Is(closeErr, context.Canceled) &&
|
||||
!xerrors.Is(closeErr, context.DeadlineExceeded) {
|
||||
tac.logger.Error(tac.ctx, "error closing DRPC connection", slog.Error(closeErr))
|
||||
}
|
||||
if tac.resumeToken != nil {
|
||||
}()
|
||||
|
||||
tac.telCtrl.New(clients.Telemetry) // synchronous, doesn't need a goroutine
|
||||
|
||||
refreshTokenCtx, refreshTokenCancel := context.WithCancel(tac.ctx)
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(3)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
tac.coordinate(clients.Coordinator)
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
defer refreshTokenCancel()
|
||||
dErr := tac.derpMap(clients.DERP)
|
||||
if dErr != nil && tac.ctx.Err() == nil {
|
||||
// The main context is still active, meaning that we want the tailnet data plane to stay
|
||||
// up, even though we hit some error getting DERP maps on the control plane. That means
|
||||
// we do NOT want to gracefully disconnect on the coordinate() routine. So, we'll just
|
||||
// close the underlying connection. This will trigger a retry of the control plane in
|
||||
// run().
|
||||
clients.Closer.Close()
|
||||
// Note that derpMap() logs it own errors, we don't bother here.
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
tac.refreshToken(refreshTokenCtx, clients.ResumeToken)
|
||||
}()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (tac *tailnetAPIConnector) coordinate(client tailnet.CoordinatorClient) {
|
||||
defer func() {
|
||||
cErr := client.Close()
|
||||
if cErr != nil {
|
||||
tac.logger.Debug(tac.ctx, "error closing Coordinate RPC", slog.Error(cErr))
|
||||
}
|
||||
}()
|
||||
coordination := tac.coordCtrl.New(client)
|
||||
tac.logger.Debug(tac.ctx, "serving coordinator")
|
||||
select {
|
||||
case <-tac.ctx.Done():
|
||||
tac.logger.Debug(tac.ctx, "main context canceled; do graceful disconnect")
|
||||
crdErr := coordination.Close(tac.gracefulCtx)
|
||||
if crdErr != nil {
|
||||
tac.logger.Warn(tac.ctx, "failed to close remote coordination", slog.Error(crdErr))
|
||||
}
|
||||
case err := <-coordination.Wait():
|
||||
if err != nil &&
|
||||
!xerrors.Is(err, io.EOF) &&
|
||||
!xerrors.Is(err, context.Canceled) &&
|
||||
!xerrors.Is(err, context.DeadlineExceeded) {
|
||||
tac.logger.Error(tac.ctx, "remote coordination error", slog.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (tac *tailnetAPIConnector) derpMap(client tailnet.DERPClient) error {
|
||||
defer func() {
|
||||
cErr := client.Close()
|
||||
if cErr != nil {
|
||||
tac.logger.Debug(tac.ctx, "error closing StreamDERPMaps RPC", slog.Error(cErr))
|
||||
}
|
||||
}()
|
||||
cw := tac.derpCtrl.New(client)
|
||||
select {
|
||||
case <-tac.ctx.Done():
|
||||
cErr := client.Close()
|
||||
if cErr != nil {
|
||||
tac.logger.Warn(tac.ctx, "failed to close StreamDERPMaps RPC", slog.Error(cErr))
|
||||
}
|
||||
return nil
|
||||
case err := <-cw.Wait():
|
||||
if xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) {
|
||||
return nil
|
||||
}
|
||||
if err != nil && !xerrors.Is(err, io.EOF) {
|
||||
tac.logger.Error(tac.ctx, "error receiving DERP Map", slog.Error(err))
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func (tac *tailnetAPIConnector) refreshToken(ctx context.Context, client tailnet.ResumeTokenClient) {
|
||||
cw := tac.tokenCtrl.New(client)
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
cErr := cw.Close(tac.ctx)
|
||||
if cErr != nil {
|
||||
tac.logger.Error(tac.ctx, "error closing token refresher", slog.Error(cErr))
|
||||
}
|
||||
}()
|
||||
|
||||
err := <-cw.Wait()
|
||||
if err != nil && !xerrors.Is(err, context.Canceled) && !xerrors.Is(err, context.DeadlineExceeded) {
|
||||
tac.logger.Error(tac.ctx, "error receiving refresh token", slog.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
func (tac *tailnetAPIConnector) SendTelemetryEvent(event *proto.TelemetryEvent) {
|
||||
tac.telCtrl.SendTelemetryEvent(event)
|
||||
}
|
||||
|
||||
type WebsocketDialer struct {
|
||||
logger slog.Logger
|
||||
dialOptions *websocket.DialOptions
|
||||
url *url.URL
|
||||
resumeTokenFailed bool
|
||||
connected chan error
|
||||
isFirst bool
|
||||
}
|
||||
|
||||
func (w *WebsocketDialer) Dial(ctx context.Context, r tailnet.ResumeTokenController,
|
||||
) (
|
||||
tailnet.ControlProtocolClients, error,
|
||||
) {
|
||||
w.logger.Debug(ctx, "dialing Coder tailnet v2+ API")
|
||||
|
||||
u := new(url.URL)
|
||||
*u = *w.url
|
||||
if r != nil && !w.resumeTokenFailed {
|
||||
if token, ok := r.Token(); ok {
|
||||
q := u.Query()
|
||||
q.Set("resume_token", tac.resumeToken.Token)
|
||||
q.Set("resume_token", token)
|
||||
u.RawQuery = q.Encode()
|
||||
tac.logger.Debug(tac.ctx, "using resume token", slog.F("resume_token", tac.resumeToken))
|
||||
w.logger.Debug(ctx, "using resume token on dial")
|
||||
}
|
||||
}
|
||||
|
||||
coordinateURL := u.String()
|
||||
tac.logger.Debug(tac.ctx, "using coordinate URL", slog.F("url", coordinateURL))
|
||||
|
||||
// nolint:bodyclose
|
||||
ws, res, err := websocket.Dial(tac.ctx, coordinateURL, tac.dialOptions)
|
||||
if tac.isFirst {
|
||||
ws, res, err := websocket.Dial(ctx, u.String(), w.dialOptions)
|
||||
if w.isFirst {
|
||||
if res != nil && slices.Contains(permanentErrorStatuses, res.StatusCode) {
|
||||
err = codersdk.ReadBodyAsError(res)
|
||||
// A bit more human-readable help in the case the API version was rejected
|
||||
@ -159,11 +289,11 @@ func (tac *tailnetAPIConnector) dial() (proto.DRPCTailnetClient, error) {
|
||||
buildinfo.Version())
|
||||
}
|
||||
}
|
||||
tac.connected <- err
|
||||
return nil, err
|
||||
w.connected <- err
|
||||
return tailnet.ControlProtocolClients{}, err
|
||||
}
|
||||
tac.isFirst = false
|
||||
close(tac.connected)
|
||||
w.isFirst = false
|
||||
close(w.connected)
|
||||
}
|
||||
if err != nil {
|
||||
bodyErr := codersdk.ReadBodyAsError(res)
|
||||
@ -172,167 +302,62 @@ func (tac *tailnetAPIConnector) dial() (proto.DRPCTailnetClient, error) {
|
||||
for _, v := range sdkErr.Validations {
|
||||
if v.Field == "resume_token" {
|
||||
// Unset the resume token for the next attempt
|
||||
tac.logger.Warn(tac.ctx, "failed to dial tailnet v2+ API: server replied invalid resume token; unsetting for next connection attempt")
|
||||
tac.resumeToken = nil
|
||||
return nil, err
|
||||
w.logger.Warn(ctx, "failed to dial tailnet v2+ API: server replied invalid resume token; unsetting for next connection attempt")
|
||||
w.resumeTokenFailed = true
|
||||
return tailnet.ControlProtocolClients{}, err
|
||||
}
|
||||
}
|
||||
}
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
tac.logger.Error(tac.ctx, "failed to dial tailnet v2+ API", slog.Error(err), slog.F("sdk_err", sdkErr))
|
||||
w.logger.Error(ctx, "failed to dial tailnet v2+ API", slog.Error(err), slog.F("sdk_err", sdkErr))
|
||||
}
|
||||
return nil, err
|
||||
return tailnet.ControlProtocolClients{}, err
|
||||
}
|
||||
w.resumeTokenFailed = false
|
||||
|
||||
client, err := tailnet.NewDRPCClient(
|
||||
websocket.NetConn(tac.gracefulCtx, ws, websocket.MessageBinary),
|
||||
tac.logger,
|
||||
websocket.NetConn(context.Background(), ws, websocket.MessageBinary),
|
||||
w.logger,
|
||||
)
|
||||
if err != nil {
|
||||
tac.logger.Debug(tac.ctx, "failed to create DRPCClient", slog.Error(err))
|
||||
w.logger.Debug(ctx, "failed to create DRPCClient", slog.Error(err))
|
||||
_ = ws.Close(websocket.StatusInternalError, "")
|
||||
return nil, err
|
||||
return tailnet.ControlProtocolClients{}, err
|
||||
}
|
||||
return client, err
|
||||
}
|
||||
|
||||
// runConnectorOnce uses the provided client to coordinate and stream DERP Maps. It is combined
|
||||
// into one function so that a problem with one tears down the other and triggers a retry (if
|
||||
// appropriate). We multiplex both RPCs over the same websocket, so we want them to share the same
|
||||
// fate.
|
||||
func (tac *tailnetAPIConnector) runConnectorOnce(client proto.DRPCTailnetClient) {
|
||||
defer func() {
|
||||
conn := client.DRPCConn()
|
||||
closeErr := conn.Close()
|
||||
if closeErr != nil &&
|
||||
!xerrors.Is(closeErr, io.EOF) &&
|
||||
!xerrors.Is(closeErr, context.Canceled) &&
|
||||
!xerrors.Is(closeErr, context.DeadlineExceeded) {
|
||||
tac.logger.Error(tac.ctx, "error closing DRPC connection", slog.Error(closeErr))
|
||||
<-conn.Closed()
|
||||
}
|
||||
}()
|
||||
|
||||
tac.telCtrl.New(client) // synchronous, doesn't need a goroutine
|
||||
|
||||
refreshTokenCtx, refreshTokenCancel := context.WithCancel(tac.ctx)
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(3)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
tac.coordinate(client)
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
defer refreshTokenCancel()
|
||||
dErr := tac.derpMap(client)
|
||||
if dErr != nil && tac.ctx.Err() == nil {
|
||||
// The main context is still active, meaning that we want the tailnet data plane to stay
|
||||
// up, even though we hit some error getting DERP maps on the control plane. That means
|
||||
// we do NOT want to gracefully disconnect on the coordinate() routine. So, we'll just
|
||||
// close the underlying connection. This will trigger a retry of the control plane in
|
||||
// run().
|
||||
client.DRPCConn().Close()
|
||||
// Note that derpMap() logs it own errors, we don't bother here.
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
tac.refreshToken(refreshTokenCtx, client)
|
||||
}()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (tac *tailnetAPIConnector) coordinate(client proto.DRPCTailnetClient) {
|
||||
// we use the gracefulCtx here so that we'll have time to send the graceful disconnect
|
||||
coord, err := client.Coordinate(tac.gracefulCtx)
|
||||
coord, err := client.Coordinate(context.Background())
|
||||
if err != nil {
|
||||
tac.logger.Error(tac.ctx, "failed to connect to Coordinate RPC", slog.Error(err))
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
cErr := coord.Close()
|
||||
if cErr != nil {
|
||||
tac.logger.Debug(tac.ctx, "error closing Coordinate RPC", slog.Error(cErr))
|
||||
}
|
||||
}()
|
||||
coordination := tac.coordCtrl.New(coord)
|
||||
tac.logger.Debug(tac.ctx, "serving coordinator")
|
||||
select {
|
||||
case <-tac.ctx.Done():
|
||||
tac.logger.Debug(tac.ctx, "main context canceled; do graceful disconnect")
|
||||
crdErr := coordination.Close(tac.gracefulCtx)
|
||||
if crdErr != nil {
|
||||
tac.logger.Warn(tac.ctx, "failed to close remote coordination", slog.Error(err))
|
||||
}
|
||||
case err = <-coordination.Wait():
|
||||
if err != nil &&
|
||||
!xerrors.Is(err, io.EOF) &&
|
||||
!xerrors.Is(err, context.Canceled) &&
|
||||
!xerrors.Is(err, context.DeadlineExceeded) {
|
||||
tac.logger.Error(tac.ctx, "remote coordination error", slog.Error(err))
|
||||
}
|
||||
}
|
||||
w.logger.Debug(ctx, "failed to create Coordinate RPC", slog.Error(err))
|
||||
_ = ws.Close(websocket.StatusInternalError, "")
|
||||
return tailnet.ControlProtocolClients{}, err
|
||||
}
|
||||
|
||||
func (tac *tailnetAPIConnector) derpMap(client proto.DRPCTailnetClient) error {
|
||||
s := &tailnet.DERPFromDRPCWrapper{}
|
||||
var err error
|
||||
s.Client, err = client.StreamDERPMaps(tac.ctx, &proto.StreamDERPMapsRequest{})
|
||||
derps := &tailnet.DERPFromDRPCWrapper{}
|
||||
derps.Client, err = client.StreamDERPMaps(context.Background(), &proto.StreamDERPMapsRequest{})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to connect to StreamDERPMaps RPC: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
cErr := s.Close()
|
||||
if cErr != nil {
|
||||
tac.logger.Debug(tac.ctx, "error closing StreamDERPMaps RPC", slog.Error(cErr))
|
||||
}
|
||||
}()
|
||||
cw := tac.derpCtrl.New(s)
|
||||
err = <-cw.Wait()
|
||||
if xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) {
|
||||
return nil
|
||||
}
|
||||
if err != nil && !xerrors.Is(err, io.EOF) {
|
||||
tac.logger.Error(tac.ctx, "error receiving DERP Map", slog.Error(err))
|
||||
}
|
||||
return err
|
||||
w.logger.Debug(ctx, "failed to create DERPMap stream", slog.Error(err))
|
||||
_ = ws.Close(websocket.StatusInternalError, "")
|
||||
return tailnet.ControlProtocolClients{}, err
|
||||
}
|
||||
|
||||
func (tac *tailnetAPIConnector) refreshToken(ctx context.Context, client proto.DRPCTailnetClient) {
|
||||
ticker := tac.clock.NewTicker(15*time.Second, "tailnetAPIConnector", "refreshToken")
|
||||
defer ticker.Stop()
|
||||
|
||||
initialCh := make(chan struct{}, 1)
|
||||
initialCh <- struct{}{}
|
||||
defer close(initialCh)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
case <-initialCh:
|
||||
return tailnet.ControlProtocolClients{
|
||||
Closer: client.DRPCConn(),
|
||||
Coordinator: coord,
|
||||
DERP: derps,
|
||||
ResumeToken: client,
|
||||
Telemetry: client,
|
||||
}, nil
|
||||
}
|
||||
|
||||
attemptCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
res, err := client.RefreshResumeToken(attemptCtx, &proto.RefreshResumeTokenRequest{})
|
||||
cancel()
|
||||
if err != nil {
|
||||
if ctx.Err() == nil {
|
||||
tac.logger.Error(tac.ctx, "error refreshing coordinator resume token", slog.Error(err))
|
||||
}
|
||||
return
|
||||
}
|
||||
tac.logger.Debug(tac.ctx, "refreshed coordinator resume token", slog.F("resume_token", res))
|
||||
tac.resumeToken = res
|
||||
dur := res.RefreshIn.AsDuration()
|
||||
if dur <= 0 {
|
||||
// A sensible delay to refresh again.
|
||||
dur = 30 * time.Minute
|
||||
}
|
||||
ticker.Reset(dur, "tailnetAPIConnector", "refreshToken", "reset")
|
||||
}
|
||||
func (w *WebsocketDialer) Connected() <-chan error {
|
||||
return w.connected
|
||||
}
|
||||
|
||||
func (tac *tailnetAPIConnector) SendTelemetryEvent(event *proto.TelemetryEvent) {
|
||||
tac.telCtrl.SendTelemetryEvent(event)
|
||||
func NewWebsocketDialer(logger slog.Logger, u *url.URL, opts *websocket.DialOptions) *WebsocketDialer {
|
||||
return &WebsocketDialer{
|
||||
logger: logger,
|
||||
dialOptions: opts,
|
||||
url: u,
|
||||
connected: make(chan error, 1),
|
||||
isFirst: true,
|
||||
}
|
||||
}
|
||||
|
@ -3,8 +3,7 @@ package workspacesdk
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
@ -13,16 +12,10 @@ import (
|
||||
"github.com/hashicorp/yamux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"nhooyr.io/websocket"
|
||||
"storj.io/drpc"
|
||||
"tailscale.com/tailcfg"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/apiversion"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/coderd/jwtutils"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/tailnet"
|
||||
"github.com/coder/coder/v2/tailnet/proto"
|
||||
"github.com/coder/coder/v2/tailnet/tailnettest"
|
||||
@ -63,32 +56,27 @@ func TestTailnetAPIConnector_Disconnects(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
sws, err := websocket.Accept(w, r, nil)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
ctx, nc := codersdk.WebsocketNetConn(r.Context(), sws, websocket.MessageBinary)
|
||||
err = svc.ServeConnV2(ctx, nc, tailnet.StreamID{
|
||||
dialer := &pipeDialer{
|
||||
ctx: testCtx,
|
||||
logger: logger,
|
||||
t: t,
|
||||
svc: svc,
|
||||
streamID: tailnet.StreamID{
|
||||
Name: "client",
|
||||
ID: clientID,
|
||||
Auth: tailnet.ClientCoordinateeAuth{AgentID: agentID},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
}))
|
||||
},
|
||||
}
|
||||
|
||||
fConn := newFakeTailnetConn()
|
||||
|
||||
uut := newTailnetAPIConnector(ctx, logger.Named("tac"), agentID, svr.URL,
|
||||
quartz.NewReal(), &websocket.DialOptions{})
|
||||
uut := newTailnetAPIConnector(ctx, logger.Named("tac"), agentID, dialer, quartz.NewReal())
|
||||
uut.runConnector(fConn)
|
||||
|
||||
call := testutil.RequireRecvCtx(ctx, t, fCoord.CoordinateCalls)
|
||||
reqTun := testutil.RequireRecvCtx(ctx, t, call.Reqs)
|
||||
require.NotNil(t, reqTun.AddTunnel)
|
||||
|
||||
_ = testutil.RequireRecvCtx(ctx, t, uut.connected)
|
||||
|
||||
// simulate a problem with DERPMaps by sending nil
|
||||
testutil.RequireSendCtx(ctx, t, derpMapCh, nil)
|
||||
|
||||
@ -109,259 +97,6 @@ func TestTailnetAPIConnector_Disconnects(t *testing.T) {
|
||||
close(call.Resps)
|
||||
}
|
||||
|
||||
func TestTailnetAPIConnector_UplevelVersion(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
agentID := uuid.UUID{0x55}
|
||||
|
||||
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
sVer := apiversion.New(proto.CurrentMajor, proto.CurrentMinor-1)
|
||||
|
||||
// the following matches what Coderd does;
|
||||
// c.f. coderd/workspaceagents.go: workspaceAgentClientCoordinate
|
||||
cVer := r.URL.Query().Get("version")
|
||||
if err := sVer.Validate(cVer); err != nil {
|
||||
httpapi.Write(ctx, w, http.StatusBadRequest, codersdk.Response{
|
||||
Message: AgentAPIMismatchMessage,
|
||||
Validations: []codersdk.ValidationError{
|
||||
{Field: "version", Detail: err.Error()},
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
}))
|
||||
|
||||
fConn := newFakeTailnetConn()
|
||||
|
||||
uut := newTailnetAPIConnector(ctx, logger, agentID, svr.URL, quartz.NewReal(), &websocket.DialOptions{})
|
||||
uut.runConnector(fConn)
|
||||
|
||||
err := testutil.RequireRecvCtx(ctx, t, uut.connected)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
||||
require.Equal(t, AgentAPIMismatchMessage, sdkErr.Message)
|
||||
require.NotEmpty(t, sdkErr.Helper)
|
||||
}
|
||||
|
||||
func TestTailnetAPIConnector_ResumeToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
logger := slogtest.Make(t, &slogtest.Options{
|
||||
IgnoreErrors: true,
|
||||
}).Leveled(slog.LevelDebug)
|
||||
agentID := uuid.UUID{0x55}
|
||||
fCoord := tailnettest.NewFakeCoordinator()
|
||||
var coord tailnet.Coordinator = fCoord
|
||||
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
|
||||
coordPtr.Store(&coord)
|
||||
derpMapCh := make(chan *tailcfg.DERPMap)
|
||||
defer close(derpMapCh)
|
||||
|
||||
clock := quartz.NewMock(t)
|
||||
resumeTokenSigningKey, err := tailnet.GenerateResumeTokenSigningKey()
|
||||
require.NoError(t, err)
|
||||
mgr := jwtutils.StaticKey{
|
||||
ID: "123",
|
||||
Key: resumeTokenSigningKey[:],
|
||||
}
|
||||
resumeTokenProvider := tailnet.NewResumeTokenKeyProvider(mgr, clock, time.Hour)
|
||||
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
|
||||
Logger: logger,
|
||||
CoordPtr: &coordPtr,
|
||||
DERPMapUpdateFrequency: time.Millisecond,
|
||||
DERPMapFn: func() *tailcfg.DERPMap { return <-derpMapCh },
|
||||
NetworkTelemetryHandler: func([]*proto.TelemetryEvent) {},
|
||||
ResumeTokenProvider: resumeTokenProvider,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
var (
|
||||
websocketConnCh = make(chan *websocket.Conn, 64)
|
||||
expectResumeToken = ""
|
||||
)
|
||||
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Accept a resume_token query parameter to use the same peer ID. This
|
||||
// behavior matches the actual client coordinate route.
|
||||
var (
|
||||
peerID = uuid.New()
|
||||
resumeToken = r.URL.Query().Get("resume_token")
|
||||
)
|
||||
t.Logf("received resume token: %s", resumeToken)
|
||||
assert.Equal(t, expectResumeToken, resumeToken)
|
||||
if resumeToken != "" {
|
||||
peerID, err = resumeTokenProvider.VerifyResumeToken(ctx, resumeToken)
|
||||
assert.NoError(t, err, "failed to parse resume token")
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, w, http.StatusUnauthorized, codersdk.Response{
|
||||
Message: CoordinateAPIInvalidResumeToken,
|
||||
Detail: err.Error(),
|
||||
Validations: []codersdk.ValidationError{
|
||||
{Field: "resume_token", Detail: CoordinateAPIInvalidResumeToken},
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
sws, err := websocket.Accept(w, r, nil)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
testutil.RequireSendCtx(ctx, t, websocketConnCh, sws)
|
||||
ctx, nc := codersdk.WebsocketNetConn(r.Context(), sws, websocket.MessageBinary)
|
||||
err = svc.ServeConnV2(ctx, nc, tailnet.StreamID{
|
||||
Name: "client",
|
||||
ID: peerID,
|
||||
Auth: tailnet.ClientCoordinateeAuth{AgentID: agentID},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
}))
|
||||
|
||||
fConn := newFakeTailnetConn()
|
||||
|
||||
newTickerTrap := clock.Trap().NewTicker("tailnetAPIConnector", "refreshToken")
|
||||
tickerResetTrap := clock.Trap().TickerReset("tailnetAPIConnector", "refreshToken", "reset")
|
||||
defer newTickerTrap.Close()
|
||||
uut := newTailnetAPIConnector(ctx, logger, agentID, svr.URL, clock, &websocket.DialOptions{})
|
||||
uut.runConnector(fConn)
|
||||
|
||||
// Fetch first token. We don't need to advance the clock since we use a
|
||||
// channel with a single item to immediately fetch.
|
||||
newTickerTrap.MustWait(ctx).Release()
|
||||
// We call ticker.Reset after each token fetch to apply the refresh duration
|
||||
// requested by the server.
|
||||
trappedReset := tickerResetTrap.MustWait(ctx)
|
||||
trappedReset.Release()
|
||||
require.NotNil(t, uut.resumeToken)
|
||||
originalResumeToken := uut.resumeToken.Token
|
||||
|
||||
// Fetch second token.
|
||||
waiter := clock.Advance(trappedReset.Duration)
|
||||
waiter.MustWait(ctx)
|
||||
trappedReset = tickerResetTrap.MustWait(ctx)
|
||||
trappedReset.Release()
|
||||
require.NotNil(t, uut.resumeToken)
|
||||
require.NotEqual(t, originalResumeToken, uut.resumeToken.Token)
|
||||
expectResumeToken = uut.resumeToken.Token
|
||||
t.Logf("expecting resume token: %s", expectResumeToken)
|
||||
|
||||
// Sever the connection and expect it to reconnect with the resume token.
|
||||
wsConn := testutil.RequireRecvCtx(ctx, t, websocketConnCh)
|
||||
_ = wsConn.Close(websocket.StatusGoingAway, "test")
|
||||
|
||||
// Wait for the resume token to be refreshed.
|
||||
trappedTicker := newTickerTrap.MustWait(ctx)
|
||||
// Advance the clock slightly to ensure the new JWT is different.
|
||||
clock.Advance(time.Second).MustWait(ctx)
|
||||
trappedTicker.Release()
|
||||
trappedReset = tickerResetTrap.MustWait(ctx)
|
||||
trappedReset.Release()
|
||||
|
||||
// The resume token should have changed again.
|
||||
require.NotNil(t, uut.resumeToken)
|
||||
require.NotEqual(t, expectResumeToken, uut.resumeToken.Token)
|
||||
}
|
||||
|
||||
func TestTailnetAPIConnector_ResumeTokenFailure(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
logger := slogtest.Make(t, &slogtest.Options{
|
||||
IgnoreErrors: true,
|
||||
}).Leveled(slog.LevelDebug)
|
||||
agentID := uuid.UUID{0x55}
|
||||
fCoord := tailnettest.NewFakeCoordinator()
|
||||
var coord tailnet.Coordinator = fCoord
|
||||
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
|
||||
coordPtr.Store(&coord)
|
||||
derpMapCh := make(chan *tailcfg.DERPMap)
|
||||
defer close(derpMapCh)
|
||||
|
||||
clock := quartz.NewMock(t)
|
||||
resumeTokenSigningKey, err := tailnet.GenerateResumeTokenSigningKey()
|
||||
require.NoError(t, err)
|
||||
mgr := jwtutils.StaticKey{
|
||||
ID: uuid.New().String(),
|
||||
Key: resumeTokenSigningKey[:],
|
||||
}
|
||||
resumeTokenProvider := tailnet.NewResumeTokenKeyProvider(mgr, clock, time.Hour)
|
||||
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
|
||||
Logger: logger,
|
||||
CoordPtr: &coordPtr,
|
||||
DERPMapUpdateFrequency: time.Millisecond,
|
||||
DERPMapFn: func() *tailcfg.DERPMap { return <-derpMapCh },
|
||||
NetworkTelemetryHandler: func(_ []*proto.TelemetryEvent) {},
|
||||
ResumeTokenProvider: resumeTokenProvider,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
var (
|
||||
websocketConnCh = make(chan *websocket.Conn, 64)
|
||||
didFail int64
|
||||
)
|
||||
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Query().Get("resume_token") != "" {
|
||||
atomic.AddInt64(&didFail, 1)
|
||||
httpapi.Write(ctx, w, http.StatusUnauthorized, codersdk.Response{
|
||||
Message: CoordinateAPIInvalidResumeToken,
|
||||
Validations: []codersdk.ValidationError{
|
||||
{Field: "resume_token", Detail: CoordinateAPIInvalidResumeToken},
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
sws, err := websocket.Accept(w, r, nil)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
testutil.RequireSendCtx(ctx, t, websocketConnCh, sws)
|
||||
ctx, nc := codersdk.WebsocketNetConn(r.Context(), sws, websocket.MessageBinary)
|
||||
err = svc.ServeConnV2(ctx, nc, tailnet.StreamID{
|
||||
Name: "client",
|
||||
ID: uuid.New(),
|
||||
Auth: tailnet.ClientCoordinateeAuth{AgentID: agentID},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
}))
|
||||
|
||||
fConn := newFakeTailnetConn()
|
||||
|
||||
newTickerTrap := clock.Trap().NewTicker("tailnetAPIConnector", "refreshToken")
|
||||
tickerResetTrap := clock.Trap().TickerReset("tailnetAPIConnector", "refreshToken", "reset")
|
||||
defer newTickerTrap.Close()
|
||||
uut := newTailnetAPIConnector(ctx, logger, agentID, svr.URL, clock, &websocket.DialOptions{})
|
||||
uut.runConnector(fConn)
|
||||
|
||||
// Wait for the resume token to be fetched for the first time.
|
||||
newTickerTrap.MustWait(ctx).Release()
|
||||
trappedReset := tickerResetTrap.MustWait(ctx)
|
||||
trappedReset.Release()
|
||||
originalResumeToken := uut.resumeToken.Token
|
||||
|
||||
// Sever the connection and expect it to reconnect with the resume token,
|
||||
// which should fail and cause the client to be disconnected. The client
|
||||
// should then reconnect with no resume token.
|
||||
wsConn := testutil.RequireRecvCtx(ctx, t, websocketConnCh)
|
||||
_ = wsConn.Close(websocket.StatusGoingAway, "test")
|
||||
|
||||
// Wait for the resume token to be refreshed, which indicates a successful
|
||||
// reconnect.
|
||||
trappedTicker := newTickerTrap.MustWait(ctx)
|
||||
// Since we failed the initial reconnect and we're definitely reconnected
|
||||
// now, the stored resume token should now be nil.
|
||||
require.Nil(t, uut.resumeToken)
|
||||
trappedTicker.Release()
|
||||
trappedReset = tickerResetTrap.MustWait(ctx)
|
||||
trappedReset.Release()
|
||||
require.NotNil(t, uut.resumeToken)
|
||||
require.NotEqual(t, originalResumeToken, uut.resumeToken.Token)
|
||||
|
||||
// The resume token should have been rejected by the server.
|
||||
require.EqualValues(t, 1, atomic.LoadInt64(&didFail))
|
||||
}
|
||||
|
||||
func TestTailnetAPIConnector_TelemetrySuccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
@ -392,23 +127,21 @@ func TestTailnetAPIConnector_TelemetrySuccess(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
sws, err := websocket.Accept(w, r, nil)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
ctx, nc := codersdk.WebsocketNetConn(r.Context(), sws, websocket.MessageBinary)
|
||||
err = svc.ServeConnV2(ctx, nc, tailnet.StreamID{
|
||||
dialer := &pipeDialer{
|
||||
ctx: ctx,
|
||||
logger: logger,
|
||||
t: t,
|
||||
svc: svc,
|
||||
streamID: tailnet.StreamID{
|
||||
Name: "client",
|
||||
ID: clientID,
|
||||
Auth: tailnet.ClientCoordinateeAuth{AgentID: agentID},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
}))
|
||||
},
|
||||
}
|
||||
|
||||
fConn := newFakeTailnetConn()
|
||||
|
||||
uut := newTailnetAPIConnector(ctx, logger, agentID, svr.URL, quartz.NewReal(), &websocket.DialOptions{})
|
||||
uut := newTailnetAPIConnector(ctx, logger, agentID, dialer, quartz.NewReal())
|
||||
uut.runConnector(fConn)
|
||||
// Coordinate calls happen _after_ telemetry is connected up, so we use this
|
||||
// to ensure telemetry is connected before sending our event
|
||||
@ -444,82 +177,42 @@ func newFakeTailnetConn() *fakeTailnetConn {
|
||||
return &fakeTailnetConn{}
|
||||
}
|
||||
|
||||
type fakeDRPCConn struct{}
|
||||
|
||||
var _ drpc.Conn = &fakeDRPCConn{}
|
||||
|
||||
// Close implements drpc.Conn.
|
||||
func (*fakeDRPCConn) Close() error {
|
||||
return nil
|
||||
type pipeDialer struct {
|
||||
ctx context.Context
|
||||
logger slog.Logger
|
||||
t testing.TB
|
||||
svc *tailnet.ClientService
|
||||
streamID tailnet.StreamID
|
||||
}
|
||||
|
||||
// Closed implements drpc.Conn.
|
||||
func (*fakeDRPCConn) Closed() <-chan struct{} {
|
||||
return nil
|
||||
func (p *pipeDialer) Dial(_ context.Context, _ tailnet.ResumeTokenController) (tailnet.ControlProtocolClients, error) {
|
||||
s, c := net.Pipe()
|
||||
go func() {
|
||||
err := p.svc.ServeConnV2(p.ctx, s, p.streamID)
|
||||
p.logger.Debug(p.ctx, "piped tailnet service complete", slog.Error(err))
|
||||
}()
|
||||
client, err := tailnet.NewDRPCClient(c, p.logger)
|
||||
if !assert.NoError(p.t, err) {
|
||||
_ = c.Close()
|
||||
return tailnet.ControlProtocolClients{}, err
|
||||
}
|
||||
coord, err := client.Coordinate(context.Background())
|
||||
if !assert.NoError(p.t, err) {
|
||||
_ = c.Close()
|
||||
return tailnet.ControlProtocolClients{}, err
|
||||
}
|
||||
|
||||
// Invoke implements drpc.Conn.
|
||||
func (*fakeDRPCConn) Invoke(_ context.Context, _ string, _ drpc.Encoding, _ drpc.Message, _ drpc.Message) error {
|
||||
return nil
|
||||
derps := &tailnet.DERPFromDRPCWrapper{}
|
||||
derps.Client, err = client.StreamDERPMaps(context.Background(), &proto.StreamDERPMapsRequest{})
|
||||
if !assert.NoError(p.t, err) {
|
||||
_ = c.Close()
|
||||
return tailnet.ControlProtocolClients{}, err
|
||||
}
|
||||
|
||||
// NewStream implements drpc.Conn.
|
||||
func (*fakeDRPCConn) NewStream(_ context.Context, _ string, _ drpc.Encoding) (drpc.Stream, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type fakeDRPCStream struct {
|
||||
ch chan struct{}
|
||||
}
|
||||
|
||||
var _ proto.DRPCTailnet_CoordinateClient = &fakeDRPCStream{}
|
||||
|
||||
// Close implements proto.DRPCTailnet_CoordinateClient.
|
||||
func (f *fakeDRPCStream) Close() error {
|
||||
close(f.ch)
|
||||
return nil
|
||||
}
|
||||
|
||||
// CloseSend implements proto.DRPCTailnet_CoordinateClient.
|
||||
func (*fakeDRPCStream) CloseSend() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Context implements proto.DRPCTailnet_CoordinateClient.
|
||||
func (*fakeDRPCStream) Context() context.Context {
|
||||
return nil
|
||||
}
|
||||
|
||||
// MsgRecv implements proto.DRPCTailnet_CoordinateClient.
|
||||
func (*fakeDRPCStream) MsgRecv(_ drpc.Message, _ drpc.Encoding) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// MsgSend implements proto.DRPCTailnet_CoordinateClient.
|
||||
func (*fakeDRPCStream) MsgSend(_ drpc.Message, _ drpc.Encoding) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Recv implements proto.DRPCTailnet_CoordinateClient.
|
||||
func (f *fakeDRPCStream) Recv() (*proto.CoordinateResponse, error) {
|
||||
<-f.ch
|
||||
return &proto.CoordinateResponse{}, nil
|
||||
}
|
||||
|
||||
// Send implements proto.DRPCTailnet_CoordinateClient.
|
||||
func (f *fakeDRPCStream) Send(*proto.CoordinateRequest) error {
|
||||
<-f.ch
|
||||
return nil
|
||||
}
|
||||
|
||||
type fakeDRPPCMapStream struct {
|
||||
fakeDRPCStream
|
||||
}
|
||||
|
||||
var _ proto.DRPCTailnet_StreamDERPMapsClient = &fakeDRPPCMapStream{}
|
||||
|
||||
// Recv implements proto.DRPCTailnet_StreamDERPMapsClient.
|
||||
func (f *fakeDRPPCMapStream) Recv() (*proto.DERPMap, error) {
|
||||
<-f.fakeDRPCStream.ch
|
||||
return &proto.DERPMap{}, nil
|
||||
return tailnet.ControlProtocolClients{
|
||||
Closer: client.DRPCConn(),
|
||||
Coordinator: coord,
|
||||
DERP: derps,
|
||||
ResumeToken: client,
|
||||
Telemetry: client,
|
||||
}, nil
|
||||
}
|
||||
|
350
codersdk/workspacesdk/connector_test.go
Normal file
350
codersdk/workspacesdk/connector_test.go
Normal file
@ -0,0 +1,350 @@
|
||||
package workspacesdk_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"nhooyr.io/websocket"
|
||||
"tailscale.com/tailcfg"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/apiversion"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/coder/v2/tailnet"
|
||||
"github.com/coder/coder/v2/tailnet/proto"
|
||||
"github.com/coder/coder/v2/tailnet/tailnettest"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestWebsocketDialer_TokenController(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
logger := slogtest.Make(t, &slogtest.Options{
|
||||
IgnoreErrors: true,
|
||||
}).Leveled(slog.LevelDebug)
|
||||
|
||||
fTokenProv := newFakeTokenController(ctx, t)
|
||||
fCoord := tailnettest.NewFakeCoordinator()
|
||||
var coord tailnet.Coordinator = fCoord
|
||||
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
|
||||
coordPtr.Store(&coord)
|
||||
|
||||
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
|
||||
Logger: logger,
|
||||
CoordPtr: &coordPtr,
|
||||
DERPMapUpdateFrequency: time.Hour,
|
||||
DERPMapFn: func() *tailcfg.DERPMap { return &tailcfg.DERPMap{} },
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
dialTokens := make(chan string, 1)
|
||||
wsErr := make(chan error, 1)
|
||||
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Error("timed out sending token")
|
||||
case dialTokens <- r.URL.Query().Get("resume_token"):
|
||||
// OK
|
||||
}
|
||||
|
||||
sws, err := websocket.Accept(w, r, nil)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
wsCtx, nc := codersdk.WebsocketNetConn(ctx, sws, websocket.MessageBinary)
|
||||
// streamID can be empty because we don't call RPCs in this test.
|
||||
wsErr <- svc.ServeConnV2(wsCtx, nc, tailnet.StreamID{})
|
||||
}))
|
||||
defer svr.Close()
|
||||
svrURL, err := url.Parse(svr.URL)
|
||||
require.NoError(t, err)
|
||||
|
||||
uut := workspacesdk.NewWebsocketDialer(logger, svrURL, &websocket.DialOptions{})
|
||||
|
||||
clientCh := make(chan tailnet.ControlProtocolClients, 1)
|
||||
go func() {
|
||||
clients, err := uut.Dial(ctx, fTokenProv)
|
||||
assert.NoError(t, err)
|
||||
clientCh <- clients
|
||||
}()
|
||||
|
||||
call := testutil.RequireRecvCtx(ctx, t, fTokenProv.tokenCalls)
|
||||
call <- tokenResponse{"test token", true}
|
||||
gotToken := <-dialTokens
|
||||
require.Equal(t, "test token", gotToken)
|
||||
|
||||
clients := testutil.RequireRecvCtx(ctx, t, clientCh)
|
||||
clients.Closer.Close()
|
||||
|
||||
err = testutil.RequireRecvCtx(ctx, t, wsErr)
|
||||
require.NoError(t, err)
|
||||
|
||||
clientCh = make(chan tailnet.ControlProtocolClients, 1)
|
||||
go func() {
|
||||
clients, err := uut.Dial(ctx, fTokenProv)
|
||||
assert.NoError(t, err)
|
||||
clientCh <- clients
|
||||
}()
|
||||
|
||||
call = testutil.RequireRecvCtx(ctx, t, fTokenProv.tokenCalls)
|
||||
call <- tokenResponse{"test token", false}
|
||||
gotToken = <-dialTokens
|
||||
require.Equal(t, "", gotToken)
|
||||
|
||||
clients = testutil.RequireRecvCtx(ctx, t, clientCh)
|
||||
clients.Closer.Close()
|
||||
|
||||
err = testutil.RequireRecvCtx(ctx, t, wsErr)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestWebsocketDialer_NoTokenController(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
logger := slogtest.Make(t, &slogtest.Options{
|
||||
IgnoreErrors: true,
|
||||
}).Leveled(slog.LevelDebug)
|
||||
|
||||
fCoord := tailnettest.NewFakeCoordinator()
|
||||
var coord tailnet.Coordinator = fCoord
|
||||
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
|
||||
coordPtr.Store(&coord)
|
||||
|
||||
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
|
||||
Logger: logger,
|
||||
CoordPtr: &coordPtr,
|
||||
DERPMapUpdateFrequency: time.Hour,
|
||||
DERPMapFn: func() *tailcfg.DERPMap { return &tailcfg.DERPMap{} },
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
dialTokens := make(chan string, 1)
|
||||
wsErr := make(chan error, 1)
|
||||
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Error("timed out sending token")
|
||||
case dialTokens <- r.URL.Query().Get("resume_token"):
|
||||
// OK
|
||||
}
|
||||
|
||||
sws, err := websocket.Accept(w, r, nil)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
wsCtx, nc := codersdk.WebsocketNetConn(ctx, sws, websocket.MessageBinary)
|
||||
// streamID can be empty because we don't call RPCs in this test.
|
||||
wsErr <- svc.ServeConnV2(wsCtx, nc, tailnet.StreamID{})
|
||||
}))
|
||||
defer svr.Close()
|
||||
svrURL, err := url.Parse(svr.URL)
|
||||
require.NoError(t, err)
|
||||
|
||||
uut := workspacesdk.NewWebsocketDialer(logger, svrURL, &websocket.DialOptions{})
|
||||
|
||||
clientCh := make(chan tailnet.ControlProtocolClients, 1)
|
||||
go func() {
|
||||
clients, err := uut.Dial(ctx, nil)
|
||||
assert.NoError(t, err)
|
||||
clientCh <- clients
|
||||
}()
|
||||
|
||||
gotToken := <-dialTokens
|
||||
require.Equal(t, "", gotToken)
|
||||
|
||||
clients := testutil.RequireRecvCtx(ctx, t, clientCh)
|
||||
clients.Closer.Close()
|
||||
|
||||
err = testutil.RequireRecvCtx(ctx, t, wsErr)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestWebsocketDialer_ResumeTokenFailure(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
logger := slogtest.Make(t, &slogtest.Options{
|
||||
IgnoreErrors: true,
|
||||
}).Leveled(slog.LevelDebug)
|
||||
|
||||
fTokenProv := newFakeTokenController(ctx, t)
|
||||
fCoord := tailnettest.NewFakeCoordinator()
|
||||
var coord tailnet.Coordinator = fCoord
|
||||
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
|
||||
coordPtr.Store(&coord)
|
||||
|
||||
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
|
||||
Logger: logger,
|
||||
CoordPtr: &coordPtr,
|
||||
DERPMapUpdateFrequency: time.Hour,
|
||||
DERPMapFn: func() *tailcfg.DERPMap { return &tailcfg.DERPMap{} },
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
dialTokens := make(chan string, 1)
|
||||
wsErr := make(chan error, 1)
|
||||
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
resumeToken := r.URL.Query().Get("resume_token")
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Error("timed out sending token")
|
||||
case dialTokens <- resumeToken:
|
||||
// OK
|
||||
}
|
||||
|
||||
if resumeToken != "" {
|
||||
httpapi.Write(ctx, w, http.StatusUnauthorized, codersdk.Response{
|
||||
Message: workspacesdk.CoordinateAPIInvalidResumeToken,
|
||||
Validations: []codersdk.ValidationError{
|
||||
{Field: "resume_token", Detail: workspacesdk.CoordinateAPIInvalidResumeToken},
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
sws, err := websocket.Accept(w, r, nil)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
wsCtx, nc := codersdk.WebsocketNetConn(ctx, sws, websocket.MessageBinary)
|
||||
// streamID can be empty because we don't call RPCs in this test.
|
||||
wsErr <- svc.ServeConnV2(wsCtx, nc, tailnet.StreamID{})
|
||||
}))
|
||||
defer svr.Close()
|
||||
svrURL, err := url.Parse(svr.URL)
|
||||
require.NoError(t, err)
|
||||
|
||||
uut := workspacesdk.NewWebsocketDialer(logger, svrURL, &websocket.DialOptions{})
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := uut.Dial(ctx, fTokenProv)
|
||||
errCh <- err
|
||||
}()
|
||||
|
||||
call := testutil.RequireRecvCtx(ctx, t, fTokenProv.tokenCalls)
|
||||
call <- tokenResponse{"test token", true}
|
||||
gotToken := <-dialTokens
|
||||
require.Equal(t, "test token", gotToken)
|
||||
|
||||
err = testutil.RequireRecvCtx(ctx, t, errCh)
|
||||
require.Error(t, err)
|
||||
|
||||
// redial should not use the token
|
||||
clientCh := make(chan tailnet.ControlProtocolClients, 1)
|
||||
go func() {
|
||||
clients, err := uut.Dial(ctx, fTokenProv)
|
||||
assert.NoError(t, err)
|
||||
clientCh <- clients
|
||||
}()
|
||||
gotToken = <-dialTokens
|
||||
require.Equal(t, "", gotToken)
|
||||
|
||||
clients := testutil.RequireRecvCtx(ctx, t, clientCh)
|
||||
require.Error(t, err)
|
||||
clients.Closer.Close()
|
||||
err = testutil.RequireRecvCtx(ctx, t, wsErr)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Successful dial should reset to using token again
|
||||
go func() {
|
||||
_, err := uut.Dial(ctx, fTokenProv)
|
||||
errCh <- err
|
||||
}()
|
||||
call = testutil.RequireRecvCtx(ctx, t, fTokenProv.tokenCalls)
|
||||
call <- tokenResponse{"test token", true}
|
||||
gotToken = <-dialTokens
|
||||
require.Equal(t, "test token", gotToken)
|
||||
err = testutil.RequireRecvCtx(ctx, t, errCh)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestWebsocketDialer_UplevelVersion(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
|
||||
|
||||
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
sVer := apiversion.New(proto.CurrentMajor, proto.CurrentMinor-1)
|
||||
|
||||
// the following matches what Coderd does;
|
||||
// c.f. coderd/workspaceagents.go: workspaceAgentClientCoordinate
|
||||
cVer := r.URL.Query().Get("version")
|
||||
if err := sVer.Validate(cVer); err != nil {
|
||||
httpapi.Write(ctx, w, http.StatusBadRequest, codersdk.Response{
|
||||
Message: workspacesdk.AgentAPIMismatchMessage,
|
||||
Validations: []codersdk.ValidationError{
|
||||
{Field: "version", Detail: err.Error()},
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
}))
|
||||
svrURL, err := url.Parse(svr.URL)
|
||||
require.NoError(t, err)
|
||||
|
||||
uut := workspacesdk.NewWebsocketDialer(logger, svrURL, &websocket.DialOptions{})
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := uut.Dial(ctx, nil)
|
||||
errCh <- err
|
||||
}()
|
||||
|
||||
err = testutil.RequireRecvCtx(ctx, t, errCh)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
||||
require.Equal(t, workspacesdk.AgentAPIMismatchMessage, sdkErr.Message)
|
||||
require.NotEmpty(t, sdkErr.Helper)
|
||||
}
|
||||
|
||||
type fakeResumeTokenController struct {
|
||||
ctx context.Context
|
||||
t testing.TB
|
||||
tokenCalls chan chan tokenResponse
|
||||
}
|
||||
|
||||
func (*fakeResumeTokenController) New(tailnet.ResumeTokenClient) tailnet.CloserWaiter {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (f *fakeResumeTokenController) Token() (string, bool) {
|
||||
call := make(chan tokenResponse)
|
||||
select {
|
||||
case <-f.ctx.Done():
|
||||
f.t.Error("timeout on Token() call")
|
||||
case f.tokenCalls <- call:
|
||||
// OK
|
||||
}
|
||||
select {
|
||||
case <-f.ctx.Done():
|
||||
f.t.Error("timeout on Token() response")
|
||||
return "", false
|
||||
case r := <-call:
|
||||
return r.token, r.ok
|
||||
}
|
||||
}
|
||||
|
||||
var _ tailnet.ResumeTokenController = &fakeResumeTokenController{}
|
||||
|
||||
func newFakeTokenController(ctx context.Context, t testing.TB) *fakeResumeTokenController {
|
||||
return &fakeResumeTokenController{
|
||||
ctx: ctx,
|
||||
t: t,
|
||||
tokenCalls: make(chan chan tokenResponse),
|
||||
}
|
||||
}
|
||||
|
||||
type tokenResponse struct {
|
||||
token string
|
||||
ok bool
|
||||
}
|
@ -228,13 +228,13 @@ func (c *Client) DialAgent(dialCtx context.Context, agentID uuid.UUID, options *
|
||||
q.Add("version", "2.0")
|
||||
coordinateURL.RawQuery = q.Encode()
|
||||
|
||||
connector := newTailnetAPIConnector(ctx, options.Logger, agentID, coordinateURL.String(), quartz.NewReal(),
|
||||
&websocket.DialOptions{
|
||||
dialer := NewWebsocketDialer(options.Logger, coordinateURL, &websocket.DialOptions{
|
||||
HTTPClient: c.client.HTTPClient,
|
||||
HTTPHeader: headers,
|
||||
// Need to disable compression to avoid a data-race.
|
||||
CompressionMode: websocket.CompressionDisabled,
|
||||
})
|
||||
connector := newTailnetAPIConnector(ctx, options.Logger, agentID, dialer, quartz.NewReal())
|
||||
|
||||
ip := tailnet.TailscaleServicePrefix.RandomAddr()
|
||||
var header http.Header
|
||||
@ -271,7 +271,7 @@ func (c *Client) DialAgent(dialCtx context.Context, agentID uuid.UUID, options *
|
||||
select {
|
||||
case <-dialCtx.Done():
|
||||
return nil, xerrors.Errorf("timed out waiting for coordinator and derp map: %w", dialCtx.Err())
|
||||
case err = <-connector.connected:
|
||||
case err = <-dialer.Connected():
|
||||
if err != nil {
|
||||
options.Logger.Error(ctx, "failed to connect to tailnet v2+ API", slog.Error(err))
|
||||
return nil, xerrors.Errorf("start connector: %w", err)
|
||||
|
@ -615,7 +615,7 @@ func newBasicResumeTokenRefresher(
|
||||
errCh: make(chan error, 1),
|
||||
}
|
||||
r.ctx, r.cancel = context.WithCancel(context.Background())
|
||||
r.timer = clock.AfterFunc(never, r.refresh)
|
||||
r.timer = clock.AfterFunc(never, r.refresh, "basicResumeTokenRefresher")
|
||||
go r.refresh()
|
||||
return r
|
||||
}
|
||||
|
Reference in New Issue
Block a user