mirror of
https://github.com/coder/coder.git
synced 2025-07-12 00:14:10 +00:00
@ -35,42 +35,93 @@ func Logger(log slog.Logger) func(next http.Handler) http.Handler {
|
||||
slog.F("start", start),
|
||||
)
|
||||
|
||||
next.ServeHTTP(sw, r)
|
||||
logContext := NewRequestLogger(httplog, r.Method, start)
|
||||
|
||||
end := time.Now()
|
||||
ctx := WithRequestLogger(r.Context(), logContext)
|
||||
|
||||
next.ServeHTTP(sw, r.WithContext(ctx))
|
||||
|
||||
// Don't log successful health check requests.
|
||||
if r.URL.Path == "/api/v2" && sw.Status == http.StatusOK {
|
||||
return
|
||||
}
|
||||
|
||||
httplog = httplog.With(
|
||||
slog.F("took", end.Sub(start)),
|
||||
slog.F("status_code", sw.Status),
|
||||
slog.F("latency_ms", float64(end.Sub(start)/time.Millisecond)),
|
||||
)
|
||||
|
||||
// For status codes 400 and higher we
|
||||
// For status codes 500 and higher we
|
||||
// want to log the response body.
|
||||
if sw.Status >= http.StatusInternalServerError {
|
||||
httplog = httplog.With(
|
||||
logContext.WithFields(
|
||||
slog.F("response_body", string(sw.ResponseBody())),
|
||||
)
|
||||
}
|
||||
|
||||
// We should not log at level ERROR for 5xx status codes because 5xx
|
||||
// includes proxy errors etc. It also causes slogtest to fail
|
||||
// instantly without an error message by default.
|
||||
logLevelFn := httplog.Debug
|
||||
if sw.Status >= http.StatusInternalServerError {
|
||||
logLevelFn = httplog.Warn
|
||||
}
|
||||
|
||||
// We already capture most of this information in the span (minus
|
||||
// the response body which we don't want to capture anyways).
|
||||
tracing.RunWithoutSpan(r.Context(), func(ctx context.Context) {
|
||||
logLevelFn(ctx, r.Method)
|
||||
})
|
||||
logContext.WriteLog(r.Context(), sw.Status)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type RequestLogger interface {
|
||||
WithFields(fields ...slog.Field)
|
||||
WriteLog(ctx context.Context, status int)
|
||||
}
|
||||
|
||||
type SlogRequestLogger struct {
|
||||
log slog.Logger
|
||||
written bool
|
||||
message string
|
||||
start time.Time
|
||||
}
|
||||
|
||||
var _ RequestLogger = &SlogRequestLogger{}
|
||||
|
||||
func NewRequestLogger(log slog.Logger, message string, start time.Time) RequestLogger {
|
||||
return &SlogRequestLogger{
|
||||
log: log,
|
||||
written: false,
|
||||
message: message,
|
||||
start: start,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *SlogRequestLogger) WithFields(fields ...slog.Field) {
|
||||
c.log = c.log.With(fields...)
|
||||
}
|
||||
|
||||
func (c *SlogRequestLogger) WriteLog(ctx context.Context, status int) {
|
||||
if c.written {
|
||||
return
|
||||
}
|
||||
c.written = true
|
||||
end := time.Now()
|
||||
|
||||
logger := c.log.With(
|
||||
slog.F("took", end.Sub(c.start)),
|
||||
slog.F("status_code", status),
|
||||
slog.F("latency_ms", float64(end.Sub(c.start)/time.Millisecond)),
|
||||
)
|
||||
// We already capture most of this information in the span (minus
|
||||
// the response body which we don't want to capture anyways).
|
||||
tracing.RunWithoutSpan(ctx, func(ctx context.Context) {
|
||||
// We should not log at level ERROR for 5xx status codes because 5xx
|
||||
// includes proxy errors etc. It also causes slogtest to fail
|
||||
// instantly without an error message by default.
|
||||
if status >= http.StatusInternalServerError {
|
||||
logger.Warn(ctx, c.message)
|
||||
} else {
|
||||
logger.Debug(ctx, c.message)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
type logContextKey struct{}
|
||||
|
||||
func WithRequestLogger(ctx context.Context, rl RequestLogger) context.Context {
|
||||
return context.WithValue(ctx, logContextKey{}, rl)
|
||||
}
|
||||
|
||||
func RequestLoggerFromContext(ctx context.Context) RequestLogger {
|
||||
val := ctx.Value(logContextKey{})
|
||||
if logCtx, ok := val.(RequestLogger); ok {
|
||||
return logCtx
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
174
coderd/httpmw/logger_internal_test.go
Normal file
174
coderd/httpmw/logger_internal_test.go
Normal file
@ -0,0 +1,174 @@
|
||||
package httpmw
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/v2/coderd/tracing"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
"github.com/coder/websocket"
|
||||
)
|
||||
|
||||
func TestRequestLogger_WriteLog(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
sink := &fakeSink{}
|
||||
logger := slog.Make(sink)
|
||||
logger = logger.Leveled(slog.LevelDebug)
|
||||
logCtx := NewRequestLogger(logger, "GET", time.Now())
|
||||
|
||||
// Add custom fields
|
||||
logCtx.WithFields(
|
||||
slog.F("custom_field", "custom_value"),
|
||||
)
|
||||
|
||||
// Write log for 200 status
|
||||
logCtx.WriteLog(ctx, http.StatusOK)
|
||||
|
||||
require.Len(t, sink.entries, 1, "log was written twice")
|
||||
|
||||
require.Equal(t, sink.entries[0].Message, "GET")
|
||||
|
||||
require.Equal(t, sink.entries[0].Fields[0].Value, "custom_value")
|
||||
|
||||
// Attempt to write again (should be skipped).
|
||||
logCtx.WriteLog(ctx, http.StatusInternalServerError)
|
||||
|
||||
require.Len(t, sink.entries, 1, "log was written twice")
|
||||
}
|
||||
|
||||
func TestLoggerMiddleware_SingleRequest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sink := &fakeSink{}
|
||||
logger := slog.Make(sink)
|
||||
logger = logger.Leveled(slog.LevelDebug)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
defer cancel()
|
||||
|
||||
// Create a test handler to simulate an HTTP request
|
||||
testHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
_, _ = rw.Write([]byte("OK"))
|
||||
})
|
||||
|
||||
// Wrap the test handler with the Logger middleware
|
||||
loggerMiddleware := Logger(logger)
|
||||
wrappedHandler := loggerMiddleware(testHandler)
|
||||
|
||||
// Create a test HTTP request
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "/test-path", nil)
|
||||
require.NoError(t, err, "failed to create request")
|
||||
|
||||
sw := &tracing.StatusWriter{ResponseWriter: httptest.NewRecorder()}
|
||||
|
||||
// Serve the request
|
||||
wrappedHandler.ServeHTTP(sw, req)
|
||||
|
||||
require.Len(t, sink.entries, 1, "log was written twice")
|
||||
|
||||
require.Equal(t, sink.entries[0].Message, "GET")
|
||||
|
||||
fieldsMap := make(map[string]interface{})
|
||||
for _, field := range sink.entries[0].Fields {
|
||||
fieldsMap[field.Name] = field.Value
|
||||
}
|
||||
|
||||
// Check that the log contains the expected fields
|
||||
requiredFields := []string{"host", "path", "proto", "remote_addr", "start", "took", "status_code", "latency_ms"}
|
||||
for _, field := range requiredFields {
|
||||
_, exists := fieldsMap[field]
|
||||
require.True(t, exists, "field %q is missing in log fields", field)
|
||||
}
|
||||
|
||||
require.Len(t, sink.entries[0].Fields, len(requiredFields), "log should contain only the required fields")
|
||||
|
||||
// Check value of the status code
|
||||
require.Equal(t, fieldsMap["status_code"], http.StatusOK)
|
||||
}
|
||||
|
||||
func TestLoggerMiddleware_WebSocket(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
||||
defer cancel()
|
||||
|
||||
sink := &fakeSink{
|
||||
newEntries: make(chan slog.SinkEntry, 2),
|
||||
}
|
||||
logger := slog.Make(sink)
|
||||
logger = logger.Leveled(slog.LevelDebug)
|
||||
done := make(chan struct{})
|
||||
wg := sync.WaitGroup{}
|
||||
// Create a test handler to simulate a WebSocket connection
|
||||
testHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
conn, err := websocket.Accept(rw, r, nil)
|
||||
if !assert.NoError(t, err, "failed to accept websocket") {
|
||||
return
|
||||
}
|
||||
defer conn.Close(websocket.StatusGoingAway, "")
|
||||
|
||||
requestLgr := RequestLoggerFromContext(r.Context())
|
||||
requestLgr.WriteLog(r.Context(), http.StatusSwitchingProtocols)
|
||||
// Block so we can be sure the end of the middleware isn't being called.
|
||||
wg.Wait()
|
||||
})
|
||||
|
||||
// Wrap the test handler with the Logger middleware
|
||||
loggerMiddleware := Logger(logger)
|
||||
wrappedHandler := loggerMiddleware(testHandler)
|
||||
|
||||
// RequestLogger expects the ResponseWriter to be *tracing.StatusWriter
|
||||
customHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
defer close(done)
|
||||
sw := &tracing.StatusWriter{ResponseWriter: rw}
|
||||
wrappedHandler.ServeHTTP(sw, r)
|
||||
})
|
||||
|
||||
srv := httptest.NewServer(customHandler)
|
||||
defer srv.Close()
|
||||
wg.Add(1)
|
||||
// nolint: bodyclose
|
||||
conn, _, err := websocket.Dial(ctx, srv.URL, nil)
|
||||
require.NoError(t, err, "failed to dial WebSocket")
|
||||
defer conn.Close(websocket.StatusNormalClosure, "")
|
||||
|
||||
// Wait for the log from within the handler
|
||||
newEntry := testutil.RequireRecvCtx(ctx, t, sink.newEntries)
|
||||
require.Equal(t, newEntry.Message, "GET")
|
||||
|
||||
// Signal the websocket handler to return (and read to handle the close frame)
|
||||
wg.Done()
|
||||
_, _, err = conn.Read(ctx)
|
||||
require.ErrorAs(t, err, &websocket.CloseError{}, "websocket read should fail with close error")
|
||||
|
||||
// Wait for the request to finish completely and verify we only logged once
|
||||
_ = testutil.RequireRecvCtx(ctx, t, done)
|
||||
require.Len(t, sink.entries, 1, "log was written twice")
|
||||
}
|
||||
|
||||
type fakeSink struct {
|
||||
entries []slog.SinkEntry
|
||||
newEntries chan slog.SinkEntry
|
||||
}
|
||||
|
||||
func (s *fakeSink) LogEntry(_ context.Context, e slog.SinkEntry) {
|
||||
s.entries = append(s.entries, e)
|
||||
if s.newEntries != nil {
|
||||
select {
|
||||
case s.newEntries <- e:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (*fakeSink) Sync() {}
|
70
coderd/httpmw/loggermock/loggermock.go
Normal file
70
coderd/httpmw/loggermock/loggermock.go
Normal file
@ -0,0 +1,70 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/coder/coder/v2/coderd/httpmw (interfaces: RequestLogger)
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -destination=loggermock/loggermock.go -package=loggermock . RequestLogger
|
||||
//
|
||||
|
||||
// Package loggermock is a generated GoMock package.
|
||||
package loggermock
|
||||
|
||||
import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
|
||||
slog "cdr.dev/slog"
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockRequestLogger is a mock of RequestLogger interface.
|
||||
type MockRequestLogger struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockRequestLoggerMockRecorder
|
||||
isgomock struct{}
|
||||
}
|
||||
|
||||
// MockRequestLoggerMockRecorder is the mock recorder for MockRequestLogger.
|
||||
type MockRequestLoggerMockRecorder struct {
|
||||
mock *MockRequestLogger
|
||||
}
|
||||
|
||||
// NewMockRequestLogger creates a new mock instance.
|
||||
func NewMockRequestLogger(ctrl *gomock.Controller) *MockRequestLogger {
|
||||
mock := &MockRequestLogger{ctrl: ctrl}
|
||||
mock.recorder = &MockRequestLoggerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockRequestLogger) EXPECT() *MockRequestLoggerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// WithFields mocks base method.
|
||||
func (m *MockRequestLogger) WithFields(fields ...slog.Field) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{}
|
||||
for _, a := range fields {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
m.ctrl.Call(m, "WithFields", varargs...)
|
||||
}
|
||||
|
||||
// WithFields indicates an expected call of WithFields.
|
||||
func (mr *MockRequestLoggerMockRecorder) WithFields(fields ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithFields", reflect.TypeOf((*MockRequestLogger)(nil).WithFields), fields...)
|
||||
}
|
||||
|
||||
// WriteLog mocks base method.
|
||||
func (m *MockRequestLogger) WriteLog(ctx context.Context, status int) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "WriteLog", ctx, status)
|
||||
}
|
||||
|
||||
// WriteLog indicates an expected call of WriteLog.
|
||||
func (mr *MockRequestLoggerMockRecorder) WriteLog(ctx, status any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WriteLog", reflect.TypeOf((*MockRequestLogger)(nil).WriteLog), ctx, status)
|
||||
}
|
Reference in New Issue
Block a user