package peer_test import ( "context" "errors" "io" "net" "net/http" "sync" "testing" "time" "github.com/pion/logging" "github.com/pion/transport/vnet" "github.com/pion/webrtc/v3" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/goleak" "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/peer" ) const ( disconnectedTimeout = time.Millisecond * 200 failedTimeout = disconnectedTimeout * 5 keepAliveInterval = time.Millisecond * 2 ) var ( // There's a global race in the vnet library allocation code. // This mutex locks around the creation of the vnet. vnetMutex = sync.Mutex{} ) func TestMain(m *testing.M) { goleak.VerifyTestMain(m) } func TestConn(t *testing.T) { t.Parallel() t.Run("Ping", func(t *testing.T) { t.Parallel() client, server, _ := createPair(t) _, err := client.Ping() require.NoError(t, err) _, err = server.Ping() require.NoError(t, err) }) t.Run("PingNetworkOffline", func(t *testing.T) { t.Parallel() _, server, wan := createPair(t) _, err := server.Ping() require.NoError(t, err) err = wan.Stop() require.NoError(t, err) _, err = server.Ping() require.ErrorIs(t, err, peer.ErrFailed) }) t.Run("PingReconnect", func(t *testing.T) { t.Parallel() _, server, wan := createPair(t) _, err := server.Ping() require.NoError(t, err) // Create a channel that closes on disconnect. ch, err := server.Dial(context.Background(), "wow", nil) assert.NoError(t, err) err = wan.Stop() require.NoError(t, err) // Once the connection is marked as disconnected, this // channel will be closed. _, err = ch.Read(make([]byte, 4)) assert.ErrorIs(t, err, peer.ErrClosed) err = wan.Start() require.NoError(t, err) _, err = server.Ping() require.NoError(t, err) }) t.Run("Accept", func(t *testing.T) { t.Parallel() client, server, _ := createPair(t) cch, err := client.Dial(context.Background(), "hello", &peer.ChannelOpts{}) require.NoError(t, err) sch, err := server.Accept(context.Background()) require.NoError(t, err) _ = cch.Close() _, err = sch.Read(make([]byte, 4)) require.ErrorIs(t, err, peer.ErrClosed) }) t.Run("AcceptNetworkOffline", func(t *testing.T) { t.Parallel() client, server, wan := createPair(t) cch, err := client.Dial(context.Background(), "hello", &peer.ChannelOpts{}) require.NoError(t, err) sch, err := server.Accept(context.Background()) require.NoError(t, err) err = wan.Stop() require.NoError(t, err) _ = cch.Close() _, err = sch.Read(make([]byte, 4)) require.ErrorIs(t, err, peer.ErrClosed) }) t.Run("Buffering", func(t *testing.T) { t.Parallel() client, server, _ := createPair(t) cch, err := client.Dial(context.Background(), "hello", &peer.ChannelOpts{}) require.NoError(t, err) sch, err := server.Accept(context.Background()) require.NoError(t, err) go func() { for i := 0; i < 1024; i++ { _, err := cch.Write(make([]byte, 4096)) require.NoError(t, err) } _ = cch.Close() }() for { _, err = sch.Read(make([]byte, 4096)) if err != nil { require.ErrorIs(t, err, peer.ErrClosed) break } } }) t.Run("NetConn", func(t *testing.T) { t.Parallel() client, server, _ := createPair(t) srv, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) defer srv.Close() go func() { sch, err := server.Accept(context.Background()) require.NoError(t, err) nc2 := sch.NetConn() nc1, err := net.Dial("tcp", srv.Addr().String()) require.NoError(t, err) go func() { _, _ = io.Copy(nc1, nc2) }() _, _ = io.Copy(nc2, nc1) }() go func() { s := http.Server{ Handler: http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { rw.WriteHeader(200) }), } defer s.Close() _ = s.Serve(srv) }() dt := http.DefaultTransport.(*http.Transport).Clone() var cch *peer.Channel dt.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { cch, err = client.Dial(context.Background(), "hello", &peer.ChannelOpts{}) if err != nil { return nil, err } return cch.NetConn(), nil } c := http.Client{ Transport: dt, } req, err := http.NewRequestWithContext(context.Background(), "GET", "http://localhost/", nil) require.NoError(t, err) resp, err := c.Do(req) require.NoError(t, err) defer resp.Body.Close() require.Equal(t, resp.StatusCode, 200) // Triggers any connections to close. // This test below ensures the DataChannel actually closes. dt.CloseIdleConnections() err = cch.Close() require.ErrorIs(t, err, peer.ErrClosed) }) t.Run("CloseBeforeNegotiate", func(t *testing.T) { t.Parallel() client, server, _ := createPair(t) err := client.Close() require.NoError(t, err) err = server.Close() require.NoError(t, err) }) t.Run("CloseWithError", func(t *testing.T) { conn, err := peer.Client([]webrtc.ICEServer{}, nil) require.NoError(t, err) expectedErr := errors.New("wow") _ = conn.CloseWithError(expectedErr) _, err = conn.Dial(context.Background(), "", nil) require.ErrorIs(t, err, expectedErr) }) t.Run("PingConcurrent", func(t *testing.T) { t.Parallel() client, server, _ := createPair(t) var wg sync.WaitGroup wg.Add(2) go func() { defer wg.Done() _, err := client.Ping() require.NoError(t, err) }() go func() { defer wg.Done() _, err := server.Ping() require.NoError(t, err) }() wg.Wait() }) } func createPair(t *testing.T) (client *peer.Conn, server *peer.Conn, wan *vnet.Router) { lf := logging.NewDefaultLoggerFactory() lf.DefaultLogLevel = logging.LogLevelDisabled vnetMutex.Lock() defer vnetMutex.Unlock() wan, err := vnet.NewRouter(&vnet.RouterConfig{ CIDR: "1.2.3.0/24", LoggerFactory: lf, }) require.NoError(t, err) c1Net := vnet.NewNet(&vnet.NetConfig{ StaticIPs: []string{"1.2.3.4"}, }) err = wan.AddNet(c1Net) require.NoError(t, err) c2Net := vnet.NewNet(&vnet.NetConfig{ StaticIPs: []string{"1.2.3.5"}, }) err = wan.AddNet(c2Net) require.NoError(t, err) c1SettingEngine := webrtc.SettingEngine{} c1SettingEngine.SetVNet(c1Net) c1SettingEngine.SetPrflxAcceptanceMinWait(0) c1SettingEngine.SetICETimeouts(disconnectedTimeout, failedTimeout, keepAliveInterval) c1, err := peer.Client([]webrtc.ICEServer{}, &peer.ConnOpts{ SettingEngine: c1SettingEngine, Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug), }) require.NoError(t, err) t.Cleanup(func() { c1.Close() }) c2SettingEngine := webrtc.SettingEngine{} c2SettingEngine.SetVNet(c2Net) c2SettingEngine.SetPrflxAcceptanceMinWait(0) c2SettingEngine.SetICETimeouts(disconnectedTimeout, failedTimeout, keepAliveInterval) c2, err := peer.Server([]webrtc.ICEServer{}, &peer.ConnOpts{ SettingEngine: c2SettingEngine, Logger: slogtest.Make(t, nil).Named("server").Leveled(slog.LevelDebug), }) require.NoError(t, err) t.Cleanup(func() { c2.Close() }) err = wan.Start() require.NoError(t, err) t.Cleanup(func() { _ = wan.Stop() }) go func() { for { select { case c := <-c2.LocalCandidate(): _ = c1.AddRemoteCandidate(c) case c := <-c2.LocalSessionDescription(): c1.SetRemoteSessionDescription(c) case <-c2.Closed(): return } } }() go func() { for { select { case c := <-c1.LocalCandidate(): _ = c2.AddRemoteCandidate(c) case c := <-c1.LocalSessionDescription(): c2.SetRemoteSessionDescription(c) case <-c1.Closed(): return } } }() return c1, c2, wan }