mirror of
https://github.com/coder/coder.git
synced 2025-07-09 11:45:56 +00:00
fix(agent): guard against multiple rpty race for same id (#7998)
* fix(agent): guard against multiple rpty race for same id * fix(agent): ensure pty is closed on error
This commit is contained in:
committed by
GitHub
parent
9440b3da66
commit
c916a9e67f
@ -1025,16 +1025,32 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
|
||||
}()
|
||||
|
||||
var rpty *reconnectingPTY
|
||||
rawRPTY, ok := a.reconnectingPTYs.Load(msg.ID)
|
||||
sendConnected := make(chan *reconnectingPTY, 1)
|
||||
// On store, reserve this ID to prevent multiple concurrent new connections.
|
||||
waitReady, ok := a.reconnectingPTYs.LoadOrStore(msg.ID, sendConnected)
|
||||
if ok {
|
||||
close(sendConnected) // Unused.
|
||||
logger.Debug(ctx, "connecting to existing session")
|
||||
rpty, ok = rawRPTY.(*reconnectingPTY)
|
||||
c, ok := waitReady.(chan *reconnectingPTY)
|
||||
if !ok {
|
||||
return xerrors.Errorf("found invalid type in reconnecting pty map: %T", rawRPTY)
|
||||
return xerrors.Errorf("found invalid type in reconnecting pty map: %T", waitReady)
|
||||
}
|
||||
rpty, ok = <-c
|
||||
if !ok || rpty == nil {
|
||||
return xerrors.Errorf("reconnecting pty closed before connection")
|
||||
}
|
||||
c <- rpty // Put it back for the next reconnect.
|
||||
} else {
|
||||
logger.Debug(ctx, "creating new session")
|
||||
|
||||
connected := false
|
||||
defer func() {
|
||||
if !connected && retErr != nil {
|
||||
a.reconnectingPTYs.Delete(msg.ID)
|
||||
close(sendConnected)
|
||||
}
|
||||
}()
|
||||
|
||||
// Empty command will default to the users shell!
|
||||
cmd, err := a.sshServer.CreateCommand(ctx, msg.Command, nil)
|
||||
if err != nil {
|
||||
@ -1055,7 +1071,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
|
||||
return xerrors.Errorf("start command: %w", err)
|
||||
}
|
||||
|
||||
ctx, cancelFunc := context.WithCancel(ctx)
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
rpty = &reconnectingPTY{
|
||||
activeConns: map[string]net.Conn{
|
||||
// We have to put the connection in the map instantly otherwise
|
||||
@ -1064,10 +1080,9 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
|
||||
},
|
||||
ptty: ptty,
|
||||
// Timeouts created with an after func can be reset!
|
||||
timeout: time.AfterFunc(a.reconnectingPTYTimeout, cancelFunc),
|
||||
timeout: time.AfterFunc(a.reconnectingPTYTimeout, cancel),
|
||||
circularBuffer: circularBuffer,
|
||||
}
|
||||
a.reconnectingPTYs.Store(msg.ID, rpty)
|
||||
// We don't need to separately monitor for the process exiting.
|
||||
// When it exits, our ptty.OutputReader() will return EOF after
|
||||
// reading all process output.
|
||||
@ -1115,8 +1130,12 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
|
||||
rpty.Close()
|
||||
a.reconnectingPTYs.Delete(msg.ID)
|
||||
}); err != nil {
|
||||
_ = process.Kill()
|
||||
_ = ptty.Close()
|
||||
return xerrors.Errorf("start routine: %w", err)
|
||||
}
|
||||
connected = true
|
||||
sendConnected <- rpty
|
||||
}
|
||||
// Resize the PTY to initial height + width.
|
||||
err := rpty.ptty.Resize(msg.Height, msg.Width)
|
||||
|
Reference in New Issue
Block a user