mirror of
https://github.com/coder/coder.git
synced 2025-07-13 21:36:50 +00:00
chore: refactor agent connection updates (#11301)
Refactors the code that handles monitoring an agent websocket with pings and updating the connection times in the DB. Consolidates v1 and v2 agent APIs under the same code for this. One substantive change (not _just_ a refactor) is that I've made it so that we actually disconnect if the agent fails to respond to our pings, rather than the old behavior where we would update the database, but not actually tear down the websocket.
This commit is contained in:
@ -12,11 +12,9 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"net/url"
|
"net/url"
|
||||||
"runtime/pprof"
|
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
@ -42,7 +40,6 @@ import (
|
|||||||
"github.com/coder/coder/v2/coderd/httpmw"
|
"github.com/coder/coder/v2/coderd/httpmw"
|
||||||
"github.com/coder/coder/v2/coderd/prometheusmetrics"
|
"github.com/coder/coder/v2/coderd/prometheusmetrics"
|
||||||
"github.com/coder/coder/v2/coderd/rbac"
|
"github.com/coder/coder/v2/coderd/rbac"
|
||||||
"github.com/coder/coder/v2/coderd/util/ptr"
|
|
||||||
"github.com/coder/coder/v2/codersdk"
|
"github.com/coder/coder/v2/codersdk"
|
||||||
"github.com/coder/coder/v2/codersdk/agentsdk"
|
"github.com/coder/coder/v2/codersdk/agentsdk"
|
||||||
"github.com/coder/coder/v2/tailnet"
|
"github.com/coder/coder/v2/tailnet"
|
||||||
@ -1084,21 +1081,10 @@ func (api *API) workspaceAgentCoordinate(rw http.ResponseWriter, r *http.Request
|
|||||||
api.WebsocketWaitMutex.Unlock()
|
api.WebsocketWaitMutex.Unlock()
|
||||||
defer api.WebsocketWaitGroup.Done()
|
defer api.WebsocketWaitGroup.Done()
|
||||||
workspaceAgent := httpmw.WorkspaceAgent(r)
|
workspaceAgent := httpmw.WorkspaceAgent(r)
|
||||||
resource, err := api.Database.GetWorkspaceResourceByID(ctx, workspaceAgent.ResourceID)
|
// Ensure the resource is still valid!
|
||||||
if err != nil {
|
// We only accept agents for resources on the latest build.
|
||||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
build, ok := ensureLatestBuild(ctx, api.Database, api.Logger, rw, workspaceAgent)
|
||||||
Message: "Failed to accept websocket.",
|
if !ok {
|
||||||
Detail: err.Error(),
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
build, err := api.Database.GetWorkspaceBuildByJobID(ctx, resource.JobID)
|
|
||||||
if err != nil {
|
|
||||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
|
||||||
Message: "Internal error fetching workspace build job.",
|
|
||||||
Detail: err.Error(),
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1120,32 +1106,6 @@ func (api *API) workspaceAgentCoordinate(rw http.ResponseWriter, r *http.Request
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure the resource is still valid!
|
|
||||||
// We only accept agents for resources on the latest build.
|
|
||||||
ensureLatestBuild := func() error {
|
|
||||||
latestBuild, err := api.Database.GetLatestWorkspaceBuildByWorkspaceID(ctx, build.WorkspaceID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if build.ID != latestBuild.ID {
|
|
||||||
return xerrors.New("build is outdated")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
err = ensureLatestBuild()
|
|
||||||
if err != nil {
|
|
||||||
api.Logger.Debug(ctx, "agent tried to connect from non-latest build",
|
|
||||||
slog.F("resource", resource),
|
|
||||||
slog.F("agent", workspaceAgent),
|
|
||||||
)
|
|
||||||
httpapi.Write(ctx, rw, http.StatusForbidden, codersdk.Response{
|
|
||||||
Message: "Agent trying to connect from non-latest build.",
|
|
||||||
Detail: err.Error(),
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, err := websocket.Accept(rw, r, nil)
|
conn, err := websocket.Accept(rw, r, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||||
@ -1158,109 +1118,10 @@ func (api *API) workspaceAgentCoordinate(rw http.ResponseWriter, r *http.Request
|
|||||||
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary)
|
ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary)
|
||||||
defer wsNetConn.Close()
|
defer wsNetConn.Close()
|
||||||
|
|
||||||
// We use a custom heartbeat routine here instead of `httpapi.Heartbeat`
|
closeCtx, closeCtxCancel := context.WithCancel(ctx)
|
||||||
// because we want to log the agent's last ping time.
|
defer closeCtxCancel()
|
||||||
var lastPing atomic.Pointer[time.Time]
|
monitor := api.startAgentWebsocketMonitor(closeCtx, workspaceAgent, build, conn)
|
||||||
lastPing.Store(ptr.Ref(time.Now())) // Since the agent initiated the request, assume it's alive.
|
defer monitor.close()
|
||||||
|
|
||||||
go pprof.Do(ctx, pprof.Labels("agent", workspaceAgent.ID.String()), func(ctx context.Context) {
|
|
||||||
// TODO(mafredri): Is this too frequent? Use separate ping disconnect timeout?
|
|
||||||
t := time.NewTicker(api.AgentConnectionUpdateFrequency)
|
|
||||||
defer t.Stop()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-t.C:
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// We don't need a context that times out here because the ping will
|
|
||||||
// eventually go through. If the context times out, then other
|
|
||||||
// websocket read operations will receive an error, obfuscating the
|
|
||||||
// actual problem.
|
|
||||||
err := conn.Ping(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
lastPing.Store(ptr.Ref(time.Now()))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
firstConnectedAt := workspaceAgent.FirstConnectedAt
|
|
||||||
if !firstConnectedAt.Valid {
|
|
||||||
firstConnectedAt = sql.NullTime{
|
|
||||||
Time: dbtime.Now(),
|
|
||||||
Valid: true,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
lastConnectedAt := sql.NullTime{
|
|
||||||
Time: dbtime.Now(),
|
|
||||||
Valid: true,
|
|
||||||
}
|
|
||||||
disconnectedAt := workspaceAgent.DisconnectedAt
|
|
||||||
updateConnectionTimes := func(ctx context.Context) error {
|
|
||||||
//nolint:gocritic // We only update ourself.
|
|
||||||
err = api.Database.UpdateWorkspaceAgentConnectionByID(dbauthz.AsSystemRestricted(ctx), database.UpdateWorkspaceAgentConnectionByIDParams{
|
|
||||||
ID: workspaceAgent.ID,
|
|
||||||
FirstConnectedAt: firstConnectedAt,
|
|
||||||
LastConnectedAt: lastConnectedAt,
|
|
||||||
DisconnectedAt: disconnectedAt,
|
|
||||||
UpdatedAt: dbtime.Now(),
|
|
||||||
LastConnectedReplicaID: uuid.NullUUID{
|
|
||||||
UUID: api.ID,
|
|
||||||
Valid: true,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
// If connection closed then context will be canceled, try to
|
|
||||||
// ensure our final update is sent. By waiting at most the agent
|
|
||||||
// inactive disconnect timeout we ensure that we don't block but
|
|
||||||
// also guarantee that the agent will be considered disconnected
|
|
||||||
// by normal status check.
|
|
||||||
//
|
|
||||||
// Use a system context as the agent has disconnected and that token
|
|
||||||
// may no longer be valid.
|
|
||||||
//nolint:gocritic
|
|
||||||
ctx, cancel := context.WithTimeout(dbauthz.AsSystemRestricted(api.ctx), api.AgentInactiveDisconnectTimeout)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
// Only update timestamp if the disconnect is new.
|
|
||||||
if !disconnectedAt.Valid {
|
|
||||||
disconnectedAt = sql.NullTime{
|
|
||||||
Time: dbtime.Now(),
|
|
||||||
Valid: true,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
err := updateConnectionTimes(ctx)
|
|
||||||
if err != nil {
|
|
||||||
// This is a bug with unit tests that cancel the app context and
|
|
||||||
// cause this error log to be generated. We should fix the unit tests
|
|
||||||
// as this is a valid log.
|
|
||||||
//
|
|
||||||
// The pq error occurs when the server is shutting down.
|
|
||||||
if !xerrors.Is(err, context.Canceled) && !database.IsQueryCanceledError(err) {
|
|
||||||
api.Logger.Error(ctx, "failed to update agent disconnect time",
|
|
||||||
slog.Error(err),
|
|
||||||
slog.F("workspace_id", build.WorkspaceID),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
api.publishWorkspaceUpdate(ctx, build.WorkspaceID)
|
|
||||||
}()
|
|
||||||
|
|
||||||
err = updateConnectionTimes(ctx)
|
|
||||||
if err != nil {
|
|
||||||
_ = conn.Close(websocket.StatusGoingAway, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
api.publishWorkspaceUpdate(ctx, build.WorkspaceID)
|
|
||||||
|
|
||||||
api.Logger.Debug(ctx, "accepting agent",
|
api.Logger.Debug(ctx, "accepting agent",
|
||||||
slog.F("owner", owner.Username),
|
slog.F("owner", owner.Username),
|
||||||
@ -1271,61 +1132,13 @@ func (api *API) workspaceAgentCoordinate(rw http.ResponseWriter, r *http.Request
|
|||||||
|
|
||||||
defer conn.Close(websocket.StatusNormalClosure, "")
|
defer conn.Close(websocket.StatusNormalClosure, "")
|
||||||
|
|
||||||
closeChan := make(chan struct{})
|
err = (*api.TailnetCoordinator.Load()).ServeAgent(wsNetConn, workspaceAgent.ID,
|
||||||
go func() {
|
fmt.Sprintf("%s-%s-%s", owner.Username, workspace.Name, workspaceAgent.Name),
|
||||||
defer close(closeChan)
|
)
|
||||||
err := (*api.TailnetCoordinator.Load()).ServeAgent(wsNetConn, workspaceAgent.ID,
|
if err != nil {
|
||||||
fmt.Sprintf("%s-%s-%s", owner.Username, workspace.Name, workspaceAgent.Name),
|
api.Logger.Warn(ctx, "tailnet coordinator agent error", slog.Error(err))
|
||||||
)
|
_ = conn.Close(websocket.StatusInternalError, err.Error())
|
||||||
if err != nil {
|
return
|
||||||
api.Logger.Warn(ctx, "tailnet coordinator agent error", slog.Error(err))
|
|
||||||
_ = conn.Close(websocket.StatusInternalError, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
ticker := time.NewTicker(api.AgentConnectionUpdateFrequency)
|
|
||||||
defer ticker.Stop()
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-closeChan:
|
|
||||||
return
|
|
||||||
case <-ticker.C:
|
|
||||||
}
|
|
||||||
|
|
||||||
lastPing := *lastPing.Load()
|
|
||||||
|
|
||||||
var connectionStatusChanged bool
|
|
||||||
if time.Since(lastPing) > api.AgentInactiveDisconnectTimeout {
|
|
||||||
if !disconnectedAt.Valid {
|
|
||||||
connectionStatusChanged = true
|
|
||||||
disconnectedAt = sql.NullTime{
|
|
||||||
Time: dbtime.Now(),
|
|
||||||
Valid: true,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
connectionStatusChanged = disconnectedAt.Valid
|
|
||||||
// TODO(mafredri): Should we update it here or allow lastConnectedAt to shadow it?
|
|
||||||
disconnectedAt = sql.NullTime{}
|
|
||||||
lastConnectedAt = sql.NullTime{
|
|
||||||
Time: dbtime.Now(),
|
|
||||||
Valid: true,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
err = updateConnectionTimes(ctx)
|
|
||||||
if err != nil {
|
|
||||||
_ = conn.Close(websocket.StatusGoingAway, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if connectionStatusChanged {
|
|
||||||
api.publishWorkspaceUpdate(ctx, build.WorkspaceID)
|
|
||||||
}
|
|
||||||
err := ensureLatestBuild()
|
|
||||||
if err != nil {
|
|
||||||
// Disconnect agents that are no longer valid.
|
|
||||||
_ = conn.Close(websocket.StatusGoingAway, "")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"runtime/pprof"
|
"runtime/pprof"
|
||||||
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -42,7 +43,7 @@ func (api *API) workspaceAgentRPC(rw http.ResponseWriter, r *http.Request) {
|
|||||||
defer api.WebsocketWaitGroup.Done()
|
defer api.WebsocketWaitGroup.Done()
|
||||||
workspaceAgent := httpmw.WorkspaceAgent(r)
|
workspaceAgent := httpmw.WorkspaceAgent(r)
|
||||||
|
|
||||||
ensureLatestBuildFn, build, ok := ensureLatestBuild(ctx, api.Database, api.Logger, rw, workspaceAgent)
|
build, ok := ensureLatestBuild(ctx, api.Database, api.Logger, rw, workspaceAgent)
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -96,10 +97,10 @@ func (api *API) workspaceAgentRPC(rw http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
defer conn.Close(websocket.StatusNormalClosure, "")
|
defer conn.Close(websocket.StatusNormalClosure, "")
|
||||||
|
|
||||||
pingFn, ok := api.agentConnectionUpdate(ctx, workspaceAgent, build.WorkspaceID, conn)
|
closeCtx, closeCtxCancel := context.WithCancel(ctx)
|
||||||
if !ok {
|
defer closeCtxCancel()
|
||||||
return
|
monitor := api.startAgentWebsocketMonitor(closeCtx, workspaceAgent, build, conn)
|
||||||
}
|
defer monitor.close()
|
||||||
|
|
||||||
agentAPI := agentapi.New(agentapi.Options{
|
agentAPI := agentapi.New(agentapi.Options{
|
||||||
AgentID: workspaceAgent.ID,
|
AgentID: workspaceAgent.ID,
|
||||||
@ -136,29 +137,22 @@ func (api *API) workspaceAgentRPC(rw http.ResponseWriter, r *http.Request) {
|
|||||||
Auth: tailnet.AgentTunnelAuth{},
|
Auth: tailnet.AgentTunnelAuth{},
|
||||||
}
|
}
|
||||||
ctx = tailnet.WithStreamID(ctx, streamID)
|
ctx = tailnet.WithStreamID(ctx, streamID)
|
||||||
|
err = agentAPI.Serve(ctx, mux)
|
||||||
closeCtx, closeCtxCancel := context.WithCancel(ctx)
|
if err != nil {
|
||||||
go func() {
|
api.Logger.Warn(ctx, "workspace agent RPC listen error", slog.Error(err))
|
||||||
defer closeCtxCancel()
|
_ = conn.Close(websocket.StatusInternalError, err.Error())
|
||||||
err := agentAPI.Serve(ctx, mux)
|
return
|
||||||
if err != nil {
|
}
|
||||||
api.Logger.Warn(ctx, "workspace agent RPC listen error", slog.Error(err))
|
|
||||||
_ = conn.Close(websocket.StatusInternalError, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
pingFn(closeCtx, ensureLatestBuildFn)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func ensureLatestBuild(ctx context.Context, db database.Store, logger slog.Logger, rw http.ResponseWriter, workspaceAgent database.WorkspaceAgent) (func() error, database.WorkspaceBuild, bool) {
|
func ensureLatestBuild(ctx context.Context, db database.Store, logger slog.Logger, rw http.ResponseWriter, workspaceAgent database.WorkspaceAgent) (database.WorkspaceBuild, bool) {
|
||||||
resource, err := db.GetWorkspaceResourceByID(ctx, workspaceAgent.ResourceID)
|
resource, err := db.GetWorkspaceResourceByID(ctx, workspaceAgent.ResourceID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||||
Message: "Internal error fetching workspace agent resource.",
|
Message: "Internal error fetching workspace agent resource.",
|
||||||
Detail: err.Error(),
|
Detail: err.Error(),
|
||||||
})
|
})
|
||||||
return nil, database.WorkspaceBuild{}, false
|
return database.WorkspaceBuild{}, false
|
||||||
}
|
}
|
||||||
|
|
||||||
build, err := db.GetWorkspaceBuildByJobID(ctx, resource.JobID)
|
build, err := db.GetWorkspaceBuildByJobID(ctx, resource.JobID)
|
||||||
@ -167,23 +161,12 @@ func ensureLatestBuild(ctx context.Context, db database.Store, logger slog.Logge
|
|||||||
Message: "Internal error fetching workspace build job.",
|
Message: "Internal error fetching workspace build job.",
|
||||||
Detail: err.Error(),
|
Detail: err.Error(),
|
||||||
})
|
})
|
||||||
return nil, database.WorkspaceBuild{}, false
|
return database.WorkspaceBuild{}, false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure the resource is still valid!
|
// Ensure the resource is still valid!
|
||||||
// We only accept agents for resources on the latest build.
|
// We only accept agents for resources on the latest build.
|
||||||
ensureLatestBuild := func() error {
|
err = checkBuildIsLatest(ctx, db, build)
|
||||||
latestBuild, err := db.GetLatestWorkspaceBuildByWorkspaceID(ctx, build.WorkspaceID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if build.ID != latestBuild.ID {
|
|
||||||
return xerrors.New("build is outdated")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
err = ensureLatestBuild()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Debug(ctx, "agent tried to connect from non-latest build",
|
logger.Debug(ctx, "agent tried to connect from non-latest build",
|
||||||
slog.F("resource", resource),
|
slog.F("resource", resource),
|
||||||
@ -193,73 +176,159 @@ func ensureLatestBuild(ctx context.Context, db database.Store, logger slog.Logge
|
|||||||
Message: "Agent trying to connect from non-latest build.",
|
Message: "Agent trying to connect from non-latest build.",
|
||||||
Detail: err.Error(),
|
Detail: err.Error(),
|
||||||
})
|
})
|
||||||
return nil, database.WorkspaceBuild{}, false
|
return database.WorkspaceBuild{}, false
|
||||||
}
|
}
|
||||||
|
|
||||||
return ensureLatestBuild, build, true
|
return build, true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (api *API) agentConnectionUpdate(ctx context.Context, workspaceAgent database.WorkspaceAgent, workspaceID uuid.UUID, conn *websocket.Conn) (func(closeCtx context.Context, ensureLatestBuildFn func() error), bool) {
|
func checkBuildIsLatest(ctx context.Context, db database.Store, build database.WorkspaceBuild) error {
|
||||||
// We use a custom heartbeat routine here instead of `httpapi.Heartbeat`
|
latestBuild, err := db.GetLatestWorkspaceBuildByWorkspaceID(ctx, build.WorkspaceID)
|
||||||
// because we want to log the agent's last ping time.
|
if err != nil {
|
||||||
var lastPing atomic.Pointer[time.Time]
|
return err
|
||||||
lastPing.Store(ptr.Ref(time.Now())) // Since the agent initiated the request, assume it's alive.
|
}
|
||||||
|
if build.ID != latestBuild.ID {
|
||||||
|
return xerrors.New("build is outdated")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
go pprof.Do(ctx, pprof.Labels("agent", workspaceAgent.ID.String()), func(ctx context.Context) {
|
func (api *API) startAgentWebsocketMonitor(ctx context.Context,
|
||||||
// TODO(mafredri): Is this too frequent? Use separate ping disconnect timeout?
|
workspaceAgent database.WorkspaceAgent, workspaceBuild database.WorkspaceBuild,
|
||||||
t := time.NewTicker(api.AgentConnectionUpdateFrequency)
|
conn *websocket.Conn,
|
||||||
defer t.Stop()
|
) *agentWebsocketMonitor {
|
||||||
|
monitor := &agentWebsocketMonitor{
|
||||||
|
apiCtx: api.ctx,
|
||||||
|
workspaceAgent: workspaceAgent,
|
||||||
|
workspaceBuild: workspaceBuild,
|
||||||
|
conn: conn,
|
||||||
|
pingPeriod: api.AgentConnectionUpdateFrequency,
|
||||||
|
db: api.Database,
|
||||||
|
replicaID: api.ID,
|
||||||
|
updater: api,
|
||||||
|
disconnectTimeout: api.AgentInactiveDisconnectTimeout,
|
||||||
|
logger: api.Logger.With(
|
||||||
|
slog.F("workspace_id", workspaceBuild.WorkspaceID),
|
||||||
|
slog.F("agent_id", workspaceAgent.ID),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
monitor.init()
|
||||||
|
monitor.start(ctx)
|
||||||
|
|
||||||
for {
|
return monitor
|
||||||
select {
|
}
|
||||||
case <-t.C:
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// We don't need a context that times out here because the ping will
|
type workspaceUpdater interface {
|
||||||
// eventually go through. If the context times out, then other
|
publishWorkspaceUpdate(ctx context.Context, workspaceID uuid.UUID)
|
||||||
// websocket read operations will receive an error, obfuscating the
|
}
|
||||||
// actual problem.
|
|
||||||
err := conn.Ping(ctx)
|
type pingerCloser interface {
|
||||||
if err != nil {
|
Ping(ctx context.Context) error
|
||||||
return
|
Close(code websocket.StatusCode, reason string) error
|
||||||
}
|
}
|
||||||
lastPing.Store(ptr.Ref(time.Now()))
|
|
||||||
|
type agentWebsocketMonitor struct {
|
||||||
|
apiCtx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
wg sync.WaitGroup
|
||||||
|
workspaceAgent database.WorkspaceAgent
|
||||||
|
workspaceBuild database.WorkspaceBuild
|
||||||
|
conn pingerCloser
|
||||||
|
db database.Store
|
||||||
|
replicaID uuid.UUID
|
||||||
|
updater workspaceUpdater
|
||||||
|
logger slog.Logger
|
||||||
|
pingPeriod time.Duration
|
||||||
|
|
||||||
|
// state manipulated by both sendPings() and monitor() goroutines: needs to be threadsafe
|
||||||
|
lastPing atomic.Pointer[time.Time]
|
||||||
|
|
||||||
|
// state manipulated only by monitor() goroutine: does not need to be threadsafe
|
||||||
|
firstConnectedAt sql.NullTime
|
||||||
|
lastConnectedAt sql.NullTime
|
||||||
|
disconnectedAt sql.NullTime
|
||||||
|
disconnectTimeout time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendPings sends websocket pings.
|
||||||
|
//
|
||||||
|
// We use a custom heartbeat routine here instead of `httpapi.Heartbeat`
|
||||||
|
// because we want to log the agent's last ping time.
|
||||||
|
func (m *agentWebsocketMonitor) sendPings(ctx context.Context) {
|
||||||
|
t := time.NewTicker(m.pingPeriod)
|
||||||
|
defer t.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-t.C:
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
}
|
}
|
||||||
})
|
|
||||||
|
|
||||||
firstConnectedAt := workspaceAgent.FirstConnectedAt
|
// We don't need a context that times out here because the ping will
|
||||||
if !firstConnectedAt.Valid {
|
// eventually go through. If the context times out, then other
|
||||||
firstConnectedAt = sql.NullTime{
|
// websocket read operations will receive an error, obfuscating the
|
||||||
Time: dbtime.Now(),
|
// actual problem.
|
||||||
|
err := m.conn.Ping(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
m.lastPing.Store(ptr.Ref(time.Now()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *agentWebsocketMonitor) updateConnectionTimes(ctx context.Context) error {
|
||||||
|
//nolint:gocritic // We only update the agent we are minding.
|
||||||
|
err := m.db.UpdateWorkspaceAgentConnectionByID(dbauthz.AsSystemRestricted(ctx), database.UpdateWorkspaceAgentConnectionByIDParams{
|
||||||
|
ID: m.workspaceAgent.ID,
|
||||||
|
FirstConnectedAt: m.firstConnectedAt,
|
||||||
|
LastConnectedAt: m.lastConnectedAt,
|
||||||
|
DisconnectedAt: m.disconnectedAt,
|
||||||
|
UpdatedAt: dbtime.Now(),
|
||||||
|
LastConnectedReplicaID: uuid.NullUUID{
|
||||||
|
UUID: m.replicaID,
|
||||||
|
Valid: true,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return xerrors.Errorf("failed to update workspace agent connection times: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *agentWebsocketMonitor) init() {
|
||||||
|
now := dbtime.Now()
|
||||||
|
m.firstConnectedAt = m.workspaceAgent.FirstConnectedAt
|
||||||
|
if !m.firstConnectedAt.Valid {
|
||||||
|
m.firstConnectedAt = sql.NullTime{
|
||||||
|
Time: now,
|
||||||
Valid: true,
|
Valid: true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
lastConnectedAt := sql.NullTime{
|
m.lastConnectedAt = sql.NullTime{
|
||||||
Time: dbtime.Now(),
|
Time: now,
|
||||||
Valid: true,
|
Valid: true,
|
||||||
}
|
}
|
||||||
disconnectedAt := workspaceAgent.DisconnectedAt
|
m.disconnectedAt = m.workspaceAgent.DisconnectedAt
|
||||||
updateConnectionTimes := func(ctx context.Context) error {
|
m.lastPing.Store(ptr.Ref(time.Now())) // Since the agent initiated the request, assume it's alive.
|
||||||
//nolint:gocritic // We only update ourself.
|
}
|
||||||
err := api.Database.UpdateWorkspaceAgentConnectionByID(dbauthz.AsSystemRestricted(ctx), database.UpdateWorkspaceAgentConnectionByIDParams{
|
|
||||||
ID: workspaceAgent.ID,
|
|
||||||
FirstConnectedAt: firstConnectedAt,
|
|
||||||
LastConnectedAt: lastConnectedAt,
|
|
||||||
DisconnectedAt: disconnectedAt,
|
|
||||||
UpdatedAt: dbtime.Now(),
|
|
||||||
LastConnectedReplicaID: uuid.NullUUID{
|
|
||||||
UUID: api.ID,
|
|
||||||
Valid: true,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
|
func (m *agentWebsocketMonitor) start(ctx context.Context) {
|
||||||
|
ctx, m.cancel = context.WithCancel(ctx)
|
||||||
|
m.wg.Add(2)
|
||||||
|
go pprof.Do(ctx, pprof.Labels("agent", m.workspaceAgent.ID.String()),
|
||||||
|
func(ctx context.Context) {
|
||||||
|
defer m.wg.Done()
|
||||||
|
m.sendPings(ctx)
|
||||||
|
})
|
||||||
|
go pprof.Do(ctx, pprof.Labels("agent", m.workspaceAgent.ID.String()),
|
||||||
|
func(ctx context.Context) {
|
||||||
|
defer m.wg.Done()
|
||||||
|
m.monitor(ctx)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *agentWebsocketMonitor) monitor(ctx context.Context) {
|
||||||
defer func() {
|
defer func() {
|
||||||
// If connection closed then context will be canceled, try to
|
// If connection closed then context will be canceled, try to
|
||||||
// ensure our final update is sent. By waiting at most the agent
|
// ensure our final update is sent. By waiting at most the agent
|
||||||
@ -270,17 +339,17 @@ func (api *API) agentConnectionUpdate(ctx context.Context, workspaceAgent databa
|
|||||||
// Use a system context as the agent has disconnected and that token
|
// Use a system context as the agent has disconnected and that token
|
||||||
// may no longer be valid.
|
// may no longer be valid.
|
||||||
//nolint:gocritic
|
//nolint:gocritic
|
||||||
ctx, cancel := context.WithTimeout(dbauthz.AsSystemRestricted(api.ctx), api.AgentInactiveDisconnectTimeout)
|
finalCtx, cancel := context.WithTimeout(dbauthz.AsSystemRestricted(m.apiCtx), m.disconnectTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
// Only update timestamp if the disconnect is new.
|
// Only update timestamp if the disconnect is new.
|
||||||
if !disconnectedAt.Valid {
|
if !m.disconnectedAt.Valid {
|
||||||
disconnectedAt = sql.NullTime{
|
m.disconnectedAt = sql.NullTime{
|
||||||
Time: dbtime.Now(),
|
Time: dbtime.Now(),
|
||||||
Valid: true,
|
Valid: true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
err := updateConnectionTimes(ctx)
|
err := m.updateConnectionTimes(finalCtx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// This is a bug with unit tests that cancel the app context and
|
// This is a bug with unit tests that cancel the app context and
|
||||||
// cause this error log to be generated. We should fix the unit tests
|
// cause this error log to be generated. We should fix the unit tests
|
||||||
@ -288,66 +357,66 @@ func (api *API) agentConnectionUpdate(ctx context.Context, workspaceAgent databa
|
|||||||
//
|
//
|
||||||
// The pq error occurs when the server is shutting down.
|
// The pq error occurs when the server is shutting down.
|
||||||
if !xerrors.Is(err, context.Canceled) && !database.IsQueryCanceledError(err) {
|
if !xerrors.Is(err, context.Canceled) && !database.IsQueryCanceledError(err) {
|
||||||
api.Logger.Error(ctx, "failed to update agent disconnect time",
|
m.logger.Error(finalCtx, "failed to update agent disconnect time",
|
||||||
slog.Error(err),
|
slog.Error(err),
|
||||||
slog.F("workspace_id", workspaceID),
|
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
api.publishWorkspaceUpdate(ctx, workspaceID)
|
m.updater.publishWorkspaceUpdate(finalCtx, m.workspaceBuild.WorkspaceID)
|
||||||
|
}()
|
||||||
|
reason := "disconnect"
|
||||||
|
defer func() {
|
||||||
|
m.logger.Debug(ctx, "agent websocket monitor is closing connection",
|
||||||
|
slog.F("reason", reason))
|
||||||
|
_ = m.conn.Close(websocket.StatusGoingAway, reason)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
err := updateConnectionTimes(ctx)
|
err := m.updateConnectionTimes(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = conn.Close(websocket.StatusGoingAway, err.Error())
|
reason = err.Error()
|
||||||
return nil, false
|
return
|
||||||
}
|
}
|
||||||
api.publishWorkspaceUpdate(ctx, workspaceID)
|
m.updater.publishWorkspaceUpdate(ctx, m.workspaceBuild.WorkspaceID)
|
||||||
|
|
||||||
return func(closeCtx context.Context, ensureLatestBuildFn func() error) {
|
ticker := time.NewTicker(m.pingPeriod)
|
||||||
ticker := time.NewTicker(api.AgentConnectionUpdateFrequency)
|
defer ticker.Stop()
|
||||||
defer ticker.Stop()
|
for {
|
||||||
for {
|
select {
|
||||||
select {
|
case <-ctx.Done():
|
||||||
case <-closeCtx.Done():
|
reason = "canceled"
|
||||||
return
|
return
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
}
|
|
||||||
|
|
||||||
lastPing := *lastPing.Load()
|
|
||||||
|
|
||||||
var connectionStatusChanged bool
|
|
||||||
if time.Since(lastPing) > api.AgentInactiveDisconnectTimeout {
|
|
||||||
if !disconnectedAt.Valid {
|
|
||||||
connectionStatusChanged = true
|
|
||||||
disconnectedAt = sql.NullTime{
|
|
||||||
Time: dbtime.Now(),
|
|
||||||
Valid: true,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
connectionStatusChanged = disconnectedAt.Valid
|
|
||||||
// TODO(mafredri): Should we update it here or allow lastConnectedAt to shadow it?
|
|
||||||
disconnectedAt = sql.NullTime{}
|
|
||||||
lastConnectedAt = sql.NullTime{
|
|
||||||
Time: dbtime.Now(),
|
|
||||||
Valid: true,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
err = updateConnectionTimes(ctx)
|
|
||||||
if err != nil {
|
|
||||||
_ = conn.Close(websocket.StatusGoingAway, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if connectionStatusChanged {
|
|
||||||
api.publishWorkspaceUpdate(ctx, workspaceID)
|
|
||||||
}
|
|
||||||
err := ensureLatestBuildFn()
|
|
||||||
if err != nil {
|
|
||||||
// Disconnect agents that are no longer valid.
|
|
||||||
_ = conn.Close(websocket.StatusGoingAway, "")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}, true
|
|
||||||
|
lastPing := *m.lastPing.Load()
|
||||||
|
if time.Since(lastPing) > m.disconnectTimeout {
|
||||||
|
reason = "ping timeout"
|
||||||
|
return
|
||||||
|
}
|
||||||
|
connectionStatusChanged := m.disconnectedAt.Valid
|
||||||
|
m.disconnectedAt = sql.NullTime{}
|
||||||
|
m.lastConnectedAt = sql.NullTime{
|
||||||
|
Time: dbtime.Now(),
|
||||||
|
Valid: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
err = m.updateConnectionTimes(ctx)
|
||||||
|
if err != nil {
|
||||||
|
reason = err.Error()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if connectionStatusChanged {
|
||||||
|
m.updater.publishWorkspaceUpdate(ctx, m.workspaceBuild.WorkspaceID)
|
||||||
|
}
|
||||||
|
err = checkBuildIsLatest(ctx, m.db, m.workspaceBuild)
|
||||||
|
if err != nil {
|
||||||
|
reason = err.Error()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *agentWebsocketMonitor) close() {
|
||||||
|
m.cancel()
|
||||||
|
m.wg.Wait()
|
||||||
}
|
}
|
||||||
|
436
coderd/workspaceagentsrpc_internal_test.go
Normal file
436
coderd/workspaceagentsrpc_internal_test.go
Normal file
@ -0,0 +1,436 @@
|
|||||||
|
package coderd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/coder/coder/v2/coderd/util/ptr"
|
||||||
|
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"nhooyr.io/websocket"
|
||||||
|
|
||||||
|
"cdr.dev/slog"
|
||||||
|
"cdr.dev/slog/sloggers/slogtest"
|
||||||
|
"github.com/coder/coder/v2/coderd/database"
|
||||||
|
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||||
|
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||||
|
"github.com/coder/coder/v2/testutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAgentWebsocketMonitor_ContextCancel(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
ctx := testutil.Context(t, testutil.WaitShort)
|
||||||
|
now := dbtime.Now()
|
||||||
|
fConn := &fakePingerCloser{}
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
mDB := dbmock.NewMockStore(ctrl)
|
||||||
|
fUpdater := &fakeUpdater{}
|
||||||
|
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||||
|
agent := database.WorkspaceAgent{
|
||||||
|
ID: uuid.New(),
|
||||||
|
FirstConnectedAt: sql.NullTime{
|
||||||
|
Time: now.Add(-time.Minute),
|
||||||
|
Valid: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
build := database.WorkspaceBuild{
|
||||||
|
ID: uuid.New(),
|
||||||
|
WorkspaceID: uuid.New(),
|
||||||
|
}
|
||||||
|
replicaID := uuid.New()
|
||||||
|
|
||||||
|
uut := &agentWebsocketMonitor{
|
||||||
|
apiCtx: ctx,
|
||||||
|
workspaceAgent: agent,
|
||||||
|
workspaceBuild: build,
|
||||||
|
conn: fConn,
|
||||||
|
db: mDB,
|
||||||
|
replicaID: replicaID,
|
||||||
|
updater: fUpdater,
|
||||||
|
logger: logger,
|
||||||
|
pingPeriod: testutil.IntervalFast,
|
||||||
|
disconnectTimeout: testutil.WaitShort,
|
||||||
|
}
|
||||||
|
uut.init()
|
||||||
|
|
||||||
|
connected := mDB.EXPECT().UpdateWorkspaceAgentConnectionByID(
|
||||||
|
gomock.Any(),
|
||||||
|
connectionUpdate(agent.ID, replicaID),
|
||||||
|
).
|
||||||
|
AnyTimes().
|
||||||
|
Return(nil)
|
||||||
|
mDB.EXPECT().UpdateWorkspaceAgentConnectionByID(
|
||||||
|
gomock.Any(),
|
||||||
|
connectionUpdate(agent.ID, replicaID, withDisconnected()),
|
||||||
|
).
|
||||||
|
After(connected).
|
||||||
|
Times(1).
|
||||||
|
Return(nil)
|
||||||
|
mDB.EXPECT().GetLatestWorkspaceBuildByWorkspaceID(gomock.Any(), build.WorkspaceID).
|
||||||
|
AnyTimes().
|
||||||
|
Return(database.WorkspaceBuild{ID: build.ID}, nil)
|
||||||
|
|
||||||
|
closeCtx, cancel := context.WithCancel(ctx)
|
||||||
|
defer cancel()
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
uut.monitor(closeCtx)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
// wait a couple intervals, but not long enough for a disconnect
|
||||||
|
time.Sleep(3 * testutil.IntervalFast)
|
||||||
|
fConn.requireNotClosed(t)
|
||||||
|
fUpdater.requireEventuallySomeUpdates(t, build.WorkspaceID)
|
||||||
|
n := fUpdater.getUpdates()
|
||||||
|
cancel()
|
||||||
|
fConn.requireEventuallyClosed(t, websocket.StatusGoingAway, "canceled")
|
||||||
|
|
||||||
|
// make sure we got at least one additional update on close
|
||||||
|
_ = testutil.RequireRecvCtx(ctx, t, done)
|
||||||
|
m := fUpdater.getUpdates()
|
||||||
|
require.Greater(t, m, n)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAgentWebsocketMonitor_PingTimeout(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
ctx := testutil.Context(t, testutil.WaitShort)
|
||||||
|
now := dbtime.Now()
|
||||||
|
fConn := &fakePingerCloser{}
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
mDB := dbmock.NewMockStore(ctrl)
|
||||||
|
fUpdater := &fakeUpdater{}
|
||||||
|
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||||
|
agent := database.WorkspaceAgent{
|
||||||
|
ID: uuid.New(),
|
||||||
|
FirstConnectedAt: sql.NullTime{
|
||||||
|
Time: now.Add(-time.Minute),
|
||||||
|
Valid: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
build := database.WorkspaceBuild{
|
||||||
|
ID: uuid.New(),
|
||||||
|
WorkspaceID: uuid.New(),
|
||||||
|
}
|
||||||
|
replicaID := uuid.New()
|
||||||
|
|
||||||
|
uut := &agentWebsocketMonitor{
|
||||||
|
apiCtx: ctx,
|
||||||
|
workspaceAgent: agent,
|
||||||
|
workspaceBuild: build,
|
||||||
|
conn: fConn,
|
||||||
|
db: mDB,
|
||||||
|
replicaID: replicaID,
|
||||||
|
updater: fUpdater,
|
||||||
|
logger: logger,
|
||||||
|
pingPeriod: testutil.IntervalFast,
|
||||||
|
disconnectTimeout: testutil.WaitShort,
|
||||||
|
}
|
||||||
|
uut.init()
|
||||||
|
// set the last ping to the past, so we go thru the timeout
|
||||||
|
uut.lastPing.Store(ptr.Ref(now.Add(-time.Hour)))
|
||||||
|
|
||||||
|
connected := mDB.EXPECT().UpdateWorkspaceAgentConnectionByID(
|
||||||
|
gomock.Any(),
|
||||||
|
connectionUpdate(agent.ID, replicaID),
|
||||||
|
).
|
||||||
|
AnyTimes().
|
||||||
|
Return(nil)
|
||||||
|
mDB.EXPECT().UpdateWorkspaceAgentConnectionByID(
|
||||||
|
gomock.Any(),
|
||||||
|
connectionUpdate(agent.ID, replicaID, withDisconnected()),
|
||||||
|
).
|
||||||
|
After(connected).
|
||||||
|
Times(1).
|
||||||
|
Return(nil)
|
||||||
|
mDB.EXPECT().GetLatestWorkspaceBuildByWorkspaceID(gomock.Any(), build.WorkspaceID).
|
||||||
|
AnyTimes().
|
||||||
|
Return(database.WorkspaceBuild{ID: build.ID}, nil)
|
||||||
|
|
||||||
|
go uut.monitor(ctx)
|
||||||
|
fConn.requireEventuallyClosed(t, websocket.StatusGoingAway, "ping timeout")
|
||||||
|
fUpdater.requireEventuallySomeUpdates(t, build.WorkspaceID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAgentWebsocketMonitor_BuildOutdated(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
ctx := testutil.Context(t, testutil.WaitShort)
|
||||||
|
now := dbtime.Now()
|
||||||
|
fConn := &fakePingerCloser{}
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
mDB := dbmock.NewMockStore(ctrl)
|
||||||
|
fUpdater := &fakeUpdater{}
|
||||||
|
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||||
|
agent := database.WorkspaceAgent{
|
||||||
|
ID: uuid.New(),
|
||||||
|
FirstConnectedAt: sql.NullTime{
|
||||||
|
Time: now.Add(-time.Minute),
|
||||||
|
Valid: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
build := database.WorkspaceBuild{
|
||||||
|
ID: uuid.New(),
|
||||||
|
WorkspaceID: uuid.New(),
|
||||||
|
}
|
||||||
|
replicaID := uuid.New()
|
||||||
|
|
||||||
|
uut := &agentWebsocketMonitor{
|
||||||
|
apiCtx: ctx,
|
||||||
|
workspaceAgent: agent,
|
||||||
|
workspaceBuild: build,
|
||||||
|
conn: fConn,
|
||||||
|
db: mDB,
|
||||||
|
replicaID: replicaID,
|
||||||
|
updater: fUpdater,
|
||||||
|
logger: logger,
|
||||||
|
pingPeriod: testutil.IntervalFast,
|
||||||
|
disconnectTimeout: testutil.WaitShort,
|
||||||
|
}
|
||||||
|
uut.init()
|
||||||
|
|
||||||
|
connected := mDB.EXPECT().UpdateWorkspaceAgentConnectionByID(
|
||||||
|
gomock.Any(),
|
||||||
|
connectionUpdate(agent.ID, replicaID),
|
||||||
|
).
|
||||||
|
AnyTimes().
|
||||||
|
Return(nil)
|
||||||
|
mDB.EXPECT().UpdateWorkspaceAgentConnectionByID(
|
||||||
|
gomock.Any(),
|
||||||
|
connectionUpdate(agent.ID, replicaID, withDisconnected()),
|
||||||
|
).
|
||||||
|
After(connected).
|
||||||
|
Times(1).
|
||||||
|
Return(nil)
|
||||||
|
|
||||||
|
// return a new buildID each time, meaning the connection is outdated
|
||||||
|
mDB.EXPECT().GetLatestWorkspaceBuildByWorkspaceID(gomock.Any(), build.WorkspaceID).
|
||||||
|
AnyTimes().
|
||||||
|
Return(database.WorkspaceBuild{ID: uuid.New()}, nil)
|
||||||
|
|
||||||
|
go uut.monitor(ctx)
|
||||||
|
fConn.requireEventuallyClosed(t, websocket.StatusGoingAway, "build is outdated")
|
||||||
|
fUpdater.requireEventuallySomeUpdates(t, build.WorkspaceID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAgentWebsocketMonitor_SendPings(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
ctx := testutil.Context(t, testutil.WaitShort)
|
||||||
|
fConn := &fakePingerCloser{}
|
||||||
|
uut := &agentWebsocketMonitor{
|
||||||
|
pingPeriod: testutil.IntervalFast,
|
||||||
|
conn: fConn,
|
||||||
|
}
|
||||||
|
go uut.sendPings(ctx)
|
||||||
|
fConn.requireEventuallyHasPing(t)
|
||||||
|
lastPing := uut.lastPing.Load()
|
||||||
|
require.NotNil(t, lastPing)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAgentWebsocketMonitor_StartClose(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
ctx := testutil.Context(t, testutil.WaitShort)
|
||||||
|
fConn := &fakePingerCloser{}
|
||||||
|
now := dbtime.Now()
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
mDB := dbmock.NewMockStore(ctrl)
|
||||||
|
fUpdater := &fakeUpdater{}
|
||||||
|
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||||
|
agent := database.WorkspaceAgent{
|
||||||
|
ID: uuid.New(),
|
||||||
|
FirstConnectedAt: sql.NullTime{
|
||||||
|
Time: now.Add(-time.Minute),
|
||||||
|
Valid: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
build := database.WorkspaceBuild{
|
||||||
|
ID: uuid.New(),
|
||||||
|
WorkspaceID: uuid.New(),
|
||||||
|
}
|
||||||
|
replicaID := uuid.New()
|
||||||
|
uut := &agentWebsocketMonitor{
|
||||||
|
apiCtx: ctx,
|
||||||
|
workspaceAgent: agent,
|
||||||
|
workspaceBuild: build,
|
||||||
|
conn: fConn,
|
||||||
|
db: mDB,
|
||||||
|
replicaID: replicaID,
|
||||||
|
updater: fUpdater,
|
||||||
|
logger: logger,
|
||||||
|
pingPeriod: testutil.IntervalFast,
|
||||||
|
disconnectTimeout: testutil.WaitShort,
|
||||||
|
}
|
||||||
|
|
||||||
|
connected := mDB.EXPECT().UpdateWorkspaceAgentConnectionByID(
|
||||||
|
gomock.Any(),
|
||||||
|
connectionUpdate(agent.ID, replicaID),
|
||||||
|
).
|
||||||
|
AnyTimes().
|
||||||
|
Return(nil)
|
||||||
|
mDB.EXPECT().UpdateWorkspaceAgentConnectionByID(
|
||||||
|
gomock.Any(),
|
||||||
|
connectionUpdate(agent.ID, replicaID, withDisconnected()),
|
||||||
|
).
|
||||||
|
After(connected).
|
||||||
|
Times(1).
|
||||||
|
Return(nil)
|
||||||
|
mDB.EXPECT().GetLatestWorkspaceBuildByWorkspaceID(gomock.Any(), build.WorkspaceID).
|
||||||
|
AnyTimes().
|
||||||
|
Return(database.WorkspaceBuild{ID: build.ID}, nil)
|
||||||
|
|
||||||
|
uut.start(ctx)
|
||||||
|
closed := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
uut.close()
|
||||||
|
close(closed)
|
||||||
|
}()
|
||||||
|
_ = testutil.RequireRecvCtx(ctx, t, closed)
|
||||||
|
}
|
||||||
|
|
||||||
|
type fakePingerCloser struct {
|
||||||
|
sync.Mutex
|
||||||
|
pings []time.Time
|
||||||
|
code websocket.StatusCode
|
||||||
|
reason string
|
||||||
|
closed bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakePingerCloser) Ping(context.Context) error {
|
||||||
|
f.Lock()
|
||||||
|
defer f.Unlock()
|
||||||
|
f.pings = append(f.pings, time.Now())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakePingerCloser) Close(code websocket.StatusCode, reason string) error {
|
||||||
|
f.Lock()
|
||||||
|
defer f.Unlock()
|
||||||
|
if f.closed {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
f.closed = true
|
||||||
|
f.code = code
|
||||||
|
f.reason = reason
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakePingerCloser) requireNotClosed(t *testing.T) {
|
||||||
|
f.Lock()
|
||||||
|
defer f.Unlock()
|
||||||
|
require.False(t, f.closed)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakePingerCloser) requireEventuallyClosed(t *testing.T, code websocket.StatusCode, reason string) {
|
||||||
|
require.Eventually(t, func() bool {
|
||||||
|
f.Lock()
|
||||||
|
defer f.Unlock()
|
||||||
|
return f.closed
|
||||||
|
}, testutil.WaitShort, testutil.IntervalFast)
|
||||||
|
f.Lock()
|
||||||
|
defer f.Unlock()
|
||||||
|
require.Equal(t, code, f.code)
|
||||||
|
require.Equal(t, reason, f.reason)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakePingerCloser) requireEventuallyHasPing(t *testing.T) {
|
||||||
|
require.Eventually(t, func() bool {
|
||||||
|
f.Lock()
|
||||||
|
defer f.Unlock()
|
||||||
|
return len(f.pings) > 0
|
||||||
|
}, testutil.WaitShort, testutil.IntervalFast)
|
||||||
|
}
|
||||||
|
|
||||||
|
type fakeUpdater struct {
|
||||||
|
sync.Mutex
|
||||||
|
updates []uuid.UUID
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeUpdater) publishWorkspaceUpdate(_ context.Context, workspaceID uuid.UUID) {
|
||||||
|
f.Lock()
|
||||||
|
defer f.Unlock()
|
||||||
|
f.updates = append(f.updates, workspaceID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeUpdater) requireEventuallySomeUpdates(t *testing.T, workspaceID uuid.UUID) {
|
||||||
|
require.Eventually(t, func() bool {
|
||||||
|
f.Lock()
|
||||||
|
defer f.Unlock()
|
||||||
|
return len(f.updates) >= 1
|
||||||
|
}, testutil.WaitShort, testutil.IntervalFast)
|
||||||
|
|
||||||
|
f.Lock()
|
||||||
|
defer f.Unlock()
|
||||||
|
for _, u := range f.updates {
|
||||||
|
require.Equal(t, workspaceID, u)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeUpdater) getUpdates() int {
|
||||||
|
f.Lock()
|
||||||
|
defer f.Unlock()
|
||||||
|
return len(f.updates)
|
||||||
|
}
|
||||||
|
|
||||||
|
type connectionUpdateMatcher struct {
|
||||||
|
agentID uuid.UUID
|
||||||
|
replicaID uuid.UUID
|
||||||
|
disconnected bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type connectionUpdateMatcherOption func(m connectionUpdateMatcher) connectionUpdateMatcher
|
||||||
|
|
||||||
|
func connectionUpdate(id, replica uuid.UUID, opts ...connectionUpdateMatcherOption) connectionUpdateMatcher {
|
||||||
|
m := connectionUpdateMatcher{
|
||||||
|
agentID: id,
|
||||||
|
replicaID: replica,
|
||||||
|
}
|
||||||
|
for _, opt := range opts {
|
||||||
|
m = opt(m)
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func withDisconnected() connectionUpdateMatcherOption {
|
||||||
|
return func(m connectionUpdateMatcher) connectionUpdateMatcher {
|
||||||
|
m.disconnected = true
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m connectionUpdateMatcher) Matches(x interface{}) bool {
|
||||||
|
args, ok := x.(database.UpdateWorkspaceAgentConnectionByIDParams)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if args.ID != m.agentID {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !args.LastConnectedReplicaID.Valid {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if args.LastConnectedReplicaID.UUID != m.replicaID {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if args.DisconnectedAt.Valid != m.disconnected {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m connectionUpdateMatcher) String() string {
|
||||||
|
return fmt.Sprintf("{agent=%s, replica=%s, disconnected=%t}",
|
||||||
|
m.agentID.String(), m.replicaID.String(), m.disconnected)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (connectionUpdateMatcher) Got(x interface{}) string {
|
||||||
|
args, ok := x.(database.UpdateWorkspaceAgentConnectionByIDParams)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Sprintf("type=%T", x)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("{agent=%s, replica=%s, disconnected=%t}",
|
||||||
|
args.ID, args.LastConnectedReplicaID.UUID, args.DisconnectedAt.Valid)
|
||||||
|
}
|
Reference in New Issue
Block a user