feat(cli): add coder exp mcp command (#17066)

Adds a `coder exp mcp` command which will start a local MCP server
listening on stdio with the following capabilities:
* Show logged in user (`coder whoami`)
* List workspaces (`coder list`)
* List templates (`coder templates list`)
* Start a workspace (`coder start`)
* Stop a workspace (`coder stop`)
* Fetch a single workspace (no direct CLI analogue)
* Execute a command inside a workspace (`coder exp rpty`)
* Report the status of a task (currently a no-op, pending task support)

This can be tested as follows:

```
# Start a local Coder server.
./scripts/develop.sh
# Start a workspace. Currently, creating workspaces is not supported.
./scripts/coder-dev.sh create -t docker --yes
# Add the MCP to your Claude config.
claude mcp add coder ./scripts/coder-dev.sh exp mcp
# Tell Claude to do something Coder-related. You may need to nudge it to use the tools.
claude 'start a docker workspace and tell me what version of python is installed'
```
This commit is contained in:
Cian Johnston
2025-03-31 18:52:09 +01:00
committed by GitHub
parent 8ea956fc11
commit 057cbd4d80
9 changed files with 1469 additions and 5 deletions

View File

@ -11,7 +11,9 @@ import (
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/cli/config"
@ -117,11 +119,7 @@ func TestGoldenFile(t *testing.T, fileName string, actual []byte, replacements m
require.NoError(t, err, "read golden file, run \"make gen/golden-files\" and commit the changes")
expected = normalizeGoldenFile(t, expected)
require.Equal(
t, string(expected), string(actual),
"golden file mismatch: %s, run \"make gen/golden-files\", verify and commit the changes",
goldenPath,
)
assert.Empty(t, cmp.Diff(string(expected), string(actual)), "golden file mismatch (-want +got): %s, run \"make gen/golden-files\", verify and commit the changes", goldenPath)
}
// normalizeGoldenFile replaces any strings that are system or timing dependent

View File

@ -13,6 +13,7 @@ func (r *RootCmd) expCmd() *serpent.Command {
Children: []*serpent.Command{
r.scaletestCmd(),
r.errorExample(),
r.mcpCommand(),
r.promptExample(),
r.rptyCommand(),
},

284
cli/exp_mcp.go Normal file
View File

@ -0,0 +1,284 @@
package cli
import (
"context"
"encoding/json"
"errors"
"log"
"os"
"path/filepath"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/sloghuman"
"github.com/coder/coder/v2/cli/cliui"
"github.com/coder/coder/v2/codersdk"
codermcp "github.com/coder/coder/v2/mcp"
"github.com/coder/serpent"
)
func (r *RootCmd) mcpCommand() *serpent.Command {
cmd := &serpent.Command{
Use: "mcp",
Short: "Run the Coder MCP server and configure it to work with AI tools.",
Long: "The Coder MCP server allows you to automatically create workspaces with parameters.",
Handler: func(i *serpent.Invocation) error {
return i.Command.HelpHandler(i)
},
Children: []*serpent.Command{
r.mcpConfigure(),
r.mcpServer(),
},
}
return cmd
}
func (r *RootCmd) mcpConfigure() *serpent.Command {
cmd := &serpent.Command{
Use: "configure",
Short: "Automatically configure the MCP server.",
Handler: func(i *serpent.Invocation) error {
return i.Command.HelpHandler(i)
},
Children: []*serpent.Command{
r.mcpConfigureClaudeDesktop(),
r.mcpConfigureClaudeCode(),
r.mcpConfigureCursor(),
},
}
return cmd
}
func (*RootCmd) mcpConfigureClaudeDesktop() *serpent.Command {
cmd := &serpent.Command{
Use: "claude-desktop",
Short: "Configure the Claude Desktop server.",
Handler: func(_ *serpent.Invocation) error {
configPath, err := os.UserConfigDir()
if err != nil {
return err
}
configPath = filepath.Join(configPath, "Claude")
err = os.MkdirAll(configPath, 0o755)
if err != nil {
return err
}
configPath = filepath.Join(configPath, "claude_desktop_config.json")
_, err = os.Stat(configPath)
if err != nil {
if !os.IsNotExist(err) {
return err
}
}
contents := map[string]any{}
data, err := os.ReadFile(configPath)
if err != nil {
if !os.IsNotExist(err) {
return err
}
} else {
err = json.Unmarshal(data, &contents)
if err != nil {
return err
}
}
binPath, err := os.Executable()
if err != nil {
return err
}
contents["mcpServers"] = map[string]any{
"coder": map[string]any{"command": binPath, "args": []string{"exp", "mcp", "server"}},
}
data, err = json.MarshalIndent(contents, "", " ")
if err != nil {
return err
}
err = os.WriteFile(configPath, data, 0o600)
if err != nil {
return err
}
return nil
},
}
return cmd
}
func (*RootCmd) mcpConfigureClaudeCode() *serpent.Command {
cmd := &serpent.Command{
Use: "claude-code",
Short: "Configure the Claude Code server.",
Handler: func(_ *serpent.Invocation) error {
return nil
},
}
return cmd
}
func (*RootCmd) mcpConfigureCursor() *serpent.Command {
var project bool
cmd := &serpent.Command{
Use: "cursor",
Short: "Configure Cursor to use Coder MCP.",
Options: serpent.OptionSet{
serpent.Option{
Flag: "project",
Env: "CODER_MCP_CURSOR_PROJECT",
Description: "Use to configure a local project to use the Cursor MCP.",
Value: serpent.BoolOf(&project),
},
},
Handler: func(_ *serpent.Invocation) error {
dir, err := os.Getwd()
if err != nil {
return err
}
if !project {
dir, err = os.UserHomeDir()
if err != nil {
return err
}
}
cursorDir := filepath.Join(dir, ".cursor")
err = os.MkdirAll(cursorDir, 0o755)
if err != nil {
return err
}
mcpConfig := filepath.Join(cursorDir, "mcp.json")
_, err = os.Stat(mcpConfig)
contents := map[string]any{}
if err != nil {
if !os.IsNotExist(err) {
return err
}
} else {
data, err := os.ReadFile(mcpConfig)
if err != nil {
return err
}
// The config can be empty, so we don't want to return an error if it is.
if len(data) > 0 {
err = json.Unmarshal(data, &contents)
if err != nil {
return err
}
}
}
mcpServers, ok := contents["mcpServers"].(map[string]any)
if !ok {
mcpServers = map[string]any{}
}
binPath, err := os.Executable()
if err != nil {
return err
}
mcpServers["coder"] = map[string]any{
"command": binPath,
"args": []string{"exp", "mcp", "server"},
}
contents["mcpServers"] = mcpServers
data, err := json.MarshalIndent(contents, "", " ")
if err != nil {
return err
}
err = os.WriteFile(mcpConfig, data, 0o600)
if err != nil {
return err
}
return nil
},
}
return cmd
}
func (r *RootCmd) mcpServer() *serpent.Command {
var (
client = new(codersdk.Client)
instructions string
allowedTools []string
)
return &serpent.Command{
Use: "server",
Handler: func(inv *serpent.Invocation) error {
return mcpServerHandler(inv, client, instructions, allowedTools)
},
Short: "Start the Coder MCP server.",
Middleware: serpent.Chain(
r.InitClient(client),
),
Options: []serpent.Option{
{
Name: "instructions",
Description: "The instructions to pass to the MCP server.",
Flag: "instructions",
Value: serpent.StringOf(&instructions),
},
{
Name: "allowed-tools",
Description: "Comma-separated list of allowed tools. If not specified, all tools are allowed.",
Flag: "allowed-tools",
Value: serpent.StringArrayOf(&allowedTools),
},
},
}
}
func mcpServerHandler(inv *serpent.Invocation, client *codersdk.Client, instructions string, allowedTools []string) error {
ctx, cancel := context.WithCancel(inv.Context())
defer cancel()
logger := slog.Make(sloghuman.Sink(inv.Stdout))
me, err := client.User(ctx, codersdk.Me)
if err != nil {
cliui.Errorf(inv.Stderr, "Failed to log in to the Coder deployment.")
cliui.Errorf(inv.Stderr, "Please check your URL and credentials.")
cliui.Errorf(inv.Stderr, "Tip: Run `coder whoami` to check your credentials.")
return err
}
cliui.Infof(inv.Stderr, "Starting MCP server")
cliui.Infof(inv.Stderr, "User : %s", me.Username)
cliui.Infof(inv.Stderr, "URL : %s", client.URL)
cliui.Infof(inv.Stderr, "Instructions : %q", instructions)
if len(allowedTools) > 0 {
cliui.Infof(inv.Stderr, "Allowed Tools : %v", allowedTools)
}
cliui.Infof(inv.Stderr, "Press Ctrl+C to stop the server")
// Capture the original stdin, stdout, and stderr.
invStdin := inv.Stdin
invStdout := inv.Stdout
invStderr := inv.Stderr
defer func() {
inv.Stdin = invStdin
inv.Stdout = invStdout
inv.Stderr = invStderr
}()
options := []codermcp.Option{
codermcp.WithInstructions(instructions),
codermcp.WithLogger(&logger),
}
// Add allowed tools option if specified
if len(allowedTools) > 0 {
options = append(options, codermcp.WithAllowedTools(allowedTools))
}
srv := codermcp.NewStdio(client, options...)
srv.SetErrorLogger(log.New(invStderr, "", log.LstdFlags))
done := make(chan error)
go func() {
defer close(done)
srvErr := srv.Listen(ctx, invStdin, invStdout)
done <- srvErr
}()
if err := <-done; err != nil {
if !errors.Is(err, context.Canceled) {
cliui.Errorf(inv.Stderr, "Failed to start the MCP server: %s", err)
return err
}
}
return nil
}

142
cli/exp_mcp_test.go Normal file
View File

@ -0,0 +1,142 @@
package cli_test
import (
"context"
"encoding/json"
"runtime"
"slices"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/cli/clitest"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/pty/ptytest"
"github.com/coder/coder/v2/testutil"
)
func TestExpMcp(t *testing.T) {
t.Parallel()
// Reading to / writing from the PTY is flaky on non-linux systems.
if runtime.GOOS != "linux" {
t.Skip("skipping on non-linux")
}
t.Run("AllowedTools", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
cancelCtx, cancel := context.WithCancel(ctx)
t.Cleanup(cancel)
// Given: a running coder deployment
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
// Given: we run the exp mcp command with allowed tools set
inv, root := clitest.New(t, "exp", "mcp", "server", "--allowed-tools=coder_whoami,coder_list_templates")
inv = inv.WithContext(cancelCtx)
pty := ptytest.New(t)
inv.Stdin = pty.Input()
inv.Stdout = pty.Output()
clitest.SetupConfig(t, client, root)
cmdDone := make(chan struct{})
go func() {
defer close(cmdDone)
err := inv.Run()
assert.NoError(t, err)
}()
// When: we send a tools/list request
toolsPayload := `{"jsonrpc":"2.0","id":2,"method":"tools/list"}`
pty.WriteLine(toolsPayload)
_ = pty.ReadLine(ctx) // ignore echoed output
output := pty.ReadLine(ctx)
cancel()
<-cmdDone
// Then: we should only see the allowed tools in the response
var toolsResponse struct {
Result struct {
Tools []struct {
Name string `json:"name"`
} `json:"tools"`
} `json:"result"`
}
err := json.Unmarshal([]byte(output), &toolsResponse)
require.NoError(t, err)
require.Len(t, toolsResponse.Result.Tools, 2, "should have exactly 2 tools")
foundTools := make([]string, 0, 2)
for _, tool := range toolsResponse.Result.Tools {
foundTools = append(foundTools, tool.Name)
}
slices.Sort(foundTools)
require.Equal(t, []string{"coder_list_templates", "coder_whoami"}, foundTools)
})
t.Run("OK", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
cancelCtx, cancel := context.WithCancel(ctx)
t.Cleanup(cancel)
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
inv, root := clitest.New(t, "exp", "mcp", "server")
inv = inv.WithContext(cancelCtx)
pty := ptytest.New(t)
inv.Stdin = pty.Input()
inv.Stdout = pty.Output()
clitest.SetupConfig(t, client, root)
cmdDone := make(chan struct{})
go func() {
defer close(cmdDone)
err := inv.Run()
assert.NoError(t, err)
}()
payload := `{"jsonrpc":"2.0","id":1,"method":"initialize"}`
pty.WriteLine(payload)
_ = pty.ReadLine(ctx) // ignore echoed output
output := pty.ReadLine(ctx)
cancel()
<-cmdDone
// Ensure the initialize output is valid JSON
t.Logf("/initialize output: %s", output)
var initializeResponse map[string]interface{}
err := json.Unmarshal([]byte(output), &initializeResponse)
require.NoError(t, err)
require.Equal(t, "2.0", initializeResponse["jsonrpc"])
require.Equal(t, 1.0, initializeResponse["id"])
require.NotNil(t, initializeResponse["result"])
})
t.Run("NoCredentials", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
cancelCtx, cancel := context.WithCancel(ctx)
t.Cleanup(cancel)
client := coderdtest.New(t, nil)
inv, root := clitest.New(t, "exp", "mcp", "server")
inv = inv.WithContext(cancelCtx)
pty := ptytest.New(t)
inv.Stdin = pty.Input()
inv.Stdout = pty.Output()
clitest.SetupConfig(t, client, root)
err := inv.Run()
assert.ErrorContains(t, err, "your session has expired")
})
}

