mirror of
https://github.com/coder/coder.git
synced 2025-07-03 16:13:58 +00:00
311 lines
8.7 KiB
Go
311 lines
8.7 KiB
Go
package loggermw
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"slices"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/go-chi/chi/v5"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"cdr.dev/slog"
|
|
"github.com/coder/coder/v2/coderd/tracing"
|
|
"github.com/coder/coder/v2/testutil"
|
|
"github.com/coder/websocket"
|
|
)
|
|
|
|
func TestRequestLogger_WriteLog(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := context.Background()
|
|
|
|
sink := &fakeSink{}
|
|
logger := slog.Make(sink)
|
|
logger = logger.Leveled(slog.LevelDebug)
|
|
logCtx := NewRequestLogger(logger, "GET", time.Now())
|
|
|
|
// Add custom fields
|
|
logCtx.WithFields(
|
|
slog.F("custom_field", "custom_value"),
|
|
)
|
|
|
|
// Write log for 200 status
|
|
logCtx.WriteLog(ctx, http.StatusOK)
|
|
|
|
require.Len(t, sink.entries, 1, "log was written twice")
|
|
|
|
require.Equal(t, sink.entries[0].Message, "GET")
|
|
|
|
require.Equal(t, sink.entries[0].Fields[0].Value, "custom_value")
|
|
|
|
// Attempt to write again (should be skipped).
|
|
logCtx.WriteLog(ctx, http.StatusInternalServerError)
|
|
|
|
require.Len(t, sink.entries, 1, "log was written twice")
|
|
}
|
|
|
|
func TestLoggerMiddleware_SingleRequest(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
sink := &fakeSink{}
|
|
logger := slog.Make(sink)
|
|
logger = logger.Leveled(slog.LevelDebug)
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
|
defer cancel()
|
|
|
|
// Create a test handler to simulate an HTTP request
|
|
testHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
|
rw.WriteHeader(http.StatusOK)
|
|
_, _ = rw.Write([]byte("OK"))
|
|
})
|
|
|
|
// Wrap the test handler with the Logger middleware
|
|
loggerMiddleware := Logger(logger)
|
|
wrappedHandler := loggerMiddleware(testHandler)
|
|
|
|
// Create a test HTTP request
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "/test-path", nil)
|
|
require.NoError(t, err, "failed to create request")
|
|
|
|
sw := &tracing.StatusWriter{ResponseWriter: httptest.NewRecorder()}
|
|
|
|
// Serve the request
|
|
wrappedHandler.ServeHTTP(sw, req)
|
|
|
|
require.Len(t, sink.entries, 1, "log was written twice")
|
|
|
|
require.Equal(t, sink.entries[0].Message, "GET")
|
|
|
|
fieldsMap := make(map[string]any)
|
|
for _, field := range sink.entries[0].Fields {
|
|
fieldsMap[field.Name] = field.Value
|
|
}
|
|
|
|
// Check that the log contains the expected fields
|
|
requiredFields := []string{"host", "path", "proto", "remote_addr", "start", "took", "status_code", "latency_ms"}
|
|
for _, field := range requiredFields {
|
|
_, exists := fieldsMap[field]
|
|
require.True(t, exists, "field %q is missing in log fields", field)
|
|
}
|
|
|
|
require.Len(t, sink.entries[0].Fields, len(requiredFields), "log should contain only the required fields")
|
|
|
|
// Check value of the status code
|
|
require.Equal(t, fieldsMap["status_code"], http.StatusOK)
|
|
}
|
|
|
|
func TestLoggerMiddleware_WebSocket(t *testing.T) {
|
|
t.Parallel()
|
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
|
defer cancel()
|
|
|
|
sink := &fakeSink{
|
|
newEntries: make(chan slog.SinkEntry, 2),
|
|
}
|
|
logger := slog.Make(sink)
|
|
logger = logger.Leveled(slog.LevelDebug)
|
|
done := make(chan struct{})
|
|
wg := sync.WaitGroup{}
|
|
// Create a test handler to simulate a WebSocket connection
|
|
testHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
|
conn, err := websocket.Accept(rw, r, nil)
|
|
if !assert.NoError(t, err, "failed to accept websocket") {
|
|
return
|
|
}
|
|
defer conn.Close(websocket.StatusGoingAway, "")
|
|
|
|
requestLgr := RequestLoggerFromContext(r.Context())
|
|
requestLgr.WriteLog(r.Context(), http.StatusSwitchingProtocols)
|
|
// Block so we can be sure the end of the middleware isn't being called.
|
|
wg.Wait()
|
|
})
|
|
|
|
// Wrap the test handler with the Logger middleware
|
|
loggerMiddleware := Logger(logger)
|
|
wrappedHandler := loggerMiddleware(testHandler)
|
|
|
|
// RequestLogger expects the ResponseWriter to be *tracing.StatusWriter
|
|
customHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
|
defer close(done)
|
|
sw := &tracing.StatusWriter{ResponseWriter: rw}
|
|
wrappedHandler.ServeHTTP(sw, r)
|
|
})
|
|
|
|
srv := httptest.NewServer(customHandler)
|
|
defer srv.Close()
|
|
wg.Add(1)
|
|
// nolint: bodyclose
|
|
conn, _, err := websocket.Dial(ctx, srv.URL, nil)
|
|
require.NoError(t, err, "failed to dial WebSocket")
|
|
defer conn.Close(websocket.StatusNormalClosure, "")
|
|
|
|
// Wait for the log from within the handler
|
|
newEntry := testutil.TryReceive(ctx, t, sink.newEntries)
|
|
require.Equal(t, newEntry.Message, "GET")
|
|
|
|
// Signal the websocket handler to return (and read to handle the close frame)
|
|
wg.Done()
|
|
_, _, err = conn.Read(ctx)
|
|
require.ErrorAs(t, err, &websocket.CloseError{}, "websocket read should fail with close error")
|
|
|
|
// Wait for the request to finish completely and verify we only logged once
|
|
_ = testutil.TryReceive(ctx, t, done)
|
|
require.Len(t, sink.entries, 1, "log was written twice")
|
|
}
|
|
|
|
func TestRequestLogger_HTTPRouteParams(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
sink := &fakeSink{}
|
|
logger := slog.Make(sink)
|
|
logger = logger.Leveled(slog.LevelDebug)
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
|
|
defer cancel()
|
|
|
|
chiCtx := chi.NewRouteContext()
|
|
chiCtx.URLParams.Add("workspace", "test-workspace")
|
|
chiCtx.URLParams.Add("agent", "test-agent")
|
|
|
|
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
|
|
|
|
// Create a test handler to simulate an HTTP request
|
|
testHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
|
rw.WriteHeader(http.StatusOK)
|
|
_, _ = rw.Write([]byte("OK"))
|
|
})
|
|
|
|
// Wrap the test handler with the Logger middleware
|
|
loggerMiddleware := Logger(logger)
|
|
wrappedHandler := loggerMiddleware(testHandler)
|
|
|
|
// Create a test HTTP request
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "/test-path/}", nil)
|
|
require.NoError(t, err, "failed to create request")
|
|
|
|
sw := &tracing.StatusWriter{ResponseWriter: httptest.NewRecorder()}
|
|
|
|
// Serve the request
|
|
wrappedHandler.ServeHTTP(sw, req)
|
|
|
|
fieldsMap := make(map[string]any)
|
|
for _, field := range sink.entries[0].Fields {
|
|
fieldsMap[field.Name] = field.Value
|
|
}
|
|
|
|
// Check that the log contains the expected fields
|
|
requiredFields := []string{"workspace", "agent"}
|
|
for _, field := range requiredFields {
|
|
_, exists := fieldsMap["params_"+field]
|
|
require.True(t, exists, "field %q is missing in log fields", field)
|
|
}
|
|
}
|
|
|
|
func TestRequestLogger_RouteParamsLogging(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
params map[string]string
|
|
expectedFields []string
|
|
}{
|
|
{
|
|
name: "EmptyParams",
|
|
params: map[string]string{},
|
|
expectedFields: []string{},
|
|
},
|
|
{
|
|
name: "SingleParam",
|
|
params: map[string]string{
|
|
"workspace": "test-workspace",
|
|
},
|
|
expectedFields: []string{"params_workspace"},
|
|
},
|
|
{
|
|
name: "MultipleParams",
|
|
params: map[string]string{
|
|
"workspace": "test-workspace",
|
|
"agent": "test-agent",
|
|
"user": "test-user",
|
|
},
|
|
expectedFields: []string{"params_workspace", "params_agent", "params_user"},
|
|
},
|
|
{
|
|
name: "EmptyValueParam",
|
|
params: map[string]string{
|
|
"workspace": "test-workspace",
|
|
"agent": "",
|
|
},
|
|
expectedFields: []string{"params_workspace"},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
sink := &fakeSink{}
|
|
logger := slog.Make(sink)
|
|
logger = logger.Leveled(slog.LevelDebug)
|
|
|
|
// Create a route context with the test parameters
|
|
chiCtx := chi.NewRouteContext()
|
|
for key, value := range tt.params {
|
|
chiCtx.URLParams.Add(key, value)
|
|
}
|
|
|
|
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
|
|
logCtx := NewRequestLogger(logger, "GET", time.Now())
|
|
|
|
// Write the log
|
|
logCtx.WriteLog(ctx, http.StatusOK)
|
|
|
|
require.Len(t, sink.entries, 1, "expected exactly one log entry")
|
|
|
|
// Convert fields to map for easier checking
|
|
fieldsMap := make(map[string]any)
|
|
for _, field := range sink.entries[0].Fields {
|
|
fieldsMap[field.Name] = field.Value
|
|
}
|
|
|
|
// Verify expected fields are present
|
|
for _, field := range tt.expectedFields {
|
|
value, exists := fieldsMap[field]
|
|
require.True(t, exists, "field %q should be present in log", field)
|
|
require.Equal(t, tt.params[strings.TrimPrefix(field, "params_")], value, "field %q has incorrect value", field)
|
|
}
|
|
|
|
// Verify no unexpected fields are present
|
|
for field := range fieldsMap {
|
|
if field == "took" || field == "status_code" || field == "latency_ms" {
|
|
continue // Skip standard fields
|
|
}
|
|
require.True(t, slices.Contains(tt.expectedFields, field), "unexpected field %q in log", field)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
type fakeSink struct {
|
|
entries []slog.SinkEntry
|
|
newEntries chan slog.SinkEntry
|
|
}
|
|
|
|
func (s *fakeSink) LogEntry(_ context.Context, e slog.SinkEntry) {
|
|
s.entries = append(s.entries, e)
|
|
if s.newEntries != nil {
|
|
select {
|
|
case s.newEntries <- e:
|
|
default:
|
|
}
|
|
}
|
|
}
|
|
|
|
func (*fakeSink) Sync() {}
|