feat: make trace provider in loadtest, add tracing to sdk (#4939)

This commit is contained in:
Dean Sheather
2022-11-09 08:10:48 +10:00
committed by GitHub
parent fa844d0878
commit d82364b9b5
24 changed files with 757 additions and 206 deletions

View File

@ -58,9 +58,12 @@ func TestAgent(t *testing.T) {
t.Run("SSH", func(t *testing.T) { t.Run("SSH", func(t *testing.T) {
t.Parallel() t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
conn, stats := setupAgent(t, codersdk.WorkspaceAgentMetadata{}, 0) conn, stats := setupAgent(t, codersdk.WorkspaceAgentMetadata{}, 0)
sshClient, err := conn.SSHClient() sshClient, err := conn.SSHClient(ctx)
require.NoError(t, err) require.NoError(t, err)
defer sshClient.Close() defer sshClient.Close()
session, err := sshClient.NewSession() session, err := sshClient.NewSession()
@ -75,9 +78,12 @@ func TestAgent(t *testing.T) {
t.Run("ReconnectingPTY", func(t *testing.T) { t.Run("ReconnectingPTY", func(t *testing.T) {
t.Parallel() t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
conn, stats := setupAgent(t, codersdk.WorkspaceAgentMetadata{}, 0) conn, stats := setupAgent(t, codersdk.WorkspaceAgentMetadata{}, 0)
ptyConn, err := conn.ReconnectingPTY(uuid.NewString(), 128, 128, "/bin/bash") ptyConn, err := conn.ReconnectingPTY(ctx, uuid.NewString(), 128, 128, "/bin/bash")
require.NoError(t, err) require.NoError(t, err)
defer ptyConn.Close() defer ptyConn.Close()
@ -217,6 +223,8 @@ func TestAgent(t *testing.T) {
t.Run("SFTP", func(t *testing.T) { t.Run("SFTP", func(t *testing.T) {
t.Parallel() t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
u, err := user.Current() u, err := user.Current()
require.NoError(t, err, "get current user") require.NoError(t, err, "get current user")
home := u.HomeDir home := u.HomeDir
@ -224,7 +232,7 @@ func TestAgent(t *testing.T) {
home = "/" + strings.ReplaceAll(home, "\\", "/") home = "/" + strings.ReplaceAll(home, "\\", "/")
} }
conn, _ := setupAgent(t, codersdk.WorkspaceAgentMetadata{}, 0) conn, _ := setupAgent(t, codersdk.WorkspaceAgentMetadata{}, 0)
sshClient, err := conn.SSHClient() sshClient, err := conn.SSHClient(ctx)
require.NoError(t, err) require.NoError(t, err)
defer sshClient.Close() defer sshClient.Close()
client, err := sftp.NewClient(sshClient) client, err := sftp.NewClient(sshClient)
@ -250,8 +258,11 @@ func TestAgent(t *testing.T) {
t.Run("SCP", func(t *testing.T) { t.Run("SCP", func(t *testing.T) {
t.Parallel() t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
conn, _ := setupAgent(t, codersdk.WorkspaceAgentMetadata{}, 0) conn, _ := setupAgent(t, codersdk.WorkspaceAgentMetadata{}, 0)
sshClient, err := conn.SSHClient() sshClient, err := conn.SSHClient(ctx)
require.NoError(t, err) require.NoError(t, err)
defer sshClient.Close() defer sshClient.Close()
scpClient, err := scp.NewClientBySSH(sshClient) scpClient, err := scp.NewClientBySSH(sshClient)
@ -386,9 +397,12 @@ func TestAgent(t *testing.T) {
t.Skip("ConPTY appears to be inconsistent on Windows.") t.Skip("ConPTY appears to be inconsistent on Windows.")
} }
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
conn, _ := setupAgent(t, codersdk.WorkspaceAgentMetadata{}, 0) conn, _ := setupAgent(t, codersdk.WorkspaceAgentMetadata{}, 0)
id := uuid.NewString() id := uuid.NewString()
netConn, err := conn.ReconnectingPTY(id, 100, 100, "/bin/bash") netConn, err := conn.ReconnectingPTY(ctx, id, 100, 100, "/bin/bash")
require.NoError(t, err) require.NoError(t, err)
bufRead := bufio.NewReader(netConn) bufRead := bufio.NewReader(netConn)
@ -426,7 +440,7 @@ func TestAgent(t *testing.T) {
expectLine(matchEchoOutput) expectLine(matchEchoOutput)
_ = netConn.Close() _ = netConn.Close()
netConn, err = conn.ReconnectingPTY(id, 100, 100, "/bin/bash") netConn, err = conn.ReconnectingPTY(ctx, id, 100, 100, "/bin/bash")
require.NoError(t, err) require.NoError(t, err)
bufRead = bufio.NewReader(netConn) bufRead = bufio.NewReader(netConn)
@ -504,12 +518,14 @@ func TestAgent(t *testing.T) {
t.Run("Speedtest", func(t *testing.T) { t.Run("Speedtest", func(t *testing.T) {
t.Parallel() t.Parallel()
t.Skip("This test is relatively flakey because of Tailscale's speedtest code...") t.Skip("This test is relatively flakey because of Tailscale's speedtest code...")
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
derpMap := tailnettest.RunDERPAndSTUN(t) derpMap := tailnettest.RunDERPAndSTUN(t)
conn, _ := setupAgent(t, codersdk.WorkspaceAgentMetadata{ conn, _ := setupAgent(t, codersdk.WorkspaceAgentMetadata{
DERPMap: derpMap, DERPMap: derpMap,
}, 0) }, 0)
defer conn.Close() defer conn.Close()
res, err := conn.Speedtest(speedtest.Upload, 250*time.Millisecond) res, err := conn.Speedtest(ctx, speedtest.Upload, 250*time.Millisecond)
require.NoError(t, err) require.NoError(t, err)
t.Logf("%.2f MBits/s", res[len(res)-1].MBitsPerSecond()) t.Logf("%.2f MBits/s", res[len(res)-1].MBitsPerSecond())
}) })
@ -599,7 +615,10 @@ func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exe
if err != nil { if err != nil {
return return
} }
ssh, err := agentConn.SSH()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
ssh, err := agentConn.SSH(ctx)
cancel()
if err != nil { if err != nil {
_ = conn.Close() _ = conn.Close()
return return
@ -626,8 +645,10 @@ func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exe
} }
func setupSSHSession(t *testing.T, options codersdk.WorkspaceAgentMetadata) *ssh.Session { func setupSSHSession(t *testing.T, options codersdk.WorkspaceAgentMetadata) *ssh.Session {
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
conn, _ := setupAgent(t, options, 0) conn, _ := setupAgent(t, options, 0)
sshClient, err := conn.SSHClient() sshClient, err := conn.SSHClient(ctx)
require.NoError(t, err) require.NoError(t, err)
t.Cleanup(func() { t.Cleanup(func() {
_ = sshClient.Close() _ = sshClient.Close()

View File

@ -198,7 +198,7 @@ func TestWorkspaceAgent(t *testing.T) {
return err == nil return err == nil
}, testutil.WaitMedium, testutil.IntervalFast) }, testutil.WaitMedium, testutil.IntervalFast)
sshClient, err := dialer.SSHClient() sshClient, err := dialer.SSHClient(ctx)
require.NoError(t, err) require.NoError(t, err)
defer sshClient.Close() defer sshClient.Close()
session, err := sshClient.NewSession() session, err := sshClient.NewSession()

View File

@ -28,6 +28,7 @@ import (
"github.com/coder/coder/provisioner/echo" "github.com/coder/coder/provisioner/echo"
"github.com/coder/coder/provisionersdk/proto" "github.com/coder/coder/provisionersdk/proto"
"github.com/coder/coder/pty/ptytest" "github.com/coder/coder/pty/ptytest"
"github.com/coder/coder/testutil"
) )
func sshConfigFileName(t *testing.T) (sshConfig string) { func sshConfigFileName(t *testing.T) (sshConfig string) {
@ -131,7 +132,9 @@ func TestConfigSSH(t *testing.T) {
if err != nil { if err != nil {
break break
} }
ssh, err := agentConn.SSH() ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
ssh, err := agentConn.SSH(ctx)
cancel()
assert.NoError(t, err) assert.NoError(t, err)
wg.Add(2) wg.Add(2)
go func() { go func() {

View File

@ -8,20 +8,30 @@ import (
"os" "os"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"go.opentelemetry.io/otel/trace"
"golang.org/x/xerrors" "golang.org/x/xerrors"
"github.com/coder/coder/cli/cliflag" "github.com/coder/coder/cli/cliflag"
"github.com/coder/coder/coderd/tracing"
"github.com/coder/coder/codersdk" "github.com/coder/coder/codersdk"
"github.com/coder/coder/loadtest/harness" "github.com/coder/coder/loadtest/harness"
) )
const loadtestTracerName = "coder_loadtest"
func loadtest() *cobra.Command { func loadtest() *cobra.Command {
var ( var (
configPath string configPath string
outputSpecs []string outputSpecs []string
traceEnable bool
traceCoder bool
traceHoneycombAPIKey string
tracePropagate bool
) )
cmd := &cobra.Command{ cmd := &cobra.Command{
Use: "loadtest --config <path> [--output json[:path]] [--output text[:path]]]", Use: "loadtest --config <path> [--output json[:path]] [--output text[:path]]]",
@ -53,6 +63,8 @@ func loadtest() *cobra.Command {
Hidden: true, Hidden: true,
Args: cobra.ExactArgs(0), Args: cobra.ExactArgs(0),
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
ctx := tracing.SetTracerName(cmd.Context(), loadtestTracerName)
config, err := loadLoadTestConfigFile(configPath, cmd.InOrStdin()) config, err := loadLoadTestConfigFile(configPath, cmd.InOrStdin())
if err != nil { if err != nil {
return err return err
@ -67,7 +79,7 @@ func loadtest() *cobra.Command {
return err return err
} }
me, err := client.User(cmd.Context(), codersdk.Me) me, err := client.User(ctx, codersdk.Me)
if err != nil { if err != nil {
return xerrors.Errorf("fetch current user: %w", err) return xerrors.Errorf("fetch current user: %w", err)
} }
@ -84,11 +96,43 @@ func loadtest() *cobra.Command {
} }
} }
if !ok { if !ok {
return xerrors.Errorf("Not logged in as site owner. Load testing is only available to site owners.") return xerrors.Errorf("Not logged in as a site owner. Load testing is only available to site owners.")
} }
// Disable ratelimits for future requests. // Setup tracing and start a span.
var (
shouldTrace = traceEnable || traceCoder || traceHoneycombAPIKey != ""
tracerProvider trace.TracerProvider = trace.NewNoopTracerProvider()
closeTracingOnce sync.Once
closeTracing = func(_ context.Context) error {
return nil
}
)
if shouldTrace {
tracerProvider, closeTracing, err = tracing.TracerProvider(ctx, loadtestTracerName, tracing.TracerOpts{
Default: traceEnable,
Coder: traceCoder,
Honeycomb: traceHoneycombAPIKey,
})
if err != nil {
return xerrors.Errorf("initialize tracing: %w", err)
}
defer func() {
closeTracingOnce.Do(func() {
// Allow time for traces to flush even if command
// context is canceled.
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
_ = closeTracing(ctx)
})
}()
}
tracer := tracerProvider.Tracer(loadtestTracerName)
// Disable ratelimits and propagate tracing spans for future
// requests. Individual tests will setup their own loggers.
client.BypassRatelimits = true client.BypassRatelimits = true
client.PropagateTracing = tracePropagate
// Prepare the test. // Prepare the test.
strategy := config.Strategy.ExecutionStrategy() strategy := config.Strategy.ExecutionStrategy()
@ -99,18 +143,22 @@ func loadtest() *cobra.Command {
for j := 0; j < t.Count; j++ { for j := 0; j < t.Count; j++ {
id := strconv.Itoa(j) id := strconv.Itoa(j)
runner, err := t.NewRunner(client) runner, err := t.NewRunner(client.Clone())
if err != nil { if err != nil {
return xerrors.Errorf("create %q runner for %s/%s: %w", t.Type, name, id, err) return xerrors.Errorf("create %q runner for %s/%s: %w", t.Type, name, id, err)
} }
th.AddRun(name, id, runner) th.AddRun(name, id, &runnableTraceWrapper{
tracer: tracer,
spanName: fmt.Sprintf("%s/%s", name, id),
runner: runner,
})
} }
} }
_, _ = fmt.Fprintln(cmd.ErrOrStderr(), "Running load test...") _, _ = fmt.Fprintln(cmd.ErrOrStderr(), "Running load test...")
testCtx := cmd.Context() testCtx := ctx
if config.Timeout > 0 { if config.Timeout > 0 {
var cancel func() var cancel func()
testCtx, cancel = context.WithTimeout(testCtx, time.Duration(config.Timeout)) testCtx, cancel = context.WithTimeout(testCtx, time.Duration(config.Timeout))
@ -158,11 +206,24 @@ func loadtest() *cobra.Command {
// Cleanup. // Cleanup.
_, _ = fmt.Fprintln(cmd.ErrOrStderr(), "\nCleaning up...") _, _ = fmt.Fprintln(cmd.ErrOrStderr(), "\nCleaning up...")
err = th.Cleanup(cmd.Context()) err = th.Cleanup(ctx)
if err != nil { if err != nil {
return xerrors.Errorf("cleanup tests: %w", err) return xerrors.Errorf("cleanup tests: %w", err)
} }
// Upload traces.
if shouldTrace {
_, _ = fmt.Fprintln(cmd.ErrOrStderr(), "\nUploading traces...")
closeTracingOnce.Do(func() {
ctx, cancel := context.WithTimeout(ctx, 1*time.Minute)
defer cancel()
err := closeTracing(ctx)
if err != nil {
_, _ = fmt.Fprintf(cmd.ErrOrStderr(), "\nError uploading traces: %+v\n", err)
}
})
}
if res.TotalFail > 0 { if res.TotalFail > 0 {
return xerrors.New("load test failed, see above for more details") return xerrors.New("load test failed, see above for more details")
} }
@ -173,6 +234,12 @@ func loadtest() *cobra.Command {
cliflag.StringVarP(cmd.Flags(), &configPath, "config", "", "CODER_LOADTEST_CONFIG_PATH", "", "Path to the load test configuration file, or - to read from stdin.") cliflag.StringVarP(cmd.Flags(), &configPath, "config", "", "CODER_LOADTEST_CONFIG_PATH", "", "Path to the load test configuration file, or - to read from stdin.")
cliflag.StringArrayVarP(cmd.Flags(), &outputSpecs, "output", "", "CODER_LOADTEST_OUTPUTS", []string{"text"}, "Output formats, see usage for more information.") cliflag.StringArrayVarP(cmd.Flags(), &outputSpecs, "output", "", "CODER_LOADTEST_OUTPUTS", []string{"text"}, "Output formats, see usage for more information.")
cliflag.BoolVarP(cmd.Flags(), &traceEnable, "trace", "", "CODER_LOADTEST_TRACE", false, "Whether application tracing data is collected. It exports to a backend configured by environment variables. See: https://github.com/open-telemetry/opentelemetry-specification/blob/main/specification/protocol/exporter.md")
cliflag.BoolVarP(cmd.Flags(), &traceCoder, "trace-coder", "", "CODER_LOADTEST_TRACE_CODER", false, "Whether opentelemetry traces are sent to Coder. We recommend keeping this disabled unless we advise you to enable it.")
cliflag.StringVarP(cmd.Flags(), &traceHoneycombAPIKey, "trace-honeycomb-api-key", "", "CODER_LOADTEST_TRACE_HONEYCOMB_API_KEY", "", "Enables trace exporting to Honeycomb.io using the provided API key.")
cliflag.BoolVarP(cmd.Flags(), &tracePropagate, "trace-propagate", "", "CODER_LOADTEST_TRACE_PROPAGATE", false, "Enables trace propagation to the Coder backend, which will be used to correlate server-side spans with client-side spans. Only enable this if the server is configured with the exact same tracing configuration as the client.")
return cmd return cmd
} }
@ -271,3 +338,53 @@ func parseLoadTestOutputs(outputs []string) ([]loadTestOutput, error) {
return out, nil return out, nil
} }
type runnableTraceWrapper struct {
tracer trace.Tracer
spanName string
runner harness.Runnable
span trace.Span
}
var _ harness.Runnable = &runnableTraceWrapper{}
var _ harness.Cleanable = &runnableTraceWrapper{}
func (r *runnableTraceWrapper) Run(ctx context.Context, id string, logs io.Writer) error {
ctx, span := r.tracer.Start(ctx, r.spanName, trace.WithNewRoot())
defer span.End()
r.span = span
traceID := "unknown trace ID"
spanID := "unknown span ID"
if span.SpanContext().HasTraceID() {
traceID = span.SpanContext().TraceID().String()
}
if span.SpanContext().HasSpanID() {
spanID = span.SpanContext().SpanID().String()
}
_, _ = fmt.Fprintf(logs, "Trace ID: %s\n", traceID)
_, _ = fmt.Fprintf(logs, "Span ID: %s\n\n", spanID)
// Make a separate span for the run itself so the sub-spans are grouped
// neatly. The cleanup span is also a child of the above span so this is
// important for readability.
ctx2, span2 := r.tracer.Start(ctx, r.spanName+" run")
defer span2.End()
return r.runner.Run(ctx2, id, logs)
}
func (r *runnableTraceWrapper) Cleanup(ctx context.Context, id string) error {
c, ok := r.runner.(harness.Cleanable)
if !ok {
return nil
}
if r.span != nil {
ctx = trace.ContextWithSpanContext(ctx, r.span.SpanContext())
}
ctx, span := r.tracer.Start(ctx, r.spanName+" cleanup")
defer span.End()
return c.Cleanup(ctx, id)
}

View File

@ -277,6 +277,8 @@ func TestLoadTest(t *testing.T) {
require.NoError(t, err, msg) require.NoError(t, err, msg)
} }
t.Logf("output %d:\n\n%s", i, string(b))
switch output.format { switch output.format {
case "text": case "text":
require.Contains(t, string(b), "Test results:", msg) require.Contains(t, string(b), "Test results:", msg)

View File

@ -128,8 +128,9 @@ func Server(vip *viper.Viper, newAPI func(context.Context, *coderd.Options) (*co
if cfg.Trace.Enable.Value || shouldCoderTrace { if cfg.Trace.Enable.Value || shouldCoderTrace {
sdkTracerProvider, closeTracing, err := tracing.TracerProvider(ctx, "coderd", tracing.TracerOpts{ sdkTracerProvider, closeTracing, err := tracing.TracerProvider(ctx, "coderd", tracing.TracerOpts{
Default: cfg.Trace.Enable.Value, Default: cfg.Trace.Enable.Value,
Coder: shouldCoderTrace, Coder: shouldCoderTrace,
Honeycomb: cfg.Trace.HoneycombAPIKey.Value,
}) })
if err != nil { if err != nil {
logger.Warn(ctx, "start telemetry exporter", slog.Error(err)) logger.Warn(ctx, "start telemetry exporter", slog.Error(err))

View File

@ -95,7 +95,7 @@ func speedtest() *cobra.Command {
dir = tsspeedtest.Upload dir = tsspeedtest.Upload
} }
cmd.Printf("Starting a %ds %s test...\n", int(duration.Seconds()), dir) cmd.Printf("Starting a %ds %s test...\n", int(duration.Seconds()), dir)
results, err := conn.Speedtest(dir, duration) results, err := conn.Speedtest(ctx, dir, duration)
if err != nil { if err != nil {
return err return err
} }

View File

@ -100,7 +100,7 @@ func ssh() *cobra.Command {
defer stopPolling() defer stopPolling()
if stdio { if stdio {
rawSSH, err := conn.SSH() rawSSH, err := conn.SSH(ctx)
if err != nil { if err != nil {
return err return err
} }
@ -113,7 +113,7 @@ func ssh() *cobra.Command {
return nil return nil
} }
sshClient, err := conn.SSHClient() sshClient, err := conn.SSHClient(ctx)
if err != nil { if err != nil {
return err return err
} }

View File

@ -88,7 +88,7 @@ func TestWorkspaceActivityBump(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer conn.Close() defer conn.Close()
sshConn, err := conn.SSHClient() sshConn, err := conn.SSHClient(ctx)
require.NoError(t, err) require.NoError(t, err)
_ = sshConn.Close() _ = sshConn.Close()

View File

@ -633,7 +633,7 @@ func TestTemplateMetrics(t *testing.T) {
_ = conn.Close() _ = conn.Close()
}() }()
sshConn, err := conn.SSHClient() sshConn, err := conn.SSHClient(ctx)
require.NoError(t, err) require.NoError(t, err)
_ = sshConn.Close() _ = sshConn.Close()

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"github.com/go-logr/logr" "github.com/go-logr/logr"
"github.com/hashicorp/go-multierror"
"go.opentelemetry.io/otel" "go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace" "go.opentelemetry.io/otel/exporters/otlp/otlptrace"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc" "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc"
@ -82,11 +83,23 @@ func TracerProvider(ctx context.Context, service string, opts TracerOpts) (*sdkt
otel.SetLogger(logr.Discard()) otel.SetLogger(logr.Discard())
return tracerProvider, func(ctx context.Context) error { return tracerProvider, func(ctx context.Context) error {
for _, close := range closers { var merr error
_ = close(ctx) err := tracerProvider.ForceFlush(ctx)
if err != nil {
merr = multierror.Append(merr, xerrors.Errorf("tracerProvider.ForceFlush(): %w", err))
} }
_ = tracerProvider.Shutdown(ctx) for i, closer := range closers {
return nil err = closer(ctx)
if err != nil {
merr = multierror.Append(merr, xerrors.Errorf("closer() %d: %w", i, err))
}
}
err = tracerProvider.Shutdown(ctx)
if err != nil {
merr = multierror.Append(merr, xerrors.Errorf("tracerProvider.Shutdown(): %w", err))
}
return merr
}, nil }, nil
} }

View File

@ -6,6 +6,8 @@ import (
"net/http" "net/http"
"github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/propagation"
semconv "go.opentelemetry.io/otel/semconv/v1.11.0" semconv "go.opentelemetry.io/otel/semconv/v1.11.0"
"go.opentelemetry.io/otel/trace" "go.opentelemetry.io/otel/trace"
) )
@ -23,11 +25,27 @@ func Middleware(tracerProvider trace.TracerProvider) func(http.Handler) http.Han
return return
} }
// Extract the trace context from the request headers.
tmp := otel.GetTextMapPropagator()
hc := propagation.HeaderCarrier(r.Header)
ctx := tmp.Extract(r.Context(), hc)
// start span with default span name. Span name will be updated to "method route" format once request finishes. // start span with default span name. Span name will be updated to "method route" format once request finishes.
ctx, span := tracer.Start(r.Context(), fmt.Sprintf("%s %s", r.Method, r.RequestURI)) ctx, span := tracer.Start(ctx, fmt.Sprintf("%s %s", r.Method, r.RequestURI))
defer span.End() defer span.End()
r = r.WithContext(ctx) r = r.WithContext(ctx)
if span.SpanContext().HasTraceID() && span.SpanContext().HasSpanID() {
// Technically these values are included in the Traceparent
// header, but they are easier to read for humans this way.
rw.Header().Set("X-Trace-ID", span.SpanContext().TraceID().String())
rw.Header().Set("X-Span-ID", span.SpanContext().SpanID().String())
// Inject the trace context into the response headers.
hc := propagation.HeaderCarrier(rw.Header())
tmp.Inject(ctx, hc)
}
sw, ok := rw.(*StatusWriter) sw, ok := rw.(*StatusWriter)
if !ok { if !ok {
panic(fmt.Sprintf("ResponseWriter not a *tracing.StatusWriter; got %T", rw)) panic(fmt.Sprintf("ResponseWriter not a *tracing.StatusWriter; got %T", rw))
@ -62,6 +80,37 @@ func EndHTTPSpan(r *http.Request, status int, span trace.Span) {
span.End() span.End()
} }
func StartSpan(ctx context.Context, opts ...trace.SpanStartOption) (context.Context, trace.Span) { type tracerNameKey struct{}
return trace.SpanFromContext(ctx).TracerProvider().Tracer(TracerName).Start(ctx, FuncNameSkip(1), opts...)
// SetTracerName sets the tracer name that will be used by all spans created
// from the context.
func SetTracerName(ctx context.Context, tracerName string) context.Context {
return context.WithValue(ctx, tracerNameKey{}, tracerName)
}
// GetTracerName returns the tracer name from the context, or TracerName if none
// is set.
func GetTracerName(ctx context.Context) string {
if tracerName, ok := ctx.Value(tracerNameKey{}).(string); ok {
return tracerName
}
return TracerName
}
// StartSpan calls StartSpanWithName with the name set to the caller's function
// name.
func StartSpan(ctx context.Context, opts ...trace.SpanStartOption) (context.Context, trace.Span) {
return StartSpanWithName(ctx, FuncNameSkip(1), opts...)
}
// StartSpanWithName starts a new span with the given name from the context. If
// a tracer name was set on the context (or one of its parents), it will be used
// as the tracer name instead of the default TracerName.
func StartSpanWithName(ctx context.Context, name string, opts ...trace.SpanStartOption) (context.Context, trace.Span) {
tracerName := GetTracerName(ctx)
return trace.SpanFromContext(ctx).
TracerProvider().
Tracer(tracerName).
Start(ctx, name, opts...)
} }

View File

@ -247,7 +247,7 @@ func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
return return
} }
defer release() defer release()
ptNetConn, err := agentConn.ReconnectingPTY(reconnect.String(), uint16(height), uint16(width), r.URL.Query().Get("command")) ptNetConn, err := agentConn.ReconnectingPTY(ctx, reconnect.String(), uint16(height), uint16(width), r.URL.Query().Get("command"))
if err != nil { if err != nil {
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("dial: %s", err)) _ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("dial: %s", err))
return return

View File

@ -260,7 +260,7 @@ func TestWorkspaceAgentTailnet(t *testing.T) {
}) })
require.NoError(t, err) require.NoError(t, err)
defer conn.Close() defer conn.Close()
sshClient, err := conn.SSHClient() sshClient, err := conn.SSHClient(ctx)
require.NoError(t, err) require.NoError(t, err)
session, err := sshClient.NewSession() session, err := sshClient.NewSession()
require.NoError(t, err) require.NoError(t, err)

View File

@ -20,6 +20,7 @@ import (
"tailscale.com/net/speedtest" "tailscale.com/net/speedtest"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"github.com/coder/coder/coderd/tracing"
"github.com/coder/coder/tailnet" "github.com/coder/coder/tailnet"
) )
@ -133,6 +134,9 @@ type AgentConn struct {
} }
func (c *AgentConn) Ping(ctx context.Context) (time.Duration, error) { func (c *AgentConn) Ping(ctx context.Context) (time.Duration, error) {
ctx, span := tracing.StartSpan(ctx)
defer span.End()
errCh := make(chan error, 1) errCh := make(chan error, 1)
durCh := make(chan time.Duration, 1) durCh := make(chan time.Duration, 1)
go c.Conn.Ping(TailnetIP, tailcfg.PingDisco, func(pr *ipnstate.PingResult) { go c.Conn.Ping(TailnetIP, tailcfg.PingDisco, func(pr *ipnstate.PingResult) {
@ -171,8 +175,11 @@ type ReconnectingPTYInit struct {
Command string Command string
} }
func (c *AgentConn) ReconnectingPTY(id string, height, width uint16, command string) (net.Conn, error) { func (c *AgentConn) ReconnectingPTY(ctx context.Context, id string, height, width uint16, command string) (net.Conn, error) {
conn, err := c.DialContextTCP(context.Background(), netip.AddrPortFrom(TailnetIP, uint16(TailnetReconnectingPTYPort))) ctx, span := tracing.StartSpan(ctx)
defer span.End()
conn, err := c.DialContextTCP(ctx, netip.AddrPortFrom(TailnetIP, uint16(TailnetReconnectingPTYPort)))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -197,14 +204,18 @@ func (c *AgentConn) ReconnectingPTY(id string, height, width uint16, command str
return conn, nil return conn, nil
} }
func (c *AgentConn) SSH() (net.Conn, error) { func (c *AgentConn) SSH(ctx context.Context) (net.Conn, error) {
return c.DialContextTCP(context.Background(), netip.AddrPortFrom(TailnetIP, uint16(TailnetSSHPort))) ctx, span := tracing.StartSpan(ctx)
defer span.End()
return c.DialContextTCP(ctx, netip.AddrPortFrom(TailnetIP, uint16(TailnetSSHPort)))
} }
// SSHClient calls SSH to create a client that uses a weak cipher // SSHClient calls SSH to create a client that uses a weak cipher
// for high throughput. // for high throughput.
func (c *AgentConn) SSHClient() (*ssh.Client, error) { func (c *AgentConn) SSHClient(ctx context.Context) (*ssh.Client, error) {
netConn, err := c.SSH() ctx, span := tracing.StartSpan(ctx)
defer span.End()
netConn, err := c.SSH(ctx)
if err != nil { if err != nil {
return nil, xerrors.Errorf("ssh: %w", err) return nil, xerrors.Errorf("ssh: %w", err)
} }
@ -220,8 +231,10 @@ func (c *AgentConn) SSHClient() (*ssh.Client, error) {
return ssh.NewClient(sshConn, channels, requests), nil return ssh.NewClient(sshConn, channels, requests), nil
} }
func (c *AgentConn) Speedtest(direction speedtest.Direction, duration time.Duration) ([]speedtest.Result, error) { func (c *AgentConn) Speedtest(ctx context.Context, direction speedtest.Direction, duration time.Duration) ([]speedtest.Result, error) {
speedConn, err := c.DialContextTCP(context.Background(), netip.AddrPortFrom(TailnetIP, uint16(TailnetSpeedtestPort))) ctx, span := tracing.StartSpan(ctx)
defer span.End()
speedConn, err := c.DialContextTCP(ctx, netip.AddrPortFrom(TailnetIP, uint16(TailnetSpeedtestPort)))
if err != nil { if err != nil {
return nil, xerrors.Errorf("dial speedtest: %w", err) return nil, xerrors.Errorf("dial speedtest: %w", err)
} }
@ -233,6 +246,8 @@ func (c *AgentConn) Speedtest(direction speedtest.Direction, duration time.Durat
} }
func (c *AgentConn) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) { func (c *AgentConn) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) {
ctx, span := tracing.StartSpan(ctx)
defer span.End()
if network == "unix" { if network == "unix" {
return nil, xerrors.New("network must be tcp or udp") return nil, xerrors.New("network must be tcp or udp")
} }
@ -277,6 +292,8 @@ func (c *AgentConn) statisticsClient() *http.Client {
} }
func (c *AgentConn) doStatisticsRequest(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) { func (c *AgentConn) doStatisticsRequest(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) {
ctx, span := tracing.StartSpan(ctx)
defer span.End()
host := net.JoinHostPort(TailnetIP.String(), strconv.Itoa(TailnetStatisticsPort)) host := net.JoinHostPort(TailnetIP.String(), strconv.Itoa(TailnetStatisticsPort))
url := fmt.Sprintf("http://%s%s", host, path) url := fmt.Sprintf("http://%s%s", host, path)
@ -309,6 +326,8 @@ type ListeningPort struct {
} }
func (c *AgentConn) ListeningPorts(ctx context.Context) (ListeningPortsResponse, error) { func (c *AgentConn) ListeningPorts(ctx context.Context) (ListeningPortsResponse, error) {
ctx, span := tracing.StartSpan(ctx)
defer span.End()
res, err := c.doStatisticsRequest(ctx, http.MethodGet, "/api/v0/listening-ports", nil) res, err := c.doStatisticsRequest(ctx, http.MethodGet, "/api/v0/listening-ports", nil)
if err != nil { if err != nil {
return ListeningPortsResponse{}, xerrors.Errorf("do request: %w", err) return ListeningPortsResponse{}, xerrors.Errorf("do request: %w", err)

View File

@ -12,7 +12,15 @@ import (
"net/url" "net/url"
"strings" "strings"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/propagation"
semconv "go.opentelemetry.io/otel/semconv/v1.11.0"
"go.opentelemetry.io/otel/trace"
"golang.org/x/xerrors" "golang.org/x/xerrors"
"github.com/coder/coder/coderd/tracing"
"cdr.dev/slog"
) )
// These cookies are Coder-specific. If a new one is added or changed, the name // These cookies are Coder-specific. If a new one is added or changed, the name
@ -30,6 +38,13 @@ const (
BypassRatelimitHeader = "X-Coder-Bypass-Ratelimit" BypassRatelimitHeader = "X-Coder-Bypass-Ratelimit"
) )
var loggableMimeTypes = map[string]struct{}{
"application/json": {},
"text/plain": {},
// lots of webserver error pages are HTML
"text/html": {},
}
// New creates a Coder client for the provided URL. // New creates a Coder client for the provided URL.
func New(serverURL *url.URL) *Client { func New(serverURL *url.URL) *Client {
return &Client{ return &Client{
@ -45,9 +60,35 @@ type Client struct {
SessionToken string SessionToken string
URL *url.URL URL *url.URL
// Logger can be provided to log requests. Request method, URL and response
// status code will be logged by default.
Logger slog.Logger
// LogBodies determines whether the request and response bodies are logged
// to the provided Logger. This is useful for debugging or testing.
LogBodies bool
// BypassRatelimits is an optional flag that can be set by the site owner to // BypassRatelimits is an optional flag that can be set by the site owner to
// disable ratelimit checks for the client. // disable ratelimit checks for the client.
BypassRatelimits bool BypassRatelimits bool
// PropagateTracing is an optional flag that can be set to propagate tracing
// spans to the Coder API. This is useful for seeing the entire request
// from end-to-end.
PropagateTracing bool
}
func (c *Client) Clone() *Client {
hc := *c.HTTPClient
u := *c.URL
return &Client{
HTTPClient: &hc,
SessionToken: c.SessionToken,
URL: &u,
Logger: c.Logger,
LogBodies: c.LogBodies,
BypassRatelimits: c.BypassRatelimits,
PropagateTracing: c.PropagateTracing,
}
} }
type RequestOption func(*http.Request) type RequestOption func(*http.Request)
@ -63,30 +104,46 @@ func WithQueryParam(key, value string) RequestOption {
} }
} }
// Request performs an HTTP request with the body provided. // Request performs a HTTP request with the body provided. The caller is
// The caller is responsible for closing the response body. // responsible for closing the response body.
func (c *Client) Request(ctx context.Context, method, path string, body interface{}, opts ...RequestOption) (*http.Response, error) { func (c *Client) Request(ctx context.Context, method, path string, body interface{}, opts ...RequestOption) (*http.Response, error) {
ctx, span := tracing.StartSpanWithName(ctx, tracing.FuncNameSkip(1))
defer span.End()
serverURL, err := c.URL.Parse(path) serverURL, err := c.URL.Parse(path)
if err != nil { if err != nil {
return nil, xerrors.Errorf("parse url: %w", err) return nil, xerrors.Errorf("parse url: %w", err)
} }
var buf bytes.Buffer var r io.Reader
if body != nil { if body != nil {
if data, ok := body.([]byte); ok { if data, ok := body.([]byte); ok {
buf = *bytes.NewBuffer(data) r = bytes.NewReader(data)
} else { } else {
// Assume JSON if not bytes. // Assume JSON if not bytes.
enc := json.NewEncoder(&buf) buf := bytes.NewBuffer(nil)
enc := json.NewEncoder(buf)
enc.SetEscapeHTML(false) enc.SetEscapeHTML(false)
err = enc.Encode(body) err = enc.Encode(body)
if err != nil { if err != nil {
return nil, xerrors.Errorf("encode body: %w", err) return nil, xerrors.Errorf("encode body: %w", err)
} }
r = buf
} }
} }
req, err := http.NewRequestWithContext(ctx, method, serverURL.String(), &buf) // Copy the request body so we can log it.
var reqBody []byte
if r != nil && c.LogBodies {
reqBody, err = io.ReadAll(r)
if err != nil {
return nil, xerrors.Errorf("read request body: %w", err)
}
r = bytes.NewReader(reqBody)
}
req, err := http.NewRequestWithContext(ctx, method, serverURL.String(), r)
if err != nil { if err != nil {
return nil, xerrors.Errorf("create request: %w", err) return nil, xerrors.Errorf("create request: %w", err)
} }
@ -95,17 +152,61 @@ func (c *Client) Request(ctx context.Context, method, path string, body interfac
req.Header.Set(BypassRatelimitHeader, "true") req.Header.Set(BypassRatelimitHeader, "true")
} }
if body != nil { if r != nil {
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
} }
for _, opt := range opts { for _, opt := range opts {
opt(req) opt(req)
} }
span.SetAttributes(semconv.NetAttributesFromHTTPRequest("tcp", req)...)
span.SetAttributes(semconv.HTTPClientAttributesFromHTTPRequest(req)...)
// Inject tracing headers if enabled.
if c.PropagateTracing {
tmp := otel.GetTextMapPropagator()
hc := propagation.HeaderCarrier(req.Header)
tmp.Inject(ctx, hc)
}
ctx = slog.With(ctx,
slog.F("method", req.Method),
slog.F("url", req.URL.String()),
)
c.Logger.Debug(ctx, "sdk request", slog.F("body", string(reqBody)))
resp, err := c.HTTPClient.Do(req) resp, err := c.HTTPClient.Do(req)
if err != nil { if err != nil {
return nil, xerrors.Errorf("do: %w", err) return nil, xerrors.Errorf("do: %w", err)
} }
span.SetAttributes(semconv.HTTPStatusCodeKey.Int(resp.StatusCode))
span.SetStatus(semconv.SpanStatusFromHTTPStatusCodeAndSpanKind(resp.StatusCode, trace.SpanKindClient))
// Copy the response body so we can log it if it's a loggable mime type.
var respBody []byte
if resp.Body != nil && c.LogBodies {
mimeType := parseMimeType(resp.Header.Get("Content-Type"))
if _, ok := loggableMimeTypes[mimeType]; ok {
respBody, err = io.ReadAll(resp.Body)
if err != nil {
return nil, xerrors.Errorf("copy response body for logs: %w", err)
}
err = resp.Body.Close()
if err != nil {
return nil, xerrors.Errorf("close response body: %w", err)
}
resp.Body = io.NopCloser(bytes.NewReader(respBody))
}
}
c.Logger.Debug(ctx, "sdk response",
slog.F("status", resp.StatusCode),
slog.F("body", string(respBody)),
slog.F("trace_id", resp.Header.Get("X-Trace-Id")),
slog.F("span_id", resp.Header.Get("X-Span-Id")),
)
return resp, err return resp, err
} }
@ -138,10 +239,7 @@ func readBodyAsError(res *http.Response) error {
return xerrors.Errorf("read body: %w", err) return xerrors.Errorf("read body: %w", err)
} }
mimeType, _, err := mime.ParseMediaType(contentType) mimeType := parseMimeType(contentType)
if err != nil {
mimeType = strings.TrimSpace(strings.Split(contentType, ";")[0])
}
if mimeType != "application/json" { if mimeType != "application/json" {
if len(resp) > 1024 { if len(resp) > 1024 {
resp = append(resp[:1024], []byte("...")...) resp = append(resp[:1024], []byte("...")...)
@ -238,3 +336,12 @@ type closeFunc func() error
func (c closeFunc) Close() error { func (c closeFunc) Close() error {
return c() return c()
} }
func parseMimeType(contentType string) string {
mimeType, _, err := mime.ParseMediaType(contentType)
if err != nil {
mimeType = strings.TrimSpace(strings.Split(contentType, ";")[0])
}
return mimeType
}

View File

@ -2,23 +2,114 @@ package codersdk
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url"
"strconv" "strconv"
"strings" "strings"
"testing" "testing"
"github.com/go-logr/logr"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/propagation"
"go.opentelemetry.io/otel/sdk/resource"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
semconv "go.opentelemetry.io/otel/semconv/v1.4.0"
"golang.org/x/xerrors" "golang.org/x/xerrors"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/sloghuman"
"github.com/coder/coder/testutil"
) )
const ( const jsonCT = "application/json"
jsonCT = "application/json"
) func Test_Client(t *testing.T) {
t.Parallel()
const method = http.MethodPost
const path = "/ok"
const token = "token"
const reqBody = `{"msg": "request body"}`
const resBody = `{"status": "ok"}`
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, method, r.Method)
assert.Equal(t, path, r.URL.Path)
assert.Equal(t, token, r.Header.Get(SessionCustomHeader))
assert.Equal(t, "true", r.Header.Get(BypassRatelimitHeader))
assert.NotEmpty(t, r.Header.Get("Traceparent"))
for k, v := range r.Header {
t.Logf("header %q: %q", k, strings.Join(v, ", "))
}
w.Header().Set("Content-Type", jsonCT)
w.WriteHeader(http.StatusOK)
_, _ = io.WriteString(w, resBody)
}))
u, err := url.Parse(s.URL)
require.NoError(t, err)
client := New(u)
client.SessionToken = token
client.BypassRatelimits = true
logBuf := bytes.NewBuffer(nil)
client.Logger = slog.Make(sloghuman.Sink(logBuf)).Leveled(slog.LevelDebug)
client.LogBodies = true
// Setup tracing.
res := resource.NewWithAttributes(
semconv.SchemaURL,
semconv.ServiceNameKey.String("codersdk_test"),
)
tracerOpts := []sdktrace.TracerProviderOption{
sdktrace.WithResource(res),
}
tracerProvider := sdktrace.NewTracerProvider(tracerOpts...)
otel.SetTracerProvider(tracerProvider)
otel.SetErrorHandler(otel.ErrorHandlerFunc(func(err error) {}))
otel.SetTextMapPropagator(
propagation.NewCompositeTextMapPropagator(
propagation.TraceContext{},
propagation.Baggage{},
),
)
otel.SetLogger(logr.Discard())
client.PropagateTracing = true
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
ctx, span := tracerProvider.Tracer("codersdk_test").Start(ctx, "codersdk client test 1")
defer span.End()
resp, err := client.Request(ctx, method, path, []byte(reqBody))
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)
require.Equal(t, jsonCT, resp.Header.Get("Content-Type"))
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, resBody, string(body))
logStr := logBuf.String()
require.Contains(t, logStr, "sdk request")
require.Contains(t, logStr, method)
require.Contains(t, logStr, path)
require.Contains(t, logStr, strings.ReplaceAll(reqBody, `"`, `\"`))
require.Contains(t, logStr, "sdk response")
require.Contains(t, logStr, "200")
require.Contains(t, logStr, strings.ReplaceAll(resBody, `"`, `\"`))
}
func Test_readBodyAsError(t *testing.T) { func Test_readBodyAsError(t *testing.T) {
t.Parallel() t.Parallel()

View File

@ -2,11 +2,14 @@ package codersdk
import ( import (
"bufio" "bufio"
"context"
"fmt" "fmt"
"io" "io"
"strings" "strings"
"golang.org/x/xerrors" "golang.org/x/xerrors"
"github.com/coder/coder/coderd/tracing"
) )
type ServerSentEvent struct { type ServerSentEvent struct {
@ -22,7 +25,10 @@ const (
ServerSentEventTypeError ServerSentEventType = "error" ServerSentEventTypeError ServerSentEventType = "error"
) )
func ServerSentEventReader(rc io.ReadCloser) func() (*ServerSentEvent, error) { func ServerSentEventReader(ctx context.Context, rc io.ReadCloser) func() (*ServerSentEvent, error) {
_, span := tracing.StartSpan(ctx)
defer span.End()
reader := bufio.NewReader(rc) reader := bufio.NewReader(rc)
nextLineValue := func(prefix string) ([]byte, error) { nextLineValue := func(prefix string) ([]byte, error) {
var ( var (

View File

@ -10,6 +10,8 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
"golang.org/x/xerrors" "golang.org/x/xerrors"
"github.com/coder/coder/coderd/tracing"
) )
// Workspace is a deployment of a template. It references a specific // Workspace is a deployment of a template. It references a specific
@ -137,6 +139,8 @@ func (c *Client) CreateWorkspaceBuild(ctx context.Context, workspace uuid.UUID,
} }
func (c *Client) WatchWorkspace(ctx context.Context, id uuid.UUID) (<-chan Workspace, error) { func (c *Client) WatchWorkspace(ctx context.Context, id uuid.UUID) (<-chan Workspace, error) {
ctx, span := tracing.StartSpan(ctx)
defer span.End()
//nolint:bodyclose //nolint:bodyclose
res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/workspaces/%s/watch", id), nil) res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/workspaces/%s/watch", id), nil)
if err != nil { if err != nil {
@ -145,7 +149,7 @@ func (c *Client) WatchWorkspace(ctx context.Context, id uuid.UUID) (<-chan Works
if res.StatusCode != http.StatusOK { if res.StatusCode != http.StatusOK {
return nil, readBodyAsError(res) return nil, readBodyAsError(res)
} }
nextEvent := ServerSentEventReader(res.Body) nextEvent := ServerSentEventReader(ctx, res.Body)
wc := make(chan Workspace, 256) wc := make(chan Workspace, 256)
go func() { go func() {

View File

@ -9,7 +9,6 @@ import (
"net/netip" "net/netip"
"net/url" "net/url"
"strconv" "strconv"
"sync"
"time" "time"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
@ -17,8 +16,10 @@ import (
"cdr.dev/slog" "cdr.dev/slog"
"cdr.dev/slog/sloggers/sloghuman" "cdr.dev/slog/sloggers/sloghuman"
"github.com/coder/coder/coderd/tracing"
"github.com/coder/coder/codersdk" "github.com/coder/coder/codersdk"
"github.com/coder/coder/loadtest/harness" "github.com/coder/coder/loadtest/harness"
"github.com/coder/coder/loadtest/loadtestutil"
) )
const defaultRequestTimeout = 5 * time.Second const defaultRequestTimeout = 5 * time.Second
@ -45,11 +46,13 @@ func NewRunner(client *codersdk.Client, cfg Config) *Runner {
// Run implements Runnable. // Run implements Runnable.
func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) error { func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) error {
logs = syncWriter{ ctx, span := tracing.StartSpan(ctx)
mut: &sync.Mutex{}, defer span.End()
w: logs,
} logs = loadtestutil.NewSyncWriter(logs)
logger := slog.Make(sloghuman.Sink(logs)).Leveled(slog.LevelDebug) logger := slog.Make(sloghuman.Sink(logs)).Leveled(slog.LevelDebug)
r.client.Logger = logger
r.client.LogBodies = true
_, _ = fmt.Fprintln(logs, "Opening connection to workspace agent") _, _ = fmt.Fprintln(logs, "Opening connection to workspace agent")
switch r.cfg.ConnectionMode { switch r.cfg.ConnectionMode {
@ -69,9 +72,72 @@ func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) error {
} }
defer conn.Close() defer conn.Close()
// Wait for the disco connection to be established. err = waitForDisco(ctx, logs, conn)
if err != nil {
return xerrors.Errorf("wait for discovery connection: %w", err)
}
// Wait for a direct connection if requested.
if r.cfg.ConnectionMode == ConnectionModeDirect {
err = waitForDirectConnection(ctx, logs, conn)
if err != nil {
return xerrors.Errorf("wait for direct connection: %w", err)
}
}
// Ensure DERP for completeness.
if r.cfg.ConnectionMode == ConnectionModeDerp {
status := conn.Status()
if len(status.Peers()) != 1 {
return xerrors.Errorf("check connection mode: expected 1 peer, got %d", len(status.Peers()))
}
peer := status.Peer[status.Peers()[0]]
if peer.Relay == "" || peer.CurAddr != "" {
return xerrors.Errorf("check connection mode: peer is connected directly, not via DERP")
}
}
_, _ = fmt.Fprint(logs, "\nConnection established.\n\n")
// HACK: even though the ping passed above, we still need to open a
// connection to the agent to ensure it's ready to accept connections. Not
// sure why this is the case but it seems to be necessary.
err = verifyConnection(ctx, logs, conn)
if err != nil {
return xerrors.Errorf("verify connection: %w", err)
}
_, _ = fmt.Fprint(logs, "\nConnection verified.\n\n")
// Make initial connections sequentially to ensure the services are
// reachable before we start spawning a bunch of goroutines and tickers.
err = performInitialConnections(ctx, logs, conn, r.cfg.Connections)
if err != nil {
return xerrors.Errorf("perform initial connections: %w", err)
}
if r.cfg.HoldDuration > 0 {
err = holdConnection(ctx, logs, conn, time.Duration(r.cfg.HoldDuration), r.cfg.Connections)
if err != nil {
return xerrors.Errorf("hold connection: %w", err)
}
}
err = conn.Close()
if err != nil {
return xerrors.Errorf("close connection: %w", err)
}
return nil
}
func waitForDisco(ctx context.Context, logs io.Writer, conn *codersdk.AgentConn) error {
const pingAttempts = 10 const pingAttempts = 10
const pingDelay = 1 * time.Second const pingDelay = 1 * time.Second
ctx, span := tracing.StartSpan(ctx)
defer span.End()
for i := 0; i < pingAttempts; i++ { for i := 0; i < pingAttempts; i++ {
_, _ = fmt.Fprintf(logs, "\tDisco ping attempt %d/%d...\n", i+1, pingAttempts) _, _ = fmt.Fprintf(logs, "\tDisco ping attempt %d/%d...\n", i+1, pingAttempts)
pingCtx, cancel := context.WithTimeout(ctx, defaultRequestTimeout) pingCtx, cancel := context.WithTimeout(ctx, defaultRequestTimeout)
@ -93,80 +159,59 @@ func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) error {
} }
} }
// Wait for a direct connection if requested. return nil
if r.cfg.ConnectionMode == ConnectionModeDirect { }
const directConnectionAttempts = 30
const directConnectionDelay = 1 * time.Second
for i := 0; i < directConnectionAttempts; i++ {
_, _ = fmt.Fprintf(logs, "\tDirect connection check %d/%d...\n", i+1, directConnectionAttempts)
status := conn.Status()
var err error func waitForDirectConnection(ctx context.Context, logs io.Writer, conn *codersdk.AgentConn) error {
if len(status.Peers()) != 1 { const directConnectionAttempts = 30
_, _ = fmt.Fprintf(logs, "\t\tExpected 1 peer, found %d", len(status.Peers())) const directConnectionDelay = 1 * time.Second
err = xerrors.Errorf("expected 1 peer, got %d", len(status.Peers()))
} else {
peer := status.Peer[status.Peers()[0]]
_, _ = fmt.Fprintf(logs, "\t\tCurAddr: %s\n", peer.CurAddr)
_, _ = fmt.Fprintf(logs, "\t\tRelay: %s\n", peer.Relay)
if peer.Relay != "" && peer.CurAddr == "" {
err = xerrors.Errorf("peer is connected via DERP, not direct")
}
}
if err == nil {
break
}
if i == directConnectionAttempts-1 {
return xerrors.Errorf("wait for direct connection to agent: %w", err)
}
select { ctx, span := tracing.StartSpan(ctx)
case <-ctx.Done(): defer span.End()
return xerrors.Errorf("wait for direct connection to agent: %w", ctx.Err())
// We use time.After here since it's a very short duration so
// leaking a timer is fine.
case <-time.After(directConnectionDelay):
}
}
}
// Ensure DERP for completeness. for i := 0; i < directConnectionAttempts; i++ {
if r.cfg.ConnectionMode == ConnectionModeDerp { _, _ = fmt.Fprintf(logs, "\tDirect connection check %d/%d...\n", i+1, directConnectionAttempts)
status := conn.Status() status := conn.Status()
var err error
if len(status.Peers()) != 1 { if len(status.Peers()) != 1 {
return xerrors.Errorf("check connection mode: expected 1 peer, got %d", len(status.Peers())) _, _ = fmt.Fprintf(logs, "\t\tExpected 1 peer, found %d", len(status.Peers()))
err = xerrors.Errorf("expected 1 peer, got %d", len(status.Peers()))
} else {
peer := status.Peer[status.Peers()[0]]
_, _ = fmt.Fprintf(logs, "\t\tCurAddr: %s\n", peer.CurAddr)
_, _ = fmt.Fprintf(logs, "\t\tRelay: %s\n", peer.Relay)
if peer.Relay != "" && peer.CurAddr == "" {
err = xerrors.Errorf("peer is connected via DERP, not direct")
}
} }
peer := status.Peer[status.Peers()[0]] if err == nil {
if peer.Relay == "" || peer.CurAddr != "" { break
return xerrors.Errorf("check connection mode: peer is connected directly, not via DERP") }
if i == directConnectionAttempts-1 {
return xerrors.Errorf("wait for direct connection to agent: %w", err)
}
select {
case <-ctx.Done():
return xerrors.Errorf("wait for direct connection to agent: %w", ctx.Err())
// We use time.After here since it's a very short duration so
// leaking a timer is fine.
case <-time.After(directConnectionDelay):
} }
} }
_, _ = fmt.Fprint(logs, "\nConnection established.\n\n") return nil
}
client := &http.Client{ func verifyConnection(ctx context.Context, logs io.Writer, conn *codersdk.AgentConn) error {
Transport: &http.Transport{
DisableKeepAlives: true,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
_, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, xerrors.Errorf("split host port %q: %w", addr, err)
}
portUint, err := strconv.ParseUint(port, 10, 16)
if err != nil {
return nil, xerrors.Errorf("parse port %q: %w", port, err)
}
return conn.DialContextTCP(ctx, netip.AddrPortFrom(codersdk.TailnetIP, uint16(portUint)))
},
},
}
// HACK: even though the ping passed above, we still need to open a
// connection to the agent to ensure it's ready to accept connections. Not
// sure why this is the case but it seems to be necessary.
const verifyConnectionAttempts = 30 const verifyConnectionAttempts = 30
const verifyConnectionDelay = 1 * time.Second const verifyConnectionDelay = 1 * time.Second
ctx, span := tracing.StartSpan(ctx)
defer span.End()
client := agentHTTPClient(conn)
for i := 0; i < verifyConnectionAttempts; i++ { for i := 0; i < verifyConnectionAttempts; i++ {
_, _ = fmt.Fprintf(logs, "\tVerify connection attempt %d/%d...\n", i+1, verifyConnectionAttempts) _, _ = fmt.Fprintf(logs, "\tVerify connection attempt %d/%d...\n", i+1, verifyConnectionAttempts)
verifyCtx, cancel := context.WithTimeout(ctx, defaultRequestTimeout) verifyCtx, cancel := context.WithTimeout(ctx, defaultRequestTimeout)
@ -198,14 +243,20 @@ func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) error {
} }
} }
_, _ = fmt.Fprint(logs, "\nConnection verified.\n\n") return nil
}
// Make initial connections sequentially to ensure the services are func performInitialConnections(ctx context.Context, logs io.Writer, conn *codersdk.AgentConn, specs []Connection) error {
// reachable before we start spawning a bunch of goroutines and tickers. if len(specs) == 0 {
if len(r.cfg.Connections) > 0 { return nil
_, _ = fmt.Fprintln(logs, "Performing initial service connections...")
} }
for i, connSpec := range r.cfg.Connections {
ctx, span := tracing.StartSpan(ctx)
defer span.End()
_, _ = fmt.Fprintln(logs, "Performing initial service connections...")
client := agentHTTPClient(conn)
for i, connSpec := range specs {
_, _ = fmt.Fprintf(logs, "\t%d. %s\n", i, connSpec.URL) _, _ = fmt.Fprintf(logs, "\t%d. %s\n", i, connSpec.URL)
timeout := defaultRequestTimeout timeout := defaultRequestTimeout
@ -230,95 +281,102 @@ func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) error {
_, _ = fmt.Fprintln(logs, "\t\tOK") _, _ = fmt.Fprintln(logs, "\t\tOK")
} }
if r.cfg.HoldDuration > 0 { return nil
eg, egCtx := errgroup.WithContext(ctx) }
if len(r.cfg.Connections) > 0 { func holdConnection(ctx context.Context, logs io.Writer, conn *codersdk.AgentConn, holdDur time.Duration, specs []Connection) error {
_, _ = fmt.Fprintln(logs, "\nStarting connection loops...") ctx, span := tracing.StartSpan(ctx)
} defer span.End()
for i, connSpec := range r.cfg.Connections {
i, connSpec := i, connSpec
if connSpec.Interval <= 0 {
continue
}
eg.Go(func() error { eg, egCtx := errgroup.WithContext(ctx)
t := time.NewTicker(time.Duration(connSpec.Interval)) client := agentHTTPClient(conn)
defer t.Stop() if len(specs) > 0 {
_, _ = fmt.Fprintln(logs, "\nStarting connection loops...")
timeout := defaultRequestTimeout }
if connSpec.Timeout > 0 { for i, connSpec := range specs {
timeout = time.Duration(connSpec.Timeout) i, connSpec := i, connSpec
} if connSpec.Interval <= 0 {
continue
for {
select {
case <-egCtx.Done():
return egCtx.Err()
case <-t.C:
ctx, cancel := context.WithTimeout(ctx, timeout)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, connSpec.URL, nil)
if err != nil {
cancel()
return xerrors.Errorf("create request: %w", err)
}
res, err := client.Do(req)
cancel()
if err != nil {
_, _ = fmt.Fprintf(logs, "\tERR: %s (%d): %+v\n", connSpec.URL, i, err)
return xerrors.Errorf("make connection to conn spec %d %q: %w", i, connSpec.URL, err)
}
res.Body.Close()
_, _ = fmt.Fprintf(logs, "\tOK: %s (%d)\n", connSpec.URL, i)
t.Reset(time.Duration(connSpec.Interval))
}
}
})
} }
// Wait for the hold duration to end. We use a fake error to signal that
// the hold duration has ended.
_, _ = fmt.Fprintf(logs, "\nWaiting for %s...\n", time.Duration(r.cfg.HoldDuration))
eg.Go(func() error { eg.Go(func() error {
t := time.NewTicker(time.Duration(r.cfg.HoldDuration)) t := time.NewTicker(time.Duration(connSpec.Interval))
defer t.Stop() defer t.Stop()
select { timeout := defaultRequestTimeout
case <-egCtx.Done(): if connSpec.Timeout > 0 {
return egCtx.Err() timeout = time.Duration(connSpec.Timeout)
case <-t.C: }
// Returning an error here will cause the errgroup context to
// be canceled, which is what we want. This fake error is for {
// ignored below. select {
return holdDurationEndedError{} case <-egCtx.Done():
return egCtx.Err()
case <-t.C:
ctx, cancel := context.WithTimeout(ctx, timeout)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, connSpec.URL, nil)
if err != nil {
cancel()
return xerrors.Errorf("create request: %w", err)
}
res, err := client.Do(req)
cancel()
if err != nil {
_, _ = fmt.Fprintf(logs, "\tERR: %s (%d): %+v\n", connSpec.URL, i, err)
return xerrors.Errorf("make connection to conn spec %d %q: %w", i, connSpec.URL, err)
}
res.Body.Close()
_, _ = fmt.Fprintf(logs, "\tOK: %s (%d)\n", connSpec.URL, i)
t.Reset(time.Duration(connSpec.Interval))
}
} }
}) })
err = eg.Wait()
if err != nil && !xerrors.Is(err, holdDurationEndedError{}) {
return xerrors.Errorf("run connections loop: %w", err)
}
} }
err = conn.Close() // Wait for the hold duration to end. We use a fake error to signal that
if err != nil { // the hold duration has ended.
return xerrors.Errorf("close connection: %w", err) _, _ = fmt.Fprintf(logs, "\nWaiting for %s...\n", holdDur)
eg.Go(func() error {
t := time.NewTicker(holdDur)
defer t.Stop()
select {
case <-egCtx.Done():
return egCtx.Err()
case <-t.C:
// Returning an error here will cause the errgroup context to
// be canceled, which is what we want. This fake error is
// ignored below.
return holdDurationEndedError{}
}
})
err := eg.Wait()
if err != nil && !xerrors.Is(err, holdDurationEndedError{}) {
return xerrors.Errorf("run connections loop: %w", err)
} }
return nil return nil
} }
// syncWriter wraps an io.Writer in a sync.Mutex. func agentHTTPClient(conn *codersdk.AgentConn) *http.Client {
type syncWriter struct { return &http.Client{
mut *sync.Mutex Transport: &http.Transport{
w io.Writer DisableKeepAlives: true,
} DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
_, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, xerrors.Errorf("split host port %q: %w", addr, err)
}
// Write implements io.Writer. portUint, err := strconv.ParseUint(port, 10, 16)
func (sw syncWriter) Write(p []byte) (n int, err error) { if err != nil {
sw.mut.Lock() return nil, xerrors.Errorf("parse port %q: %w", port, err)
defer sw.mut.Unlock() }
return sw.w.Write(p) return conn.DialContextTCP(ctx, netip.AddrPortFrom(codersdk.TailnetIP, uint16(portUint)))
},
},
}
} }

View File

@ -7,6 +7,8 @@ import (
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
"golang.org/x/xerrors" "golang.org/x/xerrors"
"github.com/coder/coder/coderd/tracing"
) )
// ExecutionStrategy defines how a TestHarness should execute a set of runs. It // ExecutionStrategy defines how a TestHarness should execute a set of runs. It
@ -49,6 +51,9 @@ func NewTestHarness(strategy ExecutionStrategy) *TestHarness {
// //
// Panics if called more than once. // Panics if called more than once.
func (h *TestHarness) Run(ctx context.Context) (err error) { func (h *TestHarness) Run(ctx context.Context) (err error) {
ctx, span := tracing.StartSpan(ctx)
defer span.End()
h.mut.Lock() h.mut.Lock()
if h.started { if h.started {
h.mut.Unlock() h.mut.Unlock()

View File

@ -95,6 +95,7 @@ type timeoutRunnerWrapper struct {
} }
var _ Runnable = timeoutRunnerWrapper{} var _ Runnable = timeoutRunnerWrapper{}
var _ Cleanable = timeoutRunnerWrapper{}
func (t timeoutRunnerWrapper) Run(ctx context.Context, id string, logs io.Writer) error { func (t timeoutRunnerWrapper) Run(ctx context.Context, id string, logs io.Writer) error {
ctx, cancel := context.WithTimeout(ctx, t.timeout) ctx, cancel := context.WithTimeout(ctx, t.timeout)
@ -103,6 +104,15 @@ func (t timeoutRunnerWrapper) Run(ctx context.Context, id string, logs io.Writer
return t.inner.Run(ctx, id, logs) return t.inner.Run(ctx, id, logs)
} }
func (t timeoutRunnerWrapper) Cleanup(ctx context.Context, id string) error {
c, ok := t.inner.(Cleanable)
if !ok {
return nil
}
return c.Cleanup(ctx, id)
}
// Execute implements ExecutionStrategy. // Execute implements ExecutionStrategy.
func (t TimeoutExecutionStrategyWrapper) Execute(ctx context.Context, runs []*TestRun) error { func (t TimeoutExecutionStrategyWrapper) Execute(ctx context.Context, runs []*TestRun) error {
for _, run := range runs { for _, run := range runs {

View File

@ -0,0 +1,26 @@
package loadtestutil
import (
"io"
"sync"
)
// SyncWriter wraps an io.Writer in a sync.Mutex.
type SyncWriter struct {
mut *sync.Mutex
w io.Writer
}
func NewSyncWriter(w io.Writer) *SyncWriter {
return &SyncWriter{
mut: &sync.Mutex{},
w: w,
}
}
// Write implements io.Writer.
func (sw *SyncWriter) Write(p []byte) (n int, err error) {
sw.mut.Lock()
defer sw.mut.Unlock()
return sw.w.Write(p)
}

View File

@ -9,9 +9,14 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
"golang.org/x/xerrors" "golang.org/x/xerrors"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/sloghuman"
"github.com/coder/coder/coderd/tracing"
"github.com/coder/coder/codersdk" "github.com/coder/coder/codersdk"
"github.com/coder/coder/cryptorand" "github.com/coder/coder/cryptorand"
"github.com/coder/coder/loadtest/harness" "github.com/coder/coder/loadtest/harness"
"github.com/coder/coder/loadtest/loadtestutil"
) )
type Runner struct { type Runner struct {
@ -32,6 +37,14 @@ func NewRunner(client *codersdk.Client, cfg Config) *Runner {
// Run implements Runnable. // Run implements Runnable.
func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) error { func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) error {
ctx, span := tracing.StartSpan(ctx)
defer span.End()
logs = loadtestutil.NewSyncWriter(logs)
logger := slog.Make(sloghuman.Sink(logs)).Leveled(slog.LevelDebug)
r.client.Logger = logger
r.client.LogBodies = true
req := r.cfg.Request req := r.cfg.Request
if req.Name == "" { if req.Name == "" {
randName, err := cryptorand.HexString(8) randName, err := cryptorand.HexString(8)
@ -66,6 +79,8 @@ func (r *Runner) Cleanup(ctx context.Context, _ string) error {
if r.workspaceID == uuid.Nil { if r.workspaceID == uuid.Nil {
return nil return nil
} }
ctx, span := tracing.StartSpan(ctx)
defer span.End()
build, err := r.client.CreateWorkspaceBuild(ctx, r.workspaceID, codersdk.CreateWorkspaceBuildRequest{ build, err := r.client.CreateWorkspaceBuild(ctx, r.workspaceID, codersdk.CreateWorkspaceBuildRequest{
Transition: codersdk.WorkspaceTransitionDelete, Transition: codersdk.WorkspaceTransitionDelete,
@ -85,6 +100,8 @@ func (r *Runner) Cleanup(ctx context.Context, _ string) error {
} }
func waitForBuild(ctx context.Context, w io.Writer, client *codersdk.Client, buildID uuid.UUID) error { func waitForBuild(ctx context.Context, w io.Writer, client *codersdk.Client, buildID uuid.UUID) error {
ctx, span := tracing.StartSpan(ctx)
defer span.End()
_, _ = fmt.Fprint(w, "Build is currently queued...") _, _ = fmt.Fprint(w, "Build is currently queued...")
// Wait for build to start. // Wait for build to start.
@ -154,6 +171,8 @@ func waitForBuild(ctx context.Context, w io.Writer, client *codersdk.Client, bui
} }
func waitForAgents(ctx context.Context, w io.Writer, client *codersdk.Client, workspaceID uuid.UUID) error { func waitForAgents(ctx context.Context, w io.Writer, client *codersdk.Client, workspaceID uuid.UUID) error {
ctx, span := tracing.StartSpan(ctx)
defer span.End()
_, _ = fmt.Fprint(w, "Waiting for agents to connect...\n\n") _, _ = fmt.Fprint(w, "Waiting for agents to connect...\n\n")
for { for {