feat: add an option to redirect all clients to a fixed endpoint

This allows to launch discovery service with a flag like
`--redirect-endpoint=example.com:443`.

Signed-off-by: Andrey Smirnov <andrey.smirnov@talos-systems.com>
This commit is contained in:
Andrey Smirnov
2022-09-08 21:27:47 +04:00
parent b34803b6e0
commit 8db8ef361e
4 changed files with 51 additions and 17 deletions

View File

@ -48,6 +48,7 @@ var (
debugAddr = ":2123"
devMode = false
gcInterval = time.Minute
redirectEndpoint = ""
)
func init() {
@ -56,6 +57,7 @@ func init() {
flag.StringVar(&metricsAddr, "metrics-addr", metricsAddr, "prometheus metrics listen addr")
flag.BoolVar(&devMode, "debug", devMode, "enable debug mode")
flag.DurationVar(&gcInterval, "gc-interval", gcInterval, "garbage collection interval")
flag.StringVar(&redirectEndpoint, "redirect-endpoint", redirectEndpoint, "redirect all clients to a new endpoint (gRPC endpoint, e.g. 'example.com:443'")
if debug.Enabled {
flag.StringVar(&debugAddr, "debug-addr", debugAddr, "debug (pprof, trace, expvar) listen addr")
@ -148,7 +150,7 @@ func run(ctx context.Context, logger *zap.Logger) error {
state := state.NewState(logger)
prom.MustRegister(state)
srv := server.NewClusterServer(state, ctx.Done())
srv := server.NewClusterServer(state, ctx.Done(), redirectEndpoint)
prom.MustRegister(srv)
lis, err := net.Listen("tcp", listenAddr)

View File

@ -32,7 +32,7 @@ import (
func TestClient(t *testing.T) {
t.Parallel()
endpoint := setupServer(t, 5000)
endpoint := setupServer(t, 5000, "")
logger := zaptest.NewLogger(t)

View File

@ -30,13 +30,16 @@ type ClusterServer struct {
stopCh <-chan struct{}
mHello *prom.CounterVec
redirectEndpoint string
}
// NewClusterServer builds new ClusterServer.
func NewClusterServer(state *state.State, stopCh <-chan struct{}) *ClusterServer {
func NewClusterServer(state *state.State, stopCh <-chan struct{}, redirectEndpoint string) *ClusterServer {
srv := &ClusterServer{
state: state,
stopCh: stopCh,
redirectEndpoint: redirectEndpoint,
mHello: prom.NewCounterVec(prom.CounterOpts{
Name: "discovery_server_hello_requests_total",
Help: "Number of hello requests by client version.",
@ -51,7 +54,7 @@ func NewClusterServer(state *state.State, stopCh <-chan struct{}) *ClusterServer
// NewTestClusterServer builds cluster server for testing code.
func NewTestClusterServer(logger *zap.Logger) *ClusterServer {
return NewClusterServer(state.NewState(logger), nil)
return NewClusterServer(state.NewState(logger), nil, "")
}
// Hello implements cluster API.
@ -70,6 +73,12 @@ func (srv *ClusterServer) Hello(ctx context.Context, req *pb.HelloRequest) (*pb.
resp.ClientIp, _ = peerAddress.MarshalBinary() //nolint:errcheck // never fails
}
if srv.redirectEndpoint != "" {
resp.Redirect = &pb.RedirectMessage{
Endpoint: srv.redirectEndpoint,
}
}
return resp, nil
}

View File

@ -45,7 +45,7 @@ func checkMetrics(t *testing.T, c prom.Collector) {
assert.NotZero(t, promtestutil.CollectAndCount(c), "collector should not be unchecked")
}
func setupServer(t *testing.T, rateLimit rate.Limit) (address string) {
func setupServer(t *testing.T, rateLimit rate.Limit, redirectEndpoint string) (address string) {
t.Helper()
logger := zaptest.NewLogger(t)
@ -59,7 +59,7 @@ func setupServer(t *testing.T, rateLimit rate.Limit) (address string) {
state.RunGC(ctx, logger, time.Second)
}()
srv := server.NewClusterServer(state, ctx.Done())
srv := server.NewClusterServer(state, ctx.Done(), redirectEndpoint)
// Check metrics before and after the test
// to ensure that collector does not switch from being unchecked to checked and invalid.
@ -99,7 +99,7 @@ func setupServer(t *testing.T, rateLimit rate.Limit) (address string) {
func TestServerAPI(t *testing.T) {
t.Parallel()
addr := setupServer(t, 5000)
addr := setupServer(t, 5000, "")
conn, e := grpc.Dial(addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, e)
@ -119,6 +119,7 @@ func TestServerAPI(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, []byte{0x7f, 0x0, 0x0, 0x1}, resp.ClientIp) // 127.0.0.1
assert.Nil(t, resp.Redirect)
})
t.Run("HelloWithRealIP", func(t *testing.T) {
@ -298,7 +299,7 @@ func TestServerAPI(t *testing.T) {
func TestValidation(t *testing.T) {
t.Parallel()
addr := setupServer(t, 5000)
addr := setupServer(t, 5000, "")
conn, e := grpc.Dial(addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, e)
@ -493,7 +494,7 @@ func testHitRateLimit(client pb.ClusterClient, ip string) func(t *testing.T) {
func TestServerRateLimit(t *testing.T) {
t.Parallel()
addr := setupServer(t, 1)
addr := setupServer(t, 1, "")
conn, e := grpc.Dial(addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, e)
@ -503,3 +504,25 @@ func TestServerRateLimit(t *testing.T) {
t.Run("HitRateLimitIP1", testHitRateLimit(client, "1.2.3.4"))
t.Run("HitRateLimitIP2", testHitRateLimit(client, "5.6.7.8"))
}
func TestServerRedirect(t *testing.T) {
t.Parallel()
addr := setupServer(t, 1, "new.example.com:443")
conn, e := grpc.Dial(addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, e)
client := pb.NewClusterClient(conn)
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
resp, err := client.Hello(ctx, &pb.HelloRequest{
ClusterId: "fake",
ClientVersion: "v0.12.0",
})
require.NoError(t, err)
assert.Equal(t, "new.example.com:443", resp.GetRedirect().GetEndpoint())
}