test: Fix goroutine leak in peer exchange exit (#370)

Closes #361.
This commit is contained in:
Kyle Carberry
2022-02-27 19:19:02 -06:00
committed by GitHub
parent f630fc5787
commit 68ceea8a28

View File

@ -64,7 +64,7 @@ func TestConn(t *testing.T) {
t.Run("Ping", func(t *testing.T) { t.Run("Ping", func(t *testing.T) {
t.Parallel() t.Parallel()
client, server, _ := createPair(t) client, server, _ := createPair(t)
exchange(client, server) exchange(t, client, server)
_, err := client.Ping() _, err := client.Ping()
require.NoError(t, err) require.NoError(t, err)
_, err = server.Ping() _, err = server.Ping()
@ -74,7 +74,7 @@ func TestConn(t *testing.T) {
t.Run("PingNetworkOffline", func(t *testing.T) { t.Run("PingNetworkOffline", func(t *testing.T) {
t.Parallel() t.Parallel()
client, server, wan := createPair(t) client, server, wan := createPair(t)
exchange(client, server) exchange(t, client, server)
_, err := server.Ping() _, err := server.Ping()
require.NoError(t, err) require.NoError(t, err)
err = wan.Stop() err = wan.Stop()
@ -86,7 +86,7 @@ func TestConn(t *testing.T) {
t.Run("PingReconnect", func(t *testing.T) { t.Run("PingReconnect", func(t *testing.T) {
t.Parallel() t.Parallel()
client, server, wan := createPair(t) client, server, wan := createPair(t)
exchange(client, server) exchange(t, client, server)
_, err := server.Ping() _, err := server.Ping()
require.NoError(t, err) require.NoError(t, err)
// Create a channel that closes on disconnect. // Create a channel that closes on disconnect.
@ -107,7 +107,7 @@ func TestConn(t *testing.T) {
t.Run("Accept", func(t *testing.T) { t.Run("Accept", func(t *testing.T) {
t.Parallel() t.Parallel()
client, server, _ := createPair(t) client, server, _ := createPair(t)
exchange(client, server) exchange(t, client, server)
cch, err := client.Dial(context.Background(), "hello", &peer.ChannelOptions{}) cch, err := client.Dial(context.Background(), "hello", &peer.ChannelOptions{})
require.NoError(t, err) require.NoError(t, err)
@ -123,7 +123,7 @@ func TestConn(t *testing.T) {
t.Run("AcceptNetworkOffline", func(t *testing.T) { t.Run("AcceptNetworkOffline", func(t *testing.T) {
t.Parallel() t.Parallel()
client, server, wan := createPair(t) client, server, wan := createPair(t)
exchange(client, server) exchange(t, client, server)
cch, err := client.Dial(context.Background(), "hello", &peer.ChannelOptions{}) cch, err := client.Dial(context.Background(), "hello", &peer.ChannelOptions{})
require.NoError(t, err) require.NoError(t, err)
sch, err := server.Accept(context.Background()) sch, err := server.Accept(context.Background())
@ -140,7 +140,7 @@ func TestConn(t *testing.T) {
t.Run("Buffering", func(t *testing.T) { t.Run("Buffering", func(t *testing.T) {
t.Parallel() t.Parallel()
client, server, _ := createPair(t) client, server, _ := createPair(t)
exchange(client, server) exchange(t, client, server)
cch, err := client.Dial(context.Background(), "hello", &peer.ChannelOptions{}) cch, err := client.Dial(context.Background(), "hello", &peer.ChannelOptions{})
require.NoError(t, err) require.NoError(t, err)
sch, err := server.Accept(context.Background()) sch, err := server.Accept(context.Background())
@ -167,7 +167,7 @@ func TestConn(t *testing.T) {
t.Run("NetConn", func(t *testing.T) { t.Run("NetConn", func(t *testing.T) {
t.Parallel() t.Parallel()
client, server, _ := createPair(t) client, server, _ := createPair(t)
exchange(client, server) exchange(t, client, server)
srv, err := net.Listen("tcp", "127.0.0.1:0") srv, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err) require.NoError(t, err)
defer srv.Close() defer srv.Close()
@ -220,7 +220,7 @@ func TestConn(t *testing.T) {
t.Run("CloseBeforeNegotiate", func(t *testing.T) { t.Run("CloseBeforeNegotiate", func(t *testing.T) {
t.Parallel() t.Parallel()
client, server, _ := createPair(t) client, server, _ := createPair(t)
exchange(client, server) exchange(t, client, server)
err := client.Close() err := client.Close()
require.NoError(t, err) require.NoError(t, err)
err = server.Close() err = server.Close()
@ -240,7 +240,7 @@ func TestConn(t *testing.T) {
t.Run("PingConcurrent", func(t *testing.T) { t.Run("PingConcurrent", func(t *testing.T) {
t.Parallel() t.Parallel()
client, server, _ := createPair(t) client, server, _ := createPair(t)
exchange(client, server) exchange(t, client, server)
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(2) wg.Add(2)
go func() { go func() {
@ -271,7 +271,7 @@ func TestConn(t *testing.T) {
t.Run("ShortBuffer", func(t *testing.T) { t.Run("ShortBuffer", func(t *testing.T) {
t.Parallel() t.Parallel()
client, server, _ := createPair(t) client, server, _ := createPair(t)
exchange(client, server) exchange(t, client, server)
go func() { go func() {
channel, err := client.Dial(context.Background(), "test", nil) channel, err := client.Dial(context.Background(), "test", nil)
require.NoError(t, err) require.NoError(t, err)
@ -345,8 +345,17 @@ func createPair(t *testing.T) (client *peer.Conn, server *peer.Conn, wan *vnet.R
return channel1, channel2, wan return channel1, channel2, wan
} }
func exchange(client, server *peer.Conn) { func exchange(t *testing.T, client, server *peer.Conn) {
var wg sync.WaitGroup
wg.Add(2)
t.Cleanup(func() {
_ = client.Close()
_ = server.Close()
wg.Wait()
})
go func() { go func() {
defer wg.Done()
for { for {
select { select {
case c := <-server.LocalCandidate(): case c := <-server.LocalCandidate():
@ -358,8 +367,8 @@ func exchange(client, server *peer.Conn) {
} }
} }
}() }()
go func() { go func() {
defer wg.Done()
for { for {
select { select {
case c := <-client.LocalCandidate(): case c := <-client.LocalCandidate():