Files
coder/vpn/speaker_internal_test.go
Cian Johnston 7b88776403 chore(testutil): add testutil.GoleakOptions (#16070)
- Adds `testutil.GoleakOptions` and consolidates existing options to
this location
- Pre-emptively adds required ignore for this Dependabot PR to pass CI
https://github.com/coder/coder/pull/16066
2025-01-08 15:38:37 +00:00

457 lines
13 KiB
Go

package vpn
import (
"context"
"encoding/binary"
"io"
"net"
"strings"
"testing"
"time"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"google.golang.org/protobuf/proto"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/testutil"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m, testutil.GoleakOptions...)
}
// TestSpeaker_RawPeer tests the speaker with a peer that we simulate by directly making reads and
// writes to the other end of the pipe. There should be at least one test that does this, rather
// than use 2 speakers so that we don't have a bug where we don't adhere to the stated protocol, but
// both sides have the bug and can still communicate.
func TestSpeaker_RawPeer(t *testing.T) {
t.Parallel()
mp, tp := net.Pipe()
defer mp.Close()
defer tp.Close()
ctx := testutil.Context(t, testutil.WaitShort)
// We're going to use deadlines for this test so that we don't hang the main test thread if
// the speaker misbehaves.
err := mp.SetReadDeadline(time.Now().Add(testutil.WaitShort))
require.NoError(t, err)
err = mp.SetWriteDeadline(time.Now().Add(testutil.WaitShort))
require.NoError(t, err)
logger := testutil.Logger(t)
var tun *speaker[*TunnelMessage, *ManagerMessage, ManagerMessage]
errCh := make(chan error, 1)
go func() {
s, err := newSpeaker[*TunnelMessage, *ManagerMessage](ctx, logger, tp, SpeakerRoleTunnel, SpeakerRoleManager)
tun = s
errCh <- err
}()
expectedHandshake := "codervpn tunnel 1.0\n"
b := make([]byte, 256)
n, err := mp.Read(b)
require.NoError(t, err)
require.Equal(t, expectedHandshake, string(b[:n]))
_, err = mp.Write([]byte("codervpn manager 1.3,2.1\n"))
require.NoError(t, err)
err = testutil.RequireRecvCtx(ctx, t, errCh)
require.NoError(t, err)
tun.start()
// send a message and verify it follows protocol for encoding
testutil.RequireSendCtx(ctx, t, tun.sendCh, &TunnelMessage{
Msg: &TunnelMessage_Start{
Start: &StartResponse{},
},
})
var msgLen uint32
err = binary.Read(mp, binary.BigEndian, &msgLen)
require.NoError(t, err)
msgBuf := make([]byte, msgLen)
n, err = mp.Read(msgBuf)
require.NoError(t, err)
require.Equal(t, msgLen, uint32(n))
msg := new(TunnelMessage)
err = proto.Unmarshal(msgBuf, msg)
require.NoError(t, err)
_, ok := msg.Msg.(*TunnelMessage_Start)
require.True(t, ok)
// Should close the pipe on close of the speaker.
err = tun.Close()
require.NoError(t, err)
_, err = mp.Read(b)
require.ErrorIs(t, err, io.EOF)
}
func TestSpeaker_HandshakeRWFailure(t *testing.T) {
t.Parallel()
mp, tp := net.Pipe()
// immediately close the pipe, so we'll get read & write failures on handshake
_ = mp.Close()
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
var tun *speaker[*TunnelMessage, *ManagerMessage, ManagerMessage]
errCh := make(chan error, 1)
go func() {
s, err := newSpeaker[*TunnelMessage, *ManagerMessage](
ctx, logger.Named("tun"), tp, SpeakerRoleTunnel, SpeakerRoleManager,
)
tun = s
errCh <- err
}()
err := testutil.RequireRecvCtx(ctx, t, errCh)
require.ErrorContains(t, err, "handshake failed")
require.Nil(t, tun)
}
func TestSpeaker_HandshakeCtxDone(t *testing.T) {
t.Parallel()
mp, tp := net.Pipe()
defer mp.Close()
defer tp.Close()
testCtx := testutil.Context(t, testutil.WaitShort)
ctx, cancel := context.WithCancel(testCtx)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
var tun *speaker[*TunnelMessage, *ManagerMessage, ManagerMessage]
errCh := make(chan error, 1)
go func() {
s, err := newSpeaker[*TunnelMessage, *ManagerMessage](
ctx, logger.Named("tun"), tp, SpeakerRoleTunnel, SpeakerRoleManager,
)
tun = s
errCh <- err
}()
cancel()
err := testutil.RequireRecvCtx(testCtx, t, errCh)
require.ErrorContains(t, err, "handshake failed")
require.Nil(t, tun)
}
func TestSpeaker_OversizeHandshake(t *testing.T) {
t.Parallel()
mp, tp := net.Pipe()
defer mp.Close()
defer tp.Close()
ctx := testutil.Context(t, testutil.WaitShort)
// We're going to use deadlines for this test so that we don't hang the main test thread if
// the speaker misbehaves.
err := mp.SetReadDeadline(time.Now().Add(testutil.WaitShort))
require.NoError(t, err)
err = mp.SetWriteDeadline(time.Now().Add(testutil.WaitShort))
require.NoError(t, err)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
var tun *speaker[*TunnelMessage, *ManagerMessage, ManagerMessage]
errCh := make(chan error, 1)
go func() {
s, err := newSpeaker[*TunnelMessage, *ManagerMessage](ctx, logger, tp, SpeakerRoleTunnel, SpeakerRoleManager)
tun = s
errCh <- err
}()
expectedHandshake := "codervpn tunnel 1.0\n"
b := make([]byte, 256)
n, err := mp.Read(b)
require.NoError(t, err)
require.Equal(t, expectedHandshake, string(b[:n]))
badHandshake := strings.Repeat("bad", 256)
_, err = mp.Write([]byte(badHandshake))
require.Error(t, err) // other side closes when we write too much
err = testutil.RequireRecvCtx(ctx, t, errCh)
require.ErrorContains(t, err, "handshake failed")
require.Nil(t, tun)
}
func TestSpeaker_HandshakeInvalid(t *testing.T) {
t.Parallel()
// nolint: paralleltest // no longer need to reinitialize loop vars in go 1.22
for _, tc := range []struct {
name, handshake string
}{
{name: "preamble", handshake: "ssh manager 1.0\n"},
{name: "2components", handshake: "ssh manager\n"},
{name: "newmajors", handshake: "codervpn manager 2.0,3.0\n"},
{name: "0version", handshake: "codervpn 0.1 manager\n"},
{name: "unknown_role", handshake: "codervpn 1.0 supervisor\n"},
{name: "unexpected_role", handshake: "codervpn 1.0 tunnel\n"},
} {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
mp, tp := net.Pipe()
defer mp.Close()
defer tp.Close()
ctx := testutil.Context(t, testutil.WaitShort)
// We're going to use deadlines for this test so that we don't hang the main test thread if
// the speaker misbehaves.
err := mp.SetReadDeadline(time.Now().Add(testutil.WaitShort))
require.NoError(t, err)
err = mp.SetWriteDeadline(time.Now().Add(testutil.WaitShort))
require.NoError(t, err)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
var tun *speaker[*TunnelMessage, *ManagerMessage, ManagerMessage]
errCh := make(chan error, 1)
go func() {
s, err := newSpeaker[*TunnelMessage, *ManagerMessage](ctx, logger, tp, SpeakerRoleTunnel, SpeakerRoleManager)
tun = s
errCh <- err
}()
_, err = mp.Write([]byte(tc.handshake))
require.NoError(t, err)
expectedHandshake := "codervpn tunnel 1.0\n"
b := make([]byte, 256)
n, err := mp.Read(b)
require.NoError(t, err)
require.Equal(t, expectedHandshake, string(b[:n]))
err = testutil.RequireRecvCtx(ctx, t, errCh)
require.ErrorContains(t, err, "validate header")
require.Nil(t, tun)
})
}
}
// TestSpeaker_RawPeer tests the speaker with a peer that we simulate by directly making reads and
// writes to the other end of the pipe. There should be at least one test that does this, rather
// than use 2 speakers so that we don't have a bug where we don't adhere to the stated protocol, but
// both sides have the bug and can still communicate.
func TestSpeaker_CorruptMessage(t *testing.T) {
t.Parallel()
mp, tp := net.Pipe()
defer mp.Close()
defer tp.Close()
ctx := testutil.Context(t, testutil.WaitShort)
// We're going to use deadlines for this test so that we don't hang the main test thread if
// the speaker misbehaves.
err := mp.SetReadDeadline(time.Now().Add(testutil.WaitShort))
require.NoError(t, err)
err = mp.SetWriteDeadline(time.Now().Add(testutil.WaitShort))
require.NoError(t, err)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
var tun *speaker[*TunnelMessage, *ManagerMessage, ManagerMessage]
errCh := make(chan error, 1)
go func() {
s, err := newSpeaker[*TunnelMessage, *ManagerMessage](ctx, logger, tp, SpeakerRoleTunnel, SpeakerRoleManager)
tun = s
errCh <- err
}()
expectedHandshake := "codervpn tunnel 1.0\n"
b := make([]byte, 256)
n, err := mp.Read(b)
require.NoError(t, err)
require.Equal(t, expectedHandshake, string(b[:n]))
_, err = mp.Write([]byte("codervpn manager 1.0\n"))
require.NoError(t, err)
err = testutil.RequireRecvCtx(ctx, t, errCh)
require.NoError(t, err)
tun.start()
err = binary.Write(mp, binary.BigEndian, uint32(10))
require.NoError(t, err)
n, err = mp.Write([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0})
require.NoError(t, err)
require.EqualValues(t, 10, n)
// it should hang up on us if we write nonsense
_, err = mp.Read(b)
require.ErrorIs(t, err, io.EOF)
}
func TestSpeaker_unaryRPC_mainline(t *testing.T) {
t.Parallel()
ctx, tun, mgr := setupSpeakers(t)
errCh := make(chan error, 1)
var resp *TunnelMessage
go func() {
r, err := mgr.unaryRPC(ctx, &ManagerMessage{
Msg: &ManagerMessage_Start{
Start: &StartRequest{
CoderUrl: "https://coder.example.com",
},
},
})
resp = r
errCh <- err
}()
req := testutil.RequireRecvCtx(ctx, t, tun.requests)
require.NotEqualValues(t, 0, req.msg.GetRpc().GetMsgId())
require.Equal(t, "https://coder.example.com", req.msg.GetStart().GetCoderUrl())
err := req.sendReply(&TunnelMessage{
Msg: &TunnelMessage_Start{
Start: &StartResponse{},
},
})
require.NoError(t, err)
err = testutil.RequireRecvCtx(ctx, t, errCh)
require.NoError(t, err)
_, ok := resp.Msg.(*TunnelMessage_Start)
require.True(t, ok)
// closing the manager should close the tun.requests channel
err = mgr.Close()
require.NoError(t, err)
select {
case _, ok := <-tun.requests:
require.False(t, ok)
case <-ctx.Done():
t.Fatal("timed out waiting for requests to close")
}
}
func TestSpeaker_unaryRPC_canceled(t *testing.T) {
t.Parallel()
testCtx, tun, mgr := setupSpeakers(t)
ctx, cancel := context.WithCancel(testCtx)
defer cancel()
errCh := make(chan error, 1)
var resp *TunnelMessage
go func() {
r, err := mgr.unaryRPC(ctx, &ManagerMessage{
Msg: &ManagerMessage_Start{
Start: &StartRequest{
CoderUrl: "https://coder.example.com",
},
},
})
resp = r
errCh <- err
}()
req := testutil.RequireRecvCtx(testCtx, t, tun.requests)
require.NotEqualValues(t, 0, req.msg.GetRpc().GetMsgId())
require.Equal(t, "https://coder.example.com", req.msg.GetStart().GetCoderUrl())
cancel()
err := testutil.RequireRecvCtx(testCtx, t, errCh)
require.ErrorIs(t, err, context.Canceled)
require.Nil(t, resp)
err = req.sendReply(&TunnelMessage{
Msg: &TunnelMessage_Start{
Start: &StartResponse{},
},
})
require.NoError(t, err)
}
func TestSpeaker_unaryRPC_hung_up(t *testing.T) {
t.Parallel()
testCtx, tun, mgr := setupSpeakers(t)
ctx, cancel := context.WithCancel(testCtx)
defer cancel()
errCh := make(chan error, 1)
var resp *TunnelMessage
go func() {
r, err := mgr.unaryRPC(ctx, &ManagerMessage{
Msg: &ManagerMessage_Start{
Start: &StartRequest{
CoderUrl: "https://coder.example.com",
},
},
})
resp = r
errCh <- err
}()
req := testutil.RequireRecvCtx(testCtx, t, tun.requests)
require.NotEqualValues(t, 0, req.msg.GetRpc().GetMsgId())
require.Equal(t, "https://coder.example.com", req.msg.GetStart().GetCoderUrl())
// When: Tunnel closes instead of replying.
err := tun.Close()
require.NoError(t, err)
// Then: we should get an error on the RPC.
err = testutil.RequireRecvCtx(testCtx, t, errCh)
require.ErrorIs(t, err, io.ErrUnexpectedEOF)
require.Nil(t, resp)
}
func TestSpeaker_unaryRPC_sendLoop(t *testing.T) {
t.Parallel()
testCtx, tun, mgr := setupSpeakers(t)
ctx, cancel := context.WithCancel(testCtx)
defer cancel()
// When: Tunnel closes before we send the RPC
err := tun.Close()
require.NoError(t, err)
// When: serdes sendloop is closed
// Send a message from the manager. This closes the manager serdes sendloop, since it will error
// when writing the message to the (closed) pipe.
testutil.RequireSendCtx(ctx, t, mgr.sendCh, &ManagerMessage{
Msg: &ManagerMessage_GetPeerUpdate{},
})
// When: we send an RPC
errCh := make(chan error, 1)
var resp *TunnelMessage
go func() {
r, err := mgr.unaryRPC(ctx, &ManagerMessage{
Msg: &ManagerMessage_Start{
Start: &StartRequest{
CoderUrl: "https://coder.example.com",
},
},
})
resp = r
errCh <- err
}()
// Then: we should get an error on the RPC.
err = testutil.RequireRecvCtx(testCtx, t, errCh)
require.ErrorIs(t, err, io.ErrUnexpectedEOF)
require.Nil(t, resp)
}
func setupSpeakers(t *testing.T) (
context.Context, *speaker[*TunnelMessage, *ManagerMessage, ManagerMessage], *speaker[*ManagerMessage, *TunnelMessage, TunnelMessage],
) {
mp, tp := net.Pipe()
t.Cleanup(func() { _ = mp.Close() })
t.Cleanup(func() { _ = tp.Close() })
ctx := testutil.Context(t, testutil.WaitShort)
logger := testutil.Logger(t)
var tun *speaker[*TunnelMessage, *ManagerMessage, ManagerMessage]
var mgr *speaker[*ManagerMessage, *TunnelMessage, TunnelMessage]
errCh := make(chan error, 2)
go func() {
s, err := newSpeaker[*TunnelMessage, *ManagerMessage](
ctx, logger.Named("tun"), tp, SpeakerRoleTunnel, SpeakerRoleManager,
)
tun = s
errCh <- err
}()
go func() {
s, err := newSpeaker[*ManagerMessage, *TunnelMessage](
ctx, logger.Named("mgr"), mp, SpeakerRoleManager, SpeakerRoleTunnel,
)
mgr = s
errCh <- err
}()
err := testutil.RequireRecvCtx(ctx, t, errCh)
require.NoError(t, err)
err = testutil.RequireRecvCtx(ctx, t, errCh)
require.NoError(t, err)
tun.start()
mgr.start()
return ctx, tun, mgr
}