Files
coder/mcp/mcp_test.go
2025-04-01 15:02:08 +01:00

398 lines
14 KiB
Go

package codermcp_test
import (
"context"
"encoding/json"
"io"
"runtime"
"testing"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
"github.com/stretchr/testify/require"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/agent/agenttest"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbfake"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
codermcp "github.com/coder/coder/v2/mcp"
"github.com/coder/coder/v2/provisionersdk/proto"
"github.com/coder/coder/v2/pty/ptytest"
"github.com/coder/coder/v2/testutil"
)
// These tests are dependent on the state of the coder server.
// Running them in parallel is prone to racy behavior.
// nolint:tparallel,paralleltest
func TestCoderTools(t *testing.T) {
if runtime.GOOS != "linux" {
t.Skip("skipping on non-linux due to pty issues")
}
ctx := testutil.Context(t, testutil.WaitLong)
// Given: a coder server, workspace, and agent.
client, store := coderdtest.NewWithDatabase(t, nil)
owner := coderdtest.CreateFirstUser(t, client)
// Given: a member user with which to test the tools.
memberClient, member := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID)
// Given: a workspace with an agent.
r := dbfake.WorkspaceBuild(t, store, database.WorkspaceTable{
OrganizationID: owner.OrganizationID,
OwnerID: member.ID,
}).WithAgent(func(agents []*proto.Agent) []*proto.Agent {
agents[0].Apps = []*proto.App{
{
Slug: "some-agent-app",
},
}
return agents
}).Do()
// Note: we want to test the list_workspaces tool before starting the
// workspace agent. Starting the workspace agent will modify the workspace
// state, which will affect the results of the list_workspaces tool.
listWorkspacesDone := make(chan struct{})
agentStarted := make(chan struct{})
go func() {
defer close(agentStarted)
<-listWorkspacesDone
agt := agenttest.New(t, client.URL, r.AgentToken)
t.Cleanup(func() {
_ = agt.Close()
})
_ = coderdtest.NewWorkspaceAgentWaiter(t, client, r.Workspace.ID).Wait()
}()
// Given: a MCP server listening on a pty.
pty := ptytest.New(t)
mcpSrv, closeSrv := startTestMCPServer(ctx, t, pty.Input(), pty.Output())
t.Cleanup(func() {
_ = closeSrv()
})
// Register tools using our registry
logger := slogtest.Make(t, nil)
agentClient := agentsdk.New(memberClient.URL)
codermcp.AllTools().Register(mcpSrv, codermcp.ToolDeps{
Client: memberClient,
Logger: &logger,
AppStatusSlug: "some-agent-app",
AgentClient: agentClient,
})
t.Run("coder_list_templates", func(t *testing.T) {
// When: the coder_list_templates tool is called
ctr := makeJSONRPCRequest(t, "tools/call", "coder_list_templates", map[string]any{})
pty.WriteLine(ctr)
_ = pty.ReadLine(ctx) // skip the echo
// Then: the response is a list of expected visible to the user.
expected, err := memberClient.Templates(ctx, codersdk.TemplateFilter{})
require.NoError(t, err)
actual := unmarshalFromCallToolResult[[]codersdk.Template](t, pty.ReadLine(ctx))
require.Len(t, actual, 1)
require.Equal(t, expected[0].ID, actual[0].ID)
})
t.Run("coder_report_task", func(t *testing.T) {
// Given: the MCP server has an agent token.
oldAgentToken := agentClient.SDK.SessionToken()
agentClient.SetSessionToken(r.AgentToken)
t.Cleanup(func() {
agentClient.SDK.SetSessionToken(oldAgentToken)
})
// When: the coder_report_task tool is called
ctr := makeJSONRPCRequest(t, "tools/call", "coder_report_task", map[string]any{
"summary": "Test summary",
"link": "https://example.com",
"emoji": "🔍",
"done": false,
"need_user_attention": true,
})
pty.WriteLine(ctr)
_ = pty.ReadLine(ctx) // skip the echo
// Then: positive feedback is given to the reporting agent.
actual := pty.ReadLine(ctx)
require.Contains(t, actual, "Thanks for reporting!")
// Then: the response is a success message.
ws, err := memberClient.Workspace(ctx, r.Workspace.ID)
require.NoError(t, err, "failed to get workspace")
agt, err := memberClient.WorkspaceAgent(ctx, ws.LatestBuild.Resources[0].Agents[0].ID)
require.NoError(t, err, "failed to get workspace agent")
require.NotEmpty(t, agt.Apps, "workspace agent should have an app")
require.NotEmpty(t, agt.Apps[0].Statuses, "workspace agent app should have a status")
st := agt.Apps[0].Statuses[0]
// require.Equal(t, ws.ID, st.WorkspaceID, "workspace app status should have the correct workspace id")
require.Equal(t, agt.ID, st.AgentID, "workspace app status should have the correct agent id")
require.Equal(t, agt.Apps[0].ID, st.AppID, "workspace app status should have the correct app id")
require.Equal(t, codersdk.WorkspaceAppStatusStateFailure, st.State, "workspace app status should be in the failure state")
require.Equal(t, "Test summary", st.Message, "workspace app status should have the correct message")
require.Equal(t, "https://example.com", st.URI, "workspace app status should have the correct uri")
require.Equal(t, "🔍", st.Icon, "workspace app status should have the correct icon")
require.True(t, st.NeedsUserAttention, "workspace app status should need user attention")
})
t.Run("coder_whoami", func(t *testing.T) {
// When: the coder_whoami tool is called
ctr := makeJSONRPCRequest(t, "tools/call", "coder_whoami", map[string]any{})
pty.WriteLine(ctr)
_ = pty.ReadLine(ctx) // skip the echo
// Then: the response is a valid JSON respresentation of the calling user.
expected, err := memberClient.User(ctx, codersdk.Me)
require.NoError(t, err)
actual := unmarshalFromCallToolResult[codersdk.User](t, pty.ReadLine(ctx))
require.Equal(t, expected.ID, actual.ID)
})
t.Run("coder_list_workspaces", func(t *testing.T) {
defer close(listWorkspacesDone)
// When: the coder_list_workspaces tool is called
ctr := makeJSONRPCRequest(t, "tools/call", "coder_list_workspaces", map[string]any{
"coder_url": client.URL.String(),
"coder_session_token": client.SessionToken(),
})
pty.WriteLine(ctr)
_ = pty.ReadLine(ctx) // skip the echo
// Then: the response is a valid JSON respresentation of the calling user's workspaces.
actual := unmarshalFromCallToolResult[codersdk.WorkspacesResponse](t, pty.ReadLine(ctx))
require.Len(t, actual.Workspaces, 1, "expected 1 workspace")
require.Equal(t, r.Workspace.ID, actual.Workspaces[0].ID, "expected the workspace to be the one we created in setup")
})
t.Run("coder_get_workspace", func(t *testing.T) {
// Given: the workspace agent is connected.
// The act of starting the agent will modify the workspace state.
<-agentStarted
// When: the coder_get_workspace tool is called
ctr := makeJSONRPCRequest(t, "tools/call", "coder_get_workspace", map[string]any{
"workspace": r.Workspace.ID.String(),
})
pty.WriteLine(ctr)
_ = pty.ReadLine(ctx) // skip the echo
expected, err := memberClient.Workspace(ctx, r.Workspace.ID)
require.NoError(t, err)
// Then: the response is a valid JSON respresentation of the workspace.
actual := unmarshalFromCallToolResult[codersdk.Workspace](t, pty.ReadLine(ctx))
require.Equal(t, expected.ID, actual.ID)
})
// NOTE: this test runs after the list_workspaces tool is called.
t.Run("coder_workspace_exec", func(t *testing.T) {
// Given: the workspace agent is connected
<-agentStarted
// When: the coder_workspace_exec tools is called with a command
randString := testutil.GetRandomName(t)
ctr := makeJSONRPCRequest(t, "tools/call", "coder_workspace_exec", map[string]any{
"workspace": r.Workspace.ID.String(),
"command": "echo " + randString,
"coder_url": client.URL.String(),
"coder_session_token": client.SessionToken(),
})
pty.WriteLine(ctr)
_ = pty.ReadLine(ctx) // skip the echo
// Then: the response is the output of the command.
actual := pty.ReadLine(ctx)
require.Contains(t, actual, randString)
})
// NOTE: this test runs after the list_workspaces tool is called.
t.Run("tool_restrictions", func(t *testing.T) {
// Given: the workspace agent is connected
<-agentStarted
// Given: a restricted MCP server with only allowed tools and commands
restrictedPty := ptytest.New(t)
allowedTools := []string{"coder_workspace_exec"}
restrictedMCPSrv, closeRestrictedSrv := startTestMCPServer(ctx, t, restrictedPty.Input(), restrictedPty.Output())
t.Cleanup(func() {
_ = closeRestrictedSrv()
})
codermcp.AllTools().
WithOnlyAllowed(allowedTools...).
Register(restrictedMCPSrv, codermcp.ToolDeps{
Client: memberClient,
Logger: &logger,
})
// When: the tools/list command is called
toolsListCmd := makeJSONRPCRequest(t, "tools/list", "", nil)
restrictedPty.WriteLine(toolsListCmd)
_ = restrictedPty.ReadLine(ctx) // skip the echo
// Then: the response is a list of only the allowed tools.
toolsListResponse := restrictedPty.ReadLine(ctx)
require.Contains(t, toolsListResponse, "coder_workspace_exec")
require.NotContains(t, toolsListResponse, "coder_whoami")
// When: a disallowed tool is called
disallowedToolCmd := makeJSONRPCRequest(t, "tools/call", "coder_whoami", map[string]any{})
restrictedPty.WriteLine(disallowedToolCmd)
_ = restrictedPty.ReadLine(ctx) // skip the echo
// Then: the response is an error indicating the tool is not available.
disallowedToolResponse := restrictedPty.ReadLine(ctx)
require.Contains(t, disallowedToolResponse, "error")
require.Contains(t, disallowedToolResponse, "not found")
})
t.Run("coder_workspace_transition_stop", func(t *testing.T) {
// Given: a separate workspace in the running state
stopWs := dbfake.WorkspaceBuild(t, store, database.WorkspaceTable{
OrganizationID: owner.OrganizationID,
OwnerID: member.ID,
}).WithAgent().Do()
// When: the coder_workspace_transition tool is called with a stop transition
ctr := makeJSONRPCRequest(t, "tools/call", "coder_workspace_transition", map[string]any{
"workspace": stopWs.Workspace.ID.String(),
"transition": "stop",
})
pty.WriteLine(ctr)
_ = pty.ReadLine(ctx) // skip the echo
// Then: the response is as expected.
expected := makeJSONRPCTextResponse(t, `{"status":"pending","transition":"stop"}`) // no provisionerd yet
actual := pty.ReadLine(ctx)
testutil.RequireJSONEq(t, expected, actual)
})
t.Run("coder_workspace_transition_start", func(t *testing.T) {
// Given: a separate workspace in the stopped state
stopWs := dbfake.WorkspaceBuild(t, store, database.WorkspaceTable{
OrganizationID: owner.OrganizationID,
OwnerID: member.ID,
}).Seed(database.WorkspaceBuild{
Transition: database.WorkspaceTransitionStop,
}).Do()
// When: the coder_workspace_transition tool is called with a start transition
ctr := makeJSONRPCRequest(t, "tools/call", "coder_workspace_transition", map[string]any{
"workspace": stopWs.Workspace.ID.String(),
"transition": "start",
})
pty.WriteLine(ctr)
_ = pty.ReadLine(ctx) // skip the echo
// Then: the response is as expected
expected := makeJSONRPCTextResponse(t, `{"status":"pending","transition":"start"}`) // no provisionerd yet
actual := pty.ReadLine(ctx)
testutil.RequireJSONEq(t, expected, actual)
})
}
// makeJSONRPCRequest is a helper function that makes a JSON RPC request.
func makeJSONRPCRequest(t *testing.T, method, name string, args map[string]any) string {
t.Helper()
req := mcp.JSONRPCRequest{
ID: "1",
JSONRPC: "2.0",
Request: mcp.Request{Method: method},
Params: struct { // Unfortunately, there is no type for this yet.
Name string "json:\"name\""
Arguments map[string]any "json:\"arguments,omitempty\""
Meta *struct {
ProgressToken mcp.ProgressToken "json:\"progressToken,omitempty\""
} "json:\"_meta,omitempty\""
}{
Name: name,
Arguments: args,
},
}
bs, err := json.Marshal(req)
require.NoError(t, err, "failed to marshal JSON RPC request")
return string(bs)
}
// makeJSONRPCTextResponse is a helper function that makes a JSON RPC text response
func makeJSONRPCTextResponse(t *testing.T, text string) string {
t.Helper()
resp := mcp.JSONRPCResponse{
ID: "1",
JSONRPC: "2.0",
Result: mcp.CallToolResult{
Content: []mcp.Content{
mcp.NewTextContent(text),
},
},
}
bs, err := json.Marshal(resp)
require.NoError(t, err, "failed to marshal JSON RPC response")
return string(bs)
}
func unmarshalFromCallToolResult[T any](t *testing.T, raw string) T {
t.Helper()
var resp map[string]any
require.NoError(t, json.Unmarshal([]byte(raw), &resp), "failed to unmarshal JSON RPC response")
res, ok := resp["result"].(map[string]any)
require.True(t, ok, "expected a result field in the response")
ct, ok := res["content"].([]any)
require.True(t, ok, "expected a content field in the result")
require.Len(t, ct, 1, "expected a single content item in the result")
ct0, ok := ct[0].(map[string]any)
require.True(t, ok, "expected a content item in the result")
txt, ok := ct0["text"].(string)
require.True(t, ok, "expected a text field in the content item")
var actual T
require.NoError(t, json.Unmarshal([]byte(txt), &actual), "failed to unmarshal content")
return actual
}
// startTestMCPServer is a helper function that starts a MCP server listening on
// a pty. It is the responsibility of the caller to close the server.
func startTestMCPServer(ctx context.Context, t testing.TB, stdin io.Reader, stdout io.Writer) (*server.MCPServer, func() error) {
t.Helper()
mcpSrv := server.NewMCPServer(
"Test Server",
"0.0.0",
server.WithInstructions(""),
server.WithLogging(),
)
stdioSrv := server.NewStdioServer(mcpSrv)
cancelCtx, cancel := context.WithCancel(ctx)
closeCh := make(chan struct{})
done := make(chan error)
go func() {
defer close(done)
srvErr := stdioSrv.Listen(cancelCtx, stdin, stdout)
done <- srvErr
}()
go func() {
select {
case <-closeCh:
cancel()
case <-done:
cancel()
}
}()
return mcpSrv, func() error {
close(closeCh)
return <-done
}
}