4
go.mod
View File

@ -480,3 +480,7 @@ require (
github.com/golang-jwt/jwt/v5 v5.2.2 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
)
require github.com/mark3labs/mcp-go v0.17.0
require github.com/yosida95/uritemplate/v3 v3.0.2 // indirect

4
go.sum
View File

@ -658,6 +658,8 @@ github.com/makeworld-the-better-one/dither/v2 v2.4.0 h1:Az/dYXiTcwcRSe59Hzw4RI1r
github.com/makeworld-the-better-one/dither/v2 v2.4.0/go.mod h1:VBtN8DXO7SNtyGmLiGA7IsFeKrBkQPze1/iAeM95arc=
github.com/marekm4/color-extractor v1.2.1 h1:3Zb2tQsn6bITZ8MBVhc33Qn1k5/SEuZ18mrXGUqIwn0=
github.com/marekm4/color-extractor v1.2.1/go.mod h1:90VjmiHI6M8ez9eYUaXLdcKnS+BAOp7w+NpwBdkJmpA=
github.com/mark3labs/mcp-go v0.17.0 h1:5Ps6T7qXr7De/2QTqs9h6BKeZ/qdeUeGrgM5lPzi930=
github.com/mark3labs/mcp-go v0.17.0/go.mod h1:KmJndYv7GIgcPVwEKJjNcbhVQ+hJGJhrCCB/9xITzpE=
github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU=
github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE=
github.com/mattn/go-colorable v0.1.9/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc=
@ -972,6 +974,8 @@ github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZ
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
github.com/yashtewari/glob-intersection v0.2.0 h1:8iuHdN88yYuCzCdjt0gDe+6bAhUwBeEWqThExu54RFg=
github.com/yashtewari/glob-intersection v0.2.0/go.mod h1:LK7pIC3piUjovexikBbJ26Yml7g8xa5bsjfx2v1fwok=
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
github.com/yudai/gojsondiff v1.0.0 h1:27cbfqXLVEJ1o8I6v3y9lg8Ydm53EKqHXAOMxEGlCOA=
github.com/yudai/gojsondiff v1.0.0/go.mod h1:AY32+k2cwILAkW1fbgxQ5mUmMiZFgLIV+FBNExI05xg=
github.com/yudai/golcs v0.0.0-20170316035057-ecda9a501e82 h1:BHyfKlQyqbsFN5p3IfnEUduWvb9is428/nNb5L3U01M=

