mirror of
https://github.com/coder/coder.git
synced 2025-07-08 11:39:50 +00:00
feat: add rbac tracing (#4093)
This commit is contained in:
@ -13,8 +13,8 @@ import (
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/features"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/coderd/httpmw"
|
||||
"github.com/coder/coder/coderd/tracing"
|
||||
)
|
||||
|
||||
type RequestParams struct {
|
||||
@ -93,9 +93,9 @@ func ResourceType[T Auditable](tgt T) database.ResourceType {
|
||||
// that should be deferred, causing the audit log to be committed when the
|
||||
// handler returns.
|
||||
func InitRequest[T Auditable](w http.ResponseWriter, p *RequestParams) (*Request[T], func()) {
|
||||
sw, ok := w.(*httpapi.StatusWriter)
|
||||
sw, ok := w.(*tracing.StatusWriter)
|
||||
if !ok {
|
||||
panic("dev error: http.ResponseWriter is not *httpapi.StatusWriter")
|
||||
panic("dev error: http.ResponseWriter is not *tracing.StatusWriter")
|
||||
}
|
||||
|
||||
req := &Request[T]{
|
||||
|
@ -199,7 +199,7 @@ func New(options *Options) *API {
|
||||
apps := func(r chi.Router) {
|
||||
r.Use(
|
||||
httpmw.RateLimitPerMinute(options.APIRateLimit),
|
||||
tracing.HTTPMW(api.TracerProvider, "coderd.http"),
|
||||
tracing.HTTPMW(api.TracerProvider),
|
||||
httpmw.ExtractAPIKey(options.Database, oauthConfigs, true),
|
||||
httpmw.ExtractUserParam(api.Database),
|
||||
// Extracts the <workspace.agent> from the url
|
||||
@ -229,7 +229,7 @@ func New(options *Options) *API {
|
||||
r.Use(
|
||||
// Specific routes can specify smaller limits.
|
||||
httpmw.RateLimitPerMinute(options.APIRateLimit),
|
||||
tracing.HTTPMW(api.TracerProvider, "coderd.http"),
|
||||
tracing.HTTPMW(api.TracerProvider),
|
||||
)
|
||||
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
httpapi.Write(w, http.StatusOK, codersdk.Response{
|
||||
|
@ -6,13 +6,14 @@ import (
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/coderd/tracing"
|
||||
)
|
||||
|
||||
func Logger(log slog.Logger) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
sw := &httpapi.StatusWriter{ResponseWriter: w}
|
||||
sw := &tracing.StatusWriter{ResponseWriter: w}
|
||||
|
||||
httplog := log.With(
|
||||
slog.F("host", httpapi.RequestHost(r)),
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/coderd/tracing"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||
@ -67,9 +68,9 @@ func Prometheus(register prometheus.Registerer) func(http.Handler) http.Handler
|
||||
rctx = chi.RouteContext(r.Context())
|
||||
)
|
||||
|
||||
sw, ok := w.(*httpapi.StatusWriter)
|
||||
sw, ok := w.(*tracing.StatusWriter)
|
||||
if !ok {
|
||||
panic("dev error: http.ResponseWriter is not *httpapi.StatusWriter")
|
||||
panic("dev error: http.ResponseWriter is not *tracing.StatusWriter")
|
||||
}
|
||||
|
||||
var (
|
||||
|
@ -10,8 +10,8 @@ import (
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/coderd/httpmw"
|
||||
"github.com/coder/coder/coderd/tracing"
|
||||
)
|
||||
|
||||
func TestPrometheus(t *testing.T) {
|
||||
@ -20,7 +20,7 @@ func TestPrometheus(t *testing.T) {
|
||||
t.Parallel()
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, chi.NewRouteContext()))
|
||||
res := &httpapi.StatusWriter{ResponseWriter: httptest.NewRecorder()}
|
||||
res := &tracing.StatusWriter{ResponseWriter: httptest.NewRecorder()}
|
||||
reg := prometheus.NewRegistry()
|
||||
httpmw.Prometheus(reg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/coderd/tracing"
|
||||
)
|
||||
|
||||
func Recover(log slog.Logger) func(h http.Handler) http.Handler {
|
||||
@ -22,7 +23,7 @@ func Recover(log slog.Logger) func(h http.Handler) http.Handler {
|
||||
)
|
||||
|
||||
var hijacked bool
|
||||
if sw, ok := w.(*httpapi.StatusWriter); ok {
|
||||
if sw, ok := w.(*tracing.StatusWriter); ok {
|
||||
hijacked = sw.Hijacked
|
||||
}
|
||||
|
||||
|
@ -8,8 +8,8 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/coderd/httpmw"
|
||||
"github.com/coder/coder/coderd/tracing"
|
||||
)
|
||||
|
||||
func TestRecover(t *testing.T) {
|
||||
@ -60,7 +60,7 @@ func TestRecover(t *testing.T) {
|
||||
var (
|
||||
log = slogtest.Make(t, nil)
|
||||
r = httptest.NewRequest("GET", "/", nil)
|
||||
w = &httpapi.StatusWriter{
|
||||
w = &tracing.StatusWriter{
|
||||
ResponseWriter: httptest.NewRecorder(),
|
||||
Hijacked: c.Hijack,
|
||||
}
|
||||
|
@ -4,9 +4,12 @@ import (
|
||||
"context"
|
||||
_ "embed"
|
||||
|
||||
"github.com/open-policy-agent/opa/rego"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/open-policy-agent/opa/rego"
|
||||
"github.com/coder/coder/coderd/tracing"
|
||||
)
|
||||
|
||||
type Authorizer interface {
|
||||
@ -22,6 +25,13 @@ type PreparedAuthorized interface {
|
||||
// the elements the subject does not have permission for. All objects must be
|
||||
// of the same type.
|
||||
func Filter[O Objecter](ctx context.Context, auth Authorizer, subjID string, subjRoles []string, action Action, objects []O) ([]O, error) {
|
||||
ctx, span := tracing.StartSpan(ctx, trace.WithAttributes(
|
||||
attribute.String("subject_id", subjID),
|
||||
attribute.StringSlice("subject_roles", subjRoles),
|
||||
attribute.Int("num_objects", len(objects)),
|
||||
))
|
||||
defer span.End()
|
||||
|
||||
if len(objects) == 0 {
|
||||
// Nothing to filter
|
||||
return objects, nil
|
||||
@ -34,8 +44,7 @@ func Filter[O Objecter](ctx context.Context, auth Authorizer, subjID string, sub
|
||||
return nil, xerrors.Errorf("prepare: %w", err)
|
||||
}
|
||||
|
||||
for i := range objects {
|
||||
object := objects[i]
|
||||
for _, object := range objects {
|
||||
rbacObj := object.RBACObject()
|
||||
if rbacObj.Type != objectType {
|
||||
return nil, xerrors.Errorf("object types must be uniform across the set (%s), found %s", objectType, object.RBACObject().Type)
|
||||
@ -45,6 +54,7 @@ func Filter[O Objecter](ctx context.Context, auth Authorizer, subjID string, sub
|
||||
filtered = append(filtered, object)
|
||||
}
|
||||
}
|
||||
|
||||
return filtered, nil
|
||||
}
|
||||
|
||||
@ -93,6 +103,9 @@ func (a RegoAuthorizer) ByRoleName(ctx context.Context, subjectID string, roleNa
|
||||
// Authorize allows passing in custom Roles.
|
||||
// This is really helpful for unit testing, as we can create custom roles to exercise edge cases.
|
||||
func (a RegoAuthorizer) Authorize(ctx context.Context, subjectID string, roles []Role, action Action, object Object) error {
|
||||
ctx, span := tracing.StartSpan(ctx)
|
||||
defer span.End()
|
||||
|
||||
input := map[string]interface{}{
|
||||
"subject": authSubject{
|
||||
ID: subjectID,
|
||||
@ -117,6 +130,9 @@ func (a RegoAuthorizer) Authorize(ctx context.Context, subjectID string, roles [
|
||||
// Prepare will partially execute the rego policy leaving the object fields unknown (except for the type).
|
||||
// This will vastly speed up performance if batch authorization on the same type of objects is needed.
|
||||
func (RegoAuthorizer) Prepare(ctx context.Context, subjectID string, roles []Role, action Action, objectType string) (*PartialAuthorizer, error) {
|
||||
ctx, span := tracing.StartSpan(ctx)
|
||||
defer span.End()
|
||||
|
||||
auth, err := newPartialAuthorizer(ctx, subjectID, roles, action, objectType)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("new partial authorizer: %w", err)
|
||||
@ -126,6 +142,9 @@ func (RegoAuthorizer) Prepare(ctx context.Context, subjectID string, roles []Rol
|
||||
}
|
||||
|
||||
func (a RegoAuthorizer) PrepareByRoleName(ctx context.Context, subjectID string, roleNames []string, action Action, objectType string) (PreparedAuthorized, error) {
|
||||
ctx, span := tracing.StartSpan(ctx)
|
||||
defer span.End()
|
||||
|
||||
roles, err := RolesByNames(roleNames)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -6,6 +6,8 @@ import (
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/open-policy-agent/opa/rego"
|
||||
|
||||
"github.com/coder/coder/coderd/tracing"
|
||||
)
|
||||
|
||||
type PartialAuthorizer struct {
|
||||
@ -24,6 +26,9 @@ type PartialAuthorizer struct {
|
||||
}
|
||||
|
||||
func newPartialAuthorizer(ctx context.Context, subjectID string, roles []Role, action Action, objectType string) (*PartialAuthorizer, error) {
|
||||
ctx, span := tracing.StartSpan(ctx)
|
||||
defer span.End()
|
||||
|
||||
input := map[string]interface{}{
|
||||
"subject": authSubject{
|
||||
ID: subjectID,
|
||||
@ -83,6 +88,9 @@ func newPartialAuthorizer(ctx context.Context, subjectID string, roles []Role, a
|
||||
|
||||
// Authorize authorizes a single object using the partially prepared queries.
|
||||
func (a PartialAuthorizer) Authorize(ctx context.Context, object Object) error {
|
||||
ctx, span := tracing.StartSpan(ctx)
|
||||
defer span.End()
|
||||
|
||||
if a.alwaysTrue {
|
||||
return nil
|
||||
}
|
||||
|
@ -1,18 +1,17 @@
|
||||
package tracing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
semconv "go.opentelemetry.io/otel/semconv/v1.10.0"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
)
|
||||
|
||||
// HTTPMW adds tracing to http routes.
|
||||
func HTTPMW(tracerProvider trace.TracerProvider, name string) func(http.Handler) http.Handler {
|
||||
func HTTPMW(tracerProvider trace.TracerProvider) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
if tracerProvider == nil {
|
||||
@ -21,13 +20,13 @@ func HTTPMW(tracerProvider trace.TracerProvider, name string) func(http.Handler)
|
||||
}
|
||||
|
||||
// start span with default span name. Span name will be updated to "method route" format once request finishes.
|
||||
ctx, span := tracerProvider.Tracer(name).Start(r.Context(), fmt.Sprintf("%s %s", r.Method, r.RequestURI))
|
||||
ctx, span := tracerProvider.Tracer("").Start(r.Context(), fmt.Sprintf("%s %s", r.Method, r.RequestURI))
|
||||
defer span.End()
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
sw, ok := rw.(*httpapi.StatusWriter)
|
||||
sw, ok := rw.(*StatusWriter)
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("ResponseWriter not a *httpapi.StatusWriter; got %T", rw))
|
||||
panic(fmt.Sprintf("ResponseWriter not a *tracing.StatusWriter; got %T", rw))
|
||||
}
|
||||
|
||||
// pass the span through the request context and serve the request to the next middleware
|
||||
@ -53,9 +52,12 @@ func EndHTTPSpan(r *http.Request, status int, span trace.Span) {
|
||||
status = http.StatusOK
|
||||
}
|
||||
span.SetAttributes(semconv.HTTPStatusCodeKey.Int(status))
|
||||
spanStatus, spanMessage := semconv.SpanStatusFromHTTPStatusCodeAndSpanKind(status, trace.SpanKindServer)
|
||||
span.SetStatus(spanStatus, spanMessage)
|
||||
span.SetStatus(semconv.SpanStatusFromHTTPStatusCodeAndSpanKind(status, trace.SpanKindServer))
|
||||
|
||||
// finally end span
|
||||
span.End()
|
||||
}
|
||||
|
||||
func StartSpan(ctx context.Context, opts ...trace.SpanStartOption) (context.Context, trace.Span) {
|
||||
return trace.SpanFromContext(ctx).TracerProvider().Tracer("").Start(ctx, FuncNameSkip(1), opts...)
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
package httpapi
|
||||
package tracing
|
||||
|
||||
import (
|
||||
"bufio"
|
@ -1,4 +1,4 @@
|
||||
package httpapi_test
|
||||
package tracing_test
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
@ -11,7 +11,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/coderd/tracing"
|
||||
)
|
||||
|
||||
func TestStatusWriter(t *testing.T) {
|
||||
@ -22,7 +22,7 @@ func TestStatusWriter(t *testing.T) {
|
||||
|
||||
var (
|
||||
rec = httptest.NewRecorder()
|
||||
w = &httpapi.StatusWriter{ResponseWriter: rec}
|
||||
w = &tracing.StatusWriter{ResponseWriter: rec}
|
||||
)
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
@ -36,7 +36,7 @@ func TestStatusWriter(t *testing.T) {
|
||||
|
||||
var (
|
||||
rec = httptest.NewRecorder()
|
||||
w = &httpapi.StatusWriter{ResponseWriter: rec}
|
||||
w = &tracing.StatusWriter{ResponseWriter: rec}
|
||||
code = http.StatusNotFound
|
||||
)
|
||||
|
||||
@ -52,7 +52,7 @@ func TestStatusWriter(t *testing.T) {
|
||||
t.Parallel()
|
||||
var (
|
||||
rec = httptest.NewRecorder()
|
||||
w = &httpapi.StatusWriter{ResponseWriter: rec}
|
||||
w = &tracing.StatusWriter{ResponseWriter: rec}
|
||||
body = []byte("hello")
|
||||
)
|
||||
|
||||
@ -70,7 +70,7 @@ func TestStatusWriter(t *testing.T) {
|
||||
t.Parallel()
|
||||
var (
|
||||
rec = httptest.NewRecorder()
|
||||
w = &httpapi.StatusWriter{ResponseWriter: rec}
|
||||
w = &tracing.StatusWriter{ResponseWriter: rec}
|
||||
body = []byte("hello")
|
||||
code = http.StatusInternalServerError
|
||||
)
|
||||
@ -88,7 +88,7 @@ func TestStatusWriter(t *testing.T) {
|
||||
t.Parallel()
|
||||
var (
|
||||
rec = httptest.NewRecorder()
|
||||
w = &httpapi.StatusWriter{ResponseWriter: rec}
|
||||
w = &tracing.StatusWriter{ResponseWriter: rec}
|
||||
// 8kb body.
|
||||
body = make([]byte, 8<<10)
|
||||
code = http.StatusInternalServerError
|
||||
@ -112,7 +112,7 @@ func TestStatusWriter(t *testing.T) {
|
||||
rec = httptest.NewRecorder()
|
||||
)
|
||||
|
||||
w := &httpapi.StatusWriter{ResponseWriter: hijacker{rec}}
|
||||
w := &tracing.StatusWriter{ResponseWriter: hijacker{rec}}
|
||||
|
||||
_, _, err := w.Hijack()
|
||||
require.Error(t, err)
|
@ -17,3 +17,16 @@ func FuncName() string {
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
||||
func FuncNameSkip(skip int) string {
|
||||
fnpc, _, _, ok := runtime.Caller(1 + skip)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
fn := runtime.FuncForPC(fnpc)
|
||||
name := fn.Name()
|
||||
if i := strings.LastIndex(name, "/"); i > 0 {
|
||||
name = name[i+1:]
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
Reference in New Issue
Block a user