mirror of
https://github.com/coder/coder.git
synced 2025-07-13 21:36:50 +00:00
feat: add panic recovery middleware (#3687)
This commit is contained in:
58
coderd/httpmw/logger.go
Normal file
58
coderd/httpmw/logger.go
Normal file
@ -0,0 +1,58 @@
|
||||
package httpmw
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
)
|
||||
|
||||
func Logger(log slog.Logger) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
sw := &httpapi.StatusWriter{ResponseWriter: w}
|
||||
|
||||
httplog := log.With(
|
||||
slog.F("host", httpapi.RequestHost(r)),
|
||||
slog.F("path", r.URL.Path),
|
||||
slog.F("proto", r.Proto),
|
||||
slog.F("remote_addr", r.RemoteAddr),
|
||||
)
|
||||
|
||||
next.ServeHTTP(sw, r)
|
||||
|
||||
// Don't log successful health check requests.
|
||||
if r.URL.Path == "/api/v2" && sw.Status == 200 {
|
||||
return
|
||||
}
|
||||
|
||||
httplog = httplog.With(
|
||||
slog.F("took", time.Since(start)),
|
||||
slog.F("status_code", sw.Status),
|
||||
slog.F("latency_ms", float64(time.Since(start)/time.Millisecond)),
|
||||
)
|
||||
|
||||
// For status codes 400 and higher we
|
||||
// want to log the response body.
|
||||
if sw.Status >= 400 {
|
||||
httplog = httplog.With(
|
||||
slog.F("response_body", string(sw.ResponseBody())),
|
||||
)
|
||||
}
|
||||
|
||||
logLevelFn := httplog.Debug
|
||||
if sw.Status >= 400 {
|
||||
logLevelFn = httplog.Warn
|
||||
}
|
||||
if sw.Status >= 500 {
|
||||
// Server errors should be treated as an ERROR
|
||||
// log level.
|
||||
logLevelFn = httplog.Error
|
||||
}
|
||||
|
||||
logLevelFn(r.Context(), r.Method)
|
||||
})
|
||||
}
|
||||
}
|
@ -6,7 +6,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
chimw "github.com/go-chi/chi/v5/middleware"
|
||||
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||
@ -66,9 +67,9 @@ func Prometheus(register prometheus.Registerer) func(http.Handler) http.Handler
|
||||
rctx = chi.RouteContext(r.Context())
|
||||
)
|
||||
|
||||
sw, ok := w.(chimw.WrapResponseWriter)
|
||||
sw, ok := w.(*httpapi.StatusWriter)
|
||||
if !ok {
|
||||
panic("dev error: http.ResponseWriter is not chimw.WrapResponseWriter")
|
||||
panic("dev error: http.ResponseWriter is not *httpapi.StatusWriter")
|
||||
}
|
||||
|
||||
var (
|
||||
@ -76,7 +77,7 @@ func Prometheus(register prometheus.Registerer) func(http.Handler) http.Handler
|
||||
distOpts []string
|
||||
)
|
||||
// We want to count WebSockets separately.
|
||||
if isWebsocketUpgrade(r) {
|
||||
if httpapi.IsWebsocketUpgrade(r) {
|
||||
websocketsConcurrent.Inc()
|
||||
defer websocketsConcurrent.Dec()
|
||||
|
||||
@ -93,20 +94,10 @@ func Prometheus(register prometheus.Registerer) func(http.Handler) http.Handler
|
||||
|
||||
path := rctx.RoutePattern()
|
||||
distOpts = append(distOpts, path)
|
||||
statusStr := strconv.Itoa(sw.Status())
|
||||
statusStr := strconv.Itoa(sw.Status)
|
||||
|
||||
requestsProcessed.WithLabelValues(statusStr, method, path).Inc()
|
||||
dist.WithLabelValues(distOpts...).Observe(float64(time.Since(start)) / 1e6)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func isWebsocketUpgrade(r *http.Request) bool {
|
||||
vs := r.Header.Values("Upgrade")
|
||||
for _, v := range vs {
|
||||
if v == "websocket" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
@ -7,10 +7,10 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
chimw "github.com/go-chi/chi/v5/middleware"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/coderd/httpmw"
|
||||
)
|
||||
|
||||
@ -20,7 +20,7 @@ func TestPrometheus(t *testing.T) {
|
||||
t.Parallel()
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, chi.NewRouteContext()))
|
||||
res := chimw.NewWrapResponseWriter(httptest.NewRecorder(), 0)
|
||||
res := &httpapi.StatusWriter{ResponseWriter: httptest.NewRecorder()}
|
||||
reg := prometheus.NewRegistry()
|
||||
httpmw.Prometheus(reg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
40
coderd/httpmw/recover.go
Normal file
40
coderd/httpmw/recover.go
Normal file
@ -0,0 +1,40 @@
|
||||
package httpmw
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"runtime/debug"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
)
|
||||
|
||||
func Recover(log slog.Logger) func(h http.Handler) http.Handler {
|
||||
return func(h http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
defer func() {
|
||||
r := recover()
|
||||
if r != nil {
|
||||
log.Warn(context.Background(),
|
||||
"panic serving http request (recovered)",
|
||||
slog.F("panic", r),
|
||||
slog.F("stack", string(debug.Stack())),
|
||||
)
|
||||
|
||||
var hijacked bool
|
||||
if sw, ok := w.(*httpapi.StatusWriter); ok {
|
||||
hijacked = sw.Hijacked
|
||||
}
|
||||
|
||||
// Only try to write errors on
|
||||
// non-hijacked responses.
|
||||
if !hijacked {
|
||||
httpapi.InternalServerError(w, nil)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
h.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
74
coderd/httpmw/recover_test.go
Normal file
74
coderd/httpmw/recover_test.go
Normal file
@ -0,0 +1,74 @@
|
||||
package httpmw_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/coderd/httpmw"
|
||||
)
|
||||
|
||||
func TestRecover(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := func(isPanic, hijack bool) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if isPanic {
|
||||
panic("Oh no!")
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
Name string
|
||||
Code int
|
||||
Panic bool
|
||||
Hijack bool
|
||||
}{
|
||||
{
|
||||
Name: "OK",
|
||||
Code: http.StatusOK,
|
||||
Panic: false,
|
||||
Hijack: false,
|
||||
},
|
||||
{
|
||||
Name: "Panic",
|
||||
Code: http.StatusInternalServerError,
|
||||
Panic: true,
|
||||
Hijack: false,
|
||||
},
|
||||
{
|
||||
Name: "Hijack",
|
||||
Code: 0,
|
||||
Panic: true,
|
||||
Hijack: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
c := c
|
||||
|
||||
t.Run(c.Name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
log = slogtest.Make(t, nil)
|
||||
r = httptest.NewRequest("GET", "/", nil)
|
||||
w = &httpapi.StatusWriter{
|
||||
ResponseWriter: httptest.NewRecorder(),
|
||||
Hijacked: c.Hijack,
|
||||
}
|
||||
)
|
||||
|
||||
httpmw.Recover(log)(handler(c.Panic, c.Hijack)).ServeHTTP(w, r)
|
||||
|
||||
require.Equal(t, c.Code, w.Status)
|
||||
})
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user