feat: Support config files with viper (#4558)

This commit is contained in:
Garrett Delfosse
2022-10-21 15:26:39 -04:00
committed by GitHub
parent 2c47cda3d1
commit c8e299c8f1
35 changed files with 920 additions and 1089 deletions

View File

@ -32,6 +32,7 @@ import (
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"go.opentelemetry.io/otel/trace"
"golang.org/x/mod/semver"
"golang.org/x/oauth2"
@ -70,14 +71,18 @@ import (
)
// nolint:gocyclo
func Server(dflags *codersdk.DeploymentFlags, newAPI func(context.Context, *coderd.Options) (*coderd.API, io.Closer, error)) *cobra.Command {
func Server(vip *viper.Viper, newAPI func(context.Context, *coderd.Options) (*coderd.API, io.Closer, error)) *cobra.Command {
root := &cobra.Command{
Use: "server",
Short: "Start a Coder server",
RunE: func(cmd *cobra.Command, args []string) error {
cfg, err := deployment.Config(cmd.Flags(), vip)
if err != nil {
return xerrors.Errorf("getting deployment config: %w", err)
}
printLogo(cmd)
logger := slog.Make(sloghuman.Sink(cmd.ErrOrStderr()))
if dflags.Verbose.Value {
if ok, _ := cmd.Flags().GetBool(varVerbose); ok {
logger = logger.Leveled(slog.LevelDebug)
}
@ -106,22 +111,21 @@ func Server(dflags *codersdk.DeploymentFlags, newAPI func(context.Context, *code
var (
tracerProvider trace.TracerProvider
err error
sqlDriver = "postgres"
)
// Coder tracing should be disabled if telemetry is disabled unless
// --telemetry-trace was explicitly provided.
shouldCoderTrace := dflags.TelemetryEnable.Value && !isTest()
shouldCoderTrace := cfg.TelemetryEnable.Value && !isTest()
// Only override if telemetryTraceEnable was specifically set.
// By default we want it to be controlled by telemetryEnable.
if cmd.Flags().Changed("telemetry-trace") {
shouldCoderTrace = dflags.TelemetryTraceEnable.Value
shouldCoderTrace = cfg.TelemetryTrace.Value
}
if dflags.TraceEnable.Value || shouldCoderTrace {
if cfg.TraceEnable.Value || shouldCoderTrace {
sdkTracerProvider, closeTracing, err := tracing.TracerProvider(ctx, "coderd", tracing.TracerOpts{
Default: dflags.TraceEnable.Value,
Default: cfg.TraceEnable.Value,
Coder: shouldCoderTrace,
})
if err != nil {
@ -146,10 +150,10 @@ func Server(dflags *codersdk.DeploymentFlags, newAPI func(context.Context, *code
config := createConfig(cmd)
builtinPostgres := false
// Only use built-in if PostgreSQL URL isn't specified!
if !dflags.InMemoryDatabase.Value && dflags.PostgresURL.Value == "" {
if !cfg.InMemoryDatabase.Value && cfg.PostgresURL.Value == "" {
var closeFunc func() error
cmd.Printf("Using built-in PostgreSQL (%s)\n", config.PostgresPath())
dflags.PostgresURL.Value, closeFunc, err = startBuiltinPostgres(ctx, config, logger)
cfg.PostgresURL.Value, closeFunc, err = startBuiltinPostgres(ctx, config, logger)
if err != nil {
return err
}
@ -162,20 +166,20 @@ func Server(dflags *codersdk.DeploymentFlags, newAPI func(context.Context, *code
}()
}
listener, err := net.Listen("tcp", dflags.Address.Value)
listener, err := net.Listen("tcp", cfg.Address.Value)
if err != nil {
return xerrors.Errorf("listen %q: %w", dflags.Address.Value, err)
return xerrors.Errorf("listen %q: %w", cfg.Address.Value, err)
}
defer listener.Close()
var tlsConfig *tls.Config
if dflags.TLSEnable.Value {
if cfg.TLSEnable.Value {
tlsConfig, err = configureTLS(
dflags.TLSMinVersion.Value,
dflags.TLSClientAuth.Value,
dflags.TLSCertFiles.Value,
dflags.TLSKeyFiles.Value,
dflags.TLSClientCAFile.Value,
cfg.TLSMinVersion.Value,
cfg.TLSClientAuth.Value,
cfg.TLSCertFiles.Value,
cfg.TLSKeyFiles.Value,
cfg.TLSClientCAFile.Value,
)
if err != nil {
return xerrors.Errorf("configure tls: %w", err)
@ -197,7 +201,7 @@ func Server(dflags *codersdk.DeploymentFlags, newAPI func(context.Context, *code
Scheme: "http",
Host: tcpAddr.String(),
}
if dflags.TLSEnable.Value {
if cfg.TLSEnable.Value {
localURL.Scheme = "https"
}
@ -210,26 +214,26 @@ func Server(dflags *codersdk.DeploymentFlags, newAPI func(context.Context, *code
// If the access URL is empty, we attempt to run a reverse-proxy
// tunnel to make the initial setup really simple.
if dflags.AccessURL.Value == "" {
if cfg.AccessURL.Value == "" {
cmd.Printf("Opening tunnel so workspaces can connect to your deployment. For production scenarios, specify an external access URL\n")
tunnel, tunnelErr, err = devtunnel.New(ctxTunnel, logger.Named("devtunnel"))
if err != nil {
return xerrors.Errorf("create tunnel: %w", err)
}
dflags.AccessURL.Value = tunnel.URL
cfg.AccessURL.Value = tunnel.URL
if dflags.WildcardAccessURL.Value == "" {
if cfg.WildcardAccessURL.Value == "" {
u, err := parseURL(ctx, tunnel.URL)
if err != nil {
return xerrors.Errorf("parse tunnel url: %w", err)
}
// Suffixed wildcard access URL.
dflags.WildcardAccessURL.Value = fmt.Sprintf("*--%s", u.Hostname())
cfg.WildcardAccessURL.Value = fmt.Sprintf("*--%s", u.Hostname())
}
}
accessURLParsed, err := parseURL(ctx, dflags.AccessURL.Value)
accessURLParsed, err := parseURL(ctx, cfg.AccessURL.Value)
if err != nil {
return xerrors.Errorf("parse URL: %w", err)
}
@ -264,17 +268,17 @@ func Server(dflags *codersdk.DeploymentFlags, newAPI func(context.Context, *code
return err
}
sshKeygenAlgorithm, err := gitsshkey.ParseAlgorithm(dflags.SSHKeygenAlgorithm.Value)
sshKeygenAlgorithm, err := gitsshkey.ParseAlgorithm(cfg.SSHKeygenAlgorithm.Value)
if err != nil {
return xerrors.Errorf("parse ssh keygen algorithm %s: %w", dflags.SSHKeygenAlgorithm.Value, err)
return xerrors.Errorf("parse ssh keygen algorithm %s: %w", cfg.SSHKeygenAlgorithm.Value, err)
}
// Validate provided auto-import templates.
var (
validatedAutoImportTemplates = make([]coderd.AutoImportTemplate, len(dflags.AutoImportTemplates.Value))
seenValidatedAutoImportTemplates = make(map[coderd.AutoImportTemplate]struct{}, len(dflags.AutoImportTemplates.Value))
validatedAutoImportTemplates = make([]coderd.AutoImportTemplate, len(cfg.AutoImportTemplates.Value))
seenValidatedAutoImportTemplates = make(map[coderd.AutoImportTemplate]struct{}, len(cfg.AutoImportTemplates.Value))
)
for i, autoImportTemplate := range dflags.AutoImportTemplates.Value {
for i, autoImportTemplate := range cfg.AutoImportTemplates.Value {
var v coderd.AutoImportTemplate
switch autoImportTemplate {
case "kubernetes":
@ -292,27 +296,27 @@ func Server(dflags *codersdk.DeploymentFlags, newAPI func(context.Context, *code
defaultRegion := &tailcfg.DERPRegion{
EmbeddedRelay: true,
RegionID: dflags.DerpServerRegionID.Value,
RegionCode: dflags.DerpServerRegionCode.Value,
RegionName: dflags.DerpServerRegionName.Value,
RegionID: cfg.DERPServerRegionID.Value,
RegionCode: cfg.DERPServerRegionCode.Value,
RegionName: cfg.DERPServerRegionName.Value,
Nodes: []*tailcfg.DERPNode{{
Name: fmt.Sprintf("%db", dflags.DerpServerRegionID.Value),
RegionID: dflags.DerpServerRegionID.Value,
Name: fmt.Sprintf("%db", cfg.DERPServerRegionID.Value),
RegionID: cfg.DERPServerRegionID.Value,
HostName: accessURLParsed.Hostname(),
DERPPort: accessURLPort,
STUNPort: -1,
ForceHTTP: accessURLParsed.Scheme == "http",
}},
}
if !dflags.DerpServerEnable.Value {
if !cfg.DERPServerEnable.Value {
defaultRegion = nil
}
derpMap, err := tailnet.NewDERPMap(ctx, defaultRegion, dflags.DerpServerSTUNAddresses.Value, dflags.DerpConfigURL.Value, dflags.DerpConfigPath.Value)
derpMap, err := tailnet.NewDERPMap(ctx, defaultRegion, cfg.DERPServerSTUNAddresses.Value, cfg.DERPConfigURL.Value, cfg.DERPConfigPath.Value)
if err != nil {
return xerrors.Errorf("create derp map: %w", err)
}
appHostname := strings.TrimSpace(dflags.WildcardAccessURL.Value)
appHostname := strings.TrimSpace(cfg.WildcardAccessURL.Value)
var appHostnameRegex *regexp.Regexp
if appHostname != "" {
appHostnameRegex, err = httpapi.CompileHostnamePattern(appHostname)
@ -329,45 +333,45 @@ func Server(dflags *codersdk.DeploymentFlags, newAPI func(context.Context, *code
Database: databasefake.New(),
DERPMap: derpMap,
Pubsub: database.NewPubsubInMemory(),
CacheDir: dflags.CacheDir.Value,
CacheDir: cfg.CacheDirectory.Value,
GoogleTokenValidator: googleTokenValidator,
SecureAuthCookie: dflags.SecureAuthCookie.Value,
SecureAuthCookie: cfg.SecureAuthCookie.Value,
SSHKeygenAlgorithm: sshKeygenAlgorithm,
TracerProvider: tracerProvider,
Telemetry: telemetry.NewNoop(),
AutoImportTemplates: validatedAutoImportTemplates,
MetricsCacheRefreshInterval: dflags.MetricsCacheRefreshInterval.Value,
AgentStatsRefreshInterval: dflags.AgentStatRefreshInterval.Value,
MetricsCacheRefreshInterval: cfg.MetricsCacheRefreshInterval.Value,
AgentStatsRefreshInterval: cfg.AgentStatRefreshInterval.Value,
Experimental: ExperimentalEnabled(cmd),
DeploymentFlags: dflags,
DeploymentConfig: &cfg,
}
if tlsConfig != nil {
options.TLSCertificates = tlsConfig.Certificates
}
if dflags.OAuth2GithubClientSecret.Value != "" {
if cfg.OAuth2GithubClientSecret.Value != "" {
options.GithubOAuth2Config, err = configureGithubOAuth2(accessURLParsed,
dflags.OAuth2GithubClientID.Value,
dflags.OAuth2GithubClientSecret.Value,
dflags.OAuth2GithubAllowSignups.Value,
dflags.OAuth2GithubAllowedOrganizations.Value,
dflags.OAuth2GithubAllowedTeams.Value,
dflags.OAuth2GithubEnterpriseBaseURL.Value,
cfg.OAuth2GithubClientID.Value,
cfg.OAuth2GithubClientSecret.Value,
cfg.OAuth2GithubAllowSignups.Value,
cfg.OAuth2GithubAllowedOrganizations.Value,
cfg.OAuth2GithubAllowedTeams.Value,
cfg.OAuth2GithubEnterpriseBaseURL.Value,
)
if err != nil {
return xerrors.Errorf("configure github oauth2: %w", err)
}
}
if dflags.OIDCClientSecret.Value != "" {
if dflags.OIDCClientID.Value == "" {
if cfg.OIDCClientSecret.Value != "" {
if cfg.OIDCClientID.Value == "" {
return xerrors.Errorf("OIDC client ID be set!")
}
if dflags.OIDCIssuerURL.Value == "" {
if cfg.OIDCIssuerURL.Value == "" {
return xerrors.Errorf("OIDC issuer URL must be set!")
}
oidcProvider, err := oidc.NewProvider(ctx, dflags.OIDCIssuerURL.Value)
oidcProvider, err := oidc.NewProvider(ctx, cfg.OIDCIssuerURL.Value)
if err != nil {
return xerrors.Errorf("configure oidc provider: %w", err)
}
@ -377,25 +381,25 @@ func Server(dflags *codersdk.DeploymentFlags, newAPI func(context.Context, *code
}
options.OIDCConfig = &coderd.OIDCConfig{
OAuth2Config: &oauth2.Config{
ClientID: dflags.OIDCClientID.Value,
ClientSecret: dflags.OIDCClientSecret.Value,
ClientID: cfg.OIDCClientID.Value,
ClientSecret: cfg.OIDCClientSecret.Value,
RedirectURL: redirectURL.String(),
Endpoint: oidcProvider.Endpoint(),
Scopes: dflags.OIDCScopes.Value,
Scopes: cfg.OIDCScopes.Value,
},
Verifier: oidcProvider.Verifier(&oidc.Config{
ClientID: dflags.OIDCClientID.Value,
ClientID: cfg.OIDCClientID.Value,
}),
EmailDomain: dflags.OIDCEmailDomain.Value,
AllowSignups: dflags.OIDCAllowSignups.Value,
EmailDomain: cfg.OIDCEmailDomain.Value,
AllowSignups: cfg.OIDCAllowSignups.Value,
}
}
if dflags.InMemoryDatabase.Value {
if cfg.InMemoryDatabase.Value {
options.Database = databasefake.New()
options.Pubsub = database.NewPubsubInMemory()
} else {
sqlDB, err := sql.Open(sqlDriver, dflags.PostgresURL.Value)
sqlDB, err := sql.Open(sqlDriver, cfg.PostgresURL.Value)
if err != nil {
return xerrors.Errorf("dial postgres: %w", err)
}
@ -427,7 +431,7 @@ func Server(dflags *codersdk.DeploymentFlags, newAPI func(context.Context, *code
return xerrors.Errorf("migrate up: %w", err)
}
options.Database = database.New(sqlDB)
options.Pubsub, err = database.NewPubsub(ctx, sqlDB, dflags.PostgresURL.Value)
options.Pubsub, err = database.NewPubsub(ctx, sqlDB, cfg.PostgresURL.Value)
if err != nil {
return xerrors.Errorf("create pubsub: %w", err)
}
@ -450,26 +454,26 @@ func Server(dflags *codersdk.DeploymentFlags, newAPI func(context.Context, *code
}
// Parse the raw telemetry URL!
telemetryURL, err := parseURL(ctx, dflags.TelemetryURL.Value)
telemetryURL, err := parseURL(ctx, cfg.TelemetryURL.Value)
if err != nil {
return xerrors.Errorf("parse telemetry url: %w", err)
}
// Disable telemetry if the in-memory database is used unless explicitly defined!
if dflags.InMemoryDatabase.Value && !cmd.Flags().Changed(dflags.TelemetryEnable.Flag) {
dflags.TelemetryEnable.Value = false
if cfg.InMemoryDatabase.Value && !cmd.Flags().Changed(cfg.TelemetryEnable.Flag) {
cfg.TelemetryEnable.Value = false
}
if dflags.TelemetryEnable.Value {
if cfg.TelemetryEnable.Value {
options.Telemetry, err = telemetry.New(telemetry.Options{
BuiltinPostgres: builtinPostgres,
DeploymentID: deploymentID,
Database: options.Database,
Logger: logger.Named("telemetry"),
URL: telemetryURL,
GitHubOAuth: dflags.OAuth2GithubClientID.Value != "",
OIDCAuth: dflags.OIDCClientID.Value != "",
OIDCIssuerURL: dflags.OIDCIssuerURL.Value,
Prometheus: dflags.PromEnabled.Value,
STUN: len(dflags.DerpServerSTUNAddresses.Value) != 0,
GitHubOAuth: cfg.OAuth2GithubClientID.Value != "",
OIDCAuth: cfg.OIDCClientID.Value != "",
OIDCIssuerURL: cfg.OIDCIssuerURL.Value,
Prometheus: cfg.PrometheusEnable.Value,
STUN: len(cfg.DERPServerSTUNAddresses.Value) != 0,
Tunnel: tunnel != nil,
})
if err != nil {
@ -480,11 +484,11 @@ func Server(dflags *codersdk.DeploymentFlags, newAPI func(context.Context, *code
// This prevents the pprof import from being accidentally deleted.
_ = pprof.Handler
if dflags.PprofEnabled.Value {
if cfg.PprofEnable.Value {
//nolint:revive
defer serveHandler(ctx, logger, nil, dflags.PprofAddress.Value, "pprof")()
defer serveHandler(ctx, logger, nil, cfg.PprofAddress.Value, "pprof")()
}
if dflags.PromEnabled.Value {
if cfg.PrometheusEnable.Value {
options.PrometheusRegistry = prometheus.NewRegistry()
closeUsersFunc, err := prometheusmetrics.ActiveUsers(ctx, options.PrometheusRegistry, options.Database, 0)
if err != nil {
@ -501,7 +505,7 @@ func Server(dflags *codersdk.DeploymentFlags, newAPI func(context.Context, *code
//nolint:revive
defer serveHandler(ctx, logger, promhttp.InstrumentMetricHandler(
options.PrometheusRegistry, promhttp.HandlerFor(options.PrometheusRegistry, promhttp.HandlerOpts{}),
), dflags.PromAddress.Value, "prometheus")()
), cfg.PrometheusAddress.Value, "prometheus")()
}
// We use a separate coderAPICloser so the Enterprise API
@ -513,7 +517,7 @@ func Server(dflags *codersdk.DeploymentFlags, newAPI func(context.Context, *code
}
client := codersdk.New(localURL)
if dflags.TLSEnable.Value {
if cfg.TLSEnable.Value {
// Secure transport isn't needed for locally communicating!
client.HTTPClient.Transport = &http.Transport{
TLSClientConfig: &tls.Config{
@ -537,8 +541,8 @@ func Server(dflags *codersdk.DeploymentFlags, newAPI func(context.Context, *code
_ = daemon.Close()
}
}()
for i := 0; i < dflags.ProvisionerDaemonCount.Value; i++ {
daemon, err := newProvisionerDaemon(ctx, coderAPI, logger, dflags.CacheDir.Value, errCh, false)
for i := 0; i < cfg.ProvisionerDaemons.Value; i++ {
daemon, err := newProvisionerDaemon(ctx, coderAPI, logger, cfg.CacheDirectory.Value, errCh, false)
if err != nil {
return xerrors.Errorf("create provisioner daemon: %w", err)
}
@ -604,7 +608,7 @@ func Server(dflags *codersdk.DeploymentFlags, newAPI func(context.Context, *code
return xerrors.Errorf("notify systemd: %w", err)
}
autobuildPoller := time.NewTicker(dflags.AutobuildPollInterval.Value)
autobuildPoller := time.NewTicker(cfg.AutobuildPollInterval.Value)
defer autobuildPoller.Stop()
autobuildExecutor := executor.New(ctx, options.Database, logger, autobuildPoller.C)
autobuildExecutor.Run()
@ -669,7 +673,7 @@ func Server(dflags *codersdk.DeploymentFlags, newAPI func(context.Context, *code
go func() {
defer wg.Done()
if dflags.Verbose.Value {
if ok, _ := cmd.Flags().GetBool(varVerbose); ok {
cmd.Printf("Shutting down provisioner daemon %d...\n", id)
}
err := shutdownWithTimeout(provisionerDaemon.Shutdown, 5*time.Second)
@ -682,7 +686,7 @@ func Server(dflags *codersdk.DeploymentFlags, newAPI func(context.Context, *code
cmd.PrintErrf("Close provisioner daemon %d: %s\n", id, err)
return
}
if dflags.Verbose.Value {
if ok, _ := cmd.Flags().GetBool(varVerbose); ok {
cmd.Printf("Gracefully shut down provisioner daemon %d\n", id)
}
}()
@ -734,7 +738,7 @@ func Server(dflags *codersdk.DeploymentFlags, newAPI func(context.Context, *code
RunE: func(cmd *cobra.Command, args []string) error {
cfg := createConfig(cmd)
logger := slog.Make(sloghuman.Sink(cmd.ErrOrStderr()))
if dflags.Verbose.Value {
if ok, _ := cmd.Flags().GetBool(varVerbose); ok {
logger = logger.Leveled(slog.LevelDebug)
}
@ -755,7 +759,7 @@ func Server(dflags *codersdk.DeploymentFlags, newAPI func(context.Context, *code
},
})
deployment.AttachFlags(root.Flags(), dflags, false)
deployment.AttachFlags(root.Flags(), vip, false)
return root
}