chore: refactor tailnetAPIConnector to use dialer (#15347)

refactors `tailnetAPIConnector` to use the `Dialer` interface in `tailnet`, introduced lower in this stack of PRs. This will let us use the same Tailnet API handling code across different things that connect to the Tailnet API (CLI client, coderd, workspace proxies, and soon: Coder VPN).

chore re: #14729
This commit is contained in:
Spike Curtis
2024-11-07 17:24:19 +04:00
committed by GitHub
parent ba483efd0f
commit 2d061e698d
5 changed files with 626 additions and 558 deletions

View File

@ -56,32 +56,28 @@ type tailnetAPIConnector struct {
logger slog.Logger
agentID uuid.UUID
coordinateURL string
clock quartz.Clock
dialOptions *websocket.DialOptions
derpCtrl tailnet.DERPController
coordCtrl tailnet.CoordinationController
telCtrl *tailnet.BasicTelemetryController
agentID uuid.UUID
clock quartz.Clock
dialer tailnet.ControlProtocolDialer
derpCtrl tailnet.DERPController
coordCtrl tailnet.CoordinationController
telCtrl *tailnet.BasicTelemetryController
tokenCtrl tailnet.ResumeTokenController
connected chan error
resumeToken *proto.RefreshResumeTokenResponse
isFirst bool
closed chan struct{}
closed chan struct{}
}
// Create a new tailnetAPIConnector without running it
func newTailnetAPIConnector(ctx context.Context, logger slog.Logger, agentID uuid.UUID, coordinateURL string, clock quartz.Clock, dialOptions *websocket.DialOptions) *tailnetAPIConnector {
func newTailnetAPIConnector(ctx context.Context, logger slog.Logger, agentID uuid.UUID, dialer tailnet.ControlProtocolDialer, clock quartz.Clock) *tailnetAPIConnector {
return &tailnetAPIConnector{
ctx: ctx,
logger: logger,
agentID: agentID,
coordinateURL: coordinateURL,
clock: clock,
dialOptions: dialOptions,
connected: make(chan error, 1),
closed: make(chan struct{}),
telCtrl: tailnet.NewBasicTelemetryController(logger),
ctx: ctx,
logger: logger,
agentID: agentID,
clock: clock,
dialer: dialer,
closed: make(chan struct{}),
telCtrl: tailnet.NewBasicTelemetryController(logger),
tokenCtrl: tailnet.NewBasicResumeTokenController(logger, clock),
}
}
@ -105,17 +101,25 @@ func (tac *tailnetAPIConnector) runConnector(conn tailnetConn) {
tac.gracefulCtx, tac.cancelGracefulCtx = context.WithCancel(context.Background())
go tac.manageGracefulTimeout()
go func() {
tac.isFirst = true
defer close(tac.closed)
// Sadly retry doesn't support quartz.Clock yet so this is not
// influenced by the configured clock.
for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(tac.ctx); {
tailnetClient, err := tac.dial()
tailnetClients, err := tac.dialer.Dial(tac.ctx, tac.tokenCtrl)
if err != nil {
if xerrors.Is(err, context.Canceled) {
continue
}
errF := slog.Error(err)
var sdkErr *codersdk.Error
if xerrors.As(err, &sdkErr) {
errF = slog.Error(sdkErr)
}
tac.logger.Error(tac.ctx, "failed to dial tailnet v2+ API", errF)
continue
}
tac.logger.Debug(tac.ctx, "obtained tailnet API v2+ client")
tac.runConnectorOnce(tailnetClient)
tac.runConnectorOnce(tailnetClients)
tac.logger.Debug(tac.ctx, "tailnet API v2+ connection lost")
}
}()
@ -127,26 +131,152 @@ var permanentErrorStatuses = []int{
http.StatusNotFound, // returned if user doesn't have permission or agent doesn't exist
}
func (tac *tailnetAPIConnector) dial() (proto.DRPCTailnetClient, error) {
tac.logger.Debug(tac.ctx, "dialing Coder tailnet v2+ API")
// runConnectorOnce uses the provided client to coordinate and stream DERP Maps. It is combined
// into one function so that a problem with one tears down the other and triggers a retry (if
// appropriate). We multiplex both RPCs over the same websocket, so we want them to share the same
// fate.
func (tac *tailnetAPIConnector) runConnectorOnce(clients tailnet.ControlProtocolClients) {
defer func() {
closeErr := clients.Closer.Close()
if closeErr != nil &&
!xerrors.Is(closeErr, io.EOF) &&
!xerrors.Is(closeErr, context.Canceled) &&
!xerrors.Is(closeErr, context.DeadlineExceeded) {
tac.logger.Error(tac.ctx, "error closing DRPC connection", slog.Error(closeErr))
}
}()
u, err := url.Parse(tac.coordinateURL)
if err != nil {
return nil, xerrors.Errorf("parse URL %q: %w", tac.coordinateURL, err)
}
if tac.resumeToken != nil {
q := u.Query()
q.Set("resume_token", tac.resumeToken.Token)
u.RawQuery = q.Encode()
tac.logger.Debug(tac.ctx, "using resume token", slog.F("resume_token", tac.resumeToken))
}
tac.telCtrl.New(clients.Telemetry) // synchronous, doesn't need a goroutine
coordinateURL := u.String()
tac.logger.Debug(tac.ctx, "using coordinate URL", slog.F("url", coordinateURL))
refreshTokenCtx, refreshTokenCancel := context.WithCancel(tac.ctx)
wg := sync.WaitGroup{}
wg.Add(3)
go func() {
defer wg.Done()
tac.coordinate(clients.Coordinator)
}()
go func() {
defer wg.Done()
defer refreshTokenCancel()
dErr := tac.derpMap(clients.DERP)
if dErr != nil && tac.ctx.Err() == nil {
// The main context is still active, meaning that we want the tailnet data plane to stay
// up, even though we hit some error getting DERP maps on the control plane. That means
// we do NOT want to gracefully disconnect on the coordinate() routine. So, we'll just
// close the underlying connection. This will trigger a retry of the control plane in
// run().
clients.Closer.Close()
// Note that derpMap() logs it own errors, we don't bother here.
}
}()
go func() {
defer wg.Done()
tac.refreshToken(refreshTokenCtx, clients.ResumeToken)
}()
wg.Wait()
}
func (tac *tailnetAPIConnector) coordinate(client tailnet.CoordinatorClient) {
defer func() {
cErr := client.Close()
if cErr != nil {
tac.logger.Debug(tac.ctx, "error closing Coordinate RPC", slog.Error(cErr))
}
}()
coordination := tac.coordCtrl.New(client)
tac.logger.Debug(tac.ctx, "serving coordinator")
select {
case <-tac.ctx.Done():
tac.logger.Debug(tac.ctx, "main context canceled; do graceful disconnect")
crdErr := coordination.Close(tac.gracefulCtx)
if crdErr != nil {
tac.logger.Warn(tac.ctx, "failed to close remote coordination", slog.Error(crdErr))
}
case err := <-coordination.Wait():
if err != nil &&
!xerrors.Is(err, io.EOF) &&
!xerrors.Is(err, context.Canceled) &&
!xerrors.Is(err, context.DeadlineExceeded) {
tac.logger.Error(tac.ctx, "remote coordination error", slog.Error(err))
}
}
}
func (tac *tailnetAPIConnector) derpMap(client tailnet.DERPClient) error {
defer func() {
cErr := client.Close()
if cErr != nil {
tac.logger.Debug(tac.ctx, "error closing StreamDERPMaps RPC", slog.Error(cErr))
}
}()
cw := tac.derpCtrl.New(client)
select {
case <-tac.ctx.Done():
cErr := client.Close()
if cErr != nil {
tac.logger.Warn(tac.ctx, "failed to close StreamDERPMaps RPC", slog.Error(cErr))
}
return nil
case err := <-cw.Wait():
if xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) {
return nil
}
if err != nil && !xerrors.Is(err, io.EOF) {
tac.logger.Error(tac.ctx, "error receiving DERP Map", slog.Error(err))
}
return err
}
}
func (tac *tailnetAPIConnector) refreshToken(ctx context.Context, client tailnet.ResumeTokenClient) {
cw := tac.tokenCtrl.New(client)
go func() {
<-ctx.Done()
cErr := cw.Close(tac.ctx)
if cErr != nil {
tac.logger.Error(tac.ctx, "error closing token refresher", slog.Error(cErr))
}
}()
err := <-cw.Wait()
if err != nil && !xerrors.Is(err, context.Canceled) && !xerrors.Is(err, context.DeadlineExceeded) {
tac.logger.Error(tac.ctx, "error receiving refresh token", slog.Error(err))
}
}
func (tac *tailnetAPIConnector) SendTelemetryEvent(event *proto.TelemetryEvent) {
tac.telCtrl.SendTelemetryEvent(event)
}
type WebsocketDialer struct {
logger slog.Logger
dialOptions *websocket.DialOptions
url *url.URL
resumeTokenFailed bool
connected chan error
isFirst bool
}
func (w *WebsocketDialer) Dial(ctx context.Context, r tailnet.ResumeTokenController,
) (
tailnet.ControlProtocolClients, error,
) {
w.logger.Debug(ctx, "dialing Coder tailnet v2+ API")
u := new(url.URL)
*u = *w.url
if r != nil && !w.resumeTokenFailed {
if token, ok := r.Token(); ok {
q := u.Query()
q.Set("resume_token", token)
u.RawQuery = q.Encode()
w.logger.Debug(ctx, "using resume token on dial")
}
}
// nolint:bodyclose
ws, res, err := websocket.Dial(tac.ctx, coordinateURL, tac.dialOptions)
if tac.isFirst {
ws, res, err := websocket.Dial(ctx, u.String(), w.dialOptions)
if w.isFirst {
if res != nil && slices.Contains(permanentErrorStatuses, res.StatusCode) {
err = codersdk.ReadBodyAsError(res)
// A bit more human-readable help in the case the API version was rejected
@ -159,11 +289,11 @@ func (tac *tailnetAPIConnector) dial() (proto.DRPCTailnetClient, error) {
buildinfo.Version())
}
}
tac.connected <- err
return nil, err
w.connected <- err
return tailnet.ControlProtocolClients{}, err
}
tac.isFirst = false
close(tac.connected)
w.isFirst = false
close(w.connected)
}
if err != nil {
bodyErr := codersdk.ReadBodyAsError(res)
@ -172,167 +302,62 @@ func (tac *tailnetAPIConnector) dial() (proto.DRPCTailnetClient, error) {
for _, v := range sdkErr.Validations {
if v.Field == "resume_token" {
// Unset the resume token for the next attempt
tac.logger.Warn(tac.ctx, "failed to dial tailnet v2+ API: server replied invalid resume token; unsetting for next connection attempt")
tac.resumeToken = nil
return nil, err
w.logger.Warn(ctx, "failed to dial tailnet v2+ API: server replied invalid resume token; unsetting for next connection attempt")
w.resumeTokenFailed = true
return tailnet.ControlProtocolClients{}, err
}
}
}
if !errors.Is(err, context.Canceled) {
tac.logger.Error(tac.ctx, "failed to dial tailnet v2+ API", slog.Error(err), slog.F("sdk_err", sdkErr))
w.logger.Error(ctx, "failed to dial tailnet v2+ API", slog.Error(err), slog.F("sdk_err", sdkErr))
}
return nil, err
return tailnet.ControlProtocolClients{}, err
}
w.resumeTokenFailed = false
client, err := tailnet.NewDRPCClient(
websocket.NetConn(tac.gracefulCtx, ws, websocket.MessageBinary),
tac.logger,
websocket.NetConn(context.Background(), ws, websocket.MessageBinary),
w.logger,
)
if err != nil {
tac.logger.Debug(tac.ctx, "failed to create DRPCClient", slog.Error(err))
w.logger.Debug(ctx, "failed to create DRPCClient", slog.Error(err))
_ = ws.Close(websocket.StatusInternalError, "")
return nil, err
return tailnet.ControlProtocolClients{}, err
}
return client, err
}
// runConnectorOnce uses the provided client to coordinate and stream DERP Maps. It is combined
// into one function so that a problem with one tears down the other and triggers a retry (if
// appropriate). We multiplex both RPCs over the same websocket, so we want them to share the same
// fate.
func (tac *tailnetAPIConnector) runConnectorOnce(client proto.DRPCTailnetClient) {
defer func() {
conn := client.DRPCConn()
closeErr := conn.Close()
if closeErr != nil &&
!xerrors.Is(closeErr, io.EOF) &&
!xerrors.Is(closeErr, context.Canceled) &&
!xerrors.Is(closeErr, context.DeadlineExceeded) {
tac.logger.Error(tac.ctx, "error closing DRPC connection", slog.Error(closeErr))
<-conn.Closed()
}
}()
tac.telCtrl.New(client) // synchronous, doesn't need a goroutine
refreshTokenCtx, refreshTokenCancel := context.WithCancel(tac.ctx)
wg := sync.WaitGroup{}
wg.Add(3)
go func() {
defer wg.Done()
tac.coordinate(client)
}()
go func() {
defer wg.Done()
defer refreshTokenCancel()
dErr := tac.derpMap(client)
if dErr != nil && tac.ctx.Err() == nil {
// The main context is still active, meaning that we want the tailnet data plane to stay
// up, even though we hit some error getting DERP maps on the control plane. That means
// we do NOT want to gracefully disconnect on the coordinate() routine. So, we'll just
// close the underlying connection. This will trigger a retry of the control plane in
// run().
client.DRPCConn().Close()
// Note that derpMap() logs it own errors, we don't bother here.
}
}()
go func() {
defer wg.Done()
tac.refreshToken(refreshTokenCtx, client)
}()
wg.Wait()
}
func (tac *tailnetAPIConnector) coordinate(client proto.DRPCTailnetClient) {
// we use the gracefulCtx here so that we'll have time to send the graceful disconnect
coord, err := client.Coordinate(tac.gracefulCtx)
coord, err := client.Coordinate(context.Background())
if err != nil {
tac.logger.Error(tac.ctx, "failed to connect to Coordinate RPC", slog.Error(err))
return
w.logger.Debug(ctx, "failed to create Coordinate RPC", slog.Error(err))
_ = ws.Close(websocket.StatusInternalError, "")
return tailnet.ControlProtocolClients{}, err
}
defer func() {
cErr := coord.Close()
if cErr != nil {
tac.logger.Debug(tac.ctx, "error closing Coordinate RPC", slog.Error(cErr))
}
}()
coordination := tac.coordCtrl.New(coord)
tac.logger.Debug(tac.ctx, "serving coordinator")
select {
case <-tac.ctx.Done():
tac.logger.Debug(tac.ctx, "main context canceled; do graceful disconnect")
crdErr := coordination.Close(tac.gracefulCtx)
if crdErr != nil {
tac.logger.Warn(tac.ctx, "failed to close remote coordination", slog.Error(err))
}
case err = <-coordination.Wait():
if err != nil &&
!xerrors.Is(err, io.EOF) &&
!xerrors.Is(err, context.Canceled) &&
!xerrors.Is(err, context.DeadlineExceeded) {
tac.logger.Error(tac.ctx, "remote coordination error", slog.Error(err))
}
}
}
func (tac *tailnetAPIConnector) derpMap(client proto.DRPCTailnetClient) error {
s := &tailnet.DERPFromDRPCWrapper{}
var err error
s.Client, err = client.StreamDERPMaps(tac.ctx, &proto.StreamDERPMapsRequest{})
derps := &tailnet.DERPFromDRPCWrapper{}
derps.Client, err = client.StreamDERPMaps(context.Background(), &proto.StreamDERPMapsRequest{})
if err != nil {
return xerrors.Errorf("failed to connect to StreamDERPMaps RPC: %w", err)
w.logger.Debug(ctx, "failed to create DERPMap stream", slog.Error(err))
_ = ws.Close(websocket.StatusInternalError, "")
return tailnet.ControlProtocolClients{}, err
}
defer func() {
cErr := s.Close()
if cErr != nil {
tac.logger.Debug(tac.ctx, "error closing StreamDERPMaps RPC", slog.Error(cErr))
}
}()
cw := tac.derpCtrl.New(s)
err = <-cw.Wait()
if xerrors.Is(err, context.Canceled) || xerrors.Is(err, context.DeadlineExceeded) {
return nil
}
if err != nil && !xerrors.Is(err, io.EOF) {
tac.logger.Error(tac.ctx, "error receiving DERP Map", slog.Error(err))
}
return err
return tailnet.ControlProtocolClients{
Closer: client.DRPCConn(),
Coordinator: coord,
DERP: derps,
ResumeToken: client,
Telemetry: client,
}, nil
}
func (tac *tailnetAPIConnector) refreshToken(ctx context.Context, client proto.DRPCTailnetClient) {
ticker := tac.clock.NewTicker(15*time.Second, "tailnetAPIConnector", "refreshToken")
defer ticker.Stop()
func (w *WebsocketDialer) Connected() <-chan error {
return w.connected
}
initialCh := make(chan struct{}, 1)
initialCh <- struct{}{}
defer close(initialCh)
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
case <-initialCh:
}
attemptCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
res, err := client.RefreshResumeToken(attemptCtx, &proto.RefreshResumeTokenRequest{})
cancel()
if err != nil {
if ctx.Err() == nil {
tac.logger.Error(tac.ctx, "error refreshing coordinator resume token", slog.Error(err))
}
return
}
tac.logger.Debug(tac.ctx, "refreshed coordinator resume token", slog.F("resume_token", res))
tac.resumeToken = res
dur := res.RefreshIn.AsDuration()
if dur <= 0 {
// A sensible delay to refresh again.
dur = 30 * time.Minute
}
ticker.Reset(dur, "tailnetAPIConnector", "refreshToken", "reset")
func NewWebsocketDialer(logger slog.Logger, u *url.URL, opts *websocket.DialOptions) *WebsocketDialer {
return &WebsocketDialer{
logger: logger,
dialOptions: opts,
url: u,
connected: make(chan error, 1),
isFirst: true,
}
}
func (tac *tailnetAPIConnector) SendTelemetryEvent(event *proto.TelemetryEvent) {
tac.telCtrl.SendTelemetryEvent(event)
}