feat: Add high availability for multiple replicas (#4555)

* feat: HA tailnet coordinator

* fixup! feat: HA tailnet coordinator

* fixup! feat: HA tailnet coordinator

* remove printlns

* close all connections on coordinator

* impelement high availability feature

* fixup! impelement high availability feature

* fixup! impelement high availability feature

* fixup! impelement high availability feature

* fixup! impelement high availability feature

* Add replicas

* Add DERP meshing to arbitrary addresses

* Move packages to highavailability folder

* Move coordinator to high availability package

* Add flags for HA

* Rename to replicasync

* Denest packages for replicas

* Add test for multiple replicas

* Fix coordination test

* Add HA to the helm chart

* Rename function pointer

* Add warnings for HA

* Add the ability to block endpoints

* Add flag to disable P2P connections

* Wow, I made the tests pass

* Add replicas endpoint

* Ensure close kills replica

* Update sql

* Add database latency to high availability

* Pipe TLS to DERP mesh

* Fix DERP mesh with TLS

* Add tests for TLS

* Fix replica sync TLS

* Fix RootCA for replica meshing

* Remove ID from replicasync

* Fix getting certificates for meshing

* Remove excessive locking

* Fix linting

* Store mesh key in the database

* Fix replica key for tests

* Fix types gen

* Fix unlocking unlocked

* Fix race in tests

* Update enterprise/derpmesh/derpmesh.go

Co-authored-by: Colin Adler <colin1adler@gmail.com>

* Rename to syncReplicas

* Reuse http client

* Delete old replicas on a CRON

* Fix race condition in connection tests

* Fix linting

* Fix nil type

* Move pubsub to in-memory for twenty test

* Add comment for configuration tweaking

* Fix leak with transport

* Fix close leak in derpmesh

* Fix race when creating server

* Remove handler update

* Skip test on Windows

* Fix DERP mesh test

* Wrap HTTP handler replacement in mutex

* Fix error message for relay

* Fix API handler for normal tests

* Fix speedtest

* Fix replica resend

* Fix derpmesh send

* Ping async

* Increase wait time of template version jobd

* Fix race when closing replica sync

* Add name to client

* Log the derpmap being used

* Don't connect if DERP is empty

* Improve agent coordinator logging

* Fix lock in coordinator

* Fix relay addr

* Fix race when updating durations

* Fix client publish race

* Run pubsub loop in a queue

* Store agent nodes in order

* Fix coordinator locking

* Check for closed pipe

Co-authored-by: Colin Adler <colin1adler@gmail.com>
This commit is contained in:
Kyle Carberry
2022-10-17 08:43:30 -05:00
committed by GitHub
parent dc3519e973
commit 2ba4a62a0d
76 changed files with 3437 additions and 404 deletions

View File

@ -72,7 +72,7 @@ func TestWorkspaceActivityBump(t *testing.T) {
"deadline %v never updated", firstDeadline,
)
require.WithinDuration(t, database.Now().Add(time.Hour), workspace.LatestBuild.Deadline.Time, time.Second)
require.WithinDuration(t, database.Now().Add(time.Hour), workspace.LatestBuild.Deadline.Time, 3*time.Second)
}
}
@ -82,7 +82,9 @@ func TestWorkspaceActivityBump(t *testing.T) {
client, workspace, assertBumped := setupActivityTest(t)
resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID)
conn, err := client.DialWorkspaceAgentTailnet(ctx, slogtest.Make(t, nil), resources[0].Agents[0].ID)
conn, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, &codersdk.DialWorkspaceAgentOptions{
Logger: slogtest.Make(t, nil),
})
require.NoError(t, err)
defer conn.Close()

View File

