chore: rework RPC version negotiation (#15687)

Changes the RPC header format from `codervpn <version> <role>` to
`codervpn <role> <version1,version2,...>`.

The versions list is a list of the maximum supported minor version for
each major version, sorted by major versions.

E.g. `1.0,2.3,3.1` means `1.0, 2.0, 2.1, 2.2, 2.3, 3.0, 3.1` are
supported.

When we eventually support multiple versions, the peer's version list
will be compared against the current supported versions list to
determine the maximum major and minor version supported by both peers.

Closes #15601
This commit is contained in:
Dean Sheather
2024-12-04 18:38:24 +09:00
committed by GitHub
parent 887ea14b6a
commit 14a60303ac
4 changed files with 428 additions and 27 deletions

View File

@ -11,7 +11,6 @@ import (
"google.golang.org/protobuf/proto"
"cdr.dev/slog"
"github.com/coder/coder/v2/apiversion"
)
type SpeakerRole string
@ -258,7 +257,7 @@ func handshake(
// read and write simultaneously to avoid deadlocking if the conn is not buffered
errCh := make(chan error, 2)
go func() {
ours := headerString(CurrentVersion, me)
ours := headerString(me, CurrentSupportedVersions)
_, err := conn.Write([]byte(ours))
logger.Debug(ctx, "wrote out header")
if err != nil {
@ -316,34 +315,43 @@ func handshake(
}
}
logger.Debug(ctx, "handshake read/write complete", slog.F("their_header", theirHeader))
err := validateHeader(theirHeader, them)
gotVersion, err := validateHeader(theirHeader, them, CurrentSupportedVersions)
if err != nil {
return xerrors.Errorf("validate header (%s): %w", theirHeader, err)
}
logger.Debug(ctx, "handshake validated", slog.F("common_version", gotVersion))
// TODO: actually use the common version to perform different behavior once
// we have multiple versions
return nil
}
const headerPreamble = "codervpn"
func headerString(version *apiversion.APIVersion, role SpeakerRole) string {
return fmt.Sprintf("%s %s %s\n", headerPreamble, version.String(), role)
func headerString(role SpeakerRole, versions RPCVersionList) string {
return fmt.Sprintf("%s %s %s\n", headerPreamble, role, versions.String())
}
func validateHeader(header string, expectedRole SpeakerRole) error {
func validateHeader(header string, expectedRole SpeakerRole, supportedVersions RPCVersionList) (RPCVersion, error) {
parts := strings.Split(header, " ")
if len(parts) != 3 {
return xerrors.New("wrong number of parts")
return RPCVersion{}, xerrors.New("wrong number of parts")
}
if parts[0] != headerPreamble {
return xerrors.New("invalid preamble")
return RPCVersion{}, xerrors.New("invalid preamble")
}
if err := CurrentVersion.Validate(parts[1]); err != nil {
return xerrors.Errorf("version: %w", err)
if parts[1] != string(expectedRole) {
return RPCVersion{}, xerrors.New("unexpected role")
}
if parts[2] != string(expectedRole) {
return xerrors.New("unexpected role")
otherVersions, err := ParseRPCVersionList(parts[2])
if err != nil {
return RPCVersion{}, xerrors.Errorf("parse version list %q: %w", parts[2], err)
}
return nil
compatibleVersion, ok := supportedVersions.IsCompatibleWith(otherVersions)
if !ok {
return RPCVersion{},
xerrors.Errorf("current supported versions %q is not compatible with peer versions %q", supportedVersions.String(), otherVersions.String())
}
return compatibleVersion, nil
}
type request[S rpcMessage, R rpcMessage] struct {

View File

@ -47,14 +47,14 @@ func TestSpeaker_RawPeer(t *testing.T) {
errCh <- err
}()
expectedHandshake := "codervpn 1.0 tunnel\n"
expectedHandshake := "codervpn tunnel 1.0\n"
b := make([]byte, 256)
n, err := mp.Read(b)
require.NoError(t, err)
require.Equal(t, expectedHandshake, string(b[:n]))
_, err = mp.Write([]byte("codervpn 1.0 manager\n"))
_, err = mp.Write([]byte("codervpn manager 1.3,2.1\n"))
require.NoError(t, err)
err = testutil.RequireRecvCtx(ctx, t, errCh)
@ -155,7 +155,7 @@ func TestSpeaker_OversizeHandshake(t *testing.T) {
errCh <- err
}()
expectedHandshake := "codervpn 1.0 tunnel\n"
expectedHandshake := "codervpn tunnel 1.0\n"
b := make([]byte, 256)
n, err := mp.Read(b)
@ -177,10 +177,10 @@ func TestSpeaker_HandshakeInvalid(t *testing.T) {
for _, tc := range []struct {
name, handshake string
}{
{name: "preamble", handshake: "ssh 1.0 manager\n"},
{name: "preamble", handshake: "ssh manager 1.0\n"},
{name: "2components", handshake: "ssh manager\n"},
{name: "newversion", handshake: "codervpn 1.1 manager\n"},
{name: "oldversion", handshake: "codervpn 0.1 manager\n"},
{name: "newmajors", handshake: "codervpn manager 2.0,3.0\n"},
{name: "0version", handshake: "codervpn 0.1 manager\n"},
{name: "unknown_role", handshake: "codervpn 1.0 supervisor\n"},
{name: "unexpected_role", handshake: "codervpn 1.0 tunnel\n"},
} {
@ -208,7 +208,7 @@ func TestSpeaker_HandshakeInvalid(t *testing.T) {
_, err = mp.Write([]byte(tc.handshake))
require.NoError(t, err)
expectedHandshake := "codervpn 1.0 tunnel\n"
expectedHandshake := "codervpn tunnel 1.0\n"
b := make([]byte, 256)
n, err := mp.Read(b)
require.NoError(t, err)
@ -246,14 +246,14 @@ func TestSpeaker_CorruptMessage(t *testing.T) {
errCh <- err
}()
expectedHandshake := "codervpn 1.0 tunnel\n"
expectedHandshake := "codervpn tunnel 1.0\n"
b := make([]byte, 256)
n, err := mp.Read(b)
require.NoError(t, err)
require.Equal(t, expectedHandshake, string(b[:n]))
_, err = mp.Write([]byte("codervpn 1.0 manager\n"))
_, err = mp.Write([]byte("codervpn manager 1.0\n"))
require.NoError(t, err)
err = testutil.RequireRecvCtx(ctx, t, errCh)

View File

@ -1,10 +1,141 @@
package vpn
import "github.com/coder/coder/v2/apiversion"
import (
"fmt"
"strconv"
"strings"
const (
CurrentMajor = 1
CurrentMinor = 0
"golang.org/x/xerrors"
)
var CurrentVersion = apiversion.New(CurrentMajor, CurrentMinor)
// CurrentSupportedVersions is the list of versions supported by this
// implementation of the VPN RPC protocol.
var CurrentSupportedVersions = RPCVersionList{
Versions: []RPCVersion{
{Major: 1, Minor: 0},
},
}
// RPCVersion represents a single version of the RPC protocol. Any given version
// is expected to be backwards compatible with all previous minor versions on
// the same major version.
//
// e.g. RPCVersion{2, 3} is backwards compatible with RPCVersion{2, 2} but is
// not backwards compatible with RPCVersion{1, 2}.
type RPCVersion struct {
Major uint64 `json:"major"`
Minor uint64 `json:"minor"`
}
// ParseRPCVersion parses a version string in the format "major.minor" into a
// RPCVersion.
func ParseRPCVersion(str string) (RPCVersion, error) {
split := strings.Split(str, ".")
if len(split) != 2 {
return RPCVersion{}, xerrors.Errorf("invalid version string: %s", str)
}
major, err := strconv.ParseUint(split[0], 10, 64)
if err != nil {
return RPCVersion{}, xerrors.Errorf("invalid version string: %s", str)
}
if major == 0 {
return RPCVersion{}, xerrors.Errorf("invalid version string: %s", str)
}
minor, err := strconv.ParseUint(split[1], 10, 64)
if err != nil {
return RPCVersion{}, xerrors.Errorf("invalid version string: %s", str)
}
return RPCVersion{Major: major, Minor: minor}, nil
}
func (v RPCVersion) String() string {
return fmt.Sprintf("%d.%d", v.Major, v.Minor)
}
// IsCompatibleWith returns the lowest version that is compatible with both
// versions. If the versions are not compatible, the second return value will be
// false.
func (v RPCVersion) IsCompatibleWith(other RPCVersion) (RPCVersion, bool) {
if v.Major != other.Major {
return RPCVersion{}, false
}
// The lowest minor version from the two versions should be returned.
if v.Minor < other.Minor {
return v, true
}
return other, true
}
// RPCVersionList represents a list of RPC versions supported by a RPC peer. An
type RPCVersionList struct {
Versions []RPCVersion `json:"versions"`
}
// ParseRPCVersionList parses a version string in the format
// "major.minor,major.minor" into a RPCVersionList.
func ParseRPCVersionList(str string) (RPCVersionList, error) {
split := strings.Split(str, ",")
versions := make([]RPCVersion, len(split))
for i, v := range split {
version, err := ParseRPCVersion(v)
if err != nil {
return RPCVersionList{}, xerrors.Errorf("invalid version list: %s", str)
}
versions[i] = version
}
vl := RPCVersionList{Versions: versions}
err := vl.Validate()
if err != nil {
return RPCVersionList{}, xerrors.Errorf("invalid parsed version list %q: %w", str, err)
}
return vl, nil
}
func (vl RPCVersionList) String() string {
versionStrings := make([]string, len(vl.Versions))
for i, v := range vl.Versions {
versionStrings[i] = v.String()
}
return strings.Join(versionStrings, ",")
}
// Validate returns an error if the version list is not sorted or contains
// duplicate major versions.
func (vl RPCVersionList) Validate() error {
if len(vl.Versions) == 0 {
return xerrors.New("no versions")
}
for i := 0; i < len(vl.Versions); i++ {
if vl.Versions[i].Major == 0 {
return xerrors.Errorf("invalid version: %s", vl.Versions[i].String())
}
if i > 0 && vl.Versions[i-1].Major == vl.Versions[i].Major {
return xerrors.Errorf("duplicate major version: %d", vl.Versions[i].Major)
}
if i > 0 && vl.Versions[i-1].Major > vl.Versions[i].Major {
return xerrors.Errorf("versions are not sorted")
}
}
return nil
}
// IsCompatibleWith returns the lowest version that is compatible with both
// version lists. If the versions are not compatible, the second return value
// will be false.
func (vl RPCVersionList) IsCompatibleWith(other RPCVersionList) (RPCVersion, bool) {
bestVersion := RPCVersion{}
for _, v1 := range vl.Versions {
for _, v2 := range other.Versions {
if v1.Major == v2.Major && v1.Major > bestVersion.Major {
v, ok := v1.IsCompatibleWith(v2)
if ok {
bestVersion = v
}
}
}
}
if bestVersion.Major == 0 {
return bestVersion, false
}
return bestVersion, true
}

262
vpn/version_test.go Normal file
View File

@ -0,0 +1,262 @@
package vpn_test
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/vpn"
)
func TestRPCVersionLatest(t *testing.T) {
t.Parallel()
require.NoError(t, vpn.CurrentSupportedVersions.Validate())
}
func TestRPCVersionParseString(t *testing.T) {
t.Parallel()
cases := []struct {
name string
input string
want vpn.RPCVersion
}{
{
name: "valid version",
input: "1.0",
want: vpn.RPCVersion{Major: 1, Minor: 0},
},
{
name: "valid version with larger numbers",
input: "12.34",
want: vpn.RPCVersion{Major: 12, Minor: 34},
},
{
name: "empty string",
input: "",
},
{
name: "one part",
input: "1",
},
{
name: "three parts",
input: "1.0.0",
},
{
name: "major version is 0",
input: "0.1",
},
{
name: "invalid major version",
input: "a.1",
},
{
name: "invalid minor version",
input: "1.a",
},
}
// nolint:paralleltest
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got, err := vpn.ParseRPCVersion(tc.input)
if tc.want.Major == 0 {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, tc.want, got)
require.Equal(t, tc.input, got.String())
}
})
}
}
func TestRPCVersionIsCompatibleWith(t *testing.T) {
t.Parallel()
cases := []struct {
name string
v1 vpn.RPCVersion
v2 vpn.RPCVersion
want vpn.RPCVersion
wantBool bool
}{
{
name: "same version",
v1: vpn.RPCVersion{Major: 1, Minor: 0},
v2: vpn.RPCVersion{Major: 1, Minor: 0},
want: vpn.RPCVersion{Major: 1, Minor: 0},
},
{
name: "compatible minor versions",
v1: vpn.RPCVersion{Major: 1, Minor: 2},
v2: vpn.RPCVersion{Major: 1, Minor: 3},
want: vpn.RPCVersion{Major: 1, Minor: 2},
},
{
name: "incompatible major versions",
v1: vpn.RPCVersion{Major: 1, Minor: 0},
v2: vpn.RPCVersion{Major: 2, Minor: 0},
want: vpn.RPCVersion{},
},
}
// nolint:paralleltest
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got, ok := tc.v1.IsCompatibleWith(tc.v2)
if tc.want.Major == 0 {
require.False(t, ok)
return
}
require.True(t, ok)
require.Equal(t, got, tc.want)
})
}
}
func TestRPCVersionListParseString(t *testing.T) {
t.Parallel()
cases := []struct {
name string
input string
want vpn.RPCVersionList
errContains string
}{
{
name: "single version",
input: "1.0",
want: vpn.RPCVersionList{
Versions: []vpn.RPCVersion{
{Major: 1, Minor: 0},
},
},
},
{
name: "multiple versions",
input: "1.1,2.3,3.2",
want: vpn.RPCVersionList{
Versions: []vpn.RPCVersion{
{Major: 1, Minor: 1},
{Major: 2, Minor: 3},
{Major: 3, Minor: 2},
},
},
},
{
name: "invalid version",
input: "1.0,invalid",
errContains: "invalid version list",
},
{
name: "empty string",
input: "",
errContains: "invalid version list",
},
{
name: "duplicate versions",
input: "1.0,1.0",
errContains: "duplicate major version",
},
{
name: "duplicate major versions",
input: "1.0,1.2",
errContains: "duplicate major version",
},
{
name: "out of order versions",
input: "2.0,1.0",
errContains: "versions are not sorted",
},
}
for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got, err := vpn.ParseRPCVersionList(tc.input)
if tc.errContains != "" {
require.ErrorContains(t, err, tc.errContains)
return
}
require.NoError(t, err)
require.Equal(t, tc.want, got)
require.Equal(t, tc.input, got.String())
})
}
}
func TestRPCVersionListValidate(t *testing.T) {
t.Parallel()
cases := []struct {
name string
list vpn.RPCVersionList
errContains string
}{
{
name: "valid list",
list: vpn.RPCVersionList{
Versions: []vpn.RPCVersion{
{Major: 1, Minor: 1},
{Major: 2, Minor: 3},
{Major: 3, Minor: 2},
},
},
},
{
name: "empty list",
list: vpn.RPCVersionList{
Versions: []vpn.RPCVersion{},
},
errContains: "no versions",
},
{
name: "duplicate versions",
list: vpn.RPCVersionList{
Versions: []vpn.RPCVersion{
{Major: 1, Minor: 0},
{Major: 1, Minor: 0},
},
},
errContains: "duplicate major version",
},
{
name: "duplicate major versions",
list: vpn.RPCVersionList{
Versions: []vpn.RPCVersion{
{Major: 1, Minor: 0},
{Major: 1, Minor: 2},
},
},
errContains: "duplicate major version",
},
{
name: "out of order versions",
list: vpn.RPCVersionList{
Versions: []vpn.RPCVersion{
{Major: 2, Minor: 0},
{Major: 1, Minor: 0},
},
},
errContains: "versions are not sorted",
},
}
for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
err := tc.list.Validate()
if tc.errContains != "" {
require.ErrorContains(t, err, tc.errContains)
} else {
require.NoError(t, err)
}
})
}
}