mirror of
synced 2025-03-14 09:55:08 +00:00
Update removing multiple old middlewares, rework the way data is passed through the context, logging fields, etc. Fix minimum keepalive interval enforcement. Signed-off-by: Andrey Smirnov <andrey.smirnov@siderolabs.com>
752 lines
16 KiB
752 lines
16 KiB
// Copyright (c) 2024 Sidero Labs, Inc.
// Use of this software is governed by the Business Source License
// included in the LICENSE file.
package server_test
import (
clientpb "github.com/siderolabs/discovery-api/api/v1alpha1/client/pb"
func TestClient(t *testing.T) {
endpoint := setupServer(t, 5000, "").address
logger := zaptest.NewLogger(t)
t.Run("TwoClients", func(t *testing.T) {
clusterID := "cluster_1"
key := make([]byte, 32)
_, err := io.ReadFull(rand.Reader, key)
require.NoError(t, err)
cipher, err := aes.NewCipher(key)
require.NoError(t, err)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
affiliate1 := "af_1"
affiliate2 := "af_2"
client1, err := client.NewClient(client.Options{
Cipher: cipher,
Endpoint: endpoint,
ClusterID: clusterID,
AffiliateID: affiliate1,
TTL: time.Minute,
Insecure: true,
require.NoError(t, err)
client2, err := client.NewClient(client.Options{
Cipher: cipher,
Endpoint: endpoint,
ClusterID: clusterID,
AffiliateID: affiliate2,
TTL: time.Minute,
Insecure: true,
require.NoError(t, err)
notify1 := make(chan struct{}, 1)
notify2 := make(chan struct{}, 1)
eg, ctx := errgroup.WithContext(ctx)
eg.Go(func() error {
return client1.Run(ctx, logger, notify1)
eg.Go(func() error {
return client2.Run(ctx, logger, notify2)
select {
case <-notify1:
case <-time.After(2 * time.Second):
require.Fail(t, "no initial snapshot update")
assert.Empty(t, client1.GetAffiliates())
select {
case <-notify2:
case <-time.After(2 * time.Second):
require.Fail(t, "no initial snapshot update")
assert.Empty(t, client2.GetAffiliates())
affiliate1PB := &client.Affiliate{
Affiliate: &clientpb.Affiliate{
NodeId: affiliate1,
Addresses: [][]byte{{1, 2, 3}},
Hostname: "host1",
Nodename: "node1",
MachineType: "controlplane",
require.NoError(t, client1.SetLocalData(affiliate1PB, nil))
affiliate2PB := &client.Affiliate{
Affiliate: &clientpb.Affiliate{
NodeId: affiliate2,
Addresses: [][]byte{{2, 3, 4}},
Hostname: "host2",
Nodename: "node2",
MachineType: "worker",
require.NoError(t, client2.SetLocalData(affiliate2PB, nil))
// both clients should eventually discover each other
for {
t.Logf("client1 affiliates = %d", len(client1.GetAffiliates()))
if len(client1.GetAffiliates()) == 1 {
select {
case <-notify1:
case <-time.After(2 * time.Second):
t.Logf("client1 affiliates on timeout = %d", len(client1.GetAffiliates()))
require.Fail(t, "no incremental update")
require.Len(t, client1.GetAffiliates(), 1)
assert.Equal(t, []*client.Affiliate{affiliate2PB}, client1.GetAffiliates())
for {
t.Logf("client2 affiliates = %d", len(client1.GetAffiliates()))
if len(client2.GetAffiliates()) == 1 {
select {
case <-notify2:
case <-time.After(2 * time.Second):
require.Fail(t, "no incremental update")
require.Len(t, client2.GetAffiliates(), 1)
assert.Equal(t, []*client.Affiliate{affiliate1PB}, client2.GetAffiliates())
// update affiliate1, client2 should see the update
affiliate1PB.Endpoints = []*clientpb.Endpoint{
Ip: []byte{1, 2, 3, 4},
Port: 5678,
require.NoError(t, client1.SetLocalData(affiliate1PB, nil))
for {
select {
case <-notify2:
case <-time.After(time.Second):
require.Fail(t, "no incremental update")
if len(client2.GetAffiliates()[0].Endpoints) == 1 {
assert.Equal(t, []*client.Affiliate{affiliate1PB}, client2.GetAffiliates())
// delete affiliate1, client2 should see the update
for {
select {
case <-notify2:
case <-time.After(time.Second):
require.Fail(t, "no incremental update")
if len(client2.GetAffiliates()) == 0 {
require.Len(t, client2.GetAffiliates(), 0)
err = eg.Wait()
if err != nil && !errors.Is(err, context.Canceled) {
assert.NoError(t, err)
t.Run("AffiliateExpire", func(t *testing.T) {
clusterID := "cluster_2"
key := make([]byte, 32)
_, err := io.ReadFull(rand.Reader, key)
require.NoError(t, err)
cipher, err := aes.NewCipher(key)
require.NoError(t, err)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
affiliate1 := "af_1"
affiliate2 := "af_2"
client1, err := client.NewClient(client.Options{
Cipher: cipher,
Endpoint: endpoint,
ClusterID: clusterID,
AffiliateID: affiliate1,
TTL: time.Second,
Insecure: true,
require.NoError(t, err)
client2, err := client.NewClient(client.Options{
Cipher: cipher,
Endpoint: endpoint,
ClusterID: clusterID,
AffiliateID: affiliate2,
TTL: time.Minute,
Insecure: true,
require.NoError(t, err)
notify1 := make(chan struct{}, 1)
notify2 := make(chan struct{}, 1)
eg, ctx := errgroup.WithContext(ctx)
ctx1, cancel1 := context.WithCancel(ctx)
defer cancel1()
ctx2, cancel2 := context.WithCancel(ctx)
defer cancel2()
eg.Go(func() error {
return client1.Run(ctx1, logger, notify1)
eg.Go(func() error {
return client2.Run(ctx2, logger, notify2)
select {
case <-notify1:
case <-time.After(2 * time.Second):
require.Fail(t, "no initial snapshot update")
assert.Empty(t, client1.GetAffiliates())
select {
case <-notify2:
case <-time.After(2 * time.Second):
require.Fail(t, "no initial snapshot update")
assert.Empty(t, client2.GetAffiliates())
// client1 publishes an affiliate with short TTL
affiliate1PB := &client.Affiliate{
Affiliate: &clientpb.Affiliate{
NodeId: affiliate1,
Addresses: [][]byte{{1, 2, 3}},
Hostname: "host1",
Nodename: "node1",
MachineType: "controlplane",
require.NoError(t, client1.SetLocalData(affiliate1PB, nil))
// client2 should see the update from client1
for {
t.Logf("client2 affiliates = %d", len(client2.GetAffiliates()))
if len(client2.GetAffiliates()) == 1 {
select {
case <-notify2:
case <-time.After(2 * time.Second):
t.Logf("client2 affiliates on timeout = %d", len(client2.GetAffiliates()))
require.Fail(t, "no incremental update")
require.Len(t, client2.GetAffiliates(), 1)
assert.Equal(t, []*client.Affiliate{affiliate1PB}, client2.GetAffiliates())
// stop client1
for {
t.Logf("client2 affiliates = %d", len(client2.GetAffiliates()))
if len(client2.GetAffiliates()) == 0 {
select {
case <-notify2:
case <-time.After(2 * time.Second):
require.Fail(t, "no expiration")
require.Len(t, client2.GetAffiliates(), 0)
err = eg.Wait()
if err != nil && !errors.Is(err, context.Canceled) {
assert.NoError(t, err)
t.Run("Cluster1", func(t *testing.T) {
clusterSimulator(t, endpoint, logger, 5)
t.Run("Cluster2", func(t *testing.T) {
clusterSimulator(t, endpoint, logger, 15)
t.Run("Cluster3", func(t *testing.T) {
clusterSimulator(t, endpoint, logger, 50)
// clusterSimulator simulates cluster with a number of affiliates discovering each other.
func clusterSimulator(t *testing.T, endpoint string, logger *zap.Logger, numAffiliates int) {
clusterIDBytes := make([]byte, 32)
_, err := io.ReadFull(rand.Reader, clusterIDBytes)
require.NoError(t, err)
cluterID := base64.StdEncoding.EncodeToString(clusterIDBytes)
key := make([]byte, 32)
_, err = io.ReadFull(rand.Reader, key)
require.NoError(t, err)
cipher, err := aes.NewCipher(key)
require.NoError(t, err)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
affiliates := make([]*client.Client, numAffiliates)
for i := range affiliates {
affiliates[i], err = client.NewClient(client.Options{
Cipher: cipher,
Endpoint: endpoint,
ClusterID: cluterID,
AffiliateID: fmt.Sprintf("affiliate-%d", i),
ClientVersion: "v0.0.1",
TTL: 10 * time.Second,
Insecure: true,
require.NoError(t, err)
notifyCh := make([]chan struct{}, numAffiliates)
for i := range notifyCh {
notifyCh[i] = make(chan struct{}, 1)
eg, ctx := errgroup.WithContext(ctx)
for i := range affiliates {
i := i
eg.Go(func() error {
return affiliates[i].Run(ctx, logger, notifyCh[i])
// establish data for each affiliate
for i := range affiliates {
require.NoError(t, affiliates[i].SetLocalData(&client.Affiliate{
Affiliate: &clientpb.Affiliate{
NodeId: fmt.Sprintf("affiliate-%d", i),
Hostname: strconv.Itoa(i),
Endpoints: []*clientpb.Endpoint{
Ip: make([]byte, 4), // IPv4
Port: uint32((i + 1) * 10),
}, []client.Endpoint{
AffiliateID: fmt.Sprintf("affiliate-%d", (i+1)%numAffiliates),
Endpoints: []*clientpb.Endpoint{
Ip: make([]byte, 16), // IPv6
Port: uint32(((i+1)%numAffiliates + 1) * 100),
checkDiscoveredState := func(affiliateID int, discovered []*client.Affiliate) error {
if len(discovered) != numAffiliates-1 {
return fmt.Errorf("discovered count %d != expected %d", len(discovered), numAffiliates-1)
expected := make(map[int]struct{})
for i := 0; i < numAffiliates; i++ {
if i != affiliateID {
expected[i] = struct{}{}
for _, affiliate := range discovered {
var thisID int
thisID, err = strconv.Atoi(affiliate.Affiliate.Hostname)
require.NoError(t, err)
delete(expected, thisID)
// each affiliate should have two endpoints: one coming from itself, another from different affiliate
if len(affiliate.Endpoints) != 2 {
return fmt.Errorf("expected 2 endpoints, got %d", len(affiliate.Endpoints))
ports := []int{
expectedPorts := []int{
(thisID + 1) * 10,
(thisID + 1) * 100,
if !reflect.DeepEqual(expectedPorts, ports) {
return fmt.Errorf("expected ports %v, got %v", expectedPorts, ports)
if len(expected) > 0 {
return fmt.Errorf("some affiliates not discovered: %v", expected)
return nil
// eventually all affiliates should see discovered state
const NumAttempts = 50 // 50 * 100ms = 5s
for j := 0; j < NumAttempts; j++ {
matches := true
for i := range affiliates {
discovered := affiliates[i].GetAffiliates()
if err = checkDiscoveredState(i, discovered); err != nil {
matches = false
if matches {
if j == NumAttempts-1 {
assert.Fail(t, "state not converged")
time.Sleep(100 * time.Millisecond)
err = eg.Wait()
if err != nil && !errors.Is(err, context.Canceled) {
assert.NoError(t, err)
func TestClientRedirect(t *testing.T) {
srv1 := setupServer(t, 5000, "")
srv2 := setupServer(t, 5000, "")
endpoint := srv1.address
logger := zaptest.NewLogger(t)
clusterID := "cluster_redirect"
key := make([]byte, 32)
_, err := io.ReadFull(rand.Reader, key)
require.NoError(t, err)
cipher, err := aes.NewCipher(key)
require.NoError(t, err)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
affiliate1 := "affiliate_one"
affiliate2 := "affiliate_two"
client1, err := client.NewClient(client.Options{
Cipher: cipher,
Endpoint: endpoint,
ClusterID: clusterID,
AffiliateID: affiliate1,
TTL: time.Minute,
Insecure: true,
require.NoError(t, err)
client2, err := client.NewClient(client.Options{
Cipher: cipher,
Endpoint: endpoint,
ClusterID: clusterID,
AffiliateID: affiliate2,
TTL: time.Minute,
Insecure: true,
require.NoError(t, err)
notify1 := make(chan struct{}, 1)
notify2 := make(chan struct{}, 1)
eg, ctx := errgroup.WithContext(ctx)
eg.Go(func() error {
return client1.Run(ctx, logger, notify1)
eg.Go(func() error {
return client2.Run(ctx, logger, notify2)
select {
case <-notify1:
case <-time.After(2 * time.Second):
require.Fail(t, "no initial snapshot update")
assert.Empty(t, client1.GetAffiliates())
select {
case <-notify2:
case <-time.After(2 * time.Second):
require.Fail(t, "no initial snapshot update")
assert.Empty(t, client2.GetAffiliates())
affiliate1PB := &client.Affiliate{
Affiliate: &clientpb.Affiliate{
NodeId: affiliate1,
Addresses: [][]byte{{1, 2, 3}},
Hostname: "host1",
Nodename: "node1",
MachineType: "controlplane",
require.NoError(t, client1.SetLocalData(affiliate1PB, nil))
affiliate2PB := &client.Affiliate{
Affiliate: &clientpb.Affiliate{
NodeId: affiliate2,
Addresses: [][]byte{{2, 3, 4}},
Hostname: "host2",
Nodename: "node2",
MachineType: "worker",
require.NoError(t, client2.SetLocalData(affiliate2PB, nil))
// both clients should eventually discover each other
for {
t.Logf("client1 affiliates = %d", len(client1.GetAffiliates()))
if len(client1.GetAffiliates()) == 1 {
select {
case <-notify1:
case <-time.After(2 * time.Second):
t.Logf("client1 affiliates on timeout = %d", len(client1.GetAffiliates()))
require.Fail(t, "no incremental update")
require.Len(t, client1.GetAffiliates(), 1)
assert.Equal(t, []*client.Affiliate{affiliate2PB}, client1.GetAffiliates())
for {
t.Logf("client2 affiliates = %d", len(client1.GetAffiliates()))
if len(client2.GetAffiliates()) == 1 {
select {
case <-notify2:
case <-time.After(2 * time.Second):
require.Fail(t, "no incremental update")
require.Len(t, client2.GetAffiliates(), 1)
assert.Equal(t, []*client.Affiliate{affiliate1PB}, client2.GetAffiliates())
// drain notify channels
for {
select {
case <-notify1:
case <-notify2:
case <-time.After(time.Second):
break drainLoop
// make srv1 redirect all clients to srv2
srv1.restartWithRedirect(t, srv2.address)
// both clients should get updates about each other after a reconnect
for {
select {
case <-notify1:
t.Logf("reconnect: client1 affiliates = %d", len(client1.GetAffiliates()))
if len(client1.GetAffiliates()) == 1 {
break client1Loop
case <-time.After(2 * time.Second):
require.Fail(t, "no incremental update")
require.Len(t, client1.GetAffiliates(), 1)
assert.Equal(t, []*client.Affiliate{affiliate2PB}, client1.GetAffiliates())
for {
select {
case <-notify2:
t.Logf("reconnect: client2 affiliates = %d", len(client2.GetAffiliates()))
if len(client2.GetAffiliates()) == 1 {
break client2Loop
case <-time.After(2 * time.Second):
require.Fail(t, "no incremental update")
require.Len(t, client2.GetAffiliates(), 1)
assert.Equal(t, []*client.Affiliate{affiliate1PB}, client2.GetAffiliates())
// stop old srv1, graceful stop should work as all clients should have disconnected
// update affiliate1, client2 should see the update
affiliate1PB.Endpoints = []*clientpb.Endpoint{
Ip: []byte{1, 2, 3, 4},
Port: 5678,
require.NoError(t, client1.SetLocalData(affiliate1PB, nil))
for {
select {
case <-notify2:
case <-time.After(time.Second):
require.Fail(t, "no incremental update")
if len(client2.GetAffiliates()[0].Endpoints) == 1 {
assert.Equal(t, []*client.Affiliate{affiliate1PB}, client2.GetAffiliates())
err = eg.Wait()
if err != nil && !errors.Is(err, context.Canceled) {
assert.NoError(t, err)