feat: vpn uses WorkspaceHostnameSuffix for DNS names (#17335)

Use the hostname suffix to set DNS names as programmed into the DNS service and returned by the vpn `Tunnel`.

part of: #16828
This commit is contained in:
Spike Curtis
2025-04-11 13:24:20 +04:00
committed by GitHub
parent 12dc086628
commit 2c573dc023
6 changed files with 249 additions and 173 deletions

View File

@ -143,7 +143,7 @@ type AgentConnectionInfo struct {
DERPMap *tailcfg.DERPMap `json:"derp_map"` DERPMap *tailcfg.DERPMap `json:"derp_map"`
DERPForceWebSockets bool `json:"derp_force_websockets"` DERPForceWebSockets bool `json:"derp_force_websockets"`
DisableDirectConnections bool `json:"disable_direct_connections"` DisableDirectConnections bool `json:"disable_direct_connections"`
HostnameSuffix string `json:"hostname_suffix"` HostnameSuffix string `json:"hostname_suffix,omitempty"`
} }
func (c *Client) AgentConnectionInfoGeneric(ctx context.Context) (AgentConnectionInfo, error) { func (c *Client) AgentConnectionInfoGeneric(ctx context.Context) (AgentConnectionInfo, error) {

View File

@ -357,9 +357,7 @@ func NewConn(options *Options) (conn *Conn, err error) {
// A FQDN to be mapped to `tsaddr.CoderServiceIPv6`. This address can be used // A FQDN to be mapped to `tsaddr.CoderServiceIPv6`. This address can be used
// when you want to know if Coder Connect is running, but are not trying to // when you want to know if Coder Connect is running, but are not trying to
// connect to a specific known workspace. // connect to a specific known workspace.
const IsCoderConnectEnabledFQDNString = "is.coder--connect--enabled--right--now.coder." const IsCoderConnectEnabledFmtString = "is.coder--connect--enabled--right--now.%s."
var IsCoderConnectEnabledFQDN, _ = dnsname.ToFQDN(IsCoderConnectEnabledFQDNString)
type ServicePrefix [6]byte type ServicePrefix [6]byte

View File

@ -864,11 +864,12 @@ func (r *basicResumeTokenRefresher) refresh() {
} }
type TunnelAllWorkspaceUpdatesController struct { type TunnelAllWorkspaceUpdatesController struct {
coordCtrl *TunnelSrcCoordController coordCtrl *TunnelSrcCoordController
dnsHostSetter DNSHostsSetter dnsHostSetter DNSHostsSetter
updateHandler UpdatesHandler dnsNameOptions DNSNameOptions
ownerUsername string updateHandler UpdatesHandler
logger slog.Logger ownerUsername string
logger slog.Logger
mu sync.Mutex mu sync.Mutex
updater *tunnelUpdater updater *tunnelUpdater
@ -883,12 +884,16 @@ type Workspace struct {
agents map[uuid.UUID]*Agent agents map[uuid.UUID]*Agent
} }
type DNSNameOptions struct {
Suffix string
}
// updateDNSNames updates the DNS names for all agents in the workspace. // updateDNSNames updates the DNS names for all agents in the workspace.
// DNS hosts must be all lowercase, or the resolver won't be able to find them. // DNS hosts must be all lowercase, or the resolver won't be able to find them.
// Usernames are globally unique & case-insensitive. // Usernames are globally unique & case-insensitive.
// Workspace names are unique per-user & case-insensitive. // Workspace names are unique per-user & case-insensitive.
// Agent names are unique per-workspace & case-insensitive. // Agent names are unique per-workspace & case-insensitive.
func (w *Workspace) updateDNSNames() error { func (w *Workspace) updateDNSNames(options DNSNameOptions) error {
wsName := strings.ToLower(w.Name) wsName := strings.ToLower(w.Name)
username := strings.ToLower(w.ownerUsername) username := strings.ToLower(w.ownerUsername)
for id, a := range w.agents { for id, a := range w.agents {
@ -896,24 +901,22 @@ func (w *Workspace) updateDNSNames() error {
names := make(map[dnsname.FQDN][]netip.Addr) names := make(map[dnsname.FQDN][]netip.Addr)
// TODO: technically, DNS labels cannot start with numbers, but the rules are often not // TODO: technically, DNS labels cannot start with numbers, but the rules are often not
// strictly enforced. // strictly enforced.
fqdn, err := dnsname.ToFQDN(fmt.Sprintf("%s.%s.me.coder.", agentName, wsName)) fqdn, err := dnsname.ToFQDN(fmt.Sprintf("%s.%s.me.%s.", agentName, wsName, options.Suffix))
if err != nil { if err != nil {
return err return err
} }
names[fqdn] = []netip.Addr{CoderServicePrefix.AddrFromUUID(a.ID)} names[fqdn] = []netip.Addr{CoderServicePrefix.AddrFromUUID(a.ID)}
fqdn, err = dnsname.ToFQDN(fmt.Sprintf("%s.%s.%s.coder.", agentName, wsName, username)) fqdn, err = dnsname.ToFQDN(fmt.Sprintf("%s.%s.%s.%s.", agentName, wsName, username, options.Suffix))
if err != nil { if err != nil {
return err return err
} }
names[fqdn] = []netip.Addr{CoderServicePrefix.AddrFromUUID(a.ID)} names[fqdn] = []netip.Addr{CoderServicePrefix.AddrFromUUID(a.ID)}
if len(w.agents) == 1 { if len(w.agents) == 1 {
fqdn, err := dnsname.ToFQDN(fmt.Sprintf("%s.coder.", wsName)) fqdn, err = dnsname.ToFQDN(fmt.Sprintf("%s.%s.", wsName, options.Suffix))
if err != nil { if err != nil {
return err return err
} }
for _, a := range w.agents { names[fqdn] = []netip.Addr{CoderServicePrefix.AddrFromUUID(a.ID)}
names[fqdn] = []netip.Addr{CoderServicePrefix.AddrFromUUID(a.ID)}
}
} }
a.Hosts = names a.Hosts = names
w.agents[id] = a w.agents[id] = a
@ -950,6 +953,7 @@ func (t *TunnelAllWorkspaceUpdatesController) New(client WorkspaceUpdatesClient)
logger: t.logger, logger: t.logger,
coordCtrl: t.coordCtrl, coordCtrl: t.coordCtrl,
dnsHostsSetter: t.dnsHostSetter, dnsHostsSetter: t.dnsHostSetter,
dnsNameOptions: t.dnsNameOptions,
updateHandler: t.updateHandler, updateHandler: t.updateHandler,
ownerUsername: t.ownerUsername, ownerUsername: t.ownerUsername,
recvLoopDone: make(chan struct{}), recvLoopDone: make(chan struct{}),
@ -996,6 +1000,7 @@ type tunnelUpdater struct {
updateHandler UpdatesHandler updateHandler UpdatesHandler
ownerUsername string ownerUsername string
recvLoopDone chan struct{} recvLoopDone chan struct{}
dnsNameOptions DNSNameOptions
sync.Mutex sync.Mutex
workspaces map[uuid.UUID]*Workspace workspaces map[uuid.UUID]*Workspace
@ -1250,7 +1255,7 @@ func (t *tunnelUpdater) allAgentIDsLocked() []uuid.UUID {
func (t *tunnelUpdater) updateDNSNamesLocked() map[dnsname.FQDN][]netip.Addr { func (t *tunnelUpdater) updateDNSNamesLocked() map[dnsname.FQDN][]netip.Addr {
names := make(map[dnsname.FQDN][]netip.Addr) names := make(map[dnsname.FQDN][]netip.Addr)
for _, w := range t.workspaces { for _, w := range t.workspaces {
err := w.updateDNSNames() err := w.updateDNSNames(t.dnsNameOptions)
if err != nil { if err != nil {
// This should never happen in production, because converting the FQDN only fails // This should never happen in production, because converting the FQDN only fails
// if names are too long, and we put strict length limits on agent, workspace, and user // if names are too long, and we put strict length limits on agent, workspace, and user
@ -1258,6 +1263,7 @@ func (t *tunnelUpdater) updateDNSNamesLocked() map[dnsname.FQDN][]netip.Addr {
t.logger.Critical(context.Background(), t.logger.Critical(context.Background(),
"failed to include DNS name(s)", "failed to include DNS name(s)",
slog.F("workspace_id", w.ID), slog.F("workspace_id", w.ID),
slog.F("suffix", t.dnsNameOptions.Suffix),
slog.Error(err)) slog.Error(err))
} }
for _, a := range w.agents { for _, a := range w.agents {
@ -1266,7 +1272,13 @@ func (t *tunnelUpdater) updateDNSNamesLocked() map[dnsname.FQDN][]netip.Addr {
} }
} }
} }
names[IsCoderConnectEnabledFQDN] = []netip.Addr{tsaddr.CoderServiceIPv6()} isCoderConnectEnabledFQDN, err := dnsname.ToFQDN(fmt.Sprintf(IsCoderConnectEnabledFmtString, t.dnsNameOptions.Suffix))
if err != nil {
t.logger.Critical(context.Background(),
"failed to include Coder Connect enabled DNS name", slog.F("suffix", t.dnsNameOptions.Suffix))
} else {
names[isCoderConnectEnabledFQDN] = []netip.Addr{tsaddr.CoderServiceIPv6()}
}
return names return names
} }
@ -1274,10 +1286,11 @@ type TunnelAllOption func(t *TunnelAllWorkspaceUpdatesController)
// WithDNS configures the tunnelAllWorkspaceUpdatesController to set DNS names for all workspaces // WithDNS configures the tunnelAllWorkspaceUpdatesController to set DNS names for all workspaces
// and agents it learns about. // and agents it learns about.
func WithDNS(d DNSHostsSetter, ownerUsername string) TunnelAllOption { func WithDNS(d DNSHostsSetter, ownerUsername string, options DNSNameOptions) TunnelAllOption {
return func(t *TunnelAllWorkspaceUpdatesController) { return func(t *TunnelAllWorkspaceUpdatesController) {
t.dnsHostSetter = d t.dnsHostSetter = d
t.ownerUsername = ownerUsername t.ownerUsername = ownerUsername
t.dnsNameOptions = options
} }
} }
@ -1293,7 +1306,11 @@ func WithHandler(h UpdatesHandler) TunnelAllOption {
func NewTunnelAllWorkspaceUpdatesController( func NewTunnelAllWorkspaceUpdatesController(
logger slog.Logger, c *TunnelSrcCoordController, opts ...TunnelAllOption, logger slog.Logger, c *TunnelSrcCoordController, opts ...TunnelAllOption,
) *TunnelAllWorkspaceUpdatesController { ) *TunnelAllWorkspaceUpdatesController {
t := &TunnelAllWorkspaceUpdatesController{logger: logger, coordCtrl: c} t := &TunnelAllWorkspaceUpdatesController{
logger: logger,
coordCtrl: c,
dnsNameOptions: DNSNameOptions{"coder"},
}
for _, opt := range opts { for _, opt := range opts {
opt(t) opt(t)
} }

View File

@ -1522,7 +1522,7 @@ func TestTunnelAllWorkspaceUpdatesController_Initial(t *testing.T) {
fUH := newFakeUpdateHandler(ctx, t) fUH := newFakeUpdateHandler(ctx, t)
fDNS := newFakeDNSSetter(ctx, t) fDNS := newFakeDNSSetter(ctx, t)
coordC, updateC, updateCtrl := setupConnectedAllWorkspaceUpdatesController(ctx, t, logger, coordC, updateC, updateCtrl := setupConnectedAllWorkspaceUpdatesController(ctx, t, logger,
tailnet.WithDNS(fDNS, "testy"), tailnet.WithDNS(fDNS, "testy", tailnet.DNSNameOptions{Suffix: "mctest"}),
tailnet.WithHandler(fUH), tailnet.WithHandler(fUH),
) )
@ -1562,16 +1562,19 @@ func TestTunnelAllWorkspaceUpdatesController_Initial(t *testing.T) {
w2a1IP := netip.MustParseAddr("fd60:627a:a42b:0201::") w2a1IP := netip.MustParseAddr("fd60:627a:a42b:0201::")
w2a2IP := netip.MustParseAddr("fd60:627a:a42b:0202::") w2a2IP := netip.MustParseAddr("fd60:627a:a42b:0202::")
expectedCoderConnectFQDN, err := dnsname.ToFQDN(fmt.Sprintf(tailnet.IsCoderConnectEnabledFmtString, "mctest"))
require.NoError(t, err)
// Also triggers setting DNS hosts // Also triggers setting DNS hosts
expectedDNS := map[dnsname.FQDN][]netip.Addr{ expectedDNS := map[dnsname.FQDN][]netip.Addr{
"w1a1.w1.me.coder.": {ws1a1IP}, "w1a1.w1.me.mctest.": {ws1a1IP},
"w2a1.w2.me.coder.": {w2a1IP}, "w2a1.w2.me.mctest.": {w2a1IP},
"w2a2.w2.me.coder.": {w2a2IP}, "w2a2.w2.me.mctest.": {w2a2IP},
"w1a1.w1.testy.coder.": {ws1a1IP}, "w1a1.w1.testy.mctest.": {ws1a1IP},
"w2a1.w2.testy.coder.": {w2a1IP}, "w2a1.w2.testy.mctest.": {w2a1IP},
"w2a2.w2.testy.coder.": {w2a2IP}, "w2a2.w2.testy.mctest.": {w2a2IP},
"w1.coder.": {ws1a1IP}, "w1.mctest.": {ws1a1IP},
tailnet.IsCoderConnectEnabledFQDNString: {tsaddr.CoderServiceIPv6()}, expectedCoderConnectFQDN: {tsaddr.CoderServiceIPv6()},
} }
dnsCall := testutil.RequireRecvCtx(ctx, t, fDNS.calls) dnsCall := testutil.RequireRecvCtx(ctx, t, fDNS.calls)
require.Equal(t, expectedDNS, dnsCall.hosts) require.Equal(t, expectedDNS, dnsCall.hosts)
@ -1586,23 +1589,23 @@ func TestTunnelAllWorkspaceUpdatesController_Initial(t *testing.T) {
{ {
ID: w1a1ID, Name: "w1a1", WorkspaceID: w1ID, ID: w1a1ID, Name: "w1a1", WorkspaceID: w1ID,
Hosts: map[dnsname.FQDN][]netip.Addr{ Hosts: map[dnsname.FQDN][]netip.Addr{
"w1.coder.": {ws1a1IP}, "w1.mctest.": {ws1a1IP},
"w1a1.w1.me.coder.": {ws1a1IP}, "w1a1.w1.me.mctest.": {ws1a1IP},
"w1a1.w1.testy.coder.": {ws1a1IP}, "w1a1.w1.testy.mctest.": {ws1a1IP},
}, },
}, },
{ {
ID: w2a1ID, Name: "w2a1", WorkspaceID: w2ID, ID: w2a1ID, Name: "w2a1", WorkspaceID: w2ID,
Hosts: map[dnsname.FQDN][]netip.Addr{ Hosts: map[dnsname.FQDN][]netip.Addr{
"w2a1.w2.me.coder.": {w2a1IP}, "w2a1.w2.me.mctest.": {w2a1IP},
"w2a1.w2.testy.coder.": {w2a1IP}, "w2a1.w2.testy.mctest.": {w2a1IP},
}, },
}, },
{ {
ID: w2a2ID, Name: "w2a2", WorkspaceID: w2ID, ID: w2a2ID, Name: "w2a2", WorkspaceID: w2ID,
Hosts: map[dnsname.FQDN][]netip.Addr{ Hosts: map[dnsname.FQDN][]netip.Addr{
"w2a2.w2.me.coder.": {w2a2IP}, "w2a2.w2.me.mctest.": {w2a2IP},
"w2a2.w2.testy.coder.": {w2a2IP}, "w2a2.w2.testy.mctest.": {w2a2IP},
}, },
}, },
}, },
@ -1634,7 +1637,7 @@ func TestTunnelAllWorkspaceUpdatesController_DeleteAgent(t *testing.T) {
fUH := newFakeUpdateHandler(ctx, t) fUH := newFakeUpdateHandler(ctx, t)
fDNS := newFakeDNSSetter(ctx, t) fDNS := newFakeDNSSetter(ctx, t)
coordC, updateC, updateCtrl := setupConnectedAllWorkspaceUpdatesController(ctx, t, logger, coordC, updateC, updateCtrl := setupConnectedAllWorkspaceUpdatesController(ctx, t, logger,
tailnet.WithDNS(fDNS, "testy"), tailnet.WithDNS(fDNS, "testy", tailnet.DNSNameOptions{Suffix: "coder"}),
tailnet.WithHandler(fUH), tailnet.WithHandler(fUH),
) )
@ -1661,12 +1664,15 @@ func TestTunnelAllWorkspaceUpdatesController_DeleteAgent(t *testing.T) {
require.Equal(t, w1a1ID[:], coordCall.req.GetAddTunnel().GetId()) require.Equal(t, w1a1ID[:], coordCall.req.GetAddTunnel().GetId())
testutil.RequireSendCtx(ctx, t, coordCall.err, nil) testutil.RequireSendCtx(ctx, t, coordCall.err, nil)
expectedCoderConnectFQDN, err := dnsname.ToFQDN(fmt.Sprintf(tailnet.IsCoderConnectEnabledFmtString, "coder"))
require.NoError(t, err)
// DNS for w1a1 // DNS for w1a1
expectedDNS := map[dnsname.FQDN][]netip.Addr{ expectedDNS := map[dnsname.FQDN][]netip.Addr{
"w1a1.w1.testy.coder.": {ws1a1IP}, "w1a1.w1.testy.coder.": {ws1a1IP},
"w1a1.w1.me.coder.": {ws1a1IP}, "w1a1.w1.me.coder.": {ws1a1IP},
"w1.coder.": {ws1a1IP}, "w1.coder.": {ws1a1IP},
tailnet.IsCoderConnectEnabledFQDNString: {tsaddr.CoderServiceIPv6()}, expectedCoderConnectFQDN: {tsaddr.CoderServiceIPv6()},
} }
dnsCall := testutil.RequireRecvCtx(ctx, t, fDNS.calls) dnsCall := testutil.RequireRecvCtx(ctx, t, fDNS.calls)
require.Equal(t, expectedDNS, dnsCall.hosts) require.Equal(t, expectedDNS, dnsCall.hosts)
@ -1719,10 +1725,10 @@ func TestTunnelAllWorkspaceUpdatesController_DeleteAgent(t *testing.T) {
// DNS contains only w1a2 // DNS contains only w1a2
expectedDNS = map[dnsname.FQDN][]netip.Addr{ expectedDNS = map[dnsname.FQDN][]netip.Addr{
"w1a2.w1.testy.coder.": {ws1a2IP}, "w1a2.w1.testy.coder.": {ws1a2IP},
"w1a2.w1.me.coder.": {ws1a2IP}, "w1a2.w1.me.coder.": {ws1a2IP},
"w1.coder.": {ws1a2IP}, "w1.coder.": {ws1a2IP},
tailnet.IsCoderConnectEnabledFQDNString: {tsaddr.CoderServiceIPv6()}, expectedCoderConnectFQDN: {tsaddr.CoderServiceIPv6()},
} }
dnsCall = testutil.RequireRecvCtx(ctx, t, fDNS.calls) dnsCall = testutil.RequireRecvCtx(ctx, t, fDNS.calls)
require.Equal(t, expectedDNS, dnsCall.hosts) require.Equal(t, expectedDNS, dnsCall.hosts)
@ -1779,7 +1785,7 @@ func TestTunnelAllWorkspaceUpdatesController_DNSError(t *testing.T) {
fConn := &fakeCoordinatee{} fConn := &fakeCoordinatee{}
tsc := tailnet.NewTunnelSrcCoordController(logger, fConn) tsc := tailnet.NewTunnelSrcCoordController(logger, fConn)
uut := tailnet.NewTunnelAllWorkspaceUpdatesController(logger, tsc, uut := tailnet.NewTunnelAllWorkspaceUpdatesController(logger, tsc,
tailnet.WithDNS(fDNS, "testy"), tailnet.WithDNS(fDNS, "testy", tailnet.DNSNameOptions{Suffix: "coder"}),
) )
updateC := newFakeWorkspaceUpdateClient(ctx, t) updateC := newFakeWorkspaceUpdateClient(ctx, t)
@ -1800,12 +1806,15 @@ func TestTunnelAllWorkspaceUpdatesController_DNSError(t *testing.T) {
upRecvCall := testutil.RequireRecvCtx(ctx, t, updateC.recv) upRecvCall := testutil.RequireRecvCtx(ctx, t, updateC.recv)
testutil.RequireSendCtx(ctx, t, upRecvCall.resp, initUp) testutil.RequireSendCtx(ctx, t, upRecvCall.resp, initUp)
expectedCoderConnectFQDN, err := dnsname.ToFQDN(fmt.Sprintf(tailnet.IsCoderConnectEnabledFmtString, "coder"))
require.NoError(t, err)
// DNS for w1a1 // DNS for w1a1
expectedDNS := map[dnsname.FQDN][]netip.Addr{ expectedDNS := map[dnsname.FQDN][]netip.Addr{
"w1a1.w1.me.coder.": {ws1a1IP}, "w1a1.w1.me.coder.": {ws1a1IP},
"w1a1.w1.testy.coder.": {ws1a1IP}, "w1a1.w1.testy.coder.": {ws1a1IP},
"w1.coder.": {ws1a1IP}, "w1.coder.": {ws1a1IP},
tailnet.IsCoderConnectEnabledFQDNString: {tsaddr.CoderServiceIPv6()}, expectedCoderConnectFQDN: {tsaddr.CoderServiceIPv6()},
} }
dnsCall := testutil.RequireRecvCtx(ctx, t, fDNS.calls) dnsCall := testutil.RequireRecvCtx(ctx, t, fDNS.calls)
require.Equal(t, expectedDNS, dnsCall.hosts) require.Equal(t, expectedDNS, dnsCall.hosts)
@ -1816,7 +1825,7 @@ func TestTunnelAllWorkspaceUpdatesController_DNSError(t *testing.T) {
testutil.RequireSendCtx(ctx, t, closeCall, io.EOF) testutil.RequireSendCtx(ctx, t, closeCall, io.EOF)
// error should be our initial DNS error // error should be our initial DNS error
err := testutil.RequireRecvCtx(ctx, t, updateCW.Wait()) err = testutil.RequireRecvCtx(ctx, t, updateCW.Wait())
require.ErrorIs(t, err, dnsError) require.ErrorIs(t, err, dnsError)
} }

View File

@ -107,6 +107,11 @@ func (*client) NewConn(initCtx context.Context, serverURL *url.URL, token string
if err != nil { if err != nil {
return nil, xerrors.Errorf("get connection info: %w", err) return nil, xerrors.Errorf("get connection info: %w", err)
} }
// default to DNS suffix of "coder" if the server hasn't set it (might be too old).
dnsNameOptions := tailnet.DNSNameOptions{Suffix: "coder"}
if connInfo.HostnameSuffix != "" {
dnsNameOptions.Suffix = connInfo.HostnameSuffix
}
headers.Set(codersdk.SessionTokenHeader, token) headers.Set(codersdk.SessionTokenHeader, token)
dialer := workspacesdk.NewWebsocketDialer(options.Logger, rpcURL, &websocket.DialOptions{ dialer := workspacesdk.NewWebsocketDialer(options.Logger, rpcURL, &websocket.DialOptions{
@ -148,7 +153,7 @@ func (*client) NewConn(initCtx context.Context, serverURL *url.URL, token string
updatesCtrl := tailnet.NewTunnelAllWorkspaceUpdatesController( updatesCtrl := tailnet.NewTunnelAllWorkspaceUpdatesController(
options.Logger, options.Logger,
coordCtrl, coordCtrl,
tailnet.WithDNS(conn, me.Username), tailnet.WithDNS(conn, me.Username, dnsNameOptions),
tailnet.WithHandler(options.UpdateHandler), tailnet.WithHandler(options.UpdateHandler),
) )
controller.WorkspaceUpdatesCtrl = updatesCtrl controller.WorkspaceUpdatesCtrl = updatesCtrl

View File

@ -3,11 +3,14 @@ package vpn_test
import ( import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/netip"
"net/url" "net/url"
"sync/atomic" "sync/atomic"
"testing" "testing"
"time" "time"
"tailscale.com/util/dnsname"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -29,136 +32,180 @@ import (
func TestClient_WorkspaceUpdates(t *testing.T) { func TestClient_WorkspaceUpdates(t *testing.T) {
t.Parallel() t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
logger := testutil.Logger(t)
userID := uuid.UUID{1} userID := uuid.UUID{1}
wsID := uuid.UUID{2} wsID := uuid.UUID{2}
peerID := uuid.UUID{3} peerID := uuid.UUID{3}
agentID := uuid.UUID{4}
fCoord := tailnettest.NewFakeCoordinator() testCases := []struct {
var coord tailnet.Coordinator = fCoord name string
coordPtr := atomic.Pointer[tailnet.Coordinator]{} agentConnectionInfo workspacesdk.AgentConnectionInfo
coordPtr.Store(&coord) hostnames []string
ctrl := gomock.NewController(t) }{
mProvider := tailnettest.NewMockWorkspaceUpdatesProvider(ctrl) {
name: "empty",
mSub := tailnettest.NewMockSubscription(ctrl) agentConnectionInfo: workspacesdk.AgentConnectionInfo{},
outUpdateCh := make(chan *proto.WorkspaceUpdate, 1) hostnames: []string{"wrk.coder.", "agnt.wrk.me.coder.", "agnt.wrk.rootbeer.coder."},
inUpdateCh := make(chan tailnet.WorkspaceUpdate, 1) },
mProvider.EXPECT().Subscribe(gomock.Any(), userID).Times(1).Return(mSub, nil) {
mSub.EXPECT().Updates().MinTimes(1).Return(outUpdateCh) name: "suffix",
mSub.EXPECT().Close().Times(1).Return(nil) agentConnectionInfo: workspacesdk.AgentConnectionInfo{HostnameSuffix: "float"},
hostnames: []string{"wrk.float.", "agnt.wrk.me.float.", "agnt.wrk.rootbeer.float."},
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
Logger: logger,
CoordPtr: &coordPtr,
DERPMapUpdateFrequency: time.Hour,
DERPMapFn: func() *tailcfg.DERPMap { return &tailcfg.DERPMap{} },
WorkspaceUpdatesProvider: mProvider,
ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(),
})
require.NoError(t, err)
user := make(chan struct{})
connInfo := make(chan struct{})
serveErrCh := make(chan error)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/v2/users/me":
httpapi.Write(ctx, w, http.StatusOK, codersdk.User{
ReducedUser: codersdk.ReducedUser{
MinimalUser: codersdk.MinimalUser{
ID: userID,
},
},
})
user <- struct{}{}
case "/api/v2/workspaceagents/connection":
httpapi.Write(ctx, w, http.StatusOK, workspacesdk.AgentConnectionInfo{
DisableDirectConnections: false,
})
connInfo <- struct{}{}
case "/api/v2/tailnet":
// need 2.3 for WorkspaceUpdates RPC
cVer := r.URL.Query().Get("version")
assert.Equal(t, "2.3", cVer)
sws, err := websocket.Accept(w, r, nil)
if !assert.NoError(t, err) {
return
}
wsCtx, nc := codersdk.WebsocketNetConn(ctx, sws, websocket.MessageBinary)
serveErrCh <- svc.ServeConnV2(wsCtx, nc, tailnet.StreamID{
Name: "client",
ID: peerID,
// Auth can be nil as we use a mock update provider
Auth: tailnet.ClientUserCoordinateeAuth{
Auth: nil,
},
})
default:
http.NotFound(w, r)
}
}))
t.Cleanup(server.Close)
svrURL, err := url.Parse(server.URL)
require.NoError(t, err)
connErrCh := make(chan error)
connCh := make(chan vpn.Conn)
go func() {
conn, err := vpn.NewClient().NewConn(ctx, svrURL, "fakeToken", &vpn.Options{
UpdateHandler: updateHandler(func(wu tailnet.WorkspaceUpdate) error {
inUpdateCh <- wu
return nil
}),
DNSConfigurator: &noopConfigurator{},
})
connErrCh <- err
connCh <- conn
}()
testutil.RequireRecvCtx(ctx, t, user)
testutil.RequireRecvCtx(ctx, t, connInfo)
err = testutil.RequireRecvCtx(ctx, t, connErrCh)
require.NoError(t, err)
conn := testutil.RequireRecvCtx(ctx, t, connCh)
// Send a workspace update
update := &proto.WorkspaceUpdate{
UpsertedWorkspaces: []*proto.Workspace{
{
Id: wsID[:],
},
}, },
} }
testutil.RequireSendCtx(ctx, t, outUpdateCh, update) for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
// It'll be received by the update handler ctx := testutil.Context(t, testutil.WaitShort)
recvUpdate := testutil.RequireRecvCtx(ctx, t, inUpdateCh) logger := testutil.Logger(t)
require.Len(t, recvUpdate.UpsertedWorkspaces, 1)
require.Equal(t, wsID, recvUpdate.UpsertedWorkspaces[0].ID)
// And be reflected on the Conn's state fCoord := tailnettest.NewFakeCoordinator()
state, err := conn.CurrentWorkspaceState() var coord tailnet.Coordinator = fCoord
require.NoError(t, err) coordPtr := atomic.Pointer[tailnet.Coordinator]{}
require.Equal(t, tailnet.WorkspaceUpdate{ coordPtr.Store(&coord)
UpsertedWorkspaces: []*tailnet.Workspace{ ctrl := gomock.NewController(t)
{ mProvider := tailnettest.NewMockWorkspaceUpdatesProvider(ctrl)
ID: wsID,
},
},
UpsertedAgents: []*tailnet.Agent{},
DeletedWorkspaces: []*tailnet.Workspace{},
DeletedAgents: []*tailnet.Agent{},
}, state)
// Close the conn mSub := tailnettest.NewMockSubscription(ctrl)
conn.Close() outUpdateCh := make(chan *proto.WorkspaceUpdate, 1)
err = testutil.RequireRecvCtx(ctx, t, serveErrCh) inUpdateCh := make(chan tailnet.WorkspaceUpdate, 1)
require.NoError(t, err) mProvider.EXPECT().Subscribe(gomock.Any(), userID).Times(1).Return(mSub, nil)
mSub.EXPECT().Updates().MinTimes(1).Return(outUpdateCh)
mSub.EXPECT().Close().Times(1).Return(nil)
svc, err := tailnet.NewClientService(tailnet.ClientServiceOptions{
Logger: logger,
CoordPtr: &coordPtr,
DERPMapUpdateFrequency: time.Hour,
DERPMapFn: func() *tailcfg.DERPMap { return &tailcfg.DERPMap{} },
WorkspaceUpdatesProvider: mProvider,
ResumeTokenProvider: tailnet.NewInsecureTestResumeTokenProvider(),
})
require.NoError(t, err)
user := make(chan struct{})
connInfo := make(chan struct{})
serveErrCh := make(chan error)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/v2/users/me":
httpapi.Write(ctx, w, http.StatusOK, codersdk.User{
ReducedUser: codersdk.ReducedUser{
MinimalUser: codersdk.MinimalUser{
ID: userID,
Username: "rootbeer",
},
},
})
user <- struct{}{}
case "/api/v2/workspaceagents/connection":
httpapi.Write(ctx, w, http.StatusOK, tc.agentConnectionInfo)
connInfo <- struct{}{}
case "/api/v2/tailnet":
// need 2.3 for WorkspaceUpdates RPC
cVer := r.URL.Query().Get("version")
assert.Equal(t, "2.3", cVer)
sws, err := websocket.Accept(w, r, nil)
if !assert.NoError(t, err) {
return
}
wsCtx, nc := codersdk.WebsocketNetConn(ctx, sws, websocket.MessageBinary)
serveErrCh <- svc.ServeConnV2(wsCtx, nc, tailnet.StreamID{
Name: "client",
ID: peerID,
// Auth can be nil as we use a mock update provider
Auth: tailnet.ClientUserCoordinateeAuth{
Auth: nil,
},
})
default:
http.NotFound(w, r)
}
}))
t.Cleanup(server.Close)
svrURL, err := url.Parse(server.URL)
require.NoError(t, err)
connErrCh := make(chan error)
connCh := make(chan vpn.Conn)
go func() {
conn, err := vpn.NewClient().NewConn(ctx, svrURL, "fakeToken", &vpn.Options{
UpdateHandler: updateHandler(func(wu tailnet.WorkspaceUpdate) error {
inUpdateCh <- wu
return nil
}),
DNSConfigurator: &noopConfigurator{},
})
connErrCh <- err
connCh <- conn
}()
testutil.RequireRecvCtx(ctx, t, user)
testutil.RequireRecvCtx(ctx, t, connInfo)
err = testutil.RequireRecvCtx(ctx, t, connErrCh)
require.NoError(t, err)
conn := testutil.RequireRecvCtx(ctx, t, connCh)
// Send a workspace update
update := &proto.WorkspaceUpdate{
UpsertedWorkspaces: []*proto.Workspace{
{
Id: wsID[:],
Name: "wrk",
},
},
UpsertedAgents: []*proto.Agent{
{
Id: agentID[:],
Name: "agnt",
WorkspaceId: wsID[:],
},
},
}
testutil.RequireSendCtx(ctx, t, outUpdateCh, update)
// It'll be received by the update handler
recvUpdate := testutil.RequireRecvCtx(ctx, t, inUpdateCh)
require.Len(t, recvUpdate.UpsertedWorkspaces, 1)
require.Equal(t, wsID, recvUpdate.UpsertedWorkspaces[0].ID)
require.Len(t, recvUpdate.UpsertedAgents, 1)
expectedHosts := map[dnsname.FQDN][]netip.Addr{}
for _, name := range tc.hostnames {
expectedHosts[dnsname.FQDN(name)] = []netip.Addr{tailnet.CoderServicePrefix.AddrFromUUID(agentID)}
}
// And be reflected on the Conn's state
state, err := conn.CurrentWorkspaceState()
require.NoError(t, err)
require.Equal(t, tailnet.WorkspaceUpdate{
UpsertedWorkspaces: []*tailnet.Workspace{
{
ID: wsID,
Name: "wrk",
},
},
UpsertedAgents: []*tailnet.Agent{
{
ID: agentID,
Name: "agnt",
WorkspaceID: wsID,
Hosts: expectedHosts,
},
},
DeletedWorkspaces: []*tailnet.Workspace{},
DeletedAgents: []*tailnet.Agent{},
}, state)
// Close the conn
conn.Close()
err = testutil.RequireRecvCtx(ctx, t, serveErrCh)
require.NoError(t, err)
})
}
} }
type updateHandler func(tailnet.WorkspaceUpdate) error type updateHandler func(tailnet.WorkspaceUpdate) error