fix pubsub/poll race on provisioner job logs (#2783)

* fix pubsub/poll race on provisioner job logs

Signed-off-by: Spike Curtis <spike@coder.com>

* only cancel on non-error

Signed-off-by: Spike Curtis <spike@coder.com>

* Improve logging & comments

Signed-off-by: spikecurtis <spike@spikecurtis.com>
This commit is contained in:
Spike Curtis
2022-07-01 14:07:18 -07:00
committed by GitHub
parent c1b3080162
commit b1e4cfe6c8
4 changed files with 320 additions and 82 deletions

View File

@ -380,7 +380,7 @@ func (server *provisionerdServer) UpdateJob(ctx context.Context, request *proto.
return nil, xerrors.Errorf("insert job logs: %w", err)
}
server.Logger.Debug(ctx, "inserted job logs", slog.F("job_id", parsedID))
data, err := json.Marshal(logs)
data, err := json.Marshal(provisionerJobLogsMessage{Logs: logs})
if err != nil {
return nil, xerrors.Errorf("marshal job log: %w", err)
}
@ -549,6 +549,16 @@ func (server *provisionerdServer) FailJob(ctx context.Context, failJob *proto.Fa
}
case *proto.FailedJob_TemplateImport_:
}
data, err := json.Marshal(provisionerJobLogsMessage{EndOfLogs: true})
if err != nil {
return nil, xerrors.Errorf("marshal job log: %w", err)
}
err = server.Pubsub.Publish(provisionerJobLogsChannel(jobID), data)
if err != nil {
server.Logger.Error(ctx, "failed to publish end of job logs", slog.F("job_id", jobID), slog.Error(err))
return nil, xerrors.Errorf("publish end of job logs: %w", err)
}
return &proto.Empty{}, nil
}
@ -711,6 +721,16 @@ func (server *provisionerdServer) CompleteJob(ctx context.Context, completed *pr
reflect.TypeOf(completed.Type).String())
}
data, err := json.Marshal(provisionerJobLogsMessage{EndOfLogs: true})
if err != nil {
return nil, xerrors.Errorf("marshal job log: %w", err)
}
err = server.Pubsub.Publish(provisionerJobLogsChannel(jobID), data)
if err != nil {
server.Logger.Error(ctx, "failed to publish end of job logs", slog.F("job_id", jobID), slog.Error(err))
return nil, xerrors.Errorf("publish end of job logs: %w", err)
}
server.Logger.Debug(ctx, "CompleteJob done", slog.F("job_id", jobID))
return &proto.Empty{}, nil
}

View File

