Files
coder/coderd/provisionerjobs.go
Spike Curtis 148a5a3593 fix: fix goroutine leak in log streaming over websocket (#15709)
fixes #14881

Our handlers for streaming logs don't read from the websocket. We don't allow the client to send us any data, but the websocket library we use requires reading from the websocket to properly handle pings and closing. Not doing so can [can cause the websocket to hang on write](https://github.com/coder/websocket/issues/405), leaking go routines which were noticed in #14881.

This fixes the issue, and in process refactors our log streaming to a encoder/decoder package which provides generic types for sending JSON over websocket.

I'd also like for us to upgrade to the latest https://github.com/coder/websocket but we should also upgrade our tailscale fork before doing so to avoid including two copies of the websocket library.
2024-12-03 10:12:30 +04:00

503 lines
16 KiB
Go

package coderd
import (
"context"
"database/sql"
"encoding/json"
"errors"
"io"
"net/http"
"sort"
"strconv"
"github.com/google/uuid"
"golang.org/x/xerrors"
"nhooyr.io/websocket"
"cdr.dev/slog"
"github.com/coder/coder/v2/codersdk/wsjson"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/db2sdk"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/pubsub"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/provisionersdk"
)
// Returns provisioner logs based on query parameters.
// The intended usage for a client to stream all logs (with JS API):
// GET /logs
// GET /logs?after=<id>&follow
// The combination of these responses should provide all current logs
// to the consumer, and future logs are streamed in the follow request.
func (api *API) provisionerJobLogs(rw http.ResponseWriter, r *http.Request, job database.ProvisionerJob) {
var (
ctx = r.Context()
logger = api.Logger.With(slog.F("job_id", job.ID))
follow = r.URL.Query().Has("follow")
afterRaw = r.URL.Query().Get("after")
)
var after int64
// Only fetch logs created after the time provided.
if afterRaw != "" {
var err error
after, err = strconv.ParseInt(afterRaw, 10, 64)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Query param \"after\" must be an integer.",
Validations: []codersdk.ValidationError{
{Field: "after", Detail: "Must be an integer"},
},
})
return
}
}
if !follow {
fetchAndWriteLogs(ctx, api.Database, job.ID, after, rw)
return
}
follower := newLogFollower(ctx, logger, api.Database, api.Pubsub, rw, r, job, after)
api.WebsocketWaitMutex.Lock()
api.WebsocketWaitGroup.Add(1)
api.WebsocketWaitMutex.Unlock()
defer api.WebsocketWaitGroup.Done()
follower.follow()
}
func (api *API) provisionerJobResources(rw http.ResponseWriter, r *http.Request, job database.ProvisionerJob) {
ctx := r.Context()
if !job.CompletedAt.Valid {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Job hasn't completed!",
})
return
}
// nolint:gocritic // GetWorkspaceResourcesByJobID is a system function.
resources, err := api.Database.GetWorkspaceResourcesByJobID(dbauthz.AsSystemRestricted(ctx), job.ID)
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching job resources.",
Detail: err.Error(),
})
return
}
resourceIDs := make([]uuid.UUID, 0)
for _, resource := range resources {
resourceIDs = append(resourceIDs, resource.ID)
}
// nolint:gocritic // GetWorkspaceAgentsByResourceIDs is a system function.
resourceAgents, err := api.Database.GetWorkspaceAgentsByResourceIDs(dbauthz.AsSystemRestricted(ctx), resourceIDs)
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching workspace agent.",
Detail: err.Error(),
})
return
}
resourceAgentIDs := make([]uuid.UUID, 0)
for _, agent := range resourceAgents {
resourceAgentIDs = append(resourceAgentIDs, agent.ID)
}
// nolint:gocritic // GetWorkspaceAppsByAgentIDs is a system function.
apps, err := api.Database.GetWorkspaceAppsByAgentIDs(dbauthz.AsSystemRestricted(ctx), resourceAgentIDs)
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching workspace applications.",
Detail: err.Error(),
})
return
}
// nolint:gocritic // GetWorkspaceAgentScriptsByAgentIDs is a system function.
scripts, err := api.Database.GetWorkspaceAgentScriptsByAgentIDs(dbauthz.AsSystemRestricted(ctx), resourceAgentIDs)
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching workspace agent scripts.",
Detail: err.Error(),
})
return
}
// nolint:gocritic // GetWorkspaceAgentLogSourcesByAgentIDs is a system function.
logSources, err := api.Database.GetWorkspaceAgentLogSourcesByAgentIDs(dbauthz.AsSystemRestricted(ctx), resourceAgentIDs)
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching workspace agent log sources.",
Detail: err.Error(),
})
return
}
// nolint:gocritic // GetWorkspaceResourceMetadataByResourceIDs is a system function.
resourceMetadata, err := api.Database.GetWorkspaceResourceMetadataByResourceIDs(dbauthz.AsSystemRestricted(ctx), resourceIDs)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching workspace metadata.",
Detail: err.Error(),
})
return
}
apiResources := make([]codersdk.WorkspaceResource, 0)
for _, resource := range resources {
agents := make([]codersdk.WorkspaceAgent, 0)
for _, agent := range resourceAgents {
if agent.ResourceID != resource.ID {
continue
}
dbApps := make([]database.WorkspaceApp, 0)
for _, app := range apps {
if app.AgentID == agent.ID {
dbApps = append(dbApps, app)
}
}
dbScripts := make([]database.WorkspaceAgentScript, 0)
for _, script := range scripts {
if script.WorkspaceAgentID == agent.ID {
dbScripts = append(dbScripts, script)
}
}
dbLogSources := make([]database.WorkspaceAgentLogSource, 0)
for _, logSource := range logSources {
if logSource.WorkspaceAgentID == agent.ID {
dbLogSources = append(dbLogSources, logSource)
}
}
apiAgent, err := db2sdk.WorkspaceAgent(
api.DERPMap(), *api.TailnetCoordinator.Load(), agent, convertProvisionedApps(dbApps), convertScripts(dbScripts), convertLogSources(dbLogSources), api.AgentInactiveDisconnectTimeout,
api.DeploymentValues.AgentFallbackTroubleshootingURL.String(),
)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error reading job agent.",
Detail: err.Error(),
})
return
}
agents = append(agents, apiAgent)
}
metadata := make([]database.WorkspaceResourceMetadatum, 0)
for _, field := range resourceMetadata {
if field.WorkspaceResourceID == resource.ID {
metadata = append(metadata, field)
}
}
apiResources = append(apiResources, convertWorkspaceResource(resource, agents, metadata))
}
sort.Slice(apiResources, func(i, j int) bool {
return apiResources[i].Name < apiResources[j].Name
})
httpapi.Write(ctx, rw, http.StatusOK, apiResources)
}
func convertProvisionerJobLogs(provisionerJobLogs []database.ProvisionerJobLog) []codersdk.ProvisionerJobLog {
sdk := make([]codersdk.ProvisionerJobLog, 0, len(provisionerJobLogs))
for _, log := range provisionerJobLogs {
sdk = append(sdk, convertProvisionerJobLog(log))
}
return sdk
}
func convertProvisionerJobLog(provisionerJobLog database.ProvisionerJobLog) codersdk.ProvisionerJobLog {
return codersdk.ProvisionerJobLog{
ID: provisionerJobLog.ID,
CreatedAt: provisionerJobLog.CreatedAt,
Source: codersdk.LogSource(provisionerJobLog.Source),
Level: codersdk.LogLevel(provisionerJobLog.Level),
Stage: provisionerJobLog.Stage,
Output: provisionerJobLog.Output,
}
}
func convertProvisionerJob(pj database.GetProvisionerJobsByIDsWithQueuePositionRow) codersdk.ProvisionerJob {
provisionerJob := pj.ProvisionerJob
job := codersdk.ProvisionerJob{
ID: provisionerJob.ID,
CreatedAt: provisionerJob.CreatedAt,
Error: provisionerJob.Error.String,
ErrorCode: codersdk.JobErrorCode(provisionerJob.ErrorCode.String),
FileID: provisionerJob.FileID,
Tags: provisionerJob.Tags,
QueuePosition: int(pj.QueuePosition),
QueueSize: int(pj.QueueSize),
}
// Applying values optional to the struct.
if provisionerJob.StartedAt.Valid {
job.StartedAt = &provisionerJob.StartedAt.Time
}
if provisionerJob.CompletedAt.Valid {
job.CompletedAt = &provisionerJob.CompletedAt.Time
}
if provisionerJob.CanceledAt.Valid {
job.CanceledAt = &provisionerJob.CanceledAt.Time
}
if provisionerJob.WorkerID.Valid {
job.WorkerID = &provisionerJob.WorkerID.UUID
}
job.Status = codersdk.ProvisionerJobStatus(pj.ProvisionerJob.JobStatus)
return job
}
func fetchAndWriteLogs(ctx context.Context, db database.Store, jobID uuid.UUID, after int64, rw http.ResponseWriter) {
logs, err := db.GetProvisionerLogsAfterID(ctx, database.GetProvisionerLogsAfterIDParams{
JobID: jobID,
CreatedAfter: after,
})
if err != nil && !errors.Is(err, sql.ErrNoRows) {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching provisioner logs.",
Detail: err.Error(),
})
return
}
if logs == nil {
logs = []database.ProvisionerJobLog{}
}
httpapi.Write(ctx, rw, http.StatusOK, convertProvisionerJobLogs(logs))
}
func jobIsComplete(logger slog.Logger, job database.ProvisionerJob) bool {
status := codersdk.ProvisionerJobStatus(job.JobStatus)
switch status {
case codersdk.ProvisionerJobCanceled:
return true
case codersdk.ProvisionerJobFailed:
return true
case codersdk.ProvisionerJobSucceeded:
return true
case codersdk.ProvisionerJobPending:
return false
case codersdk.ProvisionerJobCanceling:
return false
case codersdk.ProvisionerJobRunning:
return false
default:
logger.Error(context.Background(),
"can't convert the provisioner job status",
slog.F("job_id", job.ID), slog.F("status", status))
return false
}
}
type logFollower struct {
ctx context.Context
logger slog.Logger
db database.Store
pubsub pubsub.Pubsub
r *http.Request
rw http.ResponseWriter
conn *websocket.Conn
enc *wsjson.Encoder[codersdk.ProvisionerJobLog]
jobID uuid.UUID
after int64
complete bool
notifications chan provisionersdk.ProvisionerJobLogsNotifyMessage
errors chan error
}
func newLogFollower(
ctx context.Context, logger slog.Logger, db database.Store, ps pubsub.Pubsub,
rw http.ResponseWriter, r *http.Request, job database.ProvisionerJob, after int64,
) *logFollower {
return &logFollower{
ctx: ctx,
logger: logger,
db: db,
pubsub: ps,
r: r,
rw: rw,
jobID: job.ID,
after: after,
complete: jobIsComplete(logger, job),
notifications: make(chan provisionersdk.ProvisionerJobLogsNotifyMessage),
errors: make(chan error),
}
}
func (f *logFollower) follow() {
var cancel context.CancelFunc
f.ctx, cancel = context.WithCancel(f.ctx)
defer cancel()
// note that we only need to subscribe to updates if the job is not yet
// complete.
if !f.complete {
subCancel, err := f.pubsub.SubscribeWithErr(
provisionersdk.ProvisionerJobLogsNotifyChannel(f.jobID),
f.listener,
)
if err != nil {
httpapi.Write(f.ctx, f.rw, http.StatusInternalServerError, codersdk.Response{
Message: "failed to subscribe to job updates",
Detail: err.Error(),
})
return
}
defer subCancel()
// Move cancel up the stack so it happens before unsubscribing,
// otherwise we can end up in a deadlock due to how the
// in-memory pubsub does mutex locking on send/unsubscribe.
defer cancel()
// we were provided `complete` prior to starting this subscription, so
// we also need to check whether the job is now complete, in case the
// job completed between the last time we queried the job and the start
// of the subscription. If the job completes after this, we will get
// a notification on the subscription.
job, err := f.db.GetProvisionerJobByID(f.ctx, f.jobID)
if err != nil {
httpapi.Write(f.ctx, f.rw, http.StatusInternalServerError, codersdk.Response{
Message: "failed to query job",
Detail: err.Error(),
})
return
}
f.complete = jobIsComplete(f.logger, job)
f.logger.Debug(f.ctx, "queried job after subscribe", slog.F("complete", f.complete))
}
var err error
f.conn, err = websocket.Accept(f.rw, f.r, nil)
if err != nil {
httpapi.Write(f.ctx, f.rw, http.StatusBadRequest, codersdk.Response{
Message: "Failed to accept websocket.",
Detail: err.Error(),
})
return
}
defer f.conn.Close(websocket.StatusNormalClosure, "done")
go httpapi.Heartbeat(f.ctx, f.conn)
f.enc = wsjson.NewEncoder[codersdk.ProvisionerJobLog](f.conn, websocket.MessageText)
// query for logs once right away, so we can get historical data from before
// subscription
if err := f.query(); err != nil {
if f.ctx.Err() == nil && !xerrors.Is(err, io.EOF) {
// neither context expiry, nor EOF, close and log
f.logger.Error(f.ctx, "failed to query logs", slog.Error(err))
err = f.conn.Close(websocket.StatusInternalError, err.Error())
if err != nil {
f.logger.Warn(f.ctx, "failed to close webscoket", slog.Error(err))
}
}
return
}
// no need to wait if the job is done
if f.complete {
return
}
for {
select {
case err := <-f.errors:
// we've dropped at least one notification. This can happen if we
// lose database connectivity. We don't know whether the job is
// now complete since we could have missed the end of logs message.
// We could soldier on and retry, but loss of database connectivity
// is fairly serious, so instead just 500 and bail out. Client
// can retry and hopefully find a healthier node.
f.logger.Error(f.ctx, "dropped or corrupted notification", slog.Error(err))
err = f.conn.Close(websocket.StatusInternalError, err.Error())
if err != nil {
f.logger.Warn(f.ctx, "failed to close webscoket", slog.Error(err))
}
return
case <-f.ctx.Done():
// client disconnect
return
case n := <-f.notifications:
if n.EndOfLogs {
// safe to return here because we started the subscription,
// and then queried at least once, so we will have already
// gotten all logs prior to the start of our subscription.
return
}
err = f.query()
if err != nil {
if f.ctx.Err() == nil && !xerrors.Is(err, io.EOF) {
// neither context expiry, nor EOF, close and log
f.logger.Error(f.ctx, "failed to query logs", slog.Error(err))
err = f.conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("%s", err.Error()))
if err != nil {
f.logger.Warn(f.ctx, "failed to close webscoket", slog.Error(err))
}
}
return
}
}
}
}
func (f *logFollower) listener(_ context.Context, message []byte, err error) {
// in this function we always pair writes to channels with a select on the context
// otherwise we could block a goroutine if the follow() method exits.
if err != nil {
select {
case <-f.ctx.Done():
case f.errors <- err:
}
return
}
var n provisionersdk.ProvisionerJobLogsNotifyMessage
err = json.Unmarshal(message, &n)
if err != nil {
select {
case <-f.ctx.Done():
case f.errors <- err:
}
return
}
select {
case <-f.ctx.Done():
case f.notifications <- n:
}
}
// query fetches the latest job logs from the database and writes them to the
// connection.
func (f *logFollower) query() error {
f.logger.Debug(f.ctx, "querying logs", slog.F("after", f.after))
logs, err := f.db.GetProvisionerLogsAfterID(f.ctx, database.GetProvisionerLogsAfterIDParams{
JobID: f.jobID,
CreatedAfter: f.after,
})
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return xerrors.Errorf("error fetching logs: %w", err)
}
for _, log := range logs {
err := f.enc.Encode(convertProvisionerJobLog(log))
if err != nil {
return xerrors.Errorf("error writing to websocket: %w", err)
}
f.after = log.ID
f.logger.Debug(f.ctx, "wrote log to websocket", slog.F("id", log.ID))
}
return nil
}