Files
coder/cli/clibase/cmd.go
2023-03-10 13:41:32 -06:00

418 lines
10 KiB
Go

package clibase
import (
"context"
"errors"
"io"
"os"
"strings"
"github.com/spf13/pflag"
"golang.org/x/xerrors"
)
// Cmd describes an executable command.
type Cmd struct {
// Parent is the direct parent of the command.
Parent *Cmd
// Children is a list of direct descendants.
Children []*Cmd
// Use is provided in form "command [flags] [args...]".
Use string
// Aliases is a list of alternative names for the command.
Aliases []string
// Short is a one-line description of the command.
Short string
// Hidden determines whether the command should be hidden from help.
Hidden bool
// RawArgs determines whether the command should receive unparsed arguments.
// No flags are parsed when set, and the command is responsible for parsing
// its own flags.
RawArgs bool
// Long is a detailed description of the command,
// presented on its help page. It may contain examples.
Long string
Options OptionSet
Annotations Annotations
// Middleware is called before the Handler.
// Use Chain() to combine multiple middlewares.
Middleware MiddlewareFunc
Handler HandlerFunc
HelpHandler HandlerFunc
}
// Walk calls fn for the command and all its children.
func (c *Cmd) Walk(fn func(*Cmd)) {
fn(c)
for _, child := range c.Children {
child.Walk(fn)
}
}
// Name returns the first word in the Use string.
func (c *Cmd) Name() string {
return strings.Split(c.Use, " ")[0]
}
// FullName returns the full invocation name of the command,
// as seen on the command line.
func (c *Cmd) FullName() string {
var names []string
if c.Parent != nil {
names = append(names, c.Parent.FullName())
}
names = append(names, c.Name())
return strings.Join(names, " ")
}
// FullName returns usage of the command, preceded
// by the usage of its parents.
func (c *Cmd) FullUsage() string {
var uses []string
if c.Parent != nil {
uses = append(uses, c.Parent.FullUsage())
}
uses = append(uses, c.Use)
return strings.Join(uses, " ")
}
// Invoke creates a new invocation of the command, with
// stdio discarded.
//
// The returned invocation is not live until Run() is called.
func (c *Cmd) Invoke(args ...string) *Invocation {
return &Invocation{
Command: c,
Args: args,
Stdout: io.Discard,
Stderr: io.Discard,
Stdin: strings.NewReader(""),
}
}
// Invocation represents an instance of a command being executed.
type Invocation struct {
ctx context.Context
Command *Cmd
parsedFlags *pflag.FlagSet
Args []string
// Environ is a list of environment variables. Use EnvsWithPrefix to parse
// os.Environ.
Environ Environ
Stdout io.Writer
Stderr io.Writer
Stdin io.Reader
}
// WithOS returns the invocation as a main package, filling in the invocation's unset
// fields with OS defaults.
func (i *Invocation) WithOS() *Invocation {
return i.with(func(i *Invocation) {
if i.Stdout == nil {
i.Stdout = os.Stdout
}
if i.Stderr == nil {
i.Stderr = os.Stderr
}
if i.Stdin == nil {
i.Stdin = os.Stdin
}
if i.Args == nil {
i.Args = os.Args[1:]
}
if i.Environ == nil {
i.Environ = ParseEnviron(os.Environ(), "")
}
})
}
func (i *Invocation) Context() context.Context {
if i.ctx == nil {
// Consider returning context.Background() instead?
panic("context not set, has WithContext() or Run() been called?")
}
return i.ctx
}
func (i *Invocation) ParsedFlags() *pflag.FlagSet {
if i.parsedFlags == nil {
panic("flags not parsed, has Run() been called?")
}
return i.parsedFlags
}
type runState struct {
allArgs []string
commandDepth int
flagParseErr error
}
// run recursively executes the command and its children.
// allArgs is wired through the stack so that global flags can be accepted
// anywhere in the command invocation.
func (i *Invocation) run(state *runState) error {
err := i.Command.Options.SetDefaults()
if err != nil {
return xerrors.Errorf("setting defaults: %w", err)
}
err = i.Command.Options.ParseEnv(i.Environ)
if err != nil {
return xerrors.Errorf("parsing env: %w", err)
}
// Now the fun part, argument parsing!
children := make(map[string]*Cmd)
for _, child := range i.Command.Children {
for _, name := range append(child.Aliases, child.Name()) {
if _, ok := children[name]; ok {
return xerrors.Errorf("duplicate command name: %s", name)
}
children[name] = child
}
}
if i.parsedFlags == nil {
i.parsedFlags = pflag.NewFlagSet(i.Command.Name(), pflag.ContinueOnError)
// We handle Usage ourselves.
i.parsedFlags.Usage = func() {}
}
i.parsedFlags.AddFlagSet(i.Command.Options.FlagSet())
var parsedArgs []string
if !i.Command.RawArgs {
// Flag parsing will fail on intermediate commands in the command tree,
// so we check the error after looking for a child command.
state.flagParseErr = i.parsedFlags.Parse(state.allArgs)
parsedArgs = i.parsedFlags.Args()
}
// Run child command if found (next child only)
// We must do subcommand detection after flag parsing so we don't mistake flag
// values for subcommand names.
if len(parsedArgs) > 0 {
nextArg := parsedArgs[0]
if child, ok := children[nextArg]; ok {
child.Parent = i.Command
i.Command = child
state.commandDepth++
err = i.run(state)
if err != nil {
return xerrors.Errorf(
"subcommand %s: %w", child.Name(), err,
)
}
return nil
}
}
// Flag parse errors are irrelevant for raw args commands.
if !i.Command.RawArgs && state.flagParseErr != nil && !errors.Is(state.flagParseErr, pflag.ErrHelp) {
return xerrors.Errorf(
"parsing flags (%v) for %q: %w",
state.allArgs,
i.Command.FullName(), state.flagParseErr,
)
}
if i.Command.RawArgs {
// If we're at the root command, then the name is omitted
// from the arguments, so we can just use the entire slice.
if state.commandDepth == 0 {
i.Args = state.allArgs
} else {
argPos, err := findArg(i.Command.Name(), state.allArgs, i.parsedFlags)
if err != nil {
panic(err)
}
i.Args = state.allArgs[argPos+1:]
}
} else {
// In non-raw-arg mode, we want to skip over flags.
i.Args = parsedArgs[state.commandDepth:]
}
mw := i.Command.Middleware
if mw == nil {
mw = Chain()
}
ctx := i.ctx
if ctx == nil {
ctx = context.Background()
}
ctx, cancel := context.WithCancel(ctx)
defer cancel()
i = i.WithContext(ctx)
if i.Command.Handler == nil || errors.Is(state.flagParseErr, pflag.ErrHelp) {
if i.Command.HelpHandler == nil {
return xerrors.Errorf("no handler or help for command %s", i.Command.FullName())
}
return i.Command.HelpHandler(i)
}
err = mw(i.Command.Handler)(i)
if err != nil {
return xerrors.Errorf("running command %s: %w", i.Command.FullName(), err)
}
return nil
}
// findArg returns the index of the first occurrence of arg in args, skipping
// over all flags.
func findArg(want string, args []string, fs *pflag.FlagSet) (int, error) {
for i := 0; i < len(args); i++ {
arg := args[i]
if !strings.HasPrefix(arg, "-") {
if arg == want {
return i, nil
}
continue
}
// This is a flag!
if strings.Contains(arg, "=") {
// The flag contains the value in the same arg, just skip.
continue
}
// We need to check if NoOptValue is set, then we should not wait
// for the next arg to be the value.
f := fs.Lookup(strings.TrimLeft(arg, "-"))
if f == nil {
return -1, xerrors.Errorf("unknown flag: %s", arg)
}
if f.NoOptDefVal != "" {
continue
}
if i == len(args)-1 {
return -1, xerrors.Errorf("flag %s requires a value", arg)
}
// Skip the value.
i++
}
return -1, xerrors.Errorf("arg %s not found", want)
}
// Run executes the command.
// If two command share a flag name, the first command wins.
//
//nolint:revive
func (i *Invocation) Run() error {
return i.run(&runState{
allArgs: i.Args,
})
}
// WithContext returns a copy of the Invocation with the given context.
func (i *Invocation) WithContext(ctx context.Context) *Invocation {
return i.with(func(i *Invocation) {
i.ctx = ctx
})
}
// with returns a copy of the Invocation with the given function applied.
func (i *Invocation) with(fn func(*Invocation)) *Invocation {
i2 := *i
fn(&i2)
return &i2
}
// MiddlewareFunc returns the next handler in the chain,
// or nil if there are no more.
type MiddlewareFunc func(next HandlerFunc) HandlerFunc
func chain(ms ...MiddlewareFunc) MiddlewareFunc {
return MiddlewareFunc(func(next HandlerFunc) HandlerFunc {
if len(ms) > 0 {
return chain(ms[1:]...)(ms[0](next))
}
return next
})
}
// Chain returns a Handler that first calls middleware in order.
//
//nolint:revive
func Chain(ms ...MiddlewareFunc) MiddlewareFunc {
// We need to reverse the array to provide top-to-bottom execution
// order when defining a command.
reversed := make([]MiddlewareFunc, len(ms))
for i := range ms {
reversed[len(ms)-1-i] = ms[i]
}
return chain(reversed...)
}
func RequireNArgs(want int) MiddlewareFunc {
return RequireRangeArgs(want, want)
}
// RequireRangeArgs returns a Middleware that requires the number of arguments
// to be between start and end (inclusive). If end is -1, then the number of
// arguments must be at least start.
func RequireRangeArgs(start, end int) MiddlewareFunc {
if start < 0 {
panic("start must be >= 0")
}
return func(next HandlerFunc) HandlerFunc {
return func(i *Invocation) error {
got := len(i.Args)
switch {
case start == end && got != start:
switch start {
case 0:
return xerrors.Errorf("wanted no args but got %v %v", got, i.Args)
default:
return xerrors.Errorf(
"wanted %v args but got %v %v",
start,
got,
i.Args,
)
}
case start > 0 && end == -1:
switch {
case got < start:
return xerrors.Errorf(
"wanted at least %v args but got %v",
start,
got,
)
default:
return next(i)
}
case start > end:
panic("start must be <= end")
case got < start || got > end:
return xerrors.Errorf(
"wanted between %v and %v args but got %v",
start, end,
got,
)
default:
return next(i)
}
}
}
}
// HandlerFunc handles an Invocation of a command.
type HandlerFunc func(i *Invocation) error