mirror of
https://github.com/coder/coder.git
synced 2025-07-08 11:39:50 +00:00
fix pubsub/poll race on provisioner job logs (#2783)
* fix pubsub/poll race on provisioner job logs Signed-off-by: Spike Curtis <spike@coder.com> * only cancel on non-error Signed-off-by: Spike Curtis <spike@coder.com> * Improve logging & comments Signed-off-by: spikecurtis <spike@spikecurtis.com>
This commit is contained in:
@ -380,7 +380,7 @@ func (server *provisionerdServer) UpdateJob(ctx context.Context, request *proto.
|
||||
return nil, xerrors.Errorf("insert job logs: %w", err)
|
||||
}
|
||||
server.Logger.Debug(ctx, "inserted job logs", slog.F("job_id", parsedID))
|
||||
data, err := json.Marshal(logs)
|
||||
data, err := json.Marshal(provisionerJobLogsMessage{Logs: logs})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("marshal job log: %w", err)
|
||||
}
|
||||
@ -549,6 +549,16 @@ func (server *provisionerdServer) FailJob(ctx context.Context, failJob *proto.Fa
|
||||
}
|
||||
case *proto.FailedJob_TemplateImport_:
|
||||
}
|
||||
|
||||
data, err := json.Marshal(provisionerJobLogsMessage{EndOfLogs: true})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("marshal job log: %w", err)
|
||||
}
|
||||
err = server.Pubsub.Publish(provisionerJobLogsChannel(jobID), data)
|
||||
if err != nil {
|
||||
server.Logger.Error(ctx, "failed to publish end of job logs", slog.F("job_id", jobID), slog.Error(err))
|
||||
return nil, xerrors.Errorf("publish end of job logs: %w", err)
|
||||
}
|
||||
return &proto.Empty{}, nil
|
||||
}
|
||||
|
||||
@ -711,6 +721,16 @@ func (server *provisionerdServer) CompleteJob(ctx context.Context, completed *pr
|
||||
reflect.TypeOf(completed.Type).String())
|
||||
}
|
||||
|
||||
data, err := json.Marshal(provisionerJobLogsMessage{EndOfLogs: true})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("marshal job log: %w", err)
|
||||
}
|
||||
err = server.Pubsub.Publish(provisionerJobLogsChannel(jobID), data)
|
||||
if err != nil {
|
||||
server.Logger.Error(ctx, "failed to publish end of job logs", slog.F("job_id", jobID), slog.Error(err))
|
||||
return nil, xerrors.Errorf("publish end of job logs: %w", err)
|
||||
}
|
||||
|
||||
server.Logger.Debug(ctx, "CompleteJob done", slog.F("job_id", jobID))
|
||||
return &proto.Empty{}, nil
|
||||
}
|
||||
|
@ -28,6 +28,7 @@ import (
|
||||
// The combination of these responses should provide all current logs
|
||||
// to the consumer, and future logs are streamed in the follow request.
|
||||
func (api *API) provisionerJobLogs(rw http.ResponseWriter, r *http.Request, job database.ProvisionerJob) {
|
||||
logger := api.Logger.With(slog.F("job_id", job.ID))
|
||||
follow := r.URL.Query().Has("follow")
|
||||
afterRaw := r.URL.Query().Get("after")
|
||||
beforeRaw := r.URL.Query().Get("before")
|
||||
@ -38,6 +39,37 @@ func (api *API) provisionerJobLogs(rw http.ResponseWriter, r *http.Request, job
|
||||
return
|
||||
}
|
||||
|
||||
// if we are following logs, start the subscription before we query the database, so that we don't miss any logs
|
||||
// between the end of our query and the start of the subscription. We might get duplicates, so we'll keep track
|
||||
// of processed IDs.
|
||||
var bufferedLogs <-chan database.ProvisionerJobLog
|
||||
if follow {
|
||||
bl, closeFollow, err := api.followLogs(job.ID)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: "Internal error watching provisioner logs.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
defer closeFollow()
|
||||
bufferedLogs = bl
|
||||
|
||||
// Next query the job itself to see if it is complete. If so, the historical query to the database will return
|
||||
// the full set of logs. It's a little sad to have to query the job again, given that our caller definitely
|
||||
// has, but we need to query it *after* we start following the pubsub to avoid a race condition where the job
|
||||
// completes between the prior query and the start of following the pubsub. A more substantial refactor could
|
||||
// avoid this, but not worth it for one fewer query at this point.
|
||||
job, err = api.Database.GetProvisionerJobByID(r.Context(), job.ID)
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: "Internal error querying job.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
var after time.Time
|
||||
// Only fetch logs created after the time provided.
|
||||
if afterRaw != "" {
|
||||
@ -78,26 +110,27 @@ func (api *API) provisionerJobLogs(rw http.ResponseWriter, r *http.Request, job
|
||||
}
|
||||
}
|
||||
|
||||
if !follow {
|
||||
logs, err := api.Database.GetProvisionerLogsByIDBetween(r.Context(), database.GetProvisionerLogsByIDBetweenParams{
|
||||
JobID: job.ID,
|
||||
CreatedAfter: after,
|
||||
CreatedBefore: before,
|
||||
logs, err := api.Database.GetProvisionerLogsByIDBetween(r.Context(), database.GetProvisionerLogsByIDBetweenParams{
|
||||
JobID: job.ID,
|
||||
CreatedAfter: after,
|
||||
CreatedBefore: before,
|
||||
})
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
err = nil
|
||||
}
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: "Internal error fetching provisioner logs.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
err = nil
|
||||
}
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: "Internal error fetching provisioner logs.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if logs == nil {
|
||||
logs = []database.ProvisionerJobLog{}
|
||||
}
|
||||
return
|
||||
}
|
||||
if logs == nil {
|
||||
logs = []database.ProvisionerJobLog{}
|
||||
}
|
||||
|
||||
if !follow {
|
||||
logger.Debug(r.Context(), "Finished non-follow job logs")
|
||||
httpapi.Write(rw, http.StatusOK, convertProvisionerJobLogs(logs))
|
||||
return
|
||||
}
|
||||
@ -118,82 +151,43 @@ func (api *API) provisionerJobLogs(rw http.ResponseWriter, r *http.Request, job
|
||||
ctx, wsNetConn := websocketNetConn(r.Context(), conn, websocket.MessageText)
|
||||
defer wsNetConn.Close() // Also closes conn.
|
||||
|
||||
bufferedLogs := make(chan database.ProvisionerJobLog, 128)
|
||||
closeSubscribe, err := api.Pubsub.Subscribe(provisionerJobLogsChannel(job.ID), func(ctx context.Context, message []byte) {
|
||||
var logs []database.ProvisionerJobLog
|
||||
err := json.Unmarshal(message, &logs)
|
||||
if err != nil {
|
||||
api.Logger.Warn(ctx, fmt.Sprintf("invalid provisioner job log on channel %q: %s", provisionerJobLogsChannel(job.ID), err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
for _, log := range logs {
|
||||
select {
|
||||
case bufferedLogs <- log:
|
||||
api.Logger.Debug(r.Context(), "subscribe buffered log", slog.F("job_id", job.ID), slog.F("stage", log.Stage))
|
||||
default:
|
||||
// If this overflows users could miss logs streaming. This can happen
|
||||
// if a database request takes a long amount of time, and we get a lot of logs.
|
||||
api.Logger.Warn(ctx, "provisioner job log overflowing channel")
|
||||
}
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: "Internal error watching provisioner logs.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
defer closeSubscribe()
|
||||
|
||||
provisionerJobLogs, err := api.Database.GetProvisionerLogsByIDBetween(ctx, database.GetProvisionerLogsByIDBetweenParams{
|
||||
JobID: job.ID,
|
||||
CreatedAfter: after,
|
||||
CreatedBefore: before,
|
||||
})
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
err = nil
|
||||
}
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
Message: "Internal error fetching provisioner logs.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
logIdsDone := make(map[uuid.UUID]bool)
|
||||
|
||||
// The Go stdlib JSON encoder appends a newline character after message write.
|
||||
encoder := json.NewEncoder(wsNetConn)
|
||||
for _, provisionerJobLog := range provisionerJobLogs {
|
||||
for _, provisionerJobLog := range logs {
|
||||
logIdsDone[provisionerJobLog.ID] = true
|
||||
err = encoder.Encode(convertProvisionerJobLog(provisionerJobLog))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
if job.CompletedAt.Valid {
|
||||
// job was complete before we queried the database for historical logs, meaning we got everything. No need
|
||||
// to stream anything from the bufferedLogs.
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(250 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-r.Context().Done():
|
||||
api.Logger.Debug(context.Background(), "job logs context canceled", slog.F("job_id", job.ID))
|
||||
case <-ctx.Done():
|
||||
logger.Debug(context.Background(), "job logs context canceled")
|
||||
return
|
||||
case log := <-bufferedLogs:
|
||||
api.Logger.Debug(r.Context(), "subscribe encoding log", slog.F("job_id", job.ID), slog.F("stage", log.Stage))
|
||||
err = encoder.Encode(convertProvisionerJobLog(log))
|
||||
if err != nil {
|
||||
case log, ok := <-bufferedLogs:
|
||||
if !ok {
|
||||
logger.Debug(context.Background(), "done with published logs")
|
||||
return
|
||||
}
|
||||
case <-ticker.C:
|
||||
job, err := api.Database.GetProvisionerJobByID(r.Context(), job.ID)
|
||||
if err != nil {
|
||||
api.Logger.Warn(r.Context(), "streaming job logs; checking if completed", slog.Error(err), slog.F("job_id", job.ID.String()))
|
||||
continue
|
||||
}
|
||||
if job.CompletedAt.Valid {
|
||||
api.Logger.Debug(context.Background(), "streaming job logs done; job done", slog.F("job_id", job.ID))
|
||||
return
|
||||
if logIdsDone[log.ID] {
|
||||
logger.Debug(r.Context(), "subscribe duplicated log",
|
||||
slog.F("stage", log.Stage))
|
||||
} else {
|
||||
logger.Debug(r.Context(), "subscribe encoding log",
|
||||
slog.F("stage", log.Stage))
|
||||
err = encoder.Encode(convertProvisionerJobLog(log))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -343,3 +337,43 @@ func convertProvisionerJob(provisionerJob database.ProvisionerJob) codersdk.Prov
|
||||
func provisionerJobLogsChannel(jobID uuid.UUID) string {
|
||||
return fmt.Sprintf("provisioner-log-logs:%s", jobID)
|
||||
}
|
||||
|
||||
// provisionerJobLogsMessage is the message type published on the provisionerJobLogsChannel() channel
|
||||
type provisionerJobLogsMessage struct {
|
||||
EndOfLogs bool `json:"end_of_logs,omitempty"`
|
||||
Logs []database.ProvisionerJobLog `json:"logs,omitempty"`
|
||||
}
|
||||
|
||||
func (api *API) followLogs(jobID uuid.UUID) (<-chan database.ProvisionerJobLog, func(), error) {
|
||||
logger := api.Logger.With(slog.F("job_id", jobID))
|
||||
bufferedLogs := make(chan database.ProvisionerJobLog, 128)
|
||||
closeSubscribe, err := api.Pubsub.Subscribe(provisionerJobLogsChannel(jobID),
|
||||
func(ctx context.Context, message []byte) {
|
||||
jlMsg := provisionerJobLogsMessage{}
|
||||
err := json.Unmarshal(message, &jlMsg)
|
||||
if err != nil {
|
||||
logger.Warn(ctx, "invalid provisioner job log on channel", slog.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
for _, log := range jlMsg.Logs {
|
||||
select {
|
||||
case bufferedLogs <- log:
|
||||
logger.Debug(ctx, "subscribe buffered log", slog.F("stage", log.Stage))
|
||||
default:
|
||||
// If this overflows users could miss logs streaming. This can happen
|
||||
// we get a lot of logs and consumer isn't keeping up. We don't want to block the pubsub,
|
||||
// so just drop them.
|
||||
logger.Warn(ctx, "provisioner job log overflowing channel")
|
||||
}
|
||||
}
|
||||
if jlMsg.EndOfLogs {
|
||||
logger.Debug(ctx, "got End of Logs")
|
||||
close(bufferedLogs)
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return bufferedLogs, closeSubscribe, nil
|
||||
}
|
||||
|
183
coderd/provisionerjobs_internal_test.go
Normal file
183
coderd/provisionerjobs_internal_test.go
Normal file
@ -0,0 +1,183 @@
|
||||
package coderd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/databasefake"
|
||||
"github.com/coder/coder/codersdk"
|
||||
)
|
||||
|
||||
func TestProvisionerJobLogs_Unit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("QueryPubSubDupes", func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
t.Cleanup(cancel)
|
||||
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||
// mDB := mocks.NewStore(t)
|
||||
fDB := databasefake.New()
|
||||
fPubsub := &fakePubSub{t: t, cond: sync.NewCond(&sync.Mutex{})}
|
||||
opts := Options{
|
||||
Logger: logger,
|
||||
Database: fDB,
|
||||
Pubsub: fPubsub,
|
||||
}
|
||||
api := New(&opts)
|
||||
server := httptest.NewServer(api.Handler)
|
||||
t.Cleanup(server.Close)
|
||||
userID := uuid.New()
|
||||
keyID, keySecret, err := generateAPIKeyIDSecret()
|
||||
require.NoError(t, err)
|
||||
hashed := sha256.Sum256([]byte(keySecret))
|
||||
|
||||
u, err := url.Parse(server.URL)
|
||||
require.NoError(t, err)
|
||||
client := codersdk.Client{
|
||||
HTTPClient: server.Client(),
|
||||
SessionToken: keyID + "-" + keySecret,
|
||||
URL: u,
|
||||
}
|
||||
|
||||
buildID := uuid.New()
|
||||
workspaceID := uuid.New()
|
||||
jobID := uuid.New()
|
||||
|
||||
expectedLogs := []database.ProvisionerJobLog{
|
||||
{ID: uuid.New(), JobID: jobID, Stage: "Stage0"},
|
||||
{ID: uuid.New(), JobID: jobID, Stage: "Stage1"},
|
||||
{ID: uuid.New(), JobID: jobID, Stage: "Stage2"},
|
||||
{ID: uuid.New(), JobID: jobID, Stage: "Stage3"},
|
||||
}
|
||||
|
||||
// wow there are a lot of DB rows we touch...
|
||||
_, err = fDB.InsertAPIKey(ctx, database.InsertAPIKeyParams{
|
||||
ID: keyID,
|
||||
HashedSecret: hashed[:],
|
||||
UserID: userID,
|
||||
ExpiresAt: time.Now().Add(5 * time.Hour),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = fDB.InsertUser(ctx, database.InsertUserParams{
|
||||
ID: userID,
|
||||
RBACRoles: []string{"admin"},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = fDB.InsertWorkspaceBuild(ctx, database.InsertWorkspaceBuildParams{
|
||||
ID: buildID,
|
||||
WorkspaceID: workspaceID,
|
||||
JobID: jobID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = fDB.InsertWorkspace(ctx, database.InsertWorkspaceParams{
|
||||
ID: workspaceID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = fDB.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{
|
||||
ID: jobID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
for _, l := range expectedLogs[:2] {
|
||||
_, err := fDB.InsertProvisionerJobLogs(ctx, database.InsertProvisionerJobLogsParams{
|
||||
ID: []uuid.UUID{l.ID},
|
||||
JobID: jobID,
|
||||
Stage: []string{l.Stage},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
logs, err := client.WorkspaceBuildLogsAfter(ctx, buildID, time.Now())
|
||||
require.NoError(t, err)
|
||||
|
||||
// when the endpoint calls subscribe, we get the listener here.
|
||||
fPubsub.cond.L.Lock()
|
||||
for fPubsub.listener == nil {
|
||||
fPubsub.cond.Wait()
|
||||
}
|
||||
|
||||
// endpoint should now be listening
|
||||
assert.False(t, fPubsub.canceled)
|
||||
assert.False(t, fPubsub.closed)
|
||||
|
||||
// send all the logs in two batches, duplicating what we already returned on the DB query.
|
||||
msg := provisionerJobLogsMessage{}
|
||||
msg.Logs = expectedLogs[:2]
|
||||
data, err := json.Marshal(msg)
|
||||
require.NoError(t, err)
|
||||
fPubsub.listener(ctx, data)
|
||||
msg.Logs = expectedLogs[2:]
|
||||
data, err = json.Marshal(msg)
|
||||
require.NoError(t, err)
|
||||
fPubsub.listener(ctx, data)
|
||||
|
||||
// send end of logs
|
||||
msg.Logs = nil
|
||||
msg.EndOfLogs = true
|
||||
data, err = json.Marshal(msg)
|
||||
require.NoError(t, err)
|
||||
fPubsub.listener(ctx, data)
|
||||
|
||||
var stages []string
|
||||
for l := range logs {
|
||||
logger.Info(ctx, "got log",
|
||||
slog.F("id", l.ID),
|
||||
slog.F("stage", l.Stage))
|
||||
stages = append(stages, l.Stage)
|
||||
}
|
||||
assert.Equal(t, []string{"Stage0", "Stage1", "Stage2", "Stage3"}, stages)
|
||||
for !fPubsub.canceled {
|
||||
fPubsub.cond.Wait()
|
||||
}
|
||||
assert.False(t, fPubsub.closed)
|
||||
})
|
||||
}
|
||||
|
||||
type fakePubSub struct {
|
||||
t *testing.T
|
||||
cond *sync.Cond
|
||||
listener database.Listener
|
||||
canceled bool
|
||||
closed bool
|
||||
}
|
||||
|
||||
func (f *fakePubSub) Subscribe(_ string, listener database.Listener) (cancel func(), err error) {
|
||||
f.cond.L.Lock()
|
||||
defer f.cond.L.Unlock()
|
||||
f.listener = listener
|
||||
f.cond.Signal()
|
||||
return f.cancel, nil
|
||||
}
|
||||
|
||||
func (f *fakePubSub) Publish(_ string, _ []byte) error {
|
||||
f.t.Fail()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakePubSub) Close() error {
|
||||
f.cond.L.Lock()
|
||||
defer f.cond.L.Unlock()
|
||||
f.closed = true
|
||||
f.cond.Signal()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakePubSub) cancel() {
|
||||
f.cond.L.Lock()
|
||||
defer f.cond.L.Unlock()
|
||||
f.canceled = true
|
||||
f.cond.Signal()
|
||||
}
|
@ -45,7 +45,8 @@ func TestProvisionerJobLogs(t *testing.T) {
|
||||
logs, err := client.WorkspaceBuildLogsAfter(ctx, workspace.LatestBuild.ID, before)
|
||||
require.NoError(t, err)
|
||||
for {
|
||||
_, ok := <-logs
|
||||
log, ok := <-logs
|
||||
t.Logf("got log: [%s] %s %s", log.Level, log.Stage, log.Output)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
Reference in New Issue
Block a user