feat: extend request logs with auth & DB info (#17304)

Closes #16903
This commit is contained in:
Michael Suchacz
2025-04-15 13:27:23 +02:00
committed by GitHub
parent 979687c37f
commit 06d39151dc
19 changed files with 336 additions and 35 deletions

View File

@ -0,0 +1,203 @@
package loggermw
import (
"context"
"fmt"
"net/http"
"sync"
"time"
"github.com/go-chi/chi/v5"
"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/coderd/tracing"
)
func Logger(log slog.Logger) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
start := time.Now()
sw, ok := rw.(*tracing.StatusWriter)
if !ok {
panic(fmt.Sprintf("ResponseWriter not a *tracing.StatusWriter; got %T", rw))
}
httplog := log.With(
slog.F("host", httpapi.RequestHost(r)),
slog.F("path", r.URL.Path),
slog.F("proto", r.Proto),
slog.F("remote_addr", r.RemoteAddr),
// Include the start timestamp in the log so that we have the
// source of truth. There is at least a theoretical chance that
// there can be a delay between `next.ServeHTTP` ending and us
// actually logging the request. This can also be useful when
// filtering logs that started at a certain time (compared to
// trying to compute the value).
slog.F("start", start),
)
logContext := NewRequestLogger(httplog, r.Method, start)
ctx := WithRequestLogger(r.Context(), logContext)
next.ServeHTTP(sw, r.WithContext(ctx))
// Don't log successful health check requests.
if r.URL.Path == "/api/v2" && sw.Status == http.StatusOK {
return
}
// For status codes 500 and higher we
// want to log the response body.
if sw.Status >= http.StatusInternalServerError {
logContext.WithFields(
slog.F("response_body", string(sw.ResponseBody())),
)
}
logContext.WriteLog(r.Context(), sw.Status)
})
}
}
type RequestLogger interface {
WithFields(fields ...slog.Field)
WriteLog(ctx context.Context, status int)
WithAuthContext(actor rbac.Subject)
}
type SlogRequestLogger struct {
log slog.Logger
written bool
message string
start time.Time
// Protects actors map for concurrent writes.
mu sync.RWMutex
actors map[rbac.SubjectType]rbac.Subject
}
var _ RequestLogger = &SlogRequestLogger{}
func NewRequestLogger(log slog.Logger, message string, start time.Time) RequestLogger {
return &SlogRequestLogger{
log: log,
written: false,
message: message,
start: start,
actors: make(map[rbac.SubjectType]rbac.Subject),
}
}
func (c *SlogRequestLogger) WithFields(fields ...slog.Field) {
c.log = c.log.With(fields...)
}
func (c *SlogRequestLogger) WithAuthContext(actor rbac.Subject) {
c.mu.Lock()
defer c.mu.Unlock()
c.actors[actor.Type] = actor
}
func (c *SlogRequestLogger) addAuthContextFields() {
c.mu.RLock()
defer c.mu.RUnlock()
usr, ok := c.actors[rbac.SubjectTypeUser]
if ok {
c.log = c.log.With(
slog.F("requestor_id", usr.ID),
slog.F("requestor_name", usr.FriendlyName),
slog.F("requestor_email", usr.Email),
)
} else {
// If there is no user, we log the requestor name for the first
// actor in a defined order.
for _, v := range actorLogOrder {
subj, ok := c.actors[v]
if !ok {
continue
}
c.log = c.log.With(
slog.F("requestor_name", subj.FriendlyName),
)
break
}
}
}
var actorLogOrder = []rbac.SubjectType{
rbac.SubjectTypeAutostart,
rbac.SubjectTypeCryptoKeyReader,
rbac.SubjectTypeCryptoKeyRotator,
rbac.SubjectTypeHangDetector,
rbac.SubjectTypeNotifier,
rbac.SubjectTypePrebuildsOrchestrator,
rbac.SubjectTypeProvisionerd,
rbac.SubjectTypeResourceMonitor,
rbac.SubjectTypeSystemReadProvisionerDaemons,
rbac.SubjectTypeSystemRestricted,
}
func (c *SlogRequestLogger) WriteLog(ctx context.Context, status int) {
if c.written {
return
}
c.written = true
end := time.Now()
// Right before we write the log, we try to find the user in the actors
// and add the fields to the log.
c.addAuthContextFields()
logger := c.log.With(
slog.F("took", end.Sub(c.start)),
slog.F("status_code", status),
slog.F("latency_ms", float64(end.Sub(c.start)/time.Millisecond)),
)
// If the request is routed, add the route parameters to the log.
if chiCtx := chi.RouteContext(ctx); chiCtx != nil {
urlParams := chiCtx.URLParams
routeParamsFields := make([]slog.Field, 0, len(urlParams.Keys))
for k, v := range urlParams.Keys {
if urlParams.Values[k] != "" {
routeParamsFields = append(routeParamsFields, slog.F("params_"+v, urlParams.Values[k]))
}
}
if len(routeParamsFields) > 0 {
logger = logger.With(routeParamsFields...)
}
}
// We already capture most of this information in the span (minus
// the response body which we don't want to capture anyways).
tracing.RunWithoutSpan(ctx, func(ctx context.Context) {
// We should not log at level ERROR for 5xx status codes because 5xx
// includes proxy errors etc. It also causes slogtest to fail
// instantly without an error message by default.
if status >= http.StatusInternalServerError {
logger.Warn(ctx, c.message)
} else {
logger.Debug(ctx, c.message)
}
})
}
type logContextKey struct{}
func WithRequestLogger(ctx context.Context, rl RequestLogger) context.Context {
return context.WithValue(ctx, logContextKey{}, rl)
}
func RequestLoggerFromContext(ctx context.Context) RequestLogger {
val := ctx.Value(logContextKey{})
if logCtx, ok := val.(RequestLogger); ok {
return logCtx
}
return nil
}

