412 lines
12 KiB
Go
Raw Normal View History

// Copyright (c) 2024 Sidero Labs, Inc.
//
// Use of this software is governed by the Business Source License
// included in the LICENSE file.
// Package storage implements persistent storage for the state of the discovery service.
package storage
import (
"bufio"
"context"
"errors"
"fmt"
"io"
"os"
"path/filepath"
"slices"
"sync"
"time"
"github.com/jonboulle/clockwork"
prom "github.com/prometheus/client_golang/prometheus"
"go.uber.org/zap"
storagepb "github.com/siderolabs/discovery-service/api/storage"
)
const (
labelOperation = "operation"
labelStatus = "status"
operationSave = "save"
operationLoad = "load"
statusSuccess = "success"
statusError = "error"
)
// Storage is a persistent storage for the state of the discovery service.
type Storage struct {
state Snapshotter
logger *zap.Logger
operationsMetric *prom.CounterVec
lastSnapshotSizeMetric *prom.GaugeVec
lastOperationClustersMetric *prom.GaugeVec
lastOperationAffiliatesMetric *prom.GaugeVec
lastOperationEndpointsMetric *prom.GaugeVec
lastOperationDurationMetric *prom.GaugeVec
path string
}
// Describe implements prometheus.Collector interface.
func (storage *Storage) Describe(descs chan<- *prom.Desc) {
prom.DescribeByCollect(storage, descs)
}
// Collect implements prometheus.Collector interface.
func (storage *Storage) Collect(metrics chan<- prom.Metric) {
storage.operationsMetric.Collect(metrics)
storage.lastSnapshotSizeMetric.Collect(metrics)
storage.lastOperationClustersMetric.Collect(metrics)
storage.lastOperationAffiliatesMetric.Collect(metrics)
storage.lastOperationEndpointsMetric.Collect(metrics)
storage.lastOperationDurationMetric.Collect(metrics)
}
// Snapshotter is an interface for exporting and importing cluster state.
type Snapshotter interface {
// ExportClusterSnapshots exports cluster snapshots to the given function.
ExportClusterSnapshots(f func(*storagepb.ClusterSnapshot) error) error
// ImportClusterSnapshots imports cluster snapshots from the given function.
ImportClusterSnapshots(f func() (*storagepb.ClusterSnapshot, bool, error)) error
}
// New creates a new instance of Storage.
func New(path string, state Snapshotter, logger *zap.Logger) *Storage {
return &Storage{
state: state,
logger: logger.With(zap.String("component", "storage"), zap.String("path", path)),
path: path,
operationsMetric: prom.NewCounterVec(prom.CounterOpts{
Name: "discovery_storage_operations_total",
Help: "The total number of storage operations.",
}, []string{labelOperation, labelStatus}),
lastSnapshotSizeMetric: prom.NewGaugeVec(prom.GaugeOpts{
Name: "discovery_storage_last_snapshot_size_bytes",
Help: "The size of the last processed snapshot in bytes.",
}, []string{labelOperation}),
lastOperationClustersMetric: prom.NewGaugeVec(prom.GaugeOpts{
Name: "discovery_storage_last_operation_clusters",
Help: "The number of clusters in the snapshot of the last operation.",
}, []string{labelOperation}),
lastOperationAffiliatesMetric: prom.NewGaugeVec(prom.GaugeOpts{
Name: "discovery_storage_last_operation_affiliates",
Help: "The number of affiliates in the snapshot of the last operation.",
}, []string{labelOperation}),
lastOperationEndpointsMetric: prom.NewGaugeVec(prom.GaugeOpts{
Name: "discovery_storage_last_operation_endpoints",
Help: "The number of endpoints in the snapshot of the last operation.",
}, []string{labelOperation}),
lastOperationDurationMetric: prom.NewGaugeVec(prom.GaugeOpts{
Name: "discovery_storage_last_operation_duration_seconds",
Help: "The duration of the last operation in seconds.",
}, []string{labelOperation}),
}
}
// Start starts the storage loop that periodically saves the state.
func (storage *Storage) Start(ctx context.Context, clock clockwork.Clock, interval time.Duration) error {
storage.logger.Info("start storage loop", zap.Duration("interval", interval))
ticker := clock.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
storage.logger.Info("received shutdown signal")
if err := storage.Save(); err != nil {
return fmt.Errorf("failed to save state on shutdown: %w", err)
}
if errors.Is(ctx.Err(), context.Canceled) {
return nil
}
return ctx.Err()
case <-ticker.Chan():
if err := storage.Save(); err != nil {
storage.logger.Error("failed to save state", zap.Error(err))
}
}
}
}
// Save saves all clusters' states into the persistent storage.
func (storage *Storage) Save() (err error) {
start := time.Now()
defer func() {
if err != nil {
storage.operationsMetric.WithLabelValues(operationSave, statusError).Inc()
}
}()
// never panic, convert it into an error instead
defer func() {
if recovered := recover(); recovered != nil {
err = fmt.Errorf("save panicked: %v", recovered)
}
}()
if err = os.MkdirAll(filepath.Dir(storage.path), 0o755); err != nil {
return fmt.Errorf("failed to create directory path: %w", err)
}
tmpFile, err := getTempFile(storage.path)
if err != nil {
return fmt.Errorf("failed to create temporary file: %w", err)
}
defer func() {
tmpFile.Close() //nolint:errcheck
os.Remove(tmpFile.Name()) //nolint:errcheck
}()
stats, err := storage.Export(tmpFile)
if err != nil {
return fmt.Errorf("failed to write snapshot: %w", err)
}
if err = commitTempFile(tmpFile, storage.path); err != nil {
return fmt.Errorf("failed to commit temporary file: %w", err)
}
duration := time.Since(start)
storage.logger.Info("state saved", zap.Int("clusters", stats.NumClusters), zap.Int("affiliates", stats.NumAffiliates),
zap.Int("endpoints", stats.NumEndpoints), zap.Duration("duration", duration), zap.Int("size_bytes", stats.Size))
storage.operationsMetric.WithLabelValues(operationSave, statusSuccess).Inc()
storage.lastSnapshotSizeMetric.WithLabelValues(operationSave).Set(float64(stats.Size))
storage.lastOperationClustersMetric.WithLabelValues(operationSave).Set(float64(stats.NumClusters))
storage.lastOperationAffiliatesMetric.WithLabelValues(operationSave).Set(float64(stats.NumAffiliates))
storage.lastOperationEndpointsMetric.WithLabelValues(operationSave).Set(float64(stats.NumEndpoints))
storage.lastOperationDurationMetric.WithLabelValues(operationSave).Set(duration.Seconds())
return nil
}
// Load loads all clusters' states from the persistent storage.
func (storage *Storage) Load() (err error) {
defer func() {
if err != nil {
storage.operationsMetric.WithLabelValues(operationLoad, statusError).Inc()
}
}()
// never panic, convert it into an error instead
defer func() {
if recovered := recover(); recovered != nil {
err = fmt.Errorf("load panicked: %v", recovered)
}
}()
start := time.Now()
// open file for reading
file, err := os.Open(storage.path)
if err != nil {
return fmt.Errorf("failed to open file: %w", err)
}
defer file.Close() //nolint:errcheck
stats, err := storage.Import(file)
if err != nil {
return fmt.Errorf("failed to read snapshot: %w", err)
}
if err = file.Close(); err != nil {
return fmt.Errorf("failed to close file: %w", err)
}
duration := time.Since(start)
storage.logger.Info("state loaded", zap.Int("clusters", stats.NumClusters), zap.Int("affiliates", stats.NumAffiliates),
zap.Int("endpoints", stats.NumEndpoints), zap.Duration("duration", duration), zap.Int("size_bytes", stats.Size))
storage.operationsMetric.WithLabelValues(operationLoad, statusSuccess).Inc()
storage.lastSnapshotSizeMetric.WithLabelValues(operationLoad).Set(float64(stats.Size))
storage.lastOperationClustersMetric.WithLabelValues(operationLoad).Set(float64(stats.NumClusters))
storage.lastOperationAffiliatesMetric.WithLabelValues(operationLoad).Set(float64(stats.NumAffiliates))
storage.lastOperationEndpointsMetric.WithLabelValues(operationLoad).Set(float64(stats.NumEndpoints))
storage.lastOperationDurationMetric.WithLabelValues(operationLoad).Set(duration.Seconds())
return nil
}
// Import imports all clusters' states from the given reader.
//
// When importing, we avoid unmarshalling to the storagepb.StateSnapshot type directly, as it causes an allocation of all the cluster snapshots at once.
// Instead, we process clusters in a streaming manner, unmarshaling them one by one and importing them into the state.
func (storage *Storage) Import(reader io.Reader) (SnapshotStats, error) {
size := 0
numClusters := 0
numAffiliates := 0
numEndpoints := 0
buffer := make([]byte, 256)
bufferedReader := bufio.NewReader(reader)
// unmarshal the clusters in a streaming manner and import them into the state
if err := storage.state.ImportClusterSnapshots(func() (*storagepb.ClusterSnapshot, bool, error) {
headerSize, clusterSize, err := decodeClusterSnapshotHeader(bufferedReader)
if err != nil {
if err == io.EOF { //nolint:errorlint
return nil, false, nil
}
return nil, false, fmt.Errorf("failed to decode cluster header: %w", err)
}
if clusterSize > cap(buffer) {
buffer = slices.Grow(buffer, clusterSize)
}
buffer = buffer[:clusterSize]
if _, err = io.ReadFull(bufferedReader, buffer); err != nil {
return nil, false, fmt.Errorf("failed to read bytes: %w", err)
}
clusterSnapshot, err := decodeClusterSnapshot(buffer)
if err != nil {
return nil, false, fmt.Errorf("failed to decode cluster: %w", err)
}
buffer = buffer[:0]
// update stats
size += headerSize + clusterSize
numClusters++
numAffiliates += len(clusterSnapshot.Affiliates)
for _, affiliate := range clusterSnapshot.Affiliates {
numEndpoints += len(affiliate.Endpoints)
}
return clusterSnapshot, true, nil
}); err != nil {
return SnapshotStats{}, fmt.Errorf("failed to import clusters: %w", err)
}
return SnapshotStats{
Size: size,
NumClusters: numClusters,
NumAffiliates: numAffiliates,
NumEndpoints: numEndpoints,
}, nil
}
// Export exports all clusters' states into the given writer.
//
// When exporting, we avoid marshaling to the storagepb.StateSnapshot type directly, as it causes an allocation of all the cluster snapshots at once.
// Instead, we process clusters in a streaming manner, marshaling them one by one and exporting them into the writer.
func (storage *Storage) Export(writer io.Writer) (SnapshotStats, error) {
numClusters := 0
numAffiliates := 0
numEndpoints := 0
size := 0
var buffer []byte
bufferedWriter := bufio.NewWriter(writer)
// marshal the clusters in a streaming manner and export them into the writer
if err := storage.state.ExportClusterSnapshots(func(snapshot *storagepb.ClusterSnapshot) error {
var err error
buffer, err = encodeClusterSnapshot(buffer, snapshot)
if err != nil {
return fmt.Errorf("failed to encode cluster: %w", err)
}
written, err := bufferedWriter.Write(buffer)
if err != nil {
return fmt.Errorf("failed to write cluster: %w", err)
}
// prepare the buffer for the next iteration - reset it
buffer = buffer[:0]
// update stats
size += written
numClusters++
numAffiliates += len(snapshot.Affiliates)
for _, affiliate := range snapshot.Affiliates {
numEndpoints += len(affiliate.Endpoints)
}
return nil
}); err != nil {
return SnapshotStats{}, fmt.Errorf("failed to snapshot clusters: %w", err)
}
if err := bufferedWriter.Flush(); err != nil {
return SnapshotStats{}, fmt.Errorf("failed to flush writer: %w", err)
}
return SnapshotStats{
Size: size,
NumClusters: numClusters,
NumAffiliates: numAffiliates,
NumEndpoints: numEndpoints,
}, nil
}
func getTempFile(dst string) (*os.File, error) {
tmpFile, err := os.OpenFile(dst+".tmp", os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0o666)
if err != nil {
return nil, fmt.Errorf("failed to create file: %w", err)
}
return tmpFile, nil
}
// commitTempFile commits the temporary file to the destination and removes it.
func commitTempFile(tmpFile *os.File, dst string) error {
renamed := false
closer := sync.OnceValue(tmpFile.Close)
defer func() {
closer() //nolint:errcheck
if !renamed {
os.Remove(tmpFile.Name()) //nolint:errcheck
}
}()
if err := tmpFile.Sync(); err != nil {
return fmt.Errorf("failed to sync data: %w", err)
}
if err := closer(); err != nil {
return fmt.Errorf("failed to close file: %w", err)
}
if err := os.Rename(tmpFile.Name(), dst); err != nil {
return fmt.Errorf("failed to rename file: %w", err)
}
renamed = true
return nil
}
// SnapshotStats contains statistics about a snapshot.
type SnapshotStats struct {
Size int
NumClusters int
NumAffiliates int
NumEndpoints int
}