Files
coder/peer/conn_test.go
Kyle Carberry 53cfa8a45a feat: Create broker for negotiating connections (#14)
* feat: Create broker for negotiating connections

WebRTC require an exchange of encryption keys and network hops to connect. This package pipes the exchange over gRPC. This will be used in all connecting clients and agents.

* Regenerate protobuf definition

* Cache Go build and test

* Fix gRPC language with dRPC

Co-authored-by: Bryan <bryan@coder.com>

Co-authored-by: Bryan <bryan@coder.com>
2022-01-11 09:28:41 -06:00

305 lines
7.4 KiB
Go

package peer_test
import (
"context"
"errors"
"io"
"net"
"net/http"
"sync"
"testing"
"time"
"github.com/pion/logging"
"github.com/pion/transport/vnet"
"github.com/pion/webrtc/v3"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/peer"
)
const (
disconnectedTimeout = time.Millisecond * 200
failedTimeout = disconnectedTimeout * 5
keepAliveInterval = time.Millisecond * 2
)
var (
// There's a global race in the vnet library allocation code.
// This mutex locks around the creation of the vnet.
vnetMutex = sync.Mutex{}
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
func TestConn(t *testing.T) {
t.Parallel()
t.Run("Ping", func(t *testing.T) {
t.Parallel()
client, server, _ := createPair(t)
_, err := client.Ping()
require.NoError(t, err)
_, err = server.Ping()
require.NoError(t, err)
})
t.Run("PingNetworkOffline", func(t *testing.T) {
t.Parallel()
_, server, wan := createPair(t)
_, err := server.Ping()
require.NoError(t, err)
err = wan.Stop()
require.NoError(t, err)
_, err = server.Ping()
require.ErrorIs(t, err, peer.ErrFailed)
})
t.Run("PingReconnect", func(t *testing.T) {
t.Parallel()
_, server, wan := createPair(t)
_, err := server.Ping()
require.NoError(t, err)
// Create a channel that closes on disconnect.
ch, err := server.Dial(context.Background(), "wow", nil)
assert.NoError(t, err)
err = wan.Stop()
require.NoError(t, err)
// Once the connection is marked as disconnected, this
// channel will be closed.
_, err = ch.Read(make([]byte, 4))
assert.ErrorIs(t, err, peer.ErrClosed)
err = wan.Start()
require.NoError(t, err)
_, err = server.Ping()
require.NoError(t, err)
})
t.Run("Accept", func(t *testing.T) {
t.Parallel()
client, server, _ := createPair(t)
cch, err := client.Dial(context.Background(), "hello", &peer.ChannelOpts{})
require.NoError(t, err)
sch, err := server.Accept(context.Background())
require.NoError(t, err)
_ = cch.Close()
_, err = sch.Read(make([]byte, 4))
require.ErrorIs(t, err, peer.ErrClosed)
})
t.Run("AcceptNetworkOffline", func(t *testing.T) {
t.Parallel()
client, server, wan := createPair(t)
cch, err := client.Dial(context.Background(), "hello", &peer.ChannelOpts{})
require.NoError(t, err)
sch, err := server.Accept(context.Background())
require.NoError(t, err)
err = wan.Stop()
require.NoError(t, err)
_ = cch.Close()
_, err = sch.Read(make([]byte, 4))
require.ErrorIs(t, err, peer.ErrClosed)
})
t.Run("Buffering", func(t *testing.T) {
t.Parallel()
client, server, _ := createPair(t)
cch, err := client.Dial(context.Background(), "hello", &peer.ChannelOpts{})
require.NoError(t, err)
sch, err := server.Accept(context.Background())
require.NoError(t, err)
go func() {
for i := 0; i < 1024; i++ {
_, err := cch.Write(make([]byte, 4096))
require.NoError(t, err)
}
_ = cch.Close()
}()
for {
_, err = sch.Read(make([]byte, 4096))
if err != nil {
require.ErrorIs(t, err, peer.ErrClosed)
break
}
}
})
t.Run("NetConn", func(t *testing.T) {
t.Parallel()
client, server, _ := createPair(t)
srv, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer srv.Close()
go func() {
sch, err := server.Accept(context.Background())
require.NoError(t, err)
nc2 := sch.NetConn()
nc1, err := net.Dial("tcp", srv.Addr().String())
require.NoError(t, err)
go func() {
_, _ = io.Copy(nc1, nc2)
}()
_, _ = io.Copy(nc2, nc1)
}()
go func() {
s := http.Server{
Handler: http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(200)
}),
}
defer s.Close()
_ = s.Serve(srv)
}()
dt := http.DefaultTransport.(*http.Transport).Clone()
var cch *peer.Channel
dt.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
cch, err = client.Dial(context.Background(), "hello", &peer.ChannelOpts{})
if err != nil {
return nil, err
}
return cch.NetConn(), nil
}
c := http.Client{
Transport: dt,
}
req, err := http.NewRequestWithContext(context.Background(), "GET", "http://localhost/", nil)
require.NoError(t, err)
resp, err := c.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, resp.StatusCode, 200)
// Triggers any connections to close.
// This test below ensures the DataChannel actually closes.
dt.CloseIdleConnections()
err = cch.Close()
require.ErrorIs(t, err, peer.ErrClosed)
})
t.Run("CloseBeforeNegotiate", func(t *testing.T) {
t.Parallel()
client, server, _ := createPair(t)
err := client.Close()
require.NoError(t, err)
err = server.Close()
require.NoError(t, err)
})
t.Run("CloseWithError", func(t *testing.T) {
conn, err := peer.Client([]webrtc.ICEServer{}, nil)
require.NoError(t, err)
expectedErr := errors.New("wow")
_ = conn.CloseWithError(expectedErr)
_, err = conn.Dial(context.Background(), "", nil)
require.ErrorIs(t, err, expectedErr)
})
t.Run("PingConcurrent", func(t *testing.T) {
t.Parallel()
client, server, _ := createPair(t)
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
_, err := client.Ping()
require.NoError(t, err)
}()
go func() {
defer wg.Done()
_, err := server.Ping()
require.NoError(t, err)
}()
wg.Wait()
})
}
func createPair(t *testing.T) (client *peer.Conn, server *peer.Conn, wan *vnet.Router) {
lf := logging.NewDefaultLoggerFactory()
lf.DefaultLogLevel = logging.LogLevelDisabled
vnetMutex.Lock()
defer vnetMutex.Unlock()
wan, err := vnet.NewRouter(&vnet.RouterConfig{
CIDR: "1.2.3.0/24",
LoggerFactory: lf,
})
require.NoError(t, err)
c1Net := vnet.NewNet(&vnet.NetConfig{
StaticIPs: []string{"1.2.3.4"},
})
err = wan.AddNet(c1Net)
require.NoError(t, err)
c2Net := vnet.NewNet(&vnet.NetConfig{
StaticIPs: []string{"1.2.3.5"},
})
err = wan.AddNet(c2Net)
require.NoError(t, err)
c1SettingEngine := webrtc.SettingEngine{}
c1SettingEngine.SetVNet(c1Net)
c1SettingEngine.SetPrflxAcceptanceMinWait(0)
c1SettingEngine.SetICETimeouts(disconnectedTimeout, failedTimeout, keepAliveInterval)
c1, err := peer.Client([]webrtc.ICEServer{}, &peer.ConnOpts{
SettingEngine: c1SettingEngine,
Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug),
})
require.NoError(t, err)
t.Cleanup(func() {
c1.Close()
})
c2SettingEngine := webrtc.SettingEngine{}
c2SettingEngine.SetVNet(c2Net)
c2SettingEngine.SetPrflxAcceptanceMinWait(0)
c2SettingEngine.SetICETimeouts(disconnectedTimeout, failedTimeout, keepAliveInterval)
c2, err := peer.Server([]webrtc.ICEServer{}, &peer.ConnOpts{
SettingEngine: c2SettingEngine,
Logger: slogtest.Make(t, nil).Named("server").Leveled(slog.LevelDebug),
})
require.NoError(t, err)
t.Cleanup(func() {
c2.Close()
})
err = wan.Start()
require.NoError(t, err)
t.Cleanup(func() {
_ = wan.Stop()
})
go func() {
for {
select {
case c := <-c2.LocalCandidate():
_ = c1.AddRemoteCandidate(c)
case c := <-c2.LocalSessionDescription():
c1.SetRemoteSessionDescription(c)
case <-c2.Closed():
return
}
}
}()
go func() {
for {
select {
case c := <-c1.LocalCandidate():
_ = c2.AddRemoteCandidate(c)
case c := <-c1.LocalSessionDescription():
c2.SetRemoteSessionDescription(c)
case <-c1.Closed():
return
}
}
}()
return c1, c2, wan
}