mirror of
https://github.com/coder/coder.git
synced 2025-07-09 11:45:56 +00:00
chore: Remove WebRTC networking (#3881)
* chore: Remove WebRTC networking * Fix race condition * Fix WebSocket not closing
This commit is contained in:
@ -20,12 +20,10 @@ import (
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
"tailscale.com/net/speedtest"
|
||||
"tailscale.com/tailcfg"
|
||||
|
||||
scp "github.com/bramvdbogaerde/go-scp"
|
||||
"github.com/google/uuid"
|
||||
"github.com/pion/udp"
|
||||
"github.com/pion/webrtc/v3"
|
||||
"github.com/pkg/sftp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@ -37,10 +35,6 @@ import (
|
||||
"cdr.dev/slog"
|
||||
"cdr.dev/slog/sloggers/slogtest"
|
||||
"github.com/coder/coder/agent"
|
||||
"github.com/coder/coder/peer"
|
||||
"github.com/coder/coder/peerbroker"
|
||||
"github.com/coder/coder/peerbroker/proto"
|
||||
"github.com/coder/coder/provisionersdk"
|
||||
"github.com/coder/coder/pty/ptytest"
|
||||
"github.com/coder/coder/tailnet"
|
||||
"github.com/coder/coder/tailnet/tailnettest"
|
||||
@ -54,64 +48,49 @@ func TestMain(m *testing.M) {
|
||||
func TestAgent(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("Stats", func(t *testing.T) {
|
||||
for _, tailscale := range []bool{true, false} {
|
||||
t.Run(fmt.Sprintf("tailscale=%v", tailscale), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Parallel()
|
||||
|
||||
setupAgent := func(t *testing.T) (agent.Conn, <-chan *agent.Stats) {
|
||||
var derpMap *tailcfg.DERPMap
|
||||
if tailscale {
|
||||
derpMap = tailnettest.RunDERPAndSTUN(t)
|
||||
}
|
||||
conn, stats := setupAgent(t, agent.Metadata{
|
||||
DERPMap: derpMap,
|
||||
}, 0)
|
||||
assert.Empty(t, <-stats)
|
||||
return conn, stats
|
||||
}
|
||||
t.Run("SSH", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, stats := setupAgent(t, agent.Metadata{}, 0)
|
||||
|
||||
t.Run("SSH", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, stats := setupAgent(t)
|
||||
sshClient, err := conn.SSHClient()
|
||||
require.NoError(t, err)
|
||||
defer sshClient.Close()
|
||||
session, err := sshClient.NewSession()
|
||||
require.NoError(t, err)
|
||||
defer session.Close()
|
||||
|
||||
sshClient, err := conn.SSHClient()
|
||||
require.NoError(t, err)
|
||||
session, err := sshClient.NewSession()
|
||||
require.NoError(t, err)
|
||||
defer session.Close()
|
||||
assert.EqualValues(t, 1, (<-stats).NumConns)
|
||||
assert.Greater(t, (<-stats).RxBytes, int64(0))
|
||||
assert.Greater(t, (<-stats).TxBytes, int64(0))
|
||||
})
|
||||
|
||||
assert.EqualValues(t, 1, (<-stats).NumConns)
|
||||
assert.Greater(t, (<-stats).RxBytes, int64(0))
|
||||
assert.Greater(t, (<-stats).TxBytes, int64(0))
|
||||
})
|
||||
t.Run("ReconnectingPTY", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("ReconnectingPTY", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, stats := setupAgent(t, agent.Metadata{}, 0)
|
||||
|
||||
conn, stats := setupAgent(t)
|
||||
ptyConn, err := conn.ReconnectingPTY(uuid.NewString(), 128, 128, "/bin/bash")
|
||||
require.NoError(t, err)
|
||||
defer ptyConn.Close()
|
||||
|
||||
ptyConn, err := conn.ReconnectingPTY(uuid.NewString(), 128, 128, "/bin/bash")
|
||||
require.NoError(t, err)
|
||||
defer ptyConn.Close()
|
||||
|
||||
data, err := json.Marshal(agent.ReconnectingPTYRequest{
|
||||
Data: "echo test\r\n",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = ptyConn.Write(data)
|
||||
require.NoError(t, err)
|
||||
|
||||
var s *agent.Stats
|
||||
require.Eventuallyf(t, func() bool {
|
||||
var ok bool
|
||||
s, ok = (<-stats)
|
||||
return ok && s.NumConns > 0 && s.RxBytes > 0 && s.TxBytes > 0
|
||||
}, testutil.WaitLong, testutil.IntervalFast,
|
||||
"never saw stats: %+v", s,
|
||||
)
|
||||
})
|
||||
data, err := json.Marshal(agent.ReconnectingPTYRequest{
|
||||
Data: "echo test\r\n",
|
||||
})
|
||||
}
|
||||
require.NoError(t, err)
|
||||
_, err = ptyConn.Write(data)
|
||||
require.NoError(t, err)
|
||||
|
||||
var s *agent.Stats
|
||||
require.Eventuallyf(t, func() bool {
|
||||
var ok bool
|
||||
s, ok = (<-stats)
|
||||
return ok && s.NumConns > 0 && s.RxBytes > 0 && s.TxBytes > 0
|
||||
}, testutil.WaitLong, testutil.IntervalFast,
|
||||
"never saw stats: %+v", s,
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("SessionExec", func(t *testing.T) {
|
||||
@ -235,6 +214,7 @@ func TestAgent(t *testing.T) {
|
||||
conn, _ := setupAgent(t, agent.Metadata{}, 0)
|
||||
sshClient, err := conn.SSHClient()
|
||||
require.NoError(t, err)
|
||||
defer sshClient.Close()
|
||||
client, err := sftp.NewClient(sshClient)
|
||||
require.NoError(t, err)
|
||||
tempFile := filepath.Join(t.TempDir(), "sftp")
|
||||
@ -252,6 +232,7 @@ func TestAgent(t *testing.T) {
|
||||
conn, _ := setupAgent(t, agent.Metadata{}, 0)
|
||||
sshClient, err := conn.SSHClient()
|
||||
require.NoError(t, err)
|
||||
defer sshClient.Close()
|
||||
scpClient, err := scp.NewClientBySSH(sshClient)
|
||||
require.NoError(t, err)
|
||||
tempFile := filepath.Join(t.TempDir(), "scp")
|
||||
@ -384,9 +365,7 @@ func TestAgent(t *testing.T) {
|
||||
t.Skip("ConPTY appears to be inconsistent on Windows.")
|
||||
}
|
||||
|
||||
conn, _ := setupAgent(t, agent.Metadata{
|
||||
DERPMap: tailnettest.RunDERPAndSTUN(t),
|
||||
}, 0)
|
||||
conn, _ := setupAgent(t, agent.Metadata{}, 0)
|
||||
id := uuid.NewString()
|
||||
netConn, err := conn.ReconnectingPTY(id, 100, 100, "/bin/bash")
|
||||
require.NoError(t, err)
|
||||
@ -462,19 +441,6 @@ func TestAgent(t *testing.T) {
|
||||
return l
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Unix",
|
||||
setup: func(t *testing.T) net.Listener {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Unix socket forwarding isn't supported on Windows")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
l, err := net.Listen("unix", filepath.Join(tmpDir, "test.sock"))
|
||||
require.NoError(t, err, "create UDP listener")
|
||||
return l
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
@ -496,8 +462,11 @@ func TestAgent(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
// Dial the listener over WebRTC twice and test out of order
|
||||
conn, _ := setupAgent(t, agent.Metadata{}, 0)
|
||||
require.Eventually(t, func() bool {
|
||||
_, err := conn.Ping()
|
||||
return err == nil
|
||||
}, testutil.WaitMedium, testutil.IntervalFast)
|
||||
conn1, err := conn.DialContext(context.Background(), l.Addr().Network(), l.Addr().String())
|
||||
require.NoError(t, err)
|
||||
defer conn1.Close()
|
||||
@ -506,36 +475,11 @@ func TestAgent(t *testing.T) {
|
||||
defer conn2.Close()
|
||||
testDial(t, conn2)
|
||||
testDial(t, conn1)
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DialError", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
// This test uses Unix listeners so we can very easily ensure that
|
||||
// no other tests decide to listen on the same random port we
|
||||
// picked.
|
||||
t.Skip("this test is unsupported on Windows")
|
||||
return
|
||||
}
|
||||
|
||||
tmpDir, err := os.MkdirTemp("", "coderd_agent_test_")
|
||||
require.NoError(t, err, "create temp dir")
|
||||
t.Cleanup(func() {
|
||||
_ = os.RemoveAll(tmpDir)
|
||||
})
|
||||
|
||||
// Try to dial the non-existent Unix socket over WebRTC
|
||||
conn, _ := setupAgent(t, agent.Metadata{}, 0)
|
||||
netConn, err := conn.DialContext(context.Background(), "unix", filepath.Join(tmpDir, "test.sock"))
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "remote dial error")
|
||||
require.ErrorContains(t, err, "no such file")
|
||||
require.Nil(t, netConn)
|
||||
})
|
||||
|
||||
t.Run("Tailnet", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
derpMap := tailnettest.RunDERPAndSTUN(t)
|
||||
@ -578,7 +522,7 @@ func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exe
|
||||
return
|
||||
}
|
||||
ssh, err := agentConn.SSH()
|
||||
if !assert.NoError(t, err) {
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
@ -622,11 +566,12 @@ func (c closeFunc) Close() error {
|
||||
}
|
||||
|
||||
func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration) (
|
||||
agent.Conn,
|
||||
*agent.Conn,
|
||||
<-chan *agent.Stats,
|
||||
) {
|
||||
client, server := provisionersdk.TransportPipe()
|
||||
tailscale := metadata.DERPMap != nil
|
||||
if metadata.DERPMap == nil {
|
||||
metadata.DERPMap = tailnettest.RunDERPAndSTUN(t)
|
||||
}
|
||||
coordinator := tailnet.NewCoordinator()
|
||||
agentID := uuid.New()
|
||||
statsCh := make(chan *agent.Stats)
|
||||
@ -634,17 +579,18 @@ func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration)
|
||||
FetchMetadata: func(ctx context.Context) (agent.Metadata, error) {
|
||||
return metadata, nil
|
||||
},
|
||||
WebRTCDialer: func(ctx context.Context, logger slog.Logger) (*peerbroker.Listener, error) {
|
||||
listener, err := peerbroker.Listen(server, nil)
|
||||
return listener, err
|
||||
},
|
||||
CoordinatorDialer: func(ctx context.Context) (net.Conn, error) {
|
||||
clientConn, serverConn := net.Pipe()
|
||||
closed := make(chan struct{})
|
||||
t.Cleanup(func() {
|
||||
_ = serverConn.Close()
|
||||
_ = clientConn.Close()
|
||||
<-closed
|
||||
})
|
||||
go coordinator.ServeAgent(serverConn, agentID)
|
||||
go func() {
|
||||
_ = coordinator.ServeAgent(serverConn, agentID)
|
||||
close(closed)
|
||||
}()
|
||||
return clientConn, nil
|
||||
},
|
||||
Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug),
|
||||
@ -683,46 +629,27 @@ func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration)
|
||||
},
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
_ = client.Close()
|
||||
_ = server.Close()
|
||||
_ = closer.Close()
|
||||
})
|
||||
api := proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(client))
|
||||
stream, err := api.NegotiateConnection(context.Background())
|
||||
assert.NoError(t, err)
|
||||
if tailscale {
|
||||
conn, err := tailnet.NewConn(&tailnet.Options{
|
||||
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
|
||||
DERPMap: metadata.DERPMap,
|
||||
Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
clientConn, serverConn := net.Pipe()
|
||||
t.Cleanup(func() {
|
||||
_ = clientConn.Close()
|
||||
_ = serverConn.Close()
|
||||
_ = conn.Close()
|
||||
})
|
||||
go coordinator.ServeClient(serverConn, uuid.New(), agentID)
|
||||
sendNode, _ := tailnet.ServeCoordinator(clientConn, func(node []*tailnet.Node) error {
|
||||
return conn.UpdateNodes(node)
|
||||
})
|
||||
conn.SetNodeCallback(sendNode)
|
||||
return &agent.TailnetConn{
|
||||
Conn: conn,
|
||||
}, statsCh
|
||||
}
|
||||
conn, err := peerbroker.Dial(stream, []webrtc.ICEServer{}, &peer.ConnOptions{
|
||||
Logger: slogtest.Make(t, nil),
|
||||
conn, err := tailnet.NewConn(&tailnet.Options{
|
||||
Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)},
|
||||
DERPMap: metadata.DERPMap,
|
||||
Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
clientConn, serverConn := net.Pipe()
|
||||
t.Cleanup(func() {
|
||||
_ = clientConn.Close()
|
||||
_ = serverConn.Close()
|
||||
_ = conn.Close()
|
||||
})
|
||||
|
||||
return &agent.WebRTCConn{
|
||||
Negotiator: api,
|
||||
Conn: conn,
|
||||
go coordinator.ServeClient(serverConn, uuid.New(), agentID)
|
||||
sendNode, _ := tailnet.ServeCoordinator(clientConn, func(node []*tailnet.Node) error {
|
||||
return conn.UpdateNodes(node)
|
||||
})
|
||||
conn.SetNodeCallback(sendNode)
|
||||
return &agent.Conn{
|
||||
Conn: conn,
|
||||
}, statsCh
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user