Files
coder/peer/conn_test.go
Kyle Carberry 8be245616a feat: Add organizations endpoint for users (#50)
* feat: Add organizations endpoint for users

This moves the /user endpoint to /users/me instead. This
will reduce code duplication.

This adds /users/<name>/organizations to list organizations
a user has access to. It doesn't contain the permissions a
user has over the organizations, but that will come in a future
contribution.

* Fix requested changes

* Fix tests

* Fix timeout

* Add test for UserOrgs

* Add test for userparam getting

* Add test for NoUser
2022-01-22 23:58:10 -06:00

308 lines
7.7 KiB
Go

package peer_test
import (
"context"
"errors"
"io"
"net"
"net/http"
"sync"
"testing"
"time"
"github.com/pion/logging"
"github.com/pion/transport/vnet"
"github.com/pion/webrtc/v3"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/peer"
)
const (
disconnectedTimeout = 5 * time.Second
failedTimeout = disconnectedTimeout * 5
keepAliveInterval = time.Millisecond * 2
)
var (
// There's a global race in the vnet library allocation code.
// This mutex locks around the creation of the vnet.
vnetMutex = sync.Mutex{}
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
func TestConn(t *testing.T) {
t.Parallel()
t.Run("Ping", func(t *testing.T) {
t.Parallel()
client, server, _ := createPair(t)
_, err := client.Ping()
require.NoError(t, err)
_, err = server.Ping()
require.NoError(t, err)
})
t.Run("PingNetworkOffline", func(t *testing.T) {
t.Parallel()
_, server, wan := createPair(t)
_, err := server.Ping()
require.NoError(t, err)
err = wan.Stop()
require.NoError(t, err)
_, err = server.Ping()
require.ErrorIs(t, err, peer.ErrFailed)
})
t.Run("PingReconnect", func(t *testing.T) {
t.Parallel()
_, server, wan := createPair(t)
_, err := server.Ping()
require.NoError(t, err)
// Create a channel that closes on disconnect.
channel, err := server.Dial(context.Background(), "wow", nil)
assert.NoError(t, err)
err = wan.Stop()
require.NoError(t, err)
// Once the connection is marked as disconnected, this
// channel will be closed.
_, err = channel.Read(make([]byte, 4))
assert.ErrorIs(t, err, peer.ErrClosed)
err = wan.Start()
require.NoError(t, err)
_, err = server.Ping()
require.NoError(t, err)
})
t.Run("Accept", func(t *testing.T) {
t.Parallel()
client, server, _ := createPair(t)
cch, err := client.Dial(context.Background(), "hello", &peer.ChannelOpts{})
require.NoError(t, err)
sch, err := server.Accept(context.Background())
require.NoError(t, err)
defer sch.Close()
_ = cch.Close()
_, err = sch.Read(make([]byte, 4))
require.ErrorIs(t, err, peer.ErrClosed)
})
t.Run("AcceptNetworkOffline", func(t *testing.T) {
t.Parallel()
client, server, wan := createPair(t)
cch, err := client.Dial(context.Background(), "hello", &peer.ChannelOpts{})
require.NoError(t, err)
sch, err := server.Accept(context.Background())
require.NoError(t, err)
defer sch.Close()
err = wan.Stop()
require.NoError(t, err)
_ = cch.Close()
_, err = sch.Read(make([]byte, 4))
require.ErrorIs(t, err, peer.ErrClosed)
})
t.Run("Buffering", func(t *testing.T) {
t.Parallel()
client, server, _ := createPair(t)
cch, err := client.Dial(context.Background(), "hello", &peer.ChannelOpts{})
require.NoError(t, err)
sch, err := server.Accept(context.Background())
require.NoError(t, err)
defer sch.Close()
go func() {
for i := 0; i < 1024; i++ {
_, err := cch.Write(make([]byte, 4096))
require.NoError(t, err)
}
_ = cch.Close()
}()
for {
_, err = sch.Read(make([]byte, 4096))
if err != nil {
require.ErrorIs(t, err, peer.ErrClosed)
break
}
}
})
t.Run("NetConn", func(t *testing.T) {
t.Parallel()
client, server, _ := createPair(t)
srv, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer srv.Close()
go func() {
sch, err := server.Accept(context.Background())
require.NoError(t, err)
nc2 := sch.NetConn()
nc1, err := net.Dial("tcp", srv.Addr().String())
require.NoError(t, err)
go func() {
_, _ = io.Copy(nc1, nc2)
}()
_, _ = io.Copy(nc2, nc1)
}()
go func() {
server := http.Server{
Handler: http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(200)
}),
}
defer server.Close()
_ = server.Serve(srv)
}()
defaultTransport := http.DefaultTransport.(*http.Transport).Clone()
var cch *peer.Channel
defaultTransport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
cch, err = client.Dial(ctx, "hello", &peer.ChannelOpts{})
if err != nil {
return nil, err
}
return cch.NetConn(), nil
}
c := http.Client{
Transport: defaultTransport,
}
req, err := http.NewRequestWithContext(context.Background(), "GET", "http://localhost/", nil)
require.NoError(t, err)
resp, err := c.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, resp.StatusCode, 200)
// Triggers any connections to close.
// This test below ensures the DataChannel actually closes.
defaultTransport.CloseIdleConnections()
err = cch.Close()
require.ErrorIs(t, err, peer.ErrClosed)
})
t.Run("CloseBeforeNegotiate", func(t *testing.T) {
t.Parallel()
client, server, _ := createPair(t)
err := client.Close()
require.NoError(t, err)
err = server.Close()
require.NoError(t, err)
})
t.Run("CloseWithError", func(t *testing.T) {
conn, err := peer.Client([]webrtc.ICEServer{}, nil)
require.NoError(t, err)
expectedErr := errors.New("wow")
_ = conn.CloseWithError(expectedErr)
_, err = conn.Dial(context.Background(), "", nil)
require.ErrorIs(t, err, expectedErr)
})
t.Run("PingConcurrent", func(t *testing.T) {
t.Parallel()
client, server, _ := createPair(t)
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
_, err := client.Ping()
require.NoError(t, err)
}()
go func() {
defer wg.Done()
_, err := server.Ping()
require.NoError(t, err)
}()
wg.Wait()
})
}
func createPair(t *testing.T) (client *peer.Conn, server *peer.Conn, wan *vnet.Router) {
loggingFactory := logging.NewDefaultLoggerFactory()
loggingFactory.DefaultLogLevel = logging.LogLevelDisabled
vnetMutex.Lock()
defer vnetMutex.Unlock()
wan, err := vnet.NewRouter(&vnet.RouterConfig{
CIDR: "1.2.3.0/24",
LoggerFactory: loggingFactory,
})
require.NoError(t, err)
c1Net := vnet.NewNet(&vnet.NetConfig{
StaticIPs: []string{"1.2.3.4"},
})
err = wan.AddNet(c1Net)
require.NoError(t, err)
c2Net := vnet.NewNet(&vnet.NetConfig{
StaticIPs: []string{"1.2.3.5"},
})
err = wan.AddNet(c2Net)
require.NoError(t, err)
c1SettingEngine := webrtc.SettingEngine{}
c1SettingEngine.SetVNet(c1Net)
c1SettingEngine.SetPrflxAcceptanceMinWait(0)
c1SettingEngine.SetICETimeouts(disconnectedTimeout, failedTimeout, keepAliveInterval)
channel1, err := peer.Client([]webrtc.ICEServer{}, &peer.ConnOpts{
SettingEngine: c1SettingEngine,
Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug),
})
require.NoError(t, err)
t.Cleanup(func() {
channel1.Close()
})
c2SettingEngine := webrtc.SettingEngine{}
c2SettingEngine.SetVNet(c2Net)
c2SettingEngine.SetPrflxAcceptanceMinWait(0)
c2SettingEngine.SetICETimeouts(disconnectedTimeout, failedTimeout, keepAliveInterval)
channel2, err := peer.Server([]webrtc.ICEServer{}, &peer.ConnOpts{
SettingEngine: c2SettingEngine,
Logger: slogtest.Make(t, nil).Named("server").Leveled(slog.LevelDebug),
})
require.NoError(t, err)
t.Cleanup(func() {
channel2.Close()
})
err = wan.Start()
require.NoError(t, err)
t.Cleanup(func() {
_ = wan.Stop()
})
go func() {
for {
select {
case c := <-channel2.LocalCandidate():
_ = channel1.AddRemoteCandidate(c)
case c := <-channel2.LocalSessionDescription():
channel1.SetRemoteSessionDescription(c)
case <-channel2.Closed():
return
}
}
}()
go func() {
for {
select {
case c := <-channel1.LocalCandidate():
_ = channel2.AddRemoteCandidate(c)
case c := <-channel1.LocalSessionDescription():
channel2.SetRemoteSessionDescription(c)
case <-channel1.Closed():
return
}
}
}()
return channel1, channel2, wan
}