feat: Add support for update checks and notifications (#4810)

Co-authored-by: Kira Pilot <kira@coder.com>
This commit is contained in:
Mathias Fredriksson
2022-12-01 19:43:28 +02:00
committed by GitHub
parent 4f1cf6c9d8
commit d9f2aaf3b4
32 changed files with 1088 additions and 22 deletions

View File

@ -47,6 +47,7 @@ import (
"github.com/coder/coder/coderd/rbac"
"github.com/coder/coder/coderd/telemetry"
"github.com/coder/coder/coderd/tracing"
"github.com/coder/coder/coderd/updatecheck"
"github.com/coder/coder/coderd/wsconncache"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/provisionerd/proto"
@ -105,6 +106,7 @@ type Options struct {
AgentStatsRefreshInterval time.Duration
Experimental bool
DeploymentConfig *codersdk.DeploymentConfig
UpdateCheckOptions *updatecheck.Options // Set non-nil to enable update checking.
}
// New constructs a Coder API handler.
@ -123,7 +125,7 @@ func New(options *Options) *API {
options.AgentInactiveDisconnectTimeout = options.AgentConnectionUpdateFrequency * 2
}
if options.AgentStatsRefreshInterval == 0 {
options.AgentStatsRefreshInterval = 10 * time.Minute
options.AgentStatsRefreshInterval = 5 * time.Minute
}
if options.MetricsCacheRefreshInterval == 0 {
options.MetricsCacheRefreshInterval = time.Hour
@ -131,12 +133,6 @@ func New(options *Options) *API {
if options.APIRateLimit == 0 {
options.APIRateLimit = 512
}
if options.AgentStatsRefreshInterval == 0 {
options.AgentStatsRefreshInterval = 5 * time.Minute
}
if options.MetricsCacheRefreshInterval == 0 {
options.MetricsCacheRefreshInterval = time.Hour
}
if options.Authorizer == nil {
options.Authorizer = rbac.NewAuthorizer()
}
@ -181,6 +177,13 @@ func New(options *Options) *API {
metricsCache: metricsCache,
Auditor: atomic.Pointer[audit.Auditor]{},
}
if options.UpdateCheckOptions != nil {
api.updateChecker = updatecheck.New(
options.Database,
options.Logger.Named("update_checker"),
*options.UpdateCheckOptions,
)
}
api.Auditor.Store(&options.Auditor)
api.workspaceAgentCache = wsconncache.New(api.dialWorkspaceAgentTailnet, 0)
api.TailnetCoordinator.Store(&options.TailnetCoordinator)
@ -308,6 +311,9 @@ func New(options *Options) *API {
})
})
})
r.Route("/updatecheck", func(r chi.Router) {
r.Get("/", api.updateCheck)
})
r.Route("/config", func(r chi.Router) {
r.Use(apiKeyMiddleware)
r.Get("/deployment", api.deploymentConfig)
@ -590,13 +596,14 @@ type API struct {
// RootHandler serves "/"
RootHandler chi.Router
metricsCache *metricscache.Cache
siteHandler http.Handler
siteHandler http.Handler
WebsocketWaitMutex sync.Mutex
WebsocketWaitGroup sync.WaitGroup
metricsCache *metricscache.Cache
workspaceAgentCache *wsconncache.Cache
updateChecker *updatecheck.Checker
}
// Close waits for all WebSocket connections to drain before returning.
@ -606,6 +613,9 @@ func (api *API) Close() error {
api.WebsocketWaitMutex.Unlock()
api.metricsCache.Close()
if api.updateChecker != nil {
api.updateChecker.Close()
}
coordinator := api.TailnetCoordinator.Load()
if coordinator != nil {
_ = (*coordinator).Close()

View File

@ -48,6 +48,7 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
"GET:/healthz": {NoAuthorize: true},
"GET:/api/v2": {NoAuthorize: true},
"GET:/api/v2/buildinfo": {NoAuthorize: true},
"GET:/api/v2/updatecheck": {NoAuthorize: true},
"GET:/api/v2/users/first": {NoAuthorize: true},
"POST:/api/v2/users/first": {NoAuthorize: true},
"POST:/api/v2/users/login": {NoAuthorize: true},

View File

@ -64,6 +64,7 @@ import (
"github.com/coder/coder/coderd/httpmw"
"github.com/coder/coder/coderd/rbac"
"github.com/coder/coder/coderd/telemetry"
"github.com/coder/coder/coderd/updatecheck"
"github.com/coder/coder/coderd/util/ptr"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/cryptorand"
@ -102,6 +103,9 @@ type Options struct {
AgentStatsRefreshInterval time.Duration
DeploymentConfig *codersdk.DeploymentConfig
// Set update check options to enable update check.
UpdateCheckOptions *updatecheck.Options
// Overriding the database is heavily discouraged.
// It should only be used in cases where multiple Coder
// test instances are running against the same database.
@ -283,6 +287,7 @@ func NewOptions(t *testing.T, options *Options) (func(http.Handler), context.Can
MetricsCacheRefreshInterval: options.MetricsCacheRefreshInterval,
AgentStatsRefreshInterval: options.AgentStatsRefreshInterval,
DeploymentConfig: options.DeploymentConfig,
UpdateCheckOptions: options.UpdateCheckOptions,
}
}

View File

@ -119,9 +119,10 @@ type data struct {
workspaceResources []database.WorkspaceResource
workspaces []database.Workspace
deploymentID string
derpMeshKey string
lastLicenseID int32
deploymentID string
derpMeshKey string
lastUpdateCheck []byte
lastLicenseID int32
}
func (fakeQuerier) IsFakeDB() {}
@ -3272,6 +3273,24 @@ func (q *fakeQuerier) GetDERPMeshKey(_ context.Context) (string, error) {
return q.derpMeshKey, nil
}
func (q *fakeQuerier) InsertOrUpdateLastUpdateCheck(_ context.Context, data string) error {
q.mutex.RLock()
defer q.mutex.RUnlock()
q.lastUpdateCheck = []byte(data)
return nil
}
func (q *fakeQuerier) GetLastUpdateCheck(_ context.Context) (string, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
if q.lastUpdateCheck == nil {
return "", sql.ErrNoRows
}
return string(q.lastUpdateCheck), nil
}
func (q *fakeQuerier) InsertLicense(
_ context.Context, arg database.InsertLicenseParams,
) (database.License, error) {

View File

@ -50,6 +50,7 @@ type sqlcQuerier interface {
GetGroupByOrgAndName(ctx context.Context, arg GetGroupByOrgAndNameParams) (Group, error)
GetGroupMembers(ctx context.Context, groupID uuid.UUID) ([]User, error)
GetGroupsByOrganizationID(ctx context.Context, organizationID uuid.UUID) ([]Group, error)
GetLastUpdateCheck(ctx context.Context) (string, error)
GetLatestAgentStat(ctx context.Context, agentID uuid.UUID) (AgentStat, error)
GetLatestWorkspaceBuildByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) (WorkspaceBuild, error)
GetLatestWorkspaceBuilds(ctx context.Context) ([]WorkspaceBuild, error)
@ -141,6 +142,7 @@ type sqlcQuerier interface {
InsertGroup(ctx context.Context, arg InsertGroupParams) (Group, error)
InsertGroupMember(ctx context.Context, arg InsertGroupMemberParams) error
InsertLicense(ctx context.Context, arg InsertLicenseParams) (License, error)
InsertOrUpdateLastUpdateCheck(ctx context.Context, value string) error
InsertOrganization(ctx context.Context, arg InsertOrganizationParams) (Organization, error)
InsertOrganizationMember(ctx context.Context, arg InsertOrganizationMemberParams) (OrganizationMember, error)
InsertParameterSchema(ctx context.Context, arg InsertParameterSchemaParams) (ParameterSchema, error)

View File

@ -2970,6 +2970,17 @@ func (q *sqlQuerier) GetDeploymentID(ctx context.Context) (string, error) {
return value, err
}
const getLastUpdateCheck = `-- name: GetLastUpdateCheck :one
SELECT value FROM site_configs WHERE key = 'last_update_check'
`
func (q *sqlQuerier) GetLastUpdateCheck(ctx context.Context) (string, error) {
row := q.db.QueryRowContext(ctx, getLastUpdateCheck)
var value string
err := row.Scan(&value)
return value, err
}
const insertDERPMeshKey = `-- name: InsertDERPMeshKey :exec
INSERT INTO site_configs (key, value) VALUES ('derp_mesh_key', $1)
`
@ -2988,6 +2999,16 @@ func (q *sqlQuerier) InsertDeploymentID(ctx context.Context, value string) error
return err
}
const insertOrUpdateLastUpdateCheck = `-- name: InsertOrUpdateLastUpdateCheck :exec
INSERT INTO site_configs (key, value) VALUES ('last_update_check', $1)
ON CONFLICT (key) DO UPDATE SET value = $1 WHERE site_configs.key = 'last_update_check'
`
func (q *sqlQuerier) InsertOrUpdateLastUpdateCheck(ctx context.Context, value string) error {
_, err := q.db.ExecContext(ctx, insertOrUpdateLastUpdateCheck, value)
return err
}
const getTemplateAverageBuildTime = `-- name: GetTemplateAverageBuildTime :one
WITH build_times AS (
SELECT

View File

@ -9,3 +9,10 @@ INSERT INTO site_configs (key, value) VALUES ('derp_mesh_key', $1);
-- name: GetDERPMeshKey :one
SELECT value FROM site_configs WHERE key = 'derp_mesh_key';
-- name: InsertOrUpdateLastUpdateCheck :exec
INSERT INTO site_configs (key, value) VALUES ('last_update_check', $1)
ON CONFLICT (key) DO UPDATE SET value = $1 WHERE site_configs.key = 'last_update_check';
-- name: GetLastUpdateCheck :one
SELECT value FROM site_configs WHERE key = 'last_update_check';

54
coderd/updatecheck.go Normal file
View File

@ -0,0 +1,54 @@
package coderd
import (
"database/sql"
"net/http"
"strings"
"golang.org/x/mod/semver"
"golang.org/x/xerrors"
"github.com/coder/coder/buildinfo"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/codersdk"
)
func (api *API) updateCheck(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
currentVersion := codersdk.UpdateCheckResponse{
Current: true,
Version: buildinfo.Version(),
URL: buildinfo.ExternalURL(),
}
if api.updateChecker == nil {
// If update checking is disabled, echo the current
// version.
httpapi.Write(ctx, rw, http.StatusOK, currentVersion)
return
}
uc, err := api.updateChecker.Latest(ctx)
if err != nil {
if xerrors.Is(err, sql.ErrNoRows) {
// Update checking is enabled, but has never
// succeeded, reproduce behavior as if disabled.
httpapi.Write(ctx, rw, http.StatusOK, currentVersion)
return
}
httpapi.InternalServerError(rw, err)
return
}
// Since our dev version (v0.12.9-devel+f7246386) is not semver compatible,
// ignore everything after "-"."
versionWithoutDevel := strings.SplitN(buildinfo.Version(), "-", 2)[0]
httpapi.Write(ctx, rw, http.StatusOK, codersdk.UpdateCheckResponse{
Current: semver.Compare(versionWithoutDevel, uc.Version) >= 0,
Version: uc.Version,
URL: uc.URL,
})
}

View File

@ -0,0 +1,238 @@
// Package updatecheck provides a mechanism for periodically checking
// for updates to Coder.
//
// The update check is performed by querying the GitHub API for the
// latest release of Coder.
package updatecheck
import (
"context"
"database/sql"
"encoding/json"
"io"
"net/http"
"time"
"github.com/google/go-github/v43/github"
"golang.org/x/mod/semver"
"golang.org/x/xerrors"
"cdr.dev/slog"
"github.com/coder/coder/coderd/database"
)
const (
// defaultURL defines the URL to check for the latest version of Coder.
defaultURL = "https://api.github.com/repos/coder/coder/releases/latest"
)
// Checker is responsible for periodically checking for updates.
type Checker struct {
ctx context.Context
cancel context.CancelFunc
db database.Store
log slog.Logger
opts Options
firstCheck chan struct{}
closed chan struct{}
}
// Options set optional parameters for the update check.
type Options struct {
// Client is the HTTP client to use for the update check,
// if omitted, http.DefaultClient will be used.
Client *http.Client
// URL is the URL to check for the latest version of Coder,
// if omitted, the default URL will be used.
URL string
// Interval is the interval at which to check for updates,
// default 24h.
Interval time.Duration
// UpdateTimeout sets the timeout for the update check,
// default 30s.
UpdateTimeout time.Duration
// Notify is called when a newer version of Coder (than the
// last update check) is available.
Notify func(r Result)
}
// New returns a new Checker that periodically checks for Coder updates.
func New(db database.Store, log slog.Logger, opts Options) *Checker {
if opts.Client == nil {
opts.Client = http.DefaultClient
}
if opts.URL == "" {
opts.URL = defaultURL
}
if opts.Interval == 0 {
opts.Interval = 24 * time.Hour
}
if opts.UpdateTimeout == 0 {
opts.UpdateTimeout = 30 * time.Second
}
if opts.Notify == nil {
opts.Notify = func(r Result) {}
}
ctx, cancel := context.WithCancel(context.Background())
c := &Checker{
ctx: ctx,
cancel: cancel,
db: db,
log: log,
opts: opts,
firstCheck: make(chan struct{}),
closed: make(chan struct{}),
}
go c.start()
return c
}
// Result is the result from the last update check.
type Result struct {
Checked time.Time `json:"checked,omitempty"`
Version string `json:"version,omitempty"`
URL string `json:"url,omitempty"`
}
// Latest returns the latest version of Coder.
func (c *Checker) Latest(ctx context.Context) (r Result, err error) {
select {
case <-c.ctx.Done():
return r, c.ctx.Err()
case <-ctx.Done():
return r, ctx.Err()
case <-c.firstCheck:
}
return c.lastUpdateCheck(ctx)
}
func (c *Checker) init() (Result, error) {
defer close(c.firstCheck)
r, err := c.lastUpdateCheck(c.ctx)
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
return Result{}, xerrors.Errorf("last update check: %w", err)
}
if r.Checked.IsZero() || time.Since(r.Checked) > c.opts.Interval {
r, err = c.update()
if err != nil {
return Result{}, xerrors.Errorf("update check failed: %w", err)
}
}
return r, nil
}
func (c *Checker) start() {
defer close(c.closed)
r, err := c.init()
if err != nil {
if xerrors.Is(err, context.Canceled) {
return
}
c.log.Error(c.ctx, "init failed", slog.Error(err))
} else {
c.opts.Notify(r)
}
t := time.NewTicker(c.opts.Interval)
defer t.Stop()
diff := time.Until(r.Checked.Add(c.opts.Interval))
if diff > 0 {
c.log.Info(c.ctx, "time until next update check", slog.F("duration", diff))
t.Reset(diff)
} else {
c.log.Info(c.ctx, "time until next update check", slog.F("duration", c.opts.Interval))
}
for {
select {
case <-t.C:
rr, err := c.update()
if err != nil {
if xerrors.Is(err, context.Canceled) {
return
}
c.log.Error(c.ctx, "update check failed", slog.Error(err))
} else {
c.notifyIfNewer(r, rr)
r = rr
}
c.log.Info(c.ctx, "time until next update check", slog.F("duration", c.opts.Interval))
t.Reset(c.opts.Interval)
case <-c.ctx.Done():
return
}
}
}
func (c *Checker) update() (r Result, err error) {
ctx, cancel := context.WithTimeout(c.ctx, c.opts.UpdateTimeout)
defer cancel()
c.log.Info(c.ctx, "checking for update")
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.opts.URL, nil)
if err != nil {
return r, xerrors.Errorf("new request: %w", err)
}
resp, err := c.opts.Client.Do(req)
if err != nil {
return r, xerrors.Errorf("client do: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
b, _ := io.ReadAll(resp.Body)
return r, xerrors.Errorf("unexpected status code %d: %s", resp.StatusCode, b)
}
var rr github.RepositoryRelease
err = json.NewDecoder(resp.Body).Decode(&rr)
if err != nil {
return r, xerrors.Errorf("json decode: %w", err)
}
r = Result{
Checked: time.Now(),
Version: rr.GetTagName(),
URL: rr.GetHTMLURL(),
}
c.log.Info(ctx, "update check result", slog.F("latest_version", r.Version))
b, err := json.Marshal(r)
if err != nil {
return r, xerrors.Errorf("json marshal result: %w", err)
}
err = c.db.InsertOrUpdateLastUpdateCheck(ctx, string(b))
if err != nil {
return r, err
}
return r, nil
}
func (c *Checker) notifyIfNewer(prev, next Result) {
if (prev.Version == "" && next.Version != "") || semver.Compare(next.Version, prev.Version) > 0 {
c.opts.Notify(next)
}
}
func (c *Checker) lastUpdateCheck(ctx context.Context) (r Result, err error) {
s, err := c.db.GetLastUpdateCheck(ctx)
if err != nil {
return r, err
}
return r, json.Unmarshal([]byte(s), &r)
}
func (c *Checker) Close() error {
c.cancel()
<-c.closed
return nil
}

View File

@ -0,0 +1,158 @@
package updatecheck_test
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/google/go-github/v43/github"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/coderd/database/databasefake"
"github.com/coder/coder/coderd/updatecheck"
"github.com/coder/coder/testutil"
)
func TestChecker_Notify(t *testing.T) {
t.Parallel()
responses := []github.RepositoryRelease{
{TagName: github.String("v1.2.3"), HTMLURL: github.String("https://someurl.com")},
{TagName: github.String("v1.2.4"), HTMLURL: github.String("https://someurl.com")},
{TagName: github.String("v1.2.4"), HTMLURL: github.String("https://someurl.com")},
{TagName: github.String("v1.2.5"), HTMLURL: github.String("https://someurl.com")},
}
responseC := make(chan github.RepositoryRelease, len(responses))
for _, r := range responses {
responseC <- r
}
wantVersion := []string{"v1.2.3", "v1.2.4", "v1.2.5"}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
select {
case <-r.Context().Done():
case resp := <-responseC:
b, err := json.Marshal(resp)
assert.NoError(t, err)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write(b)
}
}))
defer srv.Close()
db := databasefake.New()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Named(t.Name())
notify := make(chan updatecheck.Result, len(wantVersion))
c := updatecheck.New(db, logger, updatecheck.Options{
Interval: 1 * time.Nanosecond, // Zero means unset.
URL: srv.URL,
Notify: func(r updatecheck.Result) {
select {
case notify <- r:
default:
t.Error("unexpected notification")
}
},
})
defer c.Close()
ctx, _ := testutil.Context(t)
for i := 0; i < len(wantVersion); i++ {
select {
case <-ctx.Done():
t.Error("timed out waiting for notification")
case r := <-notify:
assert.Equal(t, wantVersion[i], r.Version)
}
}
}
func TestChecker_Latest(t *testing.T) {
t.Parallel()
rr := github.RepositoryRelease{
TagName: github.String("v1.2.3"),
HTMLURL: github.String("https://someurl.com"),
}
tests := []struct {
name string
release github.RepositoryRelease
wantR updatecheck.Result
wantErr bool
}{
{
name: "check latest",
release: rr,
wantR: updatecheck.Result{
Version: "v1.2.3",
URL: "https://someurl.com",
},
wantErr: false,
},
{
name: "missing release data",
release: github.RepositoryRelease{},
wantErr: true,
},
{
name: "error",
release: rr,
wantErr: true,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
if tt.wantErr {
w.WriteHeader(http.StatusInternalServerError)
return
}
rrJSON, err := json.Marshal(rr)
assert.NoError(t, err)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write(rrJSON)
}))
defer srv.Close()
db := databasefake.New()
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Named(t.Name())
c := updatecheck.New(db, logger, updatecheck.Options{
URL: srv.URL,
})
defer c.Close()
ctx, _ := testutil.Context(t)
_ = ctx
gotR, err := c.Latest(ctx)
if tt.wantErr {
require.Error(t, err)
return
}
// Zero out the time so we can compare the rest of the struct.
gotR.Checked = time.Time{}
require.Equal(t, tt.wantR, gotR, "wrong version")
})
}
}
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}

View File

@ -0,0 +1,79 @@
package coderd_test
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/google/go-github/v43/github"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coder/coder/buildinfo"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/coderd/updatecheck"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/testutil"
)
func TestUpdateCheck_NewVersion(t *testing.T) {
t.Parallel()
tests := []struct {
name string
resp github.RepositoryRelease
want codersdk.UpdateCheckResponse
}{
{
name: "New version",
resp: github.RepositoryRelease{
TagName: github.String("v99.999.999"),
HTMLURL: github.String("https://someurl.com"),
},
want: codersdk.UpdateCheckResponse{
Current: false,
Version: "v99.999.999",
URL: "https://someurl.com",
},
},
{
name: "Same version",
resp: github.RepositoryRelease{
TagName: github.String(buildinfo.Version()),
HTMLURL: github.String("https://someurl.com"),
},
want: codersdk.UpdateCheckResponse{
Current: true,
Version: buildinfo.Version(),
URL: "https://someurl.com",
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
b, err := json.Marshal(tt.resp)
assert.NoError(t, err)
w.Write(b)
}))
defer srv.Close()
client := coderdtest.New(t, &coderdtest.Options{
UpdateCheckOptions: &updatecheck.Options{
URL: srv.URL,
},
})
ctx, _ := testutil.Context(t)
got, err := client.UpdateCheck(ctx)
require.NoError(t, err)
require.Equal(t, tt.want, got)
})
}
}