mirror of
https://github.com/coder/coder.git
synced 2025-07-13 21:36:50 +00:00
feat: Add support for update checks and notifications (#4810)
Co-authored-by: Kira Pilot <kira@coder.com>
This commit is contained in:
committed by
GitHub
parent
4f1cf6c9d8
commit
d9f2aaf3b4
@ -47,6 +47,7 @@ import (
|
||||
"github.com/coder/coder/coderd/rbac"
|
||||
"github.com/coder/coder/coderd/telemetry"
|
||||
"github.com/coder/coder/coderd/tracing"
|
||||
"github.com/coder/coder/coderd/updatecheck"
|
||||
"github.com/coder/coder/coderd/wsconncache"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/provisionerd/proto"
|
||||
@ -105,6 +106,7 @@ type Options struct {
|
||||
AgentStatsRefreshInterval time.Duration
|
||||
Experimental bool
|
||||
DeploymentConfig *codersdk.DeploymentConfig
|
||||
UpdateCheckOptions *updatecheck.Options // Set non-nil to enable update checking.
|
||||
}
|
||||
|
||||
// New constructs a Coder API handler.
|
||||
@ -123,7 +125,7 @@ func New(options *Options) *API {
|
||||
options.AgentInactiveDisconnectTimeout = options.AgentConnectionUpdateFrequency * 2
|
||||
}
|
||||
if options.AgentStatsRefreshInterval == 0 {
|
||||
options.AgentStatsRefreshInterval = 10 * time.Minute
|
||||
options.AgentStatsRefreshInterval = 5 * time.Minute
|
||||
}
|
||||
if options.MetricsCacheRefreshInterval == 0 {
|
||||
options.MetricsCacheRefreshInterval = time.Hour
|
||||
@ -131,12 +133,6 @@ func New(options *Options) *API {
|
||||
if options.APIRateLimit == 0 {
|
||||
options.APIRateLimit = 512
|
||||
}
|
||||
if options.AgentStatsRefreshInterval == 0 {
|
||||
options.AgentStatsRefreshInterval = 5 * time.Minute
|
||||
}
|
||||
if options.MetricsCacheRefreshInterval == 0 {
|
||||
options.MetricsCacheRefreshInterval = time.Hour
|
||||
}
|
||||
if options.Authorizer == nil {
|
||||
options.Authorizer = rbac.NewAuthorizer()
|
||||
}
|
||||
@ -181,6 +177,13 @@ func New(options *Options) *API {
|
||||
metricsCache: metricsCache,
|
||||
Auditor: atomic.Pointer[audit.Auditor]{},
|
||||
}
|
||||
if options.UpdateCheckOptions != nil {
|
||||
api.updateChecker = updatecheck.New(
|
||||
options.Database,
|
||||
options.Logger.Named("update_checker"),
|
||||
*options.UpdateCheckOptions,
|
||||
)
|
||||
}
|
||||
api.Auditor.Store(&options.Auditor)
|
||||
api.workspaceAgentCache = wsconncache.New(api.dialWorkspaceAgentTailnet, 0)
|
||||
api.TailnetCoordinator.Store(&options.TailnetCoordinator)
|
||||
@ -308,6 +311,9 @@ func New(options *Options) *API {
|
||||
})
|
||||
})
|
||||
})
|
||||
r.Route("/updatecheck", func(r chi.Router) {
|
||||
r.Get("/", api.updateCheck)
|
||||
})
|
||||
r.Route("/config", func(r chi.Router) {
|
||||
r.Use(apiKeyMiddleware)
|
||||
r.Get("/deployment", api.deploymentConfig)
|
||||
@ -590,13 +596,14 @@ type API struct {
|
||||
// RootHandler serves "/"
|
||||
RootHandler chi.Router
|
||||
|
||||
metricsCache *metricscache.Cache
|
||||
siteHandler http.Handler
|
||||
siteHandler http.Handler
|
||||
|
||||
WebsocketWaitMutex sync.Mutex
|
||||
WebsocketWaitGroup sync.WaitGroup
|
||||
|
||||
metricsCache *metricscache.Cache
|
||||
workspaceAgentCache *wsconncache.Cache
|
||||
updateChecker *updatecheck.Checker
|
||||
}
|
||||
|
||||
// Close waits for all WebSocket connections to drain before returning.
|
||||
@ -606,6 +613,9 @@ func (api *API) Close() error {
|
||||
api.WebsocketWaitMutex.Unlock()
|
||||
|
||||
api.metricsCache.Close()
|
||||
if api.updateChecker != nil {
|
||||
api.updateChecker.Close()
|
||||
}
|
||||
coordinator := api.TailnetCoordinator.Load()
|
||||
if coordinator != nil {
|
||||
_ = (*coordinator).Close()
|
||||
|
@ -48,6 +48,7 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
|
||||
"GET:/healthz": {NoAuthorize: true},
|
||||
"GET:/api/v2": {NoAuthorize: true},
|
||||
"GET:/api/v2/buildinfo": {NoAuthorize: true},
|
||||
"GET:/api/v2/updatecheck": {NoAuthorize: true},
|
||||
"GET:/api/v2/users/first": {NoAuthorize: true},
|
||||
"POST:/api/v2/users/first": {NoAuthorize: true},
|
||||
"POST:/api/v2/users/login": {NoAuthorize: true},
|
||||
|
@ -64,6 +64,7 @@ import (
|
||||
"github.com/coder/coder/coderd/httpmw"
|
||||
"github.com/coder/coder/coderd/rbac"
|
||||
"github.com/coder/coder/coderd/telemetry"
|
||||
"github.com/coder/coder/coderd/updatecheck"
|
||||
"github.com/coder/coder/coderd/util/ptr"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/cryptorand"
|
||||
@ -102,6 +103,9 @@ type Options struct {
|
||||
AgentStatsRefreshInterval time.Duration
|
||||
DeploymentConfig *codersdk.DeploymentConfig
|
||||
|
||||
// Set update check options to enable update check.
|
||||
UpdateCheckOptions *updatecheck.Options
|
||||
|
||||
// Overriding the database is heavily discouraged.
|
||||
// It should only be used in cases where multiple Coder
|
||||
// test instances are running against the same database.
|
||||
@ -283,6 +287,7 @@ func NewOptions(t *testing.T, options *Options) (func(http.Handler), context.Can
|
||||
MetricsCacheRefreshInterval: options.MetricsCacheRefreshInterval,
|
||||
AgentStatsRefreshInterval: options.AgentStatsRefreshInterval,
|
||||
DeploymentConfig: options.DeploymentConfig,
|
||||
UpdateCheckOptions: options.UpdateCheckOptions,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -119,9 +119,10 @@ type data struct {
|
||||
workspaceResources []database.WorkspaceResource
|
||||
workspaces []database.Workspace
|
||||
|
||||
deploymentID string
|
||||
derpMeshKey string
|
||||
lastLicenseID int32
|
||||
deploymentID string
|
||||
derpMeshKey string
|
||||
lastUpdateCheck []byte
|
||||
lastLicenseID int32
|
||||
}
|
||||
|
||||
func (fakeQuerier) IsFakeDB() {}
|
||||
@ -3272,6 +3273,24 @@ func (q *fakeQuerier) GetDERPMeshKey(_ context.Context) (string, error) {
|
||||
return q.derpMeshKey, nil
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) InsertOrUpdateLastUpdateCheck(_ context.Context, data string) error {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
q.lastUpdateCheck = []byte(data)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) GetLastUpdateCheck(_ context.Context) (string, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
if q.lastUpdateCheck == nil {
|
||||
return "", sql.ErrNoRows
|
||||
}
|
||||
return string(q.lastUpdateCheck), nil
|
||||
}
|
||||
|
||||
func (q *fakeQuerier) InsertLicense(
|
||||
_ context.Context, arg database.InsertLicenseParams,
|
||||
) (database.License, error) {
|
||||
|
@ -50,6 +50,7 @@ type sqlcQuerier interface {
|
||||
GetGroupByOrgAndName(ctx context.Context, arg GetGroupByOrgAndNameParams) (Group, error)
|
||||
GetGroupMembers(ctx context.Context, groupID uuid.UUID) ([]User, error)
|
||||
GetGroupsByOrganizationID(ctx context.Context, organizationID uuid.UUID) ([]Group, error)
|
||||
GetLastUpdateCheck(ctx context.Context) (string, error)
|
||||
GetLatestAgentStat(ctx context.Context, agentID uuid.UUID) (AgentStat, error)
|
||||
GetLatestWorkspaceBuildByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) (WorkspaceBuild, error)
|
||||
GetLatestWorkspaceBuilds(ctx context.Context) ([]WorkspaceBuild, error)
|
||||
@ -141,6 +142,7 @@ type sqlcQuerier interface {
|
||||
InsertGroup(ctx context.Context, arg InsertGroupParams) (Group, error)
|
||||
InsertGroupMember(ctx context.Context, arg InsertGroupMemberParams) error
|
||||
InsertLicense(ctx context.Context, arg InsertLicenseParams) (License, error)
|
||||
InsertOrUpdateLastUpdateCheck(ctx context.Context, value string) error
|
||||
InsertOrganization(ctx context.Context, arg InsertOrganizationParams) (Organization, error)
|
||||
InsertOrganizationMember(ctx context.Context, arg InsertOrganizationMemberParams) (OrganizationMember, error)
|
||||
InsertParameterSchema(ctx context.Context, arg InsertParameterSchemaParams) (ParameterSchema, error)
|
||||
|
@ -2970,6 +2970,17 @@ func (q *sqlQuerier) GetDeploymentID(ctx context.Context) (string, error) {
|
||||
return value, err
|
||||
}
|
||||
|
||||
const getLastUpdateCheck = `-- name: GetLastUpdateCheck :one
|
||||
SELECT value FROM site_configs WHERE key = 'last_update_check'
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetLastUpdateCheck(ctx context.Context) (string, error) {
|
||||
row := q.db.QueryRowContext(ctx, getLastUpdateCheck)
|
||||
var value string
|
||||
err := row.Scan(&value)
|
||||
return value, err
|
||||
}
|
||||
|
||||
const insertDERPMeshKey = `-- name: InsertDERPMeshKey :exec
|
||||
INSERT INTO site_configs (key, value) VALUES ('derp_mesh_key', $1)
|
||||
`
|
||||
@ -2988,6 +2999,16 @@ func (q *sqlQuerier) InsertDeploymentID(ctx context.Context, value string) error
|
||||
return err
|
||||
}
|
||||
|
||||
const insertOrUpdateLastUpdateCheck = `-- name: InsertOrUpdateLastUpdateCheck :exec
|
||||
INSERT INTO site_configs (key, value) VALUES ('last_update_check', $1)
|
||||
ON CONFLICT (key) DO UPDATE SET value = $1 WHERE site_configs.key = 'last_update_check'
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) InsertOrUpdateLastUpdateCheck(ctx context.Context, value string) error {
|
||||
_, err := q.db.ExecContext(ctx, insertOrUpdateLastUpdateCheck, value)
|
||||
return err
|
||||
}
|
||||
|
||||
const getTemplateAverageBuildTime = `-- name: GetTemplateAverageBuildTime :one
|
||||
WITH build_times AS (
|
||||
SELECT
|
||||
|
@ -9,3 +9,10 @@ INSERT INTO site_configs (key, value) VALUES ('derp_mesh_key', $1);
|
||||
|
||||
-- name: GetDERPMeshKey :one
|
||||
SELECT value FROM site_configs WHERE key = 'derp_mesh_key';
|
||||
|
||||
-- name: InsertOrUpdateLastUpdateCheck :exec
|
||||
INSERT INTO site_configs (key, value) VALUES ('last_update_check', $1)
|
||||
ON CONFLICT (key) DO UPDATE SET value = $1 WHERE site_configs.key = 'last_update_check';
|
||||
|
||||
-- name: GetLastUpdateCheck :one
|
||||
SELECT value FROM site_configs WHERE key = 'last_update_check';
|
||||
|
54
coderd/updatecheck.go
Normal file
54
coderd/updatecheck.go
Normal file
@ -0,0 +1,54 @@
|
||||
package coderd
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/mod/semver"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/buildinfo"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/codersdk"
|
||||
)
|
||||
|
||||
func (api *API) updateCheck(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
currentVersion := codersdk.UpdateCheckResponse{
|
||||
Current: true,
|
||||
Version: buildinfo.Version(),
|
||||
URL: buildinfo.ExternalURL(),
|
||||
}
|
||||
|
||||
if api.updateChecker == nil {
|
||||
// If update checking is disabled, echo the current
|
||||
// version.
|
||||
httpapi.Write(ctx, rw, http.StatusOK, currentVersion)
|
||||
return
|
||||
}
|
||||
|
||||
uc, err := api.updateChecker.Latest(ctx)
|
||||
if err != nil {
|
||||
if xerrors.Is(err, sql.ErrNoRows) {
|
||||
// Update checking is enabled, but has never
|
||||
// succeeded, reproduce behavior as if disabled.
|
||||
httpapi.Write(ctx, rw, http.StatusOK, currentVersion)
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.InternalServerError(rw, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Since our dev version (v0.12.9-devel+f7246386) is not semver compatible,
|
||||
// ignore everything after "-"."
|
||||
versionWithoutDevel := strings.SplitN(buildinfo.Version(), "-", 2)[0]
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.UpdateCheckResponse{
|
||||
Current: semver.Compare(versionWithoutDevel, uc.Version) >= 0,
|
||||
Version: uc.Version,
|
||||
URL: uc.URL,
|
||||
})
|
||||
}
|
238
coderd/updatecheck/updatecheck.go
Normal file
238
coderd/updatecheck/updatecheck.go
Normal file
@ -0,0 +1,238 @@
|
||||
// Package updatecheck provides a mechanism for periodically checking
|
||||
// for updates to Coder.
|
||||
//
|
||||
// The update check is performed by querying the GitHub API for the
|
||||
// latest release of Coder.
|
||||
package updatecheck
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-github/v43/github"
|
||||
"golang.org/x/mod/semver"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
)
|
||||
|
||||
const (
|
||||
// defaultURL defines the URL to check for the latest version of Coder.
|
||||
defaultURL = "https://api.github.com/repos/coder/coder/releases/latest"
|
||||
)
|
||||
|
||||
// Checker is responsible for periodically checking for updates.
|
||||
type Checker struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
db database.Store
|
||||
log slog.Logger
|
||||
opts Options
|
||||
firstCheck chan struct{}
|
||||
closed chan struct{}
|
||||
}
|
||||
|
||||
// Options set optional parameters for the update check.
|
||||
type Options struct {
|
||||
// Client is the HTTP client to use for the update check,
|
||||
// if omitted, http.DefaultClient will be used.
|
||||
Client *http.Client
|
||||
// URL is the URL to check for the latest version of Coder,
|
||||
// if omitted, the default URL will be used.
|
||||
URL string
|
||||
// Interval is the interval at which to check for updates,
|
||||
// default 24h.
|
||||
Interval time.Duration
|
||||
// UpdateTimeout sets the timeout for the update check,
|
||||
// default 30s.
|
||||
UpdateTimeout time.Duration
|
||||
// Notify is called when a newer version of Coder (than the
|
||||
// last update check) is available.
|
||||
Notify func(r Result)
|
||||
}
|
||||
|
||||
// New returns a new Checker that periodically checks for Coder updates.
|
||||
func New(db database.Store, log slog.Logger, opts Options) *Checker {
|
||||
if opts.Client == nil {
|
||||
opts.Client = http.DefaultClient
|
||||
}
|
||||
if opts.URL == "" {
|
||||
opts.URL = defaultURL
|
||||
}
|
||||
if opts.Interval == 0 {
|
||||
opts.Interval = 24 * time.Hour
|
||||
}
|
||||
if opts.UpdateTimeout == 0 {
|
||||
opts.UpdateTimeout = 30 * time.Second
|
||||
}
|
||||
if opts.Notify == nil {
|
||||
opts.Notify = func(r Result) {}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
c := &Checker{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
db: db,
|
||||
log: log,
|
||||
opts: opts,
|
||||
firstCheck: make(chan struct{}),
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
go c.start()
|
||||
return c
|
||||
}
|
||||
|
||||
// Result is the result from the last update check.
|
||||
type Result struct {
|
||||
Checked time.Time `json:"checked,omitempty"`
|
||||
Version string `json:"version,omitempty"`
|
||||
URL string `json:"url,omitempty"`
|
||||
}
|
||||
|
||||
// Latest returns the latest version of Coder.
|
||||
func (c *Checker) Latest(ctx context.Context) (r Result, err error) {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return r, c.ctx.Err()
|
||||
case <-ctx.Done():
|
||||
return r, ctx.Err()
|
||||
case <-c.firstCheck:
|
||||
}
|
||||
|
||||
return c.lastUpdateCheck(ctx)
|
||||
}
|
||||
|
||||
func (c *Checker) init() (Result, error) {
|
||||
defer close(c.firstCheck)
|
||||
|
||||
r, err := c.lastUpdateCheck(c.ctx)
|
||||
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
|
||||
return Result{}, xerrors.Errorf("last update check: %w", err)
|
||||
}
|
||||
if r.Checked.IsZero() || time.Since(r.Checked) > c.opts.Interval {
|
||||
r, err = c.update()
|
||||
if err != nil {
|
||||
return Result{}, xerrors.Errorf("update check failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (c *Checker) start() {
|
||||
defer close(c.closed)
|
||||
|
||||
r, err := c.init()
|
||||
if err != nil {
|
||||
if xerrors.Is(err, context.Canceled) {
|
||||
return
|
||||
}
|
||||
c.log.Error(c.ctx, "init failed", slog.Error(err))
|
||||
} else {
|
||||
c.opts.Notify(r)
|
||||
}
|
||||
|
||||
t := time.NewTicker(c.opts.Interval)
|
||||
defer t.Stop()
|
||||
|
||||
diff := time.Until(r.Checked.Add(c.opts.Interval))
|
||||
if diff > 0 {
|
||||
c.log.Info(c.ctx, "time until next update check", slog.F("duration", diff))
|
||||
t.Reset(diff)
|
||||
} else {
|
||||
c.log.Info(c.ctx, "time until next update check", slog.F("duration", c.opts.Interval))
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-t.C:
|
||||
rr, err := c.update()
|
||||
if err != nil {
|
||||
if xerrors.Is(err, context.Canceled) {
|
||||
return
|
||||
}
|
||||
c.log.Error(c.ctx, "update check failed", slog.Error(err))
|
||||
} else {
|
||||
c.notifyIfNewer(r, rr)
|
||||
r = rr
|
||||
}
|
||||
c.log.Info(c.ctx, "time until next update check", slog.F("duration", c.opts.Interval))
|
||||
t.Reset(c.opts.Interval)
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Checker) update() (r Result, err error) {
|
||||
ctx, cancel := context.WithTimeout(c.ctx, c.opts.UpdateTimeout)
|
||||
defer cancel()
|
||||
|
||||
c.log.Info(c.ctx, "checking for update")
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.opts.URL, nil)
|
||||
if err != nil {
|
||||
return r, xerrors.Errorf("new request: %w", err)
|
||||
}
|
||||
resp, err := c.opts.Client.Do(req)
|
||||
if err != nil {
|
||||
return r, xerrors.Errorf("client do: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
b, _ := io.ReadAll(resp.Body)
|
||||
return r, xerrors.Errorf("unexpected status code %d: %s", resp.StatusCode, b)
|
||||
}
|
||||
|
||||
var rr github.RepositoryRelease
|
||||
err = json.NewDecoder(resp.Body).Decode(&rr)
|
||||
if err != nil {
|
||||
return r, xerrors.Errorf("json decode: %w", err)
|
||||
}
|
||||
|
||||
r = Result{
|
||||
Checked: time.Now(),
|
||||
Version: rr.GetTagName(),
|
||||
URL: rr.GetHTMLURL(),
|
||||
}
|
||||
c.log.Info(ctx, "update check result", slog.F("latest_version", r.Version))
|
||||
|
||||
b, err := json.Marshal(r)
|
||||
if err != nil {
|
||||
return r, xerrors.Errorf("json marshal result: %w", err)
|
||||
}
|
||||
|
||||
err = c.db.InsertOrUpdateLastUpdateCheck(ctx, string(b))
|
||||
if err != nil {
|
||||
return r, err
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (c *Checker) notifyIfNewer(prev, next Result) {
|
||||
if (prev.Version == "" && next.Version != "") || semver.Compare(next.Version, prev.Version) > 0 {
|
||||
c.opts.Notify(next)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Checker) lastUpdateCheck(ctx context.Context) (r Result, err error) {
|
||||
s, err := c.db.GetLastUpdateCheck(ctx)
|
||||
if err != nil {
|
||||
return r, err
|
||||
}
|
||||
return r, json.Unmarshal([]byte(s), &r)
|
||||
}
|
||||
|
||||
func (c *Checker) Close() error {
|
||||
c.cancel()
|
||||
<-c.closed
|
||||
return nil
|
||||
}
|
158
coderd/updatecheck/updatecheck_test.go
Normal file
158
coderd/updatecheck/updatecheck_test.go
Normal file
@ -0,0 +1,158 @@
|
||||
package updatecheck_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-github/v43/github"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/goleak"
|
||||
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
|
||||
"github.com/coder/coder/coderd/database/databasefake"
|
||||
"github.com/coder/coder/coderd/updatecheck"
|
||||
"github.com/coder/coder/testutil"
|
||||
)
|
||||
|
||||
func TestChecker_Notify(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
responses := []github.RepositoryRelease{
|
||||
{TagName: github.String("v1.2.3"), HTMLURL: github.String("https://someurl.com")},
|
||||
{TagName: github.String("v1.2.4"), HTMLURL: github.String("https://someurl.com")},
|
||||
{TagName: github.String("v1.2.4"), HTMLURL: github.String("https://someurl.com")},
|
||||
{TagName: github.String("v1.2.5"), HTMLURL: github.String("https://someurl.com")},
|
||||
}
|
||||
responseC := make(chan github.RepositoryRelease, len(responses))
|
||||
for _, r := range responses {
|
||||
responseC <- r
|
||||
}
|
||||
|
||||
wantVersion := []string{"v1.2.3", "v1.2.4", "v1.2.5"}
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
select {
|
||||
case <-r.Context().Done():
|
||||
case resp := <-responseC:
|
||||
b, err := json.Marshal(resp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(b)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
db := databasefake.New()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Named(t.Name())
|
||||
notify := make(chan updatecheck.Result, len(wantVersion))
|
||||
c := updatecheck.New(db, logger, updatecheck.Options{
|
||||
Interval: 1 * time.Nanosecond, // Zero means unset.
|
||||
URL: srv.URL,
|
||||
Notify: func(r updatecheck.Result) {
|
||||
select {
|
||||
case notify <- r:
|
||||
default:
|
||||
t.Error("unexpected notification")
|
||||
}
|
||||
},
|
||||
})
|
||||
defer c.Close()
|
||||
|
||||
ctx, _ := testutil.Context(t)
|
||||
|
||||
for i := 0; i < len(wantVersion); i++ {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Error("timed out waiting for notification")
|
||||
case r := <-notify:
|
||||
assert.Equal(t, wantVersion[i], r.Version)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestChecker_Latest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rr := github.RepositoryRelease{
|
||||
TagName: github.String("v1.2.3"),
|
||||
HTMLURL: github.String("https://someurl.com"),
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
release github.RepositoryRelease
|
||||
wantR updatecheck.Result
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "check latest",
|
||||
release: rr,
|
||||
wantR: updatecheck.Result{
|
||||
Version: "v1.2.3",
|
||||
URL: "https://someurl.com",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "missing release data",
|
||||
release: github.RepositoryRelease{},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "error",
|
||||
release: rr,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
if tt.wantErr {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
rrJSON, err := json.Marshal(rr)
|
||||
assert.NoError(t, err)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(rrJSON)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
db := databasefake.New()
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Named(t.Name())
|
||||
c := updatecheck.New(db, logger, updatecheck.Options{
|
||||
URL: srv.URL,
|
||||
})
|
||||
defer c.Close()
|
||||
|
||||
ctx, _ := testutil.Context(t)
|
||||
_ = ctx
|
||||
|
||||
gotR, err := c.Latest(ctx)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
// Zero out the time so we can compare the rest of the struct.
|
||||
gotR.Checked = time.Time{}
|
||||
require.Equal(t, tt.wantR, gotR, "wrong version")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
goleak.VerifyTestMain(m)
|
||||
}
|
79
coderd/updatecheck_test.go
Normal file
79
coderd/updatecheck_test.go
Normal file
@ -0,0 +1,79 @@
|
||||
package coderd_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-github/v43/github"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/buildinfo"
|
||||
"github.com/coder/coder/coderd/coderdtest"
|
||||
"github.com/coder/coder/coderd/updatecheck"
|
||||
"github.com/coder/coder/codersdk"
|
||||
"github.com/coder/coder/testutil"
|
||||
)
|
||||
|
||||
func TestUpdateCheck_NewVersion(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
resp github.RepositoryRelease
|
||||
want codersdk.UpdateCheckResponse
|
||||
}{
|
||||
{
|
||||
name: "New version",
|
||||
resp: github.RepositoryRelease{
|
||||
TagName: github.String("v99.999.999"),
|
||||
HTMLURL: github.String("https://someurl.com"),
|
||||
},
|
||||
want: codersdk.UpdateCheckResponse{
|
||||
Current: false,
|
||||
Version: "v99.999.999",
|
||||
URL: "https://someurl.com",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Same version",
|
||||
resp: github.RepositoryRelease{
|
||||
TagName: github.String(buildinfo.Version()),
|
||||
HTMLURL: github.String("https://someurl.com"),
|
||||
},
|
||||
want: codersdk.UpdateCheckResponse{
|
||||
Current: true,
|
||||
Version: buildinfo.Version(),
|
||||
URL: "https://someurl.com",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
b, err := json.Marshal(tt.resp)
|
||||
assert.NoError(t, err)
|
||||
w.Write(b)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
UpdateCheckOptions: &updatecheck.Options{
|
||||
URL: srv.URL,
|
||||
},
|
||||
})
|
||||
|
||||
ctx, _ := testutil.Context(t)
|
||||
|
||||
got, err := client.UpdateCheck(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user