@ -28,6 +28,7 @@ import (
// The combination of these responses should provide all current logs
// to the consumer, and future logs are streamed in the follow request.
func (api *API) provisionerJobLogs(rw http.ResponseWriter, r *http.Request, job database.ProvisionerJob) {
logger := api.Logger.With(slog.F("job_id", job.ID))
follow := r.URL.Query().Has("follow")
afterRaw := r.URL.Query().Get("after")
beforeRaw := r.URL.Query().Get("before")
@ -38,6 +39,37 @@ func (api *API) provisionerJobLogs(rw http.ResponseWriter, r *http.Request, job
return
}
// if we are following logs, start the subscription before we query the database, so that we don't miss any logs
// between the end of our query and the start of the subscription. We might get duplicates, so we'll keep track
// of processed IDs.
var bufferedLogs <-chan database.ProvisionerJobLog
if follow {
bl, closeFollow, err := api.followLogs(job.ID)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: "Internal error watching provisioner logs.",
Detail: err.Error(),
})
return
}
defer closeFollow()
bufferedLogs = bl
// Next query the job itself to see if it is complete. If so, the historical query to the database will return
// the full set of logs. It's a little sad to have to query the job again, given that our caller definitely
// has, but we need to query it *after* we start following the pubsub to avoid a race condition where the job
// completes between the prior query and the start of following the pubsub. A more substantial refactor could
// avoid this, but not worth it for one fewer query at this point.
job, err = api.Database.GetProvisionerJobByID(r.Context(), job.ID)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: "Internal error querying job.",
Detail: err.Error(),
})
return
}
}
var after time.Time
// Only fetch logs created after the time provided.
if afterRaw != "" {
@ -78,26 +110,27 @@ func (api *API) provisionerJobLogs(rw http.ResponseWriter, r *http.Request, job
}
}
if !follow {
logs, err := api.Database.GetProvisionerLogsByIDBetween(r.Context(), database.GetProvisionerLogsByIDBetweenParams{
JobID: job.ID,
CreatedAfter: after,
CreatedBefore: before,
logs, err := api.Database.GetProvisionerLogsByIDBetween(r.Context(), database.GetProvisionerLogsByIDBetweenParams{
JobID: job.ID,
CreatedAfter: after,
CreatedBefore: before,
})
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: "Internal error fetching provisioner logs.",
Detail: err.Error(),
})
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: "Internal error fetching provisioner logs.",
Detail: err.Error(),
})
return
}
if logs == nil {
logs = []database.ProvisionerJobLog{}
}
return
}
if logs == nil {
logs = []database.ProvisionerJobLog{}
}
if !follow {
logger.Debug(r.Context(), "Finished non-follow job logs")
httpapi.Write(rw, http.StatusOK, convertProvisionerJobLogs(logs))
return
}
@ -118,82 +151,43 @@ func (api *API) provisionerJobLogs(rw http.ResponseWriter, r *http.Request, job
ctx, wsNetConn := websocketNetConn(r.Context(), conn, websocket.MessageText)
defer wsNetConn.Close() // Also closes conn.
bufferedLogs := make(chan database.ProvisionerJobLog, 128)
closeSubscribe, err := api.Pubsub.Subscribe(provisionerJobLogsChannel(job.ID), func(ctx context.Context, message []byte) {
var logs []database.ProvisionerJobLog
err := json.Unmarshal(message, &logs)
if err != nil {
api.Logger.Warn(ctx, fmt.Sprintf("invalid provisioner job log on channel %q: %s", provisionerJobLogsChannel(job.ID), err.Error()))
return
}
for _, log := range logs {
select {
case bufferedLogs <- log:
api.Logger.Debug(r.Context(), "subscribe buffered log", slog.F("job_id", job.ID), slog.F("stage", log.Stage))
default:
// If this overflows users could miss logs streaming. This can happen
// if a database request takes a long amount of time, and we get a lot of logs.
api.Logger.Warn(ctx, "provisioner job log overflowing channel")
}
}
})
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: "Internal error watching provisioner logs.",
Detail: err.Error(),
})
return
}
defer closeSubscribe()
provisionerJobLogs, err := api.Database.GetProvisionerLogsByIDBetween(ctx, database.GetProvisionerLogsByIDBetweenParams{
JobID: job.ID,
CreatedAfter: after,
CreatedBefore: before,
})
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: "Internal error fetching provisioner logs.",
Detail: err.Error(),
})
return
}
logIdsDone := make(map[uuid.UUID]bool)
// The Go stdlib JSON encoder appends a newline character after message write.
encoder := json.NewEncoder(wsNetConn)
for _, provisionerJobLog := range provisionerJobLogs {
for _, provisionerJobLog := range logs {
logIdsDone[provisionerJobLog.ID] = true
err = encoder.Encode(convertProvisionerJobLog(provisionerJobLog))
if err != nil {
return
}
}
if job.CompletedAt.Valid {
// job was complete before we queried the database for historical logs, meaning we got everything. No need
// to stream anything from the bufferedLogs.
return
}
ticker := time.NewTicker(250 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-r.Context().Done():
api.Logger.Debug(context.Background(), "job logs context canceled", slog.F("job_id", job.ID))
case <-ctx.Done():
logger.Debug(context.Background(), "job logs context canceled")
return
case log := <-bufferedLogs:
api.Logger.Debug(r.Context(), "subscribe encoding log", slog.F("job_id", job.ID), slog.F("stage", log.Stage))
err = encoder.Encode(convertProvisionerJobLog(log))
if err != nil {
case log, ok := <-bufferedLogs:
if !ok {
logger.Debug(context.Background(), "done with published logs")
return
}
case <-ticker.C:
job, err := api.Database.GetProvisionerJobByID(r.Context(), job.ID)
if err != nil {
api.Logger.Warn(r.Context(), "streaming job logs; checking if completed", slog.Error(err), slog.F("job_id", job.ID.String()))
continue
}
if job.CompletedAt.Valid {
api.Logger.Debug(context.Background(), "streaming job logs done; job done", slog.F("job_id", job.ID))
return
if logIdsDone[log.ID] {
logger.Debug(r.Context(), "subscribe duplicated log",
slog.F("stage", log.Stage))
} else {
logger.Debug(r.Context(), "subscribe encoding log",
slog.F("stage", log.Stage))
err = encoder.Encode(convertProvisionerJobLog(log))
if err != nil {
return
}
}
}
}
@ -343,3 +337,43 @@ func convertProvisionerJob(provisionerJob database.ProvisionerJob) codersdk.Prov
func provisionerJobLogsChannel(jobID uuid.UUID) string {
return fmt.Sprintf("provisioner-log-logs:%s", jobID)
}
// provisionerJobLogsMessage is the message type published on the provisionerJobLogsChannel() channel
type provisionerJobLogsMessage struct {
EndOfLogs bool `json:"end_of_logs,omitempty"`
Logs []database.ProvisionerJobLog `json:"logs,omitempty"`
}
func (api *API) followLogs(jobID uuid.UUID) (<-chan database.ProvisionerJobLog, func(), error) {
logger := api.Logger.With(slog.F("job_id", jobID))
bufferedLogs := make(chan database.ProvisionerJobLog, 128)
closeSubscribe, err := api.Pubsub.Subscribe(provisionerJobLogsChannel(jobID),
func(ctx context.Context, message []byte) {
jlMsg := provisionerJobLogsMessage{}
err := json.Unmarshal(message, &jlMsg)
if err != nil {
logger.Warn(ctx, "invalid provisioner job log on channel", slog.Error(err))
return
}
for _, log := range jlMsg.Logs {
select {
case bufferedLogs <- log:
logger.Debug(ctx, "subscribe buffered log", slog.F("stage", log.Stage))
default:
// If this overflows users could miss logs streaming. This can happen
// we get a lot of logs and consumer isn't keeping up. We don't want to block the pubsub,
// so just drop them.
logger.Warn(ctx, "provisioner job log overflowing channel")
}
}
if jlMsg.EndOfLogs {
logger.Debug(ctx, "got End of Logs")
close(bufferedLogs)
}
})
if err != nil {
return nil, nil, err
}
return bufferedLogs, closeSubscribe, nil
}

View File

@ -0,0 +1,183 @@
package coderd
import (
"context"
"crypto/sha256"
"encoding/json"
"net/http/httptest"
"net/url"
"sync"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/databasefake"
"github.com/coder/coder/codersdk"
)
func TestProvisionerJobLogs_Unit(t *testing.T) {
t.Parallel()
t.Run("QueryPubSubDupes", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
t.Cleanup(cancel)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
// mDB := mocks.NewStore(t)
fDB := databasefake.New()
fPubsub := &fakePubSub{t: t, cond: sync.NewCond(&sync.Mutex{})}
opts := Options{
Logger: logger,
Database: fDB,
Pubsub: fPubsub,
}
api := New(&opts)
server := httptest.NewServer(api.Handler)
t.Cleanup(server.Close)
userID := uuid.New()
keyID, keySecret, err := generateAPIKeyIDSecret()
require.NoError(t, err)
hashed := sha256.Sum256([]byte(keySecret))
u, err := url.Parse(server.URL)
require.NoError(t, err)
client := codersdk.Client{
HTTPClient: server.Client(),
SessionToken: keyID + "-" + keySecret,
URL: u,
}
buildID := uuid.New()
workspaceID := uuid.New()
jobID := uuid.New()
expectedLogs := []database.ProvisionerJobLog{
{ID: uuid.New(), JobID: jobID, Stage: "Stage0"},
{ID: uuid.New(), JobID: jobID, Stage: "Stage1"},
{ID: uuid.New(), JobID: jobID, Stage: "Stage2"},
{ID: uuid.New(), JobID: jobID, Stage: "Stage3"},
}
// wow there are a lot of DB rows we touch...
_, err = fDB.InsertAPIKey(ctx, database.InsertAPIKeyParams{
ID: keyID,
HashedSecret: hashed[:],
UserID: userID,
ExpiresAt: time.Now().Add(5 * time.Hour),
})
require.NoError(t, err)
_, err = fDB.InsertUser(ctx, database.InsertUserParams{
ID: userID,
RBACRoles: []string{"admin"},
})
require.NoError(t, err)
_, err = fDB.InsertWorkspaceBuild(ctx, database.InsertWorkspaceBuildParams{
ID: buildID,
WorkspaceID: workspaceID,
JobID: jobID,
})
require.NoError(t, err)
_, err = fDB.InsertWorkspace(ctx, database.InsertWorkspaceParams{
ID: workspaceID,
})
require.NoError(t, err)
_, err = fDB.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{
ID: jobID,
})
require.NoError(t, err)
for _, l := range expectedLogs[:2] {
_, err := fDB.InsertProvisionerJobLogs(ctx, database.InsertProvisionerJobLogsParams{
ID: []uuid.UUID{l.ID},
JobID: jobID,
Stage: []string{l.Stage},
})
require.NoError(t, err)
}
logs, err := client.WorkspaceBuildLogsAfter(ctx, buildID, time.Now())
require.NoError(t, err)
// when the endpoint calls subscribe, we get the listener here.
fPubsub.cond.L.Lock()
for fPubsub.listener == nil {
fPubsub.cond.Wait()
}
// endpoint should now be listening
assert.False(t, fPubsub.canceled)
assert.False(t, fPubsub.closed)
// send all the logs in two batches, duplicating what we already returned on the DB query.
msg := provisionerJobLogsMessage{}
msg.Logs = expectedLogs[:2]
data, err := json.Marshal(msg)
require.NoError(t, err)
fPubsub.listener(ctx, data)
msg.Logs = expectedLogs[2:]
data, err = json.Marshal(msg)
require.NoError(t, err)
fPubsub.listener(ctx, data)
// send end of logs
msg.Logs = nil
msg.EndOfLogs = true
data, err = json.Marshal(msg)
require.NoError(t, err)
fPubsub.listener(ctx, data)
var stages []string
for l := range logs {
logger.Info(ctx, "got log",
slog.F("id", l.ID),
slog.F("stage", l.Stage))
stages = append(stages, l.Stage)
}
assert.Equal(t, []string{"Stage0", "Stage1", "Stage2", "Stage3"}, stages)
for !fPubsub.canceled {
fPubsub.cond.Wait()
}
assert.False(t, fPubsub.closed)
})
}
type fakePubSub struct {
t *testing.T
cond *sync.Cond
listener database.Listener
canceled bool
closed bool
}
func (f *fakePubSub) Subscribe(_ string, listener database.Listener) (cancel func(), err error) {
f.cond.L.Lock()
defer f.cond.L.Unlock()
f.listener = listener
f.cond.Signal()
return f.cancel, nil
}
func (f *fakePubSub) Publish(_ string, _ []byte) error {
f.t.Fail()
return nil
}
func (f *fakePubSub) Close() error {
f.cond.L.Lock()
defer f.cond.L.Unlock()
f.closed = true
f.cond.Signal()
return nil
}
func (f *fakePubSub) cancel() {
f.cond.L.Lock()
defer f.cond.L.Unlock()
f.canceled = true
f.cond.Signal()
}

View File

@ -45,7 +45,8 @@ func TestProvisionerJobLogs(t *testing.T) {
logs, err := client.WorkspaceBuildLogsAfter(ctx, workspace.LatestBuild.ID, before)
require.NoError(t, err)
for {
_, ok := <-logs
log, ok := <-logs
t.Logf("got log: [%s] %s %s", log.Level, log.Stage, log.Output)
if !ok {
return
}