Andrey Smirnov 761d53a418 feat: update dependencies
Rekres, update dependencies.

Mostly to bring in and
get it tested, as this repo has client tests.

Signed-off-by: Andrey Smirnov <>
2025-02-26 15:34:11 +04:00

459 lines
12 KiB

// Copyright (c) 2024 Sidero Labs, Inc.
// Use of this software is governed by the Business Source License
// included in the LICENSE file.
package storage_test
import (
storagepb ""
func TestExport(t *testing.T) {
for _, tc := range []struct { //nolint:govet
name string
snapshot *storagepb.StateSnapshot
"empty state",
"small state",
&storagepb.StateSnapshot{Clusters: []*storagepb.ClusterSnapshot{{Id: "a"}, {Id: "b"}}},
"large state",
} {
t.Run(, func(t *testing.T) {
tempDir := t.TempDir()
path := filepath.Join(tempDir, "test.binpb")
logger := zaptest.NewLogger(t)
state := state.NewState(logger)
importTestState(t, state, tc.snapshot)
stateStorage := storage.New(path, state, logger)
var buffer bytes.Buffer
exportStats, err := stateStorage.Export(&buffer)
require.NoError(t, err)
assert.Equal(t, statsForSnapshot(tc.snapshot), exportStats)
exported := &storagepb.StateSnapshot{}
require.NoError(t, exported.UnmarshalVT(buffer.Bytes()))
requireEqualIgnoreOrder(t, tc.snapshot, exported)
func TestImport(t *testing.T) {
for _, tc := range []struct { //nolint:govet
name string
snapshot *storagepb.StateSnapshot
"empty state",
"small state",
&storagepb.StateSnapshot{Clusters: []*storagepb.ClusterSnapshot{{Id: "a"}, {Id: "b"}}},
"large state",
} {
t.Run(, func(t *testing.T) {
path := filepath.Join(t.TempDir(), "test.binpb")
logger := zaptest.NewLogger(t)
state := state.NewState(logger)
stateStorage := storage.New(path, state, logger)
data, err := tc.snapshot.MarshalVT()
require.NoError(t, err)
importStats, err := stateStorage.Import(bytes.NewReader(data))
require.NoError(t, err)
require.Equal(t, statsForSnapshot(tc.snapshot), importStats)
importedState := exportTestState(t, state)
requireEqualIgnoreOrder(t, tc.snapshot, importedState)
func TestImportMaxSize(t *testing.T) {
cluster := buildMaxSizeCluster()
stateSnapshot := &storagepb.StateSnapshot{Clusters: []*storagepb.ClusterSnapshot{cluster}}
path := filepath.Join(t.TempDir(), "test.binpb")
logger := zaptest.NewLogger(t)
state := state.NewState(logger)
stateStorage := storage.New(path, state, logger)
clusterData, err := cluster.MarshalVT()
require.NoError(t, err)
require.Equal(t, len(clusterData), storage.MaxClusterSize)
data, err := stateSnapshot.MarshalVT()
require.NoError(t, err)
t.Logf("max cluster marshaled size: %d", len(data))
_, err = stateStorage.Import(bytes.NewReader(data))
require.NoError(t, err)
// add one more affiliate to trigger an overflow
cluster.Affiliates = append(cluster.Affiliates, &storagepb.AffiliateSnapshot{
Id: "overflow",
data, err = stateSnapshot.MarshalVT()
require.NoError(t, err)
_, err = stateStorage.Import(bytes.NewReader(data))
require.ErrorIs(t, err, storage.ErrClusterSnapshotTooLarge)
func TestStorage(t *testing.T) {
snapshot := buildTestSnapshot(10)
tempDir := t.TempDir()
path := filepath.Join(tempDir, "test.binpb")
state := newTestSnapshotter(t, snapshot)
logger := zaptest.NewLogger(t)
stateStorage := storage.New(path, state, logger)
// test save
require.NoError(t, stateStorage.Save())
savedBytes, err := os.ReadFile(path)
require.NoError(t, err)
savedSnapshot := &storagepb.StateSnapshot{}
require.NoError(t, savedSnapshot.UnmarshalVT(savedBytes))
requireEqualIgnoreOrder(t, snapshot, savedSnapshot)
// test load
require.NoError(t, stateStorage.Load())
require.Len(t, state.getLoads(), 1)
requireEqualIgnoreOrder(t, snapshot, state.getLoads()[0])
// modify, save & load again to assert that the file content gets overwritten
snapshot.Clusters[1].Affiliates[0].Data = []byte("new aff1 data")
require.NoError(t, stateStorage.Save())
require.NoError(t, stateStorage.Load())
require.Len(t, state.getLoads(), 2)
requireEqualIgnoreOrder(t, snapshot, state.getLoads()[1])
func TestSchedule(t *testing.T) {
clock := clockwork.NewFakeClock()
snapshot := buildTestSnapshot(10)
tempDir := t.TempDir()
path := filepath.Join(tempDir, "test.binpb")
state := newTestSnapshotter(t, snapshot)
logger := zaptest.NewLogger(t)
stateStorage := storage.New(path, state, logger)
ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second)
// start the periodic storage and wait for it to block on the timer
errCh := make(chan error)
go func() {
errCh <- stateStorage.Start(ctx, clock, 10*time.Minute)
require.NoError(t, clock.BlockUntilContext(ctx, 1))
// advance time to trigger the first snapshot and assert it
clock.Advance(13 * time.Minute)
require.EventuallyWithT(t, func(collect *assert.CollectT) {
assert.Equal(collect, 1, state.getSnapshots())
}, 2*time.Second, 100*time.Millisecond)
// advance time to trigger the second snapshot and assert it
clock.Advance(10 * time.Minute)
require.EventuallyWithT(t, func(collect *assert.CollectT) {
assert.Equal(collect, 2, state.getSnapshots())
}, 2*time.Second, 100*time.Millisecond)
// cancel the context to stop the storage loop and wait for it to exit
require.NoError(t, <-errCh)
// assert that the state was saved on shutdown
require.EventuallyWithT(t, func(collect *assert.CollectT) {
assert.Equal(collect, 3, state.getSnapshots())
}, 2*time.Second, 100*time.Millisecond)
// testSnapshotter is a mock implementation of storage.Snapshotter for testing purposes.
// It keeps track of the loads and the number of snapshots that have been performed to be used in assertions.
type testSnapshotter struct {
exportData *storagepb.StateSnapshot
tb testing.TB
loads []*storagepb.StateSnapshot
snapshots int
lock sync.Mutex
func newTestSnapshotter(tb testing.TB, exportData *storagepb.StateSnapshot) *testSnapshotter {
state := state.NewState(zaptest.NewLogger(tb))
importTestState(tb, state, exportData)
return &testSnapshotter{exportData: exportData, tb: tb}
func (m *testSnapshotter) getSnapshots() int {
defer m.lock.Unlock()
return m.snapshots
func (m *testSnapshotter) getLoads() []*storagepb.StateSnapshot {
defer m.lock.Unlock()
return append([]*storagepb.StateSnapshot(nil), m.loads...)
// ExportClusterSnapshots implements storage.Snapshotter interface.
func (m *testSnapshotter) ExportClusterSnapshots(f func(snapshot *storagepb.ClusterSnapshot) error) error {
defer m.lock.Unlock()
tempState := state.NewState(zaptest.NewLogger(m.tb))
importTestState(m.tb, tempState, m.exportData)
if err := tempState.ExportClusterSnapshots(f); err != nil {
return err
return nil
// ImportClusterSnapshots implements storage.Snapshotter interface.
func (m *testSnapshotter) ImportClusterSnapshots(f func() (*storagepb.ClusterSnapshot, bool, error)) error {
defer m.lock.Unlock()
tempState := state.NewState(zaptest.NewLogger(m.tb))
if err := tempState.ImportClusterSnapshots(f); err != nil {
return err
m.loads = append(m.loads, exportTestState(m.tb, tempState))
return nil
func statsForSnapshot(snapshot *storagepb.StateSnapshot) storage.SnapshotStats {
numAffiliates := 0
numEndpoints := 0
for _, cluster := range snapshot.Clusters {
numAffiliates += len(cluster.Affiliates)
for _, affiliate := range cluster.Affiliates {
numEndpoints += len(affiliate.Endpoints)
return storage.SnapshotStats{
NumClusters: len(snapshot.Clusters),
NumAffiliates: numAffiliates,
NumEndpoints: numEndpoints,
Size: snapshot.SizeVT(),
func buildTestSnapshot(numClusters int) *storagepb.StateSnapshot {
clusters := make([]*storagepb.ClusterSnapshot, 0, numClusters)
for i := range numClusters {
affiliates := make([]*storagepb.AffiliateSnapshot, 0, 5)
for j := range 5 {
affiliates = append(affiliates, &storagepb.AffiliateSnapshot{
Id: fmt.Sprintf("aff%d", j),
Expiration: timestamppb.New(time.Now().Add(time.Hour)),
Data: []byte(fmt.Sprintf("aff%d data", j)),
if i%2 == 0 {
affiliates[0].Endpoints = []*storagepb.EndpointSnapshot{
Expiration: timestamppb.New(time.Now().Add(time.Hour)),
Data: []byte(fmt.Sprintf("endpoint%d data", i)),
clusters = append(clusters, &storagepb.ClusterSnapshot{
Id: fmt.Sprintf("cluster%d", i),
Affiliates: affiliates,
return &storagepb.StateSnapshot{
Clusters: clusters,
// buildMaxSizeCluster creates a cluster snapshot with the maximum possible marshaled size within the limits of the discovery service.
func buildMaxSizeCluster() *storagepb.ClusterSnapshot {
largestTTL := &timestamppb.Timestamp{
Seconds: math.MinInt64,
Nanos: math.MinInt32,
} // the timestamp with the maximum possible marshaled size
affiliates := make([]*storagepb.AffiliateSnapshot, 0, limits.ClusterAffiliatesMax)
for range limits.ClusterAffiliatesMax {
endpoints := make([]*storagepb.EndpointSnapshot, 0, limits.AffiliateEndpointsMax)
for range limits.AffiliateEndpointsMax {
endpoints = append(endpoints, &storagepb.EndpointSnapshot{
Expiration: largestTTL,
Data: bytes.Repeat([]byte("a"), limits.AffiliateDataMax),
affiliates = append(affiliates, &storagepb.AffiliateSnapshot{
Id: strings.Repeat("a", limits.AffiliateIDMax),
Expiration: largestTTL,
Data: bytes.Repeat([]byte("a"), limits.AffiliateDataMax),
Endpoints: endpoints,
return &storagepb.ClusterSnapshot{
Id: strings.Repeat("c", limits.ClusterIDMax),
Affiliates: affiliates,
func importTestState(tb testing.TB, state *state.State, snapshot *storagepb.StateSnapshot) {
clusters := snapshot.Clusters
i := 0
err := state.ImportClusterSnapshots(func() (*storagepb.ClusterSnapshot, bool, error) {
if i >= len(clusters) {
return nil, false, nil
cluster := clusters[i]
return cluster, true, nil
require.NoError(tb, err)
func exportTestState(tb testing.TB, state *state.State) *storagepb.StateSnapshot {
snapshot := &storagepb.StateSnapshot{}
err := state.ExportClusterSnapshots(func(cluster *storagepb.ClusterSnapshot) error {
snapshot.Clusters = append(snapshot.Clusters, cluster.CloneVT()) // clone the cluster here, as its reference is reused across iterations
return nil
require.NoError(tb, err)
return snapshot
func requireEqualIgnoreOrder(tb testing.TB, expected, actual *storagepb.StateSnapshot) {
a := expected.CloneVT()
b := actual.CloneVT()
// sort clusters
for _, st := range []*storagepb.StateSnapshot{a, b} {
slices.SortFunc(st.Clusters, func(a, b *storagepb.ClusterSnapshot) int {
return strings.Compare(a.Id, b.Id)
// sort affiliates
for _, cluster := range append(a.Clusters, b.Clusters...) {
slices.SortFunc(cluster.Affiliates, func(a, b *storagepb.AffiliateSnapshot) int {
return strings.Compare(a.Id, b.Id)
assert.True(tb, a.EqualVT(b))