mirror of
https://github.com/coder/coder.git
synced 2025-07-13 21:36:50 +00:00
chore: Add helper for uniform flags and env vars (#588)
This commit is contained in:
70
cli/cliflag/cliflag.go
Normal file
70
cli/cliflag/cliflag.go
Normal file
@ -0,0 +1,70 @@
|
||||
// Package cliflag extends flagset with environment variable defaults.
|
||||
//
|
||||
// Usage:
|
||||
//
|
||||
// cliflag.String(root.Flags(), &address, "address", "a", "CODER_ADDRESS", "127.0.0.1:3000", "The address to serve the API and dashboard")
|
||||
//
|
||||
// Will produce the following usage docs:
|
||||
//
|
||||
// -a, --address string The address to serve the API and dashboard (uses $CODER_ADDRESS). (default "127.0.0.1:3000")
|
||||
//
|
||||
package cliflag
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"github.com/spf13/pflag"
|
||||
)
|
||||
|
||||
// StringVarP sets a string flag on the given flag set.
|
||||
func StringVarP(flagset *pflag.FlagSet, p *string, name string, shorthand string, env string, def string, usage string) {
|
||||
v, ok := os.LookupEnv(env)
|
||||
if !ok || v == "" {
|
||||
v = def
|
||||
}
|
||||
flagset.StringVarP(p, name, shorthand, v, fmtUsage(usage, env))
|
||||
}
|
||||
|
||||
// Uint8VarP sets a uint8 flag on the given flag set.
|
||||
func Uint8VarP(flagset *pflag.FlagSet, ptr *uint8, name string, shorthand string, env string, def uint8, usage string) {
|
||||
val, ok := os.LookupEnv(env)
|
||||
if !ok || val == "" {
|
||||
flagset.Uint8VarP(ptr, name, shorthand, def, fmtUsage(usage, env))
|
||||
return
|
||||
}
|
||||
|
||||
vi64, err := strconv.ParseUint(val, 10, 8)
|
||||
if err != nil {
|
||||
flagset.Uint8VarP(ptr, name, shorthand, def, fmtUsage(usage, env))
|
||||
return
|
||||
}
|
||||
|
||||
flagset.Uint8VarP(ptr, name, shorthand, uint8(vi64), fmtUsage(usage, env))
|
||||
}
|
||||
|
||||
// BoolVarP sets a bool flag on the given flag set.
|
||||
func BoolVarP(flagset *pflag.FlagSet, ptr *bool, name string, shorthand string, env string, def bool, usage string) {
|
||||
val, ok := os.LookupEnv(env)
|
||||
if !ok || val == "" {
|
||||
flagset.BoolVarP(ptr, name, shorthand, def, fmtUsage(usage, env))
|
||||
return
|
||||
}
|
||||
|
||||
valb, err := strconv.ParseBool(val)
|
||||
if err != nil {
|
||||
flagset.BoolVarP(ptr, name, shorthand, def, fmtUsage(usage, env))
|
||||
return
|
||||
}
|
||||
|
||||
flagset.BoolVarP(ptr, name, shorthand, valb, fmtUsage(usage, env))
|
||||
}
|
||||
|
||||
func fmtUsage(u string, env string) string {
|
||||
if env == "" {
|
||||
return fmt.Sprintf("%s.", u)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s - consumes $%s.", u, env)
|
||||
}
|
145
cli/cliflag/cliflag_test.go
Normal file
145
cli/cliflag/cliflag_test.go
Normal file
@ -0,0 +1,145 @@
|
||||
package cliflag_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/pflag"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/cli/cliflag"
|
||||
"github.com/coder/coder/cryptorand"
|
||||
)
|
||||
|
||||
// Testcliflag cannot run in parallel because it uses t.Setenv.
|
||||
//nolint:paralleltest
|
||||
func TestCliflag(t *testing.T) {
|
||||
t.Run("StringDefault", func(t *testing.T) {
|
||||
var p string
|
||||
flagset, name, shorthand, env, usage := randomFlag()
|
||||
def, _ := cryptorand.String(10)
|
||||
|
||||
cliflag.StringVarP(flagset, &p, name, shorthand, env, def, usage)
|
||||
got, err := flagset.GetString(name)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, def, got)
|
||||
require.Contains(t, flagset.FlagUsages(), usage)
|
||||
require.Contains(t, flagset.FlagUsages(), fmt.Sprintf(" - consumes $%s", env))
|
||||
})
|
||||
|
||||
t.Run("StringEnvVar", func(t *testing.T) {
|
||||
var p string
|
||||
flagset, name, shorthand, env, usage := randomFlag()
|
||||
envValue, _ := cryptorand.String(10)
|
||||
t.Setenv(env, envValue)
|
||||
def, _ := cryptorand.String(10)
|
||||
|
||||
cliflag.StringVarP(flagset, &p, name, shorthand, env, def, usage)
|
||||
got, err := flagset.GetString(name)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, envValue, got)
|
||||
})
|
||||
|
||||
t.Run("EmptyEnvVar", func(t *testing.T) {
|
||||
var p string
|
||||
flagset, name, shorthand, _, usage := randomFlag()
|
||||
def, _ := cryptorand.String(10)
|
||||
|
||||
cliflag.StringVarP(flagset, &p, name, shorthand, "", def, usage)
|
||||
got, err := flagset.GetString(name)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, def, got)
|
||||
require.Contains(t, flagset.FlagUsages(), usage)
|
||||
require.NotContains(t, flagset.FlagUsages(), " - consumes")
|
||||
})
|
||||
|
||||
t.Run("IntDefault", func(t *testing.T) {
|
||||
var p uint8
|
||||
flagset, name, shorthand, env, usage := randomFlag()
|
||||
def, _ := cryptorand.Int63n(10)
|
||||
|
||||
cliflag.Uint8VarP(flagset, &p, name, shorthand, env, uint8(def), usage)
|
||||
got, err := flagset.GetUint8(name)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint8(def), got)
|
||||
require.Contains(t, flagset.FlagUsages(), usage)
|
||||
require.Contains(t, flagset.FlagUsages(), fmt.Sprintf(" - consumes $%s", env))
|
||||
})
|
||||
|
||||
t.Run("IntEnvVar", func(t *testing.T) {
|
||||
var p uint8
|
||||
flagset, name, shorthand, env, usage := randomFlag()
|
||||
envValue, _ := cryptorand.Int63n(10)
|
||||
t.Setenv(env, strconv.FormatUint(uint64(envValue), 10))
|
||||
def, _ := cryptorand.Int()
|
||||
|
||||
cliflag.Uint8VarP(flagset, &p, name, shorthand, env, uint8(def), usage)
|
||||
got, err := flagset.GetUint8(name)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint8(envValue), got)
|
||||
})
|
||||
|
||||
t.Run("IntFailParse", func(t *testing.T) {
|
||||
var p uint8
|
||||
flagset, name, shorthand, env, usage := randomFlag()
|
||||
envValue, _ := cryptorand.String(10)
|
||||
t.Setenv(env, envValue)
|
||||
def, _ := cryptorand.Int63n(10)
|
||||
|
||||
cliflag.Uint8VarP(flagset, &p, name, shorthand, env, uint8(def), usage)
|
||||
got, err := flagset.GetUint8(name)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint8(def), got)
|
||||
})
|
||||
|
||||
t.Run("BoolDefault", func(t *testing.T) {
|
||||
var p bool
|
||||
flagset, name, shorthand, env, usage := randomFlag()
|
||||
def, _ := cryptorand.Bool()
|
||||
|
||||
cliflag.BoolVarP(flagset, &p, name, shorthand, env, def, usage)
|
||||
got, err := flagset.GetBool(name)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, def, got)
|
||||
require.Contains(t, flagset.FlagUsages(), usage)
|
||||
require.Contains(t, flagset.FlagUsages(), fmt.Sprintf(" - consumes $%s", env))
|
||||
})
|
||||
|
||||
t.Run("BoolEnvVar", func(t *testing.T) {
|
||||
var p bool
|
||||
flagset, name, shorthand, env, usage := randomFlag()
|
||||
envValue, _ := cryptorand.Bool()
|
||||
t.Setenv(env, strconv.FormatBool(envValue))
|
||||
def, _ := cryptorand.Bool()
|
||||
|
||||
cliflag.BoolVarP(flagset, &p, name, shorthand, env, def, usage)
|
||||
got, err := flagset.GetBool(name)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, envValue, got)
|
||||
})
|
||||
|
||||
t.Run("BoolFailParse", func(t *testing.T) {
|
||||
var p bool
|
||||
flagset, name, shorthand, env, usage := randomFlag()
|
||||
envValue, _ := cryptorand.String(10)
|
||||
t.Setenv(env, envValue)
|
||||
def, _ := cryptorand.Bool()
|
||||
|
||||
cliflag.BoolVarP(flagset, &p, name, shorthand, env, def, usage)
|
||||
got, err := flagset.GetBool(name)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, def, got)
|
||||
})
|
||||
}
|
||||
|
||||
func randomFlag() (*pflag.FlagSet, string, string, string, string) {
|
||||
fsname, _ := cryptorand.String(10)
|
||||
flagset := pflag.NewFlagSet(fsname, pflag.PanicOnError)
|
||||
name, _ := cryptorand.String(10)
|
||||
shorthand, _ := cryptorand.String(1)
|
||||
env, _ := cryptorand.String(10)
|
||||
usage, _ := cryptorand.String(10)
|
||||
|
||||
return flagset, name, shorthand, env, usage
|
||||
}
|
64
cli/start.go
64
cli/start.go
@ -13,7 +13,6 @@ import (
|
||||
"net/url"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/briandowns/spinner"
|
||||
@ -25,6 +24,7 @@ import (
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/sloghuman"
|
||||
"github.com/coder/coder/cli/cliflag"
|
||||
"github.com/coder/coder/cli/cliui"
|
||||
"github.com/coder/coder/cli/config"
|
||||
"github.com/coder/coder/coderd"
|
||||
@ -44,6 +44,7 @@ func start() *cobra.Command {
|
||||
address string
|
||||
dev bool
|
||||
postgresURL string
|
||||
// provisionerDaemonCount is a uint8 to ensure a number > 0.
|
||||
provisionerDaemonCount uint8
|
||||
tlsCertFile string
|
||||
tlsClientCAFile string
|
||||
@ -57,10 +58,6 @@ func start() *cobra.Command {
|
||||
Use: "start",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
printLogo(cmd)
|
||||
if postgresURL == "" {
|
||||
// Default to the environment variable!
|
||||
postgresURL = os.Getenv("CODER_PG_CONNECTION_URL")
|
||||
}
|
||||
|
||||
listener, err := net.Listen("tcp", address)
|
||||
if err != nil {
|
||||
@ -163,7 +160,7 @@ func start() *cobra.Command {
|
||||
}
|
||||
|
||||
provisionerDaemons := make([]*provisionerd.Server, 0)
|
||||
for i := uint8(0); i < provisionerDaemonCount; i++ {
|
||||
for i := 0; uint8(i) < provisionerDaemonCount; i++ {
|
||||
daemonClose, err := newProvisionerDaemon(cmd.Context(), client, logger)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("create provisioner daemon: %w", err)
|
||||
@ -305,46 +302,27 @@ func start() *cobra.Command {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
defaultAddress := os.Getenv("CODER_ADDRESS")
|
||||
if defaultAddress == "" {
|
||||
defaultAddress = "127.0.0.1:3000"
|
||||
}
|
||||
root.Flags().StringVarP(&accessURL, "access-url", "", os.Getenv("CODER_ACCESS_URL"), "Specifies the external URL to access Coder (uses $CODER_ACCESS_URL).")
|
||||
root.Flags().StringVarP(&address, "address", "a", defaultAddress, "The address to serve the API and dashboard (uses $CODER_ADDRESS).")
|
||||
defaultDev, _ := strconv.ParseBool(os.Getenv("CODER_DEV_MODE"))
|
||||
root.Flags().BoolVarP(&dev, "dev", "", defaultDev, "Serve Coder in dev mode for tinkering (uses $CODER_DEV_MODE).")
|
||||
root.Flags().StringVarP(&postgresURL, "postgres-url", "", "",
|
||||
"URL of a PostgreSQL database to connect to (defaults to $CODER_PG_CONNECTION_URL).")
|
||||
root.Flags().Uint8VarP(&provisionerDaemonCount, "provisioner-daemons", "", 1, "The amount of provisioner daemons to create on start.")
|
||||
defaultTLSEnable, _ := strconv.ParseBool(os.Getenv("CODER_TLS_ENABLE"))
|
||||
root.Flags().BoolVarP(&tlsEnable, "tls-enable", "", defaultTLSEnable, "Specifies if TLS will be enabled (uses $CODER_TLS_ENABLE).")
|
||||
root.Flags().StringVarP(&tlsCertFile, "tls-cert-file", "", os.Getenv("CODER_TLS_CERT_FILE"),
|
||||
|
||||
cliflag.StringVarP(root.Flags(), &accessURL, "access-url", "", "CODER_ACCESS_URL", "", "Specifies the external URL to access Coder")
|
||||
cliflag.StringVarP(root.Flags(), &address, "address", "a", "CODER_ADDRESS", "127.0.0.1:3000", "The address to serve the API and dashboard")
|
||||
cliflag.BoolVarP(root.Flags(), &dev, "dev", "", "CODER_DEV_MODE", false, "Serve Coder in dev mode for tinkering")
|
||||
cliflag.StringVarP(root.Flags(), &postgresURL, "postgres-url", "", "CODER_PG_CONNECTION_URL", "", "URL of a PostgreSQL database to connect to")
|
||||
cliflag.Uint8VarP(root.Flags(), &provisionerDaemonCount, "provisioner-daemons", "", "CODER_PROVISIONER_DAEMONS", 1, "The amount of provisioner daemons to create on start.")
|
||||
cliflag.BoolVarP(root.Flags(), &tlsEnable, "tls-enable", "", "CODER_TLS_ENABLE", false, "Specifies if TLS will be enabled")
|
||||
cliflag.StringVarP(root.Flags(), &tlsCertFile, "tls-cert-file", "", "CODER_TLS_CERT_FILE", "",
|
||||
"Specifies the path to the certificate for TLS. It requires a PEM-encoded file. "+
|
||||
"To configure the listener to use a CA certificate, concatenate the primary certificate "+
|
||||
"and the CA certificate together. The primary certificate should appear first in the combined file (uses $CODER_TLS_CERT_FILE).")
|
||||
root.Flags().StringVarP(&tlsClientCAFile, "tls-client-ca-file", "", os.Getenv("CODER_TLS_CLIENT_CA_FILE"),
|
||||
"PEM-encoded Certificate Authority file used for checking the authenticity of client (uses $CODER_TLS_CLIENT_CA_FILE).")
|
||||
defaultTLSClientAuth := os.Getenv("CODER_TLS_CLIENT_AUTH")
|
||||
if defaultTLSClientAuth == "" {
|
||||
defaultTLSClientAuth = "request"
|
||||
}
|
||||
root.Flags().StringVarP(&tlsClientAuth, "tls-client-auth", "", defaultTLSClientAuth,
|
||||
"and the CA certificate together. The primary certificate should appear first in the combined file")
|
||||
cliflag.StringVarP(root.Flags(), &tlsClientCAFile, "tls-client-ca-file", "", "CODER_TLS_CLIENT_CA_FILE", "",
|
||||
"PEM-encoded Certificate Authority file used for checking the authenticity of client")
|
||||
cliflag.StringVarP(root.Flags(), &tlsClientAuth, "tls-client-auth", "", "CODER_TLS_CLIENT_AUTH", "request",
|
||||
`Specifies the policy the server will follow for TLS Client Authentication. `+
|
||||
`Accepted values are "none", "request", "require-any", "verify-if-given", or "require-and-verify" (uses $CODER_TLS_CLIENT_AUTH).`)
|
||||
root.Flags().StringVarP(&tlsKeyFile, "tls-key-file", "", os.Getenv("CODER_TLS_KEY_FILE"),
|
||||
"Specifies the path to the private key for the certificate. It requires a PEM-encoded file (uses $CODER_TLS_KEY_FILE).")
|
||||
defaultTLSMinVersion := os.Getenv("CODER_TLS_MIN_VERSION")
|
||||
if defaultTLSMinVersion == "" {
|
||||
defaultTLSMinVersion = "tls12"
|
||||
}
|
||||
root.Flags().StringVarP(&tlsMinVersion, "tls-min-version", "", defaultTLSMinVersion,
|
||||
`Specifies the minimum supported version of TLS. Accepted values are "tls10", "tls11", "tls12" or "tls13" (uses $CODER_TLS_MIN_VERSION).`)
|
||||
defaultTunnelRaw := os.Getenv("CODER_DEV_TUNNEL")
|
||||
if defaultTunnelRaw == "" {
|
||||
defaultTunnelRaw = "true"
|
||||
}
|
||||
defaultTunnel, _ := strconv.ParseBool(defaultTunnelRaw)
|
||||
root.Flags().BoolVarP(&useTunnel, "tunnel", "", defaultTunnel, "Serve dev mode through a Cloudflare Tunnel for easy setup (uses $CODER_DEV_TUNNEL).")
|
||||
`Accepted values are "none", "request", "require-any", "verify-if-given", or "require-and-verify"`)
|
||||
cliflag.StringVarP(root.Flags(), &tlsKeyFile, "tls-key-file", "", "CODER_TLS_KEY_FILE", "",
|
||||
"Specifies the path to the private key for the certificate. It requires a PEM-encoded file")
|
||||
cliflag.StringVarP(root.Flags(), &tlsMinVersion, "tls-min-version", "", "CODER_TLS_MIN_VERSION", "tls12",
|
||||
`Specifies the minimum supported version of TLS. Accepted values are "tls10", "tls11", "tls12" or "tls13"`)
|
||||
cliflag.BoolVarP(root.Flags(), &useTunnel, "tunnel", "", "CODER_DEV_TUNNEL", false, "Serve dev mode through a Cloudflare Tunnel for easy setup")
|
||||
_ = root.Flags().MarkHidden("tunnel")
|
||||
|
||||
return root
|
||||
|
@ -3,7 +3,6 @@ package cli
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"cloud.google.com/go/compute/metadata"
|
||||
@ -14,6 +13,7 @@ import (
|
||||
"cdr.dev/slog/sloggers/sloghuman"
|
||||
|
||||
"github.com/coder/coder/agent"
|
||||
"github.com/coder/coder/cli/cliflag"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/peer"
|
||||
"github.com/coder/retry"
|
||||
@ -23,6 +23,7 @@ func workspaceAgent() *cobra.Command {
|
||||
var (
|
||||
rawURL string
|
||||
auth string
|
||||
token string
|
||||
)
|
||||
cmd := &cobra.Command{
|
||||
Use: "agent",
|
||||
@ -40,11 +41,10 @@ func workspaceAgent() *cobra.Command {
|
||||
client := codersdk.New(coderURL)
|
||||
switch auth {
|
||||
case "token":
|
||||
sessionToken, exists := os.LookupEnv("CODER_TOKEN")
|
||||
if !exists {
|
||||
if token == "" {
|
||||
return xerrors.Errorf("CODER_TOKEN must be set for token auth")
|
||||
}
|
||||
client.SessionToken = sessionToken
|
||||
client.SessionToken = token
|
||||
case "google-instance-identity":
|
||||
// This is *only* done for testing to mock client authentication.
|
||||
// This will never be set in a production scenario.
|
||||
@ -83,12 +83,10 @@ func workspaceAgent() *cobra.Command {
|
||||
return closer.Close()
|
||||
},
|
||||
}
|
||||
defaultAuth := os.Getenv("CODER_AUTH")
|
||||
if defaultAuth == "" {
|
||||
defaultAuth = "token"
|
||||
}
|
||||
cmd.Flags().StringVarP(&auth, "auth", "", defaultAuth, "Specify the authentication type to use for the agent.")
|
||||
cmd.Flags().StringVarP(&rawURL, "url", "", os.Getenv("CODER_URL"), "Specify the URL to access Coder.")
|
||||
|
||||
cliflag.StringVarP(cmd.Flags(), &auth, "auth", "", "CODER_AUTH", "token", "Specify the authentication type to use for the agent")
|
||||
cliflag.StringVarP(cmd.Flags(), &rawURL, "url", "", "CODER_URL", "", "Specify the URL to access Coder")
|
||||
cliflag.StringVarP(cmd.Flags(), &auth, "token", "", "CODER_TOKEN", "", "Specifies the authentication token to access Coder")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
Reference in New Issue
Block a user