643
mcp/mcp.go Normal file
View File

@ -0,0 +1,643 @@
package codermcp
import (
"bytes"
"context"
"encoding/json"
"errors"
"io"
"os"
"slices"
"strings"
"time"
"github.com/google/uuid"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
"golang.org/x/xerrors"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/sloghuman"
"github.com/coder/coder/v2/buildinfo"
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
type mcpOptions struct {
instructions string
logger *slog.Logger
allowedTools []string
}
// Option is a function that configures the MCP server.
type Option func(*mcpOptions)
// WithInstructions sets the instructions for the MCP server.
func WithInstructions(instructions string) Option {
return func(o *mcpOptions) {
o.instructions = instructions
}
}
// WithLogger sets the logger for the MCP server.
func WithLogger(logger *slog.Logger) Option {
return func(o *mcpOptions) {
o.logger = logger
}
}
// WithAllowedTools sets the allowed tools for the MCP server.
func WithAllowedTools(tools []string) Option {
return func(o *mcpOptions) {
o.allowedTools = tools
}
}
// NewStdio creates a new MCP stdio server with the given client and options.
// It is the responsibility of the caller to start and stop the server.
func NewStdio(client *codersdk.Client, opts ...Option) *server.StdioServer {
options := &mcpOptions{
instructions: ``,
logger: ptr.Ref(slog.Make(sloghuman.Sink(os.Stdout))),
}
for _, opt := range opts {
opt(options)
}
mcpSrv := server.NewMCPServer(
"Coder Agent",
buildinfo.Version(),
server.WithInstructions(options.instructions),
)
logger := slog.Make(sloghuman.Sink(os.Stdout))
// Register tools based on the allowed list (if specified)
reg := AllTools()
if len(options.allowedTools) > 0 {
reg = reg.WithOnlyAllowed(options.allowedTools...)
}
reg.Register(mcpSrv, ToolDeps{
Client: client,
Logger: &logger,
})
srv := server.NewStdioServer(mcpSrv)
return srv
}
// allTools is the list of all available tools. When adding a new tool,
// make sure to update this list.
var allTools = ToolRegistry{
{
Tool: mcp.NewTool("coder_report_task",
mcp.WithDescription(`Report progress on a user task in Coder.
Use this tool to keep the user informed about your progress with their request.
For long-running operations, call this periodically to provide status updates.
This is especially useful when performing multi-step operations like workspace creation or deployment.`),
mcp.WithString("summary", mcp.Description(`A concise summary of your current progress on the task.
Good Summaries:
- "Taking a look at the login page..."
- "Found a bug! Fixing it now..."
- "Investigating the GitHub Issue..."
- "Waiting for workspace to start (1/3 resources ready)"
- "Downloading template files from repository"`), mcp.Required()),
mcp.WithString("link", mcp.Description(`A relevant URL related to your work, such as:
- GitHub issue link
- Pull request URL
- Documentation reference
- Workspace URL
Use complete URLs (including https://) when possible.`), mcp.Required()),
mcp.WithString("emoji", mcp.Description(`A relevant emoji that visually represents the current status:
- 🔍 for investigating/searching
- 🚀 for deploying/starting
- 🐛 for debugging
- ✅ for completion
- ⏳ for waiting
Choose an emoji that helps the user understand the current phase at a glance.`), mcp.Required()),
mcp.WithBoolean("done", mcp.Description(`Whether the overall task the user requested is complete.
Set to true only when the entire requested operation is finished successfully.
For multi-step processes, use false until all steps are complete.`), mcp.Required()),
),
MakeHandler: handleCoderReportTask,
},
{
Tool: mcp.NewTool("coder_whoami",
mcp.WithDescription(`Get information about the currently logged-in Coder user.
Returns JSON with the user's profile including fields: id, username, email, created_at, status, roles, etc.
Use this to identify the current user context before performing workspace operations.
This tool is useful for verifying permissions and checking the user's identity.
Common errors:
- Authentication failure: The session may have expired
- Server unavailable: The Coder deployment may be unreachable`),
),
MakeHandler: handleCoderWhoami,
},
{
Tool: mcp.NewTool("coder_list_templates",
mcp.WithDescription(`List all templates available on the Coder deployment.
Returns JSON with detailed information about each template, including:
- Template name, ID, and description
- Creation/modification timestamps
- Version information
- Associated organization
Use this tool to discover available templates before creating workspaces.
Templates define the infrastructure and configuration for workspaces.
Common errors:
- Authentication failure: Check user permissions
- No templates available: The deployment may not have any templates configured`),
),
MakeHandler: handleCoderListTemplates,
},
{
Tool: mcp.NewTool("coder_list_workspaces",
mcp.WithDescription(`List workspaces available on the Coder deployment.
Returns JSON with workspace metadata including status, resources, and configurations.
Use this before other workspace operations to find valid workspace names/IDs.
Results are paginated - use offset and limit parameters for large deployments.
Common errors:
- Authentication failure: Check user permissions
- Invalid owner parameter: Ensure the owner exists`),
mcp.WithString(`owner`, mcp.Description(`The username of the workspace owner to filter by.
Defaults to "me" which represents the currently authenticated user.
Use this to view workspaces belonging to other users (requires appropriate permissions).
Special value: "me" - List workspaces owned by the authenticated user.`), mcp.DefaultString(codersdk.Me)),
mcp.WithNumber(`offset`, mcp.Description(`Pagination offset - the starting index for listing workspaces.
Used with the 'limit' parameter to implement pagination.
For example, to get the second page of results with 10 items per page, use offset=10.
Defaults to 0 (first page).`), mcp.DefaultNumber(0)),
mcp.WithNumber(`limit`, mcp.Description(`Maximum number of workspaces to return in a single request.
Used with the 'offset' parameter to implement pagination.
Higher values return more results but may increase response time.
Valid range: 1-100. Defaults to 10.`), mcp.DefaultNumber(10)),
),
MakeHandler: handleCoderListWorkspaces,
},
{
Tool: mcp.NewTool("coder_get_workspace",
mcp.WithDescription(`Get detailed information about a specific Coder workspace.
Returns comprehensive JSON with the workspace's configuration, status, and resources.
Use this to check workspace status before performing operations like exec or start/stop.
The response includes the latest build status, agent connectivity, and resource details.
Common errors:
- Workspace not found: Check the workspace name or ID
- Permission denied: The user may not have access to this workspace`),
mcp.WithString("workspace", mcp.Description(`The workspace ID (UUID) or name to retrieve.
Can be specified as either:
- Full UUID: e.g., "8a0b9c7d-1e2f-3a4b-5c6d-7e8f9a0b1c2d"
- Workspace name: e.g., "dev", "python-project"
Use coder_list_workspaces first if you're not sure about available workspace names.`), mcp.Required()),
),
MakeHandler: handleCoderGetWorkspace,
},
{
Tool: mcp.NewTool("coder_workspace_exec",
mcp.WithDescription(`Execute a shell command in a remote Coder workspace.
Runs the specified command and returns the complete output (stdout/stderr).
Use this for file operations, running build commands, or checking workspace state.
The workspace must be running with a connected agent for this to succeed.
Before using this tool:
1. Verify the workspace is running using coder_get_workspace
2. Start the workspace if needed using coder_start_workspace
Common errors:
- Workspace not running: Start the workspace first
- Command not allowed: Check security restrictions
- Agent not connected: The workspace may still be starting up`),
mcp.WithString("workspace", mcp.Description(`The workspace ID (UUID) or name where the command will execute.
Can be specified as either:
- Full UUID: e.g., "8a0b9c7d-1e2f-3a4b-5c6d-7e8f9a0b1c2d"
- Workspace name: e.g., "dev", "python-project"
The workspace must be running with a connected agent.
Use coder_get_workspace first to check the workspace status.`), mcp.Required()),
mcp.WithString("command", mcp.Description(`The shell command to execute in the workspace.
Commands are executed in the default shell of the workspace.
Examples:
- "ls -la" - List files with details
- "cd /path/to/directory && command" - Execute in specific directory
- "cat ~/.bashrc" - View a file's contents
- "python -m pip list" - List installed Python packages
Note: Very long-running commands may time out.`), mcp.Required()),
),
MakeHandler: handleCoderWorkspaceExec,
},
{
Tool: mcp.NewTool("coder_workspace_transition",
mcp.WithDescription(`Start or stop a running Coder workspace.
If stopping, initiates the workspace stop transition.
Only works on workspaces that are currently running or failed.
If starting, initiates the workspace start transition.
Only works on workspaces that are currently stopped or failed.
Stopping or starting a workspace is an asynchronous operation - it may take several minutes to complete.
After calling this tool:
1. Use coder_report_task to inform the user that the workspace is stopping or starting
2. Use coder_get_workspace periodically to check for completion
Common errors:
- Workspace already started/starting/stopped/stopping: No action needed
- Cancellation failed: There may be issues with the underlying infrastructure
- User doesn't own workspace: Permission issues`),
mcp.WithString("workspace", mcp.Description(`The workspace ID (UUID) or name to start or stop.
Can be specified as either:
- Full UUID: e.g., "8a0b9c7d-1e2f-3a4b-5c6d-7e8f9a0b1c2d"
- Workspace name: e.g., "dev", "python-project"
The workspace must be in a running state to be stopped, or in a stopped or failed state to be started.
Use coder_get_workspace first to check the current workspace status.`), mcp.Required()),
mcp.WithString("transition", mcp.Description(`The transition to apply to the workspace.
Can be either "start" or "stop".`)),
),
MakeHandler: handleCoderWorkspaceTransition,
},
}
// ToolDeps contains all dependencies needed by tool handlers
type ToolDeps struct {
Client *codersdk.Client
Logger *slog.Logger
}
// ToolHandler associates a tool with its handler creation function
type ToolHandler struct {
Tool mcp.Tool
MakeHandler func(ToolDeps) server.ToolHandlerFunc
}
// ToolRegistry is a map of available tools with their handler creation
// functions
type ToolRegistry []ToolHandler
// WithOnlyAllowed returns a new ToolRegistry containing only the tools
// specified in the allowed list.
func (r ToolRegistry) WithOnlyAllowed(allowed ...string) ToolRegistry {
if len(allowed) == 0 {
return []ToolHandler{}
}
filtered := make(ToolRegistry, 0, len(r))
// The overhead of a map lookup is likely higher than a linear scan
// for a small number of tools.
for _, entry := range r {
if slices.Contains(allowed, entry.Tool.Name) {
filtered = append(filtered, entry)
}
}
return filtered
}
// Register registers all tools in the registry with the given tool adder
// and dependencies.
func (r ToolRegistry) Register(srv *server.MCPServer, deps ToolDeps) {
for _, entry := range r {
srv.AddTool(entry.Tool, entry.MakeHandler(deps))
}
}
// AllTools returns all available tools.
func AllTools() ToolRegistry {
// return a copy of allTools to avoid mutating the original
return slices.Clone(allTools)
}
type handleCoderReportTaskArgs struct {
Summary string `json:"summary"`
Link string `json:"link"`
Emoji string `json:"emoji"`
Done bool `json:"done"`
}
// Example payload:
// {"jsonrpc":"2.0","id":1,"method":"tools/call", "params": {"name": "coder_report_task", "arguments": {"summary": "I'm working on the login page.", "link": "https://github.com/coder/coder/pull/1234", "emoji": "🔍", "done": false}}}
func handleCoderReportTask(deps ToolDeps) server.ToolHandlerFunc {
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
if deps.Client == nil {
return nil, xerrors.New("developer error: client is required")
}
// Convert the request parameters to a json.RawMessage so we can unmarshal
// them into the correct struct.
args, err := unmarshalArgs[handleCoderReportTaskArgs](request.Params.Arguments)
if err != nil {
return nil, xerrors.Errorf("failed to unmarshal arguments: %w", err)
}
// TODO: Waiting on support for tasks.
deps.Logger.Info(ctx, "report task tool called", slog.F("summary", args.Summary), slog.F("link", args.Link), slog.F("done", args.Done), slog.F("emoji", args.Emoji))
/*
err := sdk.PostTask(ctx, agentsdk.PostTaskRequest{
Reporter: "claude",
Summary: summary,
URL: link,
Completion: done,
Icon: emoji,
})
if err != nil {
return nil, err
}
*/
return &mcp.CallToolResult{
Content: []mcp.Content{
mcp.NewTextContent("Thanks for reporting!"),
},
}, nil
}
}
// Example payload:
// {"jsonrpc":"2.0","id":1,"method":"tools/call", "params": {"name": "coder_whoami", "arguments": {}}}
func handleCoderWhoami(deps ToolDeps) server.ToolHandlerFunc {
return func(ctx context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
if deps.Client == nil {
return nil, xerrors.New("developer error: client is required")
}
me, err := deps.Client.User(ctx, codersdk.Me)
if err != nil {
return nil, xerrors.Errorf("Failed to fetch the current user: %s", err.Error())
}
var buf bytes.Buffer
if err := json.NewEncoder(&buf).Encode(me); err != nil {
return nil, xerrors.Errorf("Failed to encode the current user: %s", err.Error())
}
return &mcp.CallToolResult{
Content: []mcp.Content{
mcp.NewTextContent(strings.TrimSpace(buf.String())),
},
}, nil
}
}
type handleCoderListWorkspacesArgs struct {
Owner string `json:"owner"`
Offset int `json:"offset"`
Limit int `json:"limit"`
}
// Example payload:
// {"jsonrpc":"2.0","id":1,"method":"tools/call", "params": {"name": "coder_list_workspaces", "arguments": {"owner": "me", "offset": 0, "limit": 10}}}
func handleCoderListWorkspaces(deps ToolDeps) server.ToolHandlerFunc {
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
if deps.Client == nil {
return nil, xerrors.New("developer error: client is required")
}
args, err := unmarshalArgs[handleCoderListWorkspacesArgs](request.Params.Arguments)
if err != nil {
return nil, xerrors.Errorf("failed to unmarshal arguments: %w", err)
}
workspaces, err := deps.Client.Workspaces(ctx, codersdk.WorkspaceFilter{
Owner: args.Owner,
Offset: args.Offset,
Limit: args.Limit,
})
if err != nil {
return nil, xerrors.Errorf("failed to fetch workspaces: %w", err)
}
// Encode it as JSON. TODO: It might be nicer for the agent to have a tabulated response.
data, err := json.Marshal(workspaces)
if err != nil {
return nil, xerrors.Errorf("failed to encode workspaces: %s", err.Error())
}
return &mcp.CallToolResult{
Content: []mcp.Content{
mcp.NewTextContent(string(data)),
},
}, nil
}
}
type handleCoderGetWorkspaceArgs struct {
Workspace string `json:"workspace"`
}
// Example payload:
// {"jsonrpc":"2.0","id":1,"method":"tools/call", "params": {"name": "coder_get_workspace", "arguments": {"workspace": "dev"}}}
func handleCoderGetWorkspace(deps ToolDeps) server.ToolHandlerFunc {
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
if deps.Client == nil {
return nil, xerrors.New("developer error: client is required")
}
args, err := unmarshalArgs[handleCoderGetWorkspaceArgs](request.Params.Arguments)
if err != nil {
return nil, xerrors.Errorf("failed to unmarshal arguments: %w", err)
}
workspace, err := getWorkspaceByIDOrOwnerName(ctx, deps.Client, args.Workspace)
if err != nil {
return nil, xerrors.Errorf("failed to fetch workspace: %w", err)
}
workspaceJSON, err := json.Marshal(workspace)
if err != nil {
return nil, xerrors.Errorf("failed to encode workspace: %w", err)
}
return &mcp.CallToolResult{
Content: []mcp.Content{
mcp.NewTextContent(string(workspaceJSON)),
},
}, nil
}
}
type handleCoderWorkspaceExecArgs struct {
Workspace string `json:"workspace"`
Command string `json:"command"`
}
// Example payload:
// {"jsonrpc":"2.0","id":1,"method":"tools/call", "params": {"name": "coder_workspace_exec", "arguments": {"workspace": "dev", "command": "ps -ef"}}}
func handleCoderWorkspaceExec(deps ToolDeps) server.ToolHandlerFunc {
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
if deps.Client == nil {
return nil, xerrors.New("developer error: client is required")
}
args, err := unmarshalArgs[handleCoderWorkspaceExecArgs](request.Params.Arguments)
if err != nil {
return nil, xerrors.Errorf("failed to unmarshal arguments: %w", err)
}
// Attempt to fetch the workspace. We may get a UUID or a name, so try to
// handle both.
ws, err := getWorkspaceByIDOrOwnerName(ctx, deps.Client, args.Workspace)
if err != nil {
return nil, xerrors.Errorf("failed to fetch workspace: %w", err)
}
// Ensure the workspace is started.
// Select the first agent of the workspace.
var agt *codersdk.WorkspaceAgent
for _, r := range ws.LatestBuild.Resources {
for _, a := range r.Agents {
if a.Status != codersdk.WorkspaceAgentConnected {
continue
}
agt = ptr.Ref(a)
break
}
}
if agt == nil {
return nil, xerrors.Errorf("no connected agents for workspace %s", ws.ID)
}
startedAt := time.Now()
conn, err := workspacesdk.New(deps.Client).AgentReconnectingPTY(ctx, workspacesdk.WorkspaceAgentReconnectingPTYOpts{
AgentID: agt.ID,
Reconnect: uuid.New(),
Width: 80,
Height: 24,
Command: args.Command,
BackendType: "buffered", // the screen backend is annoying to use here.
})
if err != nil {
return nil, xerrors.Errorf("failed to open reconnecting PTY: %w", err)
}
defer conn.Close()
connectedAt := time.Now()
var buf bytes.Buffer
if _, err := io.Copy(&buf, conn); err != nil {
// EOF is expected when the connection is closed.
// We can ignore this error.
if !errors.Is(err, io.EOF) {
return nil, xerrors.Errorf("failed to read from reconnecting PTY: %w", err)
}
}
completedAt := time.Now()
connectionTime := connectedAt.Sub(startedAt)
executionTime := completedAt.Sub(connectedAt)
resp := map[string]string{
"connection_time": connectionTime.String(),
"execution_time": executionTime.String(),
"output": buf.String(),
}
respJSON, err := json.Marshal(resp)
if err != nil {
return nil, xerrors.Errorf("failed to encode workspace build: %w", err)
}
return &mcp.CallToolResult{
Content: []mcp.Content{
mcp.NewTextContent(string(respJSON)),
},
}, nil
}
}
// Example payload:
// {"jsonrpc":"2.0","id":1,"method":"tools/call", "params": {"name": "coder_list_templates", "arguments": {}}}
func handleCoderListTemplates(deps ToolDeps) server.ToolHandlerFunc {
return func(ctx context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
if deps.Client == nil {
return nil, xerrors.New("developer error: client is required")
}
templates, err := deps.Client.Templates(ctx, codersdk.TemplateFilter{})
if err != nil {
return nil, xerrors.Errorf("failed to fetch templates: %w", err)
}
templateJSON, err := json.Marshal(templates)
if err != nil {
return nil, xerrors.Errorf("failed to encode templates: %w", err)
}
return &mcp.CallToolResult{
Content: []mcp.Content{
mcp.NewTextContent(string(templateJSON)),
},
}, nil
}
}
type handleCoderWorkspaceTransitionArgs struct {
Workspace string `json:"workspace"`
Transition string `json:"transition"`
}
// Example payload:
// {"jsonrpc":"2.0","id":1,"method":"tools/call", "params": {"name":
// "coder_workspace_transition", "arguments": {"workspace": "dev", "transition": "stop"}}}
func handleCoderWorkspaceTransition(deps ToolDeps) server.ToolHandlerFunc {
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
if deps.Client == nil {
return nil, xerrors.New("developer error: client is required")
}
args, err := unmarshalArgs[handleCoderWorkspaceTransitionArgs](request.Params.Arguments)
if err != nil {
return nil, xerrors.Errorf("failed to unmarshal arguments: %w", err)
}
workspace, err := getWorkspaceByIDOrOwnerName(ctx, deps.Client, args.Workspace)
if err != nil {
return nil, xerrors.Errorf("failed to fetch workspace: %w", err)
}
wsTransition := codersdk.WorkspaceTransition(args.Transition)
switch wsTransition {
case codersdk.WorkspaceTransitionStart:
case codersdk.WorkspaceTransitionStop:
default:
return nil, xerrors.New("invalid transition")
}
// We're not going to check the workspace status here as it is checked on the
// server side.
wb, err := deps.Client.CreateWorkspaceBuild(ctx, workspace.ID, codersdk.CreateWorkspaceBuildRequest{
Transition: wsTransition,
})
if err != nil {
return nil, xerrors.Errorf("failed to stop workspace: %w", err)
}
resp := map[string]any{"status": wb.Status, "transition": wb.Transition}
respJSON, err := json.Marshal(resp)
if err != nil {
return nil, xerrors.Errorf("failed to encode workspace build: %w", err)
}
return &mcp.CallToolResult{
Content: []mcp.Content{
mcp.NewTextContent(string(respJSON)),
},
}, nil
}
}
func getWorkspaceByIDOrOwnerName(ctx context.Context, client *codersdk.Client, identifier string) (codersdk.Workspace, error) {
if wsid, err := uuid.Parse(identifier); err == nil {
return client.Workspace(ctx, wsid)
}
return client.WorkspaceByOwnerAndName(ctx, codersdk.Me, identifier, codersdk.WorkspaceOptions{})
}
// unmarshalArgs is a helper function to convert the map[string]any we get from
// the MCP server into a typed struct. It does this by marshaling and unmarshalling
// the arguments.
func unmarshalArgs[T any](args map[string]interface{}) (t T, err error) {
argsJSON, err := json.Marshal(args)
if err != nil {
return t, xerrors.Errorf("failed to marshal arguments: %w", err)
}
if err := json.Unmarshal(argsJSON, &t); err != nil {
return t, xerrors.Errorf("failed to unmarshal arguments: %w", err)
}
return t, nil
}

