chore(coderd/httpmw): remove dbmem usage from tests (#18146)

Related to https://github.com/coder/coder/issues/15109
This commit is contained in:
Hugo Dutka
2025-06-02 13:57:56 +02:00
committed by GitHub
parent f986d13a9c
commit 782d01bae2
14 changed files with 235 additions and 175 deletions

View File

@ -6,10 +6,8 @@ import (
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"slices"
"strings"
"sync/atomic"
"testing"
@ -18,12 +16,13 @@ import (
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/exp/slices"
"golang.org/x/oauth2"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbmem"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/httpmw"
@ -83,9 +82,9 @@ func TestAPIKey(t *testing.T) {
t.Run("NoCookie", func(t *testing.T) {
t.Parallel()
var (
db = dbmem.New()
r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
db, _ = dbtestutil.NewDB(t)
r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
)
httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
DB: db,
@ -99,9 +98,9 @@ func TestAPIKey(t *testing.T) {
t.Run("NoCookieRedirects", func(t *testing.T) {
t.Parallel()
var (
db = dbmem.New()
r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
db, _ = dbtestutil.NewDB(t)
r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
)
httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
DB: db,
@ -118,9 +117,9 @@ func TestAPIKey(t *testing.T) {
t.Run("InvalidFormat", func(t *testing.T) {
t.Parallel()
var (
db = dbmem.New()
r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
db, _ = dbtestutil.NewDB(t)
r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
)
r.Header.Set(codersdk.SessionTokenHeader, "test-wow-hello")
@ -136,9 +135,9 @@ func TestAPIKey(t *testing.T) {
t.Run("InvalidIDLength", func(t *testing.T) {
t.Parallel()
var (
db = dbmem.New()
r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
db, _ = dbtestutil.NewDB(t)
r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
)
r.Header.Set(codersdk.SessionTokenHeader, "test-wow")
@ -154,9 +153,9 @@ func TestAPIKey(t *testing.T) {
t.Run("InvalidSecretLength", func(t *testing.T) {
t.Parallel()
var (
db = dbmem.New()
r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
db, _ = dbtestutil.NewDB(t)
r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
)
r.Header.Set(codersdk.SessionTokenHeader, "testtestid-wow")
@ -172,7 +171,7 @@ func TestAPIKey(t *testing.T) {
t.Run("NotFound", func(t *testing.T) {
t.Parallel()
var (
db = dbmem.New()
db, _ = dbtestutil.NewDB(t)
id, secret = randomAPIKeyParts()
r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
@ -191,10 +190,10 @@ func TestAPIKey(t *testing.T) {
t.Run("UserLinkNotFound", func(t *testing.T) {
t.Parallel()
var (
db = dbmem.New()
r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
user = dbgen.User(t, db, database.User{
db, _ = dbtestutil.NewDB(t)
r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
user = dbgen.User(t, db, database.User{
LoginType: database.LoginTypeGithub,
})
// Intentionally not inserting any user link
@ -219,10 +218,10 @@ func TestAPIKey(t *testing.T) {
t.Run("InvalidSecret", func(t *testing.T) {
t.Parallel()
var (
db = dbmem.New()
r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
user = dbgen.User(t, db, database.User{})
db, _ = dbtestutil.NewDB(t)
r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
user = dbgen.User(t, db, database.User{})
// Use a different secret so they don't match!
hashed = sha256.Sum256([]byte("differentsecret"))
@ -244,7 +243,7 @@ func TestAPIKey(t *testing.T) {
t.Run("Expired", func(t *testing.T) {
t.Parallel()
var (
db = dbmem.New()
db, _ = dbtestutil.NewDB(t)
user = dbgen.User(t, db, database.User{})
_, token = dbgen.APIKey(t, db, database.APIKey{
UserID: user.ID,
@ -273,7 +272,7 @@ func TestAPIKey(t *testing.T) {
t.Run("Valid", func(t *testing.T) {
t.Parallel()
var (
db = dbmem.New()
db, _ = dbtestutil.NewDB(t)
user = dbgen.User(t, db, database.User{})
sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{
UserID: user.ID,
@ -309,7 +308,7 @@ func TestAPIKey(t *testing.T) {
t.Run("ValidWithScope", func(t *testing.T) {
t.Parallel()
var (
db = dbmem.New()
db, _ = dbtestutil.NewDB(t)
user = dbgen.User(t, db, database.User{})
_, token = dbgen.APIKey(t, db, database.APIKey{
UserID: user.ID,
@ -347,7 +346,7 @@ func TestAPIKey(t *testing.T) {
t.Run("QueryParameter", func(t *testing.T) {
t.Parallel()
var (
db = dbmem.New()
db, _ = dbtestutil.NewDB(t)
user = dbgen.User(t, db, database.User{})
_, token = dbgen.APIKey(t, db, database.APIKey{
UserID: user.ID,
@ -381,7 +380,7 @@ func TestAPIKey(t *testing.T) {
t.Run("ValidUpdateLastUsed", func(t *testing.T) {
t.Parallel()
var (
db = dbmem.New()
db, _ = dbtestutil.NewDB(t)
user = dbgen.User(t, db, database.User{})
sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{
UserID: user.ID,
@ -412,7 +411,7 @@ func TestAPIKey(t *testing.T) {
t.Run("ValidUpdateExpiry", func(t *testing.T) {
t.Parallel()
var (
db = dbmem.New()
db, _ = dbtestutil.NewDB(t)
user = dbgen.User(t, db, database.User{})
sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{
UserID: user.ID,
@ -443,7 +442,7 @@ func TestAPIKey(t *testing.T) {
t.Run("NoRefresh", func(t *testing.T) {
t.Parallel()
var (
db = dbmem.New()
db, _ = dbtestutil.NewDB(t)
user = dbgen.User(t, db, database.User{})
sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{
UserID: user.ID,
@ -475,7 +474,7 @@ func TestAPIKey(t *testing.T) {
t.Run("OAuthNotExpired", func(t *testing.T) {
t.Parallel()
var (
db = dbmem.New()
db, _ = dbtestutil.NewDB(t)
user = dbgen.User(t, db, database.User{})
sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{
UserID: user.ID,
@ -511,7 +510,7 @@ func TestAPIKey(t *testing.T) {
t.Run("APIKeyExpiredOAuthExpired", func(t *testing.T) {
t.Parallel()
var (
db = dbmem.New()
db, _ = dbtestutil.NewDB(t)
user = dbgen.User(t, db, database.User{})
sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{
UserID: user.ID,
@ -561,7 +560,7 @@ func TestAPIKey(t *testing.T) {
t.Run("APIKeyExpiredOAuthNotExpired", func(t *testing.T) {
t.Parallel()
var (
db = dbmem.New()
db, _ = dbtestutil.NewDB(t)
user = dbgen.User(t, db, database.User{})
sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{
UserID: user.ID,
@ -607,7 +606,7 @@ func TestAPIKey(t *testing.T) {
t.Run("OAuthRefresh", func(t *testing.T) {
t.Parallel()
var (
db = dbmem.New()
db, _ = dbtestutil.NewDB(t)
user = dbgen.User(t, db, database.User{})
sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{
UserID: user.ID,
@ -630,7 +629,7 @@ func TestAPIKey(t *testing.T) {
oauthToken := &oauth2.Token{
AccessToken: "wow",
RefreshToken: "moo",
Expiry: dbtime.Now().AddDate(0, 0, 1),
Expiry: dbtestutil.NowInDefaultTimezone().AddDate(0, 0, 1),
}
httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
DB: db,
@ -665,7 +664,7 @@ func TestAPIKey(t *testing.T) {
t.Parallel()
var (
ctx = testutil.Context(t, testutil.WaitShort)
db = dbmem.New()
db, _ = dbtestutil.NewDB(t)
user = dbgen.User(t, db, database.User{})
sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{
UserID: user.ID,
@ -715,7 +714,7 @@ func TestAPIKey(t *testing.T) {
t.Run("RemoteIPUpdates", func(t *testing.T) {
t.Parallel()
var (
db = dbmem.New()
db, _ = dbtestutil.NewDB(t)
user = dbgen.User(t, db, database.User{})
sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{
UserID: user.ID,
@ -740,15 +739,15 @@ func TestAPIKey(t *testing.T) {
gotAPIKey, err := db.GetAPIKeyByID(r.Context(), sentAPIKey.ID)
require.NoError(t, err)
require.Equal(t, net.ParseIP("1.1.1.1"), gotAPIKey.IPAddress.IPNet.IP)
require.Equal(t, "1.1.1.1", gotAPIKey.IPAddress.IPNet.IP.String())
})
t.Run("RedirectToLogin", func(t *testing.T) {
t.Parallel()
var (
db = dbmem.New()
r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
db, _ = dbtestutil.NewDB(t)
r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
)
httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
@ -767,9 +766,9 @@ func TestAPIKey(t *testing.T) {
t.Run("Optional", func(t *testing.T) {
t.Parallel()
var (
db = dbmem.New()
r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
db, _ = dbtestutil.NewDB(t)
r = httptest.NewRequest("GET", "/", nil)
rw = httptest.NewRecorder()
count int64
handler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
@ -798,7 +797,7 @@ func TestAPIKey(t *testing.T) {
t.Run("Tokens", func(t *testing.T) {
t.Parallel()
var (
db = dbmem.New()
db, _ = dbtestutil.NewDB(t)
user = dbgen.User(t, db, database.User{})
sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{
UserID: user.ID,
@ -831,7 +830,7 @@ func TestAPIKey(t *testing.T) {
t.Run("MissingConfig", func(t *testing.T) {
t.Parallel()
var (
db = dbmem.New()
db, _ = dbtestutil.NewDB(t)
user = dbgen.User(t, db, database.User{})
_, token = dbgen.APIKey(t, db, database.APIKey{
UserID: user.ID,
@ -866,7 +865,7 @@ func TestAPIKey(t *testing.T) {
t.Run("CustomRoles", func(t *testing.T) {
t.Parallel()
var (
db = dbmem.New()
db, _ = dbtestutil.NewDB(t)
org = dbgen.Organization(t, db, database.Organization{})
customRole = dbgen.CustomRole(t, db, database.CustomRole{
Name: "custom-role",
@ -933,7 +932,7 @@ func TestAPIKey(t *testing.T) {
t.Parallel()
var (
roleNotExistsName = "role-not-exists"
db = dbmem.New()
db, _ = dbtestutil.NewDB(t)
org = dbgen.Organization(t, db, database.Organization{})
user = dbgen.User(t, db, database.User{
RBACRoles: []string{