diff --git a/agent/agent_test.go b/agent/agent_test.go index e38d91001d..9c5605824f 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -58,9 +58,12 @@ func TestAgent(t *testing.T) { t.Run("SSH", func(t *testing.T) { t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + conn, stats := setupAgent(t, codersdk.WorkspaceAgentMetadata{}, 0) - sshClient, err := conn.SSHClient() + sshClient, err := conn.SSHClient(ctx) require.NoError(t, err) defer sshClient.Close() session, err := sshClient.NewSession() @@ -75,9 +78,12 @@ func TestAgent(t *testing.T) { t.Run("ReconnectingPTY", func(t *testing.T) { t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + conn, stats := setupAgent(t, codersdk.WorkspaceAgentMetadata{}, 0) - ptyConn, err := conn.ReconnectingPTY(uuid.NewString(), 128, 128, "/bin/bash") + ptyConn, err := conn.ReconnectingPTY(ctx, uuid.NewString(), 128, 128, "/bin/bash") require.NoError(t, err) defer ptyConn.Close() @@ -217,6 +223,8 @@ func TestAgent(t *testing.T) { t.Run("SFTP", func(t *testing.T) { t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() u, err := user.Current() require.NoError(t, err, "get current user") home := u.HomeDir @@ -224,7 +232,7 @@ func TestAgent(t *testing.T) { home = "/" + strings.ReplaceAll(home, "\\", "/") } conn, _ := setupAgent(t, codersdk.WorkspaceAgentMetadata{}, 0) - sshClient, err := conn.SSHClient() + sshClient, err := conn.SSHClient(ctx) require.NoError(t, err) defer sshClient.Close() client, err := sftp.NewClient(sshClient) @@ -250,8 +258,11 @@ func TestAgent(t *testing.T) { t.Run("SCP", func(t *testing.T) { t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + conn, _ := setupAgent(t, codersdk.WorkspaceAgentMetadata{}, 0) - sshClient, err := conn.SSHClient() + sshClient, err := conn.SSHClient(ctx) require.NoError(t, err) defer sshClient.Close() scpClient, err := scp.NewClientBySSH(sshClient) @@ -386,9 +397,12 @@ func TestAgent(t *testing.T) { t.Skip("ConPTY appears to be inconsistent on Windows.") } + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + conn, _ := setupAgent(t, codersdk.WorkspaceAgentMetadata{}, 0) id := uuid.NewString() - netConn, err := conn.ReconnectingPTY(id, 100, 100, "/bin/bash") + netConn, err := conn.ReconnectingPTY(ctx, id, 100, 100, "/bin/bash") require.NoError(t, err) bufRead := bufio.NewReader(netConn) @@ -426,7 +440,7 @@ func TestAgent(t *testing.T) { expectLine(matchEchoOutput) _ = netConn.Close() - netConn, err = conn.ReconnectingPTY(id, 100, 100, "/bin/bash") + netConn, err = conn.ReconnectingPTY(ctx, id, 100, 100, "/bin/bash") require.NoError(t, err) bufRead = bufio.NewReader(netConn) @@ -504,12 +518,14 @@ func TestAgent(t *testing.T) { t.Run("Speedtest", func(t *testing.T) { t.Parallel() t.Skip("This test is relatively flakey because of Tailscale's speedtest code...") + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() derpMap := tailnettest.RunDERPAndSTUN(t) conn, _ := setupAgent(t, codersdk.WorkspaceAgentMetadata{ DERPMap: derpMap, }, 0) defer conn.Close() - res, err := conn.Speedtest(speedtest.Upload, 250*time.Millisecond) + res, err := conn.Speedtest(ctx, speedtest.Upload, 250*time.Millisecond) require.NoError(t, err) t.Logf("%.2f MBits/s", res[len(res)-1].MBitsPerSecond()) }) @@ -599,7 +615,10 @@ func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exe if err != nil { return } - ssh, err := agentConn.SSH() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + ssh, err := agentConn.SSH(ctx) + cancel() if err != nil { _ = conn.Close() return @@ -626,8 +645,10 @@ func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exe } func setupSSHSession(t *testing.T, options codersdk.WorkspaceAgentMetadata) *ssh.Session { + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() conn, _ := setupAgent(t, options, 0) - sshClient, err := conn.SSHClient() + sshClient, err := conn.SSHClient(ctx) require.NoError(t, err) t.Cleanup(func() { _ = sshClient.Close() diff --git a/cli/agent_test.go b/cli/agent_test.go index a7ccd3b2b3..56da1dd554 100644 --- a/cli/agent_test.go +++ b/cli/agent_test.go @@ -198,7 +198,7 @@ func TestWorkspaceAgent(t *testing.T) { return err == nil }, testutil.WaitMedium, testutil.IntervalFast) - sshClient, err := dialer.SSHClient() + sshClient, err := dialer.SSHClient(ctx) require.NoError(t, err) defer sshClient.Close() session, err := sshClient.NewSession() diff --git a/cli/configssh_test.go b/cli/configssh_test.go index 74faa19be5..98791f3031 100644 --- a/cli/configssh_test.go +++ b/cli/configssh_test.go @@ -28,6 +28,7 @@ import ( "github.com/coder/coder/provisioner/echo" "github.com/coder/coder/provisionersdk/proto" "github.com/coder/coder/pty/ptytest" + "github.com/coder/coder/testutil" ) func sshConfigFileName(t *testing.T) (sshConfig string) { @@ -131,7 +132,9 @@ func TestConfigSSH(t *testing.T) { if err != nil { break } - ssh, err := agentConn.SSH() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + ssh, err := agentConn.SSH(ctx) + cancel() assert.NoError(t, err) wg.Add(2) go func() { diff --git a/cli/loadtest.go b/cli/loadtest.go index fcec6019e8..b42b2b7604 100644 --- a/cli/loadtest.go +++ b/cli/loadtest.go @@ -8,20 +8,30 @@ import ( "os" "strconv" "strings" + "sync" "time" "github.com/spf13/cobra" + "go.opentelemetry.io/otel/trace" "golang.org/x/xerrors" "github.com/coder/coder/cli/cliflag" + "github.com/coder/coder/coderd/tracing" "github.com/coder/coder/codersdk" "github.com/coder/coder/loadtest/harness" ) +const loadtestTracerName = "coder_loadtest" + func loadtest() *cobra.Command { var ( configPath string outputSpecs []string + + traceEnable bool + traceCoder bool + traceHoneycombAPIKey string + tracePropagate bool ) cmd := &cobra.Command{ Use: "loadtest --config [--output json[:path]] [--output text[:path]]]", @@ -53,6 +63,8 @@ func loadtest() *cobra.Command { Hidden: true, Args: cobra.ExactArgs(0), RunE: func(cmd *cobra.Command, args []string) error { + ctx := tracing.SetTracerName(cmd.Context(), loadtestTracerName) + config, err := loadLoadTestConfigFile(configPath, cmd.InOrStdin()) if err != nil { return err @@ -67,7 +79,7 @@ func loadtest() *cobra.Command { return err } - me, err := client.User(cmd.Context(), codersdk.Me) + me, err := client.User(ctx, codersdk.Me) if err != nil { return xerrors.Errorf("fetch current user: %w", err) } @@ -84,11 +96,43 @@ func loadtest() *cobra.Command { } } if !ok { - return xerrors.Errorf("Not logged in as site owner. Load testing is only available to site owners.") + return xerrors.Errorf("Not logged in as a site owner. Load testing is only available to site owners.") } - // Disable ratelimits for future requests. + // Setup tracing and start a span. + var ( + shouldTrace = traceEnable || traceCoder || traceHoneycombAPIKey != "" + tracerProvider trace.TracerProvider = trace.NewNoopTracerProvider() + closeTracingOnce sync.Once + closeTracing = func(_ context.Context) error { + return nil + } + ) + if shouldTrace { + tracerProvider, closeTracing, err = tracing.TracerProvider(ctx, loadtestTracerName, tracing.TracerOpts{ + Default: traceEnable, + Coder: traceCoder, + Honeycomb: traceHoneycombAPIKey, + }) + if err != nil { + return xerrors.Errorf("initialize tracing: %w", err) + } + defer func() { + closeTracingOnce.Do(func() { + // Allow time for traces to flush even if command + // context is canceled. + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + _ = closeTracing(ctx) + }) + }() + } + tracer := tracerProvider.Tracer(loadtestTracerName) + + // Disable ratelimits and propagate tracing spans for future + // requests. Individual tests will setup their own loggers. client.BypassRatelimits = true + client.PropagateTracing = tracePropagate // Prepare the test. strategy := config.Strategy.ExecutionStrategy() @@ -99,18 +143,22 @@ func loadtest() *cobra.Command { for j := 0; j < t.Count; j++ { id := strconv.Itoa(j) - runner, err := t.NewRunner(client) + runner, err := t.NewRunner(client.Clone()) if err != nil { return xerrors.Errorf("create %q runner for %s/%s: %w", t.Type, name, id, err) } - th.AddRun(name, id, runner) + th.AddRun(name, id, &runnableTraceWrapper{ + tracer: tracer, + spanName: fmt.Sprintf("%s/%s", name, id), + runner: runner, + }) } } _, _ = fmt.Fprintln(cmd.ErrOrStderr(), "Running load test...") - testCtx := cmd.Context() + testCtx := ctx if config.Timeout > 0 { var cancel func() testCtx, cancel = context.WithTimeout(testCtx, time.Duration(config.Timeout)) @@ -158,11 +206,24 @@ func loadtest() *cobra.Command { // Cleanup. _, _ = fmt.Fprintln(cmd.ErrOrStderr(), "\nCleaning up...") - err = th.Cleanup(cmd.Context()) + err = th.Cleanup(ctx) if err != nil { return xerrors.Errorf("cleanup tests: %w", err) } + // Upload traces. + if shouldTrace { + _, _ = fmt.Fprintln(cmd.ErrOrStderr(), "\nUploading traces...") + closeTracingOnce.Do(func() { + ctx, cancel := context.WithTimeout(ctx, 1*time.Minute) + defer cancel() + err := closeTracing(ctx) + if err != nil { + _, _ = fmt.Fprintf(cmd.ErrOrStderr(), "\nError uploading traces: %+v\n", err) + } + }) + } + if res.TotalFail > 0 { return xerrors.New("load test failed, see above for more details") } @@ -173,6 +234,12 @@ func loadtest() *cobra.Command { cliflag.StringVarP(cmd.Flags(), &configPath, "config", "", "CODER_LOADTEST_CONFIG_PATH", "", "Path to the load test configuration file, or - to read from stdin.") cliflag.StringArrayVarP(cmd.Flags(), &outputSpecs, "output", "", "CODER_LOADTEST_OUTPUTS", []string{"text"}, "Output formats, see usage for more information.") + + cliflag.BoolVarP(cmd.Flags(), &traceEnable, "trace", "", "CODER_LOADTEST_TRACE", false, "Whether application tracing data is collected. It exports to a backend configured by environment variables. See: https://github.com/open-telemetry/opentelemetry-specification/blob/main/specification/protocol/exporter.md") + cliflag.BoolVarP(cmd.Flags(), &traceCoder, "trace-coder", "", "CODER_LOADTEST_TRACE_CODER", false, "Whether opentelemetry traces are sent to Coder. We recommend keeping this disabled unless we advise you to enable it.") + cliflag.StringVarP(cmd.Flags(), &traceHoneycombAPIKey, "trace-honeycomb-api-key", "", "CODER_LOADTEST_TRACE_HONEYCOMB_API_KEY", "", "Enables trace exporting to Honeycomb.io using the provided API key.") + cliflag.BoolVarP(cmd.Flags(), &tracePropagate, "trace-propagate", "", "CODER_LOADTEST_TRACE_PROPAGATE", false, "Enables trace propagation to the Coder backend, which will be used to correlate server-side spans with client-side spans. Only enable this if the server is configured with the exact same tracing configuration as the client.") + return cmd } @@ -271,3 +338,53 @@ func parseLoadTestOutputs(outputs []string) ([]loadTestOutput, error) { return out, nil } + +type runnableTraceWrapper struct { + tracer trace.Tracer + spanName string + runner harness.Runnable + + span trace.Span +} + +var _ harness.Runnable = &runnableTraceWrapper{} +var _ harness.Cleanable = &runnableTraceWrapper{} + +func (r *runnableTraceWrapper) Run(ctx context.Context, id string, logs io.Writer) error { + ctx, span := r.tracer.Start(ctx, r.spanName, trace.WithNewRoot()) + defer span.End() + r.span = span + + traceID := "unknown trace ID" + spanID := "unknown span ID" + if span.SpanContext().HasTraceID() { + traceID = span.SpanContext().TraceID().String() + } + if span.SpanContext().HasSpanID() { + spanID = span.SpanContext().SpanID().String() + } + _, _ = fmt.Fprintf(logs, "Trace ID: %s\n", traceID) + _, _ = fmt.Fprintf(logs, "Span ID: %s\n\n", spanID) + + // Make a separate span for the run itself so the sub-spans are grouped + // neatly. The cleanup span is also a child of the above span so this is + // important for readability. + ctx2, span2 := r.tracer.Start(ctx, r.spanName+" run") + defer span2.End() + return r.runner.Run(ctx2, id, logs) +} + +func (r *runnableTraceWrapper) Cleanup(ctx context.Context, id string) error { + c, ok := r.runner.(harness.Cleanable) + if !ok { + return nil + } + + if r.span != nil { + ctx = trace.ContextWithSpanContext(ctx, r.span.SpanContext()) + } + ctx, span := r.tracer.Start(ctx, r.spanName+" cleanup") + defer span.End() + + return c.Cleanup(ctx, id) +} diff --git a/cli/loadtest_test.go b/cli/loadtest_test.go index eda0084372..b20695cfc6 100644 --- a/cli/loadtest_test.go +++ b/cli/loadtest_test.go @@ -277,6 +277,8 @@ func TestLoadTest(t *testing.T) { require.NoError(t, err, msg) } + t.Logf("output %d:\n\n%s", i, string(b)) + switch output.format { case "text": require.Contains(t, string(b), "Test results:", msg) diff --git a/cli/server.go b/cli/server.go index 1685e668cd..da1dcfc28e 100644 --- a/cli/server.go +++ b/cli/server.go @@ -128,8 +128,9 @@ func Server(vip *viper.Viper, newAPI func(context.Context, *coderd.Options) (*co if cfg.Trace.Enable.Value || shouldCoderTrace { sdkTracerProvider, closeTracing, err := tracing.TracerProvider(ctx, "coderd", tracing.TracerOpts{ - Default: cfg.Trace.Enable.Value, - Coder: shouldCoderTrace, + Default: cfg.Trace.Enable.Value, + Coder: shouldCoderTrace, + Honeycomb: cfg.Trace.HoneycombAPIKey.Value, }) if err != nil { logger.Warn(ctx, "start telemetry exporter", slog.Error(err)) diff --git a/cli/speedtest.go b/cli/speedtest.go index bf87607ef7..873e5e2794 100644 --- a/cli/speedtest.go +++ b/cli/speedtest.go @@ -95,7 +95,7 @@ func speedtest() *cobra.Command { dir = tsspeedtest.Upload } cmd.Printf("Starting a %ds %s test...\n", int(duration.Seconds()), dir) - results, err := conn.Speedtest(dir, duration) + results, err := conn.Speedtest(ctx, dir, duration) if err != nil { return err } diff --git a/cli/ssh.go b/cli/ssh.go index f4dcce3180..b72ebd398d 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -100,7 +100,7 @@ func ssh() *cobra.Command { defer stopPolling() if stdio { - rawSSH, err := conn.SSH() + rawSSH, err := conn.SSH(ctx) if err != nil { return err } @@ -113,7 +113,7 @@ func ssh() *cobra.Command { return nil } - sshClient, err := conn.SSHClient() + sshClient, err := conn.SSHClient(ctx) if err != nil { return err } diff --git a/coderd/activitybump_test.go b/coderd/activitybump_test.go index a67dc66e1c..f9c0736e0c 100644 --- a/coderd/activitybump_test.go +++ b/coderd/activitybump_test.go @@ -88,7 +88,7 @@ func TestWorkspaceActivityBump(t *testing.T) { require.NoError(t, err) defer conn.Close() - sshConn, err := conn.SSHClient() + sshConn, err := conn.SSHClient(ctx) require.NoError(t, err) _ = sshConn.Close() diff --git a/coderd/templates_test.go b/coderd/templates_test.go index 885d683fb0..787bec3db0 100644 --- a/coderd/templates_test.go +++ b/coderd/templates_test.go @@ -633,7 +633,7 @@ func TestTemplateMetrics(t *testing.T) { _ = conn.Close() }() - sshConn, err := conn.SSHClient() + sshConn, err := conn.SSHClient(ctx) require.NoError(t, err) _ = sshConn.Close() diff --git a/coderd/tracing/exporter.go b/coderd/tracing/exporter.go index de4579889f..b83a2b5a11 100644 --- a/coderd/tracing/exporter.go +++ b/coderd/tracing/exporter.go @@ -4,6 +4,7 @@ import ( "context" "github.com/go-logr/logr" + "github.com/hashicorp/go-multierror" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/exporters/otlp/otlptrace" "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc" @@ -82,11 +83,23 @@ func TracerProvider(ctx context.Context, service string, opts TracerOpts) (*sdkt otel.SetLogger(logr.Discard()) return tracerProvider, func(ctx context.Context) error { - for _, close := range closers { - _ = close(ctx) + var merr error + err := tracerProvider.ForceFlush(ctx) + if err != nil { + merr = multierror.Append(merr, xerrors.Errorf("tracerProvider.ForceFlush(): %w", err)) } - _ = tracerProvider.Shutdown(ctx) - return nil + for i, closer := range closers { + err = closer(ctx) + if err != nil { + merr = multierror.Append(merr, xerrors.Errorf("closer() %d: %w", i, err)) + } + } + err = tracerProvider.Shutdown(ctx) + if err != nil { + merr = multierror.Append(merr, xerrors.Errorf("tracerProvider.Shutdown(): %w", err)) + } + + return merr }, nil } diff --git a/coderd/tracing/httpmw.go b/coderd/tracing/httpmw.go index 9760909ee5..97f9755d2e 100644 --- a/coderd/tracing/httpmw.go +++ b/coderd/tracing/httpmw.go @@ -6,6 +6,8 @@ import ( "net/http" "github.com/go-chi/chi/v5" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/propagation" semconv "go.opentelemetry.io/otel/semconv/v1.11.0" "go.opentelemetry.io/otel/trace" ) @@ -23,11 +25,27 @@ func Middleware(tracerProvider trace.TracerProvider) func(http.Handler) http.Han return } + // Extract the trace context from the request headers. + tmp := otel.GetTextMapPropagator() + hc := propagation.HeaderCarrier(r.Header) + ctx := tmp.Extract(r.Context(), hc) + // start span with default span name. Span name will be updated to "method route" format once request finishes. - ctx, span := tracer.Start(r.Context(), fmt.Sprintf("%s %s", r.Method, r.RequestURI)) + ctx, span := tracer.Start(ctx, fmt.Sprintf("%s %s", r.Method, r.RequestURI)) defer span.End() r = r.WithContext(ctx) + if span.SpanContext().HasTraceID() && span.SpanContext().HasSpanID() { + // Technically these values are included in the Traceparent + // header, but they are easier to read for humans this way. + rw.Header().Set("X-Trace-ID", span.SpanContext().TraceID().String()) + rw.Header().Set("X-Span-ID", span.SpanContext().SpanID().String()) + + // Inject the trace context into the response headers. + hc := propagation.HeaderCarrier(rw.Header()) + tmp.Inject(ctx, hc) + } + sw, ok := rw.(*StatusWriter) if !ok { panic(fmt.Sprintf("ResponseWriter not a *tracing.StatusWriter; got %T", rw)) @@ -62,6 +80,37 @@ func EndHTTPSpan(r *http.Request, status int, span trace.Span) { span.End() } -func StartSpan(ctx context.Context, opts ...trace.SpanStartOption) (context.Context, trace.Span) { - return trace.SpanFromContext(ctx).TracerProvider().Tracer(TracerName).Start(ctx, FuncNameSkip(1), opts...) +type tracerNameKey struct{} + +// SetTracerName sets the tracer name that will be used by all spans created +// from the context. +func SetTracerName(ctx context.Context, tracerName string) context.Context { + return context.WithValue(ctx, tracerNameKey{}, tracerName) +} + +// GetTracerName returns the tracer name from the context, or TracerName if none +// is set. +func GetTracerName(ctx context.Context) string { + if tracerName, ok := ctx.Value(tracerNameKey{}).(string); ok { + return tracerName + } + + return TracerName +} + +// StartSpan calls StartSpanWithName with the name set to the caller's function +// name. +func StartSpan(ctx context.Context, opts ...trace.SpanStartOption) (context.Context, trace.Span) { + return StartSpanWithName(ctx, FuncNameSkip(1), opts...) +} + +// StartSpanWithName starts a new span with the given name from the context. If +// a tracer name was set on the context (or one of its parents), it will be used +// as the tracer name instead of the default TracerName. +func StartSpanWithName(ctx context.Context, name string, opts ...trace.SpanStartOption) (context.Context, trace.Span) { + tracerName := GetTracerName(ctx) + return trace.SpanFromContext(ctx). + TracerProvider(). + Tracer(tracerName). + Start(ctx, name, opts...) } diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index 3a2ff65da1..5b1ea4b48a 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -247,7 +247,7 @@ func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) { return } defer release() - ptNetConn, err := agentConn.ReconnectingPTY(reconnect.String(), uint16(height), uint16(width), r.URL.Query().Get("command")) + ptNetConn, err := agentConn.ReconnectingPTY(ctx, reconnect.String(), uint16(height), uint16(width), r.URL.Query().Get("command")) if err != nil { _ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("dial: %s", err)) return diff --git a/coderd/workspaceagents_test.go b/coderd/workspaceagents_test.go index a0b610379b..3e62311335 100644 --- a/coderd/workspaceagents_test.go +++ b/coderd/workspaceagents_test.go @@ -260,7 +260,7 @@ func TestWorkspaceAgentTailnet(t *testing.T) { }) require.NoError(t, err) defer conn.Close() - sshClient, err := conn.SSHClient() + sshClient, err := conn.SSHClient(ctx) require.NoError(t, err) session, err := sshClient.NewSession() require.NoError(t, err) diff --git a/codersdk/agentconn.go b/codersdk/agentconn.go index ddfb9541a1..a68ab0672a 100644 --- a/codersdk/agentconn.go +++ b/codersdk/agentconn.go @@ -20,6 +20,7 @@ import ( "tailscale.com/net/speedtest" "tailscale.com/tailcfg" + "github.com/coder/coder/coderd/tracing" "github.com/coder/coder/tailnet" ) @@ -133,6 +134,9 @@ type AgentConn struct { } func (c *AgentConn) Ping(ctx context.Context) (time.Duration, error) { + ctx, span := tracing.StartSpan(ctx) + defer span.End() + errCh := make(chan error, 1) durCh := make(chan time.Duration, 1) go c.Conn.Ping(TailnetIP, tailcfg.PingDisco, func(pr *ipnstate.PingResult) { @@ -171,8 +175,11 @@ type ReconnectingPTYInit struct { Command string } -func (c *AgentConn) ReconnectingPTY(id string, height, width uint16, command string) (net.Conn, error) { - conn, err := c.DialContextTCP(context.Background(), netip.AddrPortFrom(TailnetIP, uint16(TailnetReconnectingPTYPort))) +func (c *AgentConn) ReconnectingPTY(ctx context.Context, id string, height, width uint16, command string) (net.Conn, error) { + ctx, span := tracing.StartSpan(ctx) + defer span.End() + + conn, err := c.DialContextTCP(ctx, netip.AddrPortFrom(TailnetIP, uint16(TailnetReconnectingPTYPort))) if err != nil { return nil, err } @@ -197,14 +204,18 @@ func (c *AgentConn) ReconnectingPTY(id string, height, width uint16, command str return conn, nil } -func (c *AgentConn) SSH() (net.Conn, error) { - return c.DialContextTCP(context.Background(), netip.AddrPortFrom(TailnetIP, uint16(TailnetSSHPort))) +func (c *AgentConn) SSH(ctx context.Context) (net.Conn, error) { + ctx, span := tracing.StartSpan(ctx) + defer span.End() + return c.DialContextTCP(ctx, netip.AddrPortFrom(TailnetIP, uint16(TailnetSSHPort))) } // SSHClient calls SSH to create a client that uses a weak cipher // for high throughput. -func (c *AgentConn) SSHClient() (*ssh.Client, error) { - netConn, err := c.SSH() +func (c *AgentConn) SSHClient(ctx context.Context) (*ssh.Client, error) { + ctx, span := tracing.StartSpan(ctx) + defer span.End() + netConn, err := c.SSH(ctx) if err != nil { return nil, xerrors.Errorf("ssh: %w", err) } @@ -220,8 +231,10 @@ func (c *AgentConn) SSHClient() (*ssh.Client, error) { return ssh.NewClient(sshConn, channels, requests), nil } -func (c *AgentConn) Speedtest(direction speedtest.Direction, duration time.Duration) ([]speedtest.Result, error) { - speedConn, err := c.DialContextTCP(context.Background(), netip.AddrPortFrom(TailnetIP, uint16(TailnetSpeedtestPort))) +func (c *AgentConn) Speedtest(ctx context.Context, direction speedtest.Direction, duration time.Duration) ([]speedtest.Result, error) { + ctx, span := tracing.StartSpan(ctx) + defer span.End() + speedConn, err := c.DialContextTCP(ctx, netip.AddrPortFrom(TailnetIP, uint16(TailnetSpeedtestPort))) if err != nil { return nil, xerrors.Errorf("dial speedtest: %w", err) } @@ -233,6 +246,8 @@ func (c *AgentConn) Speedtest(direction speedtest.Direction, duration time.Durat } func (c *AgentConn) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) { + ctx, span := tracing.StartSpan(ctx) + defer span.End() if network == "unix" { return nil, xerrors.New("network must be tcp or udp") } @@ -277,6 +292,8 @@ func (c *AgentConn) statisticsClient() *http.Client { } func (c *AgentConn) doStatisticsRequest(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) { + ctx, span := tracing.StartSpan(ctx) + defer span.End() host := net.JoinHostPort(TailnetIP.String(), strconv.Itoa(TailnetStatisticsPort)) url := fmt.Sprintf("http://%s%s", host, path) @@ -309,6 +326,8 @@ type ListeningPort struct { } func (c *AgentConn) ListeningPorts(ctx context.Context) (ListeningPortsResponse, error) { + ctx, span := tracing.StartSpan(ctx) + defer span.End() res, err := c.doStatisticsRequest(ctx, http.MethodGet, "/api/v0/listening-ports", nil) if err != nil { return ListeningPortsResponse{}, xerrors.Errorf("do request: %w", err) diff --git a/codersdk/client.go b/codersdk/client.go index a961844a0e..58c18cd296 100644 --- a/codersdk/client.go +++ b/codersdk/client.go @@ -12,7 +12,15 @@ import ( "net/url" "strings" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/propagation" + semconv "go.opentelemetry.io/otel/semconv/v1.11.0" + "go.opentelemetry.io/otel/trace" "golang.org/x/xerrors" + + "github.com/coder/coder/coderd/tracing" + + "cdr.dev/slog" ) // These cookies are Coder-specific. If a new one is added or changed, the name @@ -30,6 +38,13 @@ const ( BypassRatelimitHeader = "X-Coder-Bypass-Ratelimit" ) +var loggableMimeTypes = map[string]struct{}{ + "application/json": {}, + "text/plain": {}, + // lots of webserver error pages are HTML + "text/html": {}, +} + // New creates a Coder client for the provided URL. func New(serverURL *url.URL) *Client { return &Client{ @@ -45,9 +60,35 @@ type Client struct { SessionToken string URL *url.URL + // Logger can be provided to log requests. Request method, URL and response + // status code will be logged by default. + Logger slog.Logger + // LogBodies determines whether the request and response bodies are logged + // to the provided Logger. This is useful for debugging or testing. + LogBodies bool + // BypassRatelimits is an optional flag that can be set by the site owner to // disable ratelimit checks for the client. BypassRatelimits bool + + // PropagateTracing is an optional flag that can be set to propagate tracing + // spans to the Coder API. This is useful for seeing the entire request + // from end-to-end. + PropagateTracing bool +} + +func (c *Client) Clone() *Client { + hc := *c.HTTPClient + u := *c.URL + return &Client{ + HTTPClient: &hc, + SessionToken: c.SessionToken, + URL: &u, + Logger: c.Logger, + LogBodies: c.LogBodies, + BypassRatelimits: c.BypassRatelimits, + PropagateTracing: c.PropagateTracing, + } } type RequestOption func(*http.Request) @@ -63,30 +104,46 @@ func WithQueryParam(key, value string) RequestOption { } } -// Request performs an HTTP request with the body provided. -// The caller is responsible for closing the response body. +// Request performs a HTTP request with the body provided. The caller is +// responsible for closing the response body. func (c *Client) Request(ctx context.Context, method, path string, body interface{}, opts ...RequestOption) (*http.Response, error) { + ctx, span := tracing.StartSpanWithName(ctx, tracing.FuncNameSkip(1)) + defer span.End() + serverURL, err := c.URL.Parse(path) if err != nil { return nil, xerrors.Errorf("parse url: %w", err) } - var buf bytes.Buffer + var r io.Reader if body != nil { if data, ok := body.([]byte); ok { - buf = *bytes.NewBuffer(data) + r = bytes.NewReader(data) } else { // Assume JSON if not bytes. - enc := json.NewEncoder(&buf) + buf := bytes.NewBuffer(nil) + enc := json.NewEncoder(buf) enc.SetEscapeHTML(false) err = enc.Encode(body) if err != nil { return nil, xerrors.Errorf("encode body: %w", err) } + + r = buf } } - req, err := http.NewRequestWithContext(ctx, method, serverURL.String(), &buf) + // Copy the request body so we can log it. + var reqBody []byte + if r != nil && c.LogBodies { + reqBody, err = io.ReadAll(r) + if err != nil { + return nil, xerrors.Errorf("read request body: %w", err) + } + r = bytes.NewReader(reqBody) + } + + req, err := http.NewRequestWithContext(ctx, method, serverURL.String(), r) if err != nil { return nil, xerrors.Errorf("create request: %w", err) } @@ -95,17 +152,61 @@ func (c *Client) Request(ctx context.Context, method, path string, body interfac req.Header.Set(BypassRatelimitHeader, "true") } - if body != nil { + if r != nil { req.Header.Set("Content-Type", "application/json") } for _, opt := range opts { opt(req) } + span.SetAttributes(semconv.NetAttributesFromHTTPRequest("tcp", req)...) + span.SetAttributes(semconv.HTTPClientAttributesFromHTTPRequest(req)...) + + // Inject tracing headers if enabled. + if c.PropagateTracing { + tmp := otel.GetTextMapPropagator() + hc := propagation.HeaderCarrier(req.Header) + tmp.Inject(ctx, hc) + } + + ctx = slog.With(ctx, + slog.F("method", req.Method), + slog.F("url", req.URL.String()), + ) + c.Logger.Debug(ctx, "sdk request", slog.F("body", string(reqBody))) + resp, err := c.HTTPClient.Do(req) if err != nil { return nil, xerrors.Errorf("do: %w", err) } + + span.SetAttributes(semconv.HTTPStatusCodeKey.Int(resp.StatusCode)) + span.SetStatus(semconv.SpanStatusFromHTTPStatusCodeAndSpanKind(resp.StatusCode, trace.SpanKindClient)) + + // Copy the response body so we can log it if it's a loggable mime type. + var respBody []byte + if resp.Body != nil && c.LogBodies { + mimeType := parseMimeType(resp.Header.Get("Content-Type")) + if _, ok := loggableMimeTypes[mimeType]; ok { + respBody, err = io.ReadAll(resp.Body) + if err != nil { + return nil, xerrors.Errorf("copy response body for logs: %w", err) + } + err = resp.Body.Close() + if err != nil { + return nil, xerrors.Errorf("close response body: %w", err) + } + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + } + } + + c.Logger.Debug(ctx, "sdk response", + slog.F("status", resp.StatusCode), + slog.F("body", string(respBody)), + slog.F("trace_id", resp.Header.Get("X-Trace-Id")), + slog.F("span_id", resp.Header.Get("X-Span-Id")), + ) + return resp, err } @@ -138,10 +239,7 @@ func readBodyAsError(res *http.Response) error { return xerrors.Errorf("read body: %w", err) } - mimeType, _, err := mime.ParseMediaType(contentType) - if err != nil { - mimeType = strings.TrimSpace(strings.Split(contentType, ";")[0]) - } + mimeType := parseMimeType(contentType) if mimeType != "application/json" { if len(resp) > 1024 { resp = append(resp[:1024], []byte("...")...) @@ -238,3 +336,12 @@ type closeFunc func() error func (c closeFunc) Close() error { return c() } + +func parseMimeType(contentType string) string { + mimeType, _, err := mime.ParseMediaType(contentType) + if err != nil { + mimeType = strings.TrimSpace(strings.Split(contentType, ";")[0]) + } + + return mimeType +} diff --git a/codersdk/client_internal_test.go b/codersdk/client_internal_test.go index dbb96340f1..c855db734a 100644 --- a/codersdk/client_internal_test.go +++ b/codersdk/client_internal_test.go @@ -2,23 +2,114 @@ package codersdk import ( "bytes" + "context" "encoding/json" "fmt" "io" "net/http" "net/http/httptest" + "net/url" "strconv" "strings" "testing" + "github.com/go-logr/logr" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/propagation" + "go.opentelemetry.io/otel/sdk/resource" + sdktrace "go.opentelemetry.io/otel/sdk/trace" + semconv "go.opentelemetry.io/otel/semconv/v1.4.0" "golang.org/x/xerrors" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/sloghuman" + + "github.com/coder/coder/testutil" ) -const ( - jsonCT = "application/json" -) +const jsonCT = "application/json" + +func Test_Client(t *testing.T) { + t.Parallel() + + const method = http.MethodPost + const path = "/ok" + const token = "token" + const reqBody = `{"msg": "request body"}` + const resBody = `{"status": "ok"}` + + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, method, r.Method) + assert.Equal(t, path, r.URL.Path) + assert.Equal(t, token, r.Header.Get(SessionCustomHeader)) + assert.Equal(t, "true", r.Header.Get(BypassRatelimitHeader)) + assert.NotEmpty(t, r.Header.Get("Traceparent")) + for k, v := range r.Header { + t.Logf("header %q: %q", k, strings.Join(v, ", ")) + } + + w.Header().Set("Content-Type", jsonCT) + w.WriteHeader(http.StatusOK) + _, _ = io.WriteString(w, resBody) + })) + + u, err := url.Parse(s.URL) + require.NoError(t, err) + client := New(u) + client.SessionToken = token + client.BypassRatelimits = true + + logBuf := bytes.NewBuffer(nil) + client.Logger = slog.Make(sloghuman.Sink(logBuf)).Leveled(slog.LevelDebug) + client.LogBodies = true + + // Setup tracing. + res := resource.NewWithAttributes( + semconv.SchemaURL, + semconv.ServiceNameKey.String("codersdk_test"), + ) + tracerOpts := []sdktrace.TracerProviderOption{ + sdktrace.WithResource(res), + } + tracerProvider := sdktrace.NewTracerProvider(tracerOpts...) + otel.SetTracerProvider(tracerProvider) + otel.SetErrorHandler(otel.ErrorHandlerFunc(func(err error) {})) + otel.SetTextMapPropagator( + propagation.NewCompositeTextMapPropagator( + propagation.TraceContext{}, + propagation.Baggage{}, + ), + ) + otel.SetLogger(logr.Discard()) + client.PropagateTracing = true + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + ctx, span := tracerProvider.Tracer("codersdk_test").Start(ctx, "codersdk client test 1") + defer span.End() + + resp, err := client.Request(ctx, method, path, []byte(reqBody)) + require.NoError(t, err) + defer resp.Body.Close() + + require.Equal(t, http.StatusOK, resp.StatusCode) + require.Equal(t, jsonCT, resp.Header.Get("Content-Type")) + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, resBody, string(body)) + + logStr := logBuf.String() + require.Contains(t, logStr, "sdk request") + require.Contains(t, logStr, method) + require.Contains(t, logStr, path) + require.Contains(t, logStr, strings.ReplaceAll(reqBody, `"`, `\"`)) + require.Contains(t, logStr, "sdk response") + require.Contains(t, logStr, "200") + require.Contains(t, logStr, strings.ReplaceAll(resBody, `"`, `\"`)) +} func Test_readBodyAsError(t *testing.T) { t.Parallel() diff --git a/codersdk/sse.go b/codersdk/sse.go index 39aaf71dec..56457a0c92 100644 --- a/codersdk/sse.go +++ b/codersdk/sse.go @@ -2,11 +2,14 @@ package codersdk import ( "bufio" + "context" "fmt" "io" "strings" "golang.org/x/xerrors" + + "github.com/coder/coder/coderd/tracing" ) type ServerSentEvent struct { @@ -22,7 +25,10 @@ const ( ServerSentEventTypeError ServerSentEventType = "error" ) -func ServerSentEventReader(rc io.ReadCloser) func() (*ServerSentEvent, error) { +func ServerSentEventReader(ctx context.Context, rc io.ReadCloser) func() (*ServerSentEvent, error) { + _, span := tracing.StartSpan(ctx) + defer span.End() + reader := bufio.NewReader(rc) nextLineValue := func(prefix string) ([]byte, error) { var ( diff --git a/codersdk/workspaces.go b/codersdk/workspaces.go index 6e217feed4..dc933f5196 100644 --- a/codersdk/workspaces.go +++ b/codersdk/workspaces.go @@ -10,6 +10,8 @@ import ( "github.com/google/uuid" "golang.org/x/xerrors" + + "github.com/coder/coder/coderd/tracing" ) // Workspace is a deployment of a template. It references a specific @@ -137,6 +139,8 @@ func (c *Client) CreateWorkspaceBuild(ctx context.Context, workspace uuid.UUID, } func (c *Client) WatchWorkspace(ctx context.Context, id uuid.UUID) (<-chan Workspace, error) { + ctx, span := tracing.StartSpan(ctx) + defer span.End() //nolint:bodyclose res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/workspaces/%s/watch", id), nil) if err != nil { @@ -145,7 +149,7 @@ func (c *Client) WatchWorkspace(ctx context.Context, id uuid.UUID) (<-chan Works if res.StatusCode != http.StatusOK { return nil, readBodyAsError(res) } - nextEvent := ServerSentEventReader(res.Body) + nextEvent := ServerSentEventReader(ctx, res.Body) wc := make(chan Workspace, 256) go func() { diff --git a/loadtest/agentconn/run.go b/loadtest/agentconn/run.go index 094df34c30..9577195c4e 100644 --- a/loadtest/agentconn/run.go +++ b/loadtest/agentconn/run.go @@ -9,7 +9,6 @@ import ( "net/netip" "net/url" "strconv" - "sync" "time" "golang.org/x/sync/errgroup" @@ -17,8 +16,10 @@ import ( "cdr.dev/slog" "cdr.dev/slog/sloggers/sloghuman" + "github.com/coder/coder/coderd/tracing" "github.com/coder/coder/codersdk" "github.com/coder/coder/loadtest/harness" + "github.com/coder/coder/loadtest/loadtestutil" ) const defaultRequestTimeout = 5 * time.Second @@ -45,11 +46,13 @@ func NewRunner(client *codersdk.Client, cfg Config) *Runner { // Run implements Runnable. func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) error { - logs = syncWriter{ - mut: &sync.Mutex{}, - w: logs, - } + ctx, span := tracing.StartSpan(ctx) + defer span.End() + + logs = loadtestutil.NewSyncWriter(logs) logger := slog.Make(sloghuman.Sink(logs)).Leveled(slog.LevelDebug) + r.client.Logger = logger + r.client.LogBodies = true _, _ = fmt.Fprintln(logs, "Opening connection to workspace agent") switch r.cfg.ConnectionMode { @@ -69,9 +72,72 @@ func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) error { } defer conn.Close() - // Wait for the disco connection to be established. + err = waitForDisco(ctx, logs, conn) + if err != nil { + return xerrors.Errorf("wait for discovery connection: %w", err) + } + + // Wait for a direct connection if requested. + if r.cfg.ConnectionMode == ConnectionModeDirect { + err = waitForDirectConnection(ctx, logs, conn) + if err != nil { + return xerrors.Errorf("wait for direct connection: %w", err) + } + } + + // Ensure DERP for completeness. + if r.cfg.ConnectionMode == ConnectionModeDerp { + status := conn.Status() + if len(status.Peers()) != 1 { + return xerrors.Errorf("check connection mode: expected 1 peer, got %d", len(status.Peers())) + } + peer := status.Peer[status.Peers()[0]] + if peer.Relay == "" || peer.CurAddr != "" { + return xerrors.Errorf("check connection mode: peer is connected directly, not via DERP") + } + } + + _, _ = fmt.Fprint(logs, "\nConnection established.\n\n") + + // HACK: even though the ping passed above, we still need to open a + // connection to the agent to ensure it's ready to accept connections. Not + // sure why this is the case but it seems to be necessary. + err = verifyConnection(ctx, logs, conn) + if err != nil { + return xerrors.Errorf("verify connection: %w", err) + } + + _, _ = fmt.Fprint(logs, "\nConnection verified.\n\n") + + // Make initial connections sequentially to ensure the services are + // reachable before we start spawning a bunch of goroutines and tickers. + err = performInitialConnections(ctx, logs, conn, r.cfg.Connections) + if err != nil { + return xerrors.Errorf("perform initial connections: %w", err) + } + + if r.cfg.HoldDuration > 0 { + err = holdConnection(ctx, logs, conn, time.Duration(r.cfg.HoldDuration), r.cfg.Connections) + if err != nil { + return xerrors.Errorf("hold connection: %w", err) + } + } + + err = conn.Close() + if err != nil { + return xerrors.Errorf("close connection: %w", err) + } + + return nil +} + +func waitForDisco(ctx context.Context, logs io.Writer, conn *codersdk.AgentConn) error { const pingAttempts = 10 const pingDelay = 1 * time.Second + + ctx, span := tracing.StartSpan(ctx) + defer span.End() + for i := 0; i < pingAttempts; i++ { _, _ = fmt.Fprintf(logs, "\tDisco ping attempt %d/%d...\n", i+1, pingAttempts) pingCtx, cancel := context.WithTimeout(ctx, defaultRequestTimeout) @@ -93,80 +159,59 @@ func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) error { } } - // Wait for a direct connection if requested. - if r.cfg.ConnectionMode == ConnectionModeDirect { - const directConnectionAttempts = 30 - const directConnectionDelay = 1 * time.Second - for i := 0; i < directConnectionAttempts; i++ { - _, _ = fmt.Fprintf(logs, "\tDirect connection check %d/%d...\n", i+1, directConnectionAttempts) - status := conn.Status() + return nil +} - var err error - if len(status.Peers()) != 1 { - _, _ = fmt.Fprintf(logs, "\t\tExpected 1 peer, found %d", len(status.Peers())) - err = xerrors.Errorf("expected 1 peer, got %d", len(status.Peers())) - } else { - peer := status.Peer[status.Peers()[0]] - _, _ = fmt.Fprintf(logs, "\t\tCurAddr: %s\n", peer.CurAddr) - _, _ = fmt.Fprintf(logs, "\t\tRelay: %s\n", peer.Relay) - if peer.Relay != "" && peer.CurAddr == "" { - err = xerrors.Errorf("peer is connected via DERP, not direct") - } - } - if err == nil { - break - } - if i == directConnectionAttempts-1 { - return xerrors.Errorf("wait for direct connection to agent: %w", err) - } +func waitForDirectConnection(ctx context.Context, logs io.Writer, conn *codersdk.AgentConn) error { + const directConnectionAttempts = 30 + const directConnectionDelay = 1 * time.Second - select { - case <-ctx.Done(): - return xerrors.Errorf("wait for direct connection to agent: %w", ctx.Err()) - // We use time.After here since it's a very short duration so - // leaking a timer is fine. - case <-time.After(directConnectionDelay): - } - } - } + ctx, span := tracing.StartSpan(ctx) + defer span.End() - // Ensure DERP for completeness. - if r.cfg.ConnectionMode == ConnectionModeDerp { + for i := 0; i < directConnectionAttempts; i++ { + _, _ = fmt.Fprintf(logs, "\tDirect connection check %d/%d...\n", i+1, directConnectionAttempts) status := conn.Status() + + var err error if len(status.Peers()) != 1 { - return xerrors.Errorf("check connection mode: expected 1 peer, got %d", len(status.Peers())) + _, _ = fmt.Fprintf(logs, "\t\tExpected 1 peer, found %d", len(status.Peers())) + err = xerrors.Errorf("expected 1 peer, got %d", len(status.Peers())) + } else { + peer := status.Peer[status.Peers()[0]] + _, _ = fmt.Fprintf(logs, "\t\tCurAddr: %s\n", peer.CurAddr) + _, _ = fmt.Fprintf(logs, "\t\tRelay: %s\n", peer.Relay) + if peer.Relay != "" && peer.CurAddr == "" { + err = xerrors.Errorf("peer is connected via DERP, not direct") + } } - peer := status.Peer[status.Peers()[0]] - if peer.Relay == "" || peer.CurAddr != "" { - return xerrors.Errorf("check connection mode: peer is connected directly, not via DERP") + if err == nil { + break + } + if i == directConnectionAttempts-1 { + return xerrors.Errorf("wait for direct connection to agent: %w", err) + } + + select { + case <-ctx.Done(): + return xerrors.Errorf("wait for direct connection to agent: %w", ctx.Err()) + // We use time.After here since it's a very short duration so + // leaking a timer is fine. + case <-time.After(directConnectionDelay): } } - _, _ = fmt.Fprint(logs, "\nConnection established.\n\n") + return nil +} - client := &http.Client{ - Transport: &http.Transport{ - DisableKeepAlives: true, - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - _, port, err := net.SplitHostPort(addr) - if err != nil { - return nil, xerrors.Errorf("split host port %q: %w", addr, err) - } - - portUint, err := strconv.ParseUint(port, 10, 16) - if err != nil { - return nil, xerrors.Errorf("parse port %q: %w", port, err) - } - return conn.DialContextTCP(ctx, netip.AddrPortFrom(codersdk.TailnetIP, uint16(portUint))) - }, - }, - } - - // HACK: even though the ping passed above, we still need to open a - // connection to the agent to ensure it's ready to accept connections. Not - // sure why this is the case but it seems to be necessary. +func verifyConnection(ctx context.Context, logs io.Writer, conn *codersdk.AgentConn) error { const verifyConnectionAttempts = 30 const verifyConnectionDelay = 1 * time.Second + + ctx, span := tracing.StartSpan(ctx) + defer span.End() + + client := agentHTTPClient(conn) for i := 0; i < verifyConnectionAttempts; i++ { _, _ = fmt.Fprintf(logs, "\tVerify connection attempt %d/%d...\n", i+1, verifyConnectionAttempts) verifyCtx, cancel := context.WithTimeout(ctx, defaultRequestTimeout) @@ -198,14 +243,20 @@ func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) error { } } - _, _ = fmt.Fprint(logs, "\nConnection verified.\n\n") + return nil +} - // Make initial connections sequentially to ensure the services are - // reachable before we start spawning a bunch of goroutines and tickers. - if len(r.cfg.Connections) > 0 { - _, _ = fmt.Fprintln(logs, "Performing initial service connections...") +func performInitialConnections(ctx context.Context, logs io.Writer, conn *codersdk.AgentConn, specs []Connection) error { + if len(specs) == 0 { + return nil } - for i, connSpec := range r.cfg.Connections { + + ctx, span := tracing.StartSpan(ctx) + defer span.End() + + _, _ = fmt.Fprintln(logs, "Performing initial service connections...") + client := agentHTTPClient(conn) + for i, connSpec := range specs { _, _ = fmt.Fprintf(logs, "\t%d. %s\n", i, connSpec.URL) timeout := defaultRequestTimeout @@ -230,95 +281,102 @@ func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) error { _, _ = fmt.Fprintln(logs, "\t\tOK") } - if r.cfg.HoldDuration > 0 { - eg, egCtx := errgroup.WithContext(ctx) + return nil +} - if len(r.cfg.Connections) > 0 { - _, _ = fmt.Fprintln(logs, "\nStarting connection loops...") - } - for i, connSpec := range r.cfg.Connections { - i, connSpec := i, connSpec - if connSpec.Interval <= 0 { - continue - } +func holdConnection(ctx context.Context, logs io.Writer, conn *codersdk.AgentConn, holdDur time.Duration, specs []Connection) error { + ctx, span := tracing.StartSpan(ctx) + defer span.End() - eg.Go(func() error { - t := time.NewTicker(time.Duration(connSpec.Interval)) - defer t.Stop() - - timeout := defaultRequestTimeout - if connSpec.Timeout > 0 { - timeout = time.Duration(connSpec.Timeout) - } - - for { - select { - case <-egCtx.Done(): - return egCtx.Err() - case <-t.C: - ctx, cancel := context.WithTimeout(ctx, timeout) - req, err := http.NewRequestWithContext(ctx, http.MethodGet, connSpec.URL, nil) - if err != nil { - cancel() - return xerrors.Errorf("create request: %w", err) - } - - res, err := client.Do(req) - cancel() - if err != nil { - _, _ = fmt.Fprintf(logs, "\tERR: %s (%d): %+v\n", connSpec.URL, i, err) - return xerrors.Errorf("make connection to conn spec %d %q: %w", i, connSpec.URL, err) - } - res.Body.Close() - - _, _ = fmt.Fprintf(logs, "\tOK: %s (%d)\n", connSpec.URL, i) - t.Reset(time.Duration(connSpec.Interval)) - } - } - }) + eg, egCtx := errgroup.WithContext(ctx) + client := agentHTTPClient(conn) + if len(specs) > 0 { + _, _ = fmt.Fprintln(logs, "\nStarting connection loops...") + } + for i, connSpec := range specs { + i, connSpec := i, connSpec + if connSpec.Interval <= 0 { + continue } - // Wait for the hold duration to end. We use a fake error to signal that - // the hold duration has ended. - _, _ = fmt.Fprintf(logs, "\nWaiting for %s...\n", time.Duration(r.cfg.HoldDuration)) eg.Go(func() error { - t := time.NewTicker(time.Duration(r.cfg.HoldDuration)) + t := time.NewTicker(time.Duration(connSpec.Interval)) defer t.Stop() - select { - case <-egCtx.Done(): - return egCtx.Err() - case <-t.C: - // Returning an error here will cause the errgroup context to - // be canceled, which is what we want. This fake error is - // ignored below. - return holdDurationEndedError{} + timeout := defaultRequestTimeout + if connSpec.Timeout > 0 { + timeout = time.Duration(connSpec.Timeout) + } + + for { + select { + case <-egCtx.Done(): + return egCtx.Err() + case <-t.C: + ctx, cancel := context.WithTimeout(ctx, timeout) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, connSpec.URL, nil) + if err != nil { + cancel() + return xerrors.Errorf("create request: %w", err) + } + + res, err := client.Do(req) + cancel() + if err != nil { + _, _ = fmt.Fprintf(logs, "\tERR: %s (%d): %+v\n", connSpec.URL, i, err) + return xerrors.Errorf("make connection to conn spec %d %q: %w", i, connSpec.URL, err) + } + res.Body.Close() + + _, _ = fmt.Fprintf(logs, "\tOK: %s (%d)\n", connSpec.URL, i) + t.Reset(time.Duration(connSpec.Interval)) + } } }) - - err = eg.Wait() - if err != nil && !xerrors.Is(err, holdDurationEndedError{}) { - return xerrors.Errorf("run connections loop: %w", err) - } } - err = conn.Close() - if err != nil { - return xerrors.Errorf("close connection: %w", err) + // Wait for the hold duration to end. We use a fake error to signal that + // the hold duration has ended. + _, _ = fmt.Fprintf(logs, "\nWaiting for %s...\n", holdDur) + eg.Go(func() error { + t := time.NewTicker(holdDur) + defer t.Stop() + + select { + case <-egCtx.Done(): + return egCtx.Err() + case <-t.C: + // Returning an error here will cause the errgroup context to + // be canceled, which is what we want. This fake error is + // ignored below. + return holdDurationEndedError{} + } + }) + + err := eg.Wait() + if err != nil && !xerrors.Is(err, holdDurationEndedError{}) { + return xerrors.Errorf("run connections loop: %w", err) } return nil } -// syncWriter wraps an io.Writer in a sync.Mutex. -type syncWriter struct { - mut *sync.Mutex - w io.Writer -} +func agentHTTPClient(conn *codersdk.AgentConn) *http.Client { + return &http.Client{ + Transport: &http.Transport{ + DisableKeepAlives: true, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + _, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, xerrors.Errorf("split host port %q: %w", addr, err) + } -// Write implements io.Writer. -func (sw syncWriter) Write(p []byte) (n int, err error) { - sw.mut.Lock() - defer sw.mut.Unlock() - return sw.w.Write(p) + portUint, err := strconv.ParseUint(port, 10, 16) + if err != nil { + return nil, xerrors.Errorf("parse port %q: %w", port, err) + } + return conn.DialContextTCP(ctx, netip.AddrPortFrom(codersdk.TailnetIP, uint16(portUint))) + }, + }, + } } diff --git a/loadtest/harness/harness.go b/loadtest/harness/harness.go index c9d200a7d8..863e3ed9c5 100644 --- a/loadtest/harness/harness.go +++ b/loadtest/harness/harness.go @@ -7,6 +7,8 @@ import ( "github.com/hashicorp/go-multierror" "golang.org/x/xerrors" + + "github.com/coder/coder/coderd/tracing" ) // ExecutionStrategy defines how a TestHarness should execute a set of runs. It @@ -49,6 +51,9 @@ func NewTestHarness(strategy ExecutionStrategy) *TestHarness { // // Panics if called more than once. func (h *TestHarness) Run(ctx context.Context) (err error) { + ctx, span := tracing.StartSpan(ctx) + defer span.End() + h.mut.Lock() if h.started { h.mut.Unlock() diff --git a/loadtest/harness/strategies.go b/loadtest/harness/strategies.go index 7bc52a3b25..019ee29126 100644 --- a/loadtest/harness/strategies.go +++ b/loadtest/harness/strategies.go @@ -95,6 +95,7 @@ type timeoutRunnerWrapper struct { } var _ Runnable = timeoutRunnerWrapper{} +var _ Cleanable = timeoutRunnerWrapper{} func (t timeoutRunnerWrapper) Run(ctx context.Context, id string, logs io.Writer) error { ctx, cancel := context.WithTimeout(ctx, t.timeout) @@ -103,6 +104,15 @@ func (t timeoutRunnerWrapper) Run(ctx context.Context, id string, logs io.Writer return t.inner.Run(ctx, id, logs) } +func (t timeoutRunnerWrapper) Cleanup(ctx context.Context, id string) error { + c, ok := t.inner.(Cleanable) + if !ok { + return nil + } + + return c.Cleanup(ctx, id) +} + // Execute implements ExecutionStrategy. func (t TimeoutExecutionStrategyWrapper) Execute(ctx context.Context, runs []*TestRun) error { for _, run := range runs { diff --git a/loadtest/loadtestutil/syncwriter.go b/loadtest/loadtestutil/syncwriter.go new file mode 100644 index 0000000000..caeb362af9 --- /dev/null +++ b/loadtest/loadtestutil/syncwriter.go @@ -0,0 +1,26 @@ +package loadtestutil + +import ( + "io" + "sync" +) + +// SyncWriter wraps an io.Writer in a sync.Mutex. +type SyncWriter struct { + mut *sync.Mutex + w io.Writer +} + +func NewSyncWriter(w io.Writer) *SyncWriter { + return &SyncWriter{ + mut: &sync.Mutex{}, + w: w, + } +} + +// Write implements io.Writer. +func (sw *SyncWriter) Write(p []byte) (n int, err error) { + sw.mut.Lock() + defer sw.mut.Unlock() + return sw.w.Write(p) +} diff --git a/loadtest/workspacebuild/run.go b/loadtest/workspacebuild/run.go index 473a3fa9eb..c45a3ffdeb 100644 --- a/loadtest/workspacebuild/run.go +++ b/loadtest/workspacebuild/run.go @@ -9,9 +9,14 @@ import ( "github.com/google/uuid" "golang.org/x/xerrors" + "cdr.dev/slog" + "cdr.dev/slog/sloggers/sloghuman" + + "github.com/coder/coder/coderd/tracing" "github.com/coder/coder/codersdk" "github.com/coder/coder/cryptorand" "github.com/coder/coder/loadtest/harness" + "github.com/coder/coder/loadtest/loadtestutil" ) type Runner struct { @@ -32,6 +37,14 @@ func NewRunner(client *codersdk.Client, cfg Config) *Runner { // Run implements Runnable. func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) error { + ctx, span := tracing.StartSpan(ctx) + defer span.End() + + logs = loadtestutil.NewSyncWriter(logs) + logger := slog.Make(sloghuman.Sink(logs)).Leveled(slog.LevelDebug) + r.client.Logger = logger + r.client.LogBodies = true + req := r.cfg.Request if req.Name == "" { randName, err := cryptorand.HexString(8) @@ -66,6 +79,8 @@ func (r *Runner) Cleanup(ctx context.Context, _ string) error { if r.workspaceID == uuid.Nil { return nil } + ctx, span := tracing.StartSpan(ctx) + defer span.End() build, err := r.client.CreateWorkspaceBuild(ctx, r.workspaceID, codersdk.CreateWorkspaceBuildRequest{ Transition: codersdk.WorkspaceTransitionDelete, @@ -85,6 +100,8 @@ func (r *Runner) Cleanup(ctx context.Context, _ string) error { } func waitForBuild(ctx context.Context, w io.Writer, client *codersdk.Client, buildID uuid.UUID) error { + ctx, span := tracing.StartSpan(ctx) + defer span.End() _, _ = fmt.Fprint(w, "Build is currently queued...") // Wait for build to start. @@ -154,6 +171,8 @@ func waitForBuild(ctx context.Context, w io.Writer, client *codersdk.Client, bui } func waitForAgents(ctx context.Context, w io.Writer, client *codersdk.Client, workspaceID uuid.UUID) error { + ctx, span := tracing.StartSpan(ctx) + defer span.End() _, _ = fmt.Fprint(w, "Waiting for agents to connect...\n\n") for {