mirror of
https://github.com/coder/coder.git
synced 2025-07-15 22:20:27 +00:00
feat: add session token injection to provisioner (#7461)
This commit is contained in:
@ -27,6 +27,7 @@ import (
|
||||
|
||||
"cdr.dev/slog"
|
||||
|
||||
"github.com/coder/coder/coderd/apikey"
|
||||
"github.com/coder/coder/coderd/audit"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/dbauthz"
|
||||
@ -62,6 +63,7 @@ type Server struct {
|
||||
QuotaCommitter *atomic.Pointer[proto.QuotaCommitter]
|
||||
Auditor *atomic.Pointer[audit.Auditor]
|
||||
TemplateScheduleStore *atomic.Pointer[schedule.TemplateScheduleStore]
|
||||
DeploymentValues *codersdk.DeploymentValues
|
||||
|
||||
AcquireJobDebounce time.Duration
|
||||
OIDCConfig httpmw.OAuth2Config
|
||||
@ -193,6 +195,20 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac
|
||||
}
|
||||
}
|
||||
|
||||
var sessionToken string
|
||||
switch workspaceBuild.Transition {
|
||||
case database.WorkspaceTransitionStart:
|
||||
sessionToken, err = server.regenerateSessionToken(ctx, owner, workspace)
|
||||
if err != nil {
|
||||
return nil, failJob(fmt.Sprintf("regenerate session token: %s", err))
|
||||
}
|
||||
case database.WorkspaceTransitionStop, database.WorkspaceTransitionDelete:
|
||||
err = deleteSessionToken(ctx, server.Database, workspace)
|
||||
if err != nil {
|
||||
return nil, failJob(fmt.Sprintf("delete session token: %s", err))
|
||||
}
|
||||
}
|
||||
|
||||
// Compute parameters for the workspace to consume.
|
||||
parameters, err := parameter.Compute(ctx, server.Database, parameter.ComputeScope{
|
||||
TemplateImportJobID: templateVersion.JobID,
|
||||
@ -286,6 +302,7 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac
|
||||
WorkspaceOwnerId: owner.ID.String(),
|
||||
TemplateName: template.Name,
|
||||
TemplateVersion: templateVersion.Name,
|
||||
WorkspaceOwnerSessionToken: sessionToken,
|
||||
},
|
||||
LogLevel: input.LogLevel,
|
||||
},
|
||||
@ -1410,6 +1427,64 @@ func InsertWorkspaceResource(ctx context.Context, db database.Store, jobID uuid.
|
||||
return nil
|
||||
}
|
||||
|
||||
func workspaceSessionTokenName(workspace database.Workspace) string {
|
||||
return fmt.Sprintf("%s_%s_session_token", workspace.OwnerID, workspace.ID)
|
||||
}
|
||||
|
||||
func (server *Server) regenerateSessionToken(ctx context.Context, user database.User, workspace database.Workspace) (string, error) {
|
||||
newkey, sessionToken, err := apikey.Generate(apikey.CreateParams{
|
||||
UserID: user.ID,
|
||||
LoginType: user.LoginType,
|
||||
DeploymentValues: server.DeploymentValues,
|
||||
TokenName: workspaceSessionTokenName(workspace),
|
||||
LifetimeSeconds: int64(server.DeploymentValues.MaxTokenLifetime.Value().Seconds()),
|
||||
})
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("generate API key: %w", err)
|
||||
}
|
||||
|
||||
err = server.Database.InTx(func(tx database.Store) error {
|
||||
err := deleteSessionToken(ctx, tx, workspace)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("delete session token: %w", err)
|
||||
}
|
||||
|
||||
_, err = tx.InsertAPIKey(ctx, newkey)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("insert API key: %w", err)
|
||||
}
|
||||
return nil
|
||||
}, nil)
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("create API key: %w", err)
|
||||
}
|
||||
|
||||
return sessionToken, nil
|
||||
}
|
||||
|
||||
func deleteSessionToken(ctx context.Context, db database.Store, workspace database.Workspace) error {
|
||||
err := db.InTx(func(tx database.Store) error {
|
||||
key, err := tx.GetAPIKeyByName(ctx, database.GetAPIKeyByNameParams{
|
||||
UserID: workspace.OwnerID,
|
||||
TokenName: workspaceSessionTokenName(workspace),
|
||||
})
|
||||
if err == nil {
|
||||
err = tx.DeleteAPIKeyByID(ctx, key.ID)
|
||||
}
|
||||
|
||||
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
|
||||
return xerrors.Errorf("get api key by name: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}, nil)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("in tx: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// obtainOIDCAccessToken returns a valid OpenID Connect access token
|
||||
// for the user if it's able to obtain one, otherwise it returns an empty string.
|
||||
func obtainOIDCAccessToken(ctx context.Context, db database.Store, oidcConfig httpmw.OAuth2Config, userID uuid.UUID) (string, error) {
|
||||
|
Reference in New Issue
Block a user