fix(agent): guard against multiple rpty race for same id (#7998)

* fix(agent): guard against multiple rpty race for same id
* fix(agent): ensure pty is closed on error
This commit is contained in:
Mathias Fredriksson
2023-06-13 18:14:07 +03:00
committed by GitHub
parent 9440b3da66
commit c916a9e67f

View File

@ -1025,16 +1025,32 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
}() }()
var rpty *reconnectingPTY var rpty *reconnectingPTY
rawRPTY, ok := a.reconnectingPTYs.Load(msg.ID) sendConnected := make(chan *reconnectingPTY, 1)
// On store, reserve this ID to prevent multiple concurrent new connections.
waitReady, ok := a.reconnectingPTYs.LoadOrStore(msg.ID, sendConnected)
if ok { if ok {
close(sendConnected) // Unused.
logger.Debug(ctx, "connecting to existing session") logger.Debug(ctx, "connecting to existing session")
rpty, ok = rawRPTY.(*reconnectingPTY) c, ok := waitReady.(chan *reconnectingPTY)
if !ok { if !ok {
return xerrors.Errorf("found invalid type in reconnecting pty map: %T", rawRPTY) return xerrors.Errorf("found invalid type in reconnecting pty map: %T", waitReady)
} }
rpty, ok = <-c
if !ok || rpty == nil {
return xerrors.Errorf("reconnecting pty closed before connection")
}
c <- rpty // Put it back for the next reconnect.
} else { } else {
logger.Debug(ctx, "creating new session") logger.Debug(ctx, "creating new session")
connected := false
defer func() {
if !connected && retErr != nil {
a.reconnectingPTYs.Delete(msg.ID)
close(sendConnected)
}
}()
// Empty command will default to the users shell! // Empty command will default to the users shell!
cmd, err := a.sshServer.CreateCommand(ctx, msg.Command, nil) cmd, err := a.sshServer.CreateCommand(ctx, msg.Command, nil)
if err != nil { if err != nil {
@ -1055,7 +1071,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
return xerrors.Errorf("start command: %w", err) return xerrors.Errorf("start command: %w", err)
} }
ctx, cancelFunc := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
rpty = &reconnectingPTY{ rpty = &reconnectingPTY{
activeConns: map[string]net.Conn{ activeConns: map[string]net.Conn{
// We have to put the connection in the map instantly otherwise // We have to put the connection in the map instantly otherwise
@ -1064,10 +1080,9 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
}, },
ptty: ptty, ptty: ptty,
// Timeouts created with an after func can be reset! // Timeouts created with an after func can be reset!
timeout: time.AfterFunc(a.reconnectingPTYTimeout, cancelFunc), timeout: time.AfterFunc(a.reconnectingPTYTimeout, cancel),
circularBuffer: circularBuffer, circularBuffer: circularBuffer,
} }
a.reconnectingPTYs.Store(msg.ID, rpty)
// We don't need to separately monitor for the process exiting. // We don't need to separately monitor for the process exiting.
// When it exits, our ptty.OutputReader() will return EOF after // When it exits, our ptty.OutputReader() will return EOF after
// reading all process output. // reading all process output.
@ -1115,8 +1130,12 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
rpty.Close() rpty.Close()
a.reconnectingPTYs.Delete(msg.ID) a.reconnectingPTYs.Delete(msg.ID)
}); err != nil { }); err != nil {
_ = process.Kill()
_ = ptty.Close()
return xerrors.Errorf("start routine: %w", err) return xerrors.Errorf("start routine: %w", err)
} }
connected = true
sendConnected <- rpty
} }
// Resize the PTY to initial height + width. // Resize the PTY to initial height + width.
err := rpty.ptty.Resize(msg.Height, msg.Width) err := rpty.ptty.Resize(msg.Height, msg.Width)