mirror of
https://github.com/coder/coder.git
synced 2025-07-09 11:45:56 +00:00
fix(agent): guard against multiple rpty race for same id (#7998)
* fix(agent): guard against multiple rpty race for same id * fix(agent): ensure pty is closed on error
This commit is contained in:
committed by
GitHub
parent
9440b3da66
commit
c916a9e67f
@ -1025,16 +1025,32 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
var rpty *reconnectingPTY
|
var rpty *reconnectingPTY
|
||||||
rawRPTY, ok := a.reconnectingPTYs.Load(msg.ID)
|
sendConnected := make(chan *reconnectingPTY, 1)
|
||||||
|
// On store, reserve this ID to prevent multiple concurrent new connections.
|
||||||
|
waitReady, ok := a.reconnectingPTYs.LoadOrStore(msg.ID, sendConnected)
|
||||||
if ok {
|
if ok {
|
||||||
|
close(sendConnected) // Unused.
|
||||||
logger.Debug(ctx, "connecting to existing session")
|
logger.Debug(ctx, "connecting to existing session")
|
||||||
rpty, ok = rawRPTY.(*reconnectingPTY)
|
c, ok := waitReady.(chan *reconnectingPTY)
|
||||||
if !ok {
|
if !ok {
|
||||||
return xerrors.Errorf("found invalid type in reconnecting pty map: %T", rawRPTY)
|
return xerrors.Errorf("found invalid type in reconnecting pty map: %T", waitReady)
|
||||||
}
|
}
|
||||||
|
rpty, ok = <-c
|
||||||
|
if !ok || rpty == nil {
|
||||||
|
return xerrors.Errorf("reconnecting pty closed before connection")
|
||||||
|
}
|
||||||
|
c <- rpty // Put it back for the next reconnect.
|
||||||
} else {
|
} else {
|
||||||
logger.Debug(ctx, "creating new session")
|
logger.Debug(ctx, "creating new session")
|
||||||
|
|
||||||
|
connected := false
|
||||||
|
defer func() {
|
||||||
|
if !connected && retErr != nil {
|
||||||
|
a.reconnectingPTYs.Delete(msg.ID)
|
||||||
|
close(sendConnected)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
// Empty command will default to the users shell!
|
// Empty command will default to the users shell!
|
||||||
cmd, err := a.sshServer.CreateCommand(ctx, msg.Command, nil)
|
cmd, err := a.sshServer.CreateCommand(ctx, msg.Command, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -1055,7 +1071,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
|
|||||||
return xerrors.Errorf("start command: %w", err)
|
return xerrors.Errorf("start command: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancelFunc := context.WithCancel(ctx)
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
rpty = &reconnectingPTY{
|
rpty = &reconnectingPTY{
|
||||||
activeConns: map[string]net.Conn{
|
activeConns: map[string]net.Conn{
|
||||||
// We have to put the connection in the map instantly otherwise
|
// We have to put the connection in the map instantly otherwise
|
||||||
@ -1064,10 +1080,9 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
|
|||||||
},
|
},
|
||||||
ptty: ptty,
|
ptty: ptty,
|
||||||
// Timeouts created with an after func can be reset!
|
// Timeouts created with an after func can be reset!
|
||||||
timeout: time.AfterFunc(a.reconnectingPTYTimeout, cancelFunc),
|
timeout: time.AfterFunc(a.reconnectingPTYTimeout, cancel),
|
||||||
circularBuffer: circularBuffer,
|
circularBuffer: circularBuffer,
|
||||||
}
|
}
|
||||||
a.reconnectingPTYs.Store(msg.ID, rpty)
|
|
||||||
// We don't need to separately monitor for the process exiting.
|
// We don't need to separately monitor for the process exiting.
|
||||||
// When it exits, our ptty.OutputReader() will return EOF after
|
// When it exits, our ptty.OutputReader() will return EOF after
|
||||||
// reading all process output.
|
// reading all process output.
|
||||||
@ -1115,8 +1130,12 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
|
|||||||
rpty.Close()
|
rpty.Close()
|
||||||
a.reconnectingPTYs.Delete(msg.ID)
|
a.reconnectingPTYs.Delete(msg.ID)
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
|
_ = process.Kill()
|
||||||
|
_ = ptty.Close()
|
||||||
return xerrors.Errorf("start routine: %w", err)
|
return xerrors.Errorf("start routine: %w", err)
|
||||||
}
|
}
|
||||||
|
connected = true
|
||||||
|
sendConnected <- rpty
|
||||||
}
|
}
|
||||||
// Resize the PTY to initial height + width.
|
// Resize the PTY to initial height + width.
|
||||||
err := rpty.ptty.Resize(msg.Height, msg.Width)
|
err := rpty.ptty.Resize(msg.Height, msg.Width)
|
||||||
|
Reference in New Issue
Block a user