mirror of
https://github.com/coder/coder.git
synced 2025-07-12 00:14:10 +00:00
feat(cli): add coder exp mcp
command (#17066)
Adds a `coder exp mcp` command which will start a local MCP server listening on stdio with the following capabilities: * Show logged in user (`coder whoami`) * List workspaces (`coder list`) * List templates (`coder templates list`) * Start a workspace (`coder start`) * Stop a workspace (`coder stop`) * Fetch a single workspace (no direct CLI analogue) * Execute a command inside a workspace (`coder exp rpty`) * Report the status of a task (currently a no-op, pending task support) This can be tested as follows: ``` # Start a local Coder server. ./scripts/develop.sh # Start a workspace. Currently, creating workspaces is not supported. ./scripts/coder-dev.sh create -t docker --yes # Add the MCP to your Claude config. claude mcp add coder ./scripts/coder-dev.sh exp mcp # Tell Claude to do something Coder-related. You may need to nudge it to use the tools. claude 'start a docker workspace and tell me what version of python is installed' ```
This commit is contained in:
@ -11,7 +11,9 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/cli/config"
|
||||
@ -117,11 +119,7 @@ func TestGoldenFile(t *testing.T, fileName string, actual []byte, replacements m
|
||||
require.NoError(t, err, "read golden file, run \"make gen/golden-files\" and commit the changes")
|
||||
|
||||
expected = normalizeGoldenFile(t, expected)
|
||||
require.Equal(
|
||||
t, string(expected), string(actual),
|
||||
"golden file mismatch: %s, run \"make gen/golden-files\", verify and commit the changes",
|
||||
goldenPath,
|
||||
)
|
||||
assert.Empty(t, cmp.Diff(string(expected), string(actual)), "golden file mismatch (-want +got): %s, run \"make gen/golden-files\", verify and commit the changes", goldenPath)
|
||||
}
|
||||
|
||||
// normalizeGoldenFile replaces any strings that are system or timing dependent
|
||||
|
@ -13,6 +13,7 @@ func (r *RootCmd) expCmd() *serpent.Command {
|
||||
Children: []*serpent.Command{
|
||||
r.scaletestCmd(),
|
||||
r.errorExample(),
|
||||
r.mcpCommand(),
|
||||
r.promptExample(),
|
||||
r.rptyCommand(),
|
||||
},
|
||||
|
284
cli/exp_mcp.go
Normal file
284
cli/exp_mcp.go
Normal file
@ -0,0 +1,284 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/sloghuman"
|
||||
"github.com/coder/coder/v2/cli/cliui"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
codermcp "github.com/coder/coder/v2/mcp"
|
||||
"github.com/coder/serpent"
|
||||
)
|
||||
|
||||
func (r *RootCmd) mcpCommand() *serpent.Command {
|
||||
cmd := &serpent.Command{
|
||||
Use: "mcp",
|
||||
Short: "Run the Coder MCP server and configure it to work with AI tools.",
|
||||
Long: "The Coder MCP server allows you to automatically create workspaces with parameters.",
|
||||
Handler: func(i *serpent.Invocation) error {
|
||||
return i.Command.HelpHandler(i)
|
||||
},
|
||||
Children: []*serpent.Command{
|
||||
r.mcpConfigure(),
|
||||
r.mcpServer(),
|
||||
},
|
||||
}
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (r *RootCmd) mcpConfigure() *serpent.Command {
|
||||
cmd := &serpent.Command{
|
||||
Use: "configure",
|
||||
Short: "Automatically configure the MCP server.",
|
||||
Handler: func(i *serpent.Invocation) error {
|
||||
return i.Command.HelpHandler(i)
|
||||
},
|
||||
Children: []*serpent.Command{
|
||||
r.mcpConfigureClaudeDesktop(),
|
||||
r.mcpConfigureClaudeCode(),
|
||||
r.mcpConfigureCursor(),
|
||||
},
|
||||
}
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (*RootCmd) mcpConfigureClaudeDesktop() *serpent.Command {
|
||||
cmd := &serpent.Command{
|
||||
Use: "claude-desktop",
|
||||
Short: "Configure the Claude Desktop server.",
|
||||
Handler: func(_ *serpent.Invocation) error {
|
||||
configPath, err := os.UserConfigDir()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
configPath = filepath.Join(configPath, "Claude")
|
||||
err = os.MkdirAll(configPath, 0o755)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
configPath = filepath.Join(configPath, "claude_desktop_config.json")
|
||||
_, err = os.Stat(configPath)
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
contents := map[string]any{}
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
err = json.Unmarshal(data, &contents)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
binPath, err := os.Executable()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
contents["mcpServers"] = map[string]any{
|
||||
"coder": map[string]any{"command": binPath, "args": []string{"exp", "mcp", "server"}},
|
||||
}
|
||||
data, err = json.MarshalIndent(contents, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = os.WriteFile(configPath, data, 0o600)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (*RootCmd) mcpConfigureClaudeCode() *serpent.Command {
|
||||
cmd := &serpent.Command{
|
||||
Use: "claude-code",
|
||||
Short: "Configure the Claude Code server.",
|
||||
Handler: func(_ *serpent.Invocation) error {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (*RootCmd) mcpConfigureCursor() *serpent.Command {
|
||||
var project bool
|
||||
cmd := &serpent.Command{
|
||||
Use: "cursor",
|
||||
Short: "Configure Cursor to use Coder MCP.",
|
||||
Options: serpent.OptionSet{
|
||||
serpent.Option{
|
||||
Flag: "project",
|
||||
Env: "CODER_MCP_CURSOR_PROJECT",
|
||||
Description: "Use to configure a local project to use the Cursor MCP.",
|
||||
Value: serpent.BoolOf(&project),
|
||||
},
|
||||
},
|
||||
Handler: func(_ *serpent.Invocation) error {
|
||||
dir, err := os.Getwd()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !project {
|
||||
dir, err = os.UserHomeDir()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
cursorDir := filepath.Join(dir, ".cursor")
|
||||
err = os.MkdirAll(cursorDir, 0o755)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
mcpConfig := filepath.Join(cursorDir, "mcp.json")
|
||||
_, err = os.Stat(mcpConfig)
|
||||
contents := map[string]any{}
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
data, err := os.ReadFile(mcpConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// The config can be empty, so we don't want to return an error if it is.
|
||||
if len(data) > 0 {
|
||||
err = json.Unmarshal(data, &contents)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
mcpServers, ok := contents["mcpServers"].(map[string]any)
|
||||
if !ok {
|
||||
mcpServers = map[string]any{}
|
||||
}
|
||||
binPath, err := os.Executable()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
mcpServers["coder"] = map[string]any{
|
||||
"command": binPath,
|
||||
"args": []string{"exp", "mcp", "server"},
|
||||
}
|
||||
contents["mcpServers"] = mcpServers
|
||||
data, err := json.MarshalIndent(contents, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = os.WriteFile(mcpConfig, data, 0o600)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (r *RootCmd) mcpServer() *serpent.Command {
|
||||
var (
|
||||
client = new(codersdk.Client)
|
||||
instructions string
|
||||
allowedTools []string
|
||||
)
|
||||
return &serpent.Command{
|
||||
Use: "server",
|
||||
Handler: func(inv *serpent.Invocation) error {
|
||||
return mcpServerHandler(inv, client, instructions, allowedTools)
|
||||
},
|
||||
Short: "Start the Coder MCP server.",
|
||||
Middleware: serpent.Chain(
|
||||
r.InitClient(client),
|
||||
),
|
||||
Options: []serpent.Option{
|
||||
{
|
||||
Name: "instructions",
|
||||
Description: "The instructions to pass to the MCP server.",
|
||||
Flag: "instructions",
|
||||
Value: serpent.StringOf(&instructions),
|
||||
},
|
||||
{
|
||||
Name: "allowed-tools",
|
||||
Description: "Comma-separated list of allowed tools. If not specified, all tools are allowed.",
|
||||
Flag: "allowed-tools",
|
||||
Value: serpent.StringArrayOf(&allowedTools),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func mcpServerHandler(inv *serpent.Invocation, client *codersdk.Client, instructions string, allowedTools []string) error {
|
||||
ctx, cancel := context.WithCancel(inv.Context())
|
||||
defer cancel()
|
||||
|
||||
logger := slog.Make(sloghuman.Sink(inv.Stdout))
|
||||
|
||||
me, err := client.User(ctx, codersdk.Me)
|
||||
if err != nil {
|
||||
cliui.Errorf(inv.Stderr, "Failed to log in to the Coder deployment.")
|
||||
cliui.Errorf(inv.Stderr, "Please check your URL and credentials.")
|
||||
cliui.Errorf(inv.Stderr, "Tip: Run `coder whoami` to check your credentials.")
|
||||
return err
|
||||
}
|
||||
cliui.Infof(inv.Stderr, "Starting MCP server")
|
||||
cliui.Infof(inv.Stderr, "User : %s", me.Username)
|
||||
cliui.Infof(inv.Stderr, "URL : %s", client.URL)
|
||||
cliui.Infof(inv.Stderr, "Instructions : %q", instructions)
|
||||
if len(allowedTools) > 0 {
|
||||
cliui.Infof(inv.Stderr, "Allowed Tools : %v", allowedTools)
|
||||
}
|
||||
cliui.Infof(inv.Stderr, "Press Ctrl+C to stop the server")
|
||||
|
||||
// Capture the original stdin, stdout, and stderr.
|
||||
invStdin := inv.Stdin
|
||||
invStdout := inv.Stdout
|
||||
invStderr := inv.Stderr
|
||||
defer func() {
|
||||
inv.Stdin = invStdin
|
||||
inv.Stdout = invStdout
|
||||
inv.Stderr = invStderr
|
||||
}()
|
||||
|
||||
options := []codermcp.Option{
|
||||
codermcp.WithInstructions(instructions),
|
||||
codermcp.WithLogger(&logger),
|
||||
}
|
||||
|
||||
// Add allowed tools option if specified
|
||||
if len(allowedTools) > 0 {
|
||||
options = append(options, codermcp.WithAllowedTools(allowedTools))
|
||||
}
|
||||
|
||||
srv := codermcp.NewStdio(client, options...)
|
||||
srv.SetErrorLogger(log.New(invStderr, "", log.LstdFlags))
|
||||
|
||||
done := make(chan error)
|
||||
go func() {
|
||||
defer close(done)
|
||||
srvErr := srv.Listen(ctx, invStdin, invStdout)
|
||||
done <- srvErr
|
||||
}()
|
||||
|
||||
if err := <-done; err != nil {
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
cliui.Errorf(inv.Stderr, "Failed to start the MCP server: %s", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
142
cli/exp_mcp_test.go
Normal file
142
cli/exp_mcp_test.go
Normal file
@ -0,0 +1,142 @@
|
||||
package cli_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"runtime"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/cli/clitest"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/pty/ptytest"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestExpMcp(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Reading to / writing from the PTY is flaky on non-linux systems.
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skip("skipping on non-linux")
|
||||
}
|
||||
|
||||
t.Run("AllowedTools", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
cancelCtx, cancel := context.WithCancel(ctx)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
// Given: a running coder deployment
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
|
||||
// Given: we run the exp mcp command with allowed tools set
|
||||
inv, root := clitest.New(t, "exp", "mcp", "server", "--allowed-tools=coder_whoami,coder_list_templates")
|
||||
inv = inv.WithContext(cancelCtx)
|
||||
|
||||
pty := ptytest.New(t)
|
||||
inv.Stdin = pty.Input()
|
||||
inv.Stdout = pty.Output()
|
||||
clitest.SetupConfig(t, client, root)
|
||||
|
||||
cmdDone := make(chan struct{})
|
||||
go func() {
|
||||
defer close(cmdDone)
|
||||
err := inv.Run()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
// When: we send a tools/list request
|
||||
toolsPayload := `{"jsonrpc":"2.0","id":2,"method":"tools/list"}`
|
||||
pty.WriteLine(toolsPayload)
|
||||
_ = pty.ReadLine(ctx) // ignore echoed output
|
||||
output := pty.ReadLine(ctx)
|
||||
|
||||
cancel()
|
||||
<-cmdDone
|
||||
|
||||
// Then: we should only see the allowed tools in the response
|
||||
var toolsResponse struct {
|
||||
Result struct {
|
||||
Tools []struct {
|
||||
Name string `json:"name"`
|
||||
} `json:"tools"`
|
||||
} `json:"result"`
|
||||
}
|
||||
err := json.Unmarshal([]byte(output), &toolsResponse)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, toolsResponse.Result.Tools, 2, "should have exactly 2 tools")
|
||||
foundTools := make([]string, 0, 2)
|
||||
for _, tool := range toolsResponse.Result.Tools {
|
||||
foundTools = append(foundTools, tool.Name)
|
||||
}
|
||||
slices.Sort(foundTools)
|
||||
require.Equal(t, []string{"coder_list_templates", "coder_whoami"}, foundTools)
|
||||
})
|
||||
|
||||
t.Run("OK", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
cancelCtx, cancel := context.WithCancel(ctx)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
inv, root := clitest.New(t, "exp", "mcp", "server")
|
||||
inv = inv.WithContext(cancelCtx)
|
||||
|
||||
pty := ptytest.New(t)
|
||||
inv.Stdin = pty.Input()
|
||||
inv.Stdout = pty.Output()
|
||||
clitest.SetupConfig(t, client, root)
|
||||
|
||||
cmdDone := make(chan struct{})
|
||||
go func() {
|
||||
defer close(cmdDone)
|
||||
err := inv.Run()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
payload := `{"jsonrpc":"2.0","id":1,"method":"initialize"}`
|
||||
pty.WriteLine(payload)
|
||||
_ = pty.ReadLine(ctx) // ignore echoed output
|
||||
output := pty.ReadLine(ctx)
|
||||
cancel()
|
||||
<-cmdDone
|
||||
|
||||
// Ensure the initialize output is valid JSON
|
||||
t.Logf("/initialize output: %s", output)
|
||||
var initializeResponse map[string]interface{}
|
||||
err := json.Unmarshal([]byte(output), &initializeResponse)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "2.0", initializeResponse["jsonrpc"])
|
||||
require.Equal(t, 1.0, initializeResponse["id"])
|
||||
require.NotNil(t, initializeResponse["result"])
|
||||
})
|
||||
|
||||
t.Run("NoCredentials", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
cancelCtx, cancel := context.WithCancel(ctx)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
client := coderdtest.New(t, nil)
|
||||
inv, root := clitest.New(t, "exp", "mcp", "server")
|
||||
inv = inv.WithContext(cancelCtx)
|
||||
|
||||
pty := ptytest.New(t)
|
||||
inv.Stdin = pty.Input()
|
||||
inv.Stdout = pty.Output()
|
||||
clitest.SetupConfig(t, client, root)
|
||||
|
||||
err := inv.Run()
|
||||
assert.ErrorContains(t, err, "your session has expired")
|
||||
})
|
||||
}
|
4
go.mod
4
go.mod
@ -480,3 +480,7 @@ require (
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2 // indirect
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||
)
|
||||
|
||||
require github.com/mark3labs/mcp-go v0.17.0
|
||||
|
||||
require github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
|
||||
|
4
go.sum
4
go.sum
@ -658,6 +658,8 @@ github.com/makeworld-the-better-one/dither/v2 v2.4.0 h1:Az/dYXiTcwcRSe59Hzw4RI1r
|
||||
github.com/makeworld-the-better-one/dither/v2 v2.4.0/go.mod h1:VBtN8DXO7SNtyGmLiGA7IsFeKrBkQPze1/iAeM95arc=
|
||||
github.com/marekm4/color-extractor v1.2.1 h1:3Zb2tQsn6bITZ8MBVhc33Qn1k5/SEuZ18mrXGUqIwn0=
|
||||
github.com/marekm4/color-extractor v1.2.1/go.mod h1:90VjmiHI6M8ez9eYUaXLdcKnS+BAOp7w+NpwBdkJmpA=
|
||||
github.com/mark3labs/mcp-go v0.17.0 h1:5Ps6T7qXr7De/2QTqs9h6BKeZ/qdeUeGrgM5lPzi930=
|
||||
github.com/mark3labs/mcp-go v0.17.0/go.mod h1:KmJndYv7GIgcPVwEKJjNcbhVQ+hJGJhrCCB/9xITzpE=
|
||||
github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU=
|
||||
github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE=
|
||||
github.com/mattn/go-colorable v0.1.9/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc=
|
||||
@ -972,6 +974,8 @@ github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZ
|
||||
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
|
||||
github.com/yashtewari/glob-intersection v0.2.0 h1:8iuHdN88yYuCzCdjt0gDe+6bAhUwBeEWqThExu54RFg=
|
||||
github.com/yashtewari/glob-intersection v0.2.0/go.mod h1:LK7pIC3piUjovexikBbJ26Yml7g8xa5bsjfx2v1fwok=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
|
||||
github.com/yudai/gojsondiff v1.0.0 h1:27cbfqXLVEJ1o8I6v3y9lg8Ydm53EKqHXAOMxEGlCOA=
|
||||
github.com/yudai/gojsondiff v1.0.0/go.mod h1:AY32+k2cwILAkW1fbgxQ5mUmMiZFgLIV+FBNExI05xg=
|
||||
github.com/yudai/golcs v0.0.0-20170316035057-ecda9a501e82 h1:BHyfKlQyqbsFN5p3IfnEUduWvb9is428/nNb5L3U01M=
|
||||
|
643
mcp/mcp.go
Normal file
643
mcp/mcp.go
Normal file
@ -0,0 +1,643 @@
|
||||
package codermcp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
"github.com/mark3labs/mcp-go/server"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/sloghuman"
|
||||
"github.com/coder/coder/v2/buildinfo"
|
||||
"github.com/coder/coder/v2/coderd/util/ptr"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
)
|
||||
|
||||
type mcpOptions struct {
|
||||
instructions string
|
||||
logger *slog.Logger
|
||||
allowedTools []string
|
||||
}
|
||||
|
||||
// Option is a function that configures the MCP server.
|
||||
type Option func(*mcpOptions)
|
||||
|
||||
// WithInstructions sets the instructions for the MCP server.
|
||||
func WithInstructions(instructions string) Option {
|
||||
return func(o *mcpOptions) {
|
||||
o.instructions = instructions
|
||||
}
|
||||
}
|
||||
|
||||
// WithLogger sets the logger for the MCP server.
|
||||
func WithLogger(logger *slog.Logger) Option {
|
||||
return func(o *mcpOptions) {
|
||||
o.logger = logger
|
||||
}
|
||||
}
|
||||
|
||||
// WithAllowedTools sets the allowed tools for the MCP server.
|
||||
func WithAllowedTools(tools []string) Option {
|
||||
return func(o *mcpOptions) {
|
||||
o.allowedTools = tools
|
||||
}
|
||||
}
|
||||
|
||||
// NewStdio creates a new MCP stdio server with the given client and options.
|
||||
// It is the responsibility of the caller to start and stop the server.
|
||||
func NewStdio(client *codersdk.Client, opts ...Option) *server.StdioServer {
|
||||
options := &mcpOptions{
|
||||
instructions: ``,
|
||||
logger: ptr.Ref(slog.Make(sloghuman.Sink(os.Stdout))),
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
mcpSrv := server.NewMCPServer(
|
||||
"Coder Agent",
|
||||
buildinfo.Version(),
|
||||
server.WithInstructions(options.instructions),
|
||||
)
|
||||
|
||||
logger := slog.Make(sloghuman.Sink(os.Stdout))
|
||||
|
||||
// Register tools based on the allowed list (if specified)
|
||||
reg := AllTools()
|
||||
if len(options.allowedTools) > 0 {
|
||||
reg = reg.WithOnlyAllowed(options.allowedTools...)
|
||||
}
|
||||
reg.Register(mcpSrv, ToolDeps{
|
||||
Client: client,
|
||||
Logger: &logger,
|
||||
})
|
||||
|
||||
srv := server.NewStdioServer(mcpSrv)
|
||||
return srv
|
||||
}
|
||||
|
||||
// allTools is the list of all available tools. When adding a new tool,
|
||||
// make sure to update this list.
|
||||
var allTools = ToolRegistry{
|
||||
{
|
||||
Tool: mcp.NewTool("coder_report_task",
|
||||
mcp.WithDescription(`Report progress on a user task in Coder.
|
||||
Use this tool to keep the user informed about your progress with their request.
|
||||
For long-running operations, call this periodically to provide status updates.
|
||||
This is especially useful when performing multi-step operations like workspace creation or deployment.`),
|
||||
mcp.WithString("summary", mcp.Description(`A concise summary of your current progress on the task.
|
||||
|
||||
Good Summaries:
|
||||
- "Taking a look at the login page..."
|
||||
- "Found a bug! Fixing it now..."
|
||||
- "Investigating the GitHub Issue..."
|
||||
- "Waiting for workspace to start (1/3 resources ready)"
|
||||
- "Downloading template files from repository"`), mcp.Required()),
|
||||
mcp.WithString("link", mcp.Description(`A relevant URL related to your work, such as:
|
||||
- GitHub issue link
|
||||
- Pull request URL
|
||||
- Documentation reference
|
||||
- Workspace URL
|
||||
Use complete URLs (including https://) when possible.`), mcp.Required()),
|
||||
mcp.WithString("emoji", mcp.Description(`A relevant emoji that visually represents the current status:
|
||||
- 🔍 for investigating/searching
|
||||
- 🚀 for deploying/starting
|
||||
- 🐛 for debugging
|
||||
- ✅ for completion
|
||||
- ⏳ for waiting
|
||||
Choose an emoji that helps the user understand the current phase at a glance.`), mcp.Required()),
|
||||
mcp.WithBoolean("done", mcp.Description(`Whether the overall task the user requested is complete.
|
||||
Set to true only when the entire requested operation is finished successfully.
|
||||
For multi-step processes, use false until all steps are complete.`), mcp.Required()),
|
||||
),
|
||||
MakeHandler: handleCoderReportTask,
|
||||
},
|
||||
{
|
||||
Tool: mcp.NewTool("coder_whoami",
|
||||
mcp.WithDescription(`Get information about the currently logged-in Coder user.
|
||||
Returns JSON with the user's profile including fields: id, username, email, created_at, status, roles, etc.
|
||||
Use this to identify the current user context before performing workspace operations.
|
||||
This tool is useful for verifying permissions and checking the user's identity.
|
||||
|
||||
Common errors:
|
||||
- Authentication failure: The session may have expired
|
||||
- Server unavailable: The Coder deployment may be unreachable`),
|
||||
),
|
||||
MakeHandler: handleCoderWhoami,
|
||||
},
|
||||
{
|
||||
Tool: mcp.NewTool("coder_list_templates",
|
||||
mcp.WithDescription(`List all templates available on the Coder deployment.
|
||||
Returns JSON with detailed information about each template, including:
|
||||
- Template name, ID, and description
|
||||
- Creation/modification timestamps
|
||||
- Version information
|
||||
- Associated organization
|
||||
|
||||
Use this tool to discover available templates before creating workspaces.
|
||||
Templates define the infrastructure and configuration for workspaces.
|
||||
|
||||
Common errors:
|
||||
- Authentication failure: Check user permissions
|
||||
- No templates available: The deployment may not have any templates configured`),
|
||||
),
|
||||
MakeHandler: handleCoderListTemplates,
|
||||
},
|
||||
{
|
||||
Tool: mcp.NewTool("coder_list_workspaces",
|
||||
mcp.WithDescription(`List workspaces available on the Coder deployment.
|
||||
Returns JSON with workspace metadata including status, resources, and configurations.
|
||||
Use this before other workspace operations to find valid workspace names/IDs.
|
||||
Results are paginated - use offset and limit parameters for large deployments.
|
||||
|
||||
Common errors:
|
||||
- Authentication failure: Check user permissions
|
||||
- Invalid owner parameter: Ensure the owner exists`),
|
||||
mcp.WithString(`owner`, mcp.Description(`The username of the workspace owner to filter by.
|
||||
Defaults to "me" which represents the currently authenticated user.
|
||||
Use this to view workspaces belonging to other users (requires appropriate permissions).
|
||||
Special value: "me" - List workspaces owned by the authenticated user.`), mcp.DefaultString(codersdk.Me)),
|
||||
mcp.WithNumber(`offset`, mcp.Description(`Pagination offset - the starting index for listing workspaces.
|
||||
Used with the 'limit' parameter to implement pagination.
|
||||
For example, to get the second page of results with 10 items per page, use offset=10.
|
||||
Defaults to 0 (first page).`), mcp.DefaultNumber(0)),
|
||||
mcp.WithNumber(`limit`, mcp.Description(`Maximum number of workspaces to return in a single request.
|
||||
Used with the 'offset' parameter to implement pagination.
|
||||
Higher values return more results but may increase response time.
|
||||
Valid range: 1-100. Defaults to 10.`), mcp.DefaultNumber(10)),
|
||||
),
|
||||
MakeHandler: handleCoderListWorkspaces,
|
||||
},
|
||||
{
|
||||
Tool: mcp.NewTool("coder_get_workspace",
|
||||
mcp.WithDescription(`Get detailed information about a specific Coder workspace.
|
||||
Returns comprehensive JSON with the workspace's configuration, status, and resources.
|
||||
Use this to check workspace status before performing operations like exec or start/stop.
|
||||
The response includes the latest build status, agent connectivity, and resource details.
|
||||
|
||||
Common errors:
|
||||
- Workspace not found: Check the workspace name or ID
|
||||
- Permission denied: The user may not have access to this workspace`),
|
||||
mcp.WithString("workspace", mcp.Description(`The workspace ID (UUID) or name to retrieve.
|
||||
Can be specified as either:
|
||||
- Full UUID: e.g., "8a0b9c7d-1e2f-3a4b-5c6d-7e8f9a0b1c2d"
|
||||
- Workspace name: e.g., "dev", "python-project"
|
||||
Use coder_list_workspaces first if you're not sure about available workspace names.`), mcp.Required()),
|
||||
),
|
||||
MakeHandler: handleCoderGetWorkspace,
|
||||
},
|
||||
{
|
||||
Tool: mcp.NewTool("coder_workspace_exec",
|
||||
mcp.WithDescription(`Execute a shell command in a remote Coder workspace.
|
||||
Runs the specified command and returns the complete output (stdout/stderr).
|
||||
Use this for file operations, running build commands, or checking workspace state.
|
||||
The workspace must be running with a connected agent for this to succeed.
|
||||
|
||||
Before using this tool:
|
||||
1. Verify the workspace is running using coder_get_workspace
|
||||
2. Start the workspace if needed using coder_start_workspace
|
||||
|
||||
Common errors:
|
||||
- Workspace not running: Start the workspace first
|
||||
- Command not allowed: Check security restrictions
|
||||
- Agent not connected: The workspace may still be starting up`),
|
||||
mcp.WithString("workspace", mcp.Description(`The workspace ID (UUID) or name where the command will execute.
|
||||
Can be specified as either:
|
||||
- Full UUID: e.g., "8a0b9c7d-1e2f-3a4b-5c6d-7e8f9a0b1c2d"
|
||||
- Workspace name: e.g., "dev", "python-project"
|
||||
The workspace must be running with a connected agent.
|
||||
Use coder_get_workspace first to check the workspace status.`), mcp.Required()),
|
||||
mcp.WithString("command", mcp.Description(`The shell command to execute in the workspace.
|
||||
Commands are executed in the default shell of the workspace.
|
||||
|
||||
Examples:
|
||||
- "ls -la" - List files with details
|
||||
- "cd /path/to/directory && command" - Execute in specific directory
|
||||
- "cat ~/.bashrc" - View a file's contents
|
||||
- "python -m pip list" - List installed Python packages
|
||||
|
||||
Note: Very long-running commands may time out.`), mcp.Required()),
|
||||
),
|
||||
MakeHandler: handleCoderWorkspaceExec,
|
||||
},
|
||||
{
|
||||
Tool: mcp.NewTool("coder_workspace_transition",
|
||||
mcp.WithDescription(`Start or stop a running Coder workspace.
|
||||
If stopping, initiates the workspace stop transition.
|
||||
Only works on workspaces that are currently running or failed.
|
||||
|
||||
If starting, initiates the workspace start transition.
|
||||
Only works on workspaces that are currently stopped or failed.
|
||||
|
||||
Stopping or starting a workspace is an asynchronous operation - it may take several minutes to complete.
|
||||
|
||||
After calling this tool:
|
||||
1. Use coder_report_task to inform the user that the workspace is stopping or starting
|
||||
2. Use coder_get_workspace periodically to check for completion
|
||||
|
||||
Common errors:
|
||||
- Workspace already started/starting/stopped/stopping: No action needed
|
||||
- Cancellation failed: There may be issues with the underlying infrastructure
|
||||
- User doesn't own workspace: Permission issues`),
|
||||
mcp.WithString("workspace", mcp.Description(`The workspace ID (UUID) or name to start or stop.
|
||||
Can be specified as either:
|
||||
- Full UUID: e.g., "8a0b9c7d-1e2f-3a4b-5c6d-7e8f9a0b1c2d"
|
||||
- Workspace name: e.g., "dev", "python-project"
|
||||
The workspace must be in a running state to be stopped, or in a stopped or failed state to be started.
|
||||
Use coder_get_workspace first to check the current workspace status.`), mcp.Required()),
|
||||
mcp.WithString("transition", mcp.Description(`The transition to apply to the workspace.
|
||||
Can be either "start" or "stop".`)),
|
||||
),
|
||||
MakeHandler: handleCoderWorkspaceTransition,
|
||||
},
|
||||
}
|
||||
|
||||
// ToolDeps contains all dependencies needed by tool handlers
|
||||
type ToolDeps struct {
|
||||
Client *codersdk.Client
|
||||
Logger *slog.Logger
|
||||
}
|
||||
|
||||
// ToolHandler associates a tool with its handler creation function
|
||||
type ToolHandler struct {
|
||||
Tool mcp.Tool
|
||||
MakeHandler func(ToolDeps) server.ToolHandlerFunc
|
||||
}
|
||||
|
||||
// ToolRegistry is a map of available tools with their handler creation
|
||||
// functions
|
||||
type ToolRegistry []ToolHandler
|
||||
|
||||
// WithOnlyAllowed returns a new ToolRegistry containing only the tools
|
||||
// specified in the allowed list.
|
||||
func (r ToolRegistry) WithOnlyAllowed(allowed ...string) ToolRegistry {
|
||||
if len(allowed) == 0 {
|
||||
return []ToolHandler{}
|
||||
}
|
||||
|
||||
filtered := make(ToolRegistry, 0, len(r))
|
||||
|
||||
// The overhead of a map lookup is likely higher than a linear scan
|
||||
// for a small number of tools.
|
||||
for _, entry := range r {
|
||||
if slices.Contains(allowed, entry.Tool.Name) {
|
||||
filtered = append(filtered, entry)
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
// Register registers all tools in the registry with the given tool adder
|
||||
// and dependencies.
|
||||
func (r ToolRegistry) Register(srv *server.MCPServer, deps ToolDeps) {
|
||||
for _, entry := range r {
|
||||
srv.AddTool(entry.Tool, entry.MakeHandler(deps))
|
||||
}
|
||||
}
|
||||
|
||||
// AllTools returns all available tools.
|
||||
func AllTools() ToolRegistry {
|
||||
// return a copy of allTools to avoid mutating the original
|
||||
return slices.Clone(allTools)
|
||||
}
|
||||
|
||||
type handleCoderReportTaskArgs struct {
|
||||
Summary string `json:"summary"`
|
||||
Link string `json:"link"`
|
||||
Emoji string `json:"emoji"`
|
||||
Done bool `json:"done"`
|
||||
}
|
||||
|
||||
// Example payload:
|
||||
// {"jsonrpc":"2.0","id":1,"method":"tools/call", "params": {"name": "coder_report_task", "arguments": {"summary": "I'm working on the login page.", "link": "https://github.com/coder/coder/pull/1234", "emoji": "🔍", "done": false}}}
|
||||
func handleCoderReportTask(deps ToolDeps) server.ToolHandlerFunc {
|
||||
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
if deps.Client == nil {
|
||||
return nil, xerrors.New("developer error: client is required")
|
||||
}
|
||||
|
||||
// Convert the request parameters to a json.RawMessage so we can unmarshal
|
||||
// them into the correct struct.
|
||||
args, err := unmarshalArgs[handleCoderReportTaskArgs](request.Params.Arguments)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to unmarshal arguments: %w", err)
|
||||
}
|
||||
|
||||
// TODO: Waiting on support for tasks.
|
||||
deps.Logger.Info(ctx, "report task tool called", slog.F("summary", args.Summary), slog.F("link", args.Link), slog.F("done", args.Done), slog.F("emoji", args.Emoji))
|
||||
/*
|
||||
err := sdk.PostTask(ctx, agentsdk.PostTaskRequest{
|
||||
Reporter: "claude",
|
||||
Summary: summary,
|
||||
URL: link,
|
||||
Completion: done,
|
||||
Icon: emoji,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
*/
|
||||
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
mcp.NewTextContent("Thanks for reporting!"),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Example payload:
|
||||
// {"jsonrpc":"2.0","id":1,"method":"tools/call", "params": {"name": "coder_whoami", "arguments": {}}}
|
||||
func handleCoderWhoami(deps ToolDeps) server.ToolHandlerFunc {
|
||||
return func(ctx context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
if deps.Client == nil {
|
||||
return nil, xerrors.New("developer error: client is required")
|
||||
}
|
||||
me, err := deps.Client.User(ctx, codersdk.Me)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("Failed to fetch the current user: %s", err.Error())
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err := json.NewEncoder(&buf).Encode(me); err != nil {
|
||||
return nil, xerrors.Errorf("Failed to encode the current user: %s", err.Error())
|
||||
}
|
||||
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
mcp.NewTextContent(strings.TrimSpace(buf.String())),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
type handleCoderListWorkspacesArgs struct {
|
||||
Owner string `json:"owner"`
|
||||
Offset int `json:"offset"`
|
||||
Limit int `json:"limit"`
|
||||
}
|
||||
|
||||
// Example payload:
|
||||
// {"jsonrpc":"2.0","id":1,"method":"tools/call", "params": {"name": "coder_list_workspaces", "arguments": {"owner": "me", "offset": 0, "limit": 10}}}
|
||||
func handleCoderListWorkspaces(deps ToolDeps) server.ToolHandlerFunc {
|
||||
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
if deps.Client == nil {
|
||||
return nil, xerrors.New("developer error: client is required")
|
||||
}
|
||||
args, err := unmarshalArgs[handleCoderListWorkspacesArgs](request.Params.Arguments)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to unmarshal arguments: %w", err)
|
||||
}
|
||||
|
||||
workspaces, err := deps.Client.Workspaces(ctx, codersdk.WorkspaceFilter{
|
||||
Owner: args.Owner,
|
||||
Offset: args.Offset,
|
||||
Limit: args.Limit,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to fetch workspaces: %w", err)
|
||||
}
|
||||
|
||||
// Encode it as JSON. TODO: It might be nicer for the agent to have a tabulated response.
|
||||
data, err := json.Marshal(workspaces)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to encode workspaces: %s", err.Error())
|
||||
}
|
||||
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
mcp.NewTextContent(string(data)),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
type handleCoderGetWorkspaceArgs struct {
|
||||
Workspace string `json:"workspace"`
|
||||
}
|
||||
|
||||
// Example payload:
|
||||
// {"jsonrpc":"2.0","id":1,"method":"tools/call", "params": {"name": "coder_get_workspace", "arguments": {"workspace": "dev"}}}
|
||||
func handleCoderGetWorkspace(deps ToolDeps) server.ToolHandlerFunc {
|
||||
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
if deps.Client == nil {
|
||||
return nil, xerrors.New("developer error: client is required")
|
||||
}
|
||||
args, err := unmarshalArgs[handleCoderGetWorkspaceArgs](request.Params.Arguments)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to unmarshal arguments: %w", err)
|
||||
}
|
||||
|
||||
workspace, err := getWorkspaceByIDOrOwnerName(ctx, deps.Client, args.Workspace)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to fetch workspace: %w", err)
|
||||
}
|
||||
|
||||
workspaceJSON, err := json.Marshal(workspace)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to encode workspace: %w", err)
|
||||
}
|
||||
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
mcp.NewTextContent(string(workspaceJSON)),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
type handleCoderWorkspaceExecArgs struct {
|
||||
Workspace string `json:"workspace"`
|
||||
Command string `json:"command"`
|
||||
}
|
||||
|
||||
// Example payload:
|
||||
// {"jsonrpc":"2.0","id":1,"method":"tools/call", "params": {"name": "coder_workspace_exec", "arguments": {"workspace": "dev", "command": "ps -ef"}}}
|
||||
func handleCoderWorkspaceExec(deps ToolDeps) server.ToolHandlerFunc {
|
||||
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
if deps.Client == nil {
|
||||
return nil, xerrors.New("developer error: client is required")
|
||||
}
|
||||
args, err := unmarshalArgs[handleCoderWorkspaceExecArgs](request.Params.Arguments)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to unmarshal arguments: %w", err)
|
||||
}
|
||||
|
||||
// Attempt to fetch the workspace. We may get a UUID or a name, so try to
|
||||
// handle both.
|
||||
ws, err := getWorkspaceByIDOrOwnerName(ctx, deps.Client, args.Workspace)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to fetch workspace: %w", err)
|
||||
}
|
||||
|
||||
// Ensure the workspace is started.
|
||||
// Select the first agent of the workspace.
|
||||
var agt *codersdk.WorkspaceAgent
|
||||
for _, r := range ws.LatestBuild.Resources {
|
||||
for _, a := range r.Agents {
|
||||
if a.Status != codersdk.WorkspaceAgentConnected {
|
||||
continue
|
||||
}
|
||||
agt = ptr.Ref(a)
|
||||
break
|
||||
}
|
||||
}
|
||||
if agt == nil {
|
||||
return nil, xerrors.Errorf("no connected agents for workspace %s", ws.ID)
|
||||
}
|
||||
|
||||
startedAt := time.Now()
|
||||
conn, err := workspacesdk.New(deps.Client).AgentReconnectingPTY(ctx, workspacesdk.WorkspaceAgentReconnectingPTYOpts{
|
||||
AgentID: agt.ID,
|
||||
Reconnect: uuid.New(),
|
||||
Width: 80,
|
||||
Height: 24,
|
||||
Command: args.Command,
|
||||
BackendType: "buffered", // the screen backend is annoying to use here.
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to open reconnecting PTY: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
connectedAt := time.Now()
|
||||
|
||||
var buf bytes.Buffer
|
||||
if _, err := io.Copy(&buf, conn); err != nil {
|
||||
// EOF is expected when the connection is closed.
|
||||
// We can ignore this error.
|
||||
if !errors.Is(err, io.EOF) {
|
||||
return nil, xerrors.Errorf("failed to read from reconnecting PTY: %w", err)
|
||||
}
|
||||
}
|
||||
completedAt := time.Now()
|
||||
connectionTime := connectedAt.Sub(startedAt)
|
||||
executionTime := completedAt.Sub(connectedAt)
|
||||
|
||||
resp := map[string]string{
|
||||
"connection_time": connectionTime.String(),
|
||||
"execution_time": executionTime.String(),
|
||||
"output": buf.String(),
|
||||
}
|
||||
respJSON, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to encode workspace build: %w", err)
|
||||
}
|
||||
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
mcp.NewTextContent(string(respJSON)),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Example payload:
|
||||
// {"jsonrpc":"2.0","id":1,"method":"tools/call", "params": {"name": "coder_list_templates", "arguments": {}}}
|
||||
func handleCoderListTemplates(deps ToolDeps) server.ToolHandlerFunc {
|
||||
return func(ctx context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
if deps.Client == nil {
|
||||
return nil, xerrors.New("developer error: client is required")
|
||||
}
|
||||
templates, err := deps.Client.Templates(ctx, codersdk.TemplateFilter{})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to fetch templates: %w", err)
|
||||
}
|
||||
|
||||
templateJSON, err := json.Marshal(templates)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to encode templates: %w", err)
|
||||
}
|
||||
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
mcp.NewTextContent(string(templateJSON)),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
type handleCoderWorkspaceTransitionArgs struct {
|
||||
Workspace string `json:"workspace"`
|
||||
Transition string `json:"transition"`
|
||||
}
|
||||
|
||||
// Example payload:
|
||||
// {"jsonrpc":"2.0","id":1,"method":"tools/call", "params": {"name":
|
||||
// "coder_workspace_transition", "arguments": {"workspace": "dev", "transition": "stop"}}}
|
||||
func handleCoderWorkspaceTransition(deps ToolDeps) server.ToolHandlerFunc {
|
||||
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
if deps.Client == nil {
|
||||
return nil, xerrors.New("developer error: client is required")
|
||||
}
|
||||
args, err := unmarshalArgs[handleCoderWorkspaceTransitionArgs](request.Params.Arguments)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to unmarshal arguments: %w", err)
|
||||
}
|
||||
|
||||
workspace, err := getWorkspaceByIDOrOwnerName(ctx, deps.Client, args.Workspace)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to fetch workspace: %w", err)
|
||||
}
|
||||
|
||||
wsTransition := codersdk.WorkspaceTransition(args.Transition)
|
||||
switch wsTransition {
|
||||
case codersdk.WorkspaceTransitionStart:
|
||||
case codersdk.WorkspaceTransitionStop:
|
||||
default:
|
||||
return nil, xerrors.New("invalid transition")
|
||||
}
|
||||
|
||||
// We're not going to check the workspace status here as it is checked on the
|
||||
// server side.
|
||||
wb, err := deps.Client.CreateWorkspaceBuild(ctx, workspace.ID, codersdk.CreateWorkspaceBuildRequest{
|
||||
Transition: wsTransition,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to stop workspace: %w", err)
|
||||
}
|
||||
|
||||
resp := map[string]any{"status": wb.Status, "transition": wb.Transition}
|
||||
respJSON, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("failed to encode workspace build: %w", err)
|
||||
}
|
||||
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
mcp.NewTextContent(string(respJSON)),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func getWorkspaceByIDOrOwnerName(ctx context.Context, client *codersdk.Client, identifier string) (codersdk.Workspace, error) {
|
||||
if wsid, err := uuid.Parse(identifier); err == nil {
|
||||
return client.Workspace(ctx, wsid)
|
||||
}
|
||||
return client.WorkspaceByOwnerAndName(ctx, codersdk.Me, identifier, codersdk.WorkspaceOptions{})
|
||||
}
|
||||
|
||||
// unmarshalArgs is a helper function to convert the map[string]any we get from
|
||||
// the MCP server into a typed struct. It does this by marshaling and unmarshalling
|
||||
// the arguments.
|
||||
func unmarshalArgs[T any](args map[string]interface{}) (t T, err error) {
|
||||
argsJSON, err := json.Marshal(args)
|
||||
if err != nil {
|
||||
return t, xerrors.Errorf("failed to marshal arguments: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(argsJSON, &t); err != nil {
|
||||
return t, xerrors.Errorf("failed to unmarshal arguments: %w", err)
|
||||
}
|
||||
return t, nil
|
||||
}
|
361
mcp/mcp_test.go
Normal file
361
mcp/mcp_test.go
Normal file
@ -0,0 +1,361 @@
|
||||
package codermcp_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
"github.com/mark3labs/mcp-go/server"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/agent/agenttest"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbfake"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
codermcp "github.com/coder/coder/v2/mcp"
|
||||
"github.com/coder/coder/v2/pty/ptytest"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
// These tests are dependent on the state of the coder server.
|
||||
// Running them in parallel is prone to racy behavior.
|
||||
// nolint:tparallel,paralleltest
|
||||
func TestCoderTools(t *testing.T) {
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skip("skipping on non-linux due to pty issues")
|
||||
}
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
// Given: a coder server, workspace, and agent.
|
||||
client, store := coderdtest.NewWithDatabase(t, nil)
|
||||
owner := coderdtest.CreateFirstUser(t, client)
|
||||
// Given: a member user with which to test the tools.
|
||||
memberClient, member := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
|
||||
// Given: a workspace with an agent.
|
||||
r := dbfake.WorkspaceBuild(t, store, database.WorkspaceTable{
|
||||
OrganizationID: owner.OrganizationID,
|
||||
OwnerID: member.ID,
|
||||
}).WithAgent().Do()
|
||||
|
||||
// Note: we want to test the list_workspaces tool before starting the
|
||||
// workspace agent. Starting the workspace agent will modify the workspace
|
||||
// state, which will affect the results of the list_workspaces tool.
|
||||
listWorkspacesDone := make(chan struct{})
|
||||
agentStarted := make(chan struct{})
|
||||
go func() {
|
||||
defer close(agentStarted)
|
||||
<-listWorkspacesDone
|
||||
agt := agenttest.New(t, client.URL, r.AgentToken)
|
||||
t.Cleanup(func() {
|
||||
_ = agt.Close()
|
||||
})
|
||||
_ = coderdtest.NewWorkspaceAgentWaiter(t, client, r.Workspace.ID).Wait()
|
||||
}()
|
||||
|
||||
// Given: a MCP server listening on a pty.
|
||||
pty := ptytest.New(t)
|
||||
mcpSrv, closeSrv := startTestMCPServer(ctx, t, pty.Input(), pty.Output())
|
||||
t.Cleanup(func() {
|
||||
_ = closeSrv()
|
||||
})
|
||||
|
||||
// Register tools using our registry
|
||||
logger := slogtest.Make(t, nil)
|
||||
codermcp.AllTools().Register(mcpSrv, codermcp.ToolDeps{
|
||||
Client: memberClient,
|
||||
Logger: &logger,
|
||||
})
|
||||
|
||||
t.Run("coder_list_templates", func(t *testing.T) {
|
||||
// When: the coder_list_templates tool is called
|
||||
ctr := makeJSONRPCRequest(t, "tools/call", "coder_list_templates", map[string]any{})
|
||||
|
||||
pty.WriteLine(ctr)
|
||||
_ = pty.ReadLine(ctx) // skip the echo
|
||||
|
||||
templates, err := memberClient.Templates(ctx, codersdk.TemplateFilter{})
|
||||
require.NoError(t, err)
|
||||
templatesJSON, err := json.Marshal(templates)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Then: the response is a list of templates visible to the user.
|
||||
expected := makeJSONRPCTextResponse(t, string(templatesJSON))
|
||||
actual := pty.ReadLine(ctx)
|
||||
testutil.RequireJSONEq(t, expected, actual)
|
||||
})
|
||||
|
||||
t.Run("coder_report_task", func(t *testing.T) {
|
||||
// When: the coder_report_task tool is called
|
||||
ctr := makeJSONRPCRequest(t, "tools/call", "coder_report_task", map[string]any{
|
||||
"summary": "Test summary",
|
||||
"link": "https://example.com",
|
||||
"emoji": "🔍",
|
||||
"done": false,
|
||||
"coder_url": client.URL.String(),
|
||||
"coder_session_token": client.SessionToken(),
|
||||
})
|
||||
|
||||
pty.WriteLine(ctr)
|
||||
_ = pty.ReadLine(ctx) // skip the echo
|
||||
|
||||
// Then: the response is a success message.
|
||||
// TODO: check the task was created. This functionality is not yet implemented.
|
||||
expected := makeJSONRPCTextResponse(t, "Thanks for reporting!")
|
||||
actual := pty.ReadLine(ctx)
|
||||
testutil.RequireJSONEq(t, expected, actual)
|
||||
})
|
||||
|
||||
t.Run("coder_whoami", func(t *testing.T) {
|
||||
// When: the coder_whoami tool is called
|
||||
me, err := memberClient.User(ctx, codersdk.Me)
|
||||
require.NoError(t, err)
|
||||
meJSON, err := json.Marshal(me)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctr := makeJSONRPCRequest(t, "tools/call", "coder_whoami", map[string]any{})
|
||||
|
||||
pty.WriteLine(ctr)
|
||||
_ = pty.ReadLine(ctx) // skip the echo
|
||||
|
||||
// Then: the response is a valid JSON respresentation of the calling user.
|
||||
expected := makeJSONRPCTextResponse(t, string(meJSON))
|
||||
actual := pty.ReadLine(ctx)
|
||||
testutil.RequireJSONEq(t, expected, actual)
|
||||
})
|
||||
|
||||
t.Run("coder_list_workspaces", func(t *testing.T) {
|
||||
defer close(listWorkspacesDone)
|
||||
// When: the coder_list_workspaces tool is called
|
||||
ctr := makeJSONRPCRequest(t, "tools/call", "coder_list_workspaces", map[string]any{
|
||||
"coder_url": client.URL.String(),
|
||||
"coder_session_token": client.SessionToken(),
|
||||
})
|
||||
|
||||
pty.WriteLine(ctr)
|
||||
_ = pty.ReadLine(ctx) // skip the echo
|
||||
|
||||
ws, err := memberClient.Workspaces(ctx, codersdk.WorkspaceFilter{})
|
||||
require.NoError(t, err)
|
||||
wsJSON, err := json.Marshal(ws)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Then: the response is a valid JSON respresentation of the calling user's workspaces.
|
||||
expected := makeJSONRPCTextResponse(t, string(wsJSON))
|
||||
actual := pty.ReadLine(ctx)
|
||||
testutil.RequireJSONEq(t, expected, actual)
|
||||
})
|
||||
|
||||
t.Run("coder_get_workspace", func(t *testing.T) {
|
||||
// Given: the workspace agent is connected.
|
||||
// The act of starting the agent will modify the workspace state.
|
||||
<-agentStarted
|
||||
// When: the coder_get_workspace tool is called
|
||||
ctr := makeJSONRPCRequest(t, "tools/call", "coder_get_workspace", map[string]any{
|
||||
"workspace": r.Workspace.ID.String(),
|
||||
})
|
||||
|
||||
pty.WriteLine(ctr)
|
||||
_ = pty.ReadLine(ctx) // skip the echo
|
||||
|
||||
ws, err := memberClient.Workspace(ctx, r.Workspace.ID)
|
||||
require.NoError(t, err)
|
||||
wsJSON, err := json.Marshal(ws)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Then: the response is a valid JSON respresentation of the workspace.
|
||||
expected := makeJSONRPCTextResponse(t, string(wsJSON))
|
||||
actual := pty.ReadLine(ctx)
|
||||
testutil.RequireJSONEq(t, expected, actual)
|
||||
})
|
||||
|
||||
// NOTE: this test runs after the list_workspaces tool is called.
|
||||
t.Run("coder_workspace_exec", func(t *testing.T) {
|
||||
// Given: the workspace agent is connected
|
||||
<-agentStarted
|
||||
|
||||
// When: the coder_workspace_exec tools is called with a command
|
||||
randString := testutil.GetRandomName(t)
|
||||
ctr := makeJSONRPCRequest(t, "tools/call", "coder_workspace_exec", map[string]any{
|
||||
"workspace": r.Workspace.ID.String(),
|
||||
"command": "echo " + randString,
|
||||
"coder_url": client.URL.String(),
|
||||
"coder_session_token": client.SessionToken(),
|
||||
})
|
||||
|
||||
pty.WriteLine(ctr)
|
||||
_ = pty.ReadLine(ctx) // skip the echo
|
||||
|
||||
// Then: the response is the output of the command.
|
||||
actual := pty.ReadLine(ctx)
|
||||
require.Contains(t, actual, randString)
|
||||
})
|
||||
|
||||
// NOTE: this test runs after the list_workspaces tool is called.
|
||||
t.Run("tool_restrictions", func(t *testing.T) {
|
||||
// Given: the workspace agent is connected
|
||||
<-agentStarted
|
||||
|
||||
// Given: a restricted MCP server with only allowed tools and commands
|
||||
restrictedPty := ptytest.New(t)
|
||||
allowedTools := []string{"coder_workspace_exec"}
|
||||
restrictedMCPSrv, closeRestrictedSrv := startTestMCPServer(ctx, t, restrictedPty.Input(), restrictedPty.Output())
|
||||
t.Cleanup(func() {
|
||||
_ = closeRestrictedSrv()
|
||||
})
|
||||
codermcp.AllTools().
|
||||
WithOnlyAllowed(allowedTools...).
|
||||
Register(restrictedMCPSrv, codermcp.ToolDeps{
|
||||
Client: memberClient,
|
||||
Logger: &logger,
|
||||
})
|
||||
|
||||
// When: the tools/list command is called
|
||||
toolsListCmd := makeJSONRPCRequest(t, "tools/list", "", nil)
|
||||
restrictedPty.WriteLine(toolsListCmd)
|
||||
_ = restrictedPty.ReadLine(ctx) // skip the echo
|
||||
|
||||
// Then: the response is a list of only the allowed tools.
|
||||
toolsListResponse := restrictedPty.ReadLine(ctx)
|
||||
require.Contains(t, toolsListResponse, "coder_workspace_exec")
|
||||
require.NotContains(t, toolsListResponse, "coder_whoami")
|
||||
|
||||
// When: a disallowed tool is called
|
||||
disallowedToolCmd := makeJSONRPCRequest(t, "tools/call", "coder_whoami", map[string]any{})
|
||||
restrictedPty.WriteLine(disallowedToolCmd)
|
||||
_ = restrictedPty.ReadLine(ctx) // skip the echo
|
||||
|
||||
// Then: the response is an error indicating the tool is not available.
|
||||
disallowedToolResponse := restrictedPty.ReadLine(ctx)
|
||||
require.Contains(t, disallowedToolResponse, "error")
|
||||
require.Contains(t, disallowedToolResponse, "not found")
|
||||
})
|
||||
|
||||
t.Run("coder_workspace_transition_stop", func(t *testing.T) {
|
||||
// Given: a separate workspace in the running state
|
||||
stopWs := dbfake.WorkspaceBuild(t, store, database.WorkspaceTable{
|
||||
OrganizationID: owner.OrganizationID,
|
||||
OwnerID: member.ID,
|
||||
}).WithAgent().Do()
|
||||
|
||||
// When: the coder_workspace_transition tool is called with a stop transition
|
||||
ctr := makeJSONRPCRequest(t, "tools/call", "coder_workspace_transition", map[string]any{
|
||||
"workspace": stopWs.Workspace.ID.String(),
|
||||
"transition": "stop",
|
||||
})
|
||||
|
||||
pty.WriteLine(ctr)
|
||||
_ = pty.ReadLine(ctx) // skip the echo
|
||||
|
||||
// Then: the response is as expected.
|
||||
expected := makeJSONRPCTextResponse(t, `{"status":"pending","transition":"stop"}`) // no provisionerd yet
|
||||
actual := pty.ReadLine(ctx)
|
||||
testutil.RequireJSONEq(t, expected, actual)
|
||||
})
|
||||
|
||||
t.Run("coder_workspace_transition_start", func(t *testing.T) {
|
||||
// Given: a separate workspace in the stopped state
|
||||
stopWs := dbfake.WorkspaceBuild(t, store, database.WorkspaceTable{
|
||||
OrganizationID: owner.OrganizationID,
|
||||
OwnerID: member.ID,
|
||||
}).Seed(database.WorkspaceBuild{
|
||||
Transition: database.WorkspaceTransitionStop,
|
||||
}).Do()
|
||||
|
||||
// When: the coder_workspace_transition tool is called with a start transition
|
||||
ctr := makeJSONRPCRequest(t, "tools/call", "coder_workspace_transition", map[string]any{
|
||||
"workspace": stopWs.Workspace.ID.String(),
|
||||
"transition": "start",
|
||||
})
|
||||
|
||||
pty.WriteLine(ctr)
|
||||
_ = pty.ReadLine(ctx) // skip the echo
|
||||
|
||||
// Then: the response is as expected
|
||||
expected := makeJSONRPCTextResponse(t, `{"status":"pending","transition":"start"}`) // no provisionerd yet
|
||||
actual := pty.ReadLine(ctx)
|
||||
testutil.RequireJSONEq(t, expected, actual)
|
||||
})
|
||||
}
|
||||
|
||||
// makeJSONRPCRequest is a helper function that makes a JSON RPC request.
|
||||
func makeJSONRPCRequest(t *testing.T, method, name string, args map[string]any) string {
|
||||
t.Helper()
|
||||
req := mcp.JSONRPCRequest{
|
||||
ID: "1",
|
||||
JSONRPC: "2.0",
|
||||
Request: mcp.Request{Method: method},
|
||||
Params: struct { // Unfortunately, there is no type for this yet.
|
||||
Name string "json:\"name\""
|
||||
Arguments map[string]any "json:\"arguments,omitempty\""
|
||||
Meta *struct {
|
||||
ProgressToken mcp.ProgressToken "json:\"progressToken,omitempty\""
|
||||
} "json:\"_meta,omitempty\""
|
||||
}{
|
||||
Name: name,
|
||||
Arguments: args,
|
||||
},
|
||||
}
|
||||
bs, err := json.Marshal(req)
|
||||
require.NoError(t, err, "failed to marshal JSON RPC request")
|
||||
return string(bs)
|
||||
}
|
||||
|
||||
// makeJSONRPCTextResponse is a helper function that makes a JSON RPC text response
|
||||
func makeJSONRPCTextResponse(t *testing.T, text string) string {
|
||||
t.Helper()
|
||||
|
||||
resp := mcp.JSONRPCResponse{
|
||||
ID: "1",
|
||||
JSONRPC: "2.0",
|
||||
Result: mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
mcp.NewTextContent(text),
|
||||
},
|
||||
},
|
||||
}
|
||||
bs, err := json.Marshal(resp)
|
||||
require.NoError(t, err, "failed to marshal JSON RPC response")
|
||||
return string(bs)
|
||||
}
|
||||
|
||||
// startTestMCPServer is a helper function that starts a MCP server listening on
|
||||
// a pty. It is the responsibility of the caller to close the server.
|
||||
func startTestMCPServer(ctx context.Context, t testing.TB, stdin io.Reader, stdout io.Writer) (*server.MCPServer, func() error) {
|
||||
t.Helper()
|
||||
|
||||
mcpSrv := server.NewMCPServer(
|
||||
"Test Server",
|
||||
"0.0.0",
|
||||
server.WithInstructions(""),
|
||||
server.WithLogging(),
|
||||
)
|
||||
|
||||
stdioSrv := server.NewStdioServer(mcpSrv)
|
||||
|
||||
cancelCtx, cancel := context.WithCancel(ctx)
|
||||
closeCh := make(chan struct{})
|
||||
done := make(chan error)
|
||||
go func() {
|
||||
defer close(done)
|
||||
srvErr := stdioSrv.Listen(cancelCtx, stdin, stdout)
|
||||
done <- srvErr
|
||||
}()
|
||||
|
||||
go func() {
|
||||
select {
|
||||
case <-closeCh:
|
||||
cancel()
|
||||
case <-done:
|
||||
cancel()
|
||||
}
|
||||
}()
|
||||
|
||||
return mcpSrv, func() error {
|
||||
close(closeCh)
|
||||
return <-done
|
||||
}
|
||||
}
|
27
testutil/json.go
Normal file
27
testutil/json.go
Normal file
@ -0,0 +1,27 @@
|
||||
package testutil
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
// RequireJSONEq is like assert.RequireJSONEq, but it's actually readable.
|
||||
// Note that this calls t.Fatalf under the hood, so it should never
|
||||
// be called in a goroutine.
|
||||
func RequireJSONEq(t *testing.T, expected, actual string) {
|
||||
t.Helper()
|
||||
|
||||
var expectedJSON, actualJSON any
|
||||
if err := json.Unmarshal([]byte(expected), &expectedJSON); err != nil {
|
||||
t.Fatalf("failed to unmarshal expected JSON: %s", err)
|
||||
}
|
||||
if err := json.Unmarshal([]byte(actual), &actualJSON); err != nil {
|
||||
t.Fatalf("failed to unmarshal actual JSON: %s", err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(expectedJSON, actualJSON); diff != "" {
|
||||
t.Fatalf("JSON diff (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user