mirror of
https://github.com/coder/coder.git
synced 2025-07-12 00:14:10 +00:00
feat: reinitialize agents when a prebuilt workspace is claimed (#17475)
This pull request allows coder workspace agents to be reinitialized when a prebuilt workspace is claimed by a user. This facilitates the transfer of ownership between the anonymous prebuilds system user and the new owner of the workspace. Only a single agent per prebuilt workspace is supported for now, but plumbing has already been done to facilitate the seamless transition to multi-agent support. --------- Signed-off-by: Danny Kopping <dannykopping@gmail.com> Co-authored-by: Danny Kopping <dannykopping@gmail.com>
This commit is contained in:
@ -19,12 +19,15 @@ import (
|
||||
"tailscale.com/tailcfg"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/retry"
|
||||
"github.com/coder/websocket"
|
||||
|
||||
"github.com/coder/coder/v2/agent/proto"
|
||||
"github.com/coder/coder/v2/apiversion"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/drpcsdk"
|
||||
tailnetproto "github.com/coder/coder/v2/tailnet/proto"
|
||||
"github.com/coder/websocket"
|
||||
)
|
||||
|
||||
// ExternalLogSourceID is the statically-defined ID of a log-source that
|
||||
@ -686,3 +689,188 @@ func LogsNotifyChannel(agentID uuid.UUID) string {
|
||||
type LogsNotifyMessage struct {
|
||||
CreatedAfter int64 `json:"created_after"`
|
||||
}
|
||||
|
||||
type ReinitializationReason string
|
||||
|
||||
const (
|
||||
ReinitializeReasonPrebuildClaimed ReinitializationReason = "prebuild_claimed"
|
||||
)
|
||||
|
||||
type ReinitializationEvent struct {
|
||||
WorkspaceID uuid.UUID
|
||||
Reason ReinitializationReason `json:"reason"`
|
||||
}
|
||||
|
||||
func PrebuildClaimedChannel(id uuid.UUID) string {
|
||||
return fmt.Sprintf("prebuild_claimed_%s", id)
|
||||
}
|
||||
|
||||
// WaitForReinit polls a SSE endpoint, and receives an event back under the following conditions:
|
||||
// - ping: ignored, keepalive
|
||||
// - prebuild claimed: a prebuilt workspace is claimed, so the agent must reinitialize.
|
||||
func (c *Client) WaitForReinit(ctx context.Context) (*ReinitializationEvent, error) {
|
||||
rpcURL, err := c.SDK.URL.Parse("/api/v2/workspaceagents/me/reinit")
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("parse url: %w", err)
|
||||
}
|
||||
|
||||
jar, err := cookiejar.New(nil)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("create cookie jar: %w", err)
|
||||
}
|
||||
jar.SetCookies(rpcURL, []*http.Cookie{{
|
||||
Name: codersdk.SessionTokenCookie,
|
||||
Value: c.SDK.SessionToken(),
|
||||
}})
|
||||
httpClient := &http.Client{
|
||||
Jar: jar,
|
||||
Transport: c.SDK.HTTPClient.Transport,
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, rpcURL.String(), nil)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("build request: %w", err)
|
||||
}
|
||||
|
||||
res, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("execute request: %w", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return nil, codersdk.ReadBodyAsError(res)
|
||||
}
|
||||
|
||||
reinitEvent, err := NewSSEAgentReinitReceiver(res.Body).Receive(ctx)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("listening for reinitialization events: %w", err)
|
||||
}
|
||||
return reinitEvent, nil
|
||||
}
|
||||
|
||||
func WaitForReinitLoop(ctx context.Context, logger slog.Logger, client *Client) <-chan ReinitializationEvent {
|
||||
reinitEvents := make(chan ReinitializationEvent)
|
||||
|
||||
go func() {
|
||||
for retrier := retry.New(100*time.Millisecond, 10*time.Second); retrier.Wait(ctx); {
|
||||
logger.Debug(ctx, "waiting for agent reinitialization instructions")
|
||||
reinitEvent, err := client.WaitForReinit(ctx)
|
||||
if err != nil {
|
||||
logger.Error(ctx, "failed to wait for agent reinitialization instructions", slog.Error(err))
|
||||
continue
|
||||
}
|
||||
retrier.Reset()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
close(reinitEvents)
|
||||
return
|
||||
case reinitEvents <- *reinitEvent:
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return reinitEvents
|
||||
}
|
||||
|
||||
func NewSSEAgentReinitTransmitter(logger slog.Logger, rw http.ResponseWriter, r *http.Request) *SSEAgentReinitTransmitter {
|
||||
return &SSEAgentReinitTransmitter{logger: logger, rw: rw, r: r}
|
||||
}
|
||||
|
||||
type SSEAgentReinitTransmitter struct {
|
||||
rw http.ResponseWriter
|
||||
r *http.Request
|
||||
logger slog.Logger
|
||||
}
|
||||
|
||||
var (
|
||||
ErrTransmissionSourceClosed = xerrors.New("transmission source closed")
|
||||
ErrTransmissionTargetClosed = xerrors.New("transmission target closed")
|
||||
)
|
||||
|
||||
// Transmit will read from the given chan and send events for as long as:
|
||||
// * the chan remains open
|
||||
// * the context has not been canceled
|
||||
// * not timed out
|
||||
// * the connection to the receiver remains open
|
||||
func (s *SSEAgentReinitTransmitter) Transmit(ctx context.Context, reinitEvents <-chan ReinitializationEvent) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
sseSendEvent, sseSenderClosed, err := httpapi.ServerSentEventSender(s.rw, s.r)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to create sse transmitter: %w", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
// Block returning until the ServerSentEventSender is closed
|
||||
// to avoid a race condition where we might write or flush to rw after the handler returns.
|
||||
<-sseSenderClosed
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-sseSenderClosed:
|
||||
return ErrTransmissionTargetClosed
|
||||
case reinitEvent, ok := <-reinitEvents:
|
||||
if !ok {
|
||||
return ErrTransmissionSourceClosed
|
||||
}
|
||||
err := sseSendEvent(codersdk.ServerSentEvent{
|
||||
Type: codersdk.ServerSentEventTypeData,
|
||||
Data: reinitEvent,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func NewSSEAgentReinitReceiver(r io.ReadCloser) *SSEAgentReinitReceiver {
|
||||
return &SSEAgentReinitReceiver{r: r}
|
||||
}
|
||||
|
||||
type SSEAgentReinitReceiver struct {
|
||||
r io.ReadCloser
|
||||
}
|
||||
|
||||
func (s *SSEAgentReinitReceiver) Receive(ctx context.Context) (*ReinitializationEvent, error) {
|
||||
nextEvent := codersdk.ServerSentEventReader(ctx, s.r)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
sse, err := nextEvent()
|
||||
switch {
|
||||
case err != nil:
|
||||
return nil, xerrors.Errorf("failed to read server-sent event: %w", err)
|
||||
case sse.Type == codersdk.ServerSentEventTypeError:
|
||||
return nil, xerrors.Errorf("unexpected server sent event type error")
|
||||
case sse.Type == codersdk.ServerSentEventTypePing:
|
||||
continue
|
||||
case sse.Type != codersdk.ServerSentEventTypeData:
|
||||
return nil, xerrors.Errorf("unexpected server sent event type: %s", sse.Type)
|
||||
}
|
||||
|
||||
// At this point we know that the sent event is of type codersdk.ServerSentEventTypeData
|
||||
var reinitEvent ReinitializationEvent
|
||||
b, ok := sse.Data.([]byte)
|
||||
if !ok {
|
||||
return nil, xerrors.Errorf("expected data as []byte, got %T", sse.Data)
|
||||
}
|
||||
err = json.Unmarshal(b, &reinitEvent)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("unmarshal reinit response: %w", err)
|
||||
}
|
||||
return &reinitEvent, nil
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user