mirror of
https://github.com/coder/coder.git
synced 2025-07-21 01:28:49 +00:00
feat: add session token injection to provisioner (#7461)
This commit is contained in:
118
coderd/apikey.go
118
coderd/apikey.go
@ -2,9 +2,7 @@ package coderd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
@ -12,9 +10,9 @@ import (
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/moby/moby/pkg/namesgenerator"
|
||||
"github.com/tabbed/pqtype"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/coderd/apikey"
|
||||
"github.com/coder/coder/coderd/audit"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
@ -22,7 +20,6 @@ import (
|
||||
"github.com/coder/coder/coderd/rbac"
|
||||
"github.com/coder/coder/coderd/telemetry"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/cryptorand"
|
||||
)
|
||||
|
||||
// Creates a new token API key that effectively doesn't expire.
|
||||
@ -83,13 +80,14 @@ func (api *API) postToken(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
cookie, key, err := api.createAPIKey(ctx, createAPIKeyParams{
|
||||
UserID: user.ID,
|
||||
LoginType: database.LoginTypeToken,
|
||||
ExpiresAt: database.Now().Add(lifeTime),
|
||||
Scope: scope,
|
||||
LifetimeSeconds: int64(lifeTime.Seconds()),
|
||||
TokenName: tokenName,
|
||||
cookie, key, err := api.createAPIKey(ctx, apikey.CreateParams{
|
||||
UserID: user.ID,
|
||||
LoginType: database.LoginTypeToken,
|
||||
DeploymentValues: api.DeploymentValues,
|
||||
ExpiresAt: database.Now().Add(lifeTime),
|
||||
Scope: scope,
|
||||
LifetimeSeconds: int64(lifeTime.Seconds()),
|
||||
TokenName: tokenName,
|
||||
})
|
||||
if err != nil {
|
||||
if database.IsUniqueViolation(err, database.UniqueIndexApiKeyName) {
|
||||
@ -127,10 +125,11 @@ func (api *API) postAPIKey(rw http.ResponseWriter, r *http.Request) {
|
||||
user := httpmw.UserParam(r)
|
||||
|
||||
lifeTime := time.Hour * 24 * 7
|
||||
cookie, _, err := api.createAPIKey(ctx, createAPIKeyParams{
|
||||
UserID: user.ID,
|
||||
LoginType: database.LoginTypePassword,
|
||||
RemoteAddr: r.RemoteAddr,
|
||||
cookie, _, err := api.createAPIKey(ctx, apikey.CreateParams{
|
||||
UserID: user.ID,
|
||||
DeploymentValues: api.DeploymentValues,
|
||||
LoginType: database.LoginTypePassword,
|
||||
RemoteAddr: r.RemoteAddr,
|
||||
// All api generated keys will last 1 week. Browser login tokens have
|
||||
// a shorter life.
|
||||
ExpiresAt: database.Now().Add(lifeTime),
|
||||
@ -359,33 +358,6 @@ func (api *API) tokenConfig(rw http.ResponseWriter, r *http.Request) {
|
||||
)
|
||||
}
|
||||
|
||||
// Generates a new ID and secret for an API key.
|
||||
func GenerateAPIKeyIDSecret() (id string, secret string, err error) {
|
||||
// Length of an API Key ID.
|
||||
id, err = cryptorand.String(10)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
// Length of an API Key secret.
|
||||
secret, err = cryptorand.String(22)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
return id, secret, nil
|
||||
}
|
||||
|
||||
type createAPIKeyParams struct {
|
||||
UserID uuid.UUID
|
||||
RemoteAddr string
|
||||
LoginType database.LoginType
|
||||
|
||||
// Optional.
|
||||
ExpiresAt time.Time
|
||||
LifetimeSeconds int64
|
||||
Scope database.APIKeyScope
|
||||
TokenName string
|
||||
}
|
||||
|
||||
func (api *API) validateAPIKeyLifetime(lifetime time.Duration) error {
|
||||
if lifetime <= 0 {
|
||||
return xerrors.New("lifetime must be positive number greater than 0")
|
||||
@ -401,73 +373,21 @@ func (api *API) validateAPIKeyLifetime(lifetime time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (api *API) createAPIKey(ctx context.Context, params createAPIKeyParams) (*http.Cookie, *database.APIKey, error) {
|
||||
keyID, keySecret, err := GenerateAPIKeyIDSecret()
|
||||
func (api *API) createAPIKey(ctx context.Context, params apikey.CreateParams) (*http.Cookie, *database.APIKey, error) {
|
||||
key, sessionToken, err := apikey.Generate(params)
|
||||
if err != nil {
|
||||
return nil, nil, xerrors.Errorf("generate API key: %w", err)
|
||||
}
|
||||
hashed := sha256.Sum256([]byte(keySecret))
|
||||
|
||||
// Default expires at to now+lifetime, or use the configured value if not
|
||||
// set.
|
||||
if params.ExpiresAt.IsZero() {
|
||||
if params.LifetimeSeconds != 0 {
|
||||
params.ExpiresAt = database.Now().Add(time.Duration(params.LifetimeSeconds) * time.Second)
|
||||
} else {
|
||||
params.ExpiresAt = database.Now().Add(api.DeploymentValues.SessionDuration.Value())
|
||||
params.LifetimeSeconds = int64(api.DeploymentValues.SessionDuration.Value().Seconds())
|
||||
}
|
||||
}
|
||||
if params.LifetimeSeconds == 0 {
|
||||
params.LifetimeSeconds = int64(time.Until(params.ExpiresAt).Seconds())
|
||||
}
|
||||
|
||||
ip := net.ParseIP(params.RemoteAddr)
|
||||
if ip == nil {
|
||||
ip = net.IPv4(0, 0, 0, 0)
|
||||
}
|
||||
bitlen := len(ip) * 8
|
||||
|
||||
scope := database.APIKeyScopeAll
|
||||
if params.Scope != "" {
|
||||
scope = params.Scope
|
||||
}
|
||||
switch scope {
|
||||
case database.APIKeyScopeAll, database.APIKeyScopeApplicationConnect:
|
||||
default:
|
||||
return nil, nil, xerrors.Errorf("invalid API key scope: %q", scope)
|
||||
}
|
||||
|
||||
key, err := api.Database.InsertAPIKey(ctx, database.InsertAPIKeyParams{
|
||||
ID: keyID,
|
||||
UserID: params.UserID,
|
||||
LifetimeSeconds: params.LifetimeSeconds,
|
||||
IPAddress: pqtype.Inet{
|
||||
IPNet: net.IPNet{
|
||||
IP: ip,
|
||||
Mask: net.CIDRMask(bitlen, bitlen),
|
||||
},
|
||||
Valid: true,
|
||||
},
|
||||
// Make sure in UTC time for common time zone
|
||||
ExpiresAt: params.ExpiresAt.UTC(),
|
||||
CreatedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
HashedSecret: hashed[:],
|
||||
LoginType: params.LoginType,
|
||||
Scope: scope,
|
||||
TokenName: params.TokenName,
|
||||
})
|
||||
newkey, err := api.Database.InsertAPIKey(ctx, key)
|
||||
if err != nil {
|
||||
return nil, nil, xerrors.Errorf("insert API key: %w", err)
|
||||
}
|
||||
|
||||
api.Telemetry.Report(&telemetry.Snapshot{
|
||||
APIKeys: []telemetry.APIKey{telemetry.ConvertAPIKey(key)},
|
||||
APIKeys: []telemetry.APIKey{telemetry.ConvertAPIKey(newkey)},
|
||||
})
|
||||
|
||||
// This format is consumed by the APIKey middleware.
|
||||
sessionToken := fmt.Sprintf("%s-%s", keyID, keySecret)
|
||||
return &http.Cookie{
|
||||
Name: codersdk.SessionTokenCookie,
|
||||
Value: sessionToken,
|
||||
@ -475,5 +395,5 @@ func (api *API) createAPIKey(ctx context.Context, params createAPIKeyParams) (*h
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
Secure: api.SecureAuthCookie,
|
||||
}, &key, nil
|
||||
}, &newkey, nil
|
||||
}
|
||||
|
110
coderd/apikey/apikey.go
Normal file
110
coderd/apikey/apikey.go
Normal file
@ -0,0 +1,110 @@
|
||||
package apikey
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/tabbed/pqtype"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/cryptorand"
|
||||
)
|
||||
|
||||
type CreateParams struct {
|
||||
UserID uuid.UUID
|
||||
LoginType database.LoginType
|
||||
DeploymentValues *codersdk.DeploymentValues
|
||||
|
||||
// Optional.
|
||||
ExpiresAt time.Time
|
||||
LifetimeSeconds int64
|
||||
Scope database.APIKeyScope
|
||||
TokenName string
|
||||
RemoteAddr string
|
||||
}
|
||||
|
||||
// Generate generates an API key, returning the key as a string as well as the
|
||||
// database representation. It is the responsibility of the caller to insert it
|
||||
// into the database.
|
||||
func Generate(params CreateParams) (database.InsertAPIKeyParams, string, error) {
|
||||
keyID, keySecret, err := generateKey()
|
||||
if err != nil {
|
||||
return database.InsertAPIKeyParams{}, "", xerrors.Errorf("generate API key: %w", err)
|
||||
}
|
||||
|
||||
hashed := sha256.Sum256([]byte(keySecret))
|
||||
|
||||
// Default expires at to now+lifetime, or use the configured value if not
|
||||
// set.
|
||||
if params.ExpiresAt.IsZero() {
|
||||
if params.LifetimeSeconds != 0 {
|
||||
params.ExpiresAt = database.Now().Add(time.Duration(params.LifetimeSeconds) * time.Second)
|
||||
} else {
|
||||
params.ExpiresAt = database.Now().Add(params.DeploymentValues.SessionDuration.Value())
|
||||
params.LifetimeSeconds = int64(params.DeploymentValues.SessionDuration.Value().Seconds())
|
||||
}
|
||||
}
|
||||
if params.LifetimeSeconds == 0 {
|
||||
params.LifetimeSeconds = int64(time.Until(params.ExpiresAt).Seconds())
|
||||
}
|
||||
|
||||
ip := net.ParseIP(params.RemoteAddr)
|
||||
if ip == nil {
|
||||
ip = net.IPv4(0, 0, 0, 0)
|
||||
}
|
||||
|
||||
bitlen := len(ip) * 8
|
||||
|
||||
scope := database.APIKeyScopeAll
|
||||
if params.Scope != "" {
|
||||
scope = params.Scope
|
||||
}
|
||||
switch scope {
|
||||
case database.APIKeyScopeAll, database.APIKeyScopeApplicationConnect:
|
||||
default:
|
||||
return database.InsertAPIKeyParams{}, "", xerrors.Errorf("invalid API key scope: %q", scope)
|
||||
}
|
||||
|
||||
token := fmt.Sprintf("%s-%s", keyID, keySecret)
|
||||
|
||||
return database.InsertAPIKeyParams{
|
||||
ID: keyID,
|
||||
UserID: params.UserID,
|
||||
LifetimeSeconds: params.LifetimeSeconds,
|
||||
IPAddress: pqtype.Inet{
|
||||
IPNet: net.IPNet{
|
||||
IP: ip,
|
||||
Mask: net.CIDRMask(bitlen, bitlen),
|
||||
},
|
||||
Valid: true,
|
||||
},
|
||||
// Make sure in UTC time for common time zone
|
||||
ExpiresAt: params.ExpiresAt.UTC(),
|
||||
CreatedAt: database.Now(),
|
||||
UpdatedAt: database.Now(),
|
||||
HashedSecret: hashed[:],
|
||||
LoginType: params.LoginType,
|
||||
Scope: scope,
|
||||
TokenName: params.TokenName,
|
||||
}, token, nil
|
||||
}
|
||||
|
||||
// generateKey a new ID and secret for an API key.
|
||||
func generateKey() (id string, secret string, err error) {
|
||||
// Length of an API Key ID.
|
||||
id, err = cryptorand.String(10)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
// Length of an API Key secret.
|
||||
secret, err = cryptorand.String(22)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
return id, secret, nil
|
||||
}
|
164
coderd/apikey/apikey_test.go
Normal file
164
coderd/apikey/apikey_test.go
Normal file
@ -0,0 +1,164 @@
|
||||
package apikey_test
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/cli/clibase"
|
||||
"github.com/coder/coder/coderd/apikey"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/codersdk"
|
||||
)
|
||||
|
||||
func TestGenerate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type testcase struct {
|
||||
name string
|
||||
params apikey.CreateParams
|
||||
fail bool
|
||||
}
|
||||
|
||||
cases := []testcase{
|
||||
{
|
||||
name: "OK",
|
||||
params: apikey.CreateParams{
|
||||
UserID: uuid.New(),
|
||||
LoginType: database.LoginTypeOIDC,
|
||||
DeploymentValues: &codersdk.DeploymentValues{},
|
||||
ExpiresAt: time.Now().Add(time.Hour),
|
||||
LifetimeSeconds: int64(time.Hour.Seconds()),
|
||||
TokenName: "hello",
|
||||
RemoteAddr: "1.2.3.4",
|
||||
Scope: database.APIKeyScopeApplicationConnect,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "InvalidScope",
|
||||
params: apikey.CreateParams{
|
||||
UserID: uuid.New(),
|
||||
LoginType: database.LoginTypeOIDC,
|
||||
DeploymentValues: &codersdk.DeploymentValues{},
|
||||
ExpiresAt: time.Now().Add(time.Hour),
|
||||
LifetimeSeconds: int64(time.Hour.Seconds()),
|
||||
TokenName: "hello",
|
||||
RemoteAddr: "1.2.3.4",
|
||||
Scope: database.APIKeyScope("test"),
|
||||
},
|
||||
fail: true,
|
||||
},
|
||||
{
|
||||
name: "DeploymentSessionDuration",
|
||||
params: apikey.CreateParams{
|
||||
UserID: uuid.New(),
|
||||
LoginType: database.LoginTypeOIDC,
|
||||
DeploymentValues: &codersdk.DeploymentValues{
|
||||
SessionDuration: clibase.Duration(time.Hour),
|
||||
},
|
||||
LifetimeSeconds: 0,
|
||||
ExpiresAt: time.Time{},
|
||||
TokenName: "hello",
|
||||
RemoteAddr: "1.2.3.4",
|
||||
Scope: database.APIKeyScopeApplicationConnect,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "DefaultIP",
|
||||
params: apikey.CreateParams{
|
||||
UserID: uuid.New(),
|
||||
LoginType: database.LoginTypeOIDC,
|
||||
DeploymentValues: &codersdk.DeploymentValues{},
|
||||
ExpiresAt: time.Now().Add(time.Hour),
|
||||
LifetimeSeconds: int64(time.Hour.Seconds()),
|
||||
TokenName: "hello",
|
||||
RemoteAddr: "",
|
||||
Scope: database.APIKeyScopeApplicationConnect,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "DefaultScope",
|
||||
params: apikey.CreateParams{
|
||||
UserID: uuid.New(),
|
||||
LoginType: database.LoginTypeOIDC,
|
||||
DeploymentValues: &codersdk.DeploymentValues{},
|
||||
ExpiresAt: time.Now().Add(time.Hour),
|
||||
LifetimeSeconds: int64(time.Hour.Seconds()),
|
||||
TokenName: "hello",
|
||||
RemoteAddr: "1.2.3.4",
|
||||
Scope: "",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
key, keystr, err := apikey.Generate(tc.params)
|
||||
if tc.fail {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, keystr)
|
||||
require.NotEmpty(t, key.ID)
|
||||
require.NotEmpty(t, key.HashedSecret)
|
||||
|
||||
// Assert the string secret is formatted correctly
|
||||
keytokens := strings.Split(keystr, "-")
|
||||
require.Len(t, keytokens, 2)
|
||||
require.Equal(t, key.ID, keytokens[0])
|
||||
|
||||
// Assert that the hashed secret is correct.
|
||||
hashed := sha256.Sum256([]byte(keytokens[1]))
|
||||
assert.ElementsMatch(t, hashed, key.HashedSecret[:])
|
||||
|
||||
assert.Equal(t, tc.params.UserID, key.UserID)
|
||||
assert.WithinDuration(t, database.Now(), key.CreatedAt, time.Second*5)
|
||||
assert.WithinDuration(t, database.Now(), key.UpdatedAt, time.Second*5)
|
||||
|
||||
if tc.params.LifetimeSeconds > 0 {
|
||||
assert.Equal(t, tc.params.LifetimeSeconds, key.LifetimeSeconds)
|
||||
} else if !tc.params.ExpiresAt.IsZero() {
|
||||
// Should not be a delta greater than 5 seconds.
|
||||
assert.InDelta(t, time.Until(tc.params.ExpiresAt).Seconds(), key.LifetimeSeconds, 5)
|
||||
} else {
|
||||
assert.Equal(t, int64(tc.params.DeploymentValues.SessionDuration.Value().Seconds()), key.LifetimeSeconds)
|
||||
}
|
||||
|
||||
if !tc.params.ExpiresAt.IsZero() {
|
||||
assert.Equal(t, tc.params.ExpiresAt.UTC(), key.ExpiresAt)
|
||||
} else if tc.params.LifetimeSeconds > 0 {
|
||||
assert.WithinDuration(t, database.Now().Add(time.Duration(tc.params.LifetimeSeconds)), key.ExpiresAt, time.Second*5)
|
||||
} else {
|
||||
assert.WithinDuration(t, database.Now().Add(tc.params.DeploymentValues.SessionDuration.Value()), key.ExpiresAt, time.Second*5)
|
||||
}
|
||||
|
||||
if tc.params.RemoteAddr != "" {
|
||||
assert.Equal(t, tc.params.RemoteAddr, key.IPAddress.IPNet.IP.String())
|
||||
} else {
|
||||
assert.Equal(t, "0.0.0.0", key.IPAddress.IPNet.IP.String())
|
||||
}
|
||||
|
||||
if tc.params.Scope != "" {
|
||||
assert.Equal(t, tc.params.Scope, key.Scope)
|
||||
} else {
|
||||
assert.Equal(t, database.APIKeyScopeAll, key.Scope)
|
||||
}
|
||||
|
||||
if tc.params.TokenName != "" {
|
||||
assert.Equal(t, tc.params.TokenName, key.TokenName)
|
||||
}
|
||||
if tc.params.LoginType != "" {
|
||||
assert.Equal(t, tc.params.LoginType, key.LoginType)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -964,6 +964,7 @@ func (api *API) CreateInMemoryProvisionerDaemon(ctx context.Context, debounce ti
|
||||
TemplateScheduleStore: api.TemplateScheduleStore,
|
||||
AcquireJobDebounce: debounce,
|
||||
Logger: api.Logger.Named(fmt.Sprintf("provisionerd-%s", daemon.Name)),
|
||||
DeploymentValues: api.DeploymentValues,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -138,6 +138,7 @@ var (
|
||||
rbac.ResourceUser.Type: {rbac.ActionRead},
|
||||
rbac.ResourceWorkspace.Type: {rbac.ActionRead, rbac.ActionUpdate, rbac.ActionDelete},
|
||||
rbac.ResourceUserData.Type: {rbac.ActionRead, rbac.ActionUpdate},
|
||||
rbac.ResourceAPIKey.Type: {rbac.WildcardSymbol},
|
||||
}),
|
||||
Org: map[string][]rbac.Permission{},
|
||||
User: []rbac.Permission{},
|
||||
|
@ -27,6 +27,7 @@ import (
|
||||
|
||||
"cdr.dev/slog"
|
||||
|
||||
"github.com/coder/coder/coderd/apikey"
|
||||
"github.com/coder/coder/coderd/audit"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/dbauthz"
|
||||
@ -62,6 +63,7 @@ type Server struct {
|
||||
QuotaCommitter *atomic.Pointer[proto.QuotaCommitter]
|
||||
Auditor *atomic.Pointer[audit.Auditor]
|
||||
TemplateScheduleStore *atomic.Pointer[schedule.TemplateScheduleStore]
|
||||
DeploymentValues *codersdk.DeploymentValues
|
||||
|
||||
AcquireJobDebounce time.Duration
|
||||
OIDCConfig httpmw.OAuth2Config
|
||||
@ -193,6 +195,20 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac
|
||||
}
|
||||
}
|
||||
|
||||
var sessionToken string
|
||||
switch workspaceBuild.Transition {
|
||||
case database.WorkspaceTransitionStart:
|
||||
sessionToken, err = server.regenerateSessionToken(ctx, owner, workspace)
|
||||
if err != nil {
|
||||
return nil, failJob(fmt.Sprintf("regenerate session token: %s", err))
|
||||
}
|
||||
case database.WorkspaceTransitionStop, database.WorkspaceTransitionDelete:
|
||||
err = deleteSessionToken(ctx, server.Database, workspace)
|
||||
if err != nil {
|
||||
return nil, failJob(fmt.Sprintf("delete session token: %s", err))
|
||||
}
|
||||
}
|
||||
|
||||
// Compute parameters for the workspace to consume.
|
||||
parameters, err := parameter.Compute(ctx, server.Database, parameter.ComputeScope{
|
||||
TemplateImportJobID: templateVersion.JobID,
|
||||
@ -286,6 +302,7 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac
|
||||
WorkspaceOwnerId: owner.ID.String(),
|
||||
TemplateName: template.Name,
|
||||
TemplateVersion: templateVersion.Name,
|
||||
WorkspaceOwnerSessionToken: sessionToken,
|
||||
},
|
||||
LogLevel: input.LogLevel,
|
||||
},
|
||||
@ -1410,6 +1427,64 @@ func InsertWorkspaceResource(ctx context.Context, db database.Store, jobID uuid.
|
||||
return nil
|
||||
}
|
||||
|
||||
func workspaceSessionTokenName(workspace database.Workspace) string {
|
||||
return fmt.Sprintf("%s_%s_session_token", workspace.OwnerID, workspace.ID)
|
||||
}
|
||||
|
||||
func (server *Server) regenerateSessionToken(ctx context.Context, user database.User, workspace database.Workspace) (string, error) {
|
||||
newkey, sessionToken, err := apikey.Generate(apikey.CreateParams{
|
||||
UserID: user.ID,
|
||||
LoginType: user.LoginType,
|
||||
DeploymentValues: server.DeploymentValues,
|
||||
TokenName: workspaceSessionTokenName(workspace),
|
||||
LifetimeSeconds: int64(server.DeploymentValues.MaxTokenLifetime.Value().Seconds()),
|
||||
})
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("generate API key: %w", err)
|
||||
}
|
||||
|
||||
err = server.Database.InTx(func(tx database.Store) error {
|
||||
err := deleteSessionToken(ctx, tx, workspace)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("delete session token: %w", err)
|
||||
}
|
||||
|
||||
_, err = tx.InsertAPIKey(ctx, newkey)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("insert API key: %w", err)
|
||||
}
|
||||
return nil
|
||||
}, nil)
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("create API key: %w", err)
|
||||
}
|
||||
|
||||
return sessionToken, nil
|
||||
}
|
||||
|
||||
func deleteSessionToken(ctx context.Context, db database.Store, workspace database.Workspace) error {
|
||||
err := db.InTx(func(tx database.Store) error {
|
||||
key, err := tx.GetAPIKeyByName(ctx, database.GetAPIKeyByNameParams{
|
||||
UserID: workspace.OwnerID,
|
||||
TokenName: workspaceSessionTokenName(workspace),
|
||||
})
|
||||
if err == nil {
|
||||
err = tx.DeleteAPIKeyByID(ctx, key.ID)
|
||||
}
|
||||
|
||||
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
|
||||
return xerrors.Errorf("get api key by name: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}, nil)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("in tx: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// obtainOIDCAccessToken returns a valid OpenID Connect access token
|
||||
// for the user if it's able to obtain one, otherwise it returns an empty string.
|
||||
func obtainOIDCAccessToken(ctx context.Context, db database.Store, oidcConfig httpmw.OAuth2Config, userID uuid.UUID) (string, error) {
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
@ -15,6 +16,7 @@ import (
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"github.com/coder/coder/cli/clibase"
|
||||
"github.com/coder/coder/coderd/audit"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/dbfake"
|
||||
@ -61,6 +63,7 @@ func TestAcquireJob(t *testing.T) {
|
||||
Auditor: mockAuditor(),
|
||||
TemplateScheduleStore: testTemplateScheduleStore(),
|
||||
Tracer: trace.NewNoopTracerProvider().Tracer("noop"),
|
||||
DeploymentValues: &codersdk.DeploymentValues{},
|
||||
}
|
||||
job, err := srv.AcquireJob(context.Background(), nil)
|
||||
require.NoError(t, err)
|
||||
@ -102,6 +105,10 @@ func TestAcquireJob(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv := setup(t, false)
|
||||
gitAuthProvider := "github"
|
||||
// Set the max session token lifetime so we can assert we
|
||||
// create an API key with an expiration within the bounds of the
|
||||
// deployment config.
|
||||
srv.DeploymentValues.MaxTokenLifetime = clibase.Duration(time.Hour)
|
||||
srv.GitAuthConfigs = []*gitauth.Config{{
|
||||
ID: gitAuthProvider,
|
||||
OAuth2Config: &testutil.OAuth2Config{},
|
||||
@ -192,12 +199,16 @@ func TestAcquireJob(t *testing.T) {
|
||||
})),
|
||||
})
|
||||
|
||||
published := make(chan struct{})
|
||||
closeSubscribe, err := srv.Pubsub.Subscribe(codersdk.WorkspaceNotifyChannel(workspace.ID), func(_ context.Context, _ []byte) {
|
||||
close(published)
|
||||
startPublished := make(chan struct{})
|
||||
var closed bool
|
||||
closeStartSubscribe, err := srv.Pubsub.Subscribe(codersdk.WorkspaceNotifyChannel(workspace.ID), func(_ context.Context, _ []byte) {
|
||||
if !closed {
|
||||
close(startPublished)
|
||||
closed = true
|
||||
}
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer closeSubscribe()
|
||||
defer closeStartSubscribe()
|
||||
|
||||
var job *proto.AcquiredJob
|
||||
|
||||
@ -211,11 +222,21 @@ func TestAcquireJob(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
<-published
|
||||
<-startPublished
|
||||
|
||||
got, err := json.Marshal(job.Type)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Validate that a session token is generated during the job.
|
||||
sessionToken := job.Type.(*proto.AcquiredJob_WorkspaceBuild_).WorkspaceBuild.Metadata.WorkspaceOwnerSessionToken
|
||||
require.NotEmpty(t, sessionToken)
|
||||
toks := strings.Split(sessionToken, "-")
|
||||
require.Len(t, toks, 2, "invalid api key")
|
||||
key, err := srv.Database.GetAPIKeyByID(ctx, toks[0])
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(srv.DeploymentValues.MaxTokenLifetime.Value().Seconds()), key.LifetimeSeconds)
|
||||
require.WithinDuration(t, time.Now().Add(srv.DeploymentValues.MaxTokenLifetime.Value()), key.ExpiresAt, time.Minute)
|
||||
|
||||
want, err := json.Marshal(&proto.AcquiredJob_WorkspaceBuild_{
|
||||
WorkspaceBuild: &proto.AcquiredJob_WorkspaceBuild{
|
||||
WorkspaceBuildId: build.ID.String(),
|
||||
@ -247,13 +268,59 @@ func TestAcquireJob(t *testing.T) {
|
||||
WorkspaceOwnerId: user.ID.String(),
|
||||
TemplateName: template.Name,
|
||||
TemplateVersion: version.Name,
|
||||
WorkspaceOwnerSessionToken: sessionToken,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.JSONEq(t, string(want), string(got))
|
||||
|
||||
// Assert that we delete the session token whenever
|
||||
// a stop is issued.
|
||||
stopbuild := dbgen.WorkspaceBuild(t, srv.Database, database.WorkspaceBuild{
|
||||
WorkspaceID: workspace.ID,
|
||||
BuildNumber: 2,
|
||||
JobID: uuid.New(),
|
||||
TemplateVersionID: version.ID,
|
||||
Transition: database.WorkspaceTransitionStop,
|
||||
Reason: database.BuildReasonInitiator,
|
||||
})
|
||||
_ = dbgen.ProvisionerJob(t, srv.Database, database.ProvisionerJob{
|
||||
ID: stopbuild.ID,
|
||||
InitiatorID: user.ID,
|
||||
Provisioner: database.ProvisionerTypeEcho,
|
||||
StorageMethod: database.ProvisionerStorageMethodFile,
|
||||
FileID: file.ID,
|
||||
Type: database.ProvisionerJobTypeWorkspaceBuild,
|
||||
Input: must(json.Marshal(provisionerdserver.WorkspaceProvisionJob{
|
||||
WorkspaceBuildID: stopbuild.ID,
|
||||
})),
|
||||
})
|
||||
|
||||
stopPublished := make(chan struct{})
|
||||
closeStopSubscribe, err := srv.Pubsub.Subscribe(codersdk.WorkspaceNotifyChannel(workspace.ID), func(_ context.Context, _ []byte) {
|
||||
close(stopPublished)
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer closeStopSubscribe()
|
||||
|
||||
// Grab jobs until we find the workspace build job. There is also
|
||||
// an import version job that we need to ignore.
|
||||
job, err = srv.AcquireJob(ctx, nil)
|
||||
require.NoError(t, err)
|
||||
_, ok := job.Type.(*proto.AcquiredJob_WorkspaceBuild_)
|
||||
require.True(t, ok, "acquired job not a workspace build?")
|
||||
|
||||
<-stopPublished
|
||||
|
||||
// Validate that a session token is deleted during a stop job.
|
||||
sessionToken = job.Type.(*proto.AcquiredJob_WorkspaceBuild_).WorkspaceBuild.Metadata.WorkspaceOwnerSessionToken
|
||||
require.Empty(t, sessionToken)
|
||||
_, err = srv.Database.GetAPIKeyByID(ctx, key.ID)
|
||||
require.ErrorIs(t, err, sql.ErrNoRows)
|
||||
})
|
||||
|
||||
t.Run("TemplateVersionDryRun", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
srv := setup(t, false)
|
||||
@ -1205,6 +1272,7 @@ func setup(t *testing.T, ignoreLogErrors bool) *provisionerdserver.Server {
|
||||
Auditor: mockAuditor(),
|
||||
TemplateScheduleStore: testTemplateScheduleStore(),
|
||||
Tracer: trace.NewNoopTracerProvider().Tracer("noop"),
|
||||
DeploymentValues: &codersdk.DeploymentValues{},
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -19,6 +19,7 @@ import (
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/coderd/apikey"
|
||||
"github.com/coder/coder/coderd/audit"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/dbauthz"
|
||||
@ -129,10 +130,11 @@ func (api *API) postLogin(rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
//nolint:gocritic // Creating the API key as the user instead of as system.
|
||||
cookie, key, err := api.createAPIKey(dbauthz.As(ctx, userSubj), createAPIKeyParams{
|
||||
UserID: user.ID,
|
||||
LoginType: database.LoginTypePassword,
|
||||
RemoteAddr: r.RemoteAddr,
|
||||
cookie, key, err := api.createAPIKey(dbauthz.As(ctx, userSubj), apikey.CreateParams{
|
||||
UserID: user.ID,
|
||||
LoginType: database.LoginTypePassword,
|
||||
RemoteAddr: r.RemoteAddr,
|
||||
DeploymentValues: api.DeploymentValues,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
@ -1011,10 +1013,11 @@ func (api *API) oauthLogin(r *http.Request, params oauthLoginParams) (*http.Cook
|
||||
}
|
||||
|
||||
//nolint:gocritic
|
||||
cookie, key, err := api.createAPIKey(dbauthz.AsSystemRestricted(ctx), createAPIKeyParams{
|
||||
UserID: user.ID,
|
||||
LoginType: params.LoginType,
|
||||
RemoteAddr: r.RemoteAddr,
|
||||
cookie, key, err := api.createAPIKey(dbauthz.AsSystemRestricted(ctx), apikey.CreateParams{
|
||||
UserID: user.ID,
|
||||
LoginType: params.LoginType,
|
||||
DeploymentValues: api.DeploymentValues,
|
||||
RemoteAddr: r.RemoteAddr,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, database.APIKey{}, xerrors.Errorf("create API key: %w", err)
|
||||
|
@ -11,6 +11,7 @@ import (
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/coderd/apikey"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
@ -109,12 +110,13 @@ func (api *API) workspaceApplicationAuth(rw http.ResponseWriter, r *http.Request
|
||||
exp = database.Now().Add(api.DeploymentValues.SessionDuration.Value())
|
||||
lifetimeSeconds = int64(api.DeploymentValues.SessionDuration.Value().Seconds())
|
||||
}
|
||||
cookie, _, err := api.createAPIKey(ctx, createAPIKeyParams{
|
||||
UserID: apiKey.UserID,
|
||||
LoginType: database.LoginTypePassword,
|
||||
ExpiresAt: exp,
|
||||
LifetimeSeconds: lifetimeSeconds,
|
||||
Scope: database.APIKeyScopeApplicationConnect,
|
||||
cookie, _, err := api.createAPIKey(ctx, apikey.CreateParams{
|
||||
UserID: apiKey.UserID,
|
||||
LoginType: database.LoginTypePassword,
|
||||
DeploymentValues: api.DeploymentValues,
|
||||
ExpiresAt: exp,
|
||||
LifetimeSeconds: lifetimeSeconds,
|
||||
Scope: database.APIKeyScopeApplicationConnect,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
|
Reference in New Issue
Block a user