mirror of
https://github.com/coder/coder.git
synced 2025-07-12 00:14:10 +00:00
feat: reinitialize agents when a prebuilt workspace is claimed (#17475)
This pull request allows coder workspace agents to be reinitialized when a prebuilt workspace is claimed by a user. This facilitates the transfer of ownership between the anonymous prebuilds system user and the new owner of the workspace. Only a single agent per prebuilt workspace is supported for now, but plumbing has already been done to facilitate the seamless transition to multi-agent support. --------- Signed-off-by: Danny Kopping <dannykopping@gmail.com> Co-authored-by: Danny Kopping <dannykopping@gmail.com>
This commit is contained in:
@ -19,12 +19,15 @@ import (
|
||||
"tailscale.com/tailcfg"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/retry"
|
||||
"github.com/coder/websocket"
|
||||
|
||||
"github.com/coder/coder/v2/agent/proto"
|
||||
"github.com/coder/coder/v2/apiversion"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/drpcsdk"
|
||||
tailnetproto "github.com/coder/coder/v2/tailnet/proto"
|
||||
"github.com/coder/websocket"
|
||||
)
|
||||
|
||||
// ExternalLogSourceID is the statically-defined ID of a log-source that
|
||||
@ -686,3 +689,188 @@ func LogsNotifyChannel(agentID uuid.UUID) string {
|
||||
type LogsNotifyMessage struct {
|
||||
CreatedAfter int64 `json:"created_after"`
|
||||
}
|
||||
|
||||
type ReinitializationReason string
|
||||
|
||||
const (
|
||||
ReinitializeReasonPrebuildClaimed ReinitializationReason = "prebuild_claimed"
|
||||
)
|
||||
|
||||
type ReinitializationEvent struct {
|
||||
WorkspaceID uuid.UUID
|
||||
Reason ReinitializationReason `json:"reason"`
|
||||
}
|
||||
|
||||
func PrebuildClaimedChannel(id uuid.UUID) string {
|
||||
return fmt.Sprintf("prebuild_claimed_%s", id)
|
||||
}
|
||||
|
||||
// WaitForReinit polls a SSE endpoint, and receives an event back under the following conditions:
|
||||
// - ping: ignored, keepalive
|
||||
// - prebuild claimed: a prebuilt workspace is claimed, so the agent must reinitialize.
|
||||
func (c *Client) WaitForReinit(ctx context.Context) (*ReinitializationEvent, error) {
|
||||
rpcURL, err := c.SDK.URL.Parse("/api/v2/workspaceagents/me/reinit")
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("parse url: %w", err)
|
||||
}
|
||||
|
||||
jar, err := cookiejar.New(nil)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("create cookie jar: %w", err)
|
||||
}
|
||||
jar.SetCookies(rpcURL, []*http.Cookie{{
|
||||
Name: codersdk.SessionTokenCookie,
|
||||
Value: c.SDK.SessionToken(),
|
||||
}})
|
||||
httpClient := &http.Client{
|
||||
Jar: jar,
|
||||
Transport: c.SDK.HTTPClient.Transport,
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, rpcURL.String(), nil)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("build request: %w", err)
|
||||
}
|
||||
|
||||
res, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("execute request: %w", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return nil, codersdk.ReadBodyAsError(res)
|
||||
}
|
||||
|
||||
reinitEvent, err := NewSSEAgentReinitReceiver(res.Body).Receive(ctx)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("listening for reinitialization events: %w", err)
|
||||
}
|
||||
return reinitEvent, nil
|
||||
}
|
||||
|
||||
func WaitForReinitLoop(ctx context.Context, logger slog.Logger, client *Client) <-chan ReinitializationEvent {
|
||||
reinitEvents := make(chan ReinitializationEvent)
|
||||
|
||||
go func() {
|
||||
for retrier := retry.New(100*time.Millisecond, 10*time.Second); retrier.Wait(ctx); {
|
||||
logger.Debug(ctx, "waiting for agent reinitialization instructions")
|
||||
reinitEvent, err := client.WaitForReinit(ctx)
|
||||
if err != nil {
|
||||
logger.Error(ctx, "failed to wait for agent reinitialization instructions", slog.Error(err))
|
||||
continue
|
||||
}
|
||||
retrier.Reset()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
close(reinitEvents)
|
||||
return
|
||||
case reinitEvents <- *reinitEvent:
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return reinitEvents
|
||||
}
|
||||
|
||||
func NewSSEAgentReinitTransmitter(logger slog.Logger, rw http.ResponseWriter, r *http.Request) *SSEAgentReinitTransmitter {
|
||||
return &SSEAgentReinitTransmitter{logger: logger, rw: rw, r: r}
|
||||
}
|
||||
|
||||
type SSEAgentReinitTransmitter struct {
|
||||
rw http.ResponseWriter
|
||||
r *http.Request
|
||||
logger slog.Logger
|
||||
}
|
||||
|
||||
var (
|
||||
ErrTransmissionSourceClosed = xerrors.New("transmission source closed")
|
||||
ErrTransmissionTargetClosed = xerrors.New("transmission target closed")
|
||||
)
|
||||
|
||||
// Transmit will read from the given chan and send events for as long as:
|
||||
// * the chan remains open
|
||||
// * the context has not been canceled
|
||||
// * not timed out
|
||||
// * the connection to the receiver remains open
|
||||
func (s *SSEAgentReinitTransmitter) Transmit(ctx context.Context, reinitEvents <-chan ReinitializationEvent) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
sseSendEvent, sseSenderClosed, err := httpapi.ServerSentEventSender(s.rw, s.r)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to create sse transmitter: %w", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
// Block returning until the ServerSentEventSender is closed
|
||||
// to avoid a race condition where we might write or flush to rw after the handler returns.
|
||||
<-sseSenderClosed
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-sseSenderClosed:
|
||||
return ErrTransmissionTargetClosed
|
||||
case reinitEvent, ok := <-reinitEvents:
|
||||
if !ok {
|
||||
return ErrTransmissionSourceClosed
|
||||
}
|
||||
err := sseSendEvent(codersdk.ServerSentEvent{
|
||||
Type: codersdk.ServerSentEventTypeData,
|
||||
Data: reinitEvent,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func NewSSEAgentReinitReceiver(r io.ReadCloser) *SSEAgentReinitReceiver {
|
||||
return &SSEAgentReinitReceiver{r: r}
|
||||
}
|
||||
|
||||
type SSEAgentReinitReceiver struct {
|
||||
r io.ReadCloser
|
||||
}
|
||||
|
||||
func (s *SSEAgentReinitReceiver) Receive(ctx context.Context) (*ReinitializationEvent, error) {
|
||||
nextEvent := codersdk.ServerSentEventReader(ctx, s.r)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
sse, err := nextEvent()
|
||||
switch {
|
||||
case err != nil:
|
||||
return nil, xerrors.Errorf("failed to read server-sent event: %w", err)
|
||||
case sse.Type == codersdk.ServerSentEventTypeError:
|
||||
return nil, xerrors.Errorf("unexpected server sent event type error")
|
||||
case sse.Type == codersdk.ServerSentEventTypePing:
|
||||
continue
|
||||
case sse.Type != codersdk.ServerSentEventTypeData:
|
||||
return nil, xerrors.Errorf("unexpected server sent event type: %s", sse.Type)
|
||||
}
|
||||
|
||||
// At this point we know that the sent event is of type codersdk.ServerSentEventTypeData
|
||||
var reinitEvent ReinitializationEvent
|
||||
b, ok := sse.Data.([]byte)
|
||||
if !ok {
|
||||
return nil, xerrors.Errorf("expected data as []byte, got %T", sse.Data)
|
||||
}
|
||||
err = json.Unmarshal(b, &reinitEvent)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("unmarshal reinit response: %w", err)
|
||||
}
|
||||
return &reinitEvent, nil
|
||||
}
|
||||
}
|
||||
|
122
codersdk/agentsdk/agentsdk_test.go
Normal file
122
codersdk/agentsdk/agentsdk_test.go
Normal file
@ -0,0 +1,122 @@
|
||||
package agentsdk_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestStreamAgentReinitEvents(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("transmitted events are received", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
eventToSend := agentsdk.ReinitializationEvent{
|
||||
WorkspaceID: uuid.New(),
|
||||
Reason: agentsdk.ReinitializeReasonPrebuildClaimed,
|
||||
}
|
||||
|
||||
events := make(chan agentsdk.ReinitializationEvent, 1)
|
||||
events <- eventToSend
|
||||
|
||||
transmitCtx := testutil.Context(t, testutil.WaitShort)
|
||||
transmitErrCh := make(chan error, 1)
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
transmitter := agentsdk.NewSSEAgentReinitTransmitter(slogtest.Make(t, nil), w, r)
|
||||
transmitErrCh <- transmitter.Transmit(transmitCtx, events)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
requestCtx := testutil.Context(t, testutil.WaitShort)
|
||||
req, err := http.NewRequestWithContext(requestCtx, "GET", srv.URL, nil)
|
||||
require.NoError(t, err)
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
receiveCtx := testutil.Context(t, testutil.WaitShort)
|
||||
receiver := agentsdk.NewSSEAgentReinitReceiver(resp.Body)
|
||||
sentEvent, receiveErr := receiver.Receive(receiveCtx)
|
||||
require.Nil(t, receiveErr)
|
||||
require.Equal(t, eventToSend, *sentEvent)
|
||||
})
|
||||
|
||||
t.Run("doesn't transmit events if the transmitter context is canceled", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
eventToSend := agentsdk.ReinitializationEvent{
|
||||
WorkspaceID: uuid.New(),
|
||||
Reason: agentsdk.ReinitializeReasonPrebuildClaimed,
|
||||
}
|
||||
|
||||
events := make(chan agentsdk.ReinitializationEvent, 1)
|
||||
events <- eventToSend
|
||||
|
||||
transmitCtx, cancelTransmit := context.WithCancel(testutil.Context(t, testutil.WaitShort))
|
||||
cancelTransmit()
|
||||
transmitErrCh := make(chan error, 1)
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
transmitter := agentsdk.NewSSEAgentReinitTransmitter(slogtest.Make(t, nil), w, r)
|
||||
transmitErrCh <- transmitter.Transmit(transmitCtx, events)
|
||||
}))
|
||||
|
||||
defer srv.Close()
|
||||
|
||||
requestCtx := testutil.Context(t, testutil.WaitShort)
|
||||
req, err := http.NewRequestWithContext(requestCtx, "GET", srv.URL, nil)
|
||||
require.NoError(t, err)
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
receiveCtx := testutil.Context(t, testutil.WaitShort)
|
||||
receiver := agentsdk.NewSSEAgentReinitReceiver(resp.Body)
|
||||
sentEvent, receiveErr := receiver.Receive(receiveCtx)
|
||||
require.Nil(t, sentEvent)
|
||||
require.ErrorIs(t, receiveErr, io.EOF)
|
||||
})
|
||||
|
||||
t.Run("does not receive events if the receiver context is canceled", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
eventToSend := agentsdk.ReinitializationEvent{
|
||||
WorkspaceID: uuid.New(),
|
||||
Reason: agentsdk.ReinitializeReasonPrebuildClaimed,
|
||||
}
|
||||
|
||||
events := make(chan agentsdk.ReinitializationEvent, 1)
|
||||
events <- eventToSend
|
||||
|
||||
transmitCtx := testutil.Context(t, testutil.WaitShort)
|
||||
transmitErrCh := make(chan error, 1)
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
transmitter := agentsdk.NewSSEAgentReinitTransmitter(slogtest.Make(t, nil), w, r)
|
||||
transmitErrCh <- transmitter.Transmit(transmitCtx, events)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
requestCtx := testutil.Context(t, testutil.WaitShort)
|
||||
req, err := http.NewRequestWithContext(requestCtx, "GET", srv.URL, nil)
|
||||
require.NoError(t, err)
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
receiveCtx, cancelReceive := context.WithCancel(context.Background())
|
||||
cancelReceive()
|
||||
receiver := agentsdk.NewSSEAgentReinitReceiver(resp.Body)
|
||||
sentEvent, receiveErr := receiver.Receive(receiveCtx)
|
||||
require.Nil(t, sentEvent)
|
||||
require.ErrorIs(t, receiveErr, context.Canceled)
|
||||
})
|
||||
}
|
@ -631,7 +631,7 @@ func (h *HeaderTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
}
|
||||
}
|
||||
if h.Transport == nil {
|
||||
h.Transport = http.DefaultTransport
|
||||
return http.DefaultTransport.RoundTrip(req)
|
||||
}
|
||||
return h.Transport.RoundTrip(req)
|
||||
}
|
||||
|
Reference in New Issue
Block a user