Files
coder/tailnet/test/peer.go
Spike Curtis 8c00ebc6ee chore: refactor ServerTailnet to use tailnet.Controllers (#15408)
chore of #14729

Refactors the `ServerTailnet` to use `tailnet.Controller` so that we reuse logic around reconnection and handling control messages, instead of reimplementing.  This unifies our "client" use of the tailscale API across CLI, coderd, and wsproxy.
2024-11-08 13:18:56 +04:00

402 lines
9.0 KiB
Go

package test
import (
"context"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"golang.org/x/xerrors"
"tailscale.com/types/key"
"github.com/coder/coder/v2/tailnet"
"github.com/coder/coder/v2/tailnet/proto"
)
type PeerStatus struct {
preferredDERP int32
status proto.CoordinateResponse_PeerUpdate_Kind
readyForHandshake bool
}
type PeerOption func(*Peer)
func WithID(id uuid.UUID) PeerOption {
return func(p *Peer) {
p.ID = id
}
}
func WithAuth(auth tailnet.CoordinateeAuth) PeerOption {
return func(p *Peer) {
p.auth = auth
}
}
type Peer struct {
ctx context.Context
cancel context.CancelFunc
t testing.TB
ID uuid.UUID
auth tailnet.CoordinateeAuth
name string
nodeKey key.NodePublic
discoKey key.DiscoPublic
resps <-chan *proto.CoordinateResponse
reqs chan<- *proto.CoordinateRequest
peers map[uuid.UUID]PeerStatus
peerUpdates map[uuid.UUID][]*proto.CoordinateResponse_PeerUpdate
}
func NewPeer(ctx context.Context, t testing.TB, coord tailnet.CoordinatorV2, name string, opts ...PeerOption) *Peer {
p := &Peer{
t: t,
name: name,
peers: make(map[uuid.UUID]PeerStatus),
peerUpdates: make(map[uuid.UUID][]*proto.CoordinateResponse_PeerUpdate),
ID: uuid.New(),
// SingleTailnetCoordinateeAuth allows connections to arbitrary peers
auth: tailnet.SingleTailnetCoordinateeAuth{},
// required for converting to and from protobuf, so we always include them
nodeKey: key.NewNode().Public(),
discoKey: key.NewDisco().Public(),
}
p.ctx, p.cancel = context.WithCancel(ctx)
for _, opt := range opts {
opt(p)
}
p.reqs, p.resps = coord.Coordinate(p.ctx, p.ID, name, p.auth)
return p
}
// NewAgent is a wrapper around NewPeer, creating a peer with Agent auth tied to its ID
func NewAgent(ctx context.Context, t testing.TB, coord tailnet.CoordinatorV2, name string) *Peer {
id := uuid.New()
return NewPeer(ctx, t, coord, name, WithID(id), WithAuth(tailnet.AgentCoordinateeAuth{ID: id}))
}
// NewClient is a wrapper around NewPeer, creating a peer with Client auth tied to the provided agentID
func NewClient(ctx context.Context, t testing.TB, coord tailnet.CoordinatorV2, name string, agentID uuid.UUID) *Peer {
p := NewPeer(ctx, t, coord, name, WithAuth(tailnet.ClientCoordinateeAuth{AgentID: agentID}))
p.AddTunnel(agentID)
return p
}
func (p *Peer) ConnectToCoordinator(ctx context.Context, c tailnet.CoordinatorV2) {
p.t.Helper()
p.reqs, p.resps = c.Coordinate(ctx, p.ID, p.name, p.auth)
}
func (p *Peer) AddTunnel(other uuid.UUID) {
p.t.Helper()
req := &proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Id: tailnet.UUIDToByteSlice(other)}}
select {
case <-p.ctx.Done():
p.t.Errorf("timeout adding tunnel for %s", p.name)
return
case p.reqs <- req:
return
}
}
func (p *Peer) RemoveTunnel(other uuid.UUID) {
p.t.Helper()
req := &proto.CoordinateRequest{RemoveTunnel: &proto.CoordinateRequest_Tunnel{Id: tailnet.UUIDToByteSlice(other)}}
select {
case <-p.ctx.Done():
p.t.Errorf("timeout removing tunnel for %s", p.name)
return
case p.reqs <- req:
return
}
}
func (p *Peer) UpdateDERP(derp int32) {
p.t.Helper()
node := &proto.Node{PreferredDerp: derp}
p.UpdateNode(node)
}
func (p *Peer) UpdateNode(node *proto.Node) {
p.t.Helper()
nk, err := p.nodeKey.MarshalBinary()
assert.NoError(p.t, err)
node.Key = nk
dk, err := p.discoKey.MarshalText()
assert.NoError(p.t, err)
node.Disco = string(dk)
req := &proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{Node: node}}
select {
case <-p.ctx.Done():
p.t.Errorf("timeout updating node for %s", p.name)
return
case p.reqs <- req:
return
}
}
func (p *Peer) ReadyForHandshake(peer uuid.UUID) {
p.t.Helper()
req := &proto.CoordinateRequest{ReadyForHandshake: []*proto.CoordinateRequest_ReadyForHandshake{{
Id: peer[:],
}}}
select {
case <-p.ctx.Done():
p.t.Errorf("timeout sending ready for handshake for %s", p.name)
return
case p.reqs <- req:
return
}
}
func (p *Peer) Disconnect() {
p.t.Helper()
req := &proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}}
select {
case <-p.ctx.Done():
p.t.Errorf("timeout updating node for %s", p.name)
return
case p.reqs <- req:
return
}
}
func (p *Peer) AssertEventuallyHasDERP(other uuid.UUID, derp int32) {
p.t.Helper()
for {
o, ok := p.peers[other]
if ok && o.preferredDERP == derp {
return
}
if err := p.readOneResp(); err != nil {
assert.NoError(p.t, err)
return
}
}
}
func (p *Peer) AssertNeverHasDERPs(ctx context.Context, other uuid.UUID, expected ...int32) {
p.t.Helper()
for {
select {
case <-ctx.Done():
return
case resp, ok := <-p.resps:
if !ok {
p.t.Errorf("response channel closed")
}
if !assert.NoError(p.t, p.handleResp(resp)) {
return
}
derp, ok := p.peers[other]
if !ok {
continue
}
if !assert.NotContains(p.t, expected, derp) {
return
}
}
}
}
func (p *Peer) AssertEventuallyDisconnected(other uuid.UUID) {
p.t.Helper()
for {
_, ok := p.peers[other]
if !ok {
return
}
if err := p.readOneResp(); err != nil {
assert.NoError(p.t, err)
return
}
}
}
func (p *Peer) AssertEventuallyLost(other uuid.UUID) {
p.t.Helper()
for {
o := p.peers[other]
if o.status == proto.CoordinateResponse_PeerUpdate_LOST {
return
}
if err := p.readOneResp(); err != nil {
assert.NoError(p.t, err)
return
}
}
}
func (p *Peer) AssertEventuallyResponsesClosed() {
p.t.Helper()
for {
err := p.readOneResp()
if xerrors.Is(err, responsesClosed) {
return
}
if !assert.NoError(p.t, err) {
return
}
}
}
func (p *Peer) AssertNotClosed(d time.Duration) {
p.t.Helper()
// nolint: gocritic // erroneously thinks we're hardcoding non testutil constants here
ctx, cancel := context.WithTimeout(context.Background(), d)
defer cancel()
for {
select {
case <-ctx.Done():
// success!
return
case <-p.ctx.Done():
p.t.Error("main ctx timeout before elapsed time")
return
case resp, ok := <-p.resps:
if !ok {
p.t.Error("response channel closed")
return
}
err := p.handleResp(resp)
if !assert.NoError(p.t, err) {
return
}
}
}
}
func (p *Peer) AssertEventuallyReadyForHandshake(other uuid.UUID) {
p.t.Helper()
for {
o := p.peers[other]
if o.readyForHandshake {
return
}
err := p.readOneResp()
if xerrors.Is(err, responsesClosed) {
return
}
}
}
func (p *Peer) AssertEventuallyGetsError(match string) {
p.t.Helper()
for {
err := p.readOneResp()
if xerrors.Is(err, responsesClosed) {
p.t.Error("closed before target error")
return
}
if err != nil && assert.ErrorContains(p.t, err, match) {
return
}
}
}
// AssertNeverUpdateKind asserts that we have not received
// any updates on the provided peer for the provided kind.
func (p *Peer) AssertNeverUpdateKind(peer uuid.UUID, kind proto.CoordinateResponse_PeerUpdate_Kind) {
p.t.Helper()
updates, ok := p.peerUpdates[peer]
assert.True(p.t, ok, "expected updates for peer %s", peer)
for _, update := range updates {
assert.NotEqual(p.t, kind, update.Kind, update)
}
}
var responsesClosed = xerrors.New("responses closed")
func (p *Peer) readOneResp() error {
select {
case <-p.ctx.Done():
return p.ctx.Err()
case resp, ok := <-p.resps:
if !ok {
return responsesClosed
}
err := p.handleResp(resp)
if err != nil {
return err
}
}
return nil
}
func (p *Peer) handleResp(resp *proto.CoordinateResponse) error {
if resp.Error != "" {
return xerrors.New(resp.Error)
}
for _, update := range resp.PeerUpdates {
id, err := uuid.FromBytes(update.Id)
if err != nil {
return err
}
p.peerUpdates[id] = append(p.peerUpdates[id], update)
switch update.Kind {
case proto.CoordinateResponse_PeerUpdate_NODE, proto.CoordinateResponse_PeerUpdate_LOST:
peer := p.peers[id]
peer.preferredDERP = update.GetNode().GetPreferredDerp()
peer.status = update.Kind
p.peers[id] = peer
case proto.CoordinateResponse_PeerUpdate_DISCONNECTED:
delete(p.peers, id)
case proto.CoordinateResponse_PeerUpdate_READY_FOR_HANDSHAKE:
peer := p.peers[id]
peer.readyForHandshake = true
p.peers[id] = peer
default:
return xerrors.Errorf("unhandled update kind %s", update.Kind)
}
}
return nil
}
func (p *Peer) Close(ctx context.Context) {
p.t.Helper()
p.cancel()
for {
select {
case <-ctx.Done():
p.t.Errorf("timeout waiting for responses to close for %s", p.name)
return
case _, ok := <-p.resps:
if ok {
continue
}
return
}
}
}
func (p *Peer) UngracefulDisconnect(ctx context.Context) {
p.t.Helper()
close(p.reqs)
p.Close(ctx)
}
type FakeSubjectKey struct{}
type FakeCoordinateeAuth struct {
Chan chan struct{}
}
func (f FakeCoordinateeAuth) Authorize(ctx context.Context, _ *proto.CoordinateRequest) error {
_, ok := ctx.Value(FakeSubjectKey{}).(struct{})
if !ok {
return xerrors.New("unauthorized")
}
f.Chan <- struct{}{}
return nil
}
var _ tailnet.CoordinateeAuth = (*FakeCoordinateeAuth)(nil)