mirror of
https://github.com/coder/coder.git
synced 2025-07-09 11:45:56 +00:00
This adds the ability for `TunnelAuth` to also authorize incoming wireguard node IPs, preventing agents from reporting anything other than their static IP generated from the agent ID.
107 lines
3.0 KiB
Go
107 lines
3.0 KiB
Go
package codersdk
|
|
|
|
import (
|
|
"context"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"nhooyr.io/websocket"
|
|
"tailscale.com/tailcfg"
|
|
|
|
"cdr.dev/slog"
|
|
"cdr.dev/slog/sloggers/slogtest"
|
|
"github.com/coder/coder/v2/tailnet"
|
|
"github.com/coder/coder/v2/tailnet/proto"
|
|
"github.com/coder/coder/v2/tailnet/tailnettest"
|
|
"github.com/coder/coder/v2/testutil"
|
|
)
|
|
|
|
func TestTailnetAPIConnector_Disconnects(t *testing.T) {
|
|
t.Parallel()
|
|
testCtx := testutil.Context(t, testutil.WaitShort)
|
|
ctx, cancel := context.WithCancel(testCtx)
|
|
logger := slogtest.Make(t, &slogtest.Options{
|
|
// we get EOF when we simulate a DERPMap error
|
|
IgnoredErrorIs: append(slogtest.DefaultIgnoredErrorIs, io.EOF),
|
|
}).Leveled(slog.LevelDebug)
|
|
agentID := uuid.UUID{0x55}
|
|
clientID := uuid.UUID{0x66}
|
|
fCoord := tailnettest.NewFakeCoordinator()
|
|
var coord tailnet.Coordinator = fCoord
|
|
coordPtr := atomic.Pointer[tailnet.Coordinator]{}
|
|
coordPtr.Store(&coord)
|
|
derpMapCh := make(chan *tailcfg.DERPMap)
|
|
defer close(derpMapCh)
|
|
svc, err := tailnet.NewClientService(
|
|
logger, &coordPtr,
|
|
time.Millisecond, func() *tailcfg.DERPMap { return <-derpMapCh },
|
|
)
|
|
require.NoError(t, err)
|
|
|
|
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
sws, err := websocket.Accept(w, r, nil)
|
|
if !assert.NoError(t, err) {
|
|
return
|
|
}
|
|
ctx, nc := WebsocketNetConn(r.Context(), sws, websocket.MessageBinary)
|
|
err = svc.ServeConnV2(ctx, nc, tailnet.StreamID{
|
|
Name: "client",
|
|
ID: clientID,
|
|
Auth: tailnet.ClientCoordinateeAuth{AgentID: agentID},
|
|
})
|
|
assert.NoError(t, err)
|
|
}))
|
|
|
|
fConn := newFakeTailnetConn()
|
|
|
|
uut := runTailnetAPIConnector(ctx, logger, agentID, svr.URL, &websocket.DialOptions{}, fConn)
|
|
|
|
call := testutil.RequireRecvCtx(ctx, t, fCoord.CoordinateCalls)
|
|
reqTun := testutil.RequireRecvCtx(ctx, t, call.Reqs)
|
|
require.NotNil(t, reqTun.AddTunnel)
|
|
|
|
_ = testutil.RequireRecvCtx(ctx, t, uut.connected)
|
|
|
|
// simulate a problem with DERPMaps by sending nil
|
|
testutil.RequireSendCtx(ctx, t, derpMapCh, nil)
|
|
|
|
// this should cause the coordinate call to hang up WITHOUT disconnecting
|
|
reqNil := testutil.RequireRecvCtx(ctx, t, call.Reqs)
|
|
require.Nil(t, reqNil)
|
|
|
|
// ...and then reconnect
|
|
call = testutil.RequireRecvCtx(ctx, t, fCoord.CoordinateCalls)
|
|
reqTun = testutil.RequireRecvCtx(ctx, t, call.Reqs)
|
|
require.NotNil(t, reqTun.AddTunnel)
|
|
|
|
// canceling the context should trigger the disconnect message
|
|
cancel()
|
|
reqDisc := testutil.RequireRecvCtx(testCtx, t, call.Reqs)
|
|
require.NotNil(t, reqDisc)
|
|
require.NotNil(t, reqDisc.Disconnect)
|
|
}
|
|
|
|
type fakeTailnetConn struct{}
|
|
|
|
func (*fakeTailnetConn) UpdatePeers([]*proto.CoordinateResponse_PeerUpdate) error {
|
|
// TODO implement me
|
|
panic("implement me")
|
|
}
|
|
|
|
func (*fakeTailnetConn) SetAllPeersLost() {}
|
|
|
|
func (*fakeTailnetConn) SetNodeCallback(func(*tailnet.Node)) {}
|
|
|
|
func (*fakeTailnetConn) SetDERPMap(*tailcfg.DERPMap) {}
|
|
|
|
func newFakeTailnetConn() *fakeTailnetConn {
|
|
return &fakeTailnetConn{}
|
|
}
|