mirror of
https://github.com/coder/coder.git
synced 2025-07-09 11:45:56 +00:00
418 lines
10 KiB
Go
418 lines
10 KiB
Go
package clibase
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"io"
|
|
"os"
|
|
"strings"
|
|
|
|
"github.com/spf13/pflag"
|
|
"golang.org/x/xerrors"
|
|
)
|
|
|
|
// Cmd describes an executable command.
|
|
type Cmd struct {
|
|
// Parent is the direct parent of the command.
|
|
Parent *Cmd
|
|
// Children is a list of direct descendants.
|
|
Children []*Cmd
|
|
// Use is provided in form "command [flags] [args...]".
|
|
Use string
|
|
|
|
// Aliases is a list of alternative names for the command.
|
|
Aliases []string
|
|
|
|
// Short is a one-line description of the command.
|
|
Short string
|
|
|
|
// Hidden determines whether the command should be hidden from help.
|
|
Hidden bool
|
|
|
|
// RawArgs determines whether the command should receive unparsed arguments.
|
|
// No flags are parsed when set, and the command is responsible for parsing
|
|
// its own flags.
|
|
RawArgs bool
|
|
|
|
// Long is a detailed description of the command,
|
|
// presented on its help page. It may contain examples.
|
|
Long string
|
|
Options OptionSet
|
|
Annotations Annotations
|
|
|
|
// Middleware is called before the Handler.
|
|
// Use Chain() to combine multiple middlewares.
|
|
Middleware MiddlewareFunc
|
|
Handler HandlerFunc
|
|
HelpHandler HandlerFunc
|
|
}
|
|
|
|
// Walk calls fn for the command and all its children.
|
|
func (c *Cmd) Walk(fn func(*Cmd)) {
|
|
fn(c)
|
|
for _, child := range c.Children {
|
|
child.Walk(fn)
|
|
}
|
|
}
|
|
|
|
// Name returns the first word in the Use string.
|
|
func (c *Cmd) Name() string {
|
|
return strings.Split(c.Use, " ")[0]
|
|
}
|
|
|
|
// FullName returns the full invocation name of the command,
|
|
// as seen on the command line.
|
|
func (c *Cmd) FullName() string {
|
|
var names []string
|
|
|
|
if c.Parent != nil {
|
|
names = append(names, c.Parent.FullName())
|
|
}
|
|
names = append(names, c.Name())
|
|
return strings.Join(names, " ")
|
|
}
|
|
|
|
// FullName returns usage of the command, preceded
|
|
// by the usage of its parents.
|
|
func (c *Cmd) FullUsage() string {
|
|
var uses []string
|
|
if c.Parent != nil {
|
|
uses = append(uses, c.Parent.FullUsage())
|
|
}
|
|
uses = append(uses, c.Use)
|
|
return strings.Join(uses, " ")
|
|
}
|
|
|
|
// Invoke creates a new invocation of the command, with
|
|
// stdio discarded.
|
|
//
|
|
// The returned invocation is not live until Run() is called.
|
|
func (c *Cmd) Invoke(args ...string) *Invocation {
|
|
return &Invocation{
|
|
Command: c,
|
|
Args: args,
|
|
Stdout: io.Discard,
|
|
Stderr: io.Discard,
|
|
Stdin: strings.NewReader(""),
|
|
}
|
|
}
|
|
|
|
// Invocation represents an instance of a command being executed.
|
|
type Invocation struct {
|
|
ctx context.Context
|
|
Command *Cmd
|
|
parsedFlags *pflag.FlagSet
|
|
Args []string
|
|
// Environ is a list of environment variables. Use EnvsWithPrefix to parse
|
|
// os.Environ.
|
|
Environ Environ
|
|
Stdout io.Writer
|
|
Stderr io.Writer
|
|
Stdin io.Reader
|
|
}
|
|
|
|
// WithOS returns the invocation as a main package, filling in the invocation's unset
|
|
// fields with OS defaults.
|
|
func (i *Invocation) WithOS() *Invocation {
|
|
return i.with(func(i *Invocation) {
|
|
if i.Stdout == nil {
|
|
i.Stdout = os.Stdout
|
|
}
|
|
if i.Stderr == nil {
|
|
i.Stderr = os.Stderr
|
|
}
|
|
if i.Stdin == nil {
|
|
i.Stdin = os.Stdin
|
|
}
|
|
if i.Args == nil {
|
|
i.Args = os.Args[1:]
|
|
}
|
|
if i.Environ == nil {
|
|
i.Environ = ParseEnviron(os.Environ(), "")
|
|
}
|
|
})
|
|
}
|
|
|
|
func (i *Invocation) Context() context.Context {
|
|
if i.ctx == nil {
|
|
// Consider returning context.Background() instead?
|
|
panic("context not set, has WithContext() or Run() been called?")
|
|
}
|
|
return i.ctx
|
|
}
|
|
|
|
func (i *Invocation) ParsedFlags() *pflag.FlagSet {
|
|
if i.parsedFlags == nil {
|
|
panic("flags not parsed, has Run() been called?")
|
|
}
|
|
return i.parsedFlags
|
|
}
|
|
|
|
type runState struct {
|
|
allArgs []string
|
|
commandDepth int
|
|
|
|
flagParseErr error
|
|
}
|
|
|
|
// run recursively executes the command and its children.
|
|
// allArgs is wired through the stack so that global flags can be accepted
|
|
// anywhere in the command invocation.
|
|
func (i *Invocation) run(state *runState) error {
|
|
err := i.Command.Options.SetDefaults()
|
|
if err != nil {
|
|
return xerrors.Errorf("setting defaults: %w", err)
|
|
}
|
|
|
|
err = i.Command.Options.ParseEnv(i.Environ)
|
|
if err != nil {
|
|
return xerrors.Errorf("parsing env: %w", err)
|
|
}
|
|
|
|
// Now the fun part, argument parsing!
|
|
|
|
children := make(map[string]*Cmd)
|
|
for _, child := range i.Command.Children {
|
|
for _, name := range append(child.Aliases, child.Name()) {
|
|
if _, ok := children[name]; ok {
|
|
return xerrors.Errorf("duplicate command name: %s", name)
|
|
}
|
|
children[name] = child
|
|
}
|
|
}
|
|
|
|
if i.parsedFlags == nil {
|
|
i.parsedFlags = pflag.NewFlagSet(i.Command.Name(), pflag.ContinueOnError)
|
|
// We handle Usage ourselves.
|
|
i.parsedFlags.Usage = func() {}
|
|
}
|
|
|
|
i.parsedFlags.AddFlagSet(i.Command.Options.FlagSet())
|
|
|
|
var parsedArgs []string
|
|
|
|
if !i.Command.RawArgs {
|
|
// Flag parsing will fail on intermediate commands in the command tree,
|
|
// so we check the error after looking for a child command.
|
|
state.flagParseErr = i.parsedFlags.Parse(state.allArgs)
|
|
parsedArgs = i.parsedFlags.Args()
|
|
}
|
|
|
|
// Run child command if found (next child only)
|
|
// We must do subcommand detection after flag parsing so we don't mistake flag
|
|
// values for subcommand names.
|
|
if len(parsedArgs) > 0 {
|
|
nextArg := parsedArgs[0]
|
|
if child, ok := children[nextArg]; ok {
|
|
child.Parent = i.Command
|
|
i.Command = child
|
|
state.commandDepth++
|
|
err = i.run(state)
|
|
if err != nil {
|
|
return xerrors.Errorf(
|
|
"subcommand %s: %w", child.Name(), err,
|
|
)
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// Flag parse errors are irrelevant for raw args commands.
|
|
if !i.Command.RawArgs && state.flagParseErr != nil && !errors.Is(state.flagParseErr, pflag.ErrHelp) {
|
|
return xerrors.Errorf(
|
|
"parsing flags (%v) for %q: %w",
|
|
state.allArgs,
|
|
i.Command.FullName(), state.flagParseErr,
|
|
)
|
|
}
|
|
|
|
if i.Command.RawArgs {
|
|
// If we're at the root command, then the name is omitted
|
|
// from the arguments, so we can just use the entire slice.
|
|
if state.commandDepth == 0 {
|
|
i.Args = state.allArgs
|
|
} else {
|
|
argPos, err := findArg(i.Command.Name(), state.allArgs, i.parsedFlags)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
i.Args = state.allArgs[argPos+1:]
|
|
}
|
|
} else {
|
|
// In non-raw-arg mode, we want to skip over flags.
|
|
i.Args = parsedArgs[state.commandDepth:]
|
|
}
|
|
|
|
mw := i.Command.Middleware
|
|
if mw == nil {
|
|
mw = Chain()
|
|
}
|
|
|
|
ctx := i.ctx
|
|
if ctx == nil {
|
|
ctx = context.Background()
|
|
}
|
|
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
defer cancel()
|
|
i = i.WithContext(ctx)
|
|
|
|
if i.Command.Handler == nil || errors.Is(state.flagParseErr, pflag.ErrHelp) {
|
|
if i.Command.HelpHandler == nil {
|
|
return xerrors.Errorf("no handler or help for command %s", i.Command.FullName())
|
|
}
|
|
return i.Command.HelpHandler(i)
|
|
}
|
|
|
|
err = mw(i.Command.Handler)(i)
|
|
if err != nil {
|
|
return xerrors.Errorf("running command %s: %w", i.Command.FullName(), err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// findArg returns the index of the first occurrence of arg in args, skipping
|
|
// over all flags.
|
|
func findArg(want string, args []string, fs *pflag.FlagSet) (int, error) {
|
|
for i := 0; i < len(args); i++ {
|
|
arg := args[i]
|
|
if !strings.HasPrefix(arg, "-") {
|
|
if arg == want {
|
|
return i, nil
|
|
}
|
|
continue
|
|
}
|
|
|
|
// This is a flag!
|
|
if strings.Contains(arg, "=") {
|
|
// The flag contains the value in the same arg, just skip.
|
|
continue
|
|
}
|
|
|
|
// We need to check if NoOptValue is set, then we should not wait
|
|
// for the next arg to be the value.
|
|
f := fs.Lookup(strings.TrimLeft(arg, "-"))
|
|
if f == nil {
|
|
return -1, xerrors.Errorf("unknown flag: %s", arg)
|
|
}
|
|
if f.NoOptDefVal != "" {
|
|
continue
|
|
}
|
|
|
|
if i == len(args)-1 {
|
|
return -1, xerrors.Errorf("flag %s requires a value", arg)
|
|
}
|
|
|
|
// Skip the value.
|
|
i++
|
|
}
|
|
|
|
return -1, xerrors.Errorf("arg %s not found", want)
|
|
}
|
|
|
|
// Run executes the command.
|
|
// If two command share a flag name, the first command wins.
|
|
//
|
|
//nolint:revive
|
|
func (i *Invocation) Run() error {
|
|
return i.run(&runState{
|
|
allArgs: i.Args,
|
|
})
|
|
}
|
|
|
|
// WithContext returns a copy of the Invocation with the given context.
|
|
func (i *Invocation) WithContext(ctx context.Context) *Invocation {
|
|
return i.with(func(i *Invocation) {
|
|
i.ctx = ctx
|
|
})
|
|
}
|
|
|
|
// with returns a copy of the Invocation with the given function applied.
|
|
func (i *Invocation) with(fn func(*Invocation)) *Invocation {
|
|
i2 := *i
|
|
fn(&i2)
|
|
return &i2
|
|
}
|
|
|
|
// MiddlewareFunc returns the next handler in the chain,
|
|
// or nil if there are no more.
|
|
type MiddlewareFunc func(next HandlerFunc) HandlerFunc
|
|
|
|
func chain(ms ...MiddlewareFunc) MiddlewareFunc {
|
|
return MiddlewareFunc(func(next HandlerFunc) HandlerFunc {
|
|
if len(ms) > 0 {
|
|
return chain(ms[1:]...)(ms[0](next))
|
|
}
|
|
return next
|
|
})
|
|
}
|
|
|
|
// Chain returns a Handler that first calls middleware in order.
|
|
//
|
|
//nolint:revive
|
|
func Chain(ms ...MiddlewareFunc) MiddlewareFunc {
|
|
// We need to reverse the array to provide top-to-bottom execution
|
|
// order when defining a command.
|
|
reversed := make([]MiddlewareFunc, len(ms))
|
|
for i := range ms {
|
|
reversed[len(ms)-1-i] = ms[i]
|
|
}
|
|
return chain(reversed...)
|
|
}
|
|
|
|
func RequireNArgs(want int) MiddlewareFunc {
|
|
return RequireRangeArgs(want, want)
|
|
}
|
|
|
|
// RequireRangeArgs returns a Middleware that requires the number of arguments
|
|
// to be between start and end (inclusive). If end is -1, then the number of
|
|
// arguments must be at least start.
|
|
func RequireRangeArgs(start, end int) MiddlewareFunc {
|
|
if start < 0 {
|
|
panic("start must be >= 0")
|
|
}
|
|
return func(next HandlerFunc) HandlerFunc {
|
|
return func(i *Invocation) error {
|
|
got := len(i.Args)
|
|
switch {
|
|
case start == end && got != start:
|
|
switch start {
|
|
case 0:
|
|
return xerrors.Errorf("wanted no args but got %v %v", got, i.Args)
|
|
default:
|
|
return xerrors.Errorf(
|
|
"wanted %v args but got %v %v",
|
|
start,
|
|
got,
|
|
i.Args,
|
|
)
|
|
}
|
|
case start > 0 && end == -1:
|
|
switch {
|
|
case got < start:
|
|
return xerrors.Errorf(
|
|
"wanted at least %v args but got %v",
|
|
start,
|
|
got,
|
|
)
|
|
default:
|
|
return next(i)
|
|
}
|
|
case start > end:
|
|
panic("start must be <= end")
|
|
case got < start || got > end:
|
|
return xerrors.Errorf(
|
|
"wanted between %v and %v args but got %v",
|
|
start, end,
|
|
got,
|
|
)
|
|
default:
|
|
return next(i)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// HandlerFunc handles an Invocation of a command.
|
|
type HandlerFunc func(i *Invocation) error
|