View File

@ -0,0 +1,311 @@
package loggermw
import (
"context"
"net/http"
"net/http/httptest"
"slices"
"strings"
"sync"
"testing"
"time"
"github.com/go-chi/chi/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/tracing"
"github.com/coder/coder/v2/testutil"
"github.com/coder/websocket"
)
func TestRequestLogger_WriteLog(t *testing.T) {
t.Parallel()
ctx := context.Background()
sink := &fakeSink{}
logger := slog.Make(sink)
logger = logger.Leveled(slog.LevelDebug)
logCtx := NewRequestLogger(logger, "GET", time.Now())
// Add custom fields
logCtx.WithFields(
slog.F("custom_field", "custom_value"),
)
// Write log for 200 status
logCtx.WriteLog(ctx, http.StatusOK)
require.Len(t, sink.entries, 1, "log was written twice")
require.Equal(t, sink.entries[0].Message, "GET")
require.Equal(t, sink.entries[0].Fields[0].Value, "custom_value")
// Attempt to write again (should be skipped).
logCtx.WriteLog(ctx, http.StatusInternalServerError)
require.Len(t, sink.entries, 1, "log was written twice")
}
func TestLoggerMiddleware_SingleRequest(t *testing.T) {
t.Parallel()
sink := &fakeSink{}
logger := slog.Make(sink)
logger = logger.Leveled(slog.LevelDebug)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
// Create a test handler to simulate an HTTP request
testHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(http.StatusOK)
_, _ = rw.Write([]byte("OK"))
})
// Wrap the test handler with the Logger middleware
loggerMiddleware := Logger(logger)
wrappedHandler := loggerMiddleware(testHandler)
// Create a test HTTP request
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "/test-path", nil)
require.NoError(t, err, "failed to create request")
sw := &tracing.StatusWriter{ResponseWriter: httptest.NewRecorder()}
// Serve the request
wrappedHandler.ServeHTTP(sw, req)
require.Len(t, sink.entries, 1, "log was written twice")
require.Equal(t, sink.entries[0].Message, "GET")
fieldsMap := make(map[string]any)
for _, field := range sink.entries[0].Fields {
fieldsMap[field.Name] = field.Value
}
// Check that the log contains the expected fields
requiredFields := []string{"host", "path", "proto", "remote_addr", "start", "took", "status_code", "latency_ms"}
for _, field := range requiredFields {
_, exists := fieldsMap[field]
require.True(t, exists, "field %q is missing in log fields", field)
}
require.Len(t, sink.entries[0].Fields, len(requiredFields), "log should contain only the required fields")
// Check value of the status code
require.Equal(t, fieldsMap["status_code"], http.StatusOK)
}
func TestLoggerMiddleware_WebSocket(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
sink := &fakeSink{
newEntries: make(chan slog.SinkEntry, 2),
}
logger := slog.Make(sink)
logger = logger.Leveled(slog.LevelDebug)
done := make(chan struct{})
wg := sync.WaitGroup{}
// Create a test handler to simulate a WebSocket connection
testHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
conn, err := websocket.Accept(rw, r, nil)
if !assert.NoError(t, err, "failed to accept websocket") {
return
}
defer conn.Close(websocket.StatusGoingAway, "")
requestLgr := RequestLoggerFromContext(r.Context())
requestLgr.WriteLog(r.Context(), http.StatusSwitchingProtocols)
// Block so we can be sure the end of the middleware isn't being called.
wg.Wait()
})
// Wrap the test handler with the Logger middleware
loggerMiddleware := Logger(logger)
wrappedHandler := loggerMiddleware(testHandler)
// RequestLogger expects the ResponseWriter to be *tracing.StatusWriter
customHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
defer close(done)
sw := &tracing.StatusWriter{ResponseWriter: rw}
wrappedHandler.ServeHTTP(sw, r)
})
srv := httptest.NewServer(customHandler)
defer srv.Close()
wg.Add(1)
// nolint: bodyclose
conn, _, err := websocket.Dial(ctx, srv.URL, nil)
require.NoError(t, err, "failed to dial WebSocket")
defer conn.Close(websocket.StatusNormalClosure, "")
// Wait for the log from within the handler
newEntry := testutil.RequireRecvCtx(ctx, t, sink.newEntries)
require.Equal(t, newEntry.Message, "GET")
// Signal the websocket handler to return (and read to handle the close frame)
wg.Done()
_, _, err = conn.Read(ctx)
require.ErrorAs(t, err, &websocket.CloseError{}, "websocket read should fail with close error")
// Wait for the request to finish completely and verify we only logged once
_ = testutil.RequireRecvCtx(ctx, t, done)
require.Len(t, sink.entries, 1, "log was written twice")
}
func TestRequestLogger_HTTPRouteParams(t *testing.T) {
t.Parallel()
sink := &fakeSink{}
logger := slog.Make(sink)
logger = logger.Leveled(slog.LevelDebug)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("workspace", "test-workspace")
chiCtx.URLParams.Add("agent", "test-agent")
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
// Create a test handler to simulate an HTTP request
testHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(http.StatusOK)
_, _ = rw.Write([]byte("OK"))
})
// Wrap the test handler with the Logger middleware
loggerMiddleware := Logger(logger)
wrappedHandler := loggerMiddleware(testHandler)
// Create a test HTTP request
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "/test-path/}", nil)
require.NoError(t, err, "failed to create request")
sw := &tracing.StatusWriter{ResponseWriter: httptest.NewRecorder()}
// Serve the request
wrappedHandler.ServeHTTP(sw, req)
fieldsMap := make(map[string]any)
for _, field := range sink.entries[0].Fields {
fieldsMap[field.Name] = field.Value
}
// Check that the log contains the expected fields
requiredFields := []string{"workspace", "agent"}
for _, field := range requiredFields {
_, exists := fieldsMap["params_"+field]
require.True(t, exists, "field %q is missing in log fields", field)
}
}
func TestRequestLogger_RouteParamsLogging(t *testing.T) {
t.Parallel()
tests := []struct {
name string
params map[string]string
expectedFields []string
}{
{
name: "EmptyParams",
params: map[string]string{},
expectedFields: []string{},
},
{
name: "SingleParam",
params: map[string]string{
"workspace": "test-workspace",
},
expectedFields: []string{"params_workspace"},
},
{
name: "MultipleParams",
params: map[string]string{
"workspace": "test-workspace",
"agent": "test-agent",
"user": "test-user",
},
expectedFields: []string{"params_workspace", "params_agent", "params_user"},
},
{
name: "EmptyValueParam",
params: map[string]string{
"workspace": "test-workspace",
"agent": "",
},
expectedFields: []string{"params_workspace"},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
sink := &fakeSink{}
logger := slog.Make(sink)
logger = logger.Leveled(slog.LevelDebug)
// Create a route context with the test parameters
chiCtx := chi.NewRouteContext()
for key, value := range tt.params {
chiCtx.URLParams.Add(key, value)
}
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
logCtx := NewRequestLogger(logger, "GET", time.Now())
// Write the log
logCtx.WriteLog(ctx, http.StatusOK)
require.Len(t, sink.entries, 1, "expected exactly one log entry")
// Convert fields to map for easier checking
fieldsMap := make(map[string]any)
for _, field := range sink.entries[0].Fields {
fieldsMap[field.Name] = field.Value
}
// Verify expected fields are present
for _, field := range tt.expectedFields {
value, exists := fieldsMap[field]
require.True(t, exists, "field %q should be present in log", field)
require.Equal(t, tt.params[strings.TrimPrefix(field, "params_")], value, "field %q has incorrect value", field)
}
// Verify no unexpected fields are present
for field := range fieldsMap {
if field == "took" || field == "status_code" || field == "latency_ms" {
continue // Skip standard fields
}
require.True(t, slices.Contains(tt.expectedFields, field), "unexpected field %q in log", field)
}
})
}
}
type fakeSink struct {
entries []slog.SinkEntry
newEntries chan slog.SinkEntry
}
func (s *fakeSink) LogEntry(_ context.Context, e slog.SinkEntry) {
s.entries = append(s.entries, e)
if s.newEntries != nil {
select {
case s.newEntries <- e:
default:
}
}
}
func (*fakeSink) Sync() {}

