mirror of
https://github.com/coder/coder.git
synced 2025-07-03 16:13:58 +00:00
chore of #14729 Refactors the `ServerTailnet` to use `tailnet.Controller` so that we reuse logic around reconnection and handling control messages, instead of reimplementing. This unifies our "client" use of the tailscale API across CLI, coderd, and wsproxy.
402 lines
9.0 KiB
Go
402 lines
9.0 KiB
Go
package test
|
|
|
|
import (
|
|
"context"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/stretchr/testify/assert"
|
|
"golang.org/x/xerrors"
|
|
"tailscale.com/types/key"
|
|
|
|
"github.com/coder/coder/v2/tailnet"
|
|
"github.com/coder/coder/v2/tailnet/proto"
|
|
)
|
|
|
|
type PeerStatus struct {
|
|
preferredDERP int32
|
|
status proto.CoordinateResponse_PeerUpdate_Kind
|
|
readyForHandshake bool
|
|
}
|
|
|
|
type PeerOption func(*Peer)
|
|
|
|
func WithID(id uuid.UUID) PeerOption {
|
|
return func(p *Peer) {
|
|
p.ID = id
|
|
}
|
|
}
|
|
|
|
func WithAuth(auth tailnet.CoordinateeAuth) PeerOption {
|
|
return func(p *Peer) {
|
|
p.auth = auth
|
|
}
|
|
}
|
|
|
|
type Peer struct {
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
t testing.TB
|
|
ID uuid.UUID
|
|
auth tailnet.CoordinateeAuth
|
|
name string
|
|
nodeKey key.NodePublic
|
|
discoKey key.DiscoPublic
|
|
resps <-chan *proto.CoordinateResponse
|
|
reqs chan<- *proto.CoordinateRequest
|
|
peers map[uuid.UUID]PeerStatus
|
|
peerUpdates map[uuid.UUID][]*proto.CoordinateResponse_PeerUpdate
|
|
}
|
|
|
|
func NewPeer(ctx context.Context, t testing.TB, coord tailnet.CoordinatorV2, name string, opts ...PeerOption) *Peer {
|
|
p := &Peer{
|
|
t: t,
|
|
name: name,
|
|
peers: make(map[uuid.UUID]PeerStatus),
|
|
peerUpdates: make(map[uuid.UUID][]*proto.CoordinateResponse_PeerUpdate),
|
|
ID: uuid.New(),
|
|
// SingleTailnetCoordinateeAuth allows connections to arbitrary peers
|
|
auth: tailnet.SingleTailnetCoordinateeAuth{},
|
|
// required for converting to and from protobuf, so we always include them
|
|
nodeKey: key.NewNode().Public(),
|
|
discoKey: key.NewDisco().Public(),
|
|
}
|
|
p.ctx, p.cancel = context.WithCancel(ctx)
|
|
for _, opt := range opts {
|
|
opt(p)
|
|
}
|
|
|
|
p.reqs, p.resps = coord.Coordinate(p.ctx, p.ID, name, p.auth)
|
|
return p
|
|
}
|
|
|
|
// NewAgent is a wrapper around NewPeer, creating a peer with Agent auth tied to its ID
|
|
func NewAgent(ctx context.Context, t testing.TB, coord tailnet.CoordinatorV2, name string) *Peer {
|
|
id := uuid.New()
|
|
return NewPeer(ctx, t, coord, name, WithID(id), WithAuth(tailnet.AgentCoordinateeAuth{ID: id}))
|
|
}
|
|
|
|
// NewClient is a wrapper around NewPeer, creating a peer with Client auth tied to the provided agentID
|
|
func NewClient(ctx context.Context, t testing.TB, coord tailnet.CoordinatorV2, name string, agentID uuid.UUID) *Peer {
|
|
p := NewPeer(ctx, t, coord, name, WithAuth(tailnet.ClientCoordinateeAuth{AgentID: agentID}))
|
|
p.AddTunnel(agentID)
|
|
return p
|
|
}
|
|
|
|
func (p *Peer) ConnectToCoordinator(ctx context.Context, c tailnet.CoordinatorV2) {
|
|
p.t.Helper()
|
|
p.reqs, p.resps = c.Coordinate(ctx, p.ID, p.name, p.auth)
|
|
}
|
|
|
|
func (p *Peer) AddTunnel(other uuid.UUID) {
|
|
p.t.Helper()
|
|
req := &proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Id: tailnet.UUIDToByteSlice(other)}}
|
|
select {
|
|
case <-p.ctx.Done():
|
|
p.t.Errorf("timeout adding tunnel for %s", p.name)
|
|
return
|
|
case p.reqs <- req:
|
|
return
|
|
}
|
|
}
|
|
|
|
func (p *Peer) RemoveTunnel(other uuid.UUID) {
|
|
p.t.Helper()
|
|
req := &proto.CoordinateRequest{RemoveTunnel: &proto.CoordinateRequest_Tunnel{Id: tailnet.UUIDToByteSlice(other)}}
|
|
select {
|
|
case <-p.ctx.Done():
|
|
p.t.Errorf("timeout removing tunnel for %s", p.name)
|
|
return
|
|
case p.reqs <- req:
|
|
return
|
|
}
|
|
}
|
|
|
|
func (p *Peer) UpdateDERP(derp int32) {
|
|
p.t.Helper()
|
|
node := &proto.Node{PreferredDerp: derp}
|
|
p.UpdateNode(node)
|
|
}
|
|
|
|
func (p *Peer) UpdateNode(node *proto.Node) {
|
|
p.t.Helper()
|
|
nk, err := p.nodeKey.MarshalBinary()
|
|
assert.NoError(p.t, err)
|
|
node.Key = nk
|
|
dk, err := p.discoKey.MarshalText()
|
|
assert.NoError(p.t, err)
|
|
node.Disco = string(dk)
|
|
req := &proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{Node: node}}
|
|
select {
|
|
case <-p.ctx.Done():
|
|
p.t.Errorf("timeout updating node for %s", p.name)
|
|
return
|
|
case p.reqs <- req:
|
|
return
|
|
}
|
|
}
|
|
|
|
func (p *Peer) ReadyForHandshake(peer uuid.UUID) {
|
|
p.t.Helper()
|
|
|
|
req := &proto.CoordinateRequest{ReadyForHandshake: []*proto.CoordinateRequest_ReadyForHandshake{{
|
|
Id: peer[:],
|
|
}}}
|
|
select {
|
|
case <-p.ctx.Done():
|
|
p.t.Errorf("timeout sending ready for handshake for %s", p.name)
|
|
return
|
|
case p.reqs <- req:
|
|
return
|
|
}
|
|
}
|
|
|
|
func (p *Peer) Disconnect() {
|
|
p.t.Helper()
|
|
req := &proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}}
|
|
select {
|
|
case <-p.ctx.Done():
|
|
p.t.Errorf("timeout updating node for %s", p.name)
|
|
return
|
|
case p.reqs <- req:
|
|
return
|
|
}
|
|
}
|
|
|
|
func (p *Peer) AssertEventuallyHasDERP(other uuid.UUID, derp int32) {
|
|
p.t.Helper()
|
|
for {
|
|
o, ok := p.peers[other]
|
|
if ok && o.preferredDERP == derp {
|
|
return
|
|
}
|
|
if err := p.readOneResp(); err != nil {
|
|
assert.NoError(p.t, err)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (p *Peer) AssertNeverHasDERPs(ctx context.Context, other uuid.UUID, expected ...int32) {
|
|
p.t.Helper()
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case resp, ok := <-p.resps:
|
|
if !ok {
|
|
p.t.Errorf("response channel closed")
|
|
}
|
|
if !assert.NoError(p.t, p.handleResp(resp)) {
|
|
return
|
|
}
|
|
derp, ok := p.peers[other]
|
|
if !ok {
|
|
continue
|
|
}
|
|
if !assert.NotContains(p.t, expected, derp) {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (p *Peer) AssertEventuallyDisconnected(other uuid.UUID) {
|
|
p.t.Helper()
|
|
for {
|
|
_, ok := p.peers[other]
|
|
if !ok {
|
|
return
|
|
}
|
|
if err := p.readOneResp(); err != nil {
|
|
assert.NoError(p.t, err)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (p *Peer) AssertEventuallyLost(other uuid.UUID) {
|
|
p.t.Helper()
|
|
for {
|
|
o := p.peers[other]
|
|
if o.status == proto.CoordinateResponse_PeerUpdate_LOST {
|
|
return
|
|
}
|
|
if err := p.readOneResp(); err != nil {
|
|
assert.NoError(p.t, err)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (p *Peer) AssertEventuallyResponsesClosed() {
|
|
p.t.Helper()
|
|
for {
|
|
err := p.readOneResp()
|
|
if xerrors.Is(err, responsesClosed) {
|
|
return
|
|
}
|
|
if !assert.NoError(p.t, err) {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (p *Peer) AssertNotClosed(d time.Duration) {
|
|
p.t.Helper()
|
|
// nolint: gocritic // erroneously thinks we're hardcoding non testutil constants here
|
|
ctx, cancel := context.WithTimeout(context.Background(), d)
|
|
defer cancel()
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
// success!
|
|
return
|
|
case <-p.ctx.Done():
|
|
p.t.Error("main ctx timeout before elapsed time")
|
|
return
|
|
case resp, ok := <-p.resps:
|
|
if !ok {
|
|
p.t.Error("response channel closed")
|
|
return
|
|
}
|
|
err := p.handleResp(resp)
|
|
if !assert.NoError(p.t, err) {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (p *Peer) AssertEventuallyReadyForHandshake(other uuid.UUID) {
|
|
p.t.Helper()
|
|
for {
|
|
o := p.peers[other]
|
|
if o.readyForHandshake {
|
|
return
|
|
}
|
|
|
|
err := p.readOneResp()
|
|
if xerrors.Is(err, responsesClosed) {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (p *Peer) AssertEventuallyGetsError(match string) {
|
|
p.t.Helper()
|
|
for {
|
|
err := p.readOneResp()
|
|
if xerrors.Is(err, responsesClosed) {
|
|
p.t.Error("closed before target error")
|
|
return
|
|
}
|
|
|
|
if err != nil && assert.ErrorContains(p.t, err, match) {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// AssertNeverUpdateKind asserts that we have not received
|
|
// any updates on the provided peer for the provided kind.
|
|
func (p *Peer) AssertNeverUpdateKind(peer uuid.UUID, kind proto.CoordinateResponse_PeerUpdate_Kind) {
|
|
p.t.Helper()
|
|
|
|
updates, ok := p.peerUpdates[peer]
|
|
assert.True(p.t, ok, "expected updates for peer %s", peer)
|
|
|
|
for _, update := range updates {
|
|
assert.NotEqual(p.t, kind, update.Kind, update)
|
|
}
|
|
}
|
|
|
|
var responsesClosed = xerrors.New("responses closed")
|
|
|
|
func (p *Peer) readOneResp() error {
|
|
select {
|
|
case <-p.ctx.Done():
|
|
return p.ctx.Err()
|
|
case resp, ok := <-p.resps:
|
|
if !ok {
|
|
return responsesClosed
|
|
}
|
|
err := p.handleResp(resp)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (p *Peer) handleResp(resp *proto.CoordinateResponse) error {
|
|
if resp.Error != "" {
|
|
return xerrors.New(resp.Error)
|
|
}
|
|
for _, update := range resp.PeerUpdates {
|
|
id, err := uuid.FromBytes(update.Id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
p.peerUpdates[id] = append(p.peerUpdates[id], update)
|
|
|
|
switch update.Kind {
|
|
case proto.CoordinateResponse_PeerUpdate_NODE, proto.CoordinateResponse_PeerUpdate_LOST:
|
|
peer := p.peers[id]
|
|
peer.preferredDERP = update.GetNode().GetPreferredDerp()
|
|
peer.status = update.Kind
|
|
p.peers[id] = peer
|
|
case proto.CoordinateResponse_PeerUpdate_DISCONNECTED:
|
|
delete(p.peers, id)
|
|
case proto.CoordinateResponse_PeerUpdate_READY_FOR_HANDSHAKE:
|
|
peer := p.peers[id]
|
|
peer.readyForHandshake = true
|
|
p.peers[id] = peer
|
|
default:
|
|
return xerrors.Errorf("unhandled update kind %s", update.Kind)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (p *Peer) Close(ctx context.Context) {
|
|
p.t.Helper()
|
|
p.cancel()
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
p.t.Errorf("timeout waiting for responses to close for %s", p.name)
|
|
return
|
|
case _, ok := <-p.resps:
|
|
if ok {
|
|
continue
|
|
}
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (p *Peer) UngracefulDisconnect(ctx context.Context) {
|
|
p.t.Helper()
|
|
close(p.reqs)
|
|
p.Close(ctx)
|
|
}
|
|
|
|
type FakeSubjectKey struct{}
|
|
|
|
type FakeCoordinateeAuth struct {
|
|
Chan chan struct{}
|
|
}
|
|
|
|
func (f FakeCoordinateeAuth) Authorize(ctx context.Context, _ *proto.CoordinateRequest) error {
|
|
_, ok := ctx.Value(FakeSubjectKey{}).(struct{})
|
|
if !ok {
|
|
return xerrors.New("unauthorized")
|
|
}
|
|
f.Chan <- struct{}{}
|
|
return nil
|
|
}
|
|
|
|
var _ tailnet.CoordinateeAuth = (*FakeCoordinateeAuth)(nil)
|