chore: Add helper for uniform flags and env vars (#588)

This commit is contained in:
Garrett Delfosse
2022-03-28 14:26:41 -05:00
committed by GitHub
parent be8389fd74
commit 01957da040
4 changed files with 248 additions and 57 deletions

70
cli/cliflag/cliflag.go Normal file
View File

@ -0,0 +1,70 @@
// Package cliflag extends flagset with environment variable defaults.
//
// Usage:
//
// cliflag.String(root.Flags(), &address, "address", "a", "CODER_ADDRESS", "127.0.0.1:3000", "The address to serve the API and dashboard")
//
// Will produce the following usage docs:
//
// -a, --address string The address to serve the API and dashboard (uses $CODER_ADDRESS). (default "127.0.0.1:3000")
//
package cliflag
import (
"fmt"
"os"
"strconv"
"github.com/spf13/pflag"
)
// StringVarP sets a string flag on the given flag set.
func StringVarP(flagset *pflag.FlagSet, p *string, name string, shorthand string, env string, def string, usage string) {
v, ok := os.LookupEnv(env)
if !ok || v == "" {
v = def
}
flagset.StringVarP(p, name, shorthand, v, fmtUsage(usage, env))
}
// Uint8VarP sets a uint8 flag on the given flag set.
func Uint8VarP(flagset *pflag.FlagSet, ptr *uint8, name string, shorthand string, env string, def uint8, usage string) {
val, ok := os.LookupEnv(env)
if !ok || val == "" {
flagset.Uint8VarP(ptr, name, shorthand, def, fmtUsage(usage, env))
return
}
vi64, err := strconv.ParseUint(val, 10, 8)
if err != nil {
flagset.Uint8VarP(ptr, name, shorthand, def, fmtUsage(usage, env))
return
}
flagset.Uint8VarP(ptr, name, shorthand, uint8(vi64), fmtUsage(usage, env))
}
// BoolVarP sets a bool flag on the given flag set.
func BoolVarP(flagset *pflag.FlagSet, ptr *bool, name string, shorthand string, env string, def bool, usage string) {
val, ok := os.LookupEnv(env)
if !ok || val == "" {
flagset.BoolVarP(ptr, name, shorthand, def, fmtUsage(usage, env))
return
}
valb, err := strconv.ParseBool(val)
if err != nil {
flagset.BoolVarP(ptr, name, shorthand, def, fmtUsage(usage, env))
return
}
flagset.BoolVarP(ptr, name, shorthand, valb, fmtUsage(usage, env))
}
func fmtUsage(u string, env string) string {
if env == "" {
return fmt.Sprintf("%s.", u)
}
return fmt.Sprintf("%s - consumes $%s.", u, env)
}

145
cli/cliflag/cliflag_test.go Normal file
View File

@ -0,0 +1,145 @@
package cliflag_test
import (
"fmt"
"strconv"
"testing"
"github.com/spf13/pflag"
"github.com/stretchr/testify/require"
"github.com/coder/coder/cli/cliflag"
"github.com/coder/coder/cryptorand"
)
// Testcliflag cannot run in parallel because it uses t.Setenv.
//nolint:paralleltest
func TestCliflag(t *testing.T) {
t.Run("StringDefault", func(t *testing.T) {
var p string
flagset, name, shorthand, env, usage := randomFlag()
def, _ := cryptorand.String(10)
cliflag.StringVarP(flagset, &p, name, shorthand, env, def, usage)
got, err := flagset.GetString(name)
require.NoError(t, err)
require.Equal(t, def, got)
require.Contains(t, flagset.FlagUsages(), usage)
require.Contains(t, flagset.FlagUsages(), fmt.Sprintf(" - consumes $%s", env))
})
t.Run("StringEnvVar", func(t *testing.T) {
var p string
flagset, name, shorthand, env, usage := randomFlag()
envValue, _ := cryptorand.String(10)
t.Setenv(env, envValue)
def, _ := cryptorand.String(10)
cliflag.StringVarP(flagset, &p, name, shorthand, env, def, usage)
got, err := flagset.GetString(name)
require.NoError(t, err)
require.Equal(t, envValue, got)
})
t.Run("EmptyEnvVar", func(t *testing.T) {
var p string
flagset, name, shorthand, _, usage := randomFlag()
def, _ := cryptorand.String(10)
cliflag.StringVarP(flagset, &p, name, shorthand, "", def, usage)
got, err := flagset.GetString(name)
require.NoError(t, err)
require.Equal(t, def, got)
require.Contains(t, flagset.FlagUsages(), usage)
require.NotContains(t, flagset.FlagUsages(), " - consumes")
})
t.Run("IntDefault", func(t *testing.T) {
var p uint8
flagset, name, shorthand, env, usage := randomFlag()
def, _ := cryptorand.Int63n(10)
cliflag.Uint8VarP(flagset, &p, name, shorthand, env, uint8(def), usage)
got, err := flagset.GetUint8(name)
require.NoError(t, err)
require.Equal(t, uint8(def), got)
require.Contains(t, flagset.FlagUsages(), usage)
require.Contains(t, flagset.FlagUsages(), fmt.Sprintf(" - consumes $%s", env))
})
t.Run("IntEnvVar", func(t *testing.T) {
var p uint8
flagset, name, shorthand, env, usage := randomFlag()
envValue, _ := cryptorand.Int63n(10)
t.Setenv(env, strconv.FormatUint(uint64(envValue), 10))
def, _ := cryptorand.Int()
cliflag.Uint8VarP(flagset, &p, name, shorthand, env, uint8(def), usage)
got, err := flagset.GetUint8(name)
require.NoError(t, err)
require.Equal(t, uint8(envValue), got)
})
t.Run("IntFailParse", func(t *testing.T) {
var p uint8
flagset, name, shorthand, env, usage := randomFlag()
envValue, _ := cryptorand.String(10)
t.Setenv(env, envValue)
def, _ := cryptorand.Int63n(10)
cliflag.Uint8VarP(flagset, &p, name, shorthand, env, uint8(def), usage)
got, err := flagset.GetUint8(name)
require.NoError(t, err)
require.Equal(t, uint8(def), got)
})
t.Run("BoolDefault", func(t *testing.T) {
var p bool
flagset, name, shorthand, env, usage := randomFlag()
def, _ := cryptorand.Bool()
cliflag.BoolVarP(flagset, &p, name, shorthand, env, def, usage)
got, err := flagset.GetBool(name)
require.NoError(t, err)
require.Equal(t, def, got)
require.Contains(t, flagset.FlagUsages(), usage)
require.Contains(t, flagset.FlagUsages(), fmt.Sprintf(" - consumes $%s", env))
})
t.Run("BoolEnvVar", func(t *testing.T) {
var p bool
flagset, name, shorthand, env, usage := randomFlag()
envValue, _ := cryptorand.Bool()
t.Setenv(env, strconv.FormatBool(envValue))
def, _ := cryptorand.Bool()
cliflag.BoolVarP(flagset, &p, name, shorthand, env, def, usage)
got, err := flagset.GetBool(name)
require.NoError(t, err)
require.Equal(t, envValue, got)
})
t.Run("BoolFailParse", func(t *testing.T) {
var p bool
flagset, name, shorthand, env, usage := randomFlag()
envValue, _ := cryptorand.String(10)
t.Setenv(env, envValue)
def, _ := cryptorand.Bool()
cliflag.BoolVarP(flagset, &p, name, shorthand, env, def, usage)
got, err := flagset.GetBool(name)
require.NoError(t, err)
require.Equal(t, def, got)
})
}
func randomFlag() (*pflag.FlagSet, string, string, string, string) {
fsname, _ := cryptorand.String(10)
flagset := pflag.NewFlagSet(fsname, pflag.PanicOnError)
name, _ := cryptorand.String(10)
shorthand, _ := cryptorand.String(1)
env, _ := cryptorand.String(10)
usage, _ := cryptorand.String(10)
return flagset, name, shorthand, env, usage
}

View File

@ -13,7 +13,6 @@ import (
"net/url"
"os"
"os/signal"
"strconv"
"time"
"github.com/briandowns/spinner"
@ -25,6 +24,7 @@ import (
"cdr.dev/slog"
"cdr.dev/slog/sloggers/sloghuman"
"github.com/coder/coder/cli/cliflag"
"github.com/coder/coder/cli/cliui"
"github.com/coder/coder/cli/config"
"github.com/coder/coder/coderd"
@ -44,6 +44,7 @@ func start() *cobra.Command {
address string
dev bool
postgresURL string
// provisionerDaemonCount is a uint8 to ensure a number > 0.
provisionerDaemonCount uint8
tlsCertFile string
tlsClientCAFile string
@ -57,10 +58,6 @@ func start() *cobra.Command {
Use: "start",
RunE: func(cmd *cobra.Command, args []string) error {
printLogo(cmd)
if postgresURL == "" {
// Default to the environment variable!
postgresURL = os.Getenv("CODER_PG_CONNECTION_URL")
}
listener, err := net.Listen("tcp", address)
if err != nil {
@ -163,7 +160,7 @@ func start() *cobra.Command {
}
provisionerDaemons := make([]*provisionerd.Server, 0)
for i := uint8(0); i < provisionerDaemonCount; i++ {
for i := 0; uint8(i) < provisionerDaemonCount; i++ {
daemonClose, err := newProvisionerDaemon(cmd.Context(), client, logger)
if err != nil {
return xerrors.Errorf("create provisioner daemon: %w", err)
@ -305,46 +302,27 @@ func start() *cobra.Command {
return nil
},
}
defaultAddress := os.Getenv("CODER_ADDRESS")
if defaultAddress == "" {
defaultAddress = "127.0.0.1:3000"
}
root.Flags().StringVarP(&accessURL, "access-url", "", os.Getenv("CODER_ACCESS_URL"), "Specifies the external URL to access Coder (uses $CODER_ACCESS_URL).")
root.Flags().StringVarP(&address, "address", "a", defaultAddress, "The address to serve the API and dashboard (uses $CODER_ADDRESS).")
defaultDev, _ := strconv.ParseBool(os.Getenv("CODER_DEV_MODE"))
root.Flags().BoolVarP(&dev, "dev", "", defaultDev, "Serve Coder in dev mode for tinkering (uses $CODER_DEV_MODE).")
root.Flags().StringVarP(&postgresURL, "postgres-url", "", "",
"URL of a PostgreSQL database to connect to (defaults to $CODER_PG_CONNECTION_URL).")
root.Flags().Uint8VarP(&provisionerDaemonCount, "provisioner-daemons", "", 1, "The amount of provisioner daemons to create on start.")
defaultTLSEnable, _ := strconv.ParseBool(os.Getenv("CODER_TLS_ENABLE"))
root.Flags().BoolVarP(&tlsEnable, "tls-enable", "", defaultTLSEnable, "Specifies if TLS will be enabled (uses $CODER_TLS_ENABLE).")
root.Flags().StringVarP(&tlsCertFile, "tls-cert-file", "", os.Getenv("CODER_TLS_CERT_FILE"),
cliflag.StringVarP(root.Flags(), &accessURL, "access-url", "", "CODER_ACCESS_URL", "", "Specifies the external URL to access Coder")
cliflag.StringVarP(root.Flags(), &address, "address", "a", "CODER_ADDRESS", "127.0.0.1:3000", "The address to serve the API and dashboard")
cliflag.BoolVarP(root.Flags(), &dev, "dev", "", "CODER_DEV_MODE", false, "Serve Coder in dev mode for tinkering")
cliflag.StringVarP(root.Flags(), &postgresURL, "postgres-url", "", "CODER_PG_CONNECTION_URL", "", "URL of a PostgreSQL database to connect to")
cliflag.Uint8VarP(root.Flags(), &provisionerDaemonCount, "provisioner-daemons", "", "CODER_PROVISIONER_DAEMONS", 1, "The amount of provisioner daemons to create on start.")
cliflag.BoolVarP(root.Flags(), &tlsEnable, "tls-enable", "", "CODER_TLS_ENABLE", false, "Specifies if TLS will be enabled")
cliflag.StringVarP(root.Flags(), &tlsCertFile, "tls-cert-file", "", "CODER_TLS_CERT_FILE", "",
"Specifies the path to the certificate for TLS. It requires a PEM-encoded file. "+
"To configure the listener to use a CA certificate, concatenate the primary certificate "+
"and the CA certificate together. The primary certificate should appear first in the combined file (uses $CODER_TLS_CERT_FILE).")
root.Flags().StringVarP(&tlsClientCAFile, "tls-client-ca-file", "", os.Getenv("CODER_TLS_CLIENT_CA_FILE"),
"PEM-encoded Certificate Authority file used for checking the authenticity of client (uses $CODER_TLS_CLIENT_CA_FILE).")
defaultTLSClientAuth := os.Getenv("CODER_TLS_CLIENT_AUTH")
if defaultTLSClientAuth == "" {
defaultTLSClientAuth = "request"
}
root.Flags().StringVarP(&tlsClientAuth, "tls-client-auth", "", defaultTLSClientAuth,
"and the CA certificate together. The primary certificate should appear first in the combined file")
cliflag.StringVarP(root.Flags(), &tlsClientCAFile, "tls-client-ca-file", "", "CODER_TLS_CLIENT_CA_FILE", "",
"PEM-encoded Certificate Authority file used for checking the authenticity of client")
cliflag.StringVarP(root.Flags(), &tlsClientAuth, "tls-client-auth", "", "CODER_TLS_CLIENT_AUTH", "request",
`Specifies the policy the server will follow for TLS Client Authentication. `+
`Accepted values are "none", "request", "require-any", "verify-if-given", or "require-and-verify" (uses $CODER_TLS_CLIENT_AUTH).`)
root.Flags().StringVarP(&tlsKeyFile, "tls-key-file", "", os.Getenv("CODER_TLS_KEY_FILE"),
"Specifies the path to the private key for the certificate. It requires a PEM-encoded file (uses $CODER_TLS_KEY_FILE).")
defaultTLSMinVersion := os.Getenv("CODER_TLS_MIN_VERSION")
if defaultTLSMinVersion == "" {
defaultTLSMinVersion = "tls12"
}
root.Flags().StringVarP(&tlsMinVersion, "tls-min-version", "", defaultTLSMinVersion,
`Specifies the minimum supported version of TLS. Accepted values are "tls10", "tls11", "tls12" or "tls13" (uses $CODER_TLS_MIN_VERSION).`)
defaultTunnelRaw := os.Getenv("CODER_DEV_TUNNEL")
if defaultTunnelRaw == "" {
defaultTunnelRaw = "true"
}
defaultTunnel, _ := strconv.ParseBool(defaultTunnelRaw)
root.Flags().BoolVarP(&useTunnel, "tunnel", "", defaultTunnel, "Serve dev mode through a Cloudflare Tunnel for easy setup (uses $CODER_DEV_TUNNEL).")
`Accepted values are "none", "request", "require-any", "verify-if-given", or "require-and-verify"`)
cliflag.StringVarP(root.Flags(), &tlsKeyFile, "tls-key-file", "", "CODER_TLS_KEY_FILE", "",
"Specifies the path to the private key for the certificate. It requires a PEM-encoded file")
cliflag.StringVarP(root.Flags(), &tlsMinVersion, "tls-min-version", "", "CODER_TLS_MIN_VERSION", "tls12",
`Specifies the minimum supported version of TLS. Accepted values are "tls10", "tls11", "tls12" or "tls13"`)
cliflag.BoolVarP(root.Flags(), &useTunnel, "tunnel", "", "CODER_DEV_TUNNEL", false, "Serve dev mode through a Cloudflare Tunnel for easy setup")
_ = root.Flags().MarkHidden("tunnel")
return root

View File

@ -3,7 +3,6 @@ package cli
import (
"context"
"net/url"
"os"
"time"
"cloud.google.com/go/compute/metadata"
@ -14,6 +13,7 @@ import (
"cdr.dev/slog/sloggers/sloghuman"
"github.com/coder/coder/agent"
"github.com/coder/coder/cli/cliflag"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/peer"
"github.com/coder/retry"
@ -23,6 +23,7 @@ func workspaceAgent() *cobra.Command {
var (
rawURL string
auth string
token string
)
cmd := &cobra.Command{
Use: "agent",
@ -40,11 +41,10 @@ func workspaceAgent() *cobra.Command {
client := codersdk.New(coderURL)
switch auth {
case "token":
sessionToken, exists := os.LookupEnv("CODER_TOKEN")
if !exists {
if token == "" {
return xerrors.Errorf("CODER_TOKEN must be set for token auth")
}
client.SessionToken = sessionToken
client.SessionToken = token
case "google-instance-identity":
// This is *only* done for testing to mock client authentication.
// This will never be set in a production scenario.
@ -83,12 +83,10 @@ func workspaceAgent() *cobra.Command {
return closer.Close()
},
}
defaultAuth := os.Getenv("CODER_AUTH")
if defaultAuth == "" {
defaultAuth = "token"
}
cmd.Flags().StringVarP(&auth, "auth", "", defaultAuth, "Specify the authentication type to use for the agent.")
cmd.Flags().StringVarP(&rawURL, "url", "", os.Getenv("CODER_URL"), "Specify the URL to access Coder.")
cliflag.StringVarP(cmd.Flags(), &auth, "auth", "", "CODER_AUTH", "token", "Specify the authentication type to use for the agent")
cliflag.StringVarP(cmd.Flags(), &rawURL, "url", "", "CODER_URL", "", "Specify the URL to access Coder")
cliflag.StringVarP(cmd.Flags(), &auth, "token", "", "CODER_TOKEN", "", "Specifies the authentication token to access Coder")
return cmd
}