fix: Don't use StatusAbnormalClosure (#4155)

This commit is contained in:
Kyle Carberry
2022-09-22 13:26:05 -05:00
committed by GitHub
parent 9e099b543f
commit a7ee8b31e0
17 changed files with 62 additions and 34 deletions

View File

@ -490,7 +490,7 @@ func TestAgent(t *testing.T) {
require.Eventually(t, func() bool { require.Eventually(t, func() bool {
_, err := conn.Ping() _, err := conn.Ping()
return err == nil return err == nil
}, testutil.WaitMedium, testutil.IntervalFast) }, testutil.WaitLong, testutil.IntervalFast)
}) })
t.Run("Speedtest", func(t *testing.T) { t.Run("Speedtest", func(t *testing.T) {

View File

@ -22,7 +22,7 @@ func WorkspaceBuild(ctx context.Context, writer io.Writer, client *codersdk.Clie
build, err := client.WorkspaceBuild(ctx, build) build, err := client.WorkspaceBuild(ctx, build)
return build.Job, err return build.Job, err
}, },
Logs: func() (<-chan codersdk.ProvisionerJobLog, error) { Logs: func() (<-chan codersdk.ProvisionerJobLog, io.Closer, error) {
return client.WorkspaceBuildLogsAfter(ctx, build, before) return client.WorkspaceBuildLogsAfter(ctx, build, before)
}, },
}) })
@ -31,7 +31,7 @@ func WorkspaceBuild(ctx context.Context, writer io.Writer, client *codersdk.Clie
type ProvisionerJobOptions struct { type ProvisionerJobOptions struct {
Fetch func() (codersdk.ProvisionerJob, error) Fetch func() (codersdk.ProvisionerJob, error)
Cancel func() error Cancel func() error
Logs func() (<-chan codersdk.ProvisionerJobLog, error) Logs func() (<-chan codersdk.ProvisionerJobLog, io.Closer, error)
FetchInterval time.Duration FetchInterval time.Duration
// Verbose determines whether debug and trace logs will be shown. // Verbose determines whether debug and trace logs will be shown.
@ -132,10 +132,11 @@ func ProvisionerJob(ctx context.Context, writer io.Writer, opts ProvisionerJobOp
// The initial stage needs to print after the signal handler has been registered. // The initial stage needs to print after the signal handler has been registered.
printStage() printStage()
logs, err := opts.Logs() logs, closer, err := opts.Logs()
if err != nil { if err != nil {
return xerrors.Errorf("logs: %w", err) return xerrors.Errorf("logs: %w", err)
} }
defer closer.Close()
var ( var (
// logOutput is where log output is written // logOutput is where log output is written

View File

@ -2,6 +2,7 @@ package cliui_test
import ( import (
"context" "context"
"io"
"os" "os"
"runtime" "runtime"
"sync" "sync"
@ -136,8 +137,10 @@ func newProvisionerJob(t *testing.T) provisionerJobTest {
Cancel: func() error { Cancel: func() error {
return nil return nil
}, },
Logs: func() (<-chan codersdk.ProvisionerJobLog, error) { Logs: func() (<-chan codersdk.ProvisionerJobLog, io.Closer, error) {
return logs, nil return logs, closeFunc(func() error {
return nil
}), nil
}, },
}) })
}, },
@ -164,3 +167,9 @@ func newProvisionerJob(t *testing.T) provisionerJobTest {
PTY: ptty, PTY: ptty,
} }
} }
type closeFunc func() error
func (c closeFunc) Close() error {
return c()
}

View File

