fix(agent): guard against multiple rpty race for same id (#7998)

* fix(agent): guard against multiple rpty race for same id
* fix(agent): ensure pty is closed on error
This commit is contained in:
Mathias Fredriksson
2023-06-13 18:14:07 +03:00
committed by GitHub
parent 9440b3da66
commit c916a9e67f

View File

@ -1025,16 +1025,32 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
}()
var rpty *reconnectingPTY
rawRPTY, ok := a.reconnectingPTYs.Load(msg.ID)
sendConnected := make(chan *reconnectingPTY, 1)
// On store, reserve this ID to prevent multiple concurrent new connections.
waitReady, ok := a.reconnectingPTYs.LoadOrStore(msg.ID, sendConnected)
if ok {
close(sendConnected) // Unused.
logger.Debug(ctx, "connecting to existing session")
rpty, ok = rawRPTY.(*reconnectingPTY)
c, ok := waitReady.(chan *reconnectingPTY)
if !ok {
return xerrors.Errorf("found invalid type in reconnecting pty map: %T", rawRPTY)
return xerrors.Errorf("found invalid type in reconnecting pty map: %T", waitReady)
}
rpty, ok = <-c
if !ok || rpty == nil {
return xerrors.Errorf("reconnecting pty closed before connection")
}
c <- rpty // Put it back for the next reconnect.
} else {
logger.Debug(ctx, "creating new session")
connected := false
defer func() {
if !connected && retErr != nil {
a.reconnectingPTYs.Delete(msg.ID)
close(sendConnected)
}
}()
// Empty command will default to the users shell!
cmd, err := a.sshServer.CreateCommand(ctx, msg.Command, nil)
if err != nil {
@ -1055,7 +1071,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
return xerrors.Errorf("start command: %w", err)
}
ctx, cancelFunc := context.WithCancel(ctx)
ctx, cancel := context.WithCancel(ctx)
rpty = &reconnectingPTY{
activeConns: map[string]net.Conn{
// We have to put the connection in the map instantly otherwise
@ -1064,10 +1080,9 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
},
ptty: ptty,
// Timeouts created with an after func can be reset!
timeout: time.AfterFunc(a.reconnectingPTYTimeout, cancelFunc),
timeout: time.AfterFunc(a.reconnectingPTYTimeout, cancel),
circularBuffer: circularBuffer,
}
a.reconnectingPTYs.Store(msg.ID, rpty)
// We don't need to separately monitor for the process exiting.
// When it exits, our ptty.OutputReader() will return EOF after
// reading all process output.
@ -1115,8 +1130,12 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
rpty.Close()
a.reconnectingPTYs.Delete(msg.ID)
}); err != nil {
_ = process.Kill()
_ = ptty.Close()
return xerrors.Errorf("start routine: %w", err)
}
connected = true
sendConnected <- rpty
}
// Resize the PTY to initial height + width.
err := rpty.ptty.Resize(msg.Height, msg.Width)