fix: Use buffered reader in peer to fix ShortBuffer (#303)

This prevents a io.ErrShortBuffer from occurring when the byte
slice being read is smaller than the chunks sent from the opposite
pipe.

This makes sense for unordered connections, where transmission is
not guaranteed, but does not make sense for TCP-like connections.

We use a bufio.Reader when ordered to ensure data isn't lost.
This commit is contained in:
Kyle Carberry
2022-02-17 08:45:14 -06:00
committed by GitHub
parent deb717037d
commit d43699306b
3 changed files with 41 additions and 2 deletions

View File

@ -57,6 +57,7 @@
"tfexec",
"tfstate",
"unconvert",
"webrtc",
"xerrors",
"yamux"
]

View File

@ -1,6 +1,7 @@
package peer
import (
"bufio"
"context"
"io"
"net"
@ -78,7 +79,8 @@ type Channel struct {
dc *webrtc.DataChannel
// This field can be nil. It becomes set after the DataChannel
// has been opened and is detached.
rwc datachannel.ReadWriteCloser
rwc datachannel.ReadWriteCloser
reader io.Reader
closed chan struct{}
closeMutex sync.Mutex
@ -130,6 +132,21 @@ func (c *Channel) init() {
_ = c.closeWithError(xerrors.Errorf("detach: %w", err))
return
}
// pion/webrtc will return an io.ErrShortBuffer when a read
// is triggerred with a buffer size less than the chunks written.
//
// This makes sense when considering UDP connections, because
// bufferring of data that has no transmit guarantees is likely
// to cause unexpected behavior.
//
// When ordered, this adds a bufio.Reader. This ensures additional
// data on TCP-like connections can be read in parts, while still
// being bufferred.
if c.opts.Unordered {
c.reader = c.rwc
} else {
c.reader = bufio.NewReader(c.rwc)
}
close(c.opened)
})
@ -181,7 +198,7 @@ func (c *Channel) Read(bytes []byte) (int, error) {
}
}
bytesRead, err := c.rwc.Read(bytes)
bytesRead, err := c.reader.Read(bytes)
if err != nil {
if c.isClosed() {
return 0, c.closeError

View File

@ -267,6 +267,27 @@ func TestConn(t *testing.T) {
_, err := client.Ping()
require.NoError(t, err)
})
t.Run("ShortBuffer", func(t *testing.T) {
t.Parallel()
client, server, _ := createPair(t)
exchange(client, server)
go func() {
channel, err := client.Dial(context.Background(), "test", nil)
require.NoError(t, err)
_, err = channel.Write([]byte{1, 2})
require.NoError(t, err)
}()
channel, err := server.Accept(context.Background())
require.NoError(t, err)
data := make([]byte, 1)
_, err = channel.Read(data)
require.NoError(t, err)
require.Equal(t, uint8(0x1), data[0])
_, err = channel.Read(data)
require.NoError(t, err)
require.Equal(t, uint8(0x2), data[0])
})
}
func createPair(t *testing.T) (client *peer.Conn, server *peer.Conn, wan *vnet.Router) {