mirror of
https://github.com/siderolabs/discovery-service.git
synced 2025-03-14 09:55:08 +00:00
feat: add an option to redirect all clients to a fixed endpoint
This allows to launch discovery service with a flag like `--redirect-endpoint=example.com:443`. Signed-off-by: Andrey Smirnov <andrey.smirnov@talos-systems.com>
This commit is contained in:
@ -48,6 +48,7 @@ var (
|
||||
debugAddr = ":2123"
|
||||
devMode = false
|
||||
gcInterval = time.Minute
|
||||
redirectEndpoint = ""
|
||||
)
|
||||
|
||||
func init() {
|
||||
@ -56,6 +57,7 @@ func init() {
|
||||
flag.StringVar(&metricsAddr, "metrics-addr", metricsAddr, "prometheus metrics listen addr")
|
||||
flag.BoolVar(&devMode, "debug", devMode, "enable debug mode")
|
||||
flag.DurationVar(&gcInterval, "gc-interval", gcInterval, "garbage collection interval")
|
||||
flag.StringVar(&redirectEndpoint, "redirect-endpoint", redirectEndpoint, "redirect all clients to a new endpoint (gRPC endpoint, e.g. 'example.com:443'")
|
||||
|
||||
if debug.Enabled {
|
||||
flag.StringVar(&debugAddr, "debug-addr", debugAddr, "debug (pprof, trace, expvar) listen addr")
|
||||
@ -148,7 +150,7 @@ func run(ctx context.Context, logger *zap.Logger) error {
|
||||
state := state.NewState(logger)
|
||||
prom.MustRegister(state)
|
||||
|
||||
srv := server.NewClusterServer(state, ctx.Done())
|
||||
srv := server.NewClusterServer(state, ctx.Done(), redirectEndpoint)
|
||||
prom.MustRegister(srv)
|
||||
|
||||
lis, err := net.Listen("tcp", listenAddr)
|
||||
|
@ -32,7 +32,7 @@ import (
|
||||
func TestClient(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
endpoint := setupServer(t, 5000)
|
||||
endpoint := setupServer(t, 5000, "")
|
||||
|
||||
logger := zaptest.NewLogger(t)
|
||||
|
||||
|
@ -30,13 +30,16 @@ type ClusterServer struct {
|
||||
stopCh <-chan struct{}
|
||||
|
||||
mHello *prom.CounterVec
|
||||
|
||||
redirectEndpoint string
|
||||
}
|
||||
|
||||
// NewClusterServer builds new ClusterServer.
|
||||
func NewClusterServer(state *state.State, stopCh <-chan struct{}) *ClusterServer {
|
||||
func NewClusterServer(state *state.State, stopCh <-chan struct{}, redirectEndpoint string) *ClusterServer {
|
||||
srv := &ClusterServer{
|
||||
state: state,
|
||||
stopCh: stopCh,
|
||||
redirectEndpoint: redirectEndpoint,
|
||||
mHello: prom.NewCounterVec(prom.CounterOpts{
|
||||
Name: "discovery_server_hello_requests_total",
|
||||
Help: "Number of hello requests by client version.",
|
||||
@ -51,7 +54,7 @@ func NewClusterServer(state *state.State, stopCh <-chan struct{}) *ClusterServer
|
||||
|
||||
// NewTestClusterServer builds cluster server for testing code.
|
||||
func NewTestClusterServer(logger *zap.Logger) *ClusterServer {
|
||||
return NewClusterServer(state.NewState(logger), nil)
|
||||
return NewClusterServer(state.NewState(logger), nil, "")
|
||||
}
|
||||
|
||||
// Hello implements cluster API.
|
||||
@ -70,6 +73,12 @@ func (srv *ClusterServer) Hello(ctx context.Context, req *pb.HelloRequest) (*pb.
|
||||
resp.ClientIp, _ = peerAddress.MarshalBinary() //nolint:errcheck // never fails
|
||||
}
|
||||
|
||||
if srv.redirectEndpoint != "" {
|
||||
resp.Redirect = &pb.RedirectMessage{
|
||||
Endpoint: srv.redirectEndpoint,
|
||||
}
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
|
@ -45,7 +45,7 @@ func checkMetrics(t *testing.T, c prom.Collector) {
|
||||
assert.NotZero(t, promtestutil.CollectAndCount(c), "collector should not be unchecked")
|
||||
}
|
||||
|
||||
func setupServer(t *testing.T, rateLimit rate.Limit) (address string) {
|
||||
func setupServer(t *testing.T, rateLimit rate.Limit, redirectEndpoint string) (address string) {
|
||||
t.Helper()
|
||||
|
||||
logger := zaptest.NewLogger(t)
|
||||
@ -59,7 +59,7 @@ func setupServer(t *testing.T, rateLimit rate.Limit) (address string) {
|
||||
state.RunGC(ctx, logger, time.Second)
|
||||
}()
|
||||
|
||||
srv := server.NewClusterServer(state, ctx.Done())
|
||||
srv := server.NewClusterServer(state, ctx.Done(), redirectEndpoint)
|
||||
|
||||
// Check metrics before and after the test
|
||||
// to ensure that collector does not switch from being unchecked to checked and invalid.
|
||||
@ -99,7 +99,7 @@ func setupServer(t *testing.T, rateLimit rate.Limit) (address string) {
|
||||
func TestServerAPI(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
addr := setupServer(t, 5000)
|
||||
addr := setupServer(t, 5000, "")
|
||||
|
||||
conn, e := grpc.Dial(addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
require.NoError(t, e)
|
||||
@ -119,6 +119,7 @@ func TestServerAPI(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, []byte{0x7f, 0x0, 0x0, 0x1}, resp.ClientIp) // 127.0.0.1
|
||||
assert.Nil(t, resp.Redirect)
|
||||
})
|
||||
|
||||
t.Run("HelloWithRealIP", func(t *testing.T) {
|
||||
@ -298,7 +299,7 @@ func TestServerAPI(t *testing.T) {
|
||||
func TestValidation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
addr := setupServer(t, 5000)
|
||||
addr := setupServer(t, 5000, "")
|
||||
|
||||
conn, e := grpc.Dial(addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
require.NoError(t, e)
|
||||
@ -493,7 +494,7 @@ func testHitRateLimit(client pb.ClusterClient, ip string) func(t *testing.T) {
|
||||
func TestServerRateLimit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
addr := setupServer(t, 1)
|
||||
addr := setupServer(t, 1, "")
|
||||
|
||||
conn, e := grpc.Dial(addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
require.NoError(t, e)
|
||||
@ -503,3 +504,25 @@ func TestServerRateLimit(t *testing.T) {
|
||||
t.Run("HitRateLimitIP1", testHitRateLimit(client, "1.2.3.4"))
|
||||
t.Run("HitRateLimitIP2", testHitRateLimit(client, "5.6.7.8"))
|
||||
}
|
||||
|
||||
func TestServerRedirect(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
addr := setupServer(t, 1, "new.example.com:443")
|
||||
|
||||
conn, e := grpc.Dial(addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
require.NoError(t, e)
|
||||
|
||||
client := pb.NewClusterClient(conn)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
|
||||
resp, err := client.Hello(ctx, &pb.HelloRequest{
|
||||
ClusterId: "fake",
|
||||
ClientVersion: "v0.12.0",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "new.example.com:443", resp.GetRedirect().GetEndpoint())
|
||||
}
|
||||
|
Reference in New Issue
Block a user