package codersdk import ( "context" "io" "net/http" "net/http/httptest" "sync/atomic" "testing" "time" "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "nhooyr.io/websocket" "tailscale.com/tailcfg" "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/v2/tailnet" "github.com/coder/coder/v2/tailnet/proto" "github.com/coder/coder/v2/tailnet/tailnettest" "github.com/coder/coder/v2/testutil" ) func TestTailnetAPIConnector_Disconnects(t *testing.T) { t.Parallel() testCtx := testutil.Context(t, testutil.WaitShort) ctx, cancel := context.WithCancel(testCtx) logger := slogtest.Make(t, &slogtest.Options{ // we get EOF when we simulate a DERPMap error IgnoredErrorIs: append(slogtest.DefaultIgnoredErrorIs, io.EOF), }).Leveled(slog.LevelDebug) agentID := uuid.UUID{0x55} clientID := uuid.UUID{0x66} fCoord := tailnettest.NewFakeCoordinator() var coord tailnet.Coordinator = fCoord coordPtr := atomic.Pointer[tailnet.Coordinator]{} coordPtr.Store(&coord) derpMapCh := make(chan *tailcfg.DERPMap) defer close(derpMapCh) svc, err := tailnet.NewClientService( logger, &coordPtr, time.Millisecond, func() *tailcfg.DERPMap { return <-derpMapCh }, ) require.NoError(t, err) svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { sws, err := websocket.Accept(w, r, nil) if !assert.NoError(t, err) { return } ctx, nc := WebsocketNetConn(r.Context(), sws, websocket.MessageBinary) err = svc.ServeConnV2(ctx, nc, tailnet.StreamID{ Name: "client", ID: clientID, Auth: tailnet.ClientCoordinateeAuth{AgentID: agentID}, }) assert.NoError(t, err) })) fConn := newFakeTailnetConn() uut := runTailnetAPIConnector(ctx, logger, agentID, svr.URL, &websocket.DialOptions{}, fConn) call := testutil.RequireRecvCtx(ctx, t, fCoord.CoordinateCalls) reqTun := testutil.RequireRecvCtx(ctx, t, call.Reqs) require.NotNil(t, reqTun.AddTunnel) _ = testutil.RequireRecvCtx(ctx, t, uut.connected) // simulate a problem with DERPMaps by sending nil testutil.RequireSendCtx(ctx, t, derpMapCh, nil) // this should cause the coordinate call to hang up WITHOUT disconnecting reqNil := testutil.RequireRecvCtx(ctx, t, call.Reqs) require.Nil(t, reqNil) // ...and then reconnect call = testutil.RequireRecvCtx(ctx, t, fCoord.CoordinateCalls) reqTun = testutil.RequireRecvCtx(ctx, t, call.Reqs) require.NotNil(t, reqTun.AddTunnel) // canceling the context should trigger the disconnect message cancel() reqDisc := testutil.RequireRecvCtx(testCtx, t, call.Reqs) require.NotNil(t, reqDisc) require.NotNil(t, reqDisc.Disconnect) } type fakeTailnetConn struct{} func (*fakeTailnetConn) UpdatePeers([]*proto.CoordinateResponse_PeerUpdate) error { // TODO implement me panic("implement me") } func (*fakeTailnetConn) SetAllPeersLost() {} func (*fakeTailnetConn) SetNodeCallback(func(*tailnet.Node)) {} func (*fakeTailnetConn) SetDERPMap(*tailcfg.DERPMap) {} func newFakeTailnetConn() *fakeTailnetConn { return &fakeTailnetConn{} }