2024-03-12 18:12:45 +04:00
|
|
|
// Copyright (c) 2024 Sidero Labs, Inc.
|
2022-02-02 22:03:31 +01:00
|
|
|
//
|
|
|
|
// Use of this software is governed by the Business Source License
|
|
|
|
// included in the LICENSE file.
|
|
|
|
|
|
|
|
package server
|
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
|
|
|
|
|
|
|
"google.golang.org/grpc"
|
|
|
|
"google.golang.org/grpc/codes"
|
|
|
|
"google.golang.org/grpc/status"
|
|
|
|
|
2024-01-12 18:10:16 +04:00
|
|
|
"github.com/siderolabs/discovery-service/internal/limiter"
|
2022-02-02 22:03:31 +01:00
|
|
|
)
|
|
|
|
|
2024-01-12 18:10:16 +04:00
|
|
|
func pause(ctx context.Context, limiter *limiter.IPRateLimiter) error {
|
2024-03-12 18:12:45 +04:00
|
|
|
iPAddr := PeerAddress(ctx)
|
2024-01-12 18:10:16 +04:00
|
|
|
if !IsZero(iPAddr) {
|
2022-02-02 22:03:31 +01:00
|
|
|
limit := limiter.Get(iPAddr)
|
|
|
|
|
|
|
|
err := limit.Wait(ctx)
|
|
|
|
if err != nil {
|
|
|
|
return status.Error(codes.ResourceExhausted, "rate limit exceeds request timeout")
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// RateLimitUnaryServerInterceptor limits Unary PRCs from an IPAdress.
|
2024-01-12 18:10:16 +04:00
|
|
|
func RateLimitUnaryServerInterceptor(limiter *limiter.IPRateLimiter) grpc.UnaryServerInterceptor {
|
2024-03-12 18:12:45 +04:00
|
|
|
return func(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
|
2022-02-02 22:03:31 +01:00
|
|
|
err = pause(ctx, limiter)
|
|
|
|
if err != nil {
|
|
|
|
return resp, err
|
|
|
|
}
|
|
|
|
|
|
|
|
return handler(ctx, req)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// RateLimitStreamServerInterceptor limits Stream PRCs from an IPAdress.
|
2024-01-12 18:10:16 +04:00
|
|
|
func RateLimitStreamServerInterceptor(limiter *limiter.IPRateLimiter) grpc.StreamServerInterceptor {
|
2024-03-12 18:12:45 +04:00
|
|
|
return func(srv interface{}, ss grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
2022-02-02 22:03:31 +01:00
|
|
|
ctx := ss.Context()
|
|
|
|
|
|
|
|
err := pause(ctx, limiter)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
return handler(srv, ss)
|
|
|
|
}
|
|
|
|
}
|