mirror of
https://github.com/coder/coder.git
synced 2025-03-14 10:09:57 +00:00
chore: rework RPC version negotiation (#15687)
Changes the RPC header format from `codervpn <version> <role>` to `codervpn <role> <version1,version2,...>`. The versions list is a list of the maximum supported minor version for each major version, sorted by major versions. E.g. `1.0,2.3,3.1` means `1.0, 2.0, 2.1, 2.2, 2.3, 3.0, 3.1` are supported. When we eventually support multiple versions, the peer's version list will be compared against the current supported versions list to determine the maximum major and minor version supported by both peers. Closes #15601
This commit is contained in:
@ -11,7 +11,6 @@ import (
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/v2/apiversion"
|
||||
)
|
||||
|
||||
type SpeakerRole string
|
||||
@ -258,7 +257,7 @@ func handshake(
|
||||
// read and write simultaneously to avoid deadlocking if the conn is not buffered
|
||||
errCh := make(chan error, 2)
|
||||
go func() {
|
||||
ours := headerString(CurrentVersion, me)
|
||||
ours := headerString(me, CurrentSupportedVersions)
|
||||
_, err := conn.Write([]byte(ours))
|
||||
logger.Debug(ctx, "wrote out header")
|
||||
if err != nil {
|
||||
@ -316,34 +315,43 @@ func handshake(
|
||||
}
|
||||
}
|
||||
logger.Debug(ctx, "handshake read/write complete", slog.F("their_header", theirHeader))
|
||||
err := validateHeader(theirHeader, them)
|
||||
gotVersion, err := validateHeader(theirHeader, them, CurrentSupportedVersions)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("validate header (%s): %w", theirHeader, err)
|
||||
}
|
||||
logger.Debug(ctx, "handshake validated", slog.F("common_version", gotVersion))
|
||||
// TODO: actually use the common version to perform different behavior once
|
||||
// we have multiple versions
|
||||
return nil
|
||||
}
|
||||
|
||||
const headerPreamble = "codervpn"
|
||||
|
||||
func headerString(version *apiversion.APIVersion, role SpeakerRole) string {
|
||||
return fmt.Sprintf("%s %s %s\n", headerPreamble, version.String(), role)
|
||||
func headerString(role SpeakerRole, versions RPCVersionList) string {
|
||||
return fmt.Sprintf("%s %s %s\n", headerPreamble, role, versions.String())
|
||||
}
|
||||
|
||||
func validateHeader(header string, expectedRole SpeakerRole) error {
|
||||
func validateHeader(header string, expectedRole SpeakerRole, supportedVersions RPCVersionList) (RPCVersion, error) {
|
||||
parts := strings.Split(header, " ")
|
||||
if len(parts) != 3 {
|
||||
return xerrors.New("wrong number of parts")
|
||||
return RPCVersion{}, xerrors.New("wrong number of parts")
|
||||
}
|
||||
if parts[0] != headerPreamble {
|
||||
return xerrors.New("invalid preamble")
|
||||
return RPCVersion{}, xerrors.New("invalid preamble")
|
||||
}
|
||||
if err := CurrentVersion.Validate(parts[1]); err != nil {
|
||||
return xerrors.Errorf("version: %w", err)
|
||||
if parts[1] != string(expectedRole) {
|
||||
return RPCVersion{}, xerrors.New("unexpected role")
|
||||
}
|
||||
if parts[2] != string(expectedRole) {
|
||||
return xerrors.New("unexpected role")
|
||||
otherVersions, err := ParseRPCVersionList(parts[2])
|
||||
if err != nil {
|
||||
return RPCVersion{}, xerrors.Errorf("parse version list %q: %w", parts[2], err)
|
||||
}
|
||||
return nil
|
||||
compatibleVersion, ok := supportedVersions.IsCompatibleWith(otherVersions)
|
||||
if !ok {
|
||||
return RPCVersion{},
|
||||
xerrors.Errorf("current supported versions %q is not compatible with peer versions %q", supportedVersions.String(), otherVersions.String())
|
||||
}
|
||||
return compatibleVersion, nil
|
||||
}
|
||||
|
||||
type request[S rpcMessage, R rpcMessage] struct {
|
||||
|
@ -47,14 +47,14 @@ func TestSpeaker_RawPeer(t *testing.T) {
|
||||
errCh <- err
|
||||
}()
|
||||
|
||||
expectedHandshake := "codervpn 1.0 tunnel\n"
|
||||
expectedHandshake := "codervpn tunnel 1.0\n"
|
||||
|
||||
b := make([]byte, 256)
|
||||
n, err := mp.Read(b)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expectedHandshake, string(b[:n]))
|
||||
|
||||
_, err = mp.Write([]byte("codervpn 1.0 manager\n"))
|
||||
_, err = mp.Write([]byte("codervpn manager 1.3,2.1\n"))
|
||||
require.NoError(t, err)
|
||||
|
||||
err = testutil.RequireRecvCtx(ctx, t, errCh)
|
||||
@ -155,7 +155,7 @@ func TestSpeaker_OversizeHandshake(t *testing.T) {
|
||||
errCh <- err
|
||||
}()
|
||||
|
||||
expectedHandshake := "codervpn 1.0 tunnel\n"
|
||||
expectedHandshake := "codervpn tunnel 1.0\n"
|
||||
|
||||
b := make([]byte, 256)
|
||||
n, err := mp.Read(b)
|
||||
@ -177,10 +177,10 @@ func TestSpeaker_HandshakeInvalid(t *testing.T) {
|
||||
for _, tc := range []struct {
|
||||
name, handshake string
|
||||
}{
|
||||
{name: "preamble", handshake: "ssh 1.0 manager\n"},
|
||||
{name: "preamble", handshake: "ssh manager 1.0\n"},
|
||||
{name: "2components", handshake: "ssh manager\n"},
|
||||
{name: "newversion", handshake: "codervpn 1.1 manager\n"},
|
||||
{name: "oldversion", handshake: "codervpn 0.1 manager\n"},
|
||||
{name: "newmajors", handshake: "codervpn manager 2.0,3.0\n"},
|
||||
{name: "0version", handshake: "codervpn 0.1 manager\n"},
|
||||
{name: "unknown_role", handshake: "codervpn 1.0 supervisor\n"},
|
||||
{name: "unexpected_role", handshake: "codervpn 1.0 tunnel\n"},
|
||||
} {
|
||||
@ -208,7 +208,7 @@ func TestSpeaker_HandshakeInvalid(t *testing.T) {
|
||||
_, err = mp.Write([]byte(tc.handshake))
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedHandshake := "codervpn 1.0 tunnel\n"
|
||||
expectedHandshake := "codervpn tunnel 1.0\n"
|
||||
b := make([]byte, 256)
|
||||
n, err := mp.Read(b)
|
||||
require.NoError(t, err)
|
||||
@ -246,14 +246,14 @@ func TestSpeaker_CorruptMessage(t *testing.T) {
|
||||
errCh <- err
|
||||
}()
|
||||
|
||||
expectedHandshake := "codervpn 1.0 tunnel\n"
|
||||
expectedHandshake := "codervpn tunnel 1.0\n"
|
||||
|
||||
b := make([]byte, 256)
|
||||
n, err := mp.Read(b)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expectedHandshake, string(b[:n]))
|
||||
|
||||
_, err = mp.Write([]byte("codervpn 1.0 manager\n"))
|
||||
_, err = mp.Write([]byte("codervpn manager 1.0\n"))
|
||||
require.NoError(t, err)
|
||||
|
||||
err = testutil.RequireRecvCtx(ctx, t, errCh)
|
||||
|
141
vpn/version.go
141
vpn/version.go
@ -1,10 +1,141 @@
|
||||
package vpn
|
||||
|
||||
import "github.com/coder/coder/v2/apiversion"
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
const (
|
||||
CurrentMajor = 1
|
||||
CurrentMinor = 0
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
var CurrentVersion = apiversion.New(CurrentMajor, CurrentMinor)
|
||||
// CurrentSupportedVersions is the list of versions supported by this
|
||||
// implementation of the VPN RPC protocol.
|
||||
var CurrentSupportedVersions = RPCVersionList{
|
||||
Versions: []RPCVersion{
|
||||
{Major: 1, Minor: 0},
|
||||
},
|
||||
}
|
||||
|
||||
// RPCVersion represents a single version of the RPC protocol. Any given version
|
||||
// is expected to be backwards compatible with all previous minor versions on
|
||||
// the same major version.
|
||||
//
|
||||
// e.g. RPCVersion{2, 3} is backwards compatible with RPCVersion{2, 2} but is
|
||||
// not backwards compatible with RPCVersion{1, 2}.
|
||||
type RPCVersion struct {
|
||||
Major uint64 `json:"major"`
|
||||
Minor uint64 `json:"minor"`
|
||||
}
|
||||
|
||||
// ParseRPCVersion parses a version string in the format "major.minor" into a
|
||||
// RPCVersion.
|
||||
func ParseRPCVersion(str string) (RPCVersion, error) {
|
||||
split := strings.Split(str, ".")
|
||||
if len(split) != 2 {
|
||||
return RPCVersion{}, xerrors.Errorf("invalid version string: %s", str)
|
||||
}
|
||||
major, err := strconv.ParseUint(split[0], 10, 64)
|
||||
if err != nil {
|
||||
return RPCVersion{}, xerrors.Errorf("invalid version string: %s", str)
|
||||
}
|
||||
if major == 0 {
|
||||
return RPCVersion{}, xerrors.Errorf("invalid version string: %s", str)
|
||||
}
|
||||
minor, err := strconv.ParseUint(split[1], 10, 64)
|
||||
if err != nil {
|
||||
return RPCVersion{}, xerrors.Errorf("invalid version string: %s", str)
|
||||
}
|
||||
return RPCVersion{Major: major, Minor: minor}, nil
|
||||
}
|
||||
|
||||
func (v RPCVersion) String() string {
|
||||
return fmt.Sprintf("%d.%d", v.Major, v.Minor)
|
||||
}
|
||||
|
||||
// IsCompatibleWith returns the lowest version that is compatible with both
|
||||
// versions. If the versions are not compatible, the second return value will be
|
||||
// false.
|
||||
func (v RPCVersion) IsCompatibleWith(other RPCVersion) (RPCVersion, bool) {
|
||||
if v.Major != other.Major {
|
||||
return RPCVersion{}, false
|
||||
}
|
||||
// The lowest minor version from the two versions should be returned.
|
||||
if v.Minor < other.Minor {
|
||||
return v, true
|
||||
}
|
||||
return other, true
|
||||
}
|
||||
|
||||
// RPCVersionList represents a list of RPC versions supported by a RPC peer. An
|
||||
type RPCVersionList struct {
|
||||
Versions []RPCVersion `json:"versions"`
|
||||
}
|
||||
|
||||
// ParseRPCVersionList parses a version string in the format
|
||||
// "major.minor,major.minor" into a RPCVersionList.
|
||||
func ParseRPCVersionList(str string) (RPCVersionList, error) {
|
||||
split := strings.Split(str, ",")
|
||||
versions := make([]RPCVersion, len(split))
|
||||
for i, v := range split {
|
||||
version, err := ParseRPCVersion(v)
|
||||
if err != nil {
|
||||
return RPCVersionList{}, xerrors.Errorf("invalid version list: %s", str)
|
||||
}
|
||||
versions[i] = version
|
||||
}
|
||||
vl := RPCVersionList{Versions: versions}
|
||||
err := vl.Validate()
|
||||
if err != nil {
|
||||
return RPCVersionList{}, xerrors.Errorf("invalid parsed version list %q: %w", str, err)
|
||||
}
|
||||
return vl, nil
|
||||
}
|
||||
|
||||
func (vl RPCVersionList) String() string {
|
||||
versionStrings := make([]string, len(vl.Versions))
|
||||
for i, v := range vl.Versions {
|
||||
versionStrings[i] = v.String()
|
||||
}
|
||||
return strings.Join(versionStrings, ",")
|
||||
}
|
||||
|
||||
// Validate returns an error if the version list is not sorted or contains
|
||||
// duplicate major versions.
|
||||
func (vl RPCVersionList) Validate() error {
|
||||
if len(vl.Versions) == 0 {
|
||||
return xerrors.New("no versions")
|
||||
}
|
||||
for i := 0; i < len(vl.Versions); i++ {
|
||||
if vl.Versions[i].Major == 0 {
|
||||
return xerrors.Errorf("invalid version: %s", vl.Versions[i].String())
|
||||
}
|
||||
if i > 0 && vl.Versions[i-1].Major == vl.Versions[i].Major {
|
||||
return xerrors.Errorf("duplicate major version: %d", vl.Versions[i].Major)
|
||||
}
|
||||
if i > 0 && vl.Versions[i-1].Major > vl.Versions[i].Major {
|
||||
return xerrors.Errorf("versions are not sorted")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsCompatibleWith returns the lowest version that is compatible with both
|
||||
// version lists. If the versions are not compatible, the second return value
|
||||
// will be false.
|
||||
func (vl RPCVersionList) IsCompatibleWith(other RPCVersionList) (RPCVersion, bool) {
|
||||
bestVersion := RPCVersion{}
|
||||
for _, v1 := range vl.Versions {
|
||||
for _, v2 := range other.Versions {
|
||||
if v1.Major == v2.Major && v1.Major > bestVersion.Major {
|
||||
v, ok := v1.IsCompatibleWith(v2)
|
||||
if ok {
|
||||
bestVersion = v
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if bestVersion.Major == 0 {
|
||||
return bestVersion, false
|
||||
}
|
||||
return bestVersion, true
|
||||
}
|
||||
|
262
vpn/version_test.go
Normal file
262
vpn/version_test.go
Normal file
@ -0,0 +1,262 @@
|
||||
package vpn_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/vpn"
|
||||
)
|
||||
|
||||
func TestRPCVersionLatest(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.NoError(t, vpn.CurrentSupportedVersions.Validate())
|
||||
}
|
||||
|
||||
func TestRPCVersionParseString(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
input string
|
||||
want vpn.RPCVersion
|
||||
}{
|
||||
{
|
||||
name: "valid version",
|
||||
input: "1.0",
|
||||
want: vpn.RPCVersion{Major: 1, Minor: 0},
|
||||
},
|
||||
{
|
||||
name: "valid version with larger numbers",
|
||||
input: "12.34",
|
||||
want: vpn.RPCVersion{Major: 12, Minor: 34},
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
},
|
||||
{
|
||||
name: "one part",
|
||||
input: "1",
|
||||
},
|
||||
{
|
||||
name: "three parts",
|
||||
input: "1.0.0",
|
||||
},
|
||||
{
|
||||
name: "major version is 0",
|
||||
input: "0.1",
|
||||
},
|
||||
{
|
||||
name: "invalid major version",
|
||||
input: "a.1",
|
||||
},
|
||||
{
|
||||
name: "invalid minor version",
|
||||
input: "1.a",
|
||||
},
|
||||
}
|
||||
|
||||
// nolint:paralleltest
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got, err := vpn.ParseRPCVersion(tc.input)
|
||||
if tc.want.Major == 0 {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tc.want, got)
|
||||
|
||||
require.Equal(t, tc.input, got.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRPCVersionIsCompatibleWith(t *testing.T) {
|
||||
t.Parallel()
|
||||
cases := []struct {
|
||||
name string
|
||||
v1 vpn.RPCVersion
|
||||
v2 vpn.RPCVersion
|
||||
want vpn.RPCVersion
|
||||
wantBool bool
|
||||
}{
|
||||
{
|
||||
name: "same version",
|
||||
v1: vpn.RPCVersion{Major: 1, Minor: 0},
|
||||
v2: vpn.RPCVersion{Major: 1, Minor: 0},
|
||||
want: vpn.RPCVersion{Major: 1, Minor: 0},
|
||||
},
|
||||
{
|
||||
name: "compatible minor versions",
|
||||
v1: vpn.RPCVersion{Major: 1, Minor: 2},
|
||||
v2: vpn.RPCVersion{Major: 1, Minor: 3},
|
||||
want: vpn.RPCVersion{Major: 1, Minor: 2},
|
||||
},
|
||||
{
|
||||
name: "incompatible major versions",
|
||||
v1: vpn.RPCVersion{Major: 1, Minor: 0},
|
||||
v2: vpn.RPCVersion{Major: 2, Minor: 0},
|
||||
want: vpn.RPCVersion{},
|
||||
},
|
||||
}
|
||||
|
||||
// nolint:paralleltest
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got, ok := tc.v1.IsCompatibleWith(tc.v2)
|
||||
if tc.want.Major == 0 {
|
||||
require.False(t, ok)
|
||||
return
|
||||
}
|
||||
require.True(t, ok)
|
||||
require.Equal(t, got, tc.want)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRPCVersionListParseString(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
input string
|
||||
want vpn.RPCVersionList
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "single version",
|
||||
input: "1.0",
|
||||
want: vpn.RPCVersionList{
|
||||
Versions: []vpn.RPCVersion{
|
||||
{Major: 1, Minor: 0},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple versions",
|
||||
input: "1.1,2.3,3.2",
|
||||
want: vpn.RPCVersionList{
|
||||
Versions: []vpn.RPCVersion{
|
||||
{Major: 1, Minor: 1},
|
||||
{Major: 2, Minor: 3},
|
||||
{Major: 3, Minor: 2},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid version",
|
||||
input: "1.0,invalid",
|
||||
errContains: "invalid version list",
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
errContains: "invalid version list",
|
||||
},
|
||||
{
|
||||
name: "duplicate versions",
|
||||
input: "1.0,1.0",
|
||||
errContains: "duplicate major version",
|
||||
},
|
||||
{
|
||||
name: "duplicate major versions",
|
||||
input: "1.0,1.2",
|
||||
errContains: "duplicate major version",
|
||||
},
|
||||
{
|
||||
name: "out of order versions",
|
||||
input: "2.0,1.0",
|
||||
errContains: "versions are not sorted",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got, err := vpn.ParseRPCVersionList(tc.input)
|
||||
if tc.errContains != "" {
|
||||
require.ErrorContains(t, err, tc.errContains)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tc.want, got)
|
||||
require.Equal(t, tc.input, got.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRPCVersionListValidate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
list vpn.RPCVersionList
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "valid list",
|
||||
list: vpn.RPCVersionList{
|
||||
Versions: []vpn.RPCVersion{
|
||||
{Major: 1, Minor: 1},
|
||||
{Major: 2, Minor: 3},
|
||||
{Major: 3, Minor: 2},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty list",
|
||||
list: vpn.RPCVersionList{
|
||||
Versions: []vpn.RPCVersion{},
|
||||
},
|
||||
errContains: "no versions",
|
||||
},
|
||||
{
|
||||
name: "duplicate versions",
|
||||
list: vpn.RPCVersionList{
|
||||
Versions: []vpn.RPCVersion{
|
||||
{Major: 1, Minor: 0},
|
||||
{Major: 1, Minor: 0},
|
||||
},
|
||||
},
|
||||
errContains: "duplicate major version",
|
||||
},
|
||||
{
|
||||
name: "duplicate major versions",
|
||||
list: vpn.RPCVersionList{
|
||||
Versions: []vpn.RPCVersion{
|
||||
{Major: 1, Minor: 0},
|
||||
{Major: 1, Minor: 2},
|
||||
},
|
||||
},
|
||||
errContains: "duplicate major version",
|
||||
},
|
||||
{
|
||||
name: "out of order versions",
|
||||
list: vpn.RPCVersionList{
|
||||
Versions: []vpn.RPCVersion{
|
||||
{Major: 2, Minor: 0},
|
||||
{Major: 1, Minor: 0},
|
||||
},
|
||||
},
|
||||
errContains: "versions are not sorted",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
err := tc.list.Validate()
|
||||
if tc.errContains != "" {
|
||||
require.ErrorContains(t, err, tc.errContains)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user