mirror of
https://github.com/coder/coder.git
synced 2025-07-06 15:41:45 +00:00
Part of #10532 DRPC transport over yamux and in-mem pipes was previously only used on the provisioner APIs, but now will also be used in tailnet. Moved to subpackage of codersdk to avoid import loops.
2551 lines
81 KiB
Go
2551 lines
81 KiB
Go
//go:build !slim
|
|
|
|
package cli
|
|
|
|
import (
|
|
"context"
|
|
"crypto/ecdsa"
|
|
"crypto/elliptic"
|
|
"crypto/rand"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"database/sql"
|
|
"encoding/hex"
|
|
"errors"
|
|
"flag"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"math/big"
|
|
"net"
|
|
"net/http"
|
|
"net/http/pprof"
|
|
"net/url"
|
|
"os"
|
|
"os/user"
|
|
"path/filepath"
|
|
"regexp"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/coreos/go-oidc/v3/oidc"
|
|
"github.com/coreos/go-systemd/daemon"
|
|
embeddedpostgres "github.com/fergusstrange/embedded-postgres"
|
|
"github.com/google/go-github/v43/github"
|
|
"github.com/google/uuid"
|
|
"github.com/prometheus/client_golang/prometheus"
|
|
"github.com/prometheus/client_golang/prometheus/collectors"
|
|
"github.com/prometheus/client_golang/prometheus/promhttp"
|
|
"go.opentelemetry.io/otel"
|
|
"go.opentelemetry.io/otel/propagation"
|
|
"go.opentelemetry.io/otel/trace"
|
|
"golang.org/x/mod/semver"
|
|
"golang.org/x/oauth2"
|
|
xgithub "golang.org/x/oauth2/github"
|
|
"golang.org/x/sync/errgroup"
|
|
"golang.org/x/xerrors"
|
|
"google.golang.org/api/idtoken"
|
|
"google.golang.org/api/option"
|
|
"gopkg.in/yaml.v3"
|
|
"tailscale.com/tailcfg"
|
|
|
|
"github.com/coder/pretty"
|
|
|
|
"cdr.dev/slog"
|
|
"cdr.dev/slog/sloggers/sloghuman"
|
|
"cdr.dev/slog/sloggers/slogjson"
|
|
"cdr.dev/slog/sloggers/slogstackdriver"
|
|
"github.com/coder/coder/v2/buildinfo"
|
|
"github.com/coder/coder/v2/cli/clibase"
|
|
"github.com/coder/coder/v2/cli/cliui"
|
|
"github.com/coder/coder/v2/cli/cliutil"
|
|
"github.com/coder/coder/v2/cli/config"
|
|
"github.com/coder/coder/v2/coderd"
|
|
"github.com/coder/coder/v2/coderd/autobuild"
|
|
"github.com/coder/coder/v2/coderd/batchstats"
|
|
"github.com/coder/coder/v2/coderd/database"
|
|
"github.com/coder/coder/v2/coderd/database/dbmem"
|
|
"github.com/coder/coder/v2/coderd/database/dbmetrics"
|
|
"github.com/coder/coder/v2/coderd/database/dbpurge"
|
|
"github.com/coder/coder/v2/coderd/database/migrations"
|
|
"github.com/coder/coder/v2/coderd/database/pubsub"
|
|
"github.com/coder/coder/v2/coderd/devtunnel"
|
|
"github.com/coder/coder/v2/coderd/externalauth"
|
|
"github.com/coder/coder/v2/coderd/gitsshkey"
|
|
"github.com/coder/coder/v2/coderd/httpapi"
|
|
"github.com/coder/coder/v2/coderd/httpmw"
|
|
"github.com/coder/coder/v2/coderd/oauthpki"
|
|
"github.com/coder/coder/v2/coderd/prometheusmetrics"
|
|
"github.com/coder/coder/v2/coderd/prometheusmetrics/insights"
|
|
"github.com/coder/coder/v2/coderd/schedule"
|
|
"github.com/coder/coder/v2/coderd/telemetry"
|
|
"github.com/coder/coder/v2/coderd/tracing"
|
|
"github.com/coder/coder/v2/coderd/unhanger"
|
|
"github.com/coder/coder/v2/coderd/updatecheck"
|
|
"github.com/coder/coder/v2/coderd/util/slice"
|
|
stringutil "github.com/coder/coder/v2/coderd/util/strings"
|
|
"github.com/coder/coder/v2/coderd/workspaceapps"
|
|
"github.com/coder/coder/v2/codersdk"
|
|
"github.com/coder/coder/v2/codersdk/drpc"
|
|
"github.com/coder/coder/v2/cryptorand"
|
|
"github.com/coder/coder/v2/provisioner/echo"
|
|
"github.com/coder/coder/v2/provisioner/terraform"
|
|
"github.com/coder/coder/v2/provisionerd"
|
|
"github.com/coder/coder/v2/provisionerd/proto"
|
|
"github.com/coder/coder/v2/provisionersdk"
|
|
sdkproto "github.com/coder/coder/v2/provisionersdk/proto"
|
|
"github.com/coder/coder/v2/tailnet"
|
|
"github.com/coder/retry"
|
|
"github.com/coder/wgtunnel/tunnelsdk"
|
|
)
|
|
|
|
func createOIDCConfig(ctx context.Context, vals *codersdk.DeploymentValues) (*coderd.OIDCConfig, error) {
|
|
if vals.OIDC.ClientID == "" {
|
|
return nil, xerrors.Errorf("OIDC client ID must be set!")
|
|
}
|
|
if vals.OIDC.IssuerURL == "" {
|
|
return nil, xerrors.Errorf("OIDC issuer URL must be set!")
|
|
}
|
|
|
|
oidcProvider, err := oidc.NewProvider(
|
|
ctx, vals.OIDC.IssuerURL.String(),
|
|
)
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("configure oidc provider: %w", err)
|
|
}
|
|
redirectURL, err := vals.AccessURL.Value().Parse("/api/v2/users/oidc/callback")
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("parse oidc oauth callback url: %w", err)
|
|
}
|
|
// If the scopes contain 'groups', we enable group support.
|
|
// Do not override any custom value set by the user.
|
|
if slice.Contains(vals.OIDC.Scopes, "groups") && vals.OIDC.GroupField == "" {
|
|
vals.OIDC.GroupField = "groups"
|
|
}
|
|
oauthCfg := &oauth2.Config{
|
|
ClientID: vals.OIDC.ClientID.String(),
|
|
ClientSecret: vals.OIDC.ClientSecret.String(),
|
|
RedirectURL: redirectURL.String(),
|
|
Endpoint: oidcProvider.Endpoint(),
|
|
Scopes: vals.OIDC.Scopes,
|
|
}
|
|
|
|
var useCfg httpmw.OAuth2Config = oauthCfg
|
|
if vals.OIDC.ClientKeyFile != "" {
|
|
// PKI authentication is done in the params. If a
|
|
// counter example is found, we can add a config option to
|
|
// change this.
|
|
oauthCfg.Endpoint.AuthStyle = oauth2.AuthStyleInParams
|
|
if vals.OIDC.ClientSecret != "" {
|
|
return nil, xerrors.Errorf("cannot specify both oidc client secret and oidc client key file")
|
|
}
|
|
|
|
pkiCfg, err := configureOIDCPKI(oauthCfg, vals.OIDC.ClientKeyFile.Value(), vals.OIDC.ClientCertFile.Value())
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("configure oauth pki authentication: %w", err)
|
|
}
|
|
useCfg = pkiCfg
|
|
}
|
|
if len(vals.OIDC.GroupAllowList) > 0 && vals.OIDC.GroupField == "" {
|
|
return nil, xerrors.Errorf("'oidc-group-field' must be set if 'oidc-allowed-groups' is set. Either unset 'oidc-allowed-groups' or set 'oidc-group-field'")
|
|
}
|
|
|
|
groupAllowList := make(map[string]bool)
|
|
for _, group := range vals.OIDC.GroupAllowList.Value() {
|
|
groupAllowList[group] = true
|
|
}
|
|
|
|
return &coderd.OIDCConfig{
|
|
OAuth2Config: useCfg,
|
|
Provider: oidcProvider,
|
|
Verifier: oidcProvider.Verifier(&oidc.Config{
|
|
ClientID: vals.OIDC.ClientID.String(),
|
|
}),
|
|
EmailDomain: vals.OIDC.EmailDomain,
|
|
AllowSignups: vals.OIDC.AllowSignups.Value(),
|
|
UsernameField: vals.OIDC.UsernameField.String(),
|
|
EmailField: vals.OIDC.EmailField.String(),
|
|
AuthURLParams: vals.OIDC.AuthURLParams.Value,
|
|
IgnoreUserInfo: vals.OIDC.IgnoreUserInfo.Value(),
|
|
GroupField: vals.OIDC.GroupField.String(),
|
|
GroupFilter: vals.OIDC.GroupRegexFilter.Value(),
|
|
GroupAllowList: groupAllowList,
|
|
CreateMissingGroups: vals.OIDC.GroupAutoCreate.Value(),
|
|
GroupMapping: vals.OIDC.GroupMapping.Value,
|
|
UserRoleField: vals.OIDC.UserRoleField.String(),
|
|
UserRoleMapping: vals.OIDC.UserRoleMapping.Value,
|
|
UserRolesDefault: vals.OIDC.UserRolesDefault.GetSlice(),
|
|
SignInText: vals.OIDC.SignInText.String(),
|
|
IconURL: vals.OIDC.IconURL.String(),
|
|
IgnoreEmailVerified: vals.OIDC.IgnoreEmailVerified.Value(),
|
|
}, nil
|
|
}
|
|
|
|
func afterCtx(ctx context.Context, fn func()) {
|
|
go func() {
|
|
<-ctx.Done()
|
|
fn()
|
|
}()
|
|
}
|
|
|
|
func enablePrometheus(
|
|
ctx context.Context,
|
|
logger slog.Logger,
|
|
vals *codersdk.DeploymentValues,
|
|
options *coderd.Options,
|
|
) (closeFn func(), err error) {
|
|
options.PrometheusRegistry.MustRegister(collectors.NewGoCollector())
|
|
options.PrometheusRegistry.MustRegister(collectors.NewProcessCollector(collectors.ProcessCollectorOpts{}))
|
|
|
|
closeUsersFunc, err := prometheusmetrics.ActiveUsers(ctx, options.PrometheusRegistry, options.Database, 0)
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("register active users prometheus metric: %w", err)
|
|
}
|
|
afterCtx(ctx, closeUsersFunc)
|
|
|
|
closeWorkspacesFunc, err := prometheusmetrics.Workspaces(ctx, options.PrometheusRegistry, options.Database, 0)
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("register workspaces prometheus metric: %w", err)
|
|
}
|
|
afterCtx(ctx, closeWorkspacesFunc)
|
|
|
|
insightsMetricsCollector, err := insights.NewMetricsCollector(options.Database, options.Logger, 0, 0)
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("unable to initialize insights metrics collector: %w", err)
|
|
}
|
|
err = options.PrometheusRegistry.Register(insightsMetricsCollector)
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("unable to register insights metrics collector: %w", err)
|
|
}
|
|
|
|
closeInsightsMetricsCollector, err := insightsMetricsCollector.Run(ctx)
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("unable to run insights metrics collector: %w", err)
|
|
}
|
|
afterCtx(ctx, closeInsightsMetricsCollector)
|
|
|
|
if vals.Prometheus.CollectAgentStats {
|
|
closeAgentStatsFunc, err := prometheusmetrics.AgentStats(ctx, logger, options.PrometheusRegistry, options.Database, time.Now(), 0)
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("register agent stats prometheus metric: %w", err)
|
|
}
|
|
afterCtx(ctx, closeAgentStatsFunc)
|
|
|
|
metricsAggregator, err := prometheusmetrics.NewMetricsAggregator(logger, options.PrometheusRegistry, 0)
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("can't initialize metrics aggregator: %w", err)
|
|
}
|
|
|
|
cancelMetricsAggregator := metricsAggregator.Run(ctx)
|
|
afterCtx(ctx, cancelMetricsAggregator)
|
|
|
|
options.UpdateAgentMetrics = metricsAggregator.Update
|
|
err = options.PrometheusRegistry.Register(metricsAggregator)
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("can't register metrics aggregator as collector: %w", err)
|
|
}
|
|
}
|
|
|
|
//nolint:revive
|
|
return ServeHandler(
|
|
ctx, logger, promhttp.InstrumentMetricHandler(
|
|
options.PrometheusRegistry, promhttp.HandlerFor(options.PrometheusRegistry, promhttp.HandlerOpts{}),
|
|
), vals.Prometheus.Address.String(), "prometheus",
|
|
), nil
|
|
}
|
|
|
|
func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.API, io.Closer, error)) *clibase.Cmd {
|
|
if newAPI == nil {
|
|
newAPI = func(_ context.Context, o *coderd.Options) (*coderd.API, io.Closer, error) {
|
|
api := coderd.New(o)
|
|
return api, api, nil
|
|
}
|
|
}
|
|
|
|
var (
|
|
vals = new(codersdk.DeploymentValues)
|
|
opts = vals.Options()
|
|
)
|
|
serverCmd := &clibase.Cmd{
|
|
Use: "server",
|
|
Short: "Start a Coder server",
|
|
Options: opts,
|
|
Middleware: clibase.Chain(
|
|
WriteConfigMW(vals),
|
|
PrintDeprecatedOptions(),
|
|
clibase.RequireNArgs(0),
|
|
),
|
|
Handler: func(inv *clibase.Invocation) error {
|
|
// Main command context for managing cancellation of running
|
|
// services.
|
|
ctx, cancel := context.WithCancel(inv.Context())
|
|
defer cancel()
|
|
|
|
if vals.Config != "" {
|
|
cliui.Warnf(inv.Stderr, "YAML support is experimental and offers no compatibility guarantees.")
|
|
}
|
|
|
|
go DumpHandler(ctx)
|
|
|
|
// Validate bind addresses.
|
|
if vals.Address.String() != "" {
|
|
if vals.TLS.Enable {
|
|
vals.HTTPAddress = ""
|
|
vals.TLS.Address = vals.Address
|
|
} else {
|
|
_ = vals.HTTPAddress.Set(vals.Address.String())
|
|
vals.TLS.Address.Host = ""
|
|
vals.TLS.Address.Port = ""
|
|
}
|
|
}
|
|
if vals.TLS.Enable && vals.TLS.Address.String() == "" {
|
|
return xerrors.Errorf("TLS address must be set if TLS is enabled")
|
|
}
|
|
if !vals.TLS.Enable && vals.HTTPAddress.String() == "" {
|
|
return xerrors.Errorf("TLS is disabled. Enable with --tls-enable or specify a HTTP address")
|
|
}
|
|
|
|
if vals.AccessURL.String() != "" &&
|
|
!(vals.AccessURL.Scheme == "http" || vals.AccessURL.Scheme == "https") {
|
|
return xerrors.Errorf("access-url must include a scheme (e.g. 'http://' or 'https://)")
|
|
}
|
|
|
|
// Disable rate limits if the `--dangerous-disable-rate-limits` flag
|
|
// was specified.
|
|
loginRateLimit := 60
|
|
filesRateLimit := 12
|
|
if vals.RateLimit.DisableAll {
|
|
vals.RateLimit.API = -1
|
|
loginRateLimit = -1
|
|
filesRateLimit = -1
|
|
}
|
|
|
|
PrintLogo(inv, "Coder")
|
|
logger, logCloser, err := BuildLogger(inv, vals)
|
|
if err != nil {
|
|
return xerrors.Errorf("make logger: %w", err)
|
|
}
|
|
defer logCloser()
|
|
|
|
// This line is helpful in tests.
|
|
logger.Debug(ctx, "started debug logging")
|
|
logger.Sync()
|
|
|
|
// Register signals early on so that graceful shutdown can't
|
|
// be interrupted by additional signals. Note that we avoid
|
|
// shadowing cancel() (from above) here because notifyStop()
|
|
// restores default behavior for the signals. This protects
|
|
// the shutdown sequence from abruptly terminating things
|
|
// like: database migrations, provisioner work, workspace
|
|
// cleanup in dev-mode, etc.
|
|
//
|
|
// To get out of a graceful shutdown, the user can send
|
|
// SIGQUIT with ctrl+\ or SIGKILL with `kill -9`.
|
|
notifyCtx, notifyStop := inv.SignalNotifyContext(ctx, InterruptSignals...)
|
|
defer notifyStop()
|
|
|
|
cacheDir := vals.CacheDir.String()
|
|
err = os.MkdirAll(cacheDir, 0o700)
|
|
if err != nil {
|
|
return xerrors.Errorf("create cache directory: %w", err)
|
|
}
|
|
|
|
// Clean up idle connections at the end, e.g.
|
|
// embedded-postgres can leave an idle connection
|
|
// which is caught by goleaks.
|
|
defer http.DefaultClient.CloseIdleConnections()
|
|
|
|
tracerProvider, sqlDriver, closeTracing := ConfigureTraceProvider(ctx, logger, vals)
|
|
defer func() {
|
|
logger.Debug(ctx, "closing tracing")
|
|
traceCloseErr := shutdownWithTimeout(closeTracing, 5*time.Second)
|
|
logger.Debug(ctx, "tracing closed", slog.Error(traceCloseErr))
|
|
}()
|
|
|
|
httpServers, err := ConfigureHTTPServers(logger, inv, vals)
|
|
if err != nil {
|
|
return xerrors.Errorf("configure http(s): %w", err)
|
|
}
|
|
defer httpServers.Close()
|
|
|
|
config := r.createConfig()
|
|
|
|
builtinPostgres := false
|
|
// Only use built-in if PostgreSQL URL isn't specified!
|
|
if !vals.InMemoryDatabase && vals.PostgresURL == "" {
|
|
var closeFunc func() error
|
|
cliui.Infof(inv.Stdout, "Using built-in PostgreSQL (%s)", config.PostgresPath())
|
|
pgURL, closeFunc, err := startBuiltinPostgres(ctx, config, logger)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = vals.PostgresURL.Set(pgURL)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
builtinPostgres = true
|
|
defer func() {
|
|
cliui.Infof(inv.Stdout, "Stopping built-in PostgreSQL...")
|
|
// Gracefully shut PostgreSQL down!
|
|
if err := closeFunc(); err != nil {
|
|
cliui.Errorf(inv.Stderr, "Failed to stop built-in PostgreSQL: %v", err)
|
|
} else {
|
|
cliui.Infof(inv.Stdout, "Stopped built-in PostgreSQL")
|
|
}
|
|
}()
|
|
}
|
|
|
|
// Prefer HTTP because it's less prone to TLS errors over localhost.
|
|
localURL := httpServers.TLSUrl
|
|
if httpServers.HTTPUrl != nil {
|
|
localURL = httpServers.HTTPUrl
|
|
}
|
|
|
|
ctx, httpClient, err := ConfigureHTTPClient(
|
|
ctx,
|
|
vals.TLS.ClientCertFile.String(),
|
|
vals.TLS.ClientKeyFile.String(),
|
|
vals.TLS.ClientCAFile.String(),
|
|
)
|
|
if err != nil {
|
|
return xerrors.Errorf("configure http client: %w", err)
|
|
}
|
|
|
|
// If the access URL is empty, we attempt to run a reverse-proxy
|
|
// tunnel to make the initial setup really simple.
|
|
var (
|
|
tunnel *tunnelsdk.Tunnel
|
|
tunnelDone <-chan struct{} = make(chan struct{}, 1)
|
|
)
|
|
if vals.AccessURL.String() == "" {
|
|
cliui.Infof(inv.Stderr, "Opening tunnel so workspaces can connect to your deployment. For production scenarios, specify an external access URL")
|
|
tunnel, err = devtunnel.New(ctx, logger.Named("net.devtunnel"), vals.WgtunnelHost.String())
|
|
if err != nil {
|
|
return xerrors.Errorf("create tunnel: %w", err)
|
|
}
|
|
defer tunnel.Close()
|
|
tunnelDone = tunnel.Wait()
|
|
vals.AccessURL = clibase.URL(*tunnel.URL)
|
|
|
|
if vals.WildcardAccessURL.String() == "" {
|
|
// Suffixed wildcard access URL.
|
|
u, err := url.Parse(fmt.Sprintf("*--%s", tunnel.URL.Hostname()))
|
|
if err != nil {
|
|
return xerrors.Errorf("parse wildcard url: %w", err)
|
|
}
|
|
vals.WildcardAccessURL = clibase.URL(*u)
|
|
}
|
|
}
|
|
|
|
_, accessURLPortRaw, _ := net.SplitHostPort(vals.AccessURL.Host)
|
|
if accessURLPortRaw == "" {
|
|
accessURLPortRaw = "80"
|
|
if vals.AccessURL.Scheme == "https" {
|
|
accessURLPortRaw = "443"
|
|
}
|
|
}
|
|
|
|
accessURLPort, err := strconv.Atoi(accessURLPortRaw)
|
|
if err != nil {
|
|
return xerrors.Errorf("parse access URL port: %w", err)
|
|
}
|
|
|
|
// Warn the user if the access URL is loopback or unresolvable.
|
|
isLocal, err := IsLocalURL(ctx, vals.AccessURL.Value())
|
|
if isLocal || err != nil {
|
|
reason := "could not be resolved"
|
|
if isLocal {
|
|
reason = "isn't externally reachable"
|
|
}
|
|
cliui.Warnf(
|
|
inv.Stderr,
|
|
"The access URL %s %s, this may cause unexpected problems when creating workspaces. Generate a unique *.try.coder.app URL by not specifying an access URL.\n",
|
|
pretty.Sprint(cliui.DefaultStyles.Field, vals.AccessURL.String()), reason,
|
|
)
|
|
}
|
|
|
|
// A newline is added before for visibility in terminal output.
|
|
cliui.Infof(inv.Stdout, "\nView the Web UI: %s", vals.AccessURL.String())
|
|
|
|
// Used for zero-trust instance identity with Google Cloud.
|
|
googleTokenValidator, err := idtoken.NewValidator(ctx, option.WithoutAuthentication())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
sshKeygenAlgorithm, err := gitsshkey.ParseAlgorithm(vals.SSHKeygenAlgorithm.String())
|
|
if err != nil {
|
|
return xerrors.Errorf("parse ssh keygen algorithm %s: %w", vals.SSHKeygenAlgorithm, err)
|
|
}
|
|
|
|
defaultRegion := &tailcfg.DERPRegion{
|
|
EmbeddedRelay: true,
|
|
RegionID: int(vals.DERP.Server.RegionID.Value()),
|
|
RegionCode: vals.DERP.Server.RegionCode.String(),
|
|
RegionName: vals.DERP.Server.RegionName.String(),
|
|
Nodes: []*tailcfg.DERPNode{{
|
|
Name: fmt.Sprintf("%db", vals.DERP.Server.RegionID),
|
|
RegionID: int(vals.DERP.Server.RegionID.Value()),
|
|
HostName: vals.AccessURL.Value().Hostname(),
|
|
DERPPort: accessURLPort,
|
|
STUNPort: -1,
|
|
ForceHTTP: vals.AccessURL.Scheme == "http",
|
|
}},
|
|
}
|
|
if !vals.DERP.Server.Enable {
|
|
defaultRegion = nil
|
|
}
|
|
|
|
derpMap, err := tailnet.NewDERPMap(
|
|
ctx, defaultRegion, vals.DERP.Server.STUNAddresses,
|
|
vals.DERP.Config.URL.String(), vals.DERP.Config.Path.String(),
|
|
vals.DERP.Config.BlockDirect.Value(),
|
|
)
|
|
if err != nil {
|
|
return xerrors.Errorf("create derp map: %w", err)
|
|
}
|
|
|
|
appHostname := vals.WildcardAccessURL.String()
|
|
var appHostnameRegex *regexp.Regexp
|
|
if appHostname != "" {
|
|
appHostnameRegex, err = httpapi.CompileHostnamePattern(appHostname)
|
|
if err != nil {
|
|
return xerrors.Errorf("parse wildcard access URL %q: %w", appHostname, err)
|
|
}
|
|
}
|
|
|
|
extAuthEnv, err := ReadExternalAuthProvidersFromEnv(os.Environ())
|
|
if err != nil {
|
|
return xerrors.Errorf("read external auth providers from env: %w", err)
|
|
}
|
|
|
|
vals.ExternalAuthConfigs.Value = append(vals.ExternalAuthConfigs.Value, extAuthEnv...)
|
|
externalAuthConfigs, err := externalauth.ConvertConfig(
|
|
vals.ExternalAuthConfigs.Value,
|
|
vals.AccessURL.Value(),
|
|
)
|
|
if err != nil {
|
|
return xerrors.Errorf("convert external auth config: %w", err)
|
|
}
|
|
for _, c := range externalAuthConfigs {
|
|
logger.Debug(
|
|
ctx, "loaded external auth config",
|
|
slog.F("id", c.ID),
|
|
)
|
|
}
|
|
|
|
realIPConfig, err := httpmw.ParseRealIPConfig(vals.ProxyTrustedHeaders, vals.ProxyTrustedOrigins)
|
|
if err != nil {
|
|
return xerrors.Errorf("parse real ip config: %w", err)
|
|
}
|
|
|
|
configSSHOptions, err := vals.SSHConfig.ParseOptions()
|
|
if err != nil {
|
|
return xerrors.Errorf("parse ssh config options %q: %w", vals.SSHConfig.SSHConfigOptions.String(), err)
|
|
}
|
|
|
|
options := &coderd.Options{
|
|
AccessURL: vals.AccessURL.Value(),
|
|
AppHostname: appHostname,
|
|
AppHostnameRegex: appHostnameRegex,
|
|
Logger: logger.Named("coderd"),
|
|
Database: dbmem.New(),
|
|
BaseDERPMap: derpMap,
|
|
Pubsub: pubsub.NewInMemory(),
|
|
CacheDir: cacheDir,
|
|
GoogleTokenValidator: googleTokenValidator,
|
|
ExternalAuthConfigs: externalAuthConfigs,
|
|
RealIPConfig: realIPConfig,
|
|
SecureAuthCookie: vals.SecureAuthCookie.Value(),
|
|
SSHKeygenAlgorithm: sshKeygenAlgorithm,
|
|
TracerProvider: tracerProvider,
|
|
Telemetry: telemetry.NewNoop(),
|
|
MetricsCacheRefreshInterval: vals.MetricsCacheRefreshInterval.Value(),
|
|
AgentStatsRefreshInterval: vals.AgentStatRefreshInterval.Value(),
|
|
DeploymentValues: vals,
|
|
// Do not pass secret values to DeploymentOptions. All values should be read from
|
|
// the DeploymentValues instead, this just serves to indicate the source of each
|
|
// option. This is just defensive to prevent accidentally leaking.
|
|
DeploymentOptions: codersdk.DeploymentOptionsWithoutSecrets(opts),
|
|
PrometheusRegistry: prometheus.NewRegistry(),
|
|
APIRateLimit: int(vals.RateLimit.API.Value()),
|
|
LoginRateLimit: loginRateLimit,
|
|
FilesRateLimit: filesRateLimit,
|
|
HTTPClient: httpClient,
|
|
TemplateScheduleStore: &atomic.Pointer[schedule.TemplateScheduleStore]{},
|
|
UserQuietHoursScheduleStore: &atomic.Pointer[schedule.UserQuietHoursScheduleStore]{},
|
|
SSHConfig: codersdk.SSHConfigResponse{
|
|
HostnamePrefix: vals.SSHConfig.DeploymentName.String(),
|
|
SSHConfigOptions: configSSHOptions,
|
|
},
|
|
}
|
|
if httpServers.TLSConfig != nil {
|
|
options.TLSCertificates = httpServers.TLSConfig.Certificates
|
|
}
|
|
|
|
if vals.StrictTransportSecurity > 0 {
|
|
options.StrictTransportSecurityCfg, err = httpmw.HSTSConfigOptions(
|
|
int(vals.StrictTransportSecurity.Value()), vals.StrictTransportSecurityOptions,
|
|
)
|
|
if err != nil {
|
|
return xerrors.Errorf("coderd: setting hsts header failed (options: %v): %w", vals.StrictTransportSecurityOptions, err)
|
|
}
|
|
}
|
|
|
|
if vals.UpdateCheck {
|
|
options.UpdateCheckOptions = &updatecheck.Options{
|
|
// Avoid spamming GitHub API checking for updates.
|
|
Interval: 24 * time.Hour,
|
|
// Inform server admins of new versions.
|
|
Notify: func(r updatecheck.Result) {
|
|
if semver.Compare(r.Version, buildinfo.Version()) > 0 {
|
|
options.Logger.Info(
|
|
context.Background(),
|
|
"new version of coder available",
|
|
slog.F("new_version", r.Version),
|
|
slog.F("url", r.URL),
|
|
slog.F("upgrade_instructions", "https://coder.com/docs/coder-oss/latest/admin/upgrade"),
|
|
)
|
|
}
|
|
},
|
|
}
|
|
}
|
|
|
|
if vals.OAuth2.Github.ClientSecret != "" {
|
|
options.GithubOAuth2Config, err = configureGithubOAuth2(vals.AccessURL.Value(),
|
|
vals.OAuth2.Github.ClientID.String(),
|
|
vals.OAuth2.Github.ClientSecret.String(),
|
|
vals.OAuth2.Github.AllowSignups.Value(),
|
|
vals.OAuth2.Github.AllowEveryone.Value(),
|
|
vals.OAuth2.Github.AllowedOrgs,
|
|
vals.OAuth2.Github.AllowedTeams,
|
|
vals.OAuth2.Github.EnterpriseBaseURL.String(),
|
|
)
|
|
if err != nil {
|
|
return xerrors.Errorf("configure github oauth2: %w", err)
|
|
}
|
|
}
|
|
|
|
if vals.OIDC.ClientKeyFile != "" || vals.OIDC.ClientSecret != "" {
|
|
if vals.OIDC.IgnoreEmailVerified {
|
|
logger.Warn(ctx, "coder will not check email_verified for OIDC logins")
|
|
}
|
|
|
|
oc, err := createOIDCConfig(ctx, vals)
|
|
if err != nil {
|
|
return xerrors.Errorf("create oidc config: %w", err)
|
|
}
|
|
options.OIDCConfig = oc
|
|
}
|
|
|
|
if vals.InMemoryDatabase {
|
|
// This is only used for testing.
|
|
options.Database = dbmem.New()
|
|
options.Pubsub = pubsub.NewInMemory()
|
|
} else {
|
|
sqlDB, err := ConnectToPostgres(ctx, logger, sqlDriver, vals.PostgresURL.String())
|
|
if err != nil {
|
|
return xerrors.Errorf("connect to postgres: %w", err)
|
|
}
|
|
defer func() {
|
|
_ = sqlDB.Close()
|
|
}()
|
|
|
|
options.Database = database.New(sqlDB)
|
|
options.Pubsub, err = pubsub.New(ctx, sqlDB, vals.PostgresURL.String())
|
|
if err != nil {
|
|
return xerrors.Errorf("create pubsub: %w", err)
|
|
}
|
|
defer options.Pubsub.Close()
|
|
}
|
|
|
|
if options.DeploymentValues.Prometheus.Enable && options.DeploymentValues.Prometheus.CollectDBMetrics {
|
|
options.Database = dbmetrics.New(options.Database, options.PrometheusRegistry)
|
|
}
|
|
|
|
var deploymentID string
|
|
err = options.Database.InTx(func(tx database.Store) error {
|
|
// This will block until the lock is acquired, and will be
|
|
// automatically released when the transaction ends.
|
|
err := tx.AcquireLock(ctx, database.LockIDDeploymentSetup)
|
|
if err != nil {
|
|
return xerrors.Errorf("acquire lock: %w", err)
|
|
}
|
|
|
|
deploymentID, err = tx.GetDeploymentID(ctx)
|
|
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
|
|
return xerrors.Errorf("get deployment id: %w", err)
|
|
}
|
|
if deploymentID == "" {
|
|
deploymentID = uuid.NewString()
|
|
err = tx.InsertDeploymentID(ctx, deploymentID)
|
|
if err != nil {
|
|
return xerrors.Errorf("set deployment id: %w", err)
|
|
}
|
|
}
|
|
|
|
// Read the app signing key from the DB. We store it hex encoded
|
|
// since the config table uses strings for the value and we
|
|
// don't want to deal with automatic encoding issues.
|
|
appSecurityKeyStr, err := tx.GetAppSecurityKey(ctx)
|
|
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
|
|
return xerrors.Errorf("get app signing key: %w", err)
|
|
}
|
|
// If the string in the DB is an invalid hex string or the
|
|
// length is not equal to the current key length, generate a new
|
|
// one.
|
|
//
|
|
// If the key is regenerated, old signed tokens and encrypted
|
|
// strings will become invalid. New signed app tokens will be
|
|
// generated automatically on failure. Any workspace app token
|
|
// smuggling operations in progress may fail, although with a
|
|
// helpful error.
|
|
if decoded, err := hex.DecodeString(appSecurityKeyStr); err != nil || len(decoded) != len(workspaceapps.SecurityKey{}) {
|
|
b := make([]byte, len(workspaceapps.SecurityKey{}))
|
|
_, err := rand.Read(b)
|
|
if err != nil {
|
|
return xerrors.Errorf("generate fresh app signing key: %w", err)
|
|
}
|
|
|
|
appSecurityKeyStr = hex.EncodeToString(b)
|
|
err = tx.UpsertAppSecurityKey(ctx, appSecurityKeyStr)
|
|
if err != nil {
|
|
return xerrors.Errorf("insert freshly generated app signing key to database: %w", err)
|
|
}
|
|
}
|
|
|
|
appSecurityKey, err := workspaceapps.KeyFromString(appSecurityKeyStr)
|
|
if err != nil {
|
|
return xerrors.Errorf("decode app signing key from database: %w", err)
|
|
}
|
|
|
|
options.AppSecurityKey = appSecurityKey
|
|
|
|
// Read the oauth signing key from the database. Like the app security, generate a new one
|
|
// if it is invalid for any reason.
|
|
oauthSigningKeyStr, err := tx.GetOAuthSigningKey(ctx)
|
|
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
|
|
return xerrors.Errorf("get app oauth signing key: %w", err)
|
|
}
|
|
if decoded, err := hex.DecodeString(oauthSigningKeyStr); err != nil || len(decoded) != len(options.OAuthSigningKey) {
|
|
b := make([]byte, len(options.OAuthSigningKey))
|
|
_, err := rand.Read(b)
|
|
if err != nil {
|
|
return xerrors.Errorf("generate fresh oauth signing key: %w", err)
|
|
}
|
|
|
|
oauthSigningKeyStr = hex.EncodeToString(b)
|
|
err = tx.UpsertOAuthSigningKey(ctx, oauthSigningKeyStr)
|
|
if err != nil {
|
|
return xerrors.Errorf("insert freshly generated oauth signing key to database: %w", err)
|
|
}
|
|
}
|
|
|
|
keyBytes, err := hex.DecodeString(oauthSigningKeyStr)
|
|
if err != nil {
|
|
return xerrors.Errorf("decode oauth signing key from database: %w", err)
|
|
}
|
|
if len(keyBytes) != len(options.OAuthSigningKey) {
|
|
return xerrors.Errorf("oauth signing key in database is not the correct length, expect %d got %d", len(options.OAuthSigningKey), len(keyBytes))
|
|
}
|
|
copy(options.OAuthSigningKey[:], keyBytes)
|
|
if options.OAuthSigningKey == [32]byte{} {
|
|
return xerrors.Errorf("oauth signing key in database is empty")
|
|
}
|
|
|
|
return nil
|
|
}, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if vals.Telemetry.Enable {
|
|
gitAuth := make([]telemetry.GitAuth, 0)
|
|
// TODO:
|
|
var gitAuthConfigs []codersdk.ExternalAuthConfig
|
|
for _, cfg := range gitAuthConfigs {
|
|
gitAuth = append(gitAuth, telemetry.GitAuth{
|
|
Type: cfg.Type,
|
|
})
|
|
}
|
|
|
|
options.Telemetry, err = telemetry.New(telemetry.Options{
|
|
BuiltinPostgres: builtinPostgres,
|
|
DeploymentID: deploymentID,
|
|
Database: options.Database,
|
|
Logger: logger.Named("telemetry"),
|
|
URL: vals.Telemetry.URL.Value(),
|
|
Wildcard: vals.WildcardAccessURL.String() != "",
|
|
DERPServerRelayURL: vals.DERP.Server.RelayURL.String(),
|
|
GitAuth: gitAuth,
|
|
GitHubOAuth: vals.OAuth2.Github.ClientID != "",
|
|
OIDCAuth: vals.OIDC.ClientID != "",
|
|
OIDCIssuerURL: vals.OIDC.IssuerURL.String(),
|
|
Prometheus: vals.Prometheus.Enable.Value(),
|
|
STUN: len(vals.DERP.Server.STUNAddresses) != 0,
|
|
Tunnel: tunnel != nil,
|
|
ParseLicenseJWT: func(lic *telemetry.License) error {
|
|
// This will be nil when running in AGPL-only mode.
|
|
if options.ParseLicenseClaims == nil {
|
|
return nil
|
|
}
|
|
|
|
email, trial, err := options.ParseLicenseClaims(lic.JWT)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if email != "" {
|
|
lic.Email = &email
|
|
}
|
|
lic.Trial = &trial
|
|
return nil
|
|
},
|
|
})
|
|
if err != nil {
|
|
return xerrors.Errorf("create telemetry reporter: %w", err)
|
|
}
|
|
defer options.Telemetry.Close()
|
|
} else {
|
|
logger.Warn(ctx, `telemetry disabled, unable to notify of security issues. Read more: https://coder.com/docs/v2/latest/admin/telemetry`)
|
|
}
|
|
|
|
// This prevents the pprof import from being accidentally deleted.
|
|
_ = pprof.Handler
|
|
if vals.Pprof.Enable {
|
|
//nolint:revive
|
|
defer ServeHandler(ctx, logger, nil, vals.Pprof.Address.String(), "pprof")()
|
|
}
|
|
if vals.Prometheus.Enable {
|
|
closeFn, err := enablePrometheus(
|
|
ctx,
|
|
logger.Named("prometheus"),
|
|
vals,
|
|
options,
|
|
)
|
|
if err != nil {
|
|
return xerrors.Errorf("enable prometheus: %w", err)
|
|
}
|
|
defer closeFn()
|
|
}
|
|
|
|
if vals.Swagger.Enable {
|
|
options.SwaggerEndpoint = vals.Swagger.Enable.Value()
|
|
}
|
|
|
|
batcher, closeBatcher, err := batchstats.New(ctx,
|
|
batchstats.WithLogger(options.Logger.Named("batchstats")),
|
|
batchstats.WithStore(options.Database),
|
|
)
|
|
if err != nil {
|
|
return xerrors.Errorf("failed to create agent stats batcher: %w", err)
|
|
}
|
|
options.StatsBatcher = batcher
|
|
defer closeBatcher()
|
|
|
|
// We use a separate coderAPICloser so the Enterprise API
|
|
// can have its own close functions. This is cleaner
|
|
// than abstracting the Coder API itself.
|
|
coderAPI, coderAPICloser, err := newAPI(ctx, options)
|
|
if err != nil {
|
|
return xerrors.Errorf("create coder API: %w", err)
|
|
}
|
|
|
|
if vals.Prometheus.Enable {
|
|
// Agent metrics require reference to the tailnet coordinator, so must be initiated after Coder API.
|
|
closeAgentsFunc, err := prometheusmetrics.Agents(ctx, logger, options.PrometheusRegistry, coderAPI.Database, &coderAPI.TailnetCoordinator, coderAPI.DERPMap, coderAPI.Options.AgentInactiveDisconnectTimeout, 0)
|
|
if err != nil {
|
|
return xerrors.Errorf("register agents prometheus metric: %w", err)
|
|
}
|
|
defer closeAgentsFunc()
|
|
}
|
|
|
|
client := codersdk.New(localURL)
|
|
if localURL.Scheme == "https" && IsLocalhost(localURL.Hostname()) {
|
|
// The certificate will likely be self-signed or for a different
|
|
// hostname, so we need to skip verification.
|
|
client.HTTPClient.Transport = &http.Transport{
|
|
TLSClientConfig: &tls.Config{
|
|
//nolint:gosec
|
|
InsecureSkipVerify: true,
|
|
},
|
|
}
|
|
}
|
|
defer client.HTTPClient.CloseIdleConnections()
|
|
|
|
// This is helpful for tests, but can be silently ignored.
|
|
// Coder may be ran as users that don't have permission to write in the homedir,
|
|
// such as via the systemd service.
|
|
err = config.URL().Write(client.URL.String())
|
|
if err != nil && flag.Lookup("test.v") != nil {
|
|
return xerrors.Errorf("write config url: %w", err)
|
|
}
|
|
|
|
// Since errCh only has one buffered slot, all routines
|
|
// sending on it must be wrapped in a select/default to
|
|
// avoid leaving dangling goroutines waiting for the
|
|
// channel to be consumed.
|
|
errCh := make(chan error, 1)
|
|
provisionerDaemons := make([]*provisionerd.Server, 0)
|
|
defer func() {
|
|
// We have no graceful shutdown of provisionerDaemons
|
|
// here because that's handled at the end of main, this
|
|
// is here in case the program exits early.
|
|
for _, daemon := range provisionerDaemons {
|
|
_ = daemon.Close()
|
|
}
|
|
}()
|
|
|
|
var provisionerdWaitGroup sync.WaitGroup
|
|
defer provisionerdWaitGroup.Wait()
|
|
provisionerdMetrics := provisionerd.NewMetrics(options.PrometheusRegistry)
|
|
for i := int64(0); i < vals.Provisioner.Daemons.Value(); i++ {
|
|
suffix := fmt.Sprintf("%d", i)
|
|
// The suffix is added to the hostname, so we may need to trim to fit into
|
|
// the 64 character limit.
|
|
hostname := stringutil.Truncate(cliutil.Hostname(), 63-len(suffix))
|
|
name := fmt.Sprintf("%s-%s", hostname, suffix)
|
|
daemonCacheDir := filepath.Join(cacheDir, fmt.Sprintf("provisioner-%d", i))
|
|
daemon, err := newProvisionerDaemon(
|
|
ctx, coderAPI, provisionerdMetrics, logger, vals, daemonCacheDir, errCh, &provisionerdWaitGroup, name,
|
|
)
|
|
if err != nil {
|
|
return xerrors.Errorf("create provisioner daemon: %w", err)
|
|
}
|
|
provisionerDaemons = append(provisionerDaemons, daemon)
|
|
}
|
|
provisionerdMetrics.Runner.NumDaemons.Set(float64(len(provisionerDaemons)))
|
|
|
|
shutdownConnsCtx, shutdownConns := context.WithCancel(ctx)
|
|
defer shutdownConns()
|
|
|
|
// Ensures that old database entries are cleaned up over time!
|
|
purger := dbpurge.New(ctx, logger, options.Database)
|
|
defer purger.Close()
|
|
|
|
// Wrap the server in middleware that redirects to the access URL if
|
|
// the request is not to a local IP.
|
|
var handler http.Handler = coderAPI.RootHandler
|
|
if vals.RedirectToAccessURL {
|
|
handler = redirectToAccessURL(handler, vals.AccessURL.Value(), tunnel != nil, appHostnameRegex)
|
|
}
|
|
|
|
// ReadHeaderTimeout is purposefully not enabled. It caused some
|
|
// issues with websockets over the dev tunnel.
|
|
// See: https://github.com/coder/coder/pull/3730
|
|
//nolint:gosec
|
|
httpServer := &http.Server{
|
|
// These errors are typically noise like "TLS: EOF". Vault does
|
|
// similar:
|
|
// https://github.com/hashicorp/vault/blob/e2490059d0711635e529a4efcbaa1b26998d6e1c/command/server.go#L2714
|
|
ErrorLog: log.New(io.Discard, "", 0),
|
|
Handler: handler,
|
|
BaseContext: func(_ net.Listener) context.Context {
|
|
return shutdownConnsCtx
|
|
},
|
|
}
|
|
defer func() {
|
|
_ = shutdownWithTimeout(httpServer.Shutdown, 5*time.Second)
|
|
}()
|
|
|
|
// We call this in the routine so we can kill the other listeners if
|
|
// one of them fails.
|
|
closeListenersNow := func() {
|
|
httpServers.Close()
|
|
if tunnel != nil {
|
|
_ = tunnel.Listener.Close()
|
|
}
|
|
}
|
|
|
|
eg := errgroup.Group{}
|
|
eg.Go(func() error {
|
|
defer closeListenersNow()
|
|
return httpServers.Serve(httpServer)
|
|
})
|
|
if tunnel != nil {
|
|
eg.Go(func() error {
|
|
defer closeListenersNow()
|
|
return httpServer.Serve(tunnel.Listener)
|
|
})
|
|
}
|
|
|
|
go func() {
|
|
select {
|
|
case errCh <- eg.Wait():
|
|
default:
|
|
}
|
|
}()
|
|
|
|
cliui.Infof(inv.Stdout, "\n==> Logs will stream in below (press ctrl+c to gracefully exit):")
|
|
|
|
// Updates the systemd status from activating to activated.
|
|
_, err = daemon.SdNotify(false, daemon.SdNotifyReady)
|
|
if err != nil {
|
|
return xerrors.Errorf("notify systemd: %w", err)
|
|
}
|
|
|
|
autobuildTicker := time.NewTicker(vals.AutobuildPollInterval.Value())
|
|
defer autobuildTicker.Stop()
|
|
autobuildExecutor := autobuild.NewExecutor(
|
|
ctx, options.Database, options.Pubsub, coderAPI.TemplateScheduleStore, &coderAPI.Auditor, coderAPI.AccessControlStore, logger, autobuildTicker.C)
|
|
autobuildExecutor.Run()
|
|
|
|
hangDetectorTicker := time.NewTicker(vals.JobHangDetectorInterval.Value())
|
|
defer hangDetectorTicker.Stop()
|
|
hangDetector := unhanger.New(ctx, options.Database, options.Pubsub, logger, hangDetectorTicker.C)
|
|
hangDetector.Start()
|
|
defer hangDetector.Close()
|
|
|
|
// Currently there is no way to ask the server to shut
|
|
// itself down, so any exit signal will result in a non-zero
|
|
// exit of the server.
|
|
var exitErr error
|
|
select {
|
|
case <-notifyCtx.Done():
|
|
exitErr = notifyCtx.Err()
|
|
_, _ = io.WriteString(inv.Stdout, cliui.Bold("Interrupt caught, gracefully exiting. Use ctrl+\\ to force quit"))
|
|
case <-tunnelDone:
|
|
exitErr = xerrors.New("dev tunnel closed unexpectedly")
|
|
case exitErr = <-errCh:
|
|
}
|
|
if exitErr != nil && !xerrors.Is(exitErr, context.Canceled) {
|
|
cliui.Errorf(inv.Stderr, "Unexpected error, shutting down server: %s\n", exitErr)
|
|
}
|
|
|
|
// Begin clean shut down stage, we try to shut down services
|
|
// gracefully in an order that gives the best experience.
|
|
// This procedure should not differ greatly from the order
|
|
// of `defer`s in this function, but allows us to inform
|
|
// the user about what's going on and handle errors more
|
|
// explicitly.
|
|
|
|
_, err = daemon.SdNotify(false, daemon.SdNotifyStopping)
|
|
if err != nil {
|
|
cliui.Errorf(inv.Stderr, "Notify systemd failed: %s", err)
|
|
}
|
|
|
|
// Stop accepting new connections without interrupting
|
|
// in-flight requests, give in-flight requests 5 seconds to
|
|
// complete.
|
|
cliui.Info(inv.Stdout, "Shutting down API server..."+"\n")
|
|
err = shutdownWithTimeout(httpServer.Shutdown, 3*time.Second)
|
|
if err != nil {
|
|
cliui.Errorf(inv.Stderr, "API server shutdown took longer than 3s: %s\n", err)
|
|
} else {
|
|
cliui.Info(inv.Stdout, "Gracefully shut down API server\n")
|
|
}
|
|
// Cancel any remaining in-flight requests.
|
|
shutdownConns()
|
|
|
|
// Shut down provisioners before waiting for WebSockets
|
|
// connections to close.
|
|
var wg sync.WaitGroup
|
|
for i, provisionerDaemon := range provisionerDaemons {
|
|
id := i + 1
|
|
provisionerDaemon := provisionerDaemon
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
|
|
r.Verbosef(inv, "Shutting down provisioner daemon %d...", id)
|
|
err := shutdownWithTimeout(provisionerDaemon.Shutdown, 5*time.Second)
|
|
if err != nil {
|
|
cliui.Errorf(inv.Stderr, "Failed to shut down provisioner daemon %d: %s\n", id, err)
|
|
return
|
|
}
|
|
err = provisionerDaemon.Close()
|
|
if err != nil {
|
|
cliui.Errorf(inv.Stderr, "Close provisioner daemon %d: %s\n", id, err)
|
|
return
|
|
}
|
|
r.Verbosef(inv, "Gracefully shut down provisioner daemon %d", id)
|
|
}()
|
|
}
|
|
wg.Wait()
|
|
|
|
cliui.Info(inv.Stdout, "Waiting for WebSocket connections to close..."+"\n")
|
|
_ = coderAPICloser.Close()
|
|
cliui.Info(inv.Stdout, "Done waiting for WebSocket connections"+"\n")
|
|
|
|
// Close tunnel after we no longer have in-flight connections.
|
|
if tunnel != nil {
|
|
cliui.Infof(inv.Stdout, "Waiting for tunnel to close...")
|
|
_ = tunnel.Close()
|
|
<-tunnel.Wait()
|
|
cliui.Infof(inv.Stdout, "Done waiting for tunnel")
|
|
}
|
|
|
|
// Ensures a last report can be sent before exit!
|
|
options.Telemetry.Close()
|
|
|
|
// Trigger context cancellation for any remaining services.
|
|
cancel()
|
|
|
|
switch {
|
|
case xerrors.Is(exitErr, context.DeadlineExceeded):
|
|
cliui.Warnf(inv.Stderr, "Graceful shutdown timed out")
|
|
// Errors here cause a significant number of benign CI failures.
|
|
return nil
|
|
case xerrors.Is(exitErr, context.Canceled):
|
|
return nil
|
|
case exitErr != nil:
|
|
return xerrors.Errorf("graceful shutdown: %w", exitErr)
|
|
default:
|
|
return nil
|
|
}
|
|
},
|
|
}
|
|
|
|
var pgRawURL bool
|
|
|
|
postgresBuiltinURLCmd := &clibase.Cmd{
|
|
Use: "postgres-builtin-url",
|
|
Short: "Output the connection URL for the built-in PostgreSQL deployment.",
|
|
Handler: func(inv *clibase.Invocation) error {
|
|
url, err := embeddedPostgresURL(r.createConfig())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if pgRawURL {
|
|
_, _ = fmt.Fprintf(inv.Stdout, "%s\n", url)
|
|
} else {
|
|
_, _ = fmt.Fprintf(inv.Stdout, "%s\n", pretty.Sprint(cliui.DefaultStyles.Code, fmt.Sprintf("psql %q", url)))
|
|
}
|
|
return nil
|
|
},
|
|
}
|
|
|
|
postgresBuiltinServeCmd := &clibase.Cmd{
|
|
Use: "postgres-builtin-serve",
|
|
Short: "Run the built-in PostgreSQL deployment.",
|
|
Handler: func(inv *clibase.Invocation) error {
|
|
ctx := inv.Context()
|
|
|
|
cfg := r.createConfig()
|
|
logger := inv.Logger.AppendSinks(sloghuman.Sink(inv.Stderr))
|
|
if ok, _ := inv.ParsedFlags().GetBool(varVerbose); ok {
|
|
logger = logger.Leveled(slog.LevelDebug)
|
|
}
|
|
|
|
ctx, cancel := inv.SignalNotifyContext(ctx, InterruptSignals...)
|
|
defer cancel()
|
|
|
|
url, closePg, err := startBuiltinPostgres(ctx, cfg, logger)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer func() { _ = closePg() }()
|
|
|
|
if pgRawURL {
|
|
_, _ = fmt.Fprintf(inv.Stdout, "%s\n", url)
|
|
} else {
|
|
_, _ = fmt.Fprintf(inv.Stdout, "%s\n", pretty.Sprint(cliui.DefaultStyles.Code, fmt.Sprintf("psql %q", url)))
|
|
}
|
|
|
|
<-ctx.Done()
|
|
return nil
|
|
},
|
|
}
|
|
|
|
createAdminUserCmd := r.newCreateAdminUserCommand()
|
|
|
|
rawURLOpt := clibase.Option{
|
|
Flag: "raw-url",
|
|
|
|
Value: clibase.BoolOf(&pgRawURL),
|
|
Description: "Output the raw connection URL instead of a psql command.",
|
|
}
|
|
createAdminUserCmd.Options.Add(rawURLOpt)
|
|
postgresBuiltinURLCmd.Options.Add(rawURLOpt)
|
|
postgresBuiltinServeCmd.Options.Add(rawURLOpt)
|
|
|
|
serverCmd.Children = append(
|
|
serverCmd.Children,
|
|
createAdminUserCmd, postgresBuiltinURLCmd, postgresBuiltinServeCmd,
|
|
)
|
|
|
|
return serverCmd
|
|
}
|
|
|
|
// printDeprecatedOptions loops through all command options, and prints
|
|
// a warning for usage of deprecated options.
|
|
func PrintDeprecatedOptions() clibase.MiddlewareFunc {
|
|
return func(next clibase.HandlerFunc) clibase.HandlerFunc {
|
|
return func(inv *clibase.Invocation) error {
|
|
opts := inv.Command.Options
|
|
// Print deprecation warnings.
|
|
for _, opt := range opts {
|
|
if opt.UseInstead == nil {
|
|
continue
|
|
}
|
|
|
|
if opt.ValueSource == clibase.ValueSourceNone || opt.ValueSource == clibase.ValueSourceDefault {
|
|
continue
|
|
}
|
|
|
|
warnStr := opt.Name + " is deprecated, please use "
|
|
for i, use := range opt.UseInstead {
|
|
warnStr += use.Name + " "
|
|
if i != len(opt.UseInstead)-1 {
|
|
warnStr += "and "
|
|
}
|
|
}
|
|
warnStr += "instead.\n"
|
|
|
|
cliui.Warn(inv.Stderr,
|
|
warnStr,
|
|
)
|
|
}
|
|
|
|
return next(inv)
|
|
}
|
|
}
|
|
}
|
|
|
|
// writeConfigMW will prevent the main command from running if the write-config
|
|
// flag is set. Instead, it will marshal the command options to YAML and write
|
|
// them to stdout.
|
|
func WriteConfigMW(cfg *codersdk.DeploymentValues) clibase.MiddlewareFunc {
|
|
return func(next clibase.HandlerFunc) clibase.HandlerFunc {
|
|
return func(inv *clibase.Invocation) error {
|
|
if !cfg.WriteConfig {
|
|
return next(inv)
|
|
}
|
|
|
|
opts := inv.Command.Options
|
|
n, err := opts.MarshalYAML()
|
|
if err != nil {
|
|
return xerrors.Errorf("generate yaml: %w", err)
|
|
}
|
|
enc := yaml.NewEncoder(inv.Stdout)
|
|
enc.SetIndent(2)
|
|
err = enc.Encode(n)
|
|
if err != nil {
|
|
return xerrors.Errorf("encode yaml: %w", err)
|
|
}
|
|
err = enc.Close()
|
|
if err != nil {
|
|
return xerrors.Errorf("close yaml encoder: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
}
|
|
|
|
// isLocalURL returns true if the hostname of the provided URL appears to
|
|
// resolve to a loopback address.
|
|
func IsLocalURL(ctx context.Context, u *url.URL) (bool, error) {
|
|
// In tests, we commonly use "example.com" or "google.com", which
|
|
// are not loopback, so avoid the DNS lookup to avoid flakes.
|
|
if flag.Lookup("test.v") != nil {
|
|
if u.Hostname() == "example.com" || u.Hostname() == "google.com" {
|
|
return false, nil
|
|
}
|
|
}
|
|
|
|
resolver := &net.Resolver{}
|
|
ips, err := resolver.LookupIPAddr(ctx, u.Hostname())
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
for _, ip := range ips {
|
|
if ip.IP.IsLoopback() {
|
|
return true, nil
|
|
}
|
|
}
|
|
return false, nil
|
|
}
|
|
|
|
func shutdownWithTimeout(shutdown func(context.Context) error, timeout time.Duration) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
|
defer cancel()
|
|
return shutdown(ctx)
|
|
}
|
|
|
|
// nolint:revive
|
|
func newProvisionerDaemon(
|
|
ctx context.Context,
|
|
coderAPI *coderd.API,
|
|
metrics provisionerd.Metrics,
|
|
logger slog.Logger,
|
|
cfg *codersdk.DeploymentValues,
|
|
cacheDir string,
|
|
errCh chan error,
|
|
wg *sync.WaitGroup,
|
|
name string,
|
|
) (srv *provisionerd.Server, err error) {
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
defer func() {
|
|
if err != nil {
|
|
cancel()
|
|
}
|
|
}()
|
|
|
|
err = os.MkdirAll(cacheDir, 0o700)
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("mkdir %q: %w", cacheDir, err)
|
|
}
|
|
|
|
workDir := filepath.Join(cacheDir, "work")
|
|
err = os.MkdirAll(workDir, 0o700)
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("mkdir work dir: %w", err)
|
|
}
|
|
|
|
connector := provisionerd.LocalProvisioners{}
|
|
if cfg.Provisioner.DaemonsEcho {
|
|
echoClient, echoServer := drpc.MemTransportPipe()
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
<-ctx.Done()
|
|
_ = echoClient.Close()
|
|
_ = echoServer.Close()
|
|
}()
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
defer cancel()
|
|
|
|
err := echo.Serve(ctx, &provisionersdk.ServeOptions{
|
|
Listener: echoServer,
|
|
WorkDirectory: workDir,
|
|
Logger: logger.Named("echo"),
|
|
})
|
|
if err != nil {
|
|
select {
|
|
case errCh <- err:
|
|
default:
|
|
}
|
|
}
|
|
}()
|
|
connector[string(database.ProvisionerTypeEcho)] = sdkproto.NewDRPCProvisionerClient(echoClient)
|
|
} else {
|
|
tfDir := filepath.Join(cacheDir, "tf")
|
|
err = os.MkdirAll(tfDir, 0o700)
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("mkdir terraform dir: %w", err)
|
|
}
|
|
|
|
tracer := coderAPI.TracerProvider.Tracer(tracing.TracerName)
|
|
terraformClient, terraformServer := drpc.MemTransportPipe()
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
<-ctx.Done()
|
|
_ = terraformClient.Close()
|
|
_ = terraformServer.Close()
|
|
}()
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
defer cancel()
|
|
|
|
err := terraform.Serve(ctx, &terraform.ServeOptions{
|
|
ServeOptions: &provisionersdk.ServeOptions{
|
|
Listener: terraformServer,
|
|
Logger: logger.Named("terraform"),
|
|
WorkDirectory: workDir,
|
|
},
|
|
CachePath: tfDir,
|
|
Tracer: tracer,
|
|
})
|
|
if err != nil && !xerrors.Is(err, context.Canceled) {
|
|
select {
|
|
case errCh <- err:
|
|
default:
|
|
}
|
|
}
|
|
}()
|
|
|
|
connector[string(database.ProvisionerTypeTerraform)] = sdkproto.NewDRPCProvisionerClient(terraformClient)
|
|
}
|
|
|
|
return provisionerd.New(func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
|
|
// This debounces calls to listen every second. Read the comment
|
|
// in provisionerdserver.go to learn more!
|
|
return coderAPI.CreateInMemoryProvisionerDaemon(ctx, name)
|
|
}, &provisionerd.Options{
|
|
Logger: logger.Named(fmt.Sprintf("provisionerd-%s", name)),
|
|
UpdateInterval: time.Second,
|
|
ForceCancelInterval: cfg.Provisioner.ForceCancelInterval.Value(),
|
|
Connector: connector,
|
|
TracerProvider: coderAPI.TracerProvider,
|
|
Metrics: &metrics,
|
|
}), nil
|
|
}
|
|
|
|
// nolint: revive
|
|
func PrintLogo(inv *clibase.Invocation, daemonTitle string) {
|
|
// Only print the logo in TTYs.
|
|
if !isTTYOut(inv) {
|
|
return
|
|
}
|
|
|
|
versionString := cliui.Bold(daemonTitle + " " + buildinfo.Version())
|
|
|
|
_, _ = fmt.Fprintf(inv.Stdout, "%s - Your Self-Hosted Remote Development Platform\n", versionString)
|
|
}
|
|
|
|
func loadCertificates(tlsCertFiles, tlsKeyFiles []string) ([]tls.Certificate, error) {
|
|
if len(tlsCertFiles) != len(tlsKeyFiles) {
|
|
return nil, xerrors.New("--tls-cert-file and --tls-key-file must be used the same amount of times")
|
|
}
|
|
|
|
certs := make([]tls.Certificate, len(tlsCertFiles))
|
|
for i := range tlsCertFiles {
|
|
certFile, keyFile := tlsCertFiles[i], tlsKeyFiles[i]
|
|
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
|
|
if err != nil {
|
|
return nil, xerrors.Errorf(
|
|
"load TLS key pair %d (%q, %q): %w\ncertFiles: %+v\nkeyFiles: %+v",
|
|
i, certFile, keyFile, err,
|
|
tlsCertFiles, tlsKeyFiles,
|
|
)
|
|
}
|
|
|
|
certs[i] = cert
|
|
}
|
|
|
|
return certs, nil
|
|
}
|
|
|
|
// generateSelfSignedCertificate creates an unsafe self-signed certificate
|
|
// at random that allows users to proceed with setup in the event they
|
|
// haven't configured any TLS certificates.
|
|
func generateSelfSignedCertificate() (*tls.Certificate, error) {
|
|
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
template := x509.Certificate{
|
|
SerialNumber: big.NewInt(1),
|
|
NotBefore: time.Now(),
|
|
NotAfter: time.Now().Add(time.Hour * 24 * 180),
|
|
|
|
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
|
BasicConstraintsValid: true,
|
|
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
|
|
}
|
|
|
|
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var cert tls.Certificate
|
|
cert.Certificate = append(cert.Certificate, derBytes)
|
|
cert.PrivateKey = privateKey
|
|
return &cert, nil
|
|
}
|
|
|
|
// configureServerTLS returns the TLS config used for the Coderd server
|
|
// connections to clients. A logger is passed in to allow printing warning
|
|
// messages that do not block startup.
|
|
//
|
|
//nolint:revive
|
|
func configureServerTLS(ctx context.Context, logger slog.Logger, tlsMinVersion, tlsClientAuth string, tlsCertFiles, tlsKeyFiles []string, tlsClientCAFile string, ciphers []string, allowInsecureCiphers bool) (*tls.Config, error) {
|
|
tlsConfig := &tls.Config{
|
|
MinVersion: tls.VersionTLS12,
|
|
NextProtos: []string{"h2", "http/1.1"},
|
|
}
|
|
switch tlsMinVersion {
|
|
case "tls10":
|
|
tlsConfig.MinVersion = tls.VersionTLS10
|
|
case "tls11":
|
|
tlsConfig.MinVersion = tls.VersionTLS11
|
|
case "tls12":
|
|
tlsConfig.MinVersion = tls.VersionTLS12
|
|
case "tls13":
|
|
tlsConfig.MinVersion = tls.VersionTLS13
|
|
default:
|
|
return nil, xerrors.Errorf("unrecognized tls version: %q", tlsMinVersion)
|
|
}
|
|
|
|
// A custom set of supported ciphers.
|
|
if len(ciphers) > 0 {
|
|
cipherIDs, err := configureCipherSuites(ctx, logger, ciphers, allowInsecureCiphers, tlsConfig.MinVersion, tls.VersionTLS13)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
tlsConfig.CipherSuites = cipherIDs
|
|
}
|
|
|
|
switch tlsClientAuth {
|
|
case "none":
|
|
tlsConfig.ClientAuth = tls.NoClientCert
|
|
case "request":
|
|
tlsConfig.ClientAuth = tls.RequestClientCert
|
|
case "require-any":
|
|
tlsConfig.ClientAuth = tls.RequireAnyClientCert
|
|
case "verify-if-given":
|
|
tlsConfig.ClientAuth = tls.VerifyClientCertIfGiven
|
|
case "require-and-verify":
|
|
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
|
default:
|
|
return nil, xerrors.Errorf("unrecognized tls client auth: %q", tlsClientAuth)
|
|
}
|
|
|
|
certs, err := loadCertificates(tlsCertFiles, tlsKeyFiles)
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("load certificates: %w", err)
|
|
}
|
|
if len(certs) == 0 {
|
|
selfSignedCertificate, err := generateSelfSignedCertificate()
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("generate self signed certificate: %w", err)
|
|
}
|
|
certs = append(certs, *selfSignedCertificate)
|
|
}
|
|
|
|
tlsConfig.Certificates = certs
|
|
tlsConfig.GetCertificate = func(hi *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
|
// If there's only one certificate, return it.
|
|
if len(certs) == 1 {
|
|
return &certs[0], nil
|
|
}
|
|
|
|
// Expensively check which certificate matches the client hello.
|
|
for _, cert := range certs {
|
|
cert := cert
|
|
if err := hi.SupportsCertificate(&cert); err == nil {
|
|
return &cert, nil
|
|
}
|
|
}
|
|
|
|
// Return the first certificate if we have one, or return nil so the
|
|
// server doesn't fail.
|
|
if len(certs) > 0 {
|
|
return &certs[0], nil
|
|
}
|
|
return nil, nil //nolint:nilnil
|
|
}
|
|
|
|
err = configureCAPool(tlsClientCAFile, tlsConfig)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return tlsConfig, nil
|
|
}
|
|
|
|
//nolint:revive
|
|
func configureCipherSuites(ctx context.Context, logger slog.Logger, ciphers []string, allowInsecureCiphers bool, minTLS, maxTLS uint16) ([]uint16, error) {
|
|
if minTLS > maxTLS {
|
|
return nil, xerrors.Errorf("minimum tls version (%s) cannot be greater than maximum tls version (%s)", versionName(minTLS), versionName(maxTLS))
|
|
}
|
|
if minTLS >= tls.VersionTLS13 {
|
|
// The cipher suites config option is ignored for tls 1.3 and higher.
|
|
// So this user flag is a no-op if the min version is 1.3.
|
|
return nil, xerrors.Errorf("'--tls-ciphers' cannot be specified when using minimum tls version 1.3 or higher, %d ciphers found as input.", len(ciphers))
|
|
}
|
|
// Configure the cipher suites which parses the strings and converts them
|
|
// to golang cipher suites.
|
|
supported, err := parseTLSCipherSuites(ciphers)
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("tls ciphers: %w", err)
|
|
}
|
|
|
|
// allVersions is all tls versions the server supports.
|
|
// We enumerate these to ensure if ciphers are configured, at least
|
|
// 1 cipher for each version exists.
|
|
allVersions := make(map[uint16]bool)
|
|
for v := minTLS; v <= maxTLS; v++ {
|
|
allVersions[v] = false
|
|
}
|
|
|
|
var insecure []string
|
|
cipherIDs := make([]uint16, 0, len(supported))
|
|
for _, cipher := range supported {
|
|
if cipher.Insecure {
|
|
// Always show this warning, even if they have allowInsecureCiphers
|
|
// specified.
|
|
logger.Warn(ctx, "insecure tls cipher specified for server use", slog.F("cipher", cipher.Name))
|
|
insecure = append(insecure, cipher.Name)
|
|
}
|
|
|
|
// This is a warning message to tell the user if they are specifying
|
|
// a cipher that does not support the tls versions they have specified.
|
|
// This makes the cipher essentially a "noop" cipher.
|
|
if !hasSupportedVersion(minTLS, maxTLS, cipher.SupportedVersions) {
|
|
versions := make([]string, 0, len(cipher.SupportedVersions))
|
|
for _, sv := range cipher.SupportedVersions {
|
|
versions = append(versions, versionName(sv))
|
|
}
|
|
logger.Warn(ctx, "cipher not supported for tls versions enabled, cipher will not be used",
|
|
slog.F("cipher", cipher.Name),
|
|
slog.F("cipher_supported_versions", strings.Join(versions, ",")),
|
|
slog.F("server_min_version", versionName(minTLS)),
|
|
slog.F("server_max_version", versionName(maxTLS)),
|
|
)
|
|
}
|
|
|
|
for _, v := range cipher.SupportedVersions {
|
|
allVersions[v] = true
|
|
}
|
|
|
|
cipherIDs = append(cipherIDs, cipher.ID)
|
|
}
|
|
|
|
if len(insecure) > 0 && !allowInsecureCiphers {
|
|
return nil, xerrors.Errorf("insecure tls ciphers specified, must use '--tls-allow-insecure-ciphers' to allow these: %s", strings.Join(insecure, ", "))
|
|
}
|
|
|
|
// This is an additional sanity check. The user can specify ciphers that
|
|
// do not cover the full range of tls versions they have specified.
|
|
// They can unintentionally break TLS for some tls configured versions.
|
|
var missedVersions []string
|
|
for version, covered := range allVersions {
|
|
if version == tls.VersionTLS13 {
|
|
continue // v1.3 ignores configured cipher suites.
|
|
}
|
|
if !covered {
|
|
missedVersions = append(missedVersions, versionName(version))
|
|
}
|
|
}
|
|
if len(missedVersions) > 0 {
|
|
return nil, xerrors.Errorf("no tls ciphers supported for tls versions %q."+
|
|
"Add additional ciphers, set the minimum version to 'tls13, or remove the ciphers configured and rely on the default",
|
|
strings.Join(missedVersions, ","))
|
|
}
|
|
|
|
return cipherIDs, nil
|
|
}
|
|
|
|
// parseTLSCipherSuites will parse cipher suite names like 'TLS_RSA_WITH_AES_128_CBC_SHA'
|
|
// to their tls cipher suite structs. If a cipher suite that is unsupported is
|
|
// passed in, this function will return an error.
|
|
// This function can return insecure cipher suites.
|
|
func parseTLSCipherSuites(ciphers []string) ([]tls.CipherSuite, error) {
|
|
if len(ciphers) == 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
var unsupported []string
|
|
var supported []tls.CipherSuite
|
|
// A custom set of supported ciphers.
|
|
allCiphers := append(tls.CipherSuites(), tls.InsecureCipherSuites()...)
|
|
for _, cipher := range ciphers {
|
|
// For each cipher specified by the client, find the cipher in the
|
|
// list of golang supported ciphers.
|
|
var found *tls.CipherSuite
|
|
for _, supported := range allCiphers {
|
|
if strings.EqualFold(supported.Name, cipher) {
|
|
found = supported
|
|
break
|
|
}
|
|
}
|
|
|
|
if found == nil {
|
|
unsupported = append(unsupported, cipher)
|
|
continue
|
|
}
|
|
|
|
supported = append(supported, *found)
|
|
}
|
|
|
|
if len(unsupported) > 0 {
|
|
return nil, xerrors.Errorf("unsupported tls ciphers specified, see https://github.com/golang/go/blob/master/src/crypto/tls/cipher_suites.go#L53-L75: %s", strings.Join(unsupported, ", "))
|
|
}
|
|
|
|
return supported, nil
|
|
}
|
|
|
|
// hasSupportedVersion is a helper function that returns true if the list
|
|
// of supported versions contains a version between min and max.
|
|
// If the versions list is outside the min/max, then it returns false.
|
|
func hasSupportedVersion(min, max uint16, versions []uint16) bool {
|
|
for _, v := range versions {
|
|
if v >= min && v <= max {
|
|
// If one version is in between min/max, return true.
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// versionName is tls.VersionName in go 1.21.
|
|
// Until the switch, the function is copied locally.
|
|
func versionName(version uint16) string {
|
|
switch version {
|
|
case tls.VersionSSL30:
|
|
return "SSLv3"
|
|
case tls.VersionTLS10:
|
|
return "TLS 1.0"
|
|
case tls.VersionTLS11:
|
|
return "TLS 1.1"
|
|
case tls.VersionTLS12:
|
|
return "TLS 1.2"
|
|
case tls.VersionTLS13:
|
|
return "TLS 1.3"
|
|
default:
|
|
return fmt.Sprintf("0x%04X", version)
|
|
}
|
|
}
|
|
|
|
func configureOIDCPKI(orig *oauth2.Config, keyFile string, certFile string) (*oauthpki.Config, error) {
|
|
// Read the files
|
|
keyData, err := os.ReadFile(keyFile)
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("read oidc client key file: %w", err)
|
|
}
|
|
|
|
var certData []byte
|
|
// According to the spec, this is not required. So do not require it on the initial loading
|
|
// of the PKI config.
|
|
if certFile != "" {
|
|
certData, err = os.ReadFile(certFile)
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("read oidc client cert file: %w", err)
|
|
}
|
|
}
|
|
|
|
return oauthpki.NewOauth2PKIConfig(oauthpki.ConfigParams{
|
|
ClientID: orig.ClientID,
|
|
TokenURL: orig.Endpoint.TokenURL,
|
|
Scopes: orig.Scopes,
|
|
PemEncodedKey: keyData,
|
|
PemEncodedCert: certData,
|
|
Config: orig,
|
|
})
|
|
}
|
|
|
|
func configureCAPool(tlsClientCAFile string, tlsConfig *tls.Config) error {
|
|
if tlsClientCAFile != "" {
|
|
caPool := x509.NewCertPool()
|
|
data, err := os.ReadFile(tlsClientCAFile)
|
|
if err != nil {
|
|
return xerrors.Errorf("read %q: %w", tlsClientCAFile, err)
|
|
}
|
|
if !caPool.AppendCertsFromPEM(data) {
|
|
return xerrors.Errorf("failed to parse CA certificate in tls-client-ca-file")
|
|
}
|
|
tlsConfig.ClientCAs = caPool
|
|
}
|
|
return nil
|
|
}
|
|
|
|
//nolint:revive // Ignore flag-parameter: parameter 'allowEveryone' seems to be a control flag, avoid control coupling (revive)
|
|
func configureGithubOAuth2(accessURL *url.URL, clientID, clientSecret string, allowSignups, allowEveryone bool, allowOrgs []string, rawTeams []string, enterpriseBaseURL string) (*coderd.GithubOAuth2Config, error) {
|
|
redirectURL, err := accessURL.Parse("/api/v2/users/oauth2/github/callback")
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("parse github oauth callback url: %w", err)
|
|
}
|
|
if allowEveryone && len(allowOrgs) > 0 {
|
|
return nil, xerrors.New("allow everyone and allowed orgs cannot be used together")
|
|
}
|
|
if allowEveryone && len(rawTeams) > 0 {
|
|
return nil, xerrors.New("allow everyone and allowed teams cannot be used together")
|
|
}
|
|
if !allowEveryone && len(allowOrgs) == 0 {
|
|
return nil, xerrors.New("allowed orgs is empty: must specify at least one org or allow everyone")
|
|
}
|
|
allowTeams := make([]coderd.GithubOAuth2Team, 0, len(rawTeams))
|
|
for _, rawTeam := range rawTeams {
|
|
parts := strings.SplitN(rawTeam, "/", 2)
|
|
if len(parts) != 2 {
|
|
return nil, xerrors.Errorf("github team allowlist is formatted incorrectly. got %s; wanted <organization>/<team>", rawTeam)
|
|
}
|
|
allowTeams = append(allowTeams, coderd.GithubOAuth2Team{
|
|
Organization: parts[0],
|
|
Slug: parts[1],
|
|
})
|
|
}
|
|
createClient := func(client *http.Client) (*github.Client, error) {
|
|
if enterpriseBaseURL != "" {
|
|
return github.NewEnterpriseClient(enterpriseBaseURL, "", client)
|
|
}
|
|
return github.NewClient(client), nil
|
|
}
|
|
|
|
endpoint := xgithub.Endpoint
|
|
if enterpriseBaseURL != "" {
|
|
enterpriseURL, err := url.Parse(enterpriseBaseURL)
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("parse enterprise base url: %w", err)
|
|
}
|
|
authURL, err := enterpriseURL.Parse("/login/oauth/authorize")
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("parse enterprise auth url: %w", err)
|
|
}
|
|
tokenURL, err := enterpriseURL.Parse("/login/oauth/access_token")
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("parse enterprise token url: %w", err)
|
|
}
|
|
endpoint = oauth2.Endpoint{
|
|
AuthURL: authURL.String(),
|
|
TokenURL: tokenURL.String(),
|
|
}
|
|
}
|
|
|
|
return &coderd.GithubOAuth2Config{
|
|
OAuth2Config: &oauth2.Config{
|
|
ClientID: clientID,
|
|
ClientSecret: clientSecret,
|
|
Endpoint: endpoint,
|
|
RedirectURL: redirectURL.String(),
|
|
Scopes: []string{
|
|
"read:user",
|
|
"read:org",
|
|
"user:email",
|
|
},
|
|
},
|
|
AllowSignups: allowSignups,
|
|
AllowEveryone: allowEveryone,
|
|
AllowOrganizations: allowOrgs,
|
|
AllowTeams: allowTeams,
|
|
AuthenticatedUser: func(ctx context.Context, client *http.Client) (*github.User, error) {
|
|
api, err := createClient(client)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
user, _, err := api.Users.Get(ctx, "")
|
|
return user, err
|
|
},
|
|
ListEmails: func(ctx context.Context, client *http.Client) ([]*github.UserEmail, error) {
|
|
api, err := createClient(client)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
emails, _, err := api.Users.ListEmails(ctx, &github.ListOptions{})
|
|
return emails, err
|
|
},
|
|
ListOrganizationMemberships: func(ctx context.Context, client *http.Client) ([]*github.Membership, error) {
|
|
api, err := createClient(client)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
memberships, _, err := api.Organizations.ListOrgMemberships(ctx, &github.ListOrgMembershipsOptions{
|
|
State: "active",
|
|
ListOptions: github.ListOptions{
|
|
PerPage: 100,
|
|
},
|
|
})
|
|
return memberships, err
|
|
},
|
|
TeamMembership: func(ctx context.Context, client *http.Client, org, teamSlug, username string) (*github.Membership, error) {
|
|
api, err := createClient(client)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
team, _, err := api.Teams.GetTeamMembershipBySlug(ctx, org, teamSlug, username)
|
|
return team, err
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
// embeddedPostgresURL returns the URL for the embedded PostgreSQL deployment.
|
|
func embeddedPostgresURL(cfg config.Root) (string, error) {
|
|
pgPassword, err := cfg.PostgresPassword().Read()
|
|
if errors.Is(err, os.ErrNotExist) {
|
|
pgPassword, err = cryptorand.String(16)
|
|
if err != nil {
|
|
return "", xerrors.Errorf("generate password: %w", err)
|
|
}
|
|
err = cfg.PostgresPassword().Write(pgPassword)
|
|
if err != nil {
|
|
return "", xerrors.Errorf("write password: %w", err)
|
|
}
|
|
}
|
|
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
|
return "", err
|
|
}
|
|
pgPort, err := cfg.PostgresPort().Read()
|
|
if errors.Is(err, os.ErrNotExist) {
|
|
listener, err := net.Listen("tcp4", "127.0.0.1:0")
|
|
if err != nil {
|
|
return "", xerrors.Errorf("listen for random port: %w", err)
|
|
}
|
|
_ = listener.Close()
|
|
tcpAddr, valid := listener.Addr().(*net.TCPAddr)
|
|
if !valid {
|
|
return "", xerrors.Errorf("listener returned non TCP addr: %T", tcpAddr)
|
|
}
|
|
pgPort = strconv.Itoa(tcpAddr.Port)
|
|
err = cfg.PostgresPort().Write(pgPort)
|
|
if err != nil {
|
|
return "", xerrors.Errorf("write postgres port: %w", err)
|
|
}
|
|
}
|
|
return fmt.Sprintf("postgres://coder@localhost:%s/coder?sslmode=disable&password=%s", pgPort, pgPassword), nil
|
|
}
|
|
|
|
func startBuiltinPostgres(ctx context.Context, cfg config.Root, logger slog.Logger) (string, func() error, error) {
|
|
usr, err := user.Current()
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
if usr.Uid == "0" {
|
|
return "", nil, xerrors.New("The built-in PostgreSQL cannot run as the root user. Create a non-root user and run again!")
|
|
}
|
|
|
|
// Ensure a password and port have been generated!
|
|
connectionURL, err := embeddedPostgresURL(cfg)
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
pgPassword, err := cfg.PostgresPassword().Read()
|
|
if err != nil {
|
|
return "", nil, xerrors.Errorf("read postgres password: %w", err)
|
|
}
|
|
pgPortRaw, err := cfg.PostgresPort().Read()
|
|
if err != nil {
|
|
return "", nil, xerrors.Errorf("read postgres port: %w", err)
|
|
}
|
|
pgPort, err := strconv.ParseUint(pgPortRaw, 10, 16)
|
|
if err != nil {
|
|
return "", nil, xerrors.Errorf("parse postgres port: %w", err)
|
|
}
|
|
|
|
stdlibLogger := slog.Stdlib(ctx, logger.Named("postgres"), slog.LevelDebug)
|
|
ep := embeddedpostgres.NewDatabase(
|
|
embeddedpostgres.DefaultConfig().
|
|
Version(embeddedpostgres.V13).
|
|
BinariesPath(filepath.Join(cfg.PostgresPath(), "bin")).
|
|
DataPath(filepath.Join(cfg.PostgresPath(), "data")).
|
|
RuntimePath(filepath.Join(cfg.PostgresPath(), "runtime")).
|
|
CachePath(filepath.Join(cfg.PostgresPath(), "cache")).
|
|
Username("coder").
|
|
Password(pgPassword).
|
|
Database("coder").
|
|
Port(uint32(pgPort)).
|
|
Logger(stdlibLogger.Writer()),
|
|
)
|
|
err = ep.Start()
|
|
if err != nil {
|
|
return "", nil, xerrors.Errorf("Failed to start built-in PostgreSQL. Optionally, specify an external deployment with `--postgres-url`: %w", err)
|
|
}
|
|
return connectionURL, ep.Stop, nil
|
|
}
|
|
|
|
func ConfigureHTTPClient(ctx context.Context, clientCertFile, clientKeyFile string, tlsClientCAFile string) (context.Context, *http.Client, error) {
|
|
if clientCertFile != "" && clientKeyFile != "" {
|
|
certificates, err := loadCertificates([]string{clientCertFile}, []string{clientKeyFile})
|
|
if err != nil {
|
|
return ctx, nil, err
|
|
}
|
|
|
|
tlsClientConfig := &tls.Config{ //nolint:gosec
|
|
Certificates: certificates,
|
|
NextProtos: []string{"h2", "http/1.1"},
|
|
}
|
|
err = configureCAPool(tlsClientCAFile, tlsClientConfig)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
httpClient := &http.Client{
|
|
Transport: &http.Transport{
|
|
TLSClientConfig: tlsClientConfig,
|
|
},
|
|
}
|
|
return context.WithValue(ctx, oauth2.HTTPClient, httpClient), httpClient, nil
|
|
}
|
|
return ctx, &http.Client{}, nil
|
|
}
|
|
|
|
// nolint:revive
|
|
func redirectToAccessURL(handler http.Handler, accessURL *url.URL, tunnel bool, appHostnameRegex *regexp.Regexp) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
redirect := func() {
|
|
http.Redirect(w, r, accessURL.String(), http.StatusTemporaryRedirect)
|
|
}
|
|
|
|
// Exception: DERP
|
|
// We use this endpoint when creating a DERP-mesh in the enterprise version to directly
|
|
// dial other Coderd derpers. Redirecting to the access URL breaks direct dial since the
|
|
// access URL will be load-balanced in a multi-replica deployment.
|
|
//
|
|
// It's totally fine to access DERP over TLS, but we also don't need to redirect HTTP to
|
|
// HTTPS as DERP is itself an encrypted protocol.
|
|
if isDERPPath(r.URL.Path) {
|
|
handler.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
// Only do this if we aren't tunneling.
|
|
// If we are tunneling, we want to allow the request to go through
|
|
// because the tunnel doesn't proxy with TLS.
|
|
if !tunnel && accessURL.Scheme == "https" && r.TLS == nil {
|
|
redirect()
|
|
return
|
|
}
|
|
|
|
if r.Host == accessURL.Host {
|
|
handler.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
if r.Header.Get("X-Forwarded-Host") == accessURL.Host {
|
|
handler.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
if appHostnameRegex != nil && appHostnameRegex.MatchString(r.Host) {
|
|
handler.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
redirect()
|
|
})
|
|
}
|
|
|
|
func isDERPPath(p string) bool {
|
|
segments := strings.SplitN(p, "/", 3)
|
|
if len(segments) < 2 {
|
|
return false
|
|
}
|
|
return segments[1] == "derp"
|
|
}
|
|
|
|
// IsLocalhost returns true if the host points to the local machine. Intended to
|
|
// be called with `u.Hostname()`.
|
|
func IsLocalhost(host string) bool {
|
|
return host == "localhost" || host == "127.0.0.1" || host == "::1"
|
|
}
|
|
|
|
var _ slog.Sink = &debugFilterSink{}
|
|
|
|
type debugFilterSink struct {
|
|
next []slog.Sink
|
|
re *regexp.Regexp
|
|
}
|
|
|
|
func (f *debugFilterSink) compile(res []string) error {
|
|
if len(res) == 0 {
|
|
return nil
|
|
}
|
|
|
|
var reb strings.Builder
|
|
for i, re := range res {
|
|
_, _ = fmt.Fprintf(&reb, "(%s)", re)
|
|
if i != len(res)-1 {
|
|
_, _ = reb.WriteRune('|')
|
|
}
|
|
}
|
|
|
|
re, err := regexp.Compile(reb.String())
|
|
if err != nil {
|
|
return xerrors.Errorf("compile regex: %w", err)
|
|
}
|
|
f.re = re
|
|
return nil
|
|
}
|
|
|
|
func (f *debugFilterSink) LogEntry(ctx context.Context, ent slog.SinkEntry) {
|
|
if ent.Level == slog.LevelDebug {
|
|
logName := strings.Join(ent.LoggerNames, ".")
|
|
if f.re != nil && !f.re.MatchString(logName) && !f.re.MatchString(ent.Message) {
|
|
return
|
|
}
|
|
}
|
|
for _, sink := range f.next {
|
|
sink.LogEntry(ctx, ent)
|
|
}
|
|
}
|
|
|
|
func (f *debugFilterSink) Sync() {
|
|
for _, sink := range f.next {
|
|
sink.Sync()
|
|
}
|
|
}
|
|
|
|
func BuildLogger(inv *clibase.Invocation, cfg *codersdk.DeploymentValues) (slog.Logger, func(), error) {
|
|
var (
|
|
sinks = []slog.Sink{}
|
|
closers = []func() error{}
|
|
)
|
|
|
|
addSinkIfProvided := func(sinkFn func(io.Writer) slog.Sink, loc string) error {
|
|
switch loc {
|
|
case "":
|
|
|
|
case "/dev/stdout":
|
|
sinks = append(sinks, sinkFn(inv.Stdout))
|
|
|
|
case "/dev/stderr":
|
|
sinks = append(sinks, sinkFn(inv.Stderr))
|
|
|
|
default:
|
|
fi, err := os.OpenFile(loc, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0o644)
|
|
if err != nil {
|
|
return xerrors.Errorf("open log file %q: %w", loc, err)
|
|
}
|
|
closers = append(closers, fi.Close)
|
|
sinks = append(sinks, sinkFn(fi))
|
|
}
|
|
return nil
|
|
}
|
|
|
|
err := addSinkIfProvided(sloghuman.Sink, cfg.Logging.Human.String())
|
|
if err != nil {
|
|
return slog.Logger{}, nil, xerrors.Errorf("add human sink: %w", err)
|
|
}
|
|
err = addSinkIfProvided(slogjson.Sink, cfg.Logging.JSON.String())
|
|
if err != nil {
|
|
return slog.Logger{}, nil, xerrors.Errorf("add json sink: %w", err)
|
|
}
|
|
err = addSinkIfProvided(slogstackdriver.Sink, cfg.Logging.Stackdriver.String())
|
|
if err != nil {
|
|
return slog.Logger{}, nil, xerrors.Errorf("add stackdriver sink: %w", err)
|
|
}
|
|
|
|
if cfg.Trace.CaptureLogs {
|
|
sinks = append(sinks, tracing.SlogSink{})
|
|
}
|
|
|
|
// User should log to null device if they don't want logs.
|
|
if len(sinks) == 0 {
|
|
return slog.Logger{}, nil, xerrors.New("no loggers provided")
|
|
}
|
|
|
|
filter := &debugFilterSink{next: sinks}
|
|
|
|
err = filter.compile(cfg.Logging.Filter.Value())
|
|
if err != nil {
|
|
return slog.Logger{}, nil, xerrors.Errorf("compile filters: %w", err)
|
|
}
|
|
|
|
level := slog.LevelInfo
|
|
// Debug logging is always enabled if a filter is present.
|
|
if cfg.Verbose || filter.re != nil {
|
|
level = slog.LevelDebug
|
|
}
|
|
|
|
return inv.Logger.AppendSinks(filter).Leveled(level), func() {
|
|
for _, closer := range closers {
|
|
_ = closer()
|
|
}
|
|
}, nil
|
|
}
|
|
|
|
func ConnectToPostgres(ctx context.Context, logger slog.Logger, driver string, dbURL string) (sqlDB *sql.DB, err error) {
|
|
logger.Debug(ctx, "connecting to postgresql")
|
|
|
|
// Try to connect for 30 seconds.
|
|
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
|
defer cancel()
|
|
|
|
defer func() {
|
|
if err == nil {
|
|
return
|
|
}
|
|
if sqlDB != nil {
|
|
_ = sqlDB.Close()
|
|
sqlDB = nil
|
|
}
|
|
logger.Error(ctx, "connect to postgres failed", slog.Error(err))
|
|
}()
|
|
|
|
var tries int
|
|
for r := retry.New(time.Second, 3*time.Second); r.Wait(ctx); {
|
|
tries++
|
|
|
|
sqlDB, err = sql.Open(driver, dbURL)
|
|
if err != nil {
|
|
logger.Warn(ctx, "connect to postgres: retrying", slog.Error(err), slog.F("try", tries))
|
|
continue
|
|
}
|
|
|
|
err = pingPostgres(ctx, sqlDB)
|
|
if err != nil {
|
|
logger.Warn(ctx, "ping postgres: retrying", slog.Error(err), slog.F("try", tries))
|
|
_ = sqlDB.Close()
|
|
sqlDB = nil
|
|
continue
|
|
}
|
|
|
|
break
|
|
}
|
|
if err == nil {
|
|
err = ctx.Err()
|
|
}
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("unable to connect after %d tries; last error: %w", tries, err)
|
|
}
|
|
|
|
// Ensure the PostgreSQL version is >=13.0.0!
|
|
version, err := sqlDB.QueryContext(ctx, "SHOW server_version_num;")
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("get postgres version: %w", err)
|
|
}
|
|
if !version.Next() {
|
|
return nil, xerrors.Errorf("no rows returned for version select")
|
|
}
|
|
var versionNum int
|
|
err = version.Scan(&versionNum)
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("scan version: %w", err)
|
|
}
|
|
_ = version.Close()
|
|
|
|
if versionNum < 130000 {
|
|
return nil, xerrors.Errorf("PostgreSQL version must be v13.0.0 or higher! Got: %d", versionNum)
|
|
}
|
|
logger.Debug(ctx, "connected to postgresql", slog.F("version", versionNum))
|
|
|
|
err = migrations.Up(sqlDB)
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("migrate up: %w", err)
|
|
}
|
|
// The default is 0 but the request will fail with a 500 if the DB
|
|
// cannot accept new connections, so we try to limit that here.
|
|
// Requests will wait for a new connection instead of a hard error
|
|
// if a limit is set.
|
|
sqlDB.SetMaxOpenConns(10)
|
|
// Allow a max of 3 idle connections at a time. Lower values end up
|
|
// creating a lot of connection churn. Since each connection uses about
|
|
// 10MB of memory, we're allocating 30MB to Postgres connections per
|
|
// replica, but is better than causing Postgres to spawn a thread 15-20
|
|
// times/sec. PGBouncer's transaction pooling is not the greatest so
|
|
// it's not optimal for us to deploy.
|
|
//
|
|
// This was set to 10 before we started doing HA deployments, but 3 was
|
|
// later determined to be a better middle ground as to not use up all
|
|
// of PGs default connection limit while simultaneously avoiding a lot
|
|
// of connection churn.
|
|
sqlDB.SetMaxIdleConns(3)
|
|
|
|
return sqlDB, nil
|
|
}
|
|
|
|
func pingPostgres(ctx context.Context, db *sql.DB) error {
|
|
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
|
defer cancel()
|
|
return db.PingContext(ctx)
|
|
}
|
|
|
|
type HTTPServers struct {
|
|
HTTPUrl *url.URL
|
|
HTTPListener net.Listener
|
|
|
|
// TLS
|
|
TLSUrl *url.URL
|
|
TLSListener net.Listener
|
|
TLSConfig *tls.Config
|
|
}
|
|
|
|
// Serve acts just like http.Serve. It is a blocking call until the server
|
|
// is closed, and an error is returned if any underlying Serve call fails.
|
|
func (s *HTTPServers) Serve(srv *http.Server) error {
|
|
eg := errgroup.Group{}
|
|
if s.HTTPListener != nil {
|
|
eg.Go(func() error {
|
|
defer s.Close() // close all listeners on error
|
|
return srv.Serve(s.HTTPListener)
|
|
})
|
|
}
|
|
if s.TLSListener != nil {
|
|
eg.Go(func() error {
|
|
defer s.Close() // close all listeners on error
|
|
return srv.Serve(s.TLSListener)
|
|
})
|
|
}
|
|
return eg.Wait()
|
|
}
|
|
|
|
func (s *HTTPServers) Close() {
|
|
if s.HTTPListener != nil {
|
|
_ = s.HTTPListener.Close()
|
|
}
|
|
if s.TLSListener != nil {
|
|
_ = s.TLSListener.Close()
|
|
}
|
|
}
|
|
|
|
func ConfigureTraceProvider(
|
|
ctx context.Context,
|
|
logger slog.Logger,
|
|
cfg *codersdk.DeploymentValues,
|
|
) (trace.TracerProvider, string, func(context.Context) error) {
|
|
var (
|
|
tracerProvider = trace.NewNoopTracerProvider()
|
|
closeTracing = func(context.Context) error { return nil }
|
|
sqlDriver = "postgres"
|
|
)
|
|
|
|
otel.SetTextMapPropagator(
|
|
propagation.NewCompositeTextMapPropagator(
|
|
propagation.TraceContext{},
|
|
propagation.Baggage{},
|
|
),
|
|
)
|
|
|
|
if cfg.Trace.Enable.Value() || cfg.Trace.DataDog.Value() || cfg.Trace.HoneycombAPIKey != "" {
|
|
sdkTracerProvider, _closeTracing, err := tracing.TracerProvider(ctx, "coderd", tracing.TracerOpts{
|
|
Default: cfg.Trace.Enable.Value(),
|
|
DataDog: cfg.Trace.DataDog.Value(),
|
|
Honeycomb: cfg.Trace.HoneycombAPIKey.String(),
|
|
})
|
|
if err != nil {
|
|
logger.Warn(ctx, "start telemetry exporter", slog.Error(err))
|
|
} else {
|
|
d, err := tracing.PostgresDriver(sdkTracerProvider, "coderd.database")
|
|
if err != nil {
|
|
logger.Warn(ctx, "start postgres tracing driver", slog.Error(err))
|
|
} else {
|
|
sqlDriver = d
|
|
}
|
|
|
|
tracerProvider = sdkTracerProvider
|
|
closeTracing = _closeTracing
|
|
}
|
|
}
|
|
return tracerProvider, sqlDriver, closeTracing
|
|
}
|
|
|
|
func ConfigureHTTPServers(logger slog.Logger, inv *clibase.Invocation, cfg *codersdk.DeploymentValues) (_ *HTTPServers, err error) {
|
|
ctx := inv.Context()
|
|
httpServers := &HTTPServers{}
|
|
defer func() {
|
|
if err != nil {
|
|
// Always close the listeners if we fail.
|
|
httpServers.Close()
|
|
}
|
|
}()
|
|
// Validate bind addresses.
|
|
if cfg.Address.String() != "" {
|
|
if cfg.TLS.Enable {
|
|
cfg.HTTPAddress = ""
|
|
cfg.TLS.Address = cfg.Address
|
|
} else {
|
|
_ = cfg.HTTPAddress.Set(cfg.Address.String())
|
|
cfg.TLS.Address.Host = ""
|
|
cfg.TLS.Address.Port = ""
|
|
}
|
|
}
|
|
if cfg.TLS.Enable && cfg.TLS.Address.String() == "" {
|
|
return nil, xerrors.Errorf("TLS address must be set if TLS is enabled")
|
|
}
|
|
if !cfg.TLS.Enable && cfg.HTTPAddress.String() == "" {
|
|
return nil, xerrors.Errorf("TLS is disabled. Enable with --tls-enable or specify a HTTP address")
|
|
}
|
|
|
|
if cfg.AccessURL.String() != "" &&
|
|
!(cfg.AccessURL.Scheme == "http" || cfg.AccessURL.Scheme == "https") {
|
|
return nil, xerrors.Errorf("access-url must include a scheme (e.g. 'http://' or 'https://)")
|
|
}
|
|
|
|
addrString := func(l net.Listener) string {
|
|
listenAddrStr := l.Addr().String()
|
|
// For some reason if 0.0.0.0:x is provided as the https
|
|
// address, httpsListener.Addr().String() likes to return it as
|
|
// an ipv6 address (i.e. [::]:x). If the input ip is 0.0.0.0,
|
|
// try to coerce the output back to ipv4 to make it less
|
|
// confusing.
|
|
if strings.Contains(cfg.HTTPAddress.String(), "0.0.0.0") {
|
|
listenAddrStr = strings.ReplaceAll(listenAddrStr, "[::]", "0.0.0.0")
|
|
}
|
|
return listenAddrStr
|
|
}
|
|
|
|
if cfg.HTTPAddress.String() != "" {
|
|
httpServers.HTTPListener, err = net.Listen("tcp", cfg.HTTPAddress.String())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// We want to print out the address the user supplied, not the
|
|
// loopback device.
|
|
_, _ = fmt.Fprintf(inv.Stdout, "Started HTTP listener at %s\n", (&url.URL{Scheme: "http", Host: addrString(httpServers.HTTPListener)}).String())
|
|
|
|
// Set the http URL we want to use when connecting to ourselves.
|
|
tcpAddr, tcpAddrValid := httpServers.HTTPListener.Addr().(*net.TCPAddr)
|
|
if !tcpAddrValid {
|
|
return nil, xerrors.Errorf("invalid TCP address type %T", httpServers.HTTPListener.Addr())
|
|
}
|
|
if tcpAddr.IP.IsUnspecified() {
|
|
tcpAddr.IP = net.IPv4(127, 0, 0, 1)
|
|
}
|
|
httpServers.HTTPUrl = &url.URL{
|
|
Scheme: "http",
|
|
Host: tcpAddr.String(),
|
|
}
|
|
}
|
|
|
|
if cfg.TLS.Enable {
|
|
if cfg.TLS.Address.String() == "" {
|
|
return nil, xerrors.New("tls address must be set if tls is enabled")
|
|
}
|
|
|
|
redirectHTTPToHTTPSDeprecation(ctx, logger, inv, cfg)
|
|
|
|
tlsConfig, err := configureServerTLS(
|
|
ctx,
|
|
logger,
|
|
cfg.TLS.MinVersion.String(),
|
|
cfg.TLS.ClientAuth.String(),
|
|
cfg.TLS.CertFiles,
|
|
cfg.TLS.KeyFiles,
|
|
cfg.TLS.ClientCAFile.String(),
|
|
cfg.TLS.SupportedCiphers.Value(),
|
|
cfg.TLS.AllowInsecureCiphers.Value(),
|
|
)
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("configure tls: %w", err)
|
|
}
|
|
httpsListenerInner, err := net.Listen("tcp", cfg.TLS.Address.String())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
httpServers.TLSConfig = tlsConfig
|
|
httpServers.TLSListener = tls.NewListener(httpsListenerInner, tlsConfig)
|
|
|
|
// We want to print out the address the user supplied, not the
|
|
// loopback device.
|
|
_, _ = fmt.Fprintf(inv.Stdout, "Started TLS/HTTPS listener at %s\n", (&url.URL{Scheme: "https", Host: addrString(httpServers.TLSListener)}).String())
|
|
|
|
// Set the https URL we want to use when connecting to
|
|
// ourselves.
|
|
tcpAddr, tcpAddrValid := httpServers.TLSListener.Addr().(*net.TCPAddr)
|
|
if !tcpAddrValid {
|
|
return nil, xerrors.Errorf("invalid TCP address type %T", httpServers.TLSListener.Addr())
|
|
}
|
|
if tcpAddr.IP.IsUnspecified() {
|
|
tcpAddr.IP = net.IPv4(127, 0, 0, 1)
|
|
}
|
|
httpServers.TLSUrl = &url.URL{
|
|
Scheme: "https",
|
|
Host: tcpAddr.String(),
|
|
}
|
|
}
|
|
|
|
if httpServers.HTTPListener == nil && httpServers.TLSListener == nil {
|
|
return nil, xerrors.New("must listen on at least one address")
|
|
}
|
|
|
|
return httpServers, nil
|
|
}
|
|
|
|
// redirectHTTPToHTTPSDeprecation handles deprecation of the --tls-redirect-http-to-https flag and
|
|
// "related" environment variables.
|
|
//
|
|
// --tls-redirect-http-to-https used to default to true.
|
|
// It made more sense to have the redirect be opt-in.
|
|
//
|
|
// Also, for a while we have been accepting the environment variable (but not the
|
|
// corresponding flag!) "CODER_TLS_REDIRECT_HTTP", and it appeared in a configuration
|
|
// example, so we keep accepting it to not break backward compat.
|
|
func redirectHTTPToHTTPSDeprecation(ctx context.Context, logger slog.Logger, inv *clibase.Invocation, cfg *codersdk.DeploymentValues) {
|
|
truthy := func(s string) bool {
|
|
b, err := strconv.ParseBool(s)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
return b
|
|
}
|
|
if truthy(inv.Environ.Get("CODER_TLS_REDIRECT_HTTP")) ||
|
|
truthy(inv.Environ.Get("CODER_TLS_REDIRECT_HTTP_TO_HTTPS")) ||
|
|
inv.ParsedFlags().Changed("tls-redirect-http-to-https") {
|
|
logger.Warn(ctx, "⚠️ --tls-redirect-http-to-https is deprecated, please use --redirect-to-access-url instead")
|
|
cfg.RedirectToAccessURL = cfg.TLS.RedirectHTTP
|
|
}
|
|
}
|
|
|
|
// ReadExternalAuthProvidersFromEnv is provided for compatibility purposes with
|
|
// the viper CLI.
|
|
func ReadExternalAuthProvidersFromEnv(environ []string) ([]codersdk.ExternalAuthConfig, error) {
|
|
providers, err := parseExternalAuthProvidersFromEnv("CODER_EXTERNAL_AUTH_", environ)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
// Deprecated: To support legacy git auth!
|
|
gitProviders, err := parseExternalAuthProvidersFromEnv("CODER_GITAUTH_", environ)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return append(providers, gitProviders...), nil
|
|
}
|
|
|
|
// parseExternalAuthProvidersFromEnv consumes environment variables to parse
|
|
// external auth providers. A prefix is provided to support the legacy
|
|
// parsing of `GITAUTH` environment variables.
|
|
func parseExternalAuthProvidersFromEnv(prefix string, environ []string) ([]codersdk.ExternalAuthConfig, error) {
|
|
// The index numbers must be in-order.
|
|
sort.Strings(environ)
|
|
|
|
var providers []codersdk.ExternalAuthConfig
|
|
for _, v := range clibase.ParseEnviron(environ, prefix) {
|
|
tokens := strings.SplitN(v.Name, "_", 2)
|
|
if len(tokens) != 2 {
|
|
return nil, xerrors.Errorf("invalid env var: %s", v.Name)
|
|
}
|
|
|
|
providerNum, err := strconv.Atoi(tokens[0])
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("parse number: %s", v.Name)
|
|
}
|
|
|
|
var provider codersdk.ExternalAuthConfig
|
|
switch {
|
|
case len(providers) < providerNum:
|
|
return nil, xerrors.Errorf(
|
|
"provider num %v skipped: %s",
|
|
len(providers),
|
|
v.Name,
|
|
)
|
|
case len(providers) == providerNum:
|
|
// At the next next provider.
|
|
providers = append(providers, provider)
|
|
case len(providers) == providerNum+1:
|
|
// At the current provider.
|
|
provider = providers[providerNum]
|
|
}
|
|
|
|
key := tokens[1]
|
|
switch key {
|
|
case "ID":
|
|
provider.ID = v.Value
|
|
case "TYPE":
|
|
provider.Type = v.Value
|
|
case "CLIENT_ID":
|
|
provider.ClientID = v.Value
|
|
case "CLIENT_SECRET":
|
|
provider.ClientSecret = v.Value
|
|
case "AUTH_URL":
|
|
provider.AuthURL = v.Value
|
|
case "TOKEN_URL":
|
|
provider.TokenURL = v.Value
|
|
case "VALIDATE_URL":
|
|
provider.ValidateURL = v.Value
|
|
case "REGEX":
|
|
provider.Regex = v.Value
|
|
case "DEVICE_FLOW":
|
|
b, err := strconv.ParseBool(v.Value)
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("parse bool: %s", v.Value)
|
|
}
|
|
provider.DeviceFlow = b
|
|
case "DEVICE_CODE_URL":
|
|
provider.DeviceCodeURL = v.Value
|
|
case "NO_REFRESH":
|
|
b, err := strconv.ParseBool(v.Value)
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("parse bool: %s", v.Value)
|
|
}
|
|
provider.NoRefresh = b
|
|
case "SCOPES":
|
|
provider.Scopes = strings.Split(v.Value, " ")
|
|
case "EXTRA_TOKEN_KEYS":
|
|
provider.ExtraTokenKeys = strings.Split(v.Value, " ")
|
|
case "APP_INSTALL_URL":
|
|
provider.AppInstallURL = v.Value
|
|
case "APP_INSTALLATIONS_URL":
|
|
provider.AppInstallationsURL = v.Value
|
|
case "DISPLAY_NAME":
|
|
provider.DisplayName = v.Value
|
|
case "DISPLAY_ICON":
|
|
provider.DisplayIcon = v.Value
|
|
}
|
|
providers[providerNum] = provider
|
|
}
|
|
return providers, nil
|
|
}
|