View File

@ -0,0 +1,83 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/coder/coder/v2/coderd/httpmw/loggermw (interfaces: RequestLogger)
//
// Generated by this command:
//
// mockgen -destination=loggermock/loggermock.go -package=loggermock . RequestLogger
//
// Package loggermock is a generated GoMock package.
package loggermock
import (
context "context"
reflect "reflect"
slog "cdr.dev/slog"
rbac "github.com/coder/coder/v2/coderd/rbac"
gomock "go.uber.org/mock/gomock"
)
// MockRequestLogger is a mock of RequestLogger interface.
type MockRequestLogger struct {
ctrl *gomock.Controller
recorder *MockRequestLoggerMockRecorder
isgomock struct{}
}
// MockRequestLoggerMockRecorder is the mock recorder for MockRequestLogger.
type MockRequestLoggerMockRecorder struct {
mock *MockRequestLogger
}
// NewMockRequestLogger creates a new mock instance.
func NewMockRequestLogger(ctrl *gomock.Controller) *MockRequestLogger {
mock := &MockRequestLogger{ctrl: ctrl}
mock.recorder = &MockRequestLoggerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockRequestLogger) EXPECT() *MockRequestLoggerMockRecorder {
return m.recorder
}
// WithAuthContext mocks base method.
func (m *MockRequestLogger) WithAuthContext(actor rbac.Subject) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "WithAuthContext", actor)
}
// WithAuthContext indicates an expected call of WithAuthContext.
func (mr *MockRequestLoggerMockRecorder) WithAuthContext(actor any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithAuthContext", reflect.TypeOf((*MockRequestLogger)(nil).WithAuthContext), actor)
}
// WithFields mocks base method.
func (m *MockRequestLogger) WithFields(fields ...slog.Field) {
m.ctrl.T.Helper()
varargs := []any{}
for _, a := range fields {
varargs = append(varargs, a)
}
m.ctrl.Call(m, "WithFields", varargs...)
}
// WithFields indicates an expected call of WithFields.
func (mr *MockRequestLoggerMockRecorder) WithFields(fields ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithFields", reflect.TypeOf((*MockRequestLogger)(nil).WithFields), fields...)
}
// WriteLog mocks base method.
func (m *MockRequestLogger) WriteLog(ctx context.Context, status int) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "WriteLog", ctx, status)
}
// WriteLog indicates an expected call of WriteLog.
func (mr *MockRequestLoggerMockRecorder) WriteLog(ctx, status any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WriteLog", reflect.TypeOf((*MockRequestLogger)(nil).WriteLog), ctx, status)
}