mirror of
https://github.com/coder/coder.git
synced 2025-07-13 21:36:50 +00:00
coderd: tighten /login rate limiting (#4432)
* coderd: tighten /login rate limit * coderd: add Bypass rate limit header
This commit is contained in:
@ -235,10 +235,15 @@ linters:
|
|||||||
- noctx
|
- noctx
|
||||||
- paralleltest
|
- paralleltest
|
||||||
- revive
|
- revive
|
||||||
- rowserrcheck
|
|
||||||
- sqlclosecheck
|
# These don't work until the following issue is solved.
|
||||||
|
# https://github.com/golangci/golangci-lint/issues/2649
|
||||||
|
# - rowserrcheck
|
||||||
|
# - sqlclosecheck
|
||||||
|
# - structcheck
|
||||||
|
# - wastedassign
|
||||||
|
|
||||||
- staticcheck
|
- staticcheck
|
||||||
- structcheck
|
|
||||||
- tenv
|
- tenv
|
||||||
# In Go, it's possible for a package to test it's internal functionality
|
# In Go, it's possible for a package to test it's internal functionality
|
||||||
# without testing any exported functions. This is enabled to promote
|
# without testing any exported functions. This is enabled to promote
|
||||||
@ -253,4 +258,3 @@ linters:
|
|||||||
- unconvert
|
- unconvert
|
||||||
- unused
|
- unused
|
||||||
- varcheck
|
- varcheck
|
||||||
- wastedassign
|
|
||||||
|
@ -204,7 +204,7 @@ func New(options *Options) *API {
|
|||||||
// app URL. If it is, it will serve that application.
|
// app URL. If it is, it will serve that application.
|
||||||
api.handleSubdomainApplications(
|
api.handleSubdomainApplications(
|
||||||
// Middleware to impose on the served application.
|
// Middleware to impose on the served application.
|
||||||
httpmw.RateLimitPerMinute(options.APIRateLimit),
|
httpmw.RateLimit(options.APIRateLimit, time.Minute),
|
||||||
httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{
|
httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{
|
||||||
DB: options.Database,
|
DB: options.Database,
|
||||||
OAuth2Configs: oauthConfigs,
|
OAuth2Configs: oauthConfigs,
|
||||||
@ -229,7 +229,7 @@ func New(options *Options) *API {
|
|||||||
apps := func(r chi.Router) {
|
apps := func(r chi.Router) {
|
||||||
r.Use(
|
r.Use(
|
||||||
tracing.Middleware(api.TracerProvider),
|
tracing.Middleware(api.TracerProvider),
|
||||||
httpmw.RateLimitPerMinute(options.APIRateLimit),
|
httpmw.RateLimit(options.APIRateLimit, time.Minute),
|
||||||
httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{
|
httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{
|
||||||
DB: options.Database,
|
DB: options.Database,
|
||||||
OAuth2Configs: oauthConfigs,
|
OAuth2Configs: oauthConfigs,
|
||||||
@ -267,7 +267,7 @@ func New(options *Options) *API {
|
|||||||
r.Use(
|
r.Use(
|
||||||
tracing.Middleware(api.TracerProvider),
|
tracing.Middleware(api.TracerProvider),
|
||||||
// Specific routes can specify smaller limits.
|
// Specific routes can specify smaller limits.
|
||||||
httpmw.RateLimitPerMinute(options.APIRateLimit),
|
httpmw.RateLimit(options.APIRateLimit, time.Minute),
|
||||||
)
|
)
|
||||||
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
|
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
|
||||||
httpapi.Write(r.Context(), w, http.StatusOK, codersdk.Response{
|
httpapi.Write(r.Context(), w, http.StatusOK, codersdk.Response{
|
||||||
@ -304,7 +304,7 @@ func New(options *Options) *API {
|
|||||||
apiKeyMiddleware,
|
apiKeyMiddleware,
|
||||||
// This number is arbitrary, but reading/writing
|
// This number is arbitrary, but reading/writing
|
||||||
// file content is expensive so it should be small.
|
// file content is expensive so it should be small.
|
||||||
httpmw.RateLimitPerMinute(12),
|
httpmw.RateLimit(12, time.Minute),
|
||||||
)
|
)
|
||||||
r.Get("/{fileID}", api.fileByID)
|
r.Get("/{fileID}", api.fileByID)
|
||||||
r.Post("/", api.postFile)
|
r.Post("/", api.postFile)
|
||||||
@ -391,7 +391,15 @@ func New(options *Options) *API {
|
|||||||
r.Route("/users", func(r chi.Router) {
|
r.Route("/users", func(r chi.Router) {
|
||||||
r.Get("/first", api.firstUser)
|
r.Get("/first", api.firstUser)
|
||||||
r.Post("/first", api.postFirstUser)
|
r.Post("/first", api.postFirstUser)
|
||||||
r.Post("/login", api.postLogin)
|
r.Group(func(r chi.Router) {
|
||||||
|
// We use a tight limit for password login to protect
|
||||||
|
// against audit-log write DoS, pbkdf2 DoS, and simple
|
||||||
|
// brute-force attacks.
|
||||||
|
//
|
||||||
|
// Making this too small can break tests.
|
||||||
|
r.Use(httpmw.RateLimit(60, time.Minute))
|
||||||
|
r.Post("/login", api.postLogin)
|
||||||
|
})
|
||||||
r.Get("/authmethods", api.userAuthMethods)
|
r.Get("/authmethods", api.userAuthMethods)
|
||||||
r.Route("/oauth2", func(r chi.Router) {
|
r.Route("/oauth2", func(r chi.Router) {
|
||||||
r.Route("/github", func(r chi.Router) {
|
r.Route("/github", func(r chi.Router) {
|
||||||
|
@ -631,8 +631,8 @@ func TestAPIKey(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func createUser(ctx context.Context, t *testing.T, db database.Store) database.User {
|
func createUser(ctx context.Context, t *testing.T, db database.Store, opts ...func(u *database.InsertUserParams)) database.User {
|
||||||
user, err := db.InsertUser(ctx, database.InsertUserParams{
|
insert := database.InsertUserParams{
|
||||||
ID: uuid.New(),
|
ID: uuid.New(),
|
||||||
Email: "email@coder.com",
|
Email: "email@coder.com",
|
||||||
Username: "username",
|
Username: "username",
|
||||||
@ -640,7 +640,11 @@ func createUser(ctx context.Context, t *testing.T, db database.Store) database.U
|
|||||||
CreatedAt: time.Now(),
|
CreatedAt: time.Now(),
|
||||||
UpdatedAt: time.Now(),
|
UpdatedAt: time.Now(),
|
||||||
RBACRoles: []string{},
|
RBACRoles: []string{},
|
||||||
})
|
}
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(&insert)
|
||||||
|
}
|
||||||
|
user, err := db.InsertUser(ctx, insert)
|
||||||
require.NoError(t, err, "create user")
|
require.NoError(t, err, "create user")
|
||||||
return user
|
return user
|
||||||
}
|
}
|
||||||
|
@ -1,39 +1,71 @@
|
|||||||
package httpmw
|
package httpmw
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-chi/httprate"
|
"github.com/go-chi/httprate"
|
||||||
|
"golang.org/x/xerrors"
|
||||||
|
|
||||||
"github.com/coder/coder/coderd/database"
|
"github.com/coder/coder/coderd/database"
|
||||||
"github.com/coder/coder/coderd/httpapi"
|
"github.com/coder/coder/coderd/httpapi"
|
||||||
|
"github.com/coder/coder/coderd/rbac"
|
||||||
"github.com/coder/coder/codersdk"
|
"github.com/coder/coder/codersdk"
|
||||||
|
"github.com/coder/coder/cryptorand"
|
||||||
)
|
)
|
||||||
|
|
||||||
// RateLimitPerMinute returns a handler that limits requests per-minute based
|
// RateLimit returns a handler that limits requests per-minute based
|
||||||
// on IP, endpoint, and user ID (if available).
|
// on IP, endpoint, and user ID (if available).
|
||||||
func RateLimitPerMinute(count int) func(http.Handler) http.Handler {
|
func RateLimit(count int, window time.Duration) func(http.Handler) http.Handler {
|
||||||
// -1 is no rate limit
|
// -1 is no rate limit
|
||||||
if count <= 0 {
|
if count <= 0 {
|
||||||
return func(handler http.Handler) http.Handler {
|
return func(handler http.Handler) http.Handler {
|
||||||
return handler
|
return handler
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return httprate.Limit(
|
return httprate.Limit(
|
||||||
count,
|
count,
|
||||||
1*time.Minute,
|
window,
|
||||||
httprate.WithKeyFuncs(func(r *http.Request) (string, error) {
|
httprate.WithKeyFuncs(func(r *http.Request) (string, error) {
|
||||||
// Prioritize by user, but fallback to IP.
|
// Prioritize by user, but fallback to IP.
|
||||||
apiKey, ok := r.Context().Value(apiKeyContextKey{}).(database.APIKey)
|
apiKey, ok := r.Context().Value(apiKeyContextKey{}).(database.APIKey)
|
||||||
if ok {
|
if !ok {
|
||||||
|
return httprate.KeyByIP(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
if ok, _ := strconv.ParseBool(r.Header.Get(codersdk.BypassRatelimitHeader)); !ok {
|
||||||
|
// No bypass attempt, just ratelimit.
|
||||||
return apiKey.UserID.String(), nil
|
return apiKey.UserID.String(), nil
|
||||||
}
|
}
|
||||||
return httprate.KeyByIP(r)
|
|
||||||
|
// Allow Owner to bypass rate limiting for load tests
|
||||||
|
// and automation.
|
||||||
|
auth := UserAuthorization(r)
|
||||||
|
|
||||||
|
// We avoid using rbac.Authorizer since rego is CPU-intensive
|
||||||
|
// and undermines the DoS-prevention goal of the rate limiter.
|
||||||
|
for _, role := range auth.Roles {
|
||||||
|
if role == rbac.RoleOwner() {
|
||||||
|
// HACK: use a random key each time to
|
||||||
|
// de facto disable rate limiting. The
|
||||||
|
// `httprate` package has no
|
||||||
|
// support for selectively changing the limit
|
||||||
|
// for particular keys.
|
||||||
|
return cryptorand.String(16)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return apiKey.UserID.String(), xerrors.Errorf(
|
||||||
|
"%q provided but user is not %v",
|
||||||
|
codersdk.BypassRatelimitHeader, rbac.RoleOwner(),
|
||||||
|
)
|
||||||
}, httprate.KeyByEndpoint),
|
}, httprate.KeyByEndpoint),
|
||||||
httprate.WithLimitHandler(func(w http.ResponseWriter, r *http.Request) {
|
httprate.WithLimitHandler(func(w http.ResponseWriter, r *http.Request) {
|
||||||
httpapi.Write(r.Context(), w, http.StatusTooManyRequests, codersdk.Response{
|
httpapi.Write(r.Context(), w, http.StatusTooManyRequests, codersdk.Response{
|
||||||
Message: "You've been rate limited for sending too many requests!",
|
Message: fmt.Sprintf("You've been rate limited for sending more than %v requests in %v.", count, window),
|
||||||
})
|
})
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
@ -1,23 +1,60 @@
|
|||||||
package httpmw_test
|
package httpmw_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/go-chi/chi/v5"
|
"github.com/go-chi/chi/v5"
|
||||||
|
"github.com/google/uuid"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/coder/coder/coderd/database"
|
||||||
|
"github.com/coder/coder/coderd/database/databasefake"
|
||||||
"github.com/coder/coder/coderd/httpmw"
|
"github.com/coder/coder/coderd/httpmw"
|
||||||
|
"github.com/coder/coder/coderd/rbac"
|
||||||
|
"github.com/coder/coder/codersdk"
|
||||||
"github.com/coder/coder/testutil"
|
"github.com/coder/coder/testutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func insertAPIKey(ctx context.Context, t *testing.T, db database.Store, userID uuid.UUID) string {
|
||||||
|
id, secret := randomAPIKeyParts()
|
||||||
|
hashed := sha256.Sum256([]byte(secret))
|
||||||
|
|
||||||
|
_, err := db.InsertAPIKey(ctx, database.InsertAPIKeyParams{
|
||||||
|
ID: id,
|
||||||
|
HashedSecret: hashed[:],
|
||||||
|
LastUsed: database.Now().AddDate(0, 0, -1),
|
||||||
|
ExpiresAt: database.Now().AddDate(0, 0, 1),
|
||||||
|
UserID: userID,
|
||||||
|
LoginType: database.LoginTypePassword,
|
||||||
|
Scope: database.APIKeyScopeAll,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return fmt.Sprintf("%s-%s", id, secret)
|
||||||
|
}
|
||||||
|
|
||||||
|
func randRemoteAddr() string {
|
||||||
|
var b [4]byte
|
||||||
|
// nolint:gosec
|
||||||
|
rand.Read(b[:])
|
||||||
|
// nolint:gosec
|
||||||
|
return fmt.Sprintf("%s:%v", net.IP(b[:]).String(), rand.Int31()%(1<<16))
|
||||||
|
}
|
||||||
|
|
||||||
func TestRateLimit(t *testing.T) {
|
func TestRateLimit(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
t.Run("NoUser", func(t *testing.T) {
|
t.Run("NoUserSucceeds", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
rtr := chi.NewRouter()
|
rtr := chi.NewRouter()
|
||||||
rtr.Use(httpmw.RateLimitPerMinute(5))
|
rtr.Use(httpmw.RateLimit(5, time.Second))
|
||||||
rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) {
|
rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) {
|
||||||
rw.WriteHeader(http.StatusOK)
|
rw.WriteHeader(http.StatusOK)
|
||||||
})
|
})
|
||||||
@ -31,4 +68,107 @@ func TestRateLimit(t *testing.T) {
|
|||||||
return resp.StatusCode == http.StatusTooManyRequests
|
return resp.StatusCode == http.StatusTooManyRequests
|
||||||
}, testutil.WaitShort, testutil.IntervalFast)
|
}, testutil.WaitShort, testutil.IntervalFast)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("RandomIPs", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
rtr := chi.NewRouter()
|
||||||
|
rtr.Use(httpmw.RateLimit(5, time.Second))
|
||||||
|
rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) {
|
||||||
|
rw.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
require.Never(t, func() bool {
|
||||||
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
req.RemoteAddr = randRemoteAddr()
|
||||||
|
rtr.ServeHTTP(rec, req)
|
||||||
|
resp := rec.Result()
|
||||||
|
defer resp.Body.Close()
|
||||||
|
return resp.StatusCode == http.StatusTooManyRequests
|
||||||
|
}, testutil.WaitShort, testutil.IntervalFast)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("RegularUser", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
db := databasefake.New()
|
||||||
|
|
||||||
|
u := createUser(ctx, t, db)
|
||||||
|
key := insertAPIKey(ctx, t, db, u.ID)
|
||||||
|
|
||||||
|
rtr := chi.NewRouter()
|
||||||
|
rtr.Use(httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{
|
||||||
|
DB: db,
|
||||||
|
Optional: false,
|
||||||
|
}))
|
||||||
|
|
||||||
|
rtr.Use(httpmw.RateLimit(5, time.Second))
|
||||||
|
rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) {
|
||||||
|
rw.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Bypass must fail
|
||||||
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
|
req.Header.Set(codersdk.SessionCustomHeader, key)
|
||||||
|
req.Header.Set(codersdk.BypassRatelimitHeader, "true")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
// Assert we're not using IP address.
|
||||||
|
req.RemoteAddr = randRemoteAddr()
|
||||||
|
rtr.ServeHTTP(rec, req)
|
||||||
|
resp := rec.Result()
|
||||||
|
defer resp.Body.Close()
|
||||||
|
require.Equal(t, http.StatusPreconditionRequired, resp.StatusCode)
|
||||||
|
|
||||||
|
require.Eventually(t, func() bool {
|
||||||
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
|
req.Header.Set(codersdk.SessionCustomHeader, key)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
// Assert we're not using IP address.
|
||||||
|
req.RemoteAddr = randRemoteAddr()
|
||||||
|
rtr.ServeHTTP(rec, req)
|
||||||
|
resp := rec.Result()
|
||||||
|
defer resp.Body.Close()
|
||||||
|
return resp.StatusCode == http.StatusTooManyRequests
|
||||||
|
}, testutil.WaitShort, testutil.IntervalFast)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("OwnerBypass", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
db := databasefake.New()
|
||||||
|
|
||||||
|
u := createUser(ctx, t, db, func(u *database.InsertUserParams) {
|
||||||
|
u.RBACRoles = []string{rbac.RoleOwner()}
|
||||||
|
})
|
||||||
|
|
||||||
|
key := insertAPIKey(ctx, t, db, u.ID)
|
||||||
|
|
||||||
|
rtr := chi.NewRouter()
|
||||||
|
rtr.Use(httpmw.ExtractAPIKey(httpmw.ExtractAPIKeyConfig{
|
||||||
|
DB: db,
|
||||||
|
Optional: false,
|
||||||
|
}))
|
||||||
|
|
||||||
|
rtr.Use(httpmw.RateLimit(5, time.Second))
|
||||||
|
rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) {
|
||||||
|
rw.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
require.Never(t, func() bool {
|
||||||
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
|
req.Header.Set(codersdk.SessionCustomHeader, key)
|
||||||
|
req.Header.Set(codersdk.BypassRatelimitHeader, "true")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
// Assert we're not using IP address.
|
||||||
|
req.RemoteAddr = randRemoteAddr()
|
||||||
|
rtr.ServeHTTP(rec, req)
|
||||||
|
resp := rec.Result()
|
||||||
|
defer resp.Body.Close()
|
||||||
|
return resp.StatusCode == http.StatusTooManyRequests
|
||||||
|
}, testutil.WaitShort, testutil.IntervalFast)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
@ -24,6 +24,9 @@ const (
|
|||||||
SessionCustomHeader = "Coder-Session-Token"
|
SessionCustomHeader = "Coder-Session-Token"
|
||||||
OAuth2StateKey = "oauth_state"
|
OAuth2StateKey = "oauth_state"
|
||||||
OAuth2RedirectKey = "oauth_redirect"
|
OAuth2RedirectKey = "oauth_redirect"
|
||||||
|
|
||||||
|
// nolint: gosec
|
||||||
|
BypassRatelimitHeader = "X-Coder-Bypass-Ratelimit"
|
||||||
)
|
)
|
||||||
|
|
||||||
// New creates a Coder client for the provided URL.
|
// New creates a Coder client for the provided URL.
|
||||||
|
Reference in New Issue
Block a user