mirror of
https://github.com/coder/coder.git
synced 2025-07-12 00:14:10 +00:00
chore: refactor coordination (#15343)
Refactors the way clients of the Tailnet API (clients of the API, which include both workspace "agents" and "clients") interact with the API. Introduces the idea of abstract "controllers" for each of the RPCs in the API, and implements a Coordination controller by refactoring from `workspacesdk`. chore re: #14729
This commit is contained in:
@ -90,302 +90,6 @@ type Coordinatee interface {
|
||||
SetTunnelDestination(id uuid.UUID)
|
||||
}
|
||||
|
||||
type Coordination interface {
|
||||
Close(context.Context) error
|
||||
Error() <-chan error
|
||||
}
|
||||
|
||||
type remoteCoordination struct {
|
||||
sync.Mutex
|
||||
closed bool
|
||||
errChan chan error
|
||||
coordinatee Coordinatee
|
||||
tgt uuid.UUID
|
||||
logger slog.Logger
|
||||
protocol proto.DRPCTailnet_CoordinateClient
|
||||
respLoopDone chan struct{}
|
||||
}
|
||||
|
||||
// Close attempts to gracefully close the remoteCoordination by sending a Disconnect message and
|
||||
// waiting for the server to hang up the coordination. If the provided context expires, we stop
|
||||
// waiting for the server and close the coordination stream from our end.
|
||||
func (c *remoteCoordination) Close(ctx context.Context) (retErr error) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
if c.closed {
|
||||
return nil
|
||||
}
|
||||
c.closed = true
|
||||
defer func() {
|
||||
// We shouldn't just close the protocol right away, because the way dRPC streams work is
|
||||
// that if you close them, that could take effect immediately, even before the Disconnect
|
||||
// message is processed. Coordinators are supposed to hang up on us once they get a
|
||||
// Disconnect message, so we should wait around for that until the context expires.
|
||||
select {
|
||||
case <-c.respLoopDone:
|
||||
c.logger.Debug(ctx, "responses closed after disconnect")
|
||||
return
|
||||
case <-ctx.Done():
|
||||
c.logger.Warn(ctx, "context expired while waiting for coordinate responses to close")
|
||||
}
|
||||
// forcefully close the stream
|
||||
protoErr := c.protocol.Close()
|
||||
<-c.respLoopDone
|
||||
if retErr == nil {
|
||||
retErr = protoErr
|
||||
}
|
||||
}()
|
||||
err := c.protocol.Send(&proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}})
|
||||
if err != nil && !xerrors.Is(err, io.EOF) {
|
||||
// Coordinator RPC hangs up when it gets disconnect, so EOF is expected.
|
||||
return xerrors.Errorf("send disconnect: %w", err)
|
||||
}
|
||||
c.logger.Debug(context.Background(), "sent disconnect")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *remoteCoordination) Error() <-chan error {
|
||||
return c.errChan
|
||||
}
|
||||
|
||||
func (c *remoteCoordination) sendErr(err error) {
|
||||
select {
|
||||
case c.errChan <- err:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (c *remoteCoordination) respLoop() {
|
||||
defer func() {
|
||||
c.coordinatee.SetAllPeersLost()
|
||||
close(c.respLoopDone)
|
||||
}()
|
||||
for {
|
||||
resp, err := c.protocol.Recv()
|
||||
if err != nil {
|
||||
c.logger.Debug(context.Background(), "failed to read from protocol", slog.Error(err))
|
||||
c.sendErr(xerrors.Errorf("read: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
err = c.coordinatee.UpdatePeers(resp.GetPeerUpdates())
|
||||
if err != nil {
|
||||
c.logger.Debug(context.Background(), "failed to update peers", slog.Error(err))
|
||||
c.sendErr(xerrors.Errorf("update peers: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
// Only send acks from peers without a target.
|
||||
if c.tgt == uuid.Nil {
|
||||
// Send an ack back for all received peers. This could
|
||||
// potentially be smarter to only send an ACK once per client,
|
||||
// but there's nothing currently stopping clients from reusing
|
||||
// IDs.
|
||||
rfh := []*proto.CoordinateRequest_ReadyForHandshake{}
|
||||
for _, peer := range resp.GetPeerUpdates() {
|
||||
if peer.Kind != proto.CoordinateResponse_PeerUpdate_NODE {
|
||||
continue
|
||||
}
|
||||
|
||||
rfh = append(rfh, &proto.CoordinateRequest_ReadyForHandshake{Id: peer.Id})
|
||||
}
|
||||
if len(rfh) > 0 {
|
||||
err := c.protocol.Send(&proto.CoordinateRequest{
|
||||
ReadyForHandshake: rfh,
|
||||
})
|
||||
if err != nil {
|
||||
c.logger.Debug(context.Background(), "failed to send ready for handshake", slog.Error(err))
|
||||
c.sendErr(xerrors.Errorf("send: %w", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// NewRemoteCoordination uses the provided protocol to coordinate the provided coordinatee (usually a
|
||||
// Conn). If the tunnelTarget is not uuid.Nil, then we add a tunnel to the peer (i.e. we are acting as
|
||||
// a client---agents should NOT set this!).
|
||||
func NewRemoteCoordination(logger slog.Logger,
|
||||
protocol proto.DRPCTailnet_CoordinateClient, coordinatee Coordinatee,
|
||||
tunnelTarget uuid.UUID,
|
||||
) Coordination {
|
||||
c := &remoteCoordination{
|
||||
errChan: make(chan error, 1),
|
||||
coordinatee: coordinatee,
|
||||
tgt: tunnelTarget,
|
||||
logger: logger,
|
||||
protocol: protocol,
|
||||
respLoopDone: make(chan struct{}),
|
||||
}
|
||||
if tunnelTarget != uuid.Nil {
|
||||
c.coordinatee.SetTunnelDestination(tunnelTarget)
|
||||
c.Lock()
|
||||
err := c.protocol.Send(&proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Id: tunnelTarget[:]}})
|
||||
c.Unlock()
|
||||
if err != nil {
|
||||
c.sendErr(err)
|
||||
}
|
||||
}
|
||||
|
||||
coordinatee.SetNodeCallback(func(node *Node) {
|
||||
pn, err := NodeToProto(node)
|
||||
if err != nil {
|
||||
c.logger.Critical(context.Background(), "failed to convert node", slog.Error(err))
|
||||
c.sendErr(err)
|
||||
return
|
||||
}
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
if c.closed {
|
||||
c.logger.Debug(context.Background(), "ignored node update because coordination is closed")
|
||||
return
|
||||
}
|
||||
err = c.protocol.Send(&proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{Node: pn}})
|
||||
if err != nil {
|
||||
c.sendErr(xerrors.Errorf("write: %w", err))
|
||||
}
|
||||
})
|
||||
go c.respLoop()
|
||||
return c
|
||||
}
|
||||
|
||||
type inMemoryCoordination struct {
|
||||
sync.Mutex
|
||||
ctx context.Context
|
||||
errChan chan error
|
||||
closed bool
|
||||
respLoopDone chan struct{}
|
||||
coordinatee Coordinatee
|
||||
logger slog.Logger
|
||||
resps <-chan *proto.CoordinateResponse
|
||||
reqs chan<- *proto.CoordinateRequest
|
||||
}
|
||||
|
||||
func (c *inMemoryCoordination) sendErr(err error) {
|
||||
select {
|
||||
case c.errChan <- err:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (c *inMemoryCoordination) Error() <-chan error {
|
||||
return c.errChan
|
||||
}
|
||||
|
||||
// NewInMemoryCoordination connects a Coordinatee (usually Conn) to an in memory Coordinator, for testing
|
||||
// or local clients. Set ClientID to uuid.Nil for an agent.
|
||||
func NewInMemoryCoordination(
|
||||
ctx context.Context, logger slog.Logger,
|
||||
clientID, agentID uuid.UUID,
|
||||
coordinator Coordinator, coordinatee Coordinatee,
|
||||
) Coordination {
|
||||
thisID := agentID
|
||||
logger = logger.With(slog.F("agent_id", agentID))
|
||||
var auth CoordinateeAuth = AgentCoordinateeAuth{ID: agentID}
|
||||
if clientID != uuid.Nil {
|
||||
// this is a client connection
|
||||
auth = ClientCoordinateeAuth{AgentID: agentID}
|
||||
logger = logger.With(slog.F("client_id", clientID))
|
||||
thisID = clientID
|
||||
}
|
||||
c := &inMemoryCoordination{
|
||||
ctx: ctx,
|
||||
errChan: make(chan error, 1),
|
||||
coordinatee: coordinatee,
|
||||
logger: logger,
|
||||
respLoopDone: make(chan struct{}),
|
||||
}
|
||||
|
||||
// use the background context since we will depend exclusively on closing the req channel to
|
||||
// tell the coordinator we are done.
|
||||
c.reqs, c.resps = coordinator.Coordinate(context.Background(),
|
||||
thisID, fmt.Sprintf("inmemory%s", thisID),
|
||||
auth,
|
||||
)
|
||||
go c.respLoop()
|
||||
if agentID != uuid.Nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
c.logger.Warn(ctx, "context expired before we could add tunnel", slog.Error(ctx.Err()))
|
||||
return c
|
||||
case c.reqs <- &proto.CoordinateRequest{AddTunnel: &proto.CoordinateRequest_Tunnel{Id: agentID[:]}}:
|
||||
// OK!
|
||||
}
|
||||
}
|
||||
coordinatee.SetNodeCallback(func(n *Node) {
|
||||
pn, err := NodeToProto(n)
|
||||
if err != nil {
|
||||
c.logger.Critical(ctx, "failed to convert node", slog.Error(err))
|
||||
c.sendErr(err)
|
||||
return
|
||||
}
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
if c.closed {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
c.logger.Info(ctx, "context expired before sending node update")
|
||||
return
|
||||
case c.reqs <- &proto.CoordinateRequest{UpdateSelf: &proto.CoordinateRequest_UpdateSelf{Node: pn}}:
|
||||
c.logger.Debug(ctx, "sent node in-memory to coordinator")
|
||||
}
|
||||
})
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *inMemoryCoordination) respLoop() {
|
||||
defer func() {
|
||||
c.coordinatee.SetAllPeersLost()
|
||||
close(c.respLoopDone)
|
||||
}()
|
||||
for resp := range c.resps {
|
||||
c.logger.Debug(context.Background(), "got in-memory response from coordinator", slog.F("resp", resp))
|
||||
err := c.coordinatee.UpdatePeers(resp.GetPeerUpdates())
|
||||
if err != nil {
|
||||
c.sendErr(xerrors.Errorf("failed to update peers: %w", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
c.logger.Debug(context.Background(), "in-memory response channel closed")
|
||||
}
|
||||
|
||||
func (*inMemoryCoordination) AwaitAck() <-chan struct{} {
|
||||
// This is only used for tests, so just return a closed channel.
|
||||
ch := make(chan struct{})
|
||||
close(ch)
|
||||
return ch
|
||||
}
|
||||
|
||||
// Close attempts to gracefully close the remoteCoordination by sending a Disconnect message and
|
||||
// waiting for the server to hang up the coordination. If the provided context expires, we stop
|
||||
// waiting for the server and close the coordination stream from our end.
|
||||
func (c *inMemoryCoordination) Close(ctx context.Context) error {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
c.logger.Debug(context.Background(), "closing in-memory coordination")
|
||||
if c.closed {
|
||||
return nil
|
||||
}
|
||||
defer close(c.reqs)
|
||||
c.closed = true
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return xerrors.Errorf("failed to gracefully disconnect: %w", c.ctx.Err())
|
||||
case c.reqs <- &proto.CoordinateRequest{Disconnect: &proto.CoordinateRequest_Disconnect{}}:
|
||||
c.logger.Debug(context.Background(), "sent graceful disconnect in-memory")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return xerrors.Errorf("context expired waiting for responses to close: %w", c.ctx.Err())
|
||||
case <-c.respLoopDone:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
const LoggerName = "coord"
|
||||
|
||||
var (
|
||||
|
Reference in New Issue
Block a user