@ -1,6 +1,7 @@
package coderd
import (
"crypto/tls"
"crypto/x509"
"io"
"net/http"
@ -82,7 +83,10 @@ type Options struct {
TracerProvider trace.TracerProvider
AutoImportTemplates []AutoImportTemplate
TailnetCoordinator *tailnet.Coordinator
// TLSCertificates is used to mesh DERP servers securely.
TLSCertificates []tls.Certificate
TailnetCoordinator tailnet.Coordinator
DERPServer *derp.Server
DERPMap *tailcfg.DERPMap
MetricsCacheRefreshInterval time.Duration
@ -130,6 +134,9 @@ func New(options *Options) *API {
if options.TailnetCoordinator == nil {
options.TailnetCoordinator = tailnet.NewCoordinator()
}
if options.DERPServer == nil {
options.DERPServer = derp.NewServer(key.NewNode(), tailnet.Logger(options.Logger.Named("derp")))
}
if options.Auditor == nil {
options.Auditor = audit.NewNop()
}
@ -168,7 +175,7 @@ func New(options *Options) *API {
api.Auditor.Store(&options.Auditor)
api.WorkspaceQuotaEnforcer.Store(&options.WorkspaceQuotaEnforcer)
api.workspaceAgentCache = wsconncache.New(api.dialWorkspaceAgentTailnet, 0)
api.derpServer = derp.NewServer(key.NewNode(), tailnet.Logger(options.Logger))
api.TailnetCoordinator.Store(&options.TailnetCoordinator)
oauthConfigs := &httpmw.OAuth2Configs{
Github: options.GithubOAuth2Config,
OIDC: options.OIDCConfig,
@ -246,7 +253,7 @@ func New(options *Options) *API {
r.Route("/%40{user}/{workspace_and_agent}/apps/{workspaceapp}", apps)
r.Route("/@{user}/{workspace_and_agent}/apps/{workspaceapp}", apps)
r.Route("/derp", func(r chi.Router) {
r.Get("/", derphttp.Handler(api.derpServer).ServeHTTP)
r.Get("/", derphttp.Handler(api.DERPServer).ServeHTTP)
// This is used when UDP is blocked, and latency must be checked via HTTP(s).
r.Get("/latency-check", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
@ -550,6 +557,7 @@ type API struct {
Auditor atomic.Pointer[audit.Auditor]
WorkspaceClientCoordinateOverride atomic.Pointer[func(rw http.ResponseWriter) bool]
WorkspaceQuotaEnforcer atomic.Pointer[workspacequota.Enforcer]
TailnetCoordinator atomic.Pointer[tailnet.Coordinator]
HTTPAuth *HTTPAuthorizer
// APIHandler serves "/api/v2"
@ -557,7 +565,6 @@ type API struct {
// RootHandler serves "/"
RootHandler chi.Router
derpServer *derp.Server
metricsCache *metricscache.Cache
siteHandler http.Handler
websocketWaitMutex sync.Mutex
@ -572,7 +579,10 @@ func (api *API) Close() error {
api.websocketWaitMutex.Unlock()
api.metricsCache.Close()
coordinator := api.TailnetCoordinator.Load()
if coordinator != nil {
_ = (*coordinator).Close()
}
return api.workspaceAgentCache.Close()
}

View File

@ -7,6 +7,7 @@ import (
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/base64"
@ -23,6 +24,7 @@ import (
"regexp"
"strconv"
"strings"
"sync"
"testing"
"time"
@ -37,8 +39,10 @@ import (
"golang.org/x/xerrors"
"google.golang.org/api/idtoken"
"google.golang.org/api/option"
"tailscale.com/derp"
"tailscale.com/net/stun/stuntest"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"tailscale.com/types/nettype"
"cdr.dev/slog"
@ -60,6 +64,7 @@ import (
"github.com/coder/coder/provisionerd"
"github.com/coder/coder/provisionersdk"
"github.com/coder/coder/provisionersdk/proto"
"github.com/coder/coder/tailnet"
"github.com/coder/coder/testutil"
)
@ -77,12 +82,19 @@ type Options struct {
AutobuildTicker <-chan time.Time
AutobuildStats chan<- executor.Stats
Auditor audit.Auditor
TLSCertificates []tls.Certificate
// IncludeProvisionerDaemon when true means to start an in-memory provisionerD
IncludeProvisionerDaemon bool
MetricsCacheRefreshInterval time.Duration
AgentStatsRefreshInterval time.Duration
DeploymentFlags *codersdk.DeploymentFlags
// Overriding the database is heavily discouraged.
// It should only be used in cases where multiple Coder
// test instances are running against the same database.
Database database.Store
Pubsub database.Pubsub
}
// New constructs a codersdk client connected to an in-memory API instance.
@ -116,7 +128,7 @@ func newWithCloser(t *testing.T, options *Options) (*codersdk.Client, io.Closer)
return client, closer
}
func NewOptions(t *testing.T, options *Options) (*httptest.Server, context.CancelFunc, *coderd.Options) {
func NewOptions(t *testing.T, options *Options) (func(http.Handler), context.CancelFunc, *coderd.Options) {
if options == nil {
options = &Options{}
}
@ -137,23 +149,40 @@ func NewOptions(t *testing.T, options *Options) (*httptest.Server, context.Cance
close(options.AutobuildStats)
})
}
db, pubsub := dbtestutil.NewDB(t)
if options.Database == nil {
options.Database, options.Pubsub = dbtestutil.NewDB(t)
}
ctx, cancelFunc := context.WithCancel(context.Background())
lifecycleExecutor := executor.New(
ctx,
db,
options.Database,
slogtest.Make(t, nil).Named("autobuild.executor").Leveled(slog.LevelDebug),
options.AutobuildTicker,
).WithStatsChannel(options.AutobuildStats)
lifecycleExecutor.Run()
srv := httptest.NewUnstartedServer(nil)
var mutex sync.RWMutex
var handler http.Handler
srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mutex.RLock()
defer mutex.RUnlock()
if handler != nil {
handler.ServeHTTP(w, r)
}
}))
srv.Config.BaseContext = func(_ net.Listener) context.Context {
return ctx
}
srv.Start()
if options.TLSCertificates != nil {
srv.TLS = &tls.Config{
Certificates: options.TLSCertificates,
MinVersion: tls.VersionTLS12,
}
srv.StartTLS()
} else {
srv.Start()
}
t.Cleanup(srv.Close)
tcpAddr, ok := srv.Listener.Addr().(*net.TCPAddr)
@ -169,6 +198,9 @@ func NewOptions(t *testing.T, options *Options) (*httptest.Server, context.Cance
stunAddr, stunCleanup := stuntest.ServeWithPacketListener(t, nettype.Std{})
t.Cleanup(stunCleanup)
derpServer := derp.NewServer(key.NewNode(), tailnet.Logger(slogtest.Make(t, nil).Named("derp")))
derpServer.SetMeshKey("test-key")
// match default with cli default
if options.SSHKeygenAlgorithm == "" {
options.SSHKeygenAlgorithm = gitsshkey.AlgorithmEd25519
@ -181,53 +213,59 @@ func NewOptions(t *testing.T, options *Options) (*httptest.Server, context.Cance
require.NoError(t, err)
}
return srv, cancelFunc, &coderd.Options{
AgentConnectionUpdateFrequency: 150 * time.Millisecond,
// Force a long disconnection timeout to ensure
// agents are not marked as disconnected during slow tests.
AgentInactiveDisconnectTimeout: testutil.WaitShort,
AccessURL: serverURL,
AppHostname: options.AppHostname,
AppHostnameRegex: appHostnameRegex,
Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug),
CacheDir: t.TempDir(),
Database: db,
Pubsub: pubsub,
return func(h http.Handler) {
mutex.Lock()
defer mutex.Unlock()
handler = h
}, cancelFunc, &coderd.Options{
AgentConnectionUpdateFrequency: 150 * time.Millisecond,
// Force a long disconnection timeout to ensure
// agents are not marked as disconnected during slow tests.
AgentInactiveDisconnectTimeout: testutil.WaitShort,
AccessURL: serverURL,
AppHostname: options.AppHostname,
AppHostnameRegex: appHostnameRegex,
Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug),
CacheDir: t.TempDir(),
Database: options.Database,
Pubsub: options.Pubsub,
Auditor: options.Auditor,
AWSCertificates: options.AWSCertificates,
AzureCertificates: options.AzureCertificates,
GithubOAuth2Config: options.GithubOAuth2Config,
OIDCConfig: options.OIDCConfig,
GoogleTokenValidator: options.GoogleTokenValidator,
SSHKeygenAlgorithm: options.SSHKeygenAlgorithm,
APIRateLimit: options.APIRateLimit,
Authorizer: options.Authorizer,
Telemetry: telemetry.NewNoop(),
DERPMap: &tailcfg.DERPMap{
Regions: map[int]*tailcfg.DERPRegion{
1: {
EmbeddedRelay: true,
RegionID: 1,
RegionCode: "coder",
RegionName: "Coder",
Nodes: []*tailcfg.DERPNode{{
Name: "1a",
RegionID: 1,
IPv4: "127.0.0.1",
DERPPort: derpPort,
STUNPort: stunAddr.Port,
InsecureForTests: true,
ForceHTTP: true,
}},
Auditor: options.Auditor,
AWSCertificates: options.AWSCertificates,
AzureCertificates: options.AzureCertificates,
GithubOAuth2Config: options.GithubOAuth2Config,
OIDCConfig: options.OIDCConfig,
GoogleTokenValidator: options.GoogleTokenValidator,
SSHKeygenAlgorithm: options.SSHKeygenAlgorithm,
DERPServer: derpServer,
APIRateLimit: options.APIRateLimit,
Authorizer: options.Authorizer,
Telemetry: telemetry.NewNoop(),
TLSCertificates: options.TLSCertificates,
DERPMap: &tailcfg.DERPMap{
Regions: map[int]*tailcfg.DERPRegion{
1: {
EmbeddedRelay: true,
RegionID: 1,
RegionCode: "coder",
RegionName: "Coder",
Nodes: []*tailcfg.DERPNode{{
Name: "1a",
RegionID: 1,
IPv4: "127.0.0.1",
DERPPort: derpPort,
STUNPort: stunAddr.Port,
InsecureForTests: true,
ForceHTTP: options.TLSCertificates == nil,
}},
},
},
},
},
AutoImportTemplates: options.AutoImportTemplates,
MetricsCacheRefreshInterval: options.MetricsCacheRefreshInterval,
AgentStatsRefreshInterval: options.AgentStatsRefreshInterval,
DeploymentFlags: options.DeploymentFlags,
}
AutoImportTemplates: options.AutoImportTemplates,
MetricsCacheRefreshInterval: options.MetricsCacheRefreshInterval,
AgentStatsRefreshInterval: options.AgentStatsRefreshInterval,
DeploymentFlags: options.DeploymentFlags,
}
}
// NewWithAPI constructs an in-memory API instance and returns a client to talk to it.
@ -237,10 +275,10 @@ func NewWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c
if options == nil {
options = &Options{}
}
srv, cancelFunc, newOptions := NewOptions(t, options)
setHandler, cancelFunc, newOptions := NewOptions(t, options)
// We set the handler after server creation for the access URL.
coderAPI := coderd.New(newOptions)
srv.Config.Handler = coderAPI.RootHandler
setHandler(coderAPI.RootHandler)
var provisionerCloser io.Closer = nopcloser{}
if options.IncludeProvisionerDaemon {
provisionerCloser = NewProvisionerDaemon(t, coderAPI)
@ -459,7 +497,7 @@ func AwaitTemplateVersionJob(t *testing.T, client *codersdk.Client, version uuid
var err error
templateVersion, err = client.TemplateVersion(context.Background(), version)
return assert.NoError(t, err) && templateVersion.Job.CompletedAt != nil
}, testutil.WaitShort, testutil.IntervalFast)
}, testutil.WaitMedium, testutil.IntervalFast)
return templateVersion
}

View File

@ -107,11 +107,17 @@ type data struct {
workspaceApps []database.WorkspaceApp
workspaces []database.Workspace
licenses []database.License
replicas []database.Replica
deploymentID string
derpMeshKey string
lastLicenseID int32
}
func (*fakeQuerier) Ping(_ context.Context) (time.Duration, error) {
return 0, nil
}
// InTx doesn't rollback data properly for in-memory yet.
func (q *fakeQuerier) InTx(fn func(database.Store) error) error {
q.mutex.Lock()
@ -2931,6 +2937,21 @@ func (q *fakeQuerier) GetDeploymentID(_ context.Context) (string, error) {
return q.deploymentID, nil
}
func (q *fakeQuerier) InsertDERPMeshKey(_ context.Context, id string) error {
q.mutex.Lock()
defer q.mutex.Unlock()
q.derpMeshKey = id
return nil
}
func (q *fakeQuerier) GetDERPMeshKey(_ context.Context) (string, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
return q.derpMeshKey, nil
}
func (q *fakeQuerier) InsertLicense(
_ context.Context, arg database.InsertLicenseParams,
) (database.License, error) {
@ -3196,3 +3217,70 @@ func (q *fakeQuerier) DeleteGroupByID(_ context.Context, id uuid.UUID) error {
return sql.ErrNoRows
}
func (q *fakeQuerier) DeleteReplicasUpdatedBefore(_ context.Context, before time.Time) error {
q.mutex.Lock()
defer q.mutex.Unlock()
for i, replica := range q.replicas {
if replica.UpdatedAt.Before(before) {
q.replicas = append(q.replicas[:i], q.replicas[i+1:]...)
}
}
return nil
}
func (q *fakeQuerier) InsertReplica(_ context.Context, arg database.InsertReplicaParams) (database.Replica, error) {
q.mutex.Lock()
defer q.mutex.Unlock()
replica := database.Replica{
ID: arg.ID,
CreatedAt: arg.CreatedAt,
StartedAt: arg.StartedAt,
UpdatedAt: arg.UpdatedAt,
Hostname: arg.Hostname,
RegionID: arg.RegionID,
RelayAddress: arg.RelayAddress,
Version: arg.Version,
DatabaseLatency: arg.DatabaseLatency,
}
q.replicas = append(q.replicas, replica)
return replica, nil
}
func (q *fakeQuerier) UpdateReplica(_ context.Context, arg database.UpdateReplicaParams) (database.Replica, error) {
q.mutex.Lock()
defer q.mutex.Unlock()
for index, replica := range q.replicas {
if replica.ID != arg.ID {
continue
}
replica.Hostname = arg.Hostname
replica.StartedAt = arg.StartedAt
replica.StoppedAt = arg.StoppedAt
replica.UpdatedAt = arg.UpdatedAt
replica.RelayAddress = arg.RelayAddress
replica.RegionID = arg.RegionID
replica.Version = arg.Version
replica.Error = arg.Error
replica.DatabaseLatency = arg.DatabaseLatency
q.replicas[index] = replica
return replica, nil
}
return database.Replica{}, sql.ErrNoRows
}
func (q *fakeQuerier) GetReplicasUpdatedAfter(_ context.Context, updatedAt time.Time) ([]database.Replica, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
replicas := make([]database.Replica, 0)
for _, replica := range q.replicas {
if replica.UpdatedAt.After(updatedAt) && !replica.StoppedAt.Valid {
replicas = append(replicas, replica)
}
}
return replicas, nil
}

View File

@ -12,6 +12,7 @@ import (
"context"
"database/sql"
"errors"
"time"
"github.com/jmoiron/sqlx"
"golang.org/x/xerrors"
@ -24,6 +25,7 @@ type Store interface {
// customQuerier contains custom queries that are not generated.
customQuerier
Ping(ctx context.Context) (time.Duration, error)
InTx(func(Store) error) error
}
@ -58,6 +60,13 @@ type sqlQuerier struct {
db DBTX
}
// Ping returns the time it takes to ping the database.
func (q *sqlQuerier) Ping(ctx context.Context) (time.Duration, error) {
start := time.Now()
err := q.sdb.PingContext(ctx)
return time.Since(start), err
}
// InTx performs database operations inside a transaction.
func (q *sqlQuerier) InTx(function func(Store) error) error {
if _, ok := q.db.(*sqlx.Tx); ok {

View File

@ -256,7 +256,8 @@ CREATE TABLE provisioner_daemons (
created_at timestamp with time zone NOT NULL,
updated_at timestamp with time zone,
name character varying(64) NOT NULL,
provisioners provisioner_type[] NOT NULL
provisioners provisioner_type[] NOT NULL,
replica_id uuid
);
CREATE TABLE provisioner_job_logs (
@ -287,6 +288,20 @@ CREATE TABLE provisioner_jobs (
file_id uuid NOT NULL
);
CREATE TABLE replicas (
id uuid NOT NULL,
created_at timestamp with time zone NOT NULL,
started_at timestamp with time zone NOT NULL,
stopped_at timestamp with time zone,
updated_at timestamp with time zone NOT NULL,
hostname text NOT NULL,
region_id integer NOT NULL,
relay_address text NOT NULL,
database_latency integer NOT NULL,
version text NOT NULL,
error text DEFAULT ''::text NOT NULL
);
CREATE TABLE site_configs (
key character varying(256) NOT NULL,
value character varying(8192) NOT NULL

View File

@ -0,0 +1,2 @@
DROP TABLE replicas;
ALTER TABLE provisioner_daemons DROP COLUMN replica_id;

View File

@ -0,0 +1,28 @@
CREATE TABLE IF NOT EXISTS replicas (
-- A unique identifier for the replica that is stored on disk.
-- For persistent replicas, this will be reused.
-- For ephemeral replicas, this will be a new UUID for each one.
id uuid NOT NULL,
created_at timestamp with time zone NOT NULL,
-- The time the replica was created.
started_at timestamp with time zone NOT NULL,
-- The time the replica was last seen.
stopped_at timestamp with time zone,
-- Updated periodically to ensure the replica is still alive.
updated_at timestamp with time zone NOT NULL,
-- Hostname is the hostname of the replica.
hostname text NOT NULL,
-- Region is the region the replica is in.
-- We only DERP mesh to the same region ID of a running replica.
region_id integer NOT NULL,
-- An address that should be accessible to other replicas.
relay_address text NOT NULL,
-- The latency of the replica to the database in microseconds.
database_latency int NOT NULL,
-- Version is the Coder version of the replica.
version text NOT NULL,
error text NOT NULL DEFAULT ''
);
-- Associates a provisioner daemon with a replica.
ALTER TABLE provisioner_daemons ADD COLUMN replica_id uuid;

View File

@ -508,6 +508,7 @@ type ProvisionerDaemon struct {
UpdatedAt sql.NullTime `db:"updated_at" json:"updated_at"`
Name string `db:"name" json:"name"`
Provisioners []ProvisionerType `db:"provisioners" json:"provisioners"`
ReplicaID uuid.NullUUID `db:"replica_id" json:"replica_id"`
}
type ProvisionerJob struct {
@ -538,6 +539,20 @@ type ProvisionerJobLog struct {
Output string `db:"output" json:"output"`
}
type Replica struct {
ID uuid.UUID `db:"id" json:"id"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
StartedAt time.Time `db:"started_at" json:"started_at"`
StoppedAt sql.NullTime `db:"stopped_at" json:"stopped_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
Hostname string `db:"hostname" json:"hostname"`
RegionID int32 `db:"region_id" json:"region_id"`
RelayAddress string `db:"relay_address" json:"relay_address"`
DatabaseLatency int32 `db:"database_latency" json:"database_latency"`
Version string `db:"version" json:"version"`
Error string `db:"error" json:"error"`
}
type SiteConfig struct {
Key string `db:"key" json:"key"`
Value string `db:"value" json:"value"`

View File

@ -47,8 +47,9 @@ func (m *memoryPubsub) Publish(event string, message []byte) error {
return nil
}
for _, listener := range listeners {
listener(context.Background(), message)
go listener(context.Background(), message)
}
return nil
}

View File

@ -26,6 +26,7 @@ type sqlcQuerier interface {
DeleteLicense(ctx context.Context, id int32) (int32, error)
DeleteOldAgentStats(ctx context.Context) error
DeleteParameterValueByID(ctx context.Context, id uuid.UUID) error
DeleteReplicasUpdatedBefore(ctx context.Context, updatedAt time.Time) error
GetAPIKeyByID(ctx context.Context, id string) (APIKey, error)
GetAPIKeysByLoginType(ctx context.Context, loginType LoginType) ([]APIKey, error)
GetAPIKeysLastUsedAfter(ctx context.Context, lastUsed time.Time) ([]APIKey, error)
@ -38,6 +39,7 @@ type sqlcQuerier interface {
// This function returns roles for authorization purposes. Implied member roles
// are included.
GetAuthorizationUserRoles(ctx context.Context, userID uuid.UUID) (GetAuthorizationUserRolesRow, error)
GetDERPMeshKey(ctx context.Context) (string, error)
GetDeploymentID(ctx context.Context) (string, error)
GetFileByHashAndCreator(ctx context.Context, arg GetFileByHashAndCreatorParams) (File, error)
GetFileByID(ctx context.Context, id uuid.UUID) (File, error)
@ -67,6 +69,7 @@ type sqlcQuerier interface {
GetProvisionerJobsByIDs(ctx context.Context, ids []uuid.UUID) ([]ProvisionerJob, error)
GetProvisionerJobsCreatedAfter(ctx context.Context, createdAt time.Time) ([]ProvisionerJob, error)
GetProvisionerLogsByIDBetween(ctx context.Context, arg GetProvisionerLogsByIDBetweenParams) ([]ProvisionerJobLog, error)
GetReplicasUpdatedAfter(ctx context.Context, updatedAt time.Time) ([]Replica, error)
GetTemplateAverageBuildTime(ctx context.Context, arg GetTemplateAverageBuildTimeParams) (GetTemplateAverageBuildTimeRow, error)
GetTemplateByID(ctx context.Context, id uuid.UUID) (Template, error)
GetTemplateByOrganizationAndName(ctx context.Context, arg GetTemplateByOrganizationAndNameParams) (Template, error)
@ -123,6 +126,7 @@ type sqlcQuerier interface {
// every member of the org.
InsertAllUsersGroup(ctx context.Context, organizationID uuid.UUID) (Group, error)
InsertAuditLog(ctx context.Context, arg InsertAuditLogParams) (AuditLog, error)
InsertDERPMeshKey(ctx context.Context, value string) error
InsertDeploymentID(ctx context.Context, value string) error
InsertFile(ctx context.Context, arg InsertFileParams) (File, error)
InsertGitSSHKey(ctx context.Context, arg InsertGitSSHKeyParams) (GitSSHKey, error)
@ -136,6 +140,7 @@ type sqlcQuerier interface {
InsertProvisionerDaemon(ctx context.Context, arg InsertProvisionerDaemonParams) (ProvisionerDaemon, error)
InsertProvisionerJob(ctx context.Context, arg InsertProvisionerJobParams) (ProvisionerJob, error)
InsertProvisionerJobLogs(ctx context.Context, arg InsertProvisionerJobLogsParams) ([]ProvisionerJobLog, error)
InsertReplica(ctx context.Context, arg InsertReplicaParams) (Replica, error)
InsertTemplate(ctx context.Context, arg InsertTemplateParams) (Template, error)
InsertTemplateVersion(ctx context.Context, arg InsertTemplateVersionParams) (TemplateVersion, error)
InsertUser(ctx context.Context, arg InsertUserParams) (User, error)
@ -156,6 +161,7 @@ type sqlcQuerier interface {
UpdateProvisionerJobByID(ctx context.Context, arg UpdateProvisionerJobByIDParams) error
UpdateProvisionerJobWithCancelByID(ctx context.Context, arg UpdateProvisionerJobWithCancelByIDParams) error
UpdateProvisionerJobWithCompleteByID(ctx context.Context, arg UpdateProvisionerJobWithCompleteByIDParams) error
UpdateReplica(ctx context.Context, arg UpdateReplicaParams) (Replica, error)
UpdateTemplateActiveVersionByID(ctx context.Context, arg UpdateTemplateActiveVersionByIDParams) error
UpdateTemplateDeletedByID(ctx context.Context, arg UpdateTemplateDeletedByIDParams) error
UpdateTemplateMetaByID(ctx context.Context, arg UpdateTemplateMetaByIDParams) (Template, error)

View File

@ -2031,7 +2031,7 @@ func (q *sqlQuerier) ParameterValues(ctx context.Context, arg ParameterValuesPar
const getProvisionerDaemonByID = `-- name: GetProvisionerDaemonByID :one
SELECT
id, created_at, updated_at, name, provisioners
id, created_at, updated_at, name, provisioners, replica_id
FROM
provisioner_daemons
WHERE
@ -2047,13 +2047,14 @@ func (q *sqlQuerier) GetProvisionerDaemonByID(ctx context.Context, id uuid.UUID)
&i.UpdatedAt,
&i.Name,
pq.Array(&i.Provisioners),
&i.ReplicaID,
)
return i, err
}
const getProvisionerDaemons = `-- name: GetProvisionerDaemons :many
SELECT
id, created_at, updated_at, name, provisioners
id, created_at, updated_at, name, provisioners, replica_id
FROM
provisioner_daemons
`
@ -2073,6 +2074,7 @@ func (q *sqlQuerier) GetProvisionerDaemons(ctx context.Context) ([]ProvisionerDa
&i.UpdatedAt,
&i.Name,
pq.Array(&i.Provisioners),
&i.ReplicaID,
); err != nil {
return nil, err
}
@ -2096,7 +2098,7 @@ INSERT INTO
provisioners
)
VALUES
($1, $2, $3, $4) RETURNING id, created_at, updated_at, name, provisioners
($1, $2, $3, $4) RETURNING id, created_at, updated_at, name, provisioners, replica_id
`
type InsertProvisionerDaemonParams struct {
@ -2120,6 +2122,7 @@ func (q *sqlQuerier) InsertProvisionerDaemon(ctx context.Context, arg InsertProv
&i.UpdatedAt,
&i.Name,
pq.Array(&i.Provisioners),
&i.ReplicaID,
)
return i, err
}
@ -2577,6 +2580,177 @@ func (q *sqlQuerier) UpdateProvisionerJobWithCompleteByID(ctx context.Context, a
return err
}
const deleteReplicasUpdatedBefore = `-- name: DeleteReplicasUpdatedBefore :exec
DELETE FROM replicas WHERE updated_at < $1
`
func (q *sqlQuerier) DeleteReplicasUpdatedBefore(ctx context.Context, updatedAt time.Time) error {
_, err := q.db.ExecContext(ctx, deleteReplicasUpdatedBefore, updatedAt)
return err
}
const getReplicasUpdatedAfter = `-- name: GetReplicasUpdatedAfter :many
SELECT id, created_at, started_at, stopped_at, updated_at, hostname, region_id, relay_address, database_latency, version, error FROM replicas WHERE updated_at > $1 AND stopped_at IS NULL
`
func (q *sqlQuerier) GetReplicasUpdatedAfter(ctx context.Context, updatedAt time.Time) ([]Replica, error) {
rows, err := q.db.QueryContext(ctx, getReplicasUpdatedAfter, updatedAt)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Replica
for rows.Next() {
var i Replica
if err := rows.Scan(
&i.ID,
&i.CreatedAt,
&i.StartedAt,
&i.StoppedAt,
&i.UpdatedAt,
&i.Hostname,
&i.RegionID,
&i.RelayAddress,
&i.DatabaseLatency,
&i.Version,
&i.Error,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const insertReplica = `-- name: InsertReplica :one
INSERT INTO replicas (
id,
created_at,
started_at,
updated_at,
hostname,
region_id,
relay_address,
version,
database_latency
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) RETURNING id, created_at, started_at, stopped_at, updated_at, hostname, region_id, relay_address, database_latency, version, error
`
type InsertReplicaParams struct {
ID uuid.UUID `db:"id" json:"id"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
StartedAt time.Time `db:"started_at" json:"started_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
Hostname string `db:"hostname" json:"hostname"`
RegionID int32 `db:"region_id" json:"region_id"`
RelayAddress string `db:"relay_address" json:"relay_address"`
Version string `db:"version" json:"version"`
DatabaseLatency int32 `db:"database_latency" json:"database_latency"`
}
func (q *sqlQuerier) InsertReplica(ctx context.Context, arg InsertReplicaParams) (Replica, error) {
row := q.db.QueryRowContext(ctx, insertReplica,
arg.ID,
arg.CreatedAt,
arg.StartedAt,
arg.UpdatedAt,
arg.Hostname,
arg.RegionID,
arg.RelayAddress,
arg.Version,
arg.DatabaseLatency,
)
var i Replica
err := row.Scan(
&i.ID,
&i.CreatedAt,
&i.StartedAt,
&i.StoppedAt,
&i.UpdatedAt,
&i.Hostname,
&i.RegionID,
&i.RelayAddress,
&i.DatabaseLatency,
&i.Version,
&i.Error,
)
return i, err
}
const updateReplica = `-- name: UpdateReplica :one
UPDATE replicas SET
updated_at = $2,
started_at = $3,
stopped_at = $4,
relay_address = $5,
region_id = $6,
hostname = $7,
version = $8,
error = $9,
database_latency = $10
WHERE id = $1 RETURNING id, created_at, started_at, stopped_at, updated_at, hostname, region_id, relay_address, database_latency, version, error
`
type UpdateReplicaParams struct {
ID uuid.UUID `db:"id" json:"id"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
StartedAt time.Time `db:"started_at" json:"started_at"`
StoppedAt sql.NullTime `db:"stopped_at" json:"stopped_at"`
RelayAddress string `db:"relay_address" json:"relay_address"`
RegionID int32 `db:"region_id" json:"region_id"`
Hostname string `db:"hostname" json:"hostname"`
Version string `db:"version" json:"version"`
Error string `db:"error" json:"error"`
DatabaseLatency int32 `db:"database_latency" json:"database_latency"`
}
func (q *sqlQuerier) UpdateReplica(ctx context.Context, arg UpdateReplicaParams) (Replica, error) {
row := q.db.QueryRowContext(ctx, updateReplica,
arg.ID,
arg.UpdatedAt,
arg.StartedAt,
arg.StoppedAt,
arg.RelayAddress,
arg.RegionID,
arg.Hostname,
arg.Version,
arg.Error,
arg.DatabaseLatency,
)
var i Replica
err := row.Scan(
&i.ID,
&i.CreatedAt,
&i.StartedAt,
&i.StoppedAt,
&i.UpdatedAt,
&i.Hostname,
&i.RegionID,
&i.RelayAddress,
&i.DatabaseLatency,
&i.Version,
&i.Error,
)
return i, err
}
const getDERPMeshKey = `-- name: GetDERPMeshKey :one
SELECT value FROM site_configs WHERE key = 'derp_mesh_key'
`
func (q *sqlQuerier) GetDERPMeshKey(ctx context.Context) (string, error) {
row := q.db.QueryRowContext(ctx, getDERPMeshKey)
var value string
err := row.Scan(&value)
return value, err
}
const getDeploymentID = `-- name: GetDeploymentID :one
SELECT value FROM site_configs WHERE key = 'deployment_id'
`
@ -2588,6 +2762,15 @@ func (q *sqlQuerier) GetDeploymentID(ctx context.Context) (string, error) {
return value, err
}
const insertDERPMeshKey = `-- name: InsertDERPMeshKey :exec
INSERT INTO site_configs (key, value) VALUES ('derp_mesh_key', $1)
`
func (q *sqlQuerier) InsertDERPMeshKey(ctx context.Context, value string) error {
_, err := q.db.ExecContext(ctx, insertDERPMeshKey, value)
return err
}
const insertDeploymentID = `-- name: InsertDeploymentID :exec
INSERT INTO site_configs (key, value) VALUES ('deployment_id', $1)
`

View File

@ -0,0 +1,31 @@
-- name: GetReplicasUpdatedAfter :many
SELECT * FROM replicas WHERE updated_at > $1 AND stopped_at IS NULL;
-- name: InsertReplica :one
INSERT INTO replicas (
id,
created_at,
started_at,
updated_at,
hostname,
region_id,
relay_address,
version,
database_latency
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) RETURNING *;
-- name: UpdateReplica :one
UPDATE replicas SET
updated_at = $2,
started_at = $3,
stopped_at = $4,
relay_address = $5,
region_id = $6,
hostname = $7,
version = $8,
error = $9,
database_latency = $10
WHERE id = $1 RETURNING *;
-- name: DeleteReplicasUpdatedBefore :exec
DELETE FROM replicas WHERE updated_at < $1;

View File

@ -3,3 +3,9 @@ INSERT INTO site_configs (key, value) VALUES ('deployment_id', $1);
-- name: GetDeploymentID :one
SELECT value FROM site_configs WHERE key = 'deployment_id';
-- name: InsertDERPMeshKey :exec
INSERT INTO site_configs (key, value) VALUES ('derp_mesh_key', $1);
-- name: GetDERPMeshKey :one
SELECT value FROM site_configs WHERE key = 'derp_mesh_key';

View File

@ -270,7 +270,7 @@ func (api *API) provisionerJobResources(rw http.ResponseWriter, r *http.Request,
}
}
apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, agent, convertApps(dbApps), api.AgentInactiveDisconnectTimeout)
apiAgent, err := convertWorkspaceAgent(api.DERPMap, *api.TailnetCoordinator.Load(), agent, convertApps(dbApps), api.AgentInactiveDisconnectTimeout)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error reading job agent.",

View File

@ -146,6 +146,10 @@ var (
ResourceDeploymentFlags = Object{
Type: "deployment_flags",
}
ResourceReplicas = Object{
Type: "replicas",
}
)
// Object is used to create objects for authz checks when you have none in

View File

@ -627,7 +627,9 @@ func TestTemplateMetrics(t *testing.T) {
require.NoError(t, err)
assert.Zero(t, workspaces[0].LastUsedAt)
conn, err := client.DialWorkspaceAgentTailnet(ctx, slogtest.Make(t, nil).Named("tailnet"), resources[0].Agents[0].ID)
conn, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, &codersdk.DialWorkspaceAgentOptions{
Logger: slogtest.Make(t, nil).Named("tailnet"),
})
require.NoError(t, err)
defer func() {
_ = conn.Close()

View File

@ -49,7 +49,7 @@ func (api *API) workspaceAgent(rw http.ResponseWriter, r *http.Request) {
})
return
}
apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, workspaceAgent, convertApps(dbApps), api.AgentInactiveDisconnectTimeout)
apiAgent, err := convertWorkspaceAgent(api.DERPMap, *api.TailnetCoordinator.Load(), workspaceAgent, convertApps(dbApps), api.AgentInactiveDisconnectTimeout)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error reading workspace agent.",
@ -78,7 +78,7 @@ func (api *API) workspaceAgentApps(rw http.ResponseWriter, r *http.Request) {
func (api *API) workspaceAgentMetadata(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
workspaceAgent := httpmw.WorkspaceAgent(r)
apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, workspaceAgent, nil, api.AgentInactiveDisconnectTimeout)
apiAgent, err := convertWorkspaceAgent(api.DERPMap, *api.TailnetCoordinator.Load(), workspaceAgent, nil, api.AgentInactiveDisconnectTimeout)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error reading workspace agent.",
@ -98,7 +98,7 @@ func (api *API) workspaceAgentMetadata(rw http.ResponseWriter, r *http.Request)
func (api *API) postWorkspaceAgentVersion(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
workspaceAgent := httpmw.WorkspaceAgent(r)
apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, workspaceAgent, nil, api.AgentInactiveDisconnectTimeout)
apiAgent, err := convertWorkspaceAgent(api.DERPMap, *api.TailnetCoordinator.Load(), workspaceAgent, nil, api.AgentInactiveDisconnectTimeout)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error reading workspace agent.",
@ -152,7 +152,7 @@ func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
httpapi.ResourceNotFound(rw)
return
}
apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, workspaceAgent, nil, api.AgentInactiveDisconnectTimeout)
apiAgent, err := convertWorkspaceAgent(api.DERPMap, *api.TailnetCoordinator.Load(), workspaceAgent, nil, api.AgentInactiveDisconnectTimeout)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error reading workspace agent.",
@ -229,7 +229,7 @@ func (api *API) workspaceAgentListeningPorts(rw http.ResponseWriter, r *http.Req
return
}
apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, workspaceAgent, nil, api.AgentInactiveDisconnectTimeout)
apiAgent, err := convertWorkspaceAgent(api.DERPMap, *api.TailnetCoordinator.Load(), workspaceAgent, nil, api.AgentInactiveDisconnectTimeout)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error reading workspace agent.",
@ -376,8 +376,9 @@ func (api *API) dialWorkspaceAgentTailnet(r *http.Request, agentID uuid.UUID) (*
})
conn.SetNodeCallback(sendNodes)
go func() {
err := api.TailnetCoordinator.ServeClient(serverConn, uuid.New(), agentID)
err := (*api.TailnetCoordinator.Load()).ServeClient(serverConn, uuid.New(), agentID)
if err != nil {
api.Logger.Warn(r.Context(), "tailnet coordinator client error", slog.Error(err))
_ = conn.Close()
}
}()
@ -514,8 +515,9 @@ func (api *API) workspaceAgentCoordinate(rw http.ResponseWriter, r *http.Request
closeChan := make(chan struct{})
go func() {
defer close(closeChan)
err := api.TailnetCoordinator.ServeAgent(wsNetConn, workspaceAgent.ID)
err := (*api.TailnetCoordinator.Load()).ServeAgent(wsNetConn, workspaceAgent.ID)
if err != nil {
api.Logger.Warn(ctx, "tailnet coordinator agent error", slog.Error(err))
_ = conn.Close(websocket.StatusInternalError, err.Error())
return
}
@ -583,7 +585,7 @@ func (api *API) workspaceAgentClientCoordinate(rw http.ResponseWriter, r *http.R
go httpapi.Heartbeat(ctx, conn)
defer conn.Close(websocket.StatusNormalClosure, "")
err = api.TailnetCoordinator.ServeClient(websocket.NetConn(ctx, conn, websocket.MessageBinary), uuid.New(), workspaceAgent.ID)
err = (*api.TailnetCoordinator.Load()).ServeClient(websocket.NetConn(ctx, conn, websocket.MessageBinary), uuid.New(), workspaceAgent.ID)
if err != nil {
_ = conn.Close(websocket.StatusInternalError, err.Error())
return
@ -611,7 +613,7 @@ func convertApps(dbApps []database.WorkspaceApp) []codersdk.WorkspaceApp {
return apps
}
func convertWorkspaceAgent(derpMap *tailcfg.DERPMap, coordinator *tailnet.Coordinator, dbAgent database.WorkspaceAgent, apps []codersdk.WorkspaceApp, agentInactiveDisconnectTimeout time.Duration) (codersdk.WorkspaceAgent, error) {
func convertWorkspaceAgent(derpMap *tailcfg.DERPMap, coordinator tailnet.Coordinator, dbAgent database.WorkspaceAgent, apps []codersdk.WorkspaceApp, agentInactiveDisconnectTimeout time.Duration) (codersdk.WorkspaceAgent, error) {
var envs map[string]string
if dbAgent.EnvironmentVariables.Valid {
err := json.Unmarshal(dbAgent.EnvironmentVariables.RawMessage, &envs)

View File

@ -123,13 +123,13 @@ func TestWorkspaceAgentListen(t *testing.T) {
defer cancel()
resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID)
conn, err := client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, resources[0].Agents[0].ID)
conn, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, nil)
require.NoError(t, err)
defer func() {
_ = conn.Close()
}()
require.Eventually(t, func() bool {
_, err := conn.Ping()
_, err := conn.Ping(ctx)
return err == nil
}, testutil.WaitLong, testutil.IntervalFast)
})
@ -253,7 +253,9 @@ func TestWorkspaceAgentTailnet(t *testing.T) {
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
conn, err := client.DialWorkspaceAgentTailnet(ctx, slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug), resources[0].Agents[0].ID)
conn, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, &codersdk.DialWorkspaceAgentOptions{
Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug),
})
require.NoError(t, err)
defer conn.Close()
sshClient, err := conn.SSHClient()

View File

@ -861,7 +861,7 @@ func (api *API) convertWorkspaceBuild(
apiAgents := make([]codersdk.WorkspaceAgent, 0)
for _, agent := range agents {
apps := appsByAgentID[agent.ID]
apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, agent, convertApps(apps), api.AgentInactiveDisconnectTimeout)
apiAgent, err := convertWorkspaceAgent(api.DERPMap, *api.TailnetCoordinator.Load(), agent, convertApps(apps), api.AgentInactiveDisconnectTimeout)
if err != nil {
return codersdk.WorkspaceBuild{}, xerrors.Errorf("converting workspace agent: %w", err)
}

View File

@ -128,7 +128,9 @@ func TestCache(t *testing.T) {
return
}
defer release()
proxy.Transport = conn.HTTPTransport()
transport := conn.HTTPTransport()
defer transport.CloseIdleConnections()
proxy.Transport = transport
res := httptest.NewRecorder()
proxy.ServeHTTP(res, req)
resp := res.Result()