mirror of
https://github.com/coder/coder.git
synced 2025-07-13 21:36:50 +00:00
feat: implement MCP HTTP server endpoint with authentication (#18670)
# Add MCP HTTP server with streamable transport support - Add MCP HTTP server with streamable transport support - Integrate with existing toolsdk for Coder workspace operations - Add comprehensive E2E tests with OAuth2 bearer token support - Register MCP endpoint at /api/experimental/mcp/http with authentication - Support RFC 6750 Bearer token authentication for MCP clients Change-Id: Ib9024569ae452729908797c42155006aa04330af Signed-off-by: Thomas Kosiewski <tk@coder.com>
This commit is contained in:
68
coderd/apidoc/docs.go
generated
68
coderd/apidoc/docs.go
generated
@ -11711,7 +11711,73 @@ const docTemplate = `{
|
||||
}
|
||||
},
|
||||
"codersdk.CreateTestAuditLogRequest": {
|
||||
"type": "object"
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"enum": [
|
||||
"create",
|
||||
"write",
|
||||
"delete",
|
||||
"start",
|
||||
"stop"
|
||||
],
|
||||
"allOf": [
|
||||
{
|
||||
"$ref": "#/definitions/codersdk.AuditAction"
|
||||
}
|
||||
]
|
||||
},
|
||||
"additional_fields": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "integer"
|
||||
}
|
||||
},
|
||||
"build_reason": {
|
||||
"enum": [
|
||||
"autostart",
|
||||
"autostop",
|
||||
"initiator"
|
||||
],
|
||||
"allOf": [
|
||||
{
|
||||
"$ref": "#/definitions/codersdk.BuildReason"
|
||||
}
|
||||
]
|
||||
},
|
||||
"organization_id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"request_id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"resource_id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"resource_type": {
|
||||
"enum": [
|
||||
"template",
|
||||
"template_version",
|
||||
"user",
|
||||
"workspace",
|
||||
"workspace_build",
|
||||
"git_ssh_key",
|
||||
"auditable_group"
|
||||
],
|
||||
"allOf": [
|
||||
{
|
||||
"$ref": "#/definitions/codersdk.ResourceType"
|
||||
}
|
||||
]
|
||||
},
|
||||
"time": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.CreateTokenRequest": {
|
||||
"type": "object",
|
||||
|
58
coderd/apidoc/swagger.json
generated
58
coderd/apidoc/swagger.json
generated
@ -10427,7 +10427,63 @@
|
||||
}
|
||||
},
|
||||
"codersdk.CreateTestAuditLogRequest": {
|
||||
"type": "object"
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"enum": ["create", "write", "delete", "start", "stop"],
|
||||
"allOf": [
|
||||
{
|
||||
"$ref": "#/definitions/codersdk.AuditAction"
|
||||
}
|
||||
]
|
||||
},
|
||||
"additional_fields": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "integer"
|
||||
}
|
||||
},
|
||||
"build_reason": {
|
||||
"enum": ["autostart", "autostop", "initiator"],
|
||||
"allOf": [
|
||||
{
|
||||
"$ref": "#/definitions/codersdk.BuildReason"
|
||||
}
|
||||
]
|
||||
},
|
||||
"organization_id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"request_id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"resource_id": {
|
||||
"type": "string",
|
||||
"format": "uuid"
|
||||
},
|
||||
"resource_type": {
|
||||
"enum": [
|
||||
"template",
|
||||
"template_version",
|
||||
"user",
|
||||
"workspace",
|
||||
"workspace_build",
|
||||
"git_ssh_key",
|
||||
"auditable_group"
|
||||
],
|
||||
"allOf": [
|
||||
{
|
||||
"$ref": "#/definitions/codersdk.ResourceType"
|
||||
}
|
||||
]
|
||||
},
|
||||
"time": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.CreateTokenRequest": {
|
||||
"type": "object",
|
||||
|
@ -972,6 +972,10 @@ func New(options *Options) *API {
|
||||
r.Route("/aitasks", func(r chi.Router) {
|
||||
r.Get("/prompts", api.aiTasksPrompts)
|
||||
})
|
||||
r.Route("/mcp", func(r chi.Router) {
|
||||
// MCP HTTP transport endpoint with mandatory authentication
|
||||
r.Mount("/http", api.mcpHTTPHandler())
|
||||
})
|
||||
})
|
||||
|
||||
r.Route("/api/v2", func(r chi.Router) {
|
||||
|
135
coderd/mcp/mcp.go
Normal file
135
coderd/mcp/mcp.go
Normal file
@ -0,0 +1,135 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
"github.com/mark3labs/mcp-go/server"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
|
||||
"github.com/coder/coder/v2/buildinfo"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/toolsdk"
|
||||
)
|
||||
|
||||
const (
|
||||
// MCPServerName is the name used for the MCP server.
|
||||
MCPServerName = "Coder"
|
||||
// MCPServerInstructions is the instructions text for the MCP server.
|
||||
MCPServerInstructions = "Coder MCP Server providing workspace and template management tools"
|
||||
)
|
||||
|
||||
// Server represents an MCP HTTP server instance
|
||||
type Server struct {
|
||||
Logger slog.Logger
|
||||
|
||||
// mcpServer is the underlying MCP server
|
||||
mcpServer *server.MCPServer
|
||||
|
||||
// streamableServer handles HTTP transport
|
||||
streamableServer *server.StreamableHTTPServer
|
||||
}
|
||||
|
||||
// NewServer creates a new MCP HTTP server
|
||||
func NewServer(logger slog.Logger) (*Server, error) {
|
||||
// Create the core MCP server
|
||||
mcpSrv := server.NewMCPServer(
|
||||
MCPServerName,
|
||||
buildinfo.Version(),
|
||||
server.WithInstructions(MCPServerInstructions),
|
||||
)
|
||||
|
||||
// Create logger adapter for mcp-go
|
||||
mcpLogger := &mcpLoggerAdapter{logger: logger}
|
||||
|
||||
// Create streamable HTTP server with configuration
|
||||
streamableServer := server.NewStreamableHTTPServer(mcpSrv,
|
||||
server.WithHeartbeatInterval(30*time.Second),
|
||||
server.WithLogger(mcpLogger),
|
||||
)
|
||||
|
||||
return &Server{
|
||||
Logger: logger,
|
||||
mcpServer: mcpSrv,
|
||||
streamableServer: streamableServer,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ServeHTTP implements http.Handler interface
|
||||
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
s.streamableServer.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// RegisterTools registers all available MCP tools with the server
|
||||
func (s *Server) RegisterTools(client *codersdk.Client) error {
|
||||
if client == nil {
|
||||
return xerrors.New("client cannot be nil: MCP HTTP server requires authenticated client")
|
||||
}
|
||||
|
||||
// Create tool dependencies
|
||||
toolDeps, err := toolsdk.NewDeps(client)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to initialize tool dependencies: %w", err)
|
||||
}
|
||||
|
||||
// Register all available tools
|
||||
for _, tool := range toolsdk.All {
|
||||
s.mcpServer.AddTools(mcpFromSDK(tool, toolDeps))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// mcpFromSDK adapts a toolsdk.Tool to go-mcp's server.ServerTool
|
||||
func mcpFromSDK(sdkTool toolsdk.GenericTool, tb toolsdk.Deps) server.ServerTool {
|
||||
if sdkTool.Schema.Properties == nil {
|
||||
panic("developer error: schema properties cannot be nil")
|
||||
}
|
||||
|
||||
return server.ServerTool{
|
||||
Tool: mcp.Tool{
|
||||
Name: sdkTool.Name,
|
||||
Description: sdkTool.Description,
|
||||
InputSchema: mcp.ToolInputSchema{
|
||||
Type: "object",
|
||||
Properties: sdkTool.Schema.Properties,
|
||||
Required: sdkTool.Schema.Required,
|
||||
},
|
||||
},
|
||||
Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
var buf bytes.Buffer
|
||||
if err := json.NewEncoder(&buf).Encode(request.Params.Arguments); err != nil {
|
||||
return nil, xerrors.Errorf("failed to encode request arguments: %w", err)
|
||||
}
|
||||
result, err := sdkTool.Handler(ctx, tb, buf.Bytes())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
mcp.NewTextContent(string(result)),
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// mcpLoggerAdapter adapts slog.Logger to the mcp-go util.Logger interface
|
||||
type mcpLoggerAdapter struct {
|
||||
logger slog.Logger
|
||||
}
|
||||
|
||||
func (l *mcpLoggerAdapter) Infof(format string, v ...any) {
|
||||
l.logger.Info(context.Background(), fmt.Sprintf(format, v...))
|
||||
}
|
||||
|
||||
func (l *mcpLoggerAdapter) Errorf(format string, v ...any) {
|
||||
l.logger.Error(context.Background(), fmt.Sprintf(format, v...))
|
||||
}
|
1223
coderd/mcp/mcp_e2e_test.go
Normal file
1223
coderd/mcp/mcp_e2e_test.go
Normal file
File diff suppressed because it is too large
Load Diff
133
coderd/mcp/mcp_test.go
Normal file
133
coderd/mcp/mcp_test.go
Normal file
@ -0,0 +1,133 @@
|
||||
package mcp_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
mcpserver "github.com/coder/coder/v2/coderd/mcp"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/toolsdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestMCPServer_Creation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := testutil.Logger(t)
|
||||
|
||||
server, err := mcpserver.NewServer(logger)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, server)
|
||||
}
|
||||
|
||||
func TestMCPServer_Handler(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := testutil.Logger(t)
|
||||
|
||||
server, err := mcpserver.NewServer(logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test that server implements http.Handler interface
|
||||
var handler http.Handler = server
|
||||
require.NotNil(t, handler)
|
||||
}
|
||||
|
||||
func TestMCPHTTP_InitializeRequest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := testutil.Logger(t)
|
||||
|
||||
server, err := mcpserver.NewServer(logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Use server directly as http.Handler
|
||||
handler := server
|
||||
|
||||
// Create initialize request
|
||||
initRequest := map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "initialize",
|
||||
"params": map[string]any{
|
||||
"protocolVersion": mcp.LATEST_PROTOCOL_VERSION,
|
||||
"capabilities": map[string]any{},
|
||||
"clientInfo": map[string]any{
|
||||
"name": "test-client",
|
||||
"version": "1.0.0",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
body, err := json.Marshal(initRequest)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json,text/event-stream")
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
handler.ServeHTTP(recorder, req)
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Logf("Response body: %s", recorder.Body.String())
|
||||
}
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
|
||||
// Check that a session ID was returned
|
||||
sessionID := recorder.Header().Get("Mcp-Session-Id")
|
||||
assert.NotEmpty(t, sessionID)
|
||||
|
||||
// Parse response
|
||||
var response map[string]any
|
||||
err = json.Unmarshal(recorder.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "2.0", response["jsonrpc"])
|
||||
assert.Equal(t, float64(1), response["id"])
|
||||
|
||||
result, ok := response["result"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
|
||||
assert.Equal(t, mcp.LATEST_PROTOCOL_VERSION, result["protocolVersion"])
|
||||
assert.Contains(t, result, "capabilities")
|
||||
assert.Contains(t, result, "serverInfo")
|
||||
}
|
||||
|
||||
func TestMCPHTTP_ToolRegistration(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := testutil.Logger(t)
|
||||
|
||||
server, err := mcpserver.NewServer(logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test registering tools with nil client should return error
|
||||
err = server.RegisterTools(nil)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "client cannot be nil", "Should reject nil client with appropriate error message")
|
||||
|
||||
// Test registering tools with valid client should succeed
|
||||
client := &codersdk.Client{}
|
||||
err = server.RegisterTools(client)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify that all expected tools are available in the toolsdk
|
||||
expectedToolCount := len(toolsdk.All)
|
||||
require.Greater(t, expectedToolCount, 0, "Should have some tools available")
|
||||
|
||||
// Verify specific tools are present by checking tool names
|
||||
toolNames := make([]string, len(toolsdk.All))
|
||||
for i, tool := range toolsdk.All {
|
||||
toolNames[i] = tool.Name
|
||||
}
|
||||
require.Contains(t, toolNames, toolsdk.ToolNameReportTask, "Should include ReportTask (UserClientOptional)")
|
||||
require.Contains(t, toolNames, toolsdk.ToolNameGetAuthenticatedUser, "Should include GetAuthenticatedUser (requires auth)")
|
||||
}
|
39
coderd/mcp_http.go
Normal file
39
coderd/mcp_http.go
Normal file
@ -0,0 +1,39 @@
|
||||
package coderd
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"cdr.dev/slog"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
"github.com/coder/coder/v2/coderd/mcp"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
// mcpHTTPHandler creates the MCP HTTP transport handler
|
||||
func (api *API) mcpHTTPHandler() http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Create MCP server instance for each request
|
||||
mcpServer, err := mcp.NewServer(api.Logger.Named("mcp"))
|
||||
if err != nil {
|
||||
api.Logger.Error(r.Context(), "failed to create MCP server", slog.Error(err))
|
||||
httpapi.Write(r.Context(), w, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "MCP server initialization failed",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
authenticatedClient := codersdk.New(api.AccessURL)
|
||||
// Extract the original session token from the request
|
||||
authenticatedClient.SetSessionToken(httpmw.APITokenFromRequest(r))
|
||||
|
||||
// Register tools with authenticated client
|
||||
if err := mcpServer.RegisterTools(authenticatedClient); err != nil {
|
||||
api.Logger.Warn(r.Context(), "failed to register MCP tools", slog.Error(err))
|
||||
}
|
||||
|
||||
// Handle the MCP request
|
||||
mcpServer.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
@ -536,12 +536,13 @@ func (api *API) postOAuth2ClientRegistration(rw http.ResponseWriter, r *http.Req
|
||||
|
||||
// Store in database - use system context since this is a public endpoint
|
||||
now := dbtime.Now()
|
||||
clientName := req.GenerateClientName()
|
||||
//nolint:gocritic // Dynamic client registration is a public endpoint, system access required
|
||||
app, err := api.Database.InsertOAuth2ProviderApp(dbauthz.AsSystemRestricted(ctx), database.InsertOAuth2ProviderAppParams{
|
||||
ID: clientID,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
Name: req.GenerateClientName(),
|
||||
Name: clientName,
|
||||
Icon: req.LogoURI,
|
||||
CallbackURL: req.RedirectURIs[0], // Primary redirect URI
|
||||
RedirectUris: req.RedirectURIs,
|
||||
@ -566,7 +567,11 @@ func (api *API) postOAuth2ClientRegistration(rw http.ResponseWriter, r *http.Req
|
||||
RegistrationClientUri: sql.NullString{String: fmt.Sprintf("%s/oauth2/clients/%s", api.AccessURL.String(), clientID), Valid: true},
|
||||
})
|
||||
if err != nil {
|
||||
api.Logger.Error(ctx, "failed to store oauth2 client registration", slog.Error(err))
|
||||
api.Logger.Error(ctx, "failed to store oauth2 client registration",
|
||||
slog.Error(err),
|
||||
slog.F("client_name", clientName),
|
||||
slog.F("client_id", clientID.String()),
|
||||
slog.F("redirect_uris", req.RedirectURIs))
|
||||
writeOAuth2RegistrationError(ctx, rw, http.StatusInternalServerError,
|
||||
"server_error", "Failed to store client registration")
|
||||
return
|
||||
|
Reference in New Issue
Block a user