361
mcp/mcp_test.go Normal file
View File

@ -0,0 +1,361 @@
package codermcp_test
import (
"context"
"encoding/json"
"io"
"runtime"
"testing"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
"github.com/stretchr/testify/require"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/agent/agenttest"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbfake"
"github.com/coder/coder/v2/codersdk"
codermcp "github.com/coder/coder/v2/mcp"
"github.com/coder/coder/v2/pty/ptytest"
"github.com/coder/coder/v2/testutil"
)
// These tests are dependent on the state of the coder server.
// Running them in parallel is prone to racy behavior.
// nolint:tparallel,paralleltest
func TestCoderTools(t *testing.T) {
if runtime.GOOS != "linux" {
t.Skip("skipping on non-linux due to pty issues")
}
ctx := testutil.Context(t, testutil.WaitLong)
// Given: a coder server, workspace, and agent.
client, store := coderdtest.NewWithDatabase(t, nil)
owner := coderdtest.CreateFirstUser(t, client)
// Given: a member user with which to test the tools.
memberClient, member := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
// Given: a workspace with an agent.
r := dbfake.WorkspaceBuild(t, store, database.WorkspaceTable{
OrganizationID: owner.OrganizationID,
OwnerID: member.ID,
}).WithAgent().Do()
// Note: we want to test the list_workspaces tool before starting the
// workspace agent. Starting the workspace agent will modify the workspace
// state, which will affect the results of the list_workspaces tool.
listWorkspacesDone := make(chan struct{})
agentStarted := make(chan struct{})
go func() {
defer close(agentStarted)
<-listWorkspacesDone
agt := agenttest.New(t, client.URL, r.AgentToken)
t.Cleanup(func() {
_ = agt.Close()
})
_ = coderdtest.NewWorkspaceAgentWaiter(t, client, r.Workspace.ID).Wait()
}()
// Given: a MCP server listening on a pty.
pty := ptytest.New(t)
mcpSrv, closeSrv := startTestMCPServer(ctx, t, pty.Input(), pty.Output())
t.Cleanup(func() {
_ = closeSrv()
})
// Register tools using our registry
logger := slogtest.Make(t, nil)
codermcp.AllTools().Register(mcpSrv, codermcp.ToolDeps{
Client: memberClient,
Logger: &logger,
})
t.Run("coder_list_templates", func(t *testing.T) {
// When: the coder_list_templates tool is called
ctr := makeJSONRPCRequest(t, "tools/call", "coder_list_templates", map[string]any{})
pty.WriteLine(ctr)
_ = pty.ReadLine(ctx) // skip the echo
templates, err := memberClient.Templates(ctx, codersdk.TemplateFilter{})
require.NoError(t, err)
templatesJSON, err := json.Marshal(templates)
require.NoError(t, err)
// Then: the response is a list of templates visible to the user.
expected := makeJSONRPCTextResponse(t, string(templatesJSON))
actual := pty.ReadLine(ctx)
testutil.RequireJSONEq(t, expected, actual)
})
t.Run("coder_report_task", func(t *testing.T) {
// When: the coder_report_task tool is called
ctr := makeJSONRPCRequest(t, "tools/call", "coder_report_task", map[string]any{
"summary": "Test summary",
"link": "https://example.com",
"emoji": "🔍",
"done": false,
"coder_url": client.URL.String(),
"coder_session_token": client.SessionToken(),
})
pty.WriteLine(ctr)
_ = pty.ReadLine(ctx) // skip the echo
// Then: the response is a success message.
// TODO: check the task was created. This functionality is not yet implemented.
expected := makeJSONRPCTextResponse(t, "Thanks for reporting!")
actual := pty.ReadLine(ctx)
testutil.RequireJSONEq(t, expected, actual)
})
t.Run("coder_whoami", func(t *testing.T) {
// When: the coder_whoami tool is called
me, err := memberClient.User(ctx, codersdk.Me)
require.NoError(t, err)
meJSON, err := json.Marshal(me)
require.NoError(t, err)
ctr := makeJSONRPCRequest(t, "tools/call", "coder_whoami", map[string]any{})
pty.WriteLine(ctr)
_ = pty.ReadLine(ctx) // skip the echo
// Then: the response is a valid JSON respresentation of the calling user.
expected := makeJSONRPCTextResponse(t, string(meJSON))
actual := pty.ReadLine(ctx)
testutil.RequireJSONEq(t, expected, actual)
})
t.Run("coder_list_workspaces", func(t *testing.T) {
defer close(listWorkspacesDone)
// When: the coder_list_workspaces tool is called
ctr := makeJSONRPCRequest(t, "tools/call", "coder_list_workspaces", map[string]any{
"coder_url": client.URL.String(),
"coder_session_token": client.SessionToken(),
})
pty.WriteLine(ctr)
_ = pty.ReadLine(ctx) // skip the echo
ws, err := memberClient.Workspaces(ctx, codersdk.WorkspaceFilter{})
require.NoError(t, err)
wsJSON, err := json.Marshal(ws)
require.NoError(t, err)
// Then: the response is a valid JSON respresentation of the calling user's workspaces.
expected := makeJSONRPCTextResponse(t, string(wsJSON))
actual := pty.ReadLine(ctx)
testutil.RequireJSONEq(t, expected, actual)
})
t.Run("coder_get_workspace", func(t *testing.T) {
// Given: the workspace agent is connected.
// The act of starting the agent will modify the workspace state.
<-agentStarted
// When: the coder_get_workspace tool is called
ctr := makeJSONRPCRequest(t, "tools/call", "coder_get_workspace", map[string]any{
"workspace": r.Workspace.ID.String(),
})
pty.WriteLine(ctr)
_ = pty.ReadLine(ctx) // skip the echo
ws, err := memberClient.Workspace(ctx, r.Workspace.ID)
require.NoError(t, err)
wsJSON, err := json.Marshal(ws)
require.NoError(t, err)
// Then: the response is a valid JSON respresentation of the workspace.
expected := makeJSONRPCTextResponse(t, string(wsJSON))
actual := pty.ReadLine(ctx)
testutil.RequireJSONEq(t, expected, actual)
})
// NOTE: this test runs after the list_workspaces tool is called.
t.Run("coder_workspace_exec", func(t *testing.T) {
// Given: the workspace agent is connected
<-agentStarted
// When: the coder_workspace_exec tools is called with a command
randString := testutil.GetRandomName(t)
ctr := makeJSONRPCRequest(t, "tools/call", "coder_workspace_exec", map[string]any{
"workspace": r.Workspace.ID.String(),
"command": "echo " + randString,
"coder_url": client.URL.String(),
"coder_session_token": client.SessionToken(),
})
pty.WriteLine(ctr)
_ = pty.ReadLine(ctx) // skip the echo
// Then: the response is the output of the command.
actual := pty.ReadLine(ctx)
require.Contains(t, actual, randString)
})
// NOTE: this test runs after the list_workspaces tool is called.
t.Run("tool_restrictions", func(t *testing.T) {
// Given: the workspace agent is connected
<-agentStarted
// Given: a restricted MCP server with only allowed tools and commands
restrictedPty := ptytest.New(t)
allowedTools := []string{"coder_workspace_exec"}
restrictedMCPSrv, closeRestrictedSrv := startTestMCPServer(ctx, t, restrictedPty.Input(), restrictedPty.Output())
t.Cleanup(func() {
_ = closeRestrictedSrv()
})
codermcp.AllTools().
WithOnlyAllowed(allowedTools...).
Register(restrictedMCPSrv, codermcp.ToolDeps{
Client: memberClient,
Logger: &logger,
})
// When: the tools/list command is called
toolsListCmd := makeJSONRPCRequest(t, "tools/list", "", nil)
restrictedPty.WriteLine(toolsListCmd)
_ = restrictedPty.ReadLine(ctx) // skip the echo
// Then: the response is a list of only the allowed tools.
toolsListResponse := restrictedPty.ReadLine(ctx)
require.Contains(t, toolsListResponse, "coder_workspace_exec")
require.NotContains(t, toolsListResponse, "coder_whoami")
// When: a disallowed tool is called
disallowedToolCmd := makeJSONRPCRequest(t, "tools/call", "coder_whoami", map[string]any{})
restrictedPty.WriteLine(disallowedToolCmd)
_ = restrictedPty.ReadLine(ctx) // skip the echo
// Then: the response is an error indicating the tool is not available.
disallowedToolResponse := restrictedPty.ReadLine(ctx)
require.Contains(t, disallowedToolResponse, "error")
require.Contains(t, disallowedToolResponse, "not found")
})
t.Run("coder_workspace_transition_stop", func(t *testing.T) {
// Given: a separate workspace in the running state
stopWs := dbfake.WorkspaceBuild(t, store, database.WorkspaceTable{
OrganizationID: owner.OrganizationID,
OwnerID: member.ID,
}).WithAgent().Do()
// When: the coder_workspace_transition tool is called with a stop transition
ctr := makeJSONRPCRequest(t, "tools/call", "coder_workspace_transition", map[string]any{
"workspace": stopWs.Workspace.ID.String(),
"transition": "stop",
})
pty.WriteLine(ctr)
_ = pty.ReadLine(ctx) // skip the echo
// Then: the response is as expected.
expected := makeJSONRPCTextResponse(t, `{"status":"pending","transition":"stop"}`) // no provisionerd yet
actual := pty.ReadLine(ctx)
testutil.RequireJSONEq(t, expected, actual)
})
t.Run("coder_workspace_transition_start", func(t *testing.T) {
// Given: a separate workspace in the stopped state
stopWs := dbfake.WorkspaceBuild(t, store, database.WorkspaceTable{
OrganizationID: owner.OrganizationID,
OwnerID: member.ID,
}).Seed(database.WorkspaceBuild{
Transition: database.WorkspaceTransitionStop,
}).Do()
// When: the coder_workspace_transition tool is called with a start transition
ctr := makeJSONRPCRequest(t, "tools/call", "coder_workspace_transition", map[string]any{
"workspace": stopWs.Workspace.ID.String(),
"transition": "start",
})
pty.WriteLine(ctr)
_ = pty.ReadLine(ctx) // skip the echo
// Then: the response is as expected
expected := makeJSONRPCTextResponse(t, `{"status":"pending","transition":"start"}`) // no provisionerd yet
actual := pty.ReadLine(ctx)
testutil.RequireJSONEq(t, expected, actual)
})
}
// makeJSONRPCRequest is a helper function that makes a JSON RPC request.
func makeJSONRPCRequest(t *testing.T, method, name string, args map[string]any) string {
t.Helper()
req := mcp.JSONRPCRequest{
ID: "1",
JSONRPC: "2.0",
Request: mcp.Request{Method: method},
Params: struct { // Unfortunately, there is no type for this yet.
Name string "json:\"name\""
Arguments map[string]any "json:\"arguments,omitempty\""
Meta *struct {
ProgressToken mcp.ProgressToken "json:\"progressToken,omitempty\""
} "json:\"_meta,omitempty\""
}{
Name: name,
Arguments: args,
},
}
bs, err := json.Marshal(req)
require.NoError(t, err, "failed to marshal JSON RPC request")
return string(bs)
}
// makeJSONRPCTextResponse is a helper function that makes a JSON RPC text response
func makeJSONRPCTextResponse(t *testing.T, text string) string {
t.Helper()
resp := mcp.JSONRPCResponse{
ID: "1",
JSONRPC: "2.0",
Result: mcp.CallToolResult{
Content: []mcp.Content{
mcp.NewTextContent(text),
},
},
}
bs, err := json.Marshal(resp)
require.NoError(t, err, "failed to marshal JSON RPC response")
return string(bs)
}
// startTestMCPServer is a helper function that starts a MCP server listening on
// a pty. It is the responsibility of the caller to close the server.
func startTestMCPServer(ctx context.Context, t testing.TB, stdin io.Reader, stdout io.Writer) (*server.MCPServer, func() error) {
t.Helper()
mcpSrv := server.NewMCPServer(
"Test Server",
"0.0.0",
server.WithInstructions(""),
server.WithLogging(),
)
stdioSrv := server.NewStdioServer(mcpSrv)
cancelCtx, cancel := context.WithCancel(ctx)
closeCh := make(chan struct{})
done := make(chan error)
go func() {
defer close(done)
srvErr := stdioSrv.Listen(cancelCtx, stdin, stdout)
done <- srvErr
}()
go func() {
select {
case <-closeCh:
cancel()
case <-done:
cancel()
}
}()
return mcpSrv, func() error {
close(closeCh)
return <-done
}
}

27
testutil/json.go Normal file
View File

@ -0,0 +1,27 @@
package testutil
import (
"encoding/json"
"testing"
"github.com/google/go-cmp/cmp"
)
// RequireJSONEq is like assert.RequireJSONEq, but it's actually readable.
// Note that this calls t.Fatalf under the hood, so it should never
// be called in a goroutine.
func RequireJSONEq(t *testing.T, expected, actual string) {
t.Helper()
var expectedJSON, actualJSON any
if err := json.Unmarshal([]byte(expected), &expectedJSON); err != nil {
t.Fatalf("failed to unmarshal expected JSON: %s", err)
}
if err := json.Unmarshal([]byte(actual), &actualJSON); err != nil {
t.Fatalf("failed to unmarshal actual JSON: %s", err)
}
if diff := cmp.Diff(expectedJSON, actualJSON); diff != "" {
t.Fatalf("JSON diff (-want +got):\n%s", diff)
}
}