@ -2,6 +2,7 @@ package cli
import ( import (
"fmt" "fmt"
"io"
"time" "time"
"github.com/spf13/cobra" "github.com/spf13/cobra"
@ -253,7 +254,7 @@ PromptParamLoop:
Cancel: func() error { Cancel: func() error {
return client.CancelTemplateVersionDryRun(cmd.Context(), templateVersion.ID, dryRun.ID) return client.CancelTemplateVersionDryRun(cmd.Context(), templateVersion.ID, dryRun.ID)
}, },
Logs: func() (<-chan codersdk.ProvisionerJobLog, error) { Logs: func() (<-chan codersdk.ProvisionerJobLog, io.Closer, error) {
return client.TemplateVersionDryRunLogsAfter(cmd.Context(), templateVersion.ID, dryRun.ID, after) return client.TemplateVersionDryRunLogsAfter(cmd.Context(), templateVersion.ID, dryRun.ID, after)
}, },
// Don't show log output for the dry-run unless there's an error. // Don't show log output for the dry-run unless there's an error.

View File

@ -2,6 +2,7 @@ package cli
import ( import (
"fmt" "fmt"
"io"
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
@ -182,7 +183,7 @@ func createValidTemplateVersion(cmd *cobra.Command, args createValidTemplateVers
Cancel: func() error { Cancel: func() error {
return client.CancelTemplateVersion(cmd.Context(), version.ID) return client.CancelTemplateVersion(cmd.Context(), version.ID)
}, },
Logs: func() (<-chan codersdk.ProvisionerJobLog, error) { Logs: func() (<-chan codersdk.ProvisionerJobLog, io.Closer, error) {
return client.TemplateVersionLogsAfter(cmd.Context(), version.ID, before) return client.TemplateVersionLogsAfter(cmd.Context(), version.ID, before)
}, },
}) })

View File

@ -66,10 +66,11 @@ func update() *cobra.Command {
if err != nil { if err != nil {
return err return err
} }
logs, err := client.WorkspaceBuildLogsAfter(cmd.Context(), build.ID, before) logs, closer, err := client.WorkspaceBuildLogsAfter(cmd.Context(), build.ID, before)
if err != nil { if err != nil {
return err return err
} }
defer closer.Close()
for { for {
log, ok := <-logs log, ok := <-logs
if !ok { if !ok {

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"io"
"os" "os"
"strings" "strings"
"time" "time"
@ -100,7 +101,7 @@ func main() {
Fetch: func() (codersdk.ProvisionerJob, error) { Fetch: func() (codersdk.ProvisionerJob, error) {
return job, nil return job, nil
}, },
Logs: func() (<-chan codersdk.ProvisionerJobLog, error) { Logs: func() (<-chan codersdk.ProvisionerJobLog, io.Closer, error) {
logs := make(chan codersdk.ProvisionerJobLog) logs := make(chan codersdk.ProvisionerJobLog)
go func() { go func() {
defer close(logs) defer close(logs)
@ -143,7 +144,7 @@ func main() {
} }
} }
}() }()
return logs, nil return logs, io.NopCloser(strings.NewReader("")), nil
}, },
Cancel: func() error { Cancel: func() error {
job.Status = codersdk.ProvisionerJobCanceling job.Status = codersdk.ProvisionerJobCanceling

View File

@ -108,8 +108,9 @@ func TestProvisionerJobLogs_Unit(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
} }
logs, err := client.WorkspaceBuildLogsAfter(ctx, buildID, time.Now()) logs, closer, err := client.WorkspaceBuildLogsAfter(ctx, buildID, time.Now())
require.NoError(t, err) require.NoError(t, err)
defer closer.Close()
// when the endpoint calls subscribe, we get the listener here. // when the endpoint calls subscribe, we get the listener here.
fPubsub.cond.L.Lock() fPubsub.cond.L.Lock()

View File

@ -44,8 +44,9 @@ func TestProvisionerJobLogs(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel() defer cancel()
logs, err := client.WorkspaceBuildLogsAfter(ctx, workspace.LatestBuild.ID, before) logs, closer, err := client.WorkspaceBuildLogsAfter(ctx, workspace.LatestBuild.ID, before)
require.NoError(t, err) require.NoError(t, err)
defer closer.Close()
for { for {
log, ok := <-logs log, ok := <-logs
t.Logf("got log: [%s] %s %s", log.Level, log.Stage, log.Output) t.Logf("got log: [%s] %s %s", log.Level, log.Stage, log.Output)
@ -82,8 +83,9 @@ func TestProvisionerJobLogs(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel() defer cancel()
logs, err := client.WorkspaceBuildLogsAfter(ctx, workspace.LatestBuild.ID, before) logs, closer, err := client.WorkspaceBuildLogsAfter(ctx, workspace.LatestBuild.ID, before)
require.NoError(t, err) require.NoError(t, err)
defer closer.Close()
for { for {
_, ok := <-logs _, ok := <-logs
if !ok { if !ok {

View File

@ -447,8 +447,9 @@ func TestTemplateVersionLogs(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel() defer cancel()
logs, err := client.TemplateVersionLogsAfter(ctx, version.ID, before) logs, closer, err := client.TemplateVersionLogsAfter(ctx, version.ID, before)
require.NoError(t, err) require.NoError(t, err)
defer closer.Close()
for { for {
_, ok := <-logs _, ok := <-logs
if !ok { if !ok {
@ -618,8 +619,9 @@ func TestTemplateVersionDryRun(t *testing.T) {
require.Equal(t, job.ID, newJob.ID) require.Equal(t, job.ID, newJob.ID)
// Stream logs // Stream logs
logs, err := client.TemplateVersionDryRunLogsAfter(ctx, version.ID, job.ID, after) logs, closer, err := client.TemplateVersionDryRunLogsAfter(ctx, version.ID, job.ID, after)
require.NoError(t, err) require.NoError(t, err)
defer closer.Close()
logsDone := make(chan struct{}) logsDone := make(chan struct{})
go func() { go func() {

View File

@ -347,7 +347,7 @@ func (api *API) workspaceAgentCoordinate(rw http.ResponseWriter, r *http.Request
err = updateConnectionTimes() err = updateConnectionTimes()
if err != nil { if err != nil {
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error()) _ = conn.Close(websocket.StatusGoingAway, err.Error())
return return
} }
@ -380,7 +380,7 @@ func (api *API) workspaceAgentCoordinate(rw http.ResponseWriter, r *http.Request
} }
err = updateConnectionTimes() err = updateConnectionTimes()
if err != nil { if err != nil {
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error()) _ = conn.Close(websocket.StatusGoingAway, err.Error())
return return
} }
err := ensureLatestBuild() err := ensureLatestBuild()
@ -571,7 +571,7 @@ func (api *API) workspaceAgentReportStats(rw http.ResponseWriter, r *http.Reques
}) })
return return
} }
defer conn.Close(websocket.StatusAbnormalClosure, "") defer conn.Close(websocket.StatusGoingAway, "")
var lastReport codersdk.AgentStatsReportResponse var lastReport codersdk.AgentStatsReportResponse
latestStat, err := api.Database.GetLatestAgentStat(ctx, workspaceAgent.ID) latestStat, err := api.Database.GetLatestAgentStat(ctx, workspaceAgent.ID)

View File

@ -128,7 +128,7 @@ func TestWorkspaceAgentListen(t *testing.T) {
require.Eventually(t, func() bool { require.Eventually(t, func() bool {
_, err := conn.Ping() _, err := conn.Ping()
return err == nil return err == nil
}, testutil.WaitMedium, testutil.IntervalFast) }, testutil.WaitLong, testutil.IntervalFast)
}) })
t.Run("FailNonLatestBuild", func(t *testing.T) { t.Run("FailNonLatestBuild", func(t *testing.T) {

View File

@ -442,8 +442,9 @@ func TestWorkspaceBuildLogs(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel() defer cancel()
logs, err := client.WorkspaceBuildLogsAfter(ctx, workspace.LatestBuild.ID, before.Add(-time.Hour)) logs, closer, err := client.WorkspaceBuildLogsAfter(ctx, workspace.LatestBuild.ID, before.Add(-time.Hour))
require.NoError(t, err) require.NoError(t, err)
defer closer.Close()
for { for {
log, ok := <-logs log, ok := <-logs
if !ok { if !ok {

View File

@ -5,6 +5,7 @@ import (
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io"
"net/http" "net/http"
"net/http/cookiejar" "net/http/cookiejar"
"net/url" "net/url"
@ -104,18 +105,18 @@ func (c *Client) provisionerJobLogsBefore(ctx context.Context, path string, befo
} }
// provisionerJobLogsAfter streams logs that occurred after a specific time. // provisionerJobLogsAfter streams logs that occurred after a specific time.
func (c *Client) provisionerJobLogsAfter(ctx context.Context, path string, after time.Time) (<-chan ProvisionerJobLog, error) { func (c *Client) provisionerJobLogsAfter(ctx context.Context, path string, after time.Time) (<-chan ProvisionerJobLog, io.Closer, error) {
afterQuery := "" afterQuery := ""
if !after.IsZero() { if !after.IsZero() {
afterQuery = fmt.Sprintf("&after=%d", after.UTC().UnixMilli()) afterQuery = fmt.Sprintf("&after=%d", after.UTC().UnixMilli())
} }
followURL, err := c.URL.Parse(fmt.Sprintf("%s?follow%s", path, afterQuery)) followURL, err := c.URL.Parse(fmt.Sprintf("%s?follow%s", path, afterQuery))
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
jar, err := cookiejar.New(nil) jar, err := cookiejar.New(nil)
if err != nil { if err != nil {
return nil, xerrors.Errorf("create cookie jar: %w", err) return nil, nil, xerrors.Errorf("create cookie jar: %w", err)
} }
jar.SetCookies(followURL, []*http.Cookie{{ jar.SetCookies(followURL, []*http.Cookie{{
Name: SessionTokenKey, Name: SessionTokenKey,
@ -129,11 +130,13 @@ func (c *Client) provisionerJobLogsAfter(ctx context.Context, path string, after
CompressionMode: websocket.CompressionDisabled, CompressionMode: websocket.CompressionDisabled,
}) })
if err != nil { if err != nil {
return nil, readBodyAsError(res) return nil, nil, readBodyAsError(res)
} }
logs := make(chan ProvisionerJobLog) logs := make(chan ProvisionerJobLog)
decoder := json.NewDecoder(websocket.NetConn(ctx, conn, websocket.MessageText)) decoder := json.NewDecoder(websocket.NetConn(ctx, conn, websocket.MessageText))
closed := make(chan struct{})
go func() { go func() {
defer close(closed)
defer close(logs) defer close(logs)
defer conn.Close(websocket.StatusGoingAway, "") defer conn.Close(websocket.StatusGoingAway, "")
var log ProvisionerJobLog var log ProvisionerJobLog
@ -149,5 +152,9 @@ func (c *Client) provisionerJobLogsAfter(ctx context.Context, path string, after
} }
} }
}() }()
return logs, nil return logs, closeFunc(func() error {
_ = conn.Close(websocket.StatusNormalClosure, "")
<-closed
return nil
}), nil
} }

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io"
"net/http" "net/http"
"time" "time"
@ -99,7 +100,7 @@ func (c *Client) TemplateVersionLogsBefore(ctx context.Context, version uuid.UUI
} }
// TemplateVersionLogsAfter streams logs for a template version that occurred after a specific time. // TemplateVersionLogsAfter streams logs for a template version that occurred after a specific time.
func (c *Client) TemplateVersionLogsAfter(ctx context.Context, version uuid.UUID, after time.Time) (<-chan ProvisionerJobLog, error) { func (c *Client) TemplateVersionLogsAfter(ctx context.Context, version uuid.UUID, after time.Time) (<-chan ProvisionerJobLog, io.Closer, error) {
return c.provisionerJobLogsAfter(ctx, fmt.Sprintf("/api/v2/templateversions/%s/logs", version), after) return c.provisionerJobLogsAfter(ctx, fmt.Sprintf("/api/v2/templateversions/%s/logs", version), after)
} }
@ -166,7 +167,7 @@ func (c *Client) TemplateVersionDryRunLogsBefore(ctx context.Context, version, j
// TemplateVersionDryRunLogsAfter streams logs for a template version dry-run // TemplateVersionDryRunLogsAfter streams logs for a template version dry-run
// that occurred after a specific time. // that occurred after a specific time.
func (c *Client) TemplateVersionDryRunLogsAfter(ctx context.Context, version, job uuid.UUID, after time.Time) (<-chan ProvisionerJobLog, error) { func (c *Client) TemplateVersionDryRunLogsAfter(ctx context.Context, version, job uuid.UUID, after time.Time) (<-chan ProvisionerJobLog, io.Closer, error) {
return c.provisionerJobLogsAfter(ctx, fmt.Sprintf("/api/v2/templateversions/%s/dry-run/%s/logs", version, job), after) return c.provisionerJobLogsAfter(ctx, fmt.Sprintf("/api/v2/templateversions/%s/dry-run/%s/logs", version, job), after)
} }

View File

@ -308,15 +308,15 @@ func (c *Client) DialWorkspaceAgentTailnet(ctx context.Context, logger slog.Logg
logger.Debug(ctx, "serving coordinator") logger.Debug(ctx, "serving coordinator")
err = <-errChan err = <-errChan
if errors.Is(err, context.Canceled) { if errors.Is(err, context.Canceled) {
_ = ws.Close(websocket.StatusAbnormalClosure, "") _ = ws.Close(websocket.StatusGoingAway, "")
return return
} }
if err != nil { if err != nil {
logger.Debug(ctx, "error serving coordinator", slog.Error(err)) logger.Debug(ctx, "error serving coordinator", slog.Error(err))
_ = ws.Close(websocket.StatusAbnormalClosure, "") _ = ws.Close(websocket.StatusGoingAway, "")
continue continue
} }
_ = ws.Close(websocket.StatusAbnormalClosure, "") _ = ws.Close(websocket.StatusGoingAway, "")
} }
}() }()
err = <-first err = <-first
@ -446,7 +446,7 @@ func (c *Client) AgentReportStats(
var req AgentStatsReportRequest var req AgentStatsReportRequest
err := wsjson.Read(ctx, conn, &req) err := wsjson.Read(ctx, conn, &req)
if err != nil { if err != nil {
_ = conn.Close(websocket.StatusAbnormalClosure, "") _ = conn.Close(websocket.StatusGoingAway, "")
return err return err
} }
@ -460,7 +460,7 @@ func (c *Client) AgentReportStats(
err = wsjson.Write(ctx, conn, resp) err = wsjson.Write(ctx, conn, resp)
if err != nil { if err != nil {
_ = conn.Close(websocket.StatusAbnormalClosure, "") _ = conn.Close(websocket.StatusGoingAway, "")
return err return err
} }
} }

View File

@ -102,7 +102,7 @@ func (c *Client) WorkspaceBuildLogsBefore(ctx context.Context, build uuid.UUID,
} }
// WorkspaceBuildLogsAfter streams logs for a workspace build that occurred after a specific time. // WorkspaceBuildLogsAfter streams logs for a workspace build that occurred after a specific time.
func (c *Client) WorkspaceBuildLogsAfter(ctx context.Context, build uuid.UUID, after time.Time) (<-chan ProvisionerJobLog, error) { func (c *Client) WorkspaceBuildLogsAfter(ctx context.Context, build uuid.UUID, after time.Time) (<-chan ProvisionerJobLog, io.Closer, error) {
return c.provisionerJobLogsAfter(ctx, fmt.Sprintf("/api/v2/workspacebuilds/%s/logs", build), after) return c.provisionerJobLogsAfter(ctx, fmt.Sprintf("/api/v2/workspacebuilds/%s/logs", build), after)
} }