mirror of
https://github.com/siderolabs/discovery-service.git
synced 2025-03-14 09:55:08 +00:00
317 lines
8.1 KiB
Go
317 lines
8.1 KiB
Go
|
// Copyright (c) 2024 Sidero Labs, Inc.
|
||
|
//
|
||
|
// Use of this software is governed by the Business Source License
|
||
|
// included in the LICENSE file.
|
||
|
|
||
|
// Package service implements the high-level service entry point.
|
||
|
package service
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"net"
|
||
|
"net/http"
|
||
|
"time"
|
||
|
|
||
|
grpc_prometheus "github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus"
|
||
|
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging"
|
||
|
grpc_recovery "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/recovery"
|
||
|
"github.com/jonboulle/clockwork"
|
||
|
prom "github.com/prometheus/client_golang/prometheus"
|
||
|
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||
|
"github.com/siderolabs/discovery-api/api/v1alpha1/server/pb"
|
||
|
"github.com/siderolabs/go-debug"
|
||
|
"go.uber.org/zap"
|
||
|
"golang.org/x/sync/errgroup"
|
||
|
"google.golang.org/grpc"
|
||
|
"google.golang.org/grpc/codes"
|
||
|
"google.golang.org/grpc/keepalive"
|
||
|
"google.golang.org/grpc/status"
|
||
|
|
||
|
"github.com/siderolabs/discovery-service/internal/landing"
|
||
|
"github.com/siderolabs/discovery-service/internal/limiter"
|
||
|
"github.com/siderolabs/discovery-service/internal/state"
|
||
|
"github.com/siderolabs/discovery-service/internal/state/storage"
|
||
|
"github.com/siderolabs/discovery-service/pkg/limits"
|
||
|
"github.com/siderolabs/discovery-service/pkg/server"
|
||
|
)
|
||
|
|
||
|
// Options are the configuration options for the service.
|
||
|
type Options struct {
|
||
|
MetricsRegisterer prom.Registerer
|
||
|
|
||
|
LandingAddr string
|
||
|
MetricsAddr string
|
||
|
SnapshotPath string
|
||
|
DebugAddr string
|
||
|
RedirectEndpoint string
|
||
|
ListenAddr string
|
||
|
|
||
|
GCInterval time.Duration
|
||
|
SnapshotInterval time.Duration
|
||
|
|
||
|
LandingServerEnabled bool
|
||
|
DebugServerEnabled bool
|
||
|
MetricsServerEnabled bool
|
||
|
SnapshotsEnabled bool
|
||
|
}
|
||
|
|
||
|
// Run starts the service with the given options.
|
||
|
func Run(ctx context.Context, options Options, logger *zap.Logger) error {
|
||
|
logger.Info("service starting")
|
||
|
|
||
|
defer logger.Info("service shut down")
|
||
|
|
||
|
recoveryOpt := grpc_recovery.WithRecoveryHandler(recoveryHandler(logger))
|
||
|
|
||
|
limiter := limiter.NewIPRateLimiter(limits.IPRateRequestsPerSecondMax, limits.IPRateBurstSizeMax)
|
||
|
|
||
|
metrics := grpc_prometheus.NewServerMetrics(
|
||
|
grpc_prometheus.WithServerHandlingTimeHistogram(grpc_prometheus.WithHistogramBuckets([]float64{0.01, 0.1, 0.25, 0.5, 1.0, 2.5})),
|
||
|
)
|
||
|
|
||
|
loggingOpts := []logging.Option{
|
||
|
logging.WithLogOnEvents(logging.StartCall, logging.FinishCall),
|
||
|
logging.WithFieldsFromContext(logging.ExtractFields),
|
||
|
}
|
||
|
|
||
|
//nolint:contextcheck
|
||
|
serverOptions := []grpc.ServerOption{
|
||
|
grpc.ChainUnaryInterceptor(
|
||
|
server.AddLoggingFieldsUnaryServerInterceptor(),
|
||
|
logging.UnaryServerInterceptor(interceptorLogger(logger), loggingOpts...),
|
||
|
server.RateLimitUnaryServerInterceptor(limiter),
|
||
|
metrics.UnaryServerInterceptor(),
|
||
|
grpc_recovery.UnaryServerInterceptor(recoveryOpt),
|
||
|
),
|
||
|
grpc.ChainStreamInterceptor(
|
||
|
server.AddLoggingFieldsStreamServerInterceptor(),
|
||
|
server.RateLimitStreamServerInterceptor(limiter),
|
||
|
logging.StreamServerInterceptor(interceptorLogger(logger), loggingOpts...),
|
||
|
metrics.StreamServerInterceptor(),
|
||
|
grpc_recovery.StreamServerInterceptor(recoveryOpt),
|
||
|
),
|
||
|
grpc.KeepaliveEnforcementPolicy(keepalive.EnforcementPolicy{
|
||
|
MinTime: 10 * time.Second,
|
||
|
}),
|
||
|
grpc.SharedWriteBuffer(true),
|
||
|
grpc.ReadBufferSize(16 * 1024),
|
||
|
grpc.WriteBufferSize(16 * 1024),
|
||
|
}
|
||
|
|
||
|
state := state.NewState(logger)
|
||
|
|
||
|
var stateStorage *storage.Storage
|
||
|
|
||
|
if options.SnapshotsEnabled {
|
||
|
stateStorage = storage.New(options.SnapshotPath, state, logger)
|
||
|
if err := stateStorage.Load(); err != nil {
|
||
|
logger.Warn("failed to load state from storage", zap.Error(err))
|
||
|
}
|
||
|
} else {
|
||
|
logger.Info("snapshots are disabled")
|
||
|
}
|
||
|
|
||
|
srv := server.NewClusterServer(state, ctx.Done(), options.RedirectEndpoint)
|
||
|
|
||
|
lis, err := net.Listen("tcp", options.ListenAddr)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("failed to listen: %w", err)
|
||
|
}
|
||
|
|
||
|
landingLis, err := net.Listen("tcp", options.LandingAddr)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("failed to listen: %w", err)
|
||
|
}
|
||
|
|
||
|
s := grpc.NewServer(serverOptions...)
|
||
|
pb.RegisterClusterServer(s, srv)
|
||
|
|
||
|
metrics.InitializeMetrics(s)
|
||
|
|
||
|
var (
|
||
|
metricsServer http.Server
|
||
|
landingServer http.Server
|
||
|
)
|
||
|
|
||
|
if options.MetricsServerEnabled {
|
||
|
var metricsMux http.ServeMux
|
||
|
|
||
|
metricsMux.Handle("/metrics", promhttp.Handler())
|
||
|
|
||
|
metricsServer = http.Server{
|
||
|
Addr: options.MetricsAddr,
|
||
|
Handler: &metricsMux,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if options.LandingServerEnabled {
|
||
|
landingServer = http.Server{
|
||
|
Handler: landing.Handler(state, logger),
|
||
|
}
|
||
|
}
|
||
|
|
||
|
eg, ctx := errgroup.WithContext(ctx)
|
||
|
|
||
|
if options.SnapshotsEnabled {
|
||
|
eg.Go(func() error {
|
||
|
return stateStorage.Start(ctx, clockwork.NewRealClock(), options.SnapshotInterval)
|
||
|
})
|
||
|
}
|
||
|
|
||
|
eg.Go(func() error {
|
||
|
logger.Info("gRPC server starting", zap.Stringer("address", lis.Addr()))
|
||
|
|
||
|
if serveErr := s.Serve(lis); serveErr != nil {
|
||
|
return fmt.Errorf("failed to serve: %w", serveErr)
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
})
|
||
|
|
||
|
if options.LandingServerEnabled {
|
||
|
eg.Go(func() error {
|
||
|
logger.Info("landing server starting", zap.Stringer("address", landingLis.Addr()))
|
||
|
|
||
|
if serveErr := landingServer.Serve(landingLis); serveErr != nil && !errors.Is(serveErr, http.ErrServerClosed) {
|
||
|
return fmt.Errorf("failed to serve: %w", serveErr)
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
})
|
||
|
}
|
||
|
|
||
|
if options.MetricsServerEnabled {
|
||
|
eg.Go(func() error {
|
||
|
logger.Info("metrics starting", zap.String("address", metricsServer.Addr))
|
||
|
|
||
|
if serveErr := metricsServer.ListenAndServe(); serveErr != nil && !errors.Is(serveErr, http.ErrServerClosed) {
|
||
|
return serveErr
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
})
|
||
|
}
|
||
|
|
||
|
eg.Go(func() error {
|
||
|
<-ctx.Done()
|
||
|
|
||
|
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||
|
defer shutdownCancel()
|
||
|
|
||
|
s.GracefulStop()
|
||
|
|
||
|
if options.LandingServerEnabled {
|
||
|
landingServer.Shutdown(ctx) //nolint:errcheck
|
||
|
}
|
||
|
|
||
|
if options.MetricsServerEnabled {
|
||
|
metricsServer.Shutdown(shutdownCtx) //nolint:errcheck,contextcheck
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
})
|
||
|
|
||
|
eg.Go(func() error {
|
||
|
state.RunGC(ctx, logger, options.GCInterval)
|
||
|
|
||
|
return nil
|
||
|
})
|
||
|
|
||
|
eg.Go(func() error {
|
||
|
limiter.RunGC(ctx)
|
||
|
|
||
|
return nil
|
||
|
})
|
||
|
|
||
|
if options.DebugServerEnabled {
|
||
|
eg.Go(func() error {
|
||
|
return debug.ListenAndServe(ctx, options.DebugAddr, func(msg string) { logger.Info(msg) })
|
||
|
})
|
||
|
}
|
||
|
|
||
|
if options.MetricsRegisterer != nil {
|
||
|
collectors := []prom.Collector{state, srv, metrics, stateStorage}
|
||
|
|
||
|
defer unregisterCollectors(options.MetricsRegisterer, collectors...)
|
||
|
|
||
|
if err = registerCollectors(options.MetricsRegisterer, collectors...); err != nil {
|
||
|
return fmt.Errorf("failed to register collectors: %w", err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return eg.Wait()
|
||
|
}
|
||
|
|
||
|
func recoveryHandler(logger *zap.Logger) grpc_recovery.RecoveryHandlerFunc {
|
||
|
return func(p interface{}) error {
|
||
|
if logger != nil {
|
||
|
logger.Error("grpc panic", zap.Any("panic", p), zap.Stack("stack"))
|
||
|
}
|
||
|
|
||
|
return status.Errorf(codes.Internal, "%v", p)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func interceptorLogger(l *zap.Logger) logging.Logger {
|
||
|
return logging.LoggerFunc(func(_ context.Context, lvl logging.Level, msg string, fields ...any) {
|
||
|
f := make([]zap.Field, 0, len(fields)/2)
|
||
|
|
||
|
for i := 0; i < len(fields); i += 2 {
|
||
|
key := fields[i].(string) //nolint:forcetypeassert,errcheck
|
||
|
value := fields[i+1]
|
||
|
|
||
|
switch v := value.(type) {
|
||
|
case string:
|
||
|
f = append(f, zap.String(key, v))
|
||
|
case int:
|
||
|
f = append(f, zap.Int(key, v))
|
||
|
case bool:
|
||
|
f = append(f, zap.Bool(key, v))
|
||
|
default:
|
||
|
f = append(f, zap.Any(key, v))
|
||
|
}
|
||
|
}
|
||
|
|
||
|
logger := l.WithOptions(zap.AddCallerSkip(1)).With(f...)
|
||
|
|
||
|
switch lvl {
|
||
|
case logging.LevelDebug:
|
||
|
logger.Debug(msg)
|
||
|
case logging.LevelInfo:
|
||
|
logger.Info(msg)
|
||
|
case logging.LevelWarn:
|
||
|
logger.Warn(msg)
|
||
|
case logging.LevelError:
|
||
|
logger.Error(msg)
|
||
|
default:
|
||
|
panic(fmt.Sprintf("unknown level %v", lvl))
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
|
||
|
func unregisterCollectors(registerer prom.Registerer, collectors ...prom.Collector) {
|
||
|
for _, collector := range collectors {
|
||
|
if collector == nil {
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
registerer.Unregister(collector)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func registerCollectors(registerer prom.Registerer, collectors ...prom.Collector) (err error) {
|
||
|
for _, collector := range collectors {
|
||
|
if collector == nil {
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
if err = registerer.Register(collector); err != nil {
|
||
|
return fmt.Errorf("failed to register collector: %w", err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|