feat: reinitialize agents when a prebuilt workspace is claimed (#17475)

This pull request allows coder workspace agents to be reinitialized when
a prebuilt workspace is claimed by a user. This facilitates the transfer
of ownership between the anonymous prebuilds system user and the new
owner of the workspace.

Only a single agent per prebuilt workspace is supported for now, but
plumbing has already been done to facilitate the seamless transition to
multi-agent support.

---------

Signed-off-by: Danny Kopping <dannykopping@gmail.com>
Co-authored-by: Danny Kopping <dannykopping@gmail.com>
This commit is contained in:
Sas Swart
2025-05-14 14:15:36 +02:00
committed by GitHub
parent fcbdd1a28e
commit 425ee6fa55
38 changed files with 2184 additions and 449 deletions

View File

@ -19,12 +19,15 @@ import (
"tailscale.com/tailcfg"
"cdr.dev/slog"
"github.com/coder/retry"
"github.com/coder/websocket"
"github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/apiversion"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/drpcsdk"
tailnetproto "github.com/coder/coder/v2/tailnet/proto"
"github.com/coder/websocket"
)
// ExternalLogSourceID is the statically-defined ID of a log-source that
@ -686,3 +689,188 @@ func LogsNotifyChannel(agentID uuid.UUID) string {
type LogsNotifyMessage struct {
CreatedAfter int64 `json:"created_after"`
}
type ReinitializationReason string
const (
ReinitializeReasonPrebuildClaimed ReinitializationReason = "prebuild_claimed"
)
type ReinitializationEvent struct {
WorkspaceID uuid.UUID
Reason ReinitializationReason `json:"reason"`
}
func PrebuildClaimedChannel(id uuid.UUID) string {
return fmt.Sprintf("prebuild_claimed_%s", id)
}
// WaitForReinit polls a SSE endpoint, and receives an event back under the following conditions:
// - ping: ignored, keepalive
// - prebuild claimed: a prebuilt workspace is claimed, so the agent must reinitialize.
func (c *Client) WaitForReinit(ctx context.Context) (*ReinitializationEvent, error) {
rpcURL, err := c.SDK.URL.Parse("/api/v2/workspaceagents/me/reinit")
if err != nil {
return nil, xerrors.Errorf("parse url: %w", err)
}
jar, err := cookiejar.New(nil)
if err != nil {
return nil, xerrors.Errorf("create cookie jar: %w", err)
}
jar.SetCookies(rpcURL, []*http.Cookie{{
Name: codersdk.SessionTokenCookie,
Value: c.SDK.SessionToken(),
}})
httpClient := &http.Client{
Jar: jar,
Transport: c.SDK.HTTPClient.Transport,
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, rpcURL.String(), nil)
if err != nil {
return nil, xerrors.Errorf("build request: %w", err)
}
res, err := httpClient.Do(req)
if err != nil {
return nil, xerrors.Errorf("execute request: %w", err)
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return nil, codersdk.ReadBodyAsError(res)
}
reinitEvent, err := NewSSEAgentReinitReceiver(res.Body).Receive(ctx)
if err != nil {
return nil, xerrors.Errorf("listening for reinitialization events: %w", err)
}
return reinitEvent, nil
}
func WaitForReinitLoop(ctx context.Context, logger slog.Logger, client *Client) <-chan ReinitializationEvent {
reinitEvents := make(chan ReinitializationEvent)
go func() {
for retrier := retry.New(100*time.Millisecond, 10*time.Second); retrier.Wait(ctx); {
logger.Debug(ctx, "waiting for agent reinitialization instructions")
reinitEvent, err := client.WaitForReinit(ctx)
if err != nil {
logger.Error(ctx, "failed to wait for agent reinitialization instructions", slog.Error(err))
continue
}
retrier.Reset()
select {
case <-ctx.Done():
close(reinitEvents)
return
case reinitEvents <- *reinitEvent:
}
}
}()
return reinitEvents
}
func NewSSEAgentReinitTransmitter(logger slog.Logger, rw http.ResponseWriter, r *http.Request) *SSEAgentReinitTransmitter {
return &SSEAgentReinitTransmitter{logger: logger, rw: rw, r: r}
}
type SSEAgentReinitTransmitter struct {
rw http.ResponseWriter
r *http.Request
logger slog.Logger
}
var (
ErrTransmissionSourceClosed = xerrors.New("transmission source closed")
ErrTransmissionTargetClosed = xerrors.New("transmission target closed")
)
// Transmit will read from the given chan and send events for as long as:
// * the chan remains open
// * the context has not been canceled
// * not timed out
// * the connection to the receiver remains open
func (s *SSEAgentReinitTransmitter) Transmit(ctx context.Context, reinitEvents <-chan ReinitializationEvent) error {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
sseSendEvent, sseSenderClosed, err := httpapi.ServerSentEventSender(s.rw, s.r)
if err != nil {
return xerrors.Errorf("failed to create sse transmitter: %w", err)
}
defer func() {
// Block returning until the ServerSentEventSender is closed
// to avoid a race condition where we might write or flush to rw after the handler returns.
<-sseSenderClosed
}()
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-sseSenderClosed:
return ErrTransmissionTargetClosed
case reinitEvent, ok := <-reinitEvents:
if !ok {
return ErrTransmissionSourceClosed
}
err := sseSendEvent(codersdk.ServerSentEvent{
Type: codersdk.ServerSentEventTypeData,
Data: reinitEvent,
})
if err != nil {
return err
}
}
}
}
func NewSSEAgentReinitReceiver(r io.ReadCloser) *SSEAgentReinitReceiver {
return &SSEAgentReinitReceiver{r: r}
}
type SSEAgentReinitReceiver struct {
r io.ReadCloser
}
func (s *SSEAgentReinitReceiver) Receive(ctx context.Context) (*ReinitializationEvent, error) {
nextEvent := codersdk.ServerSentEventReader(ctx, s.r)
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
sse, err := nextEvent()
switch {
case err != nil:
return nil, xerrors.Errorf("failed to read server-sent event: %w", err)
case sse.Type == codersdk.ServerSentEventTypeError:
return nil, xerrors.Errorf("unexpected server sent event type error")
case sse.Type == codersdk.ServerSentEventTypePing:
continue
case sse.Type != codersdk.ServerSentEventTypeData:
return nil, xerrors.Errorf("unexpected server sent event type: %s", sse.Type)
}
// At this point we know that the sent event is of type codersdk.ServerSentEventTypeData
var reinitEvent ReinitializationEvent
b, ok := sse.Data.([]byte)
if !ok {
return nil, xerrors.Errorf("expected data as []byte, got %T", sse.Data)
}
err = json.Unmarshal(b, &reinitEvent)
if err != nil {
return nil, xerrors.Errorf("unmarshal reinit response: %w", err)
}
return &reinitEvent, nil
}
}