diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index d9e82c5b6e..2b70e6b2a0 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -707,6 +707,13 @@ func (q *querier) DeleteApplicationConnectAPIKeysByUserID(ctx context.Context, u return q.db.DeleteApplicationConnectAPIKeysByUserID(ctx, userID) } +func (q *querier) DeleteCoordinator(ctx context.Context, id uuid.UUID) error { + if err := q.authorizeContext(ctx, rbac.ActionDelete, rbac.ResourceTailnetCoordinator); err != nil { + return err + } + return q.db.DeleteCoordinator(ctx, id) +} + func (q *querier) DeleteGitSSHKey(ctx context.Context, userID uuid.UUID) error { return deleteQ(q.log, q.auth, q.db.GetGitSSHKey, q.db.DeleteGitSSHKey)(ctx, userID) } @@ -765,6 +772,20 @@ func (q *querier) DeleteReplicasUpdatedBefore(ctx context.Context, updatedAt tim return q.db.DeleteReplicasUpdatedBefore(ctx, updatedAt) } +func (q *querier) DeleteTailnetAgent(ctx context.Context, arg database.DeleteTailnetAgentParams) (database.DeleteTailnetAgentRow, error) { + if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceTailnetCoordinator); err != nil { + return database.DeleteTailnetAgentRow{}, err + } + return q.db.DeleteTailnetAgent(ctx, arg) +} + +func (q *querier) DeleteTailnetClient(ctx context.Context, arg database.DeleteTailnetClientParams) (database.DeleteTailnetClientRow, error) { + if err := q.authorizeContext(ctx, rbac.ActionDelete, rbac.ResourceTailnetCoordinator); err != nil { + return database.DeleteTailnetClientRow{}, err + } + return q.db.DeleteTailnetClient(ctx, arg) +} + func (q *querier) GetAPIKeyByID(ctx context.Context, id string) (database.APIKey, error) { return fetch(q.log, q.auth, q.db.GetAPIKeyByID)(ctx, id) } @@ -1137,6 +1158,20 @@ func (q *querier) GetServiceBanner(ctx context.Context) (string, error) { return q.db.GetServiceBanner(ctx) } +func (q *querier) GetTailnetAgents(ctx context.Context, id uuid.UUID) ([]database.TailnetAgent, error) { + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceTailnetCoordinator); err != nil { + return nil, err + } + return q.db.GetTailnetAgents(ctx, id) +} + +func (q *querier) GetTailnetClientsForAgent(ctx context.Context, agentID uuid.UUID) ([]database.TailnetClient, error) { + if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceTailnetCoordinator); err != nil { + return nil, err + } + return q.db.GetTailnetClientsForAgent(ctx, agentID) +} + // Only used by metrics cache. func (q *querier) GetTemplateAverageBuildTime(ctx context.Context, arg database.GetTemplateAverageBuildTimeParams) (database.GetTemplateAverageBuildTimeRow, error) { if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { @@ -2515,3 +2550,24 @@ func (q *querier) UpsertServiceBanner(ctx context.Context, value string) error { } return q.db.UpsertServiceBanner(ctx, value) } + +func (q *querier) UpsertTailnetAgent(ctx context.Context, arg database.UpsertTailnetAgentParams) (database.TailnetAgent, error) { + if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceTailnetCoordinator); err != nil { + return database.TailnetAgent{}, err + } + return q.db.UpsertTailnetAgent(ctx, arg) +} + +func (q *querier) UpsertTailnetClient(ctx context.Context, arg database.UpsertTailnetClientParams) (database.TailnetClient, error) { + if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceTailnetCoordinator); err != nil { + return database.TailnetClient{}, err + } + return q.db.UpsertTailnetClient(ctx, arg) +} + +func (q *querier) UpsertTailnetCoordinator(ctx context.Context, id uuid.UUID) (database.TailnetCoordinator, error) { + if err := q.authorizeContext(ctx, rbac.ActionUpdate, rbac.ResourceTailnetCoordinator); err != nil { + return database.TailnetCoordinator{}, err + } + return q.db.UpsertTailnetCoordinator(ctx, id) +} diff --git a/coderd/database/dbfake/dbfake.go b/coderd/database/dbfake/dbfake.go index da623da862..65ba55c06a 100644 --- a/coderd/database/dbfake/dbfake.go +++ b/coderd/database/dbfake/dbfake.go @@ -966,6 +966,14 @@ func isNotNull(v interface{}) bool { return reflect.ValueOf(v).FieldByName("Valid").Bool() } +// ErrUnimplemented is returned by methods only used by the enterprise/tailnet.pgCoord. This coordinator explicitly +// depends on postgres triggers that announce changes on the pubsub. Implementing support for this in the fake +// database would strongly couple the fakeQuerier to the pubsub, which is undesirable. Furthermore, it makes little +// sense to directly test the pgCoord against anything other than postgres. The fakeQuerier is designed to allow us to +// test the Coderd API, and for that kind of test, the in-memory, AGPL tailnet coordinator is sufficient. Therefore, +// these methods remain unimplemented in the fakeQuerier. +var ErrUnimplemented = xerrors.New("unimplemented") + func (*fakeQuerier) AcquireLock(_ context.Context, _ int64) error { return xerrors.New("AcquireLock must only be called within a transaction") } @@ -1066,6 +1074,10 @@ func (q *fakeQuerier) DeleteApplicationConnectAPIKeysByUserID(_ context.Context, return nil } +func (*fakeQuerier) DeleteCoordinator(context.Context, uuid.UUID) error { + return ErrUnimplemented +} + func (q *fakeQuerier) DeleteGitSSHKey(_ context.Context, userID uuid.UUID) error { q.mutex.Lock() defer q.mutex.Unlock() @@ -1174,6 +1186,14 @@ func (q *fakeQuerier) DeleteReplicasUpdatedBefore(_ context.Context, before time return nil } +func (*fakeQuerier) DeleteTailnetAgent(context.Context, database.DeleteTailnetAgentParams) (database.DeleteTailnetAgentRow, error) { + return database.DeleteTailnetAgentRow{}, ErrUnimplemented +} + +func (*fakeQuerier) DeleteTailnetClient(context.Context, database.DeleteTailnetClientParams) (database.DeleteTailnetClientRow, error) { + return database.DeleteTailnetClientRow{}, ErrUnimplemented +} + func (q *fakeQuerier) GetAPIKeyByID(_ context.Context, id string) (database.APIKey, error) { q.mutex.RLock() defer q.mutex.RUnlock() @@ -2185,6 +2205,14 @@ func (q *fakeQuerier) GetServiceBanner(_ context.Context) (string, error) { return string(q.serviceBanner), nil } +func (*fakeQuerier) GetTailnetAgents(context.Context, uuid.UUID) ([]database.TailnetAgent, error) { + return nil, ErrUnimplemented +} + +func (*fakeQuerier) GetTailnetClientsForAgent(context.Context, uuid.UUID) ([]database.TailnetClient, error) { + return nil, ErrUnimplemented +} + func (q *fakeQuerier) GetTemplateAverageBuildTime(ctx context.Context, arg database.GetTemplateAverageBuildTimeParams) (database.GetTemplateAverageBuildTimeRow, error) { if err := validateDatabaseType(arg); err != nil { return database.GetTemplateAverageBuildTimeRow{}, err @@ -5238,3 +5266,15 @@ func (q *fakeQuerier) UpsertServiceBanner(_ context.Context, data string) error q.serviceBanner = []byte(data) return nil } + +func (*fakeQuerier) UpsertTailnetAgent(context.Context, database.UpsertTailnetAgentParams) (database.TailnetAgent, error) { + return database.TailnetAgent{}, ErrUnimplemented +} + +func (*fakeQuerier) UpsertTailnetClient(context.Context, database.UpsertTailnetClientParams) (database.TailnetClient, error) { + return database.TailnetClient{}, ErrUnimplemented +} + +func (*fakeQuerier) UpsertTailnetCoordinator(context.Context, uuid.UUID) (database.TailnetCoordinator, error) { + return database.TailnetCoordinator{}, ErrUnimplemented +} diff --git a/coderd/database/dbmetrics/dbmetrics.go b/coderd/database/dbmetrics/dbmetrics.go index 71171ed6ff..b73418fcfd 100644 --- a/coderd/database/dbmetrics/dbmetrics.go +++ b/coderd/database/dbmetrics/dbmetrics.go @@ -143,6 +143,12 @@ func (m metricsStore) DeleteApplicationConnectAPIKeysByUserID(ctx context.Contex return err } +func (m metricsStore) DeleteCoordinator(ctx context.Context, id uuid.UUID) error { + start := time.Now() + defer m.queryLatencies.WithLabelValues("DeleteCoordinator").Observe(time.Since(start).Seconds()) + return m.s.DeleteCoordinator(ctx, id) +} + func (m metricsStore) DeleteGitSSHKey(ctx context.Context, userID uuid.UUID) error { start := time.Now() err := m.s.DeleteGitSSHKey(ctx, userID) @@ -199,6 +205,18 @@ func (m metricsStore) DeleteReplicasUpdatedBefore(ctx context.Context, updatedAt return err } +func (m metricsStore) DeleteTailnetAgent(ctx context.Context, arg database.DeleteTailnetAgentParams) (database.DeleteTailnetAgentRow, error) { + start := time.Now() + defer m.queryLatencies.WithLabelValues("DeleteTailnetAgent").Observe(time.Since(start).Seconds()) + return m.s.DeleteTailnetAgent(ctx, arg) +} + +func (m metricsStore) DeleteTailnetClient(ctx context.Context, arg database.DeleteTailnetClientParams) (database.DeleteTailnetClientRow, error) { + start := time.Now() + defer m.queryLatencies.WithLabelValues("DeleteTailnetClient").Observe(time.Since(start).Seconds()) + return m.s.DeleteTailnetClient(ctx, arg) +} + func (m metricsStore) GetAPIKeyByID(ctx context.Context, id string) (database.APIKey, error) { start := time.Now() apiKey, err := m.s.GetAPIKeyByID(ctx, id) @@ -556,6 +574,18 @@ func (m metricsStore) GetServiceBanner(ctx context.Context) (string, error) { return banner, err } +func (m metricsStore) GetTailnetAgents(ctx context.Context, id uuid.UUID) ([]database.TailnetAgent, error) { + start := time.Now() + defer m.queryLatencies.WithLabelValues("GetTailnetAgents").Observe(time.Since(start).Seconds()) + return m.s.GetTailnetAgents(ctx, id) +} + +func (m metricsStore) GetTailnetClientsForAgent(ctx context.Context, agentID uuid.UUID) ([]database.TailnetClient, error) { + start := time.Now() + defer m.queryLatencies.WithLabelValues("GetTailnetClientsForAgent").Observe(time.Since(start).Seconds()) + return m.s.GetTailnetClientsForAgent(ctx, agentID) +} + func (m metricsStore) GetTemplateAverageBuildTime(ctx context.Context, arg database.GetTemplateAverageBuildTimeParams) (database.GetTemplateAverageBuildTimeRow, error) { start := time.Now() buildTime, err := m.s.GetTemplateAverageBuildTime(ctx, arg) @@ -1549,3 +1579,21 @@ func (m metricsStore) UpsertServiceBanner(ctx context.Context, value string) err m.queryLatencies.WithLabelValues("UpsertServiceBanner").Observe(time.Since(start).Seconds()) return r0 } + +func (m metricsStore) UpsertTailnetAgent(ctx context.Context, arg database.UpsertTailnetAgentParams) (database.TailnetAgent, error) { + start := time.Now() + defer m.queryLatencies.WithLabelValues("UpsertTailnetAgent").Observe(time.Since(start).Seconds()) + return m.s.UpsertTailnetAgent(ctx, arg) +} + +func (m metricsStore) UpsertTailnetClient(ctx context.Context, arg database.UpsertTailnetClientParams) (database.TailnetClient, error) { + start := time.Now() + defer m.queryLatencies.WithLabelValues("UpsertTailnetClient").Observe(time.Since(start).Seconds()) + return m.s.UpsertTailnetClient(ctx, arg) +} + +func (m metricsStore) UpsertTailnetCoordinator(ctx context.Context, id uuid.UUID) (database.TailnetCoordinator, error) { + start := time.Now() + defer m.queryLatencies.WithLabelValues("UpsertTailnetCoordinator").Observe(time.Since(start).Seconds()) + return m.s.UpsertTailnetCoordinator(ctx, id) +} diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index 610795142b..6006e9ff37 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -110,6 +110,20 @@ func (mr *MockStoreMockRecorder) DeleteApplicationConnectAPIKeysByUserID(arg0, a return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteApplicationConnectAPIKeysByUserID", reflect.TypeOf((*MockStore)(nil).DeleteApplicationConnectAPIKeysByUserID), arg0, arg1) } +// DeleteCoordinator mocks base method. +func (m *MockStore) DeleteCoordinator(arg0 context.Context, arg1 uuid.UUID) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteCoordinator", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteCoordinator indicates an expected call of DeleteCoordinator. +func (mr *MockStoreMockRecorder) DeleteCoordinator(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteCoordinator", reflect.TypeOf((*MockStore)(nil).DeleteCoordinator), arg0, arg1) +} + // DeleteGitSSHKey mocks base method. func (m *MockStore) DeleteGitSSHKey(arg0 context.Context, arg1 uuid.UUID) error { m.ctrl.T.Helper() @@ -223,6 +237,36 @@ func (mr *MockStoreMockRecorder) DeleteReplicasUpdatedBefore(arg0, arg1 interfac return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteReplicasUpdatedBefore", reflect.TypeOf((*MockStore)(nil).DeleteReplicasUpdatedBefore), arg0, arg1) } +// DeleteTailnetAgent mocks base method. +func (m *MockStore) DeleteTailnetAgent(arg0 context.Context, arg1 database.DeleteTailnetAgentParams) (database.DeleteTailnetAgentRow, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteTailnetAgent", arg0, arg1) + ret0, _ := ret[0].(database.DeleteTailnetAgentRow) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DeleteTailnetAgent indicates an expected call of DeleteTailnetAgent. +func (mr *MockStoreMockRecorder) DeleteTailnetAgent(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteTailnetAgent", reflect.TypeOf((*MockStore)(nil).DeleteTailnetAgent), arg0, arg1) +} + +// DeleteTailnetClient mocks base method. +func (m *MockStore) DeleteTailnetClient(arg0 context.Context, arg1 database.DeleteTailnetClientParams) (database.DeleteTailnetClientRow, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteTailnetClient", arg0, arg1) + ret0, _ := ret[0].(database.DeleteTailnetClientRow) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DeleteTailnetClient indicates an expected call of DeleteTailnetClient. +func (mr *MockStoreMockRecorder) DeleteTailnetClient(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteTailnetClient", reflect.TypeOf((*MockStore)(nil).DeleteTailnetClient), arg0, arg1) +} + // GetAPIKeyByID mocks base method. func (m *MockStore) GetAPIKeyByID(arg0 context.Context, arg1 string) (database.APIKey, error) { m.ctrl.T.Helper() @@ -1033,6 +1077,36 @@ func (mr *MockStoreMockRecorder) GetServiceBanner(arg0 interface{}) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceBanner", reflect.TypeOf((*MockStore)(nil).GetServiceBanner), arg0) } +// GetTailnetAgents mocks base method. +func (m *MockStore) GetTailnetAgents(arg0 context.Context, arg1 uuid.UUID) ([]database.TailnetAgent, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetTailnetAgents", arg0, arg1) + ret0, _ := ret[0].([]database.TailnetAgent) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetTailnetAgents indicates an expected call of GetTailnetAgents. +func (mr *MockStoreMockRecorder) GetTailnetAgents(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTailnetAgents", reflect.TypeOf((*MockStore)(nil).GetTailnetAgents), arg0, arg1) +} + +// GetTailnetClientsForAgent mocks base method. +func (m *MockStore) GetTailnetClientsForAgent(arg0 context.Context, arg1 uuid.UUID) ([]database.TailnetClient, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetTailnetClientsForAgent", arg0, arg1) + ret0, _ := ret[0].([]database.TailnetClient) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetTailnetClientsForAgent indicates an expected call of GetTailnetClientsForAgent. +func (mr *MockStoreMockRecorder) GetTailnetClientsForAgent(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTailnetClientsForAgent", reflect.TypeOf((*MockStore)(nil).GetTailnetClientsForAgent), arg0, arg1) +} + // GetTemplateAverageBuildTime mocks base method. func (m *MockStore) GetTemplateAverageBuildTime(arg0 context.Context, arg1 database.GetTemplateAverageBuildTimeParams) (database.GetTemplateAverageBuildTimeRow, error) { m.ctrl.T.Helper() @@ -3189,6 +3263,51 @@ func (mr *MockStoreMockRecorder) UpsertServiceBanner(arg0, arg1 interface{}) *go return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertServiceBanner", reflect.TypeOf((*MockStore)(nil).UpsertServiceBanner), arg0, arg1) } +// UpsertTailnetAgent mocks base method. +func (m *MockStore) UpsertTailnetAgent(arg0 context.Context, arg1 database.UpsertTailnetAgentParams) (database.TailnetAgent, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpsertTailnetAgent", arg0, arg1) + ret0, _ := ret[0].(database.TailnetAgent) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UpsertTailnetAgent indicates an expected call of UpsertTailnetAgent. +func (mr *MockStoreMockRecorder) UpsertTailnetAgent(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertTailnetAgent", reflect.TypeOf((*MockStore)(nil).UpsertTailnetAgent), arg0, arg1) +} + +// UpsertTailnetClient mocks base method. +func (m *MockStore) UpsertTailnetClient(arg0 context.Context, arg1 database.UpsertTailnetClientParams) (database.TailnetClient, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpsertTailnetClient", arg0, arg1) + ret0, _ := ret[0].(database.TailnetClient) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UpsertTailnetClient indicates an expected call of UpsertTailnetClient. +func (mr *MockStoreMockRecorder) UpsertTailnetClient(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertTailnetClient", reflect.TypeOf((*MockStore)(nil).UpsertTailnetClient), arg0, arg1) +} + +// UpsertTailnetCoordinator mocks base method. +func (m *MockStore) UpsertTailnetCoordinator(arg0 context.Context, arg1 uuid.UUID) (database.TailnetCoordinator, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpsertTailnetCoordinator", arg0, arg1) + ret0, _ := ret[0].(database.TailnetCoordinator) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UpsertTailnetCoordinator indicates an expected call of UpsertTailnetCoordinator. +func (mr *MockStoreMockRecorder) UpsertTailnetCoordinator(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertTailnetCoordinator", reflect.TypeOf((*MockStore)(nil).UpsertTailnetCoordinator), arg0, arg1) +} + // Wrappers mocks base method. func (m *MockStore) Wrappers() []string { m.ctrl.T.Helper() diff --git a/coderd/database/dbtestutil/db.go b/coderd/database/dbtestutil/db.go index 932e4aaf47..ad8cecf143 100644 --- a/coderd/database/dbtestutil/db.go +++ b/coderd/database/dbtestutil/db.go @@ -14,12 +14,17 @@ import ( "github.com/coder/coder/coderd/database/pubsub" ) +// WillUsePostgres returns true if a call to NewDB() will return a real, postgres-backed Store and Pubsub. +func WillUsePostgres() bool { + return os.Getenv("DB") != "" +} + func NewDB(t testing.TB) (database.Store, pubsub.Pubsub) { t.Helper() db := dbfake.New() ps := pubsub.NewInMemory() - if os.Getenv("DB") != "" { + if WillUsePostgres() { connectionURL := os.Getenv("CODER_PG_CONNECTION_URL") if connectionURL == "" { var ( diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index 25cf2107ba..05a5a3057a 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -171,6 +171,45 @@ BEGIN END; $$; +CREATE FUNCTION tailnet_notify_agent_change() RETURNS trigger + LANGUAGE plpgsql + AS $$ +BEGIN + IF (OLD IS NOT NULL) THEN + PERFORM pg_notify('tailnet_agent_update', OLD.id::text); + RETURN NULL; + END IF; + IF (NEW IS NOT NULL) THEN + PERFORM pg_notify('tailnet_agent_update', NEW.id::text); + RETURN NULL; + END IF; +END; +$$; + +CREATE FUNCTION tailnet_notify_client_change() RETURNS trigger + LANGUAGE plpgsql + AS $$ +BEGIN + IF (OLD IS NOT NULL) THEN + PERFORM pg_notify('tailnet_client_update', OLD.id || ',' || OLD.agent_id); + RETURN NULL; + END IF; + IF (NEW IS NOT NULL) THEN + PERFORM pg_notify('tailnet_client_update', NEW.id || ',' || NEW.agent_id); + RETURN NULL; + END IF; +END; +$$; + +CREATE FUNCTION tailnet_notify_coordinator_heartbeat() RETURNS trigger + LANGUAGE plpgsql + AS $$ +BEGIN + PERFORM pg_notify('tailnet_coordinator_heartbeat', NEW.id::text); + RETURN NULL; +END; +$$; + CREATE TABLE api_keys ( id text NOT NULL, hashed_secret bytea NOT NULL, @@ -383,6 +422,28 @@ CREATE TABLE site_configs ( value character varying(8192) NOT NULL ); +CREATE TABLE tailnet_agents ( + id uuid NOT NULL, + coordinator_id uuid NOT NULL, + updated_at timestamp with time zone NOT NULL, + node jsonb NOT NULL +); + +CREATE TABLE tailnet_clients ( + id uuid NOT NULL, + coordinator_id uuid NOT NULL, + agent_id uuid NOT NULL, + updated_at timestamp with time zone NOT NULL, + node jsonb NOT NULL +); + +CREATE TABLE tailnet_coordinators ( + id uuid NOT NULL, + heartbeat_at timestamp with time zone NOT NULL +); + +COMMENT ON TABLE tailnet_coordinators IS 'We keep this separate from replicas in case we need to break the coordinator out into its own service'; + CREATE TABLE template_version_parameters ( template_version_id uuid NOT NULL, name text NOT NULL, @@ -835,6 +896,15 @@ ALTER TABLE ONLY provisioner_jobs ALTER TABLE ONLY site_configs ADD CONSTRAINT site_configs_key_key UNIQUE (key); +ALTER TABLE ONLY tailnet_agents + ADD CONSTRAINT tailnet_agents_pkey PRIMARY KEY (id, coordinator_id); + +ALTER TABLE ONLY tailnet_clients + ADD CONSTRAINT tailnet_clients_pkey PRIMARY KEY (id, coordinator_id); + +ALTER TABLE ONLY tailnet_coordinators + ADD CONSTRAINT tailnet_coordinators_pkey PRIMARY KEY (id); + ALTER TABLE ONLY template_version_parameters ADD CONSTRAINT template_version_parameters_template_version_id_name_key UNIQUE (template_version_id, name); @@ -922,6 +992,12 @@ CREATE UNIQUE INDEX idx_organization_name ON organizations USING btree (name); CREATE UNIQUE INDEX idx_organization_name_lower ON organizations USING btree (lower(name)); +CREATE INDEX idx_tailnet_agents_coordinator ON tailnet_agents USING btree (coordinator_id); + +CREATE INDEX idx_tailnet_clients_agent ON tailnet_clients USING btree (agent_id); + +CREATE INDEX idx_tailnet_clients_coordinator ON tailnet_clients USING btree (coordinator_id); + CREATE UNIQUE INDEX idx_users_email ON users USING btree (email) WHERE (deleted = false); CREATE UNIQUE INDEX idx_users_username ON users USING btree (username) WHERE (deleted = false); @@ -948,6 +1024,12 @@ CREATE INDEX workspace_resources_job_id_idx ON workspace_resources USING btree ( CREATE UNIQUE INDEX workspaces_owner_id_lower_idx ON workspaces USING btree (owner_id, lower((name)::text)) WHERE (deleted = false); +CREATE TRIGGER tailnet_notify_agent_change AFTER INSERT OR DELETE OR UPDATE ON tailnet_agents FOR EACH ROW EXECUTE FUNCTION tailnet_notify_agent_change(); + +CREATE TRIGGER tailnet_notify_client_change AFTER INSERT OR DELETE OR UPDATE ON tailnet_clients FOR EACH ROW EXECUTE FUNCTION tailnet_notify_client_change(); + +CREATE TRIGGER tailnet_notify_coordinator_heartbeat AFTER INSERT OR UPDATE ON tailnet_coordinators FOR EACH ROW EXECUTE FUNCTION tailnet_notify_coordinator_heartbeat(); + CREATE TRIGGER trigger_insert_apikeys BEFORE INSERT ON api_keys FOR EACH ROW EXECUTE FUNCTION insert_apikey_fail_if_user_deleted(); CREATE TRIGGER trigger_update_users AFTER INSERT OR UPDATE ON users FOR EACH ROW WHEN ((new.deleted = true)) EXECUTE FUNCTION delete_deleted_user_api_keys(); @@ -982,6 +1064,12 @@ ALTER TABLE ONLY provisioner_job_logs ALTER TABLE ONLY provisioner_jobs ADD CONSTRAINT provisioner_jobs_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE; +ALTER TABLE ONLY tailnet_agents + ADD CONSTRAINT tailnet_agents_coordinator_id_fkey FOREIGN KEY (coordinator_id) REFERENCES tailnet_coordinators(id) ON DELETE CASCADE; + +ALTER TABLE ONLY tailnet_clients + ADD CONSTRAINT tailnet_clients_coordinator_id_fkey FOREIGN KEY (coordinator_id) REFERENCES tailnet_coordinators(id) ON DELETE CASCADE; + ALTER TABLE ONLY template_version_parameters ADD CONSTRAINT template_version_parameters_template_version_id_fkey FOREIGN KEY (template_version_id) REFERENCES template_versions(id) ON DELETE CASCADE; diff --git a/coderd/database/migrations/000130_ha_coordinator.down.sql b/coderd/database/migrations/000130_ha_coordinator.down.sql new file mode 100644 index 0000000000..54c8b02539 --- /dev/null +++ b/coderd/database/migrations/000130_ha_coordinator.down.sql @@ -0,0 +1,18 @@ +BEGIN; + +DROP TRIGGER IF EXISTS tailnet_notify_client_change ON tailnet_clients; +DROP FUNCTION IF EXISTS tailnet_notify_client_change; +DROP INDEX IF EXISTS idx_tailnet_clients_agent; +DROP INDEX IF EXISTS idx_tailnet_clients_coordinator; +DROP TABLE tailnet_clients; + +DROP TRIGGER IF EXISTS tailnet_notify_agent_change ON tailnet_agents; +DROP FUNCTION IF EXISTS tailnet_notify_agent_change; +DROP INDEX IF EXISTS idx_tailnet_agents_coordinator; +DROP TABLE IF EXISTS tailnet_agents; + +DROP TRIGGER IF EXISTS tailnet_notify_coordinator_heartbeat ON tailnet_coordinators; +DROP FUNCTION IF EXISTS tailnet_notify_coordinator_heartbeat; +DROP TABLE IF EXISTS tailnet_coordinators; + +COMMIT; diff --git a/coderd/database/migrations/000130_ha_coordinator.up.sql b/coderd/database/migrations/000130_ha_coordinator.up.sql new file mode 100644 index 0000000000..f30bd077c7 --- /dev/null +++ b/coderd/database/migrations/000130_ha_coordinator.up.sql @@ -0,0 +1,97 @@ +BEGIN; + +CREATE TABLE tailnet_coordinators ( + id uuid NOT NULL PRIMARY KEY, + heartbeat_at timestamp with time zone NOT NULL +); + +COMMENT ON TABLE tailnet_coordinators IS 'We keep this separate from replicas in case we need to break the coordinator out into its own service'; + +CREATE TABLE tailnet_clients ( + id uuid NOT NULL, + coordinator_id uuid NOT NULL, + agent_id uuid NOT NULL, + updated_at timestamp with time zone NOT NULL, + node jsonb NOT NULL, + PRIMARY KEY (id, coordinator_id), + FOREIGN KEY (coordinator_id) REFERENCES tailnet_coordinators(id) ON DELETE CASCADE +); + + +-- For querying/deleting mappings +CREATE INDEX idx_tailnet_clients_agent ON tailnet_clients (agent_id); + +-- For shutting down / GC a coordinator +CREATE INDEX idx_tailnet_clients_coordinator ON tailnet_clients (coordinator_id); + +CREATE TABLE tailnet_agents ( + id uuid NOT NULL, + coordinator_id uuid NOT NULL, + updated_at timestamp with time zone NOT NULL, + node jsonb NOT NULL, + PRIMARY KEY (id, coordinator_id), + FOREIGN KEY (coordinator_id) REFERENCES tailnet_coordinators(id) ON DELETE CASCADE +); + + +-- For shutting down / GC a coordinator +CREATE INDEX idx_tailnet_agents_coordinator ON tailnet_agents (coordinator_id); + +-- Any time the tailnet_clients table changes, send an update with the affected client and agent IDs +CREATE FUNCTION tailnet_notify_client_change() RETURNS trigger + LANGUAGE plpgsql + AS $$ +BEGIN + IF (OLD IS NOT NULL) THEN + PERFORM pg_notify('tailnet_client_update', OLD.id || ',' || OLD.agent_id); + RETURN NULL; + END IF; + IF (NEW IS NOT NULL) THEN + PERFORM pg_notify('tailnet_client_update', NEW.id || ',' || NEW.agent_id); + RETURN NULL; + END IF; +END; +$$; + +CREATE TRIGGER tailnet_notify_client_change + AFTER INSERT OR UPDATE OR DELETE ON tailnet_clients + FOR EACH ROW +EXECUTE PROCEDURE tailnet_notify_client_change(); + +-- Any time tailnet_agents table changes, send an update with the affected agent ID. +CREATE FUNCTION tailnet_notify_agent_change() RETURNS trigger + LANGUAGE plpgsql + AS $$ +BEGIN + IF (OLD IS NOT NULL) THEN + PERFORM pg_notify('tailnet_agent_update', OLD.id::text); + RETURN NULL; + END IF; + IF (NEW IS NOT NULL) THEN + PERFORM pg_notify('tailnet_agent_update', NEW.id::text); + RETURN NULL; + END IF; +END; +$$; + +CREATE TRIGGER tailnet_notify_agent_change + AFTER INSERT OR UPDATE OR DELETE ON tailnet_agents + FOR EACH ROW +EXECUTE PROCEDURE tailnet_notify_agent_change(); + +-- Send coordinator heartbeats +CREATE FUNCTION tailnet_notify_coordinator_heartbeat() RETURNS trigger + LANGUAGE plpgsql +AS $$ +BEGIN + PERFORM pg_notify('tailnet_coordinator_heartbeat', NEW.id::text); + RETURN NULL; +END; +$$; + +CREATE TRIGGER tailnet_notify_coordinator_heartbeat + AFTER INSERT OR UPDATE ON tailnet_coordinators + FOR EACH ROW +EXECUTE PROCEDURE tailnet_notify_coordinator_heartbeat(); + +COMMIT; diff --git a/coderd/database/migrations/testdata/fixtures/000130_ha_coordinator.up.sql b/coderd/database/migrations/testdata/fixtures/000130_ha_coordinator.up.sql new file mode 100644 index 0000000000..8af4fa4827 --- /dev/null +++ b/coderd/database/migrations/testdata/fixtures/000130_ha_coordinator.up.sql @@ -0,0 +1,28 @@ +INSERT INTO tailnet_coordinators + (id, heartbeat_at) +VALUES + ( + 'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11', + '2023-06-15 10:23:54+00' + ); + +INSERT INTO tailnet_clients + (id, agent_id, coordinator_id, updated_at, node) +VALUES + ( + 'b0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11', + 'c0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11', + 'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11', + '2023-06-15 10:23:54+00', + '{"preferred_derp": 12}'::json + ); + +INSERT INTO tailnet_agents +(id, coordinator_id, updated_at, node) +VALUES + ( + 'c0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11', + 'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11', + '2023-06-15 10:23:54+00', + '{"preferred_derp": 13}'::json + ); diff --git a/coderd/database/models.go b/coderd/database/models.go index 5d8c0d810a..d3c7f9c436 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -1534,6 +1534,27 @@ type SiteConfig struct { Value string `db:"value" json:"value"` } +type TailnetAgent struct { + ID uuid.UUID `db:"id" json:"id"` + CoordinatorID uuid.UUID `db:"coordinator_id" json:"coordinator_id"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` + Node json.RawMessage `db:"node" json:"node"` +} + +type TailnetClient struct { + ID uuid.UUID `db:"id" json:"id"` + CoordinatorID uuid.UUID `db:"coordinator_id" json:"coordinator_id"` + AgentID uuid.UUID `db:"agent_id" json:"agent_id"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` + Node json.RawMessage `db:"node" json:"node"` +} + +// We keep this separate from replicas in case we need to break the coordinator out into its own service +type TailnetCoordinator struct { + ID uuid.UUID `db:"id" json:"id"` + HeartbeatAt time.Time `db:"heartbeat_at" json:"heartbeat_at"` +} + type Template struct { ID uuid.UUID `db:"id" json:"id"` CreatedAt time.Time `db:"created_at" json:"created_at"` diff --git a/coderd/database/querier.go b/coderd/database/querier.go index f2e1358004..292029a340 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -29,6 +29,7 @@ type sqlcQuerier interface { DeleteAPIKeyByID(ctx context.Context, id string) error DeleteAPIKeysByUserID(ctx context.Context, userID uuid.UUID) error DeleteApplicationConnectAPIKeysByUserID(ctx context.Context, userID uuid.UUID) error + DeleteCoordinator(ctx context.Context, id uuid.UUID) error DeleteGitSSHKey(ctx context.Context, userID uuid.UUID) error DeleteGroupByID(ctx context.Context, id uuid.UUID) error DeleteGroupMemberFromGroup(ctx context.Context, arg DeleteGroupMemberFromGroupParams) error @@ -39,6 +40,8 @@ type sqlcQuerier interface { DeleteOldWorkspaceAgentStartupLogs(ctx context.Context) error DeleteOldWorkspaceAgentStats(ctx context.Context) error DeleteReplicasUpdatedBefore(ctx context.Context, updatedAt time.Time) error + DeleteTailnetAgent(ctx context.Context, arg DeleteTailnetAgentParams) (DeleteTailnetAgentRow, error) + DeleteTailnetClient(ctx context.Context, arg DeleteTailnetClientParams) (DeleteTailnetClientRow, error) GetAPIKeyByID(ctx context.Context, id string) (APIKey, error) // there is no unique constraint on empty token names GetAPIKeyByName(ctx context.Context, arg GetAPIKeyByNameParams) (APIKey, error) @@ -97,6 +100,8 @@ type sqlcQuerier interface { GetQuotaConsumedForUser(ctx context.Context, ownerID uuid.UUID) (int64, error) GetReplicasUpdatedAfter(ctx context.Context, updatedAt time.Time) ([]Replica, error) GetServiceBanner(ctx context.Context) (string, error) + GetTailnetAgents(ctx context.Context, id uuid.UUID) ([]TailnetAgent, error) + GetTailnetClientsForAgent(ctx context.Context, agentID uuid.UUID) ([]TailnetClient, error) GetTemplateAverageBuildTime(ctx context.Context, arg GetTemplateAverageBuildTimeParams) (GetTemplateAverageBuildTimeRow, error) GetTemplateByID(ctx context.Context, id uuid.UUID) (Template, error) GetTemplateByOrganizationAndName(ctx context.Context, arg GetTemplateByOrganizationAndNameParams) (Template, error) @@ -264,6 +269,9 @@ type sqlcQuerier interface { UpsertLastUpdateCheck(ctx context.Context, value string) error UpsertLogoURL(ctx context.Context, value string) error UpsertServiceBanner(ctx context.Context, value string) error + UpsertTailnetAgent(ctx context.Context, arg UpsertTailnetAgentParams) (TailnetAgent, error) + UpsertTailnetClient(ctx context.Context, arg UpsertTailnetClientParams) (TailnetClient, error) + UpsertTailnetCoordinator(ctx context.Context, id uuid.UUID) (TailnetCoordinator, error) } var _ sqlcQuerier = (*sqlQuerier)(nil) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 3f6baeb2c5..2980c75594 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -3261,6 +3261,239 @@ func (q *sqlQuerier) UpsertServiceBanner(ctx context.Context, value string) erro return err } +const deleteCoordinator = `-- name: DeleteCoordinator :exec +DELETE +FROM tailnet_coordinators +WHERE id = $1 +` + +func (q *sqlQuerier) DeleteCoordinator(ctx context.Context, id uuid.UUID) error { + _, err := q.db.ExecContext(ctx, deleteCoordinator, id) + return err +} + +const deleteTailnetAgent = `-- name: DeleteTailnetAgent :one +DELETE +FROM tailnet_agents +WHERE id = $1 and coordinator_id = $2 +RETURNING id, coordinator_id +` + +type DeleteTailnetAgentParams struct { + ID uuid.UUID `db:"id" json:"id"` + CoordinatorID uuid.UUID `db:"coordinator_id" json:"coordinator_id"` +} + +type DeleteTailnetAgentRow struct { + ID uuid.UUID `db:"id" json:"id"` + CoordinatorID uuid.UUID `db:"coordinator_id" json:"coordinator_id"` +} + +func (q *sqlQuerier) DeleteTailnetAgent(ctx context.Context, arg DeleteTailnetAgentParams) (DeleteTailnetAgentRow, error) { + row := q.db.QueryRowContext(ctx, deleteTailnetAgent, arg.ID, arg.CoordinatorID) + var i DeleteTailnetAgentRow + err := row.Scan(&i.ID, &i.CoordinatorID) + return i, err +} + +const deleteTailnetClient = `-- name: DeleteTailnetClient :one +DELETE +FROM tailnet_clients +WHERE id = $1 and coordinator_id = $2 +RETURNING id, coordinator_id +` + +type DeleteTailnetClientParams struct { + ID uuid.UUID `db:"id" json:"id"` + CoordinatorID uuid.UUID `db:"coordinator_id" json:"coordinator_id"` +} + +type DeleteTailnetClientRow struct { + ID uuid.UUID `db:"id" json:"id"` + CoordinatorID uuid.UUID `db:"coordinator_id" json:"coordinator_id"` +} + +func (q *sqlQuerier) DeleteTailnetClient(ctx context.Context, arg DeleteTailnetClientParams) (DeleteTailnetClientRow, error) { + row := q.db.QueryRowContext(ctx, deleteTailnetClient, arg.ID, arg.CoordinatorID) + var i DeleteTailnetClientRow + err := row.Scan(&i.ID, &i.CoordinatorID) + return i, err +} + +const getTailnetAgents = `-- name: GetTailnetAgents :many +SELECT id, coordinator_id, updated_at, node +FROM tailnet_agents +WHERE id = $1 +` + +func (q *sqlQuerier) GetTailnetAgents(ctx context.Context, id uuid.UUID) ([]TailnetAgent, error) { + rows, err := q.db.QueryContext(ctx, getTailnetAgents, id) + if err != nil { + return nil, err + } + defer rows.Close() + var items []TailnetAgent + for rows.Next() { + var i TailnetAgent + if err := rows.Scan( + &i.ID, + &i.CoordinatorID, + &i.UpdatedAt, + &i.Node, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getTailnetClientsForAgent = `-- name: GetTailnetClientsForAgent :many +SELECT id, coordinator_id, agent_id, updated_at, node +FROM tailnet_clients +WHERE agent_id = $1 +` + +func (q *sqlQuerier) GetTailnetClientsForAgent(ctx context.Context, agentID uuid.UUID) ([]TailnetClient, error) { + rows, err := q.db.QueryContext(ctx, getTailnetClientsForAgent, agentID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []TailnetClient + for rows.Next() { + var i TailnetClient + if err := rows.Scan( + &i.ID, + &i.CoordinatorID, + &i.AgentID, + &i.UpdatedAt, + &i.Node, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const upsertTailnetAgent = `-- name: UpsertTailnetAgent :one +INSERT INTO + tailnet_agents ( + id, + coordinator_id, + node, + updated_at +) +VALUES + ($1, $2, $3, now() at time zone 'utc') +ON CONFLICT (id, coordinator_id) +DO UPDATE SET + id = $1, + coordinator_id = $2, + node = $3, + updated_at = now() at time zone 'utc' +RETURNING id, coordinator_id, updated_at, node +` + +type UpsertTailnetAgentParams struct { + ID uuid.UUID `db:"id" json:"id"` + CoordinatorID uuid.UUID `db:"coordinator_id" json:"coordinator_id"` + Node json.RawMessage `db:"node" json:"node"` +} + +func (q *sqlQuerier) UpsertTailnetAgent(ctx context.Context, arg UpsertTailnetAgentParams) (TailnetAgent, error) { + row := q.db.QueryRowContext(ctx, upsertTailnetAgent, arg.ID, arg.CoordinatorID, arg.Node) + var i TailnetAgent + err := row.Scan( + &i.ID, + &i.CoordinatorID, + &i.UpdatedAt, + &i.Node, + ) + return i, err +} + +const upsertTailnetClient = `-- name: UpsertTailnetClient :one +INSERT INTO + tailnet_clients ( + id, + coordinator_id, + agent_id, + node, + updated_at +) +VALUES + ($1, $2, $3, $4, now() at time zone 'utc') +ON CONFLICT (id, coordinator_id) +DO UPDATE SET + id = $1, + coordinator_id = $2, + agent_id = $3, + node = $4, + updated_at = now() at time zone 'utc' +RETURNING id, coordinator_id, agent_id, updated_at, node +` + +type UpsertTailnetClientParams struct { + ID uuid.UUID `db:"id" json:"id"` + CoordinatorID uuid.UUID `db:"coordinator_id" json:"coordinator_id"` + AgentID uuid.UUID `db:"agent_id" json:"agent_id"` + Node json.RawMessage `db:"node" json:"node"` +} + +func (q *sqlQuerier) UpsertTailnetClient(ctx context.Context, arg UpsertTailnetClientParams) (TailnetClient, error) { + row := q.db.QueryRowContext(ctx, upsertTailnetClient, + arg.ID, + arg.CoordinatorID, + arg.AgentID, + arg.Node, + ) + var i TailnetClient + err := row.Scan( + &i.ID, + &i.CoordinatorID, + &i.AgentID, + &i.UpdatedAt, + &i.Node, + ) + return i, err +} + +const upsertTailnetCoordinator = `-- name: UpsertTailnetCoordinator :one +INSERT INTO + tailnet_coordinators ( + id, + heartbeat_at +) +VALUES + ($1, now() at time zone 'utc') +ON CONFLICT (id) +DO UPDATE SET + id = $1, + heartbeat_at = now() at time zone 'utc' +RETURNING id, heartbeat_at +` + +func (q *sqlQuerier) UpsertTailnetCoordinator(ctx context.Context, id uuid.UUID) (TailnetCoordinator, error) { + row := q.db.QueryRowContext(ctx, upsertTailnetCoordinator, id) + var i TailnetCoordinator + err := row.Scan(&i.ID, &i.HeartbeatAt) + return i, err +} + const getTemplateAverageBuildTime = `-- name: GetTemplateAverageBuildTime :one WITH build_times AS ( SELECT diff --git a/coderd/database/queries/tailnet.sql b/coderd/database/queries/tailnet.sql new file mode 100644 index 0000000000..e45cb480b1 --- /dev/null +++ b/coderd/database/queries/tailnet.sql @@ -0,0 +1,79 @@ +-- name: UpsertTailnetClient :one +INSERT INTO + tailnet_clients ( + id, + coordinator_id, + agent_id, + node, + updated_at +) +VALUES + ($1, $2, $3, $4, now() at time zone 'utc') +ON CONFLICT (id, coordinator_id) +DO UPDATE SET + id = $1, + coordinator_id = $2, + agent_id = $3, + node = $4, + updated_at = now() at time zone 'utc' +RETURNING *; + +-- name: UpsertTailnetAgent :one +INSERT INTO + tailnet_agents ( + id, + coordinator_id, + node, + updated_at +) +VALUES + ($1, $2, $3, now() at time zone 'utc') +ON CONFLICT (id, coordinator_id) +DO UPDATE SET + id = $1, + coordinator_id = $2, + node = $3, + updated_at = now() at time zone 'utc' +RETURNING *; + + +-- name: DeleteTailnetClient :one +DELETE +FROM tailnet_clients +WHERE id = $1 and coordinator_id = $2 +RETURNING id, coordinator_id; + +-- name: DeleteTailnetAgent :one +DELETE +FROM tailnet_agents +WHERE id = $1 and coordinator_id = $2 +RETURNING id, coordinator_id; + +-- name: DeleteCoordinator :exec +DELETE +FROM tailnet_coordinators +WHERE id = $1; + +-- name: GetTailnetAgents :many +SELECT * +FROM tailnet_agents +WHERE id = $1; + +-- name: GetTailnetClientsForAgent :many +SELECT * +FROM tailnet_clients +WHERE agent_id = $1; + +-- name: UpsertTailnetCoordinator :one +INSERT INTO + tailnet_coordinators ( + id, + heartbeat_at +) +VALUES + ($1, now() at time zone 'utc') +ON CONFLICT (id) +DO UPDATE SET + id = $1, + heartbeat_at = now() at time zone 'utc' +RETURNING *; diff --git a/coderd/rbac/object.go b/coderd/rbac/object.go index e867abfb69..060f325c60 100644 --- a/coderd/rbac/object.go +++ b/coderd/rbac/object.go @@ -173,6 +173,11 @@ var ( ResourceSystem = Object{ Type: "system", } + + // ResourceTailnetCoordinator is a pseudo-resource for use by the tailnet coordinator + ResourceTailnetCoordinator = Object{ + Type: "tailnet_coordinator", + } ) // Object is used to create objects for authz checks when you have none in diff --git a/coderd/rbac/object_gen.go b/coderd/rbac/object_gen.go index 9af80010cf..d0a7bb5e68 100644 --- a/coderd/rbac/object_gen.go +++ b/coderd/rbac/object_gen.go @@ -18,6 +18,7 @@ func AllResources() []Object { ResourceReplicas, ResourceRoleAssignment, ResourceSystem, + ResourceTailnetCoordinator, ResourceTemplate, ResourceUser, ResourceUserData, diff --git a/enterprise/tailnet/coordinator_test.go b/enterprise/tailnet/coordinator_test.go index bcc3ddca34..a29bf2ad27 100644 --- a/enterprise/tailnet/coordinator_test.go +++ b/enterprise/tailnet/coordinator_test.go @@ -95,7 +95,7 @@ func TestCoordinatorSingle(t *testing.T) { assert.NoError(t, err) close(closeAgentChan) }() - sendAgentNode(&agpl.Node{}) + sendAgentNode(&agpl.Node{PreferredDERP: 1}) require.Eventually(t, func() bool { return coordinator.Node(agentID) != nil }, testutil.WaitShort, testutil.IntervalFast) @@ -117,12 +117,12 @@ func TestCoordinatorSingle(t *testing.T) { }() agentNodes := <-clientNodeChan require.Len(t, agentNodes, 1) - sendClientNode(&agpl.Node{}) + sendClientNode(&agpl.Node{PreferredDERP: 2}) clientNodes := <-agentNodeChan require.Len(t, clientNodes, 1) // Ensure an update to the agent node reaches the client! - sendAgentNode(&agpl.Node{}) + sendAgentNode(&agpl.Node{PreferredDERP: 3}) agentNodes = <-clientNodeChan require.Len(t, agentNodes, 1) @@ -188,7 +188,7 @@ func TestCoordinatorHA(t *testing.T) { assert.NoError(t, err) close(closeAgentChan) }() - sendAgentNode(&agpl.Node{}) + sendAgentNode(&agpl.Node{PreferredDERP: 1}) require.Eventually(t, func() bool { return coordinator1.Node(agentID) != nil }, testutil.WaitShort, testutil.IntervalFast) @@ -214,13 +214,13 @@ func TestCoordinatorHA(t *testing.T) { }() agentNodes := <-clientNodeChan require.Len(t, agentNodes, 1) - sendClientNode(&agpl.Node{}) + sendClientNode(&agpl.Node{PreferredDERP: 2}) _ = sendClientNode clientNodes := <-agentNodeChan require.Len(t, clientNodes, 1) // Ensure an update to the agent node reaches the client! - sendAgentNode(&agpl.Node{}) + sendAgentNode(&agpl.Node{PreferredDERP: 3}) agentNodes = <-clientNodeChan require.Len(t, agentNodes, 1) diff --git a/enterprise/tailnet/pgcoord.go b/enterprise/tailnet/pgcoord.go new file mode 100644 index 0000000000..3fca584c28 --- /dev/null +++ b/enterprise/tailnet/pgcoord.go @@ -0,0 +1,1213 @@ +package tailnet + +import ( + "context" + "database/sql" + "encoding/json" + "io" + "net" + "net/http" + "strings" + "sync" + "time" + + "github.com/cenkalti/backoff/v4" + "github.com/google/uuid" + "golang.org/x/xerrors" + + "cdr.dev/slog" + + "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/database/pubsub" + agpl "github.com/coder/coder/tailnet" +) + +const ( + EventHeartbeats = "tailnet_coordinator_heartbeat" + eventClientUpdate = "tailnet_client_update" + eventAgentUpdate = "tailnet_agent_update" + HeartbeatPeriod = time.Second * 2 + MissedHeartbeats = 3 + numQuerierWorkers = 10 + numBinderWorkers = 10 + dbMaxBackoff = 10 * time.Second +) + +// pgCoord is a postgres-backed coordinator +// +// ┌────────┐ ┌────────┐ ┌───────┐ +// │ connIO ├───────► binder ├────────► store │ +// └───▲────┘ │ │ │ │ +// │ └────────┘ ┌──────┤ │ +// │ │ └───────┘ +// │ │ +// │ ┌──────────▼┐ ┌────────┐ +// │ │ │ │ │ +// └────────────┤ querier ◄─────┤ pubsub │ +// │ │ │ │ +// └───────────┘ └────────┘ +// +// each incoming connection (websocket) from a client or agent is wrapped in a connIO which handles reading & writing +// from it. Node updates from a connIO are sent to the binder, which writes them to the database.Store. The querier +// is responsible for querying the store for the nodes the connection needs (e.g. for a client, the corresponding +// agent). The querier receives pubsub notifications about changes, which trigger queries for the latest state. +// +// The querier also sends the coordinator's heartbeat, and monitors the heartbeats of other coordinators. When +// heartbeats cease for a coordinator, it stops using any nodes discovered from that coordinator and pushes an update +// to affected connIOs. +// +// This package uses the term "binding" to mean the act of registering an association between some connection (client +// or agent) and an agpl.Node. It uses the term "mapping" to mean the act of determining the nodes that the connection +// needs to receive (e.g. for a client, the node bound to the corresponding agent, or for an agent, the nodes bound to +// all clients of the agent). +type pgCoord struct { + ctx context.Context + logger slog.Logger + pubsub pubsub.Pubsub + store database.Store + + bindings chan binding + newConnections chan *connIO + id uuid.UUID + + cancel context.CancelFunc + closeOnce sync.Once + closed chan struct{} + + binder *binder + querier *querier +} + +// NewPGCoord creates a high-availability coordinator that stores state in the PostgreSQL database and +// receives notifications of updates via the pubsub. +func NewPGCoord(ctx context.Context, logger slog.Logger, ps pubsub.Pubsub, store database.Store) (agpl.Coordinator, error) { + ctx, cancel := context.WithCancel(ctx) + id := uuid.New() + logger = logger.Named("pgcoord").With(slog.F("coordinator_id", id)) + bCh := make(chan binding) + cCh := make(chan *connIO) + // signals when first heartbeat has been sent, so it's safe to start binding. + fHB := make(chan struct{}) + + c := &pgCoord{ + ctx: ctx, + cancel: cancel, + logger: logger, + pubsub: ps, + store: store, + binder: newBinder(ctx, logger, id, store, bCh, fHB), + bindings: bCh, + newConnections: cCh, + id: id, + querier: newQuerier(ctx, logger, ps, store, id, cCh, numQuerierWorkers, fHB), + closed: make(chan struct{}), + } + return c, nil +} + +func (*pgCoord) ServeHTTPDebug(w http.ResponseWriter, _ *http.Request) { + // TODO(spikecurtis) I'd like to hold off implementing this until after the rest of this is code reviewed. + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("Coder Enterprise PostgreSQL distributed tailnet coordinator")) +} + +func (c *pgCoord) Node(id uuid.UUID) *agpl.Node { + // In production, we only ever get this request for an agent. + // We're going to directly query the database, since we would only have the agent mapping stored locally if we had + // a client of that agent connected, which isn't always the case. + mappings, err := c.querier.queryAgent(id) + if err != nil { + c.logger.Error(c.ctx, "failed to query agents", slog.Error(err)) + } + mappings = c.querier.heartbeats.filter(mappings) + var bestT time.Time + var bestN *agpl.Node + for _, m := range mappings { + if m.updatedAt.After(bestT) { + bestN = m.node + bestT = m.updatedAt + } + } + return bestN +} + +func (c *pgCoord) ServeClient(conn net.Conn, id uuid.UUID, agent uuid.UUID) error { + defer func() { + err := conn.Close() + if err != nil { + c.logger.Debug(c.ctx, "closing client connection", + slog.F("client_id", id), + slog.F("agent_id", agent), + slog.Error(err)) + } + }() + cIO := newConnIO(c.ctx, c.logger, c.bindings, conn, id, agent) + if err := sendCtx(c.ctx, c.newConnections, cIO); err != nil { + // can only be a context error, no need to log here. + return err + } + <-cIO.ctx.Done() + return nil +} + +func (c *pgCoord) ServeAgent(conn net.Conn, id uuid.UUID, name string) error { + defer func() { + err := conn.Close() + if err != nil { + c.logger.Debug(c.ctx, "closing agent connection", + slog.F("agent_id", id), + slog.Error(err)) + } + }() + logger := c.logger.With(slog.F("name", name)) + cIO := newConnIO(c.ctx, logger, c.bindings, conn, uuid.Nil, id) + if err := sendCtx(c.ctx, c.newConnections, cIO); err != nil { + // can only be a context error, no need to log here. + return err + } + <-cIO.ctx.Done() + return nil +} + +func (c *pgCoord) Close() error { + c.cancel() + c.closeOnce.Do(func() { close(c.closed) }) + return nil +} + +// connIO manages the reading and writing to a connected client or agent. Agent connIOs have their client field set to +// uuid.Nil. It reads node updates via its decoder, then pushes them onto the bindings channel. It receives mappings +// via its updates TrackedConn, which then writes them. +type connIO struct { + pCtx context.Context + ctx context.Context + cancel context.CancelFunc + logger slog.Logger + client uuid.UUID + agent uuid.UUID + decoder *json.Decoder + updates *agpl.TrackedConn + bindings chan<- binding +} + +func newConnIO( + pCtx context.Context, logger slog.Logger, bindings chan<- binding, conn net.Conn, client, agent uuid.UUID, +) *connIO { + ctx, cancel := context.WithCancel(pCtx) + id := agent + logger = logger.With(slog.F("agent_id", agent)) + if client != uuid.Nil { + logger = logger.With(slog.F("client_id", client)) + id = client + } + c := &connIO{ + pCtx: pCtx, + ctx: ctx, + cancel: cancel, + logger: logger, + client: client, + agent: agent, + decoder: json.NewDecoder(conn), + updates: agpl.NewTrackedConn(ctx, cancel, conn, id, logger, 0), + bindings: bindings, + } + go c.recvLoop() + go c.updates.SendUpdates() + logger.Info(ctx, "serving connection") + return c +} + +func (c *connIO) recvLoop() { + defer func() { + // withdraw bindings when we exit. We need to use the parent context here, since our own context might be + // canceled, but we still need to withdraw bindings. + b := binding{ + bKey: bKey{ + client: c.client, + agent: c.agent, + }, + } + if err := sendCtx(c.pCtx, c.bindings, b); err != nil { + c.logger.Debug(c.ctx, "parent context expired while withdrawing bindings", slog.Error(err)) + } + }() + defer c.cancel() + for { + var node agpl.Node + err := c.decoder.Decode(&node) + if err != nil { + if xerrors.Is(err, io.EOF) || xerrors.Is(err, io.ErrClosedPipe) || xerrors.Is(err, context.Canceled) { + c.logger.Debug(c.ctx, "exiting recvLoop", slog.Error(err)) + } else { + c.logger.Error(c.ctx, "failed to decode Node update", slog.Error(err)) + } + return + } + c.logger.Debug(c.ctx, "got node update", slog.F("node", node)) + b := binding{ + bKey: bKey{ + client: c.client, + agent: c.agent, + }, + node: &node, + } + if err := sendCtx(c.ctx, c.bindings, b); err != nil { + c.logger.Debug(c.ctx, "recvLoop ctx expired", slog.Error(err)) + return + } + } +} + +func sendCtx[A any](ctx context.Context, c chan<- A, a A) (err error) { + select { + case <-ctx.Done(): + return ctx.Err() + case c <- a: + return nil + } +} + +// bKey, or "binding key" identifies a client or agent in a binding. Agents have their client field set to uuid.Nil. +type bKey struct { + client uuid.UUID + agent uuid.UUID +} + +// binding represents an association between a client or agent and a Node. +type binding struct { + bKey + node *agpl.Node +} + +// binder reads node bindings from the channel and writes them to the database. It handles retries with a backoff. +type binder struct { + ctx context.Context + logger slog.Logger + coordinatorID uuid.UUID + store database.Store + bindings <-chan binding + + mu sync.Mutex + latest map[bKey]binding + workQ *workQ[bKey] +} + +func newBinder(ctx context.Context, logger slog.Logger, + id uuid.UUID, store database.Store, + bindings <-chan binding, startWorkers <-chan struct{}, +) *binder { + b := &binder{ + ctx: ctx, + logger: logger, + coordinatorID: id, + store: store, + bindings: bindings, + latest: make(map[bKey]binding), + workQ: newWorkQ[bKey](ctx), + } + go b.handleBindings() + go func() { + <-startWorkers + for i := 0; i < numBinderWorkers; i++ { + go b.worker() + } + }() + return b +} + +func (b *binder) handleBindings() { + for { + select { + case <-b.ctx.Done(): + b.logger.Debug(b.ctx, "binder exiting", slog.Error(b.ctx.Err())) + return + case bnd := <-b.bindings: + b.storeBinding(bnd) + b.workQ.enqueue(bnd.bKey) + } + } +} + +func (b *binder) worker() { + eb := backoff.NewExponentialBackOff() + eb.MaxElapsedTime = 0 // retry indefinitely + eb.MaxInterval = dbMaxBackoff + bkoff := backoff.WithContext(eb, b.ctx) + for { + bk, err := b.workQ.acquire() + if err != nil { + // context expired + return + } + err = backoff.Retry(func() error { + bnd := b.retrieveBinding(bk) + return b.writeOne(bnd) + }, bkoff) + if err != nil { + bkoff.Reset() + } + b.workQ.done(bk) + } +} + +func (b *binder) writeOne(bnd binding) error { + var nodeRaw json.RawMessage + var err error + if bnd.node != nil { + nodeRaw, err = json.Marshal(*bnd.node) + if err != nil { + // this is very bad news, but it should never happen because the node was Unmarshalled by this process + // earlier. + b.logger.Error(b.ctx, "failed to marshall node", slog.Error(err)) + return err + } + } + + switch { + case bnd.client == uuid.Nil && len(nodeRaw) > 0: + _, err = b.store.UpsertTailnetAgent(b.ctx, database.UpsertTailnetAgentParams{ + ID: bnd.agent, + CoordinatorID: b.coordinatorID, + Node: nodeRaw, + }) + case bnd.client == uuid.Nil && len(nodeRaw) == 0: + _, err = b.store.DeleteTailnetAgent(b.ctx, database.DeleteTailnetAgentParams{ + ID: bnd.agent, + CoordinatorID: b.coordinatorID, + }) + if xerrors.Is(err, sql.ErrNoRows) { + // treat deletes as idempotent + err = nil + } + case bnd.client != uuid.Nil && len(nodeRaw) > 0: + _, err = b.store.UpsertTailnetClient(b.ctx, database.UpsertTailnetClientParams{ + ID: bnd.client, + CoordinatorID: b.coordinatorID, + AgentID: bnd.agent, + Node: nodeRaw, + }) + case bnd.client != uuid.Nil && len(nodeRaw) == 0: + _, err = b.store.DeleteTailnetClient(b.ctx, database.DeleteTailnetClientParams{ + ID: bnd.client, + CoordinatorID: b.coordinatorID, + }) + if xerrors.Is(err, sql.ErrNoRows) { + // treat deletes as idempotent + err = nil + } + default: + panic("unhittable") + } + if err != nil { + b.logger.Error(b.ctx, "failed to write binding to database", + slog.F("client_id", bnd.client), + slog.F("agent_id", bnd.agent), + slog.F("node", string(nodeRaw)), + slog.Error(err)) + } + return err +} + +// storeBinding stores the latest binding, where we interpret node == nil as removing the binding. This keeps the map +// from growing without bound. +func (b *binder) storeBinding(bnd binding) { + b.mu.Lock() + defer b.mu.Unlock() + if bnd.node != nil { + b.latest[bnd.bKey] = bnd + } else { + // nil node is interpreted as removing binding + delete(b.latest, bnd.bKey) + } +} + +// retrieveBinding gets the latest binding for a key. +func (b *binder) retrieveBinding(bk bKey) binding { + b.mu.Lock() + defer b.mu.Unlock() + bnd, ok := b.latest[bk] + if !ok { + bnd = binding{ + bKey: bk, + node: nil, + } + } + return bnd +} + +// mapper tracks a single client or agent ID, and fans out updates to that ID->node mapping to every local connection +// that needs it. +type mapper struct { + ctx context.Context + logger slog.Logger + + add chan *connIO + del chan *connIO + + // reads from this channel trigger sending latest nodes to + // all connections. It is used when coordinators are added + // or removed + update chan struct{} + + mappings chan []mapping + + conns map[bKey]*connIO + latest []mapping + + heartbeats *heartbeats +} + +func newMapper(ctx context.Context, logger slog.Logger, mk mKey, h *heartbeats) *mapper { + logger = logger.With( + slog.F("agent_id", mk.agent), + slog.F("clients_of_agent", mk.clientsOfAgent), + ) + m := &mapper{ + ctx: ctx, + logger: logger, + add: make(chan *connIO), + del: make(chan *connIO), + update: make(chan struct{}), + conns: make(map[bKey]*connIO), + mappings: make(chan []mapping), + heartbeats: h, + } + go m.run() + return m +} + +func (m *mapper) run() { + for { + select { + case <-m.ctx.Done(): + return + case c := <-m.add: + m.conns[bKey{c.client, c.agent}] = c + nodes := m.mappingsToNodes(m.latest) + if len(nodes) == 0 { + m.logger.Debug(m.ctx, "skipping 0 length node update") + continue + } + if err := c.updates.Enqueue(nodes); err != nil { + m.logger.Error(m.ctx, "failed to enqueue node update", slog.Error(err)) + } + case c := <-m.del: + delete(m.conns, bKey{c.client, c.agent}) + case mappings := <-m.mappings: + m.latest = mappings + nodes := m.mappingsToNodes(mappings) + if len(nodes) == 0 { + m.logger.Debug(m.ctx, "skipping 0 length node update") + continue + } + for _, conn := range m.conns { + if err := conn.updates.Enqueue(nodes); err != nil { + m.logger.Error(m.ctx, "failed to enqueue node update", slog.Error(err)) + } + } + case <-m.update: + nodes := m.mappingsToNodes(m.latest) + if len(nodes) == 0 { + m.logger.Debug(m.ctx, "skipping 0 length node update") + continue + } + for _, conn := range m.conns { + if err := conn.updates.Enqueue(nodes); err != nil { + m.logger.Error(m.ctx, "failed to enqueue triggered node update", slog.Error(err)) + } + } + } + } +} + +// mappingsToNodes takes a set of mappings and resolves the best set of nodes. We may get several mappings for a +// particular connection, from different coordinators in the distributed system. Furthermore, some coordinators +// might be considered invalid on account of missing heartbeats. We take the most recent mapping from a valid +// coordinator as the "best" mapping. +func (m *mapper) mappingsToNodes(mappings []mapping) []*agpl.Node { + mappings = m.heartbeats.filter(mappings) + best := make(map[bKey]mapping, len(mappings)) + for _, m := range mappings { + bk := bKey{client: m.client, agent: m.agent} + bestM, ok := best[bk] + if !ok || m.updatedAt.After(bestM.updatedAt) { + best[bk] = m + } + } + nodes := make([]*agpl.Node, 0, len(best)) + for _, m := range best { + nodes = append(nodes, m.node) + } + return nodes +} + +// querier is responsible for monitoring pubsub notifications and querying the database for the mappings that all +// connected clients and agents need. It also checks heartbeats and withdraws mappings from coordinators that have +// failed heartbeats. +type querier struct { + ctx context.Context + logger slog.Logger + pubsub pubsub.Pubsub + store database.Store + newConnections chan *connIO + + workQ *workQ[mKey] + heartbeats *heartbeats + updates <-chan struct{} + + mu sync.Mutex + mappers map[mKey]*countedMapper +} + +type countedMapper struct { + *mapper + count int + cancel context.CancelFunc +} + +func newQuerier( + ctx context.Context, logger slog.Logger, + ps pubsub.Pubsub, store database.Store, + self uuid.UUID, newConnections chan *connIO, numWorkers int, + firstHeartbeat chan<- struct{}, +) *querier { + updates := make(chan struct{}) + q := &querier{ + ctx: ctx, + logger: logger.Named("querier"), + pubsub: ps, + store: store, + newConnections: newConnections, + workQ: newWorkQ[mKey](ctx), + heartbeats: newHeartbeats(ctx, logger, ps, store, self, updates, firstHeartbeat), + mappers: make(map[mKey]*countedMapper), + updates: updates, + } + go q.subscribe() + go q.handleConnIO() + for i := 0; i < numWorkers; i++ { + go q.worker() + } + go q.handleUpdates() + return q +} + +func (q *querier) handleConnIO() { + for { + select { + case <-q.ctx.Done(): + return + case c := <-q.newConnections: + q.newConn(c) + } + } +} + +func (q *querier) newConn(c *connIO) { + q.mu.Lock() + defer q.mu.Unlock() + mk := mKey{ + agent: c.agent, + // if client is Nil, this is an agent connection, and it wants the mappings for all the clients of itself + clientsOfAgent: c.client == uuid.Nil, + } + cm, ok := q.mappers[mk] + if !ok { + ctx, cancel := context.WithCancel(q.ctx) + mpr := newMapper(ctx, q.logger, mk, q.heartbeats) + cm = &countedMapper{ + mapper: mpr, + count: 0, + cancel: cancel, + } + q.mappers[mk] = cm + // we don't have any mapping state for this key yet + q.workQ.enqueue(mk) + } + if err := sendCtx(cm.ctx, cm.add, c); err != nil { + return + } + cm.count++ + go q.cleanupConn(c) +} + +func (q *querier) cleanupConn(c *connIO) { + <-c.ctx.Done() + q.mu.Lock() + defer q.mu.Unlock() + mk := mKey{ + agent: c.agent, + // if client is Nil, this is an agent connection, and it wants the mappings for all the clients of itself + clientsOfAgent: c.client == uuid.Nil, + } + cm := q.mappers[mk] + if err := sendCtx(cm.ctx, cm.del, c); err != nil { + return + } + cm.count-- + if cm.count == 0 { + cm.cancel() + delete(q.mappers, mk) + } +} + +func (q *querier) worker() { + eb := backoff.NewExponentialBackOff() + eb.MaxElapsedTime = 0 // retry indefinitely + eb.MaxInterval = dbMaxBackoff + bkoff := backoff.WithContext(eb, q.ctx) + for { + mk, err := q.workQ.acquire() + if err != nil { + // context expired + return + } + err = backoff.Retry(func() error { + return q.query(mk) + }, bkoff) + if err != nil { + bkoff.Reset() + } + q.workQ.done(mk) + } +} + +func (q *querier) query(mk mKey) error { + var mappings []mapping + var err error + if mk.clientsOfAgent { + mappings, err = q.queryClientsOfAgent(mk.agent) + if err != nil { + return err + } + } else { + mappings, err = q.queryAgent(mk.agent) + if err != nil { + return err + } + } + q.mu.Lock() + mpr, ok := q.mappers[mk] + q.mu.Unlock() + if !ok { + q.logger.Debug(q.ctx, "query for missing mapper", + slog.F("agent_id", mk.agent), slog.F("clients_of_agent", mk.clientsOfAgent)) + return nil + } + mpr.mappings <- mappings + return nil +} + +func (q *querier) queryClientsOfAgent(agent uuid.UUID) ([]mapping, error) { + clients, err := q.store.GetTailnetClientsForAgent(q.ctx, agent) + if err != nil { + return nil, err + } + mappings := make([]mapping, 0, len(clients)) + for _, client := range clients { + node := new(agpl.Node) + err := json.Unmarshal(client.Node, node) + if err != nil { + q.logger.Error(q.ctx, "failed to unmarshal node", slog.Error(err)) + return nil, backoff.Permanent(err) + } + mappings = append(mappings, mapping{ + client: client.ID, + agent: client.AgentID, + coordinator: client.CoordinatorID, + updatedAt: client.UpdatedAt, + node: node, + }) + } + return mappings, nil +} + +func (q *querier) queryAgent(agentID uuid.UUID) ([]mapping, error) { + agents, err := q.store.GetTailnetAgents(q.ctx, agentID) + if err != nil { + return nil, err + } + mappings := make([]mapping, 0, len(agents)) + for _, agent := range agents { + node := new(agpl.Node) + err := json.Unmarshal(agent.Node, node) + if err != nil { + q.logger.Error(q.ctx, "failed to unmarshal node", slog.Error(err)) + return nil, backoff.Permanent(err) + } + mappings = append(mappings, mapping{ + agent: agent.ID, + coordinator: agent.CoordinatorID, + updatedAt: agent.UpdatedAt, + node: node, + }) + } + return mappings, nil +} + +func (q *querier) subscribe() { + eb := backoff.NewExponentialBackOff() + eb.MaxElapsedTime = 0 // retry indefinitely + eb.MaxInterval = dbMaxBackoff + bkoff := backoff.WithContext(eb, q.ctx) + var cancelClient context.CancelFunc + err := backoff.Retry(func() error { + cancelFn, err := q.pubsub.SubscribeWithErr(eventClientUpdate, q.listenClient) + if err != nil { + q.logger.Warn(q.ctx, "failed to subscribe to client updates", slog.Error(err)) + return err + } + cancelClient = cancelFn + return nil + }, bkoff) + if err != nil { + if q.ctx.Err() == nil { + q.logger.Error(q.ctx, "code bug: retry failed before context canceled", slog.Error(err)) + } + return + } + defer cancelClient() + bkoff.Reset() + + var cancelAgent context.CancelFunc + err = backoff.Retry(func() error { + cancelFn, err := q.pubsub.SubscribeWithErr(eventAgentUpdate, q.listenAgent) + if err != nil { + q.logger.Warn(q.ctx, "failed to subscribe to agent updates", slog.Error(err)) + return err + } + cancelAgent = cancelFn + return nil + }, bkoff) + if err != nil { + if q.ctx.Err() == nil { + q.logger.Error(q.ctx, "code bug: retry failed before context canceled", slog.Error(err)) + } + return + } + defer cancelAgent() + + // hold subscriptions open until context is canceled + <-q.ctx.Done() +} + +func (q *querier) listenClient(_ context.Context, msg []byte, err error) { + if xerrors.Is(err, pubsub.ErrDroppedMessages) { + q.logger.Warn(q.ctx, "pubsub may have dropped client updates") + // we need to schedule a full resync of client mappings + q.resyncClientMappings() + return + } + if err != nil { + q.logger.Warn(q.ctx, "unhandled pubsub error", slog.Error(err)) + } + client, agent, err := parseClientUpdate(string(msg)) + if err != nil { + q.logger.Error(q.ctx, "failed to parse client update", slog.F("msg", string(msg)), slog.Error(err)) + return + } + logger := q.logger.With(slog.F("client_id", client), slog.F("agent_id", agent)) + logger.Debug(q.ctx, "got client update") + mk := mKey{ + agent: agent, + clientsOfAgent: true, + } + q.mu.Lock() + _, ok := q.mappers[mk] + q.mu.Unlock() + if !ok { + logger.Debug(q.ctx, "ignoring update because we have no mapper") + return + } + q.workQ.enqueue(mk) +} + +func (q *querier) listenAgent(_ context.Context, msg []byte, err error) { + if xerrors.Is(err, pubsub.ErrDroppedMessages) { + q.logger.Warn(q.ctx, "pubsub may have dropped agent updates") + // we need to schedule a full resync of agent mappings + q.resyncAgentMappings() + return + } + if err != nil { + q.logger.Warn(q.ctx, "unhandled pubsub error", slog.Error(err)) + } + agent, err := parseAgentUpdate(string(msg)) + if err != nil { + q.logger.Error(q.ctx, "failed to parse agent update", slog.F("msg", string(msg)), slog.Error(err)) + return + } + logger := q.logger.With(slog.F("agent_id", agent)) + logger.Debug(q.ctx, "got agent update") + mk := mKey{ + agent: agent, + clientsOfAgent: false, + } + q.mu.Lock() + _, ok := q.mappers[mk] + q.mu.Unlock() + if !ok { + logger.Debug(q.ctx, "ignoring update because we have no mapper") + return + } + q.workQ.enqueue(mk) +} + +func (q *querier) resyncClientMappings() { + q.mu.Lock() + defer q.mu.Unlock() + for mk := range q.mappers { + if mk.clientsOfAgent { + q.workQ.enqueue(mk) + } + } +} + +func (q *querier) resyncAgentMappings() { + q.mu.Lock() + defer q.mu.Unlock() + for mk := range q.mappers { + if !mk.clientsOfAgent { + q.workQ.enqueue(mk) + } + } +} + +func (q *querier) handleUpdates() { + for { + select { + case <-q.ctx.Done(): + return + case <-q.updates: + q.updateAll() + } + } +} + +func (q *querier) updateAll() { + q.mu.Lock() + defer q.mu.Unlock() + + for _, cm := range q.mappers { + // send on goroutine to avoid holding the q.mu. Heartbeat failures come asynchronously with respect to + // other kinds of work, so it's fine to deliver the command to refresh async. + go func(m *mapper) { + // make sure we send on the _mapper_ context, not our own in case the mapper is + // shutting down or shut down. + _ = sendCtx(m.ctx, m.update, struct{}{}) + }(cm.mapper) + } +} + +func parseClientUpdate(msg string) (client, agent uuid.UUID, err error) { + parts := strings.Split(msg, ",") + if len(parts) != 2 { + return uuid.Nil, uuid.Nil, xerrors.Errorf("expected 2 parts separated by comma") + } + client, err = uuid.Parse(parts[0]) + if err != nil { + return uuid.Nil, uuid.Nil, xerrors.Errorf("failed to parse client UUID: %w", err) + } + agent, err = uuid.Parse(parts[1]) + if err != nil { + return uuid.Nil, uuid.Nil, xerrors.Errorf("failed to parse agent UUID: %w", err) + } + return client, agent, nil +} + +func parseAgentUpdate(msg string) (agent uuid.UUID, err error) { + agent, err = uuid.Parse(msg) + if err != nil { + return uuid.Nil, xerrors.Errorf("failed to parse agent UUID: %w", err) + } + return agent, nil +} + +// mKey identifies a set of node mappings we want to query. +type mKey struct { + agent uuid.UUID + // we always query based on the agent ID, but if we have client connection(s), we query the agent itself. If we + // have an agent connection, we need the node mappings for all clients of the agent. + clientsOfAgent bool +} + +// mapping associates a particular client or agent, and its respective coordinator with a node. It is generalized to +// include clients or agents: agent mappings will have client set to uuid.Nil. +type mapping struct { + client uuid.UUID + agent uuid.UUID + coordinator uuid.UUID + updatedAt time.Time + node *agpl.Node +} + +// workQ allows scheduling work based on a key. Multiple enqueue requests for the same key are coalesced, and +// only one in-progress job per key is scheduled. +type workQ[K mKey | bKey] struct { + ctx context.Context + + cond *sync.Cond + pending []K + inProgress map[K]bool +} + +func newWorkQ[K mKey | bKey](ctx context.Context) *workQ[K] { + q := &workQ[K]{ + ctx: ctx, + cond: sync.NewCond(&sync.Mutex{}), + inProgress: make(map[K]bool), + } + // wake up all waiting workers when context is done + go func() { + <-ctx.Done() + q.cond.L.Lock() + defer q.cond.L.Unlock() + q.cond.Broadcast() + }() + return q +} + +// enqueue adds the key to the workQ if it is not already pending. +func (q *workQ[K]) enqueue(key K) { + q.cond.L.Lock() + defer q.cond.L.Unlock() + for _, mk := range q.pending { + if mk == key { + // already pending, no-op + return + } + } + q.pending = append(q.pending, key) + q.cond.Signal() +} + +// acquire gets a new key to begin working on. This call blocks until work is available. After acquiring a key, the +// worker MUST call done() with the same key to mark it complete and allow new pending work to be acquired for the key. +// An error is returned if the workQ context is canceled to unblock waiting workers. +func (q *workQ[K]) acquire() (key K, err error) { + q.cond.L.Lock() + defer q.cond.L.Unlock() + for !q.workAvailable() && q.ctx.Err() == nil { + q.cond.Wait() + } + if q.ctx.Err() != nil { + return key, q.ctx.Err() + } + for i, mk := range q.pending { + _, ok := q.inProgress[mk] + if !ok { + q.pending = append(q.pending[:i], q.pending[i+1:]...) + q.inProgress[mk] = true + return mk, nil + } + } + // this should not be possible because we are holding the lock when we exit the loop that waits + panic("woke with no work available") +} + +// workAvailable returns true if there is work we can do. Must be called while holding q.cond.L +func (q workQ[K]) workAvailable() bool { + for _, mk := range q.pending { + _, ok := q.inProgress[mk] + if !ok { + return true + } + } + return false +} + +// done marks the key completed; MUST be called after acquire() for each key. +func (q *workQ[K]) done(key K) { + q.cond.L.Lock() + defer q.cond.L.Unlock() + delete(q.inProgress, key) + q.cond.Signal() +} + +// heartbeats sends heartbeats for this coordinator on a timer, and monitors heartbeats from other coordinators. If a +// coordinator misses their heartbeat, we remove it from our map of "valid" coordinators, such that we will filter out +// any mappings for it when filter() is called, and we send a signal on the update channel, which triggers all mappers +// to recompute their mappings and push them out to their connections. +type heartbeats struct { + ctx context.Context + logger slog.Logger + pubsub pubsub.Pubsub + store database.Store + self uuid.UUID + + update chan<- struct{} + firstHeartbeat chan<- struct{} + + lock sync.RWMutex + coordinators map[uuid.UUID]time.Time + timer *time.Timer +} + +func newHeartbeats( + ctx context.Context, logger slog.Logger, + ps pubsub.Pubsub, store database.Store, + self uuid.UUID, update chan<- struct{}, + firstHeartbeat chan<- struct{}, +) *heartbeats { + h := &heartbeats{ + ctx: ctx, + logger: logger, + pubsub: ps, + store: store, + self: self, + update: update, + firstHeartbeat: firstHeartbeat, + coordinators: make(map[uuid.UUID]time.Time), + } + go h.subscribe() + go h.sendBeats() + return h +} + +func (h *heartbeats) filter(mappings []mapping) []mapping { + out := make([]mapping, 0, len(mappings)) + h.lock.RLock() + defer h.lock.RUnlock() + for _, m := range mappings { + ok := m.coordinator == h.self + if !ok { + _, ok = h.coordinators[m.coordinator] + } + if ok { + out = append(out, m) + } + } + return out +} + +func (h *heartbeats) subscribe() { + eb := backoff.NewExponentialBackOff() + eb.MaxElapsedTime = 0 // retry indefinitely + eb.MaxInterval = dbMaxBackoff + bkoff := backoff.WithContext(eb, h.ctx) + var cancel context.CancelFunc + bErr := backoff.Retry(func() error { + cancelFn, err := h.pubsub.SubscribeWithErr(EventHeartbeats, h.listen) + if err != nil { + h.logger.Warn(h.ctx, "failed to subscribe to heartbeats", slog.Error(err)) + return err + } + cancel = cancelFn + return nil + }, bkoff) + if bErr != nil { + if h.ctx.Err() == nil { + h.logger.Error(h.ctx, "code bug: retry failed before context canceled", slog.Error(bErr)) + } + return + } + // cancel subscription when context finishes + defer cancel() + <-h.ctx.Done() +} + +func (h *heartbeats) listen(_ context.Context, msg []byte, err error) { + if err != nil { + // in the context of heartbeats, if we miss some messages it will be OK as long + // as we aren't disconnected for multiple beats. Still, even if we are disconnected + // for longer, there isn't much to do except log. Once we reconnect we will reinstate + // any expired coordinators that are still alive and continue on. + h.logger.Warn(h.ctx, "heartbeat notification error", slog.Error(err)) + return + } + id, err := uuid.Parse(string(msg)) + if err != nil { + h.logger.Error(h.ctx, "unable to parse heartbeat", slog.F("msg", string(msg)), slog.Error(err)) + return + } + if id == h.self { + h.logger.Debug(h.ctx, "ignoring our own heartbeat") + return + } + h.recvBeat(id) +} + +func (h *heartbeats) recvBeat(id uuid.UUID) { + h.logger.Debug(h.ctx, "got heartbeat", slog.F("other_coordinator_id", id)) + h.lock.Lock() + defer h.lock.Unlock() + var oldestTime time.Time + h.coordinators[id] = time.Now() + + if h.timer == nil { + // this can only happen for the very first beat + h.timer = time.AfterFunc(MissedHeartbeats*HeartbeatPeriod, h.checkExpiry) + h.logger.Debug(h.ctx, "set initial heartbeat timeout") + return + } + + for _, t := range h.coordinators { + if oldestTime.IsZero() || t.Before(oldestTime) { + oldestTime = t + } + } + d := time.Until(oldestTime.Add(MissedHeartbeats * HeartbeatPeriod)) + h.logger.Debug(h.ctx, "computed oldest heartbeat", slog.F("oldest", oldestTime), slog.F("time_to_expiry", d)) + // only reschedule if it's in the future. + if d > 0 { + h.timer.Reset(d) + } +} + +func (h *heartbeats) checkExpiry() { + h.logger.Debug(h.ctx, "checking heartbeat expiry") + h.lock.Lock() + now := time.Now() + expired := false + for id, t := range h.coordinators { + lastHB := now.Sub(t) + h.logger.Debug(h.ctx, "last heartbeat from coordinator", slog.F("other_coordinator_id", id), slog.F("last_heartbeat", lastHB)) + if lastHB > MissedHeartbeats*HeartbeatPeriod { + expired = true + delete(h.coordinators, id) + h.logger.Info(h.ctx, "coordinator failed heartbeat check", slog.F("other_coordinator_id", id), slog.F("last_heartbeat", lastHB)) + } + } + h.lock.Unlock() + if expired { + _ = sendCtx(h.ctx, h.update, struct{}{}) + } +} + +func (h *heartbeats) sendBeats() { + // send an initial heartbeat so that other coordinators can start using our bindings right away. + h.sendBeat() + close(h.firstHeartbeat) // signal binder it can start writing + defer h.sendDelete() + tkr := time.NewTicker(HeartbeatPeriod) + defer tkr.Stop() + for { + select { + case <-h.ctx.Done(): + h.logger.Debug(h.ctx, "ending heartbeats", slog.Error(h.ctx.Err())) + return + case <-tkr.C: + h.sendBeat() + } + } +} + +func (h *heartbeats) sendBeat() { + _, err := h.store.UpsertTailnetCoordinator(h.ctx, h.self) + if err != nil { + // just log errors, heartbeats are rescheduled on a timer + h.logger.Error(h.ctx, "failed to send heartbeat", slog.Error(err)) + return + } + h.logger.Debug(h.ctx, "sent heartbeat") +} + +func (h *heartbeats) sendDelete() { + // here we don't want to use the main context, since it will have been canceled + err := h.store.DeleteCoordinator(context.Background(), h.self) + if err != nil { + h.logger.Error(h.ctx, "failed to send coordinator delete", slog.Error(err)) + return + } + h.logger.Debug(h.ctx, "deleted coordinator") +} diff --git a/enterprise/tailnet/pgcoord_test.go b/enterprise/tailnet/pgcoord_test.go new file mode 100644 index 0000000000..b2fbb8d0f9 --- /dev/null +++ b/enterprise/tailnet/pgcoord_test.go @@ -0,0 +1,655 @@ +package tailnet_test + +import ( + "context" + "database/sql" + "encoding/json" + "io" + "net" + "sync" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/goleak" + "golang.org/x/xerrors" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" + + "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/database/dbtestutil" + "github.com/coder/coder/enterprise/tailnet" + agpl "github.com/coder/coder/tailnet" + "github.com/coder/coder/testutil" +) + +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} + +func TestPGCoordinatorSingle_ClientWithoutAgent(t *testing.T) { + t.Parallel() + if !dbtestutil.WillUsePostgres() { + t.Skip("test only with postgres") + } + store, pubsub := dbtestutil.NewDB(t) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) + defer cancel() + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + coordinator, err := tailnet.NewPGCoord(ctx, logger, pubsub, store) + require.NoError(t, err) + defer coordinator.Close() + + agentID := uuid.New() + client := newTestClient(t, coordinator, agentID) + defer client.close() + client.sendNode(&agpl.Node{PreferredDERP: 10}) + require.Eventually(t, func() bool { + clients, err := store.GetTailnetClientsForAgent(ctx, agentID) + if err != nil && !xerrors.Is(err, sql.ErrNoRows) { + t.Fatalf("database error: %v", err) + } + if len(clients) == 0 { + return false + } + var node agpl.Node + err = json.Unmarshal(clients[0].Node, &node) + assert.NoError(t, err) + assert.Equal(t, 10, node.PreferredDERP) + return true + }, testutil.WaitShort, testutil.IntervalFast) + + err = client.close() + require.NoError(t, err) + <-client.errChan + <-client.closeChan + assertEventuallyNoClientsForAgent(ctx, t, store, agentID) +} + +func TestPGCoordinatorSingle_AgentWithoutClients(t *testing.T) { + t.Parallel() + if !dbtestutil.WillUsePostgres() { + t.Skip("test only with postgres") + } + store, pubsub := dbtestutil.NewDB(t) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) + defer cancel() + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + coordinator, err := tailnet.NewPGCoord(ctx, logger, pubsub, store) + require.NoError(t, err) + defer coordinator.Close() + + agent := newTestAgent(t, coordinator) + defer agent.close() + agent.sendNode(&agpl.Node{PreferredDERP: 10}) + require.Eventually(t, func() bool { + agents, err := store.GetTailnetAgents(ctx, agent.id) + if err != nil && !xerrors.Is(err, sql.ErrNoRows) { + t.Fatalf("database error: %v", err) + } + if len(agents) == 0 { + return false + } + var node agpl.Node + err = json.Unmarshal(agents[0].Node, &node) + assert.NoError(t, err) + assert.Equal(t, 10, node.PreferredDERP) + return true + }, testutil.WaitShort, testutil.IntervalFast) + err = agent.close() + require.NoError(t, err) + <-agent.errChan + <-agent.closeChan + assertEventuallyNoAgents(ctx, t, store, agent.id) +} + +func TestPGCoordinatorSingle_AgentWithClient(t *testing.T) { + t.Parallel() + if !dbtestutil.WillUsePostgres() { + t.Skip("test only with postgres") + } + store, pubsub := dbtestutil.NewDB(t) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) + defer cancel() + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + coordinator, err := tailnet.NewPGCoord(ctx, logger, pubsub, store) + require.NoError(t, err) + defer coordinator.Close() + + agent := newTestAgent(t, coordinator) + defer agent.close() + agent.sendNode(&agpl.Node{PreferredDERP: 10}) + + client := newTestClient(t, coordinator, agent.id) + defer client.close() + + agentNodes := client.recvNodes(ctx, t) + require.Len(t, agentNodes, 1) + assert.Equal(t, 10, agentNodes[0].PreferredDERP) + client.sendNode(&agpl.Node{PreferredDERP: 11}) + clientNodes := agent.recvNodes(ctx, t) + require.Len(t, clientNodes, 1) + assert.Equal(t, 11, clientNodes[0].PreferredDERP) + + // Ensure an update to the agent node reaches the connIO! + agent.sendNode(&agpl.Node{PreferredDERP: 12}) + agentNodes = client.recvNodes(ctx, t) + require.Len(t, agentNodes, 1) + assert.Equal(t, 12, agentNodes[0].PreferredDERP) + + // Close the agent WebSocket so a new one can connect. + err = agent.close() + require.NoError(t, err) + _ = agent.recvErr(ctx, t) + agent.waitForClose(ctx, t) + + // Create a new agent connection. This is to simulate a reconnect! + agent = newTestAgent(t, coordinator, agent.id) + // Ensure the existing listening connIO sends its node immediately! + clientNodes = agent.recvNodes(ctx, t) + require.Len(t, clientNodes, 1) + assert.Equal(t, 11, clientNodes[0].PreferredDERP) + + // Send a bunch of updates in rapid succession, and test that we eventually get the latest. We don't want the + // coordinator accidentally reordering things. + for d := 13; d < 36; d++ { + agent.sendNode(&agpl.Node{PreferredDERP: d}) + } + for { + nodes := client.recvNodes(ctx, t) + if !assert.Len(t, nodes, 1) { + break + } + if nodes[0].PreferredDERP == 35 { + // got latest! + break + } + } + + err = agent.close() + require.NoError(t, err) + _ = agent.recvErr(ctx, t) + agent.waitForClose(ctx, t) + + err = client.close() + require.NoError(t, err) + _ = client.recvErr(ctx, t) + client.waitForClose(ctx, t) + + assertEventuallyNoAgents(ctx, t, store, agent.id) + assertEventuallyNoClientsForAgent(ctx, t, store, agent.id) +} + +func TestPGCoordinatorSingle_MissedHeartbeats(t *testing.T) { + t.Parallel() + if !dbtestutil.WillUsePostgres() { + t.Skip("test only with postgres") + } + store, pubsub := dbtestutil.NewDB(t) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) + defer cancel() + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + coordinator, err := tailnet.NewPGCoord(ctx, logger, pubsub, store) + require.NoError(t, err) + defer coordinator.Close() + + agent := newTestAgent(t, coordinator) + defer agent.close() + agent.sendNode(&agpl.Node{PreferredDERP: 10}) + + client := newTestClient(t, coordinator, agent.id) + defer client.close() + + nodes := client.recvNodes(ctx, t) + assertHasDERPs(t, nodes, 10) + client.sendNode(&agpl.Node{PreferredDERP: 11}) + nodes = agent.recvNodes(ctx, t) + assertHasDERPs(t, nodes, 11) + + // simulate a second coordinator via DB calls only --- our goal is to test broken heart-beating, so we can't use a + // real coordinator + fCoord := &fakeCoordinator{ + ctx: ctx, + t: t, + store: store, + id: uuid.New(), + } + start := time.Now() + fCoord.heartbeat() + fCoord.agentNode(agent.id, &agpl.Node{PreferredDERP: 12}) + nodes = client.recvNodes(ctx, t) + assertHasDERPs(t, nodes, 12) + + // when the fake coordinator misses enough heartbeats, the real coordinator should send an update with the old + // node for the agent. + nodes = client.recvNodes(ctx, t) + assert.Greater(t, time.Since(start), tailnet.HeartbeatPeriod*tailnet.MissedHeartbeats) + assertHasDERPs(t, nodes, 10) + + err = agent.close() + require.NoError(t, err) + _ = agent.recvErr(ctx, t) + agent.waitForClose(ctx, t) + + err = client.close() + require.NoError(t, err) + _ = client.recvErr(ctx, t) + client.waitForClose(ctx, t) + + assertEventuallyNoClientsForAgent(ctx, t, store, agent.id) +} + +func TestPGCoordinatorSingle_SendsHeartbeats(t *testing.T) { + t.Parallel() + if !dbtestutil.WillUsePostgres() { + t.Skip("test only with postgres") + } + store, pubsub := dbtestutil.NewDB(t) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) + defer cancel() + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + + mu := sync.Mutex{} + heartbeats := []time.Time{} + unsub, err := pubsub.SubscribeWithErr(tailnet.EventHeartbeats, func(_ context.Context, msg []byte, err error) { + assert.NoError(t, err) + mu.Lock() + defer mu.Unlock() + heartbeats = append(heartbeats, time.Now()) + }) + require.NoError(t, err) + defer unsub() + + start := time.Now() + coordinator, err := tailnet.NewPGCoord(ctx, logger, pubsub, store) + require.NoError(t, err) + defer coordinator.Close() + + require.Eventually(t, func() bool { + mu.Lock() + defer mu.Unlock() + if len(heartbeats) < 2 { + return false + } + require.Greater(t, heartbeats[0].Sub(start), time.Duration(0)) + require.Greater(t, heartbeats[1].Sub(start), time.Duration(0)) + return assert.Greater(t, heartbeats[1].Sub(heartbeats[0]), tailnet.HeartbeatPeriod*9/10) + }, testutil.WaitMedium, testutil.IntervalMedium) +} + +// TestPGCoordinatorDual_Mainline tests with 2 coordinators, one agent connected to each, and 2 clients per agent. +// +// +---------+ +// agent1 ---> | coord1 | <--- client11 (coord 1, agent 1) +// | | +// | | <--- client12 (coord 1, agent 2) +// +---------+ +// +---------+ +// agent2 ---> | coord2 | <--- client21 (coord 2, agent 1) +// | | +// | | <--- client22 (coord2, agent 2) +// +---------+ +func TestPGCoordinatorDual_Mainline(t *testing.T) { + t.Parallel() + if !dbtestutil.WillUsePostgres() { + t.Skip("test only with postgres") + } + store, pubsub := dbtestutil.NewDB(t) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) + defer cancel() + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + coord1, err := tailnet.NewPGCoord(ctx, logger, pubsub, store) + require.NoError(t, err) + defer coord1.Close() + coord2, err := tailnet.NewPGCoord(ctx, logger, pubsub, store) + require.NoError(t, err) + defer coord2.Close() + + agent1 := newTestAgent(t, coord1) + defer agent1.close() + agent2 := newTestAgent(t, coord2) + defer agent2.close() + + client11 := newTestClient(t, coord1, agent1.id) + defer client11.close() + client12 := newTestClient(t, coord1, agent2.id) + defer client12.close() + client21 := newTestClient(t, coord2, agent1.id) + defer client21.close() + client22 := newTestClient(t, coord2, agent2.id) + defer client22.close() + + client11.sendNode(&agpl.Node{PreferredDERP: 11}) + nodes := agent1.recvNodes(ctx, t) + assert.Len(t, nodes, 1) + assertHasDERPs(t, nodes, 11) + + client21.sendNode(&agpl.Node{PreferredDERP: 21}) + nodes = agent1.recvNodes(ctx, t) + assertHasDERPs(t, nodes, 21, 11) + + client22.sendNode(&agpl.Node{PreferredDERP: 22}) + nodes = agent2.recvNodes(ctx, t) + assertHasDERPs(t, nodes, 22) + + agent2.sendNode(&agpl.Node{PreferredDERP: 2}) + nodes = client22.recvNodes(ctx, t) + assertHasDERPs(t, nodes, 2) + nodes = client12.recvNodes(ctx, t) + assertHasDERPs(t, nodes, 2) + + client12.sendNode(&agpl.Node{PreferredDERP: 12}) + nodes = agent2.recvNodes(ctx, t) + assertHasDERPs(t, nodes, 12, 22) + + agent1.sendNode(&agpl.Node{PreferredDERP: 1}) + nodes = client21.recvNodes(ctx, t) + assertHasDERPs(t, nodes, 1) + nodes = client11.recvNodes(ctx, t) + assertHasDERPs(t, nodes, 1) + + // let's close coord2 + err = coord2.Close() + require.NoError(t, err) + + // this closes agent2, client22, client21 + err = agent2.recvErr(ctx, t) + require.ErrorIs(t, err, io.EOF) + err = client22.recvErr(ctx, t) + require.ErrorIs(t, err, io.EOF) + err = client21.recvErr(ctx, t) + require.ErrorIs(t, err, io.EOF) + + // agent1 will see an update that drops client21. + // In this case the update is superfluous because client11's node hasn't changed, and agents don't deprogram clients + // from the dataplane even if they are missing. Suppressing this kind of update would require the coordinator to + // store all the data its sent to each connection, so we don't bother. + nodes = agent1.recvNodes(ctx, t) + assertHasDERPs(t, nodes, 11) + + // note that although agent2 is disconnected, client12 does NOT get an update because we suppress empty updates. + // (Its easy to tell these are superfluous.) + + assertEventuallyNoAgents(ctx, t, store, agent2.id) + + // Close coord1 + err = coord1.Close() + require.NoError(t, err) + // this closes agent1, client12, client11 + err = agent1.recvErr(ctx, t) + require.ErrorIs(t, err, io.EOF) + err = client12.recvErr(ctx, t) + require.ErrorIs(t, err, io.EOF) + err = client11.recvErr(ctx, t) + require.ErrorIs(t, err, io.EOF) + + // wait for all connections to close + err = agent1.close() + require.NoError(t, err) + agent1.waitForClose(ctx, t) + + err = agent2.close() + require.NoError(t, err) + agent2.waitForClose(ctx, t) + + err = client11.close() + require.NoError(t, err) + client11.waitForClose(ctx, t) + + err = client12.close() + require.NoError(t, err) + client12.waitForClose(ctx, t) + + err = client21.close() + require.NoError(t, err) + client21.waitForClose(ctx, t) + + err = client22.close() + require.NoError(t, err) + client22.waitForClose(ctx, t) + + assertEventuallyNoAgents(ctx, t, store, agent1.id) + assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) + assertEventuallyNoClientsForAgent(ctx, t, store, agent2.id) +} + +// TestPGCoordinator_MultiAgent tests when a single agent connects to multiple coordinators. +// We use two agent connections, but they share the same AgentID. This could happen due to a reconnection, +// or an infrastructure problem where an old workspace is not fully cleaned up before a new one started. +// +// +---------+ +// agent1 ---> | coord1 | +// +---------+ +// +---------+ +// agent2 ---> | coord2 | +// +---------+ +// +---------+ +// | coord3 | <--- client +// +---------+ +func TestPGCoordinator_MultiAgent(t *testing.T) { + t.Parallel() + if !dbtestutil.WillUsePostgres() { + t.Skip("test only with postgres") + } + store, pubsub := dbtestutil.NewDB(t) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) + defer cancel() + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + coord1, err := tailnet.NewPGCoord(ctx, logger, pubsub, store) + require.NoError(t, err) + defer coord1.Close() + coord2, err := tailnet.NewPGCoord(ctx, logger, pubsub, store) + require.NoError(t, err) + defer coord2.Close() + coord3, err := tailnet.NewPGCoord(ctx, logger, pubsub, store) + require.NoError(t, err) + defer coord3.Close() + + agent1 := newTestAgent(t, coord1) + defer agent1.close() + agent2 := newTestAgent(t, coord2, agent1.id) + defer agent2.close() + + client := newTestClient(t, coord3, agent1.id) + defer client.close() + + client.sendNode(&agpl.Node{PreferredDERP: 3}) + nodes := agent1.recvNodes(ctx, t) + assertHasDERPs(t, nodes, 3) + nodes = agent2.recvNodes(ctx, t) + assertHasDERPs(t, nodes, 3) + + agent1.sendNode(&agpl.Node{PreferredDERP: 1}) + nodes = client.recvNodes(ctx, t) + assertHasDERPs(t, nodes, 1) + + // agent2's update overrides agent1 because it is newer + agent2.sendNode(&agpl.Node{PreferredDERP: 2}) + nodes = client.recvNodes(ctx, t) + assertHasDERPs(t, nodes, 2) + + // agent2 disconnects, and we should revert back to agent1 + err = agent2.close() + require.NoError(t, err) + err = agent2.recvErr(ctx, t) + require.ErrorIs(t, err, io.ErrClosedPipe) + agent2.waitForClose(ctx, t) + nodes = client.recvNodes(ctx, t) + assertHasDERPs(t, nodes, 1) + + agent1.sendNode(&agpl.Node{PreferredDERP: 11}) + nodes = client.recvNodes(ctx, t) + assertHasDERPs(t, nodes, 11) + + client.sendNode(&agpl.Node{PreferredDERP: 31}) + nodes = agent1.recvNodes(ctx, t) + assertHasDERPs(t, nodes, 31) + + err = agent1.close() + require.NoError(t, err) + err = agent1.recvErr(ctx, t) + require.ErrorIs(t, err, io.ErrClosedPipe) + agent1.waitForClose(ctx, t) + + err = client.close() + require.NoError(t, err) + err = client.recvErr(ctx, t) + require.ErrorIs(t, err, io.ErrClosedPipe) + client.waitForClose(ctx, t) + + assertEventuallyNoClientsForAgent(ctx, t, store, agent1.id) + assertEventuallyNoAgents(ctx, t, store, agent1.id) +} + +type testConn struct { + ws, serverWS net.Conn + nodeChan chan []*agpl.Node + sendNode func(node *agpl.Node) + errChan <-chan error + id uuid.UUID + closeChan chan struct{} +} + +func newTestConn(ids []uuid.UUID) *testConn { + a := &testConn{} + a.ws, a.serverWS = net.Pipe() + a.nodeChan = make(chan []*agpl.Node) + a.sendNode, a.errChan = agpl.ServeCoordinator(a.ws, func(nodes []*agpl.Node) error { + a.nodeChan <- nodes + return nil + }) + if len(ids) > 1 { + panic("too many") + } + if len(ids) == 1 { + a.id = ids[0] + } else { + a.id = uuid.New() + } + a.closeChan = make(chan struct{}) + return a +} + +func newTestAgent(t *testing.T, coord agpl.Coordinator, id ...uuid.UUID) *testConn { + a := newTestConn(id) + go func() { + err := coord.ServeAgent(a.serverWS, a.id, "") + assert.NoError(t, err) + close(a.closeChan) + }() + return a +} + +func (c *testConn) close() error { + return c.ws.Close() +} + +func (c *testConn) recvNodes(ctx context.Context, t *testing.T) []*agpl.Node { + t.Helper() + select { + case <-ctx.Done(): + t.Fatal("timeout receiving nodes") + return nil + case nodes := <-c.nodeChan: + return nodes + } +} + +func (c *testConn) recvErr(ctx context.Context, t *testing.T) error { + t.Helper() + select { + case <-ctx.Done(): + t.Fatal("timeout receiving error") + return ctx.Err() + case err := <-c.errChan: + return err + } +} + +func (c *testConn) waitForClose(ctx context.Context, t *testing.T) { + t.Helper() + select { + case <-ctx.Done(): + t.Fatal("timeout waiting for connection to close") + return + case <-c.closeChan: + return + } +} + +func newTestClient(t *testing.T, coord agpl.Coordinator, agentID uuid.UUID, id ...uuid.UUID) *testConn { + c := newTestConn(id) + go func() { + err := coord.ServeClient(c.serverWS, c.id, agentID) + assert.NoError(t, err) + close(c.closeChan) + }() + return c +} + +func assertHasDERPs(t *testing.T, nodes []*agpl.Node, expected ...int) { + if !assert.Len(t, nodes, len(expected), "expected %d node(s), got %d", len(expected), len(nodes)) { + return + } + derps := make([]int, 0, len(nodes)) + for _, n := range nodes { + derps = append(derps, n.PreferredDERP) + } + for _, e := range expected { + assert.Contains(t, derps, e, "expected DERP %v, got %v", e, derps) + } +} + +func assertEventuallyNoAgents(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID) { + assert.Eventually(t, func() bool { + agents, err := store.GetTailnetAgents(ctx, agentID) + if xerrors.Is(err, sql.ErrNoRows) { + return true + } + if err != nil { + t.Fatal(err) + } + return len(agents) == 0 + }, testutil.WaitShort, testutil.IntervalFast) +} + +func assertEventuallyNoClientsForAgent(ctx context.Context, t *testing.T, store database.Store, agentID uuid.UUID) { + assert.Eventually(t, func() bool { + clients, err := store.GetTailnetClientsForAgent(ctx, agentID) + if xerrors.Is(err, sql.ErrNoRows) { + return true + } + if err != nil { + t.Fatal(err) + } + return len(clients) == 0 + }, testutil.WaitShort, testutil.IntervalFast) +} + +type fakeCoordinator struct { + ctx context.Context + t *testing.T + store database.Store + id uuid.UUID +} + +func (c *fakeCoordinator) heartbeat() { + c.t.Helper() + _, err := c.store.UpsertTailnetCoordinator(c.ctx, c.id) + require.NoError(c.t, err) +} + +func (c *fakeCoordinator) agentNode(agentID uuid.UUID, node *agpl.Node) { + c.t.Helper() + nodeRaw, err := json.Marshal(node) + require.NoError(c.t, err) + _, err = c.store.UpsertTailnetAgent(c.ctx, database.UpsertTailnetAgentParams{ + ID: agentID, + CoordinatorID: c.id, + Node: nodeRaw, + }) + require.NoError(c.t, err) +} diff --git a/tailnet/coordinator.go b/tailnet/coordinator.go index f0e0c1475b..ee675ef666 100644 --- a/tailnet/coordinator.go +++ b/tailnet/coordinator.go @@ -1,6 +1,7 @@ package tailnet import ( + "bytes" "context" "encoding/json" "errors" @@ -174,11 +175,12 @@ func newCore(logger slog.Logger) *core { var ErrWouldBlock = xerrors.New("would block") type TrackedConn struct { - ctx context.Context - cancel func() - conn net.Conn - updates chan []*Node - logger slog.Logger + ctx context.Context + cancel func() + conn net.Conn + updates chan []*Node + logger slog.Logger + lastData []byte // ID is an ephemeral UUID used to uniquely identify the owner of the // connection. @@ -224,6 +226,10 @@ func (t *TrackedConn) SendUpdates() { t.logger.Error(t.ctx, "unable to marshal nodes update", slog.Error(err), slog.F("nodes", nodes)) return } + if bytes.Equal(t.lastData, data) { + t.logger.Debug(t.ctx, "skipping duplicate update", slog.F("nodes", nodes)) + continue + } // Set a deadline so that hung connections don't put back pressure on the system. // Node updates are tiny, so even the dinkiest connection can handle them if it's not hung. @@ -255,6 +261,7 @@ func (t *TrackedConn) SendUpdates() { _ = t.Close() return } + t.lastData = data } } } diff --git a/tailnet/coordinator_test.go b/tailnet/coordinator_test.go index 94c6f6da58..300a89ad5f 100644 --- a/tailnet/coordinator_test.go +++ b/tailnet/coordinator_test.go @@ -96,7 +96,7 @@ func TestCoordinator(t *testing.T) { assert.NoError(t, err) close(closeAgentChan) }() - sendAgentNode(&tailnet.Node{}) + sendAgentNode(&tailnet.Node{PreferredDERP: 1}) require.Eventually(t, func() bool { return coordinator.Node(agentID) != nil }, testutil.WaitShort, testutil.IntervalFast) @@ -122,7 +122,7 @@ func TestCoordinator(t *testing.T) { case <-ctx.Done(): t.Fatal("timed out") } - sendClientNode(&tailnet.Node{}) + sendClientNode(&tailnet.Node{PreferredDERP: 2}) clientNodes := <-agentNodeChan require.Len(t, clientNodes, 1) @@ -131,7 +131,7 @@ func TestCoordinator(t *testing.T) { time.Sleep(tailnet.WriteTimeout * 3 / 2) // Ensure an update to the agent node reaches the client! - sendAgentNode(&tailnet.Node{}) + sendAgentNode(&tailnet.Node{PreferredDERP: 3}) select { case agentNodes := <-clientNodeChan: require.Len(t, agentNodes, 1) @@ -193,7 +193,7 @@ func TestCoordinator(t *testing.T) { assert.NoError(t, err) close(closeAgentChan1) }() - sendAgentNode1(&tailnet.Node{}) + sendAgentNode1(&tailnet.Node{PreferredDERP: 1}) require.Eventually(t, func() bool { return coordinator.Node(agentID) != nil }, testutil.WaitShort, testutil.IntervalFast) @@ -215,12 +215,12 @@ func TestCoordinator(t *testing.T) { }() agentNodes := <-clientNodeChan require.Len(t, agentNodes, 1) - sendClientNode(&tailnet.Node{}) + sendClientNode(&tailnet.Node{PreferredDERP: 2}) clientNodes := <-agentNodeChan1 require.Len(t, clientNodes, 1) // Ensure an update to the agent node reaches the client! - sendAgentNode1(&tailnet.Node{}) + sendAgentNode1(&tailnet.Node{PreferredDERP: 3}) agentNodes = <-clientNodeChan require.Len(t, agentNodes, 1)