diff --git a/coderd/httpapi/httpapi.go b/coderd/httpapi/httpapi.go index a42b79eaa4..5e4648828f 100644 --- a/coderd/httpapi/httpapi.go +++ b/coderd/httpapi/httpapi.go @@ -6,14 +6,13 @@ import ( "encoding/json" "errors" "fmt" - "io" "net/http" "reflect" "strings" - "sync" "time" "github.com/go-playground/validator/v10" + "golang.org/x/xerrors" "github.com/coder/coder/coderd/tracing" "github.com/coder/coder/codersdk" @@ -174,8 +173,7 @@ func WebsocketCloseSprintf(format string, vars ...any) string { return msg } -func ServerSentEventSender(rw http.ResponseWriter, r *http.Request) (func(ctx context.Context, sse codersdk.ServerSentEvent) error, error) { - var mu sync.Mutex +func ServerSentEventSender(rw http.ResponseWriter, r *http.Request) (sendEvent func(ctx context.Context, sse codersdk.ServerSentEvent) error, closed chan struct{}, err error) { h := rw.Header() h.Set("Content-Type", "text/event-stream") h.Set("Cache-Control", "no-cache") @@ -187,37 +185,50 @@ func ServerSentEventSender(rw http.ResponseWriter, r *http.Request) (func(ctx co panic("http.ResponseWriter is not http.Flusher") } - // Send a heartbeat every 15 seconds to avoid the connection being killed. + closed = make(chan struct{}) + type sseEvent struct { + payload []byte + errC chan error + } + eventC := make(chan sseEvent) + + // Synchronized handling of events (no guarantee of order). go func() { + defer close(closed) + + // Send a heartbeat every 15 seconds to avoid the connection being killed. ticker := time.NewTicker(time.Second * 15) defer ticker.Stop() for { + var event sseEvent + select { case <-r.Context().Done(): return + case event = <-eventC: case <-ticker.C: - mu.Lock() - _, err := io.WriteString(rw, fmt.Sprintf("event: %s\n\n", codersdk.ServerSentEventTypePing)) - if err != nil { - mu.Unlock() - return + event = sseEvent{ + payload: []byte(fmt.Sprintf("event: %s\n\n", codersdk.ServerSentEventTypePing)), } - f.Flush() - mu.Unlock() } + + _, err := rw.Write(event.payload) + if event.errC != nil { + event.errC <- err + } + if err != nil { + return + } + f.Flush() } }() - sendEvent := func(ctx context.Context, sse codersdk.ServerSentEvent) error { - if ctx.Err() != nil { - return ctx.Err() - } - + sendEvent = func(ctx context.Context, sse codersdk.ServerSentEvent) error { buf := &bytes.Buffer{} enc := json.NewEncoder(buf) - _, err := buf.Write([]byte(fmt.Sprintf("event: %s\ndata: ", sse.Type))) + _, err := buf.WriteString(fmt.Sprintf("event: %s\ndata: ", sse.Type)) if err != nil { return err } @@ -232,16 +243,32 @@ func ServerSentEventSender(rw http.ResponseWriter, r *http.Request) (func(ctx co return err } - mu.Lock() - defer mu.Unlock() - _, err = rw.Write(buf.Bytes()) - if err != nil { - return err + event := sseEvent{ + payload: buf.Bytes(), + errC: make(chan error, 1), // Buffered to prevent deadlock. } - f.Flush() - return nil + select { + case <-r.Context().Done(): + return r.Context().Err() + case <-ctx.Done(): + return ctx.Err() + case <-closed: + return xerrors.New("server sent event sender closed") + case eventC <- event: + // Re-check closure signals after sending the event to allow + // for early exit. We don't check closed here because it + // can't happen while processing the event. + select { + case <-r.Context().Done(): + return r.Context().Err() + case <-ctx.Done(): + return ctx.Err() + case err := <-event.errC: + return err + } + } } - return sendEvent, nil + return sendEvent, closed, nil } diff --git a/coderd/workspaces.go b/coderd/workspaces.go index 0295dc29d5..2b3d077d10 100644 --- a/coderd/workspaces.go +++ b/coderd/workspaces.go @@ -867,7 +867,7 @@ func (api *API) watchWorkspace(rw http.ResponseWriter, r *http.Request) { return } - sendEvent, err := httpapi.ServerSentEventSender(rw, r) + sendEvent, senderClosed, err := httpapi.ServerSentEventSender(rw, r) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error setting up server-sent events.", @@ -875,6 +875,10 @@ func (api *API) watchWorkspace(rw http.ResponseWriter, r *http.Request) { }) return } + // Prevent handler from returning until the sender is closed. + defer func() { + <-senderClosed + }() // Ignore all trace spans after this, they're not too useful. ctx = trace.ContextWithSpan(ctx, tracing.NoopSpan) @@ -885,6 +889,8 @@ func (api *API) watchWorkspace(rw http.ResponseWriter, r *http.Request) { select { case <-ctx.Done(): return + case <-senderClosed: + return case <-t.C: workspace, err := api.Database.GetWorkspaceByID(ctx, workspace.ID) if err != nil {