mirror of
https://github.com/coder/coder.git
synced 2025-07-13 21:36:50 +00:00
fix: fix workspacesdk to return error on API mismatch (#13683)
This commit is contained in:
@ -3,8 +3,10 @@ package workspacesdk
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"slices"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -14,6 +16,7 @@ import (
|
|||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
|
|
||||||
"cdr.dev/slog"
|
"cdr.dev/slog"
|
||||||
|
"github.com/coder/coder/v2/buildinfo"
|
||||||
"github.com/coder/coder/v2/codersdk"
|
"github.com/coder/coder/v2/codersdk"
|
||||||
"github.com/coder/coder/v2/tailnet"
|
"github.com/coder/coder/v2/tailnet"
|
||||||
"github.com/coder/coder/v2/tailnet/proto"
|
"github.com/coder/coder/v2/tailnet/proto"
|
||||||
@ -101,6 +104,9 @@ func (tac *tailnetAPIConnector) run() {
|
|||||||
defer close(tac.closed)
|
defer close(tac.closed)
|
||||||
for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(tac.ctx); {
|
for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(tac.ctx); {
|
||||||
tailnetClient, err := tac.dial()
|
tailnetClient, err := tac.dial()
|
||||||
|
if xerrors.Is(err, &codersdk.Error{}) {
|
||||||
|
return
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -110,13 +116,29 @@ func (tac *tailnetAPIConnector) run() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var permanentErrorStatuses = []int{
|
||||||
|
http.StatusConflict, // returned if client/agent connections disabled (browser only)
|
||||||
|
http.StatusBadRequest, // returned if API mismatch
|
||||||
|
http.StatusNotFound, // returned if user doesn't have permission or agent doesn't exist
|
||||||
|
}
|
||||||
|
|
||||||
func (tac *tailnetAPIConnector) dial() (proto.DRPCTailnetClient, error) {
|
func (tac *tailnetAPIConnector) dial() (proto.DRPCTailnetClient, error) {
|
||||||
tac.logger.Debug(tac.ctx, "dialing Coder tailnet v2+ API")
|
tac.logger.Debug(tac.ctx, "dialing Coder tailnet v2+ API")
|
||||||
// nolint:bodyclose
|
// nolint:bodyclose
|
||||||
ws, res, err := websocket.Dial(tac.ctx, tac.coordinateURL, tac.dialOptions)
|
ws, res, err := websocket.Dial(tac.ctx, tac.coordinateURL, tac.dialOptions)
|
||||||
if tac.isFirst {
|
if tac.isFirst {
|
||||||
if res != nil && res.StatusCode == http.StatusConflict {
|
if res != nil && slices.Contains(permanentErrorStatuses, res.StatusCode) {
|
||||||
err = codersdk.ReadBodyAsError(res)
|
err = codersdk.ReadBodyAsError(res)
|
||||||
|
// A bit more human-readable help in the case the API version was rejected
|
||||||
|
var sdkErr *codersdk.Error
|
||||||
|
if xerrors.As(err, &sdkErr) {
|
||||||
|
if sdkErr.Message == AgentAPIMismatchMessage &&
|
||||||
|
sdkErr.StatusCode() == http.StatusBadRequest {
|
||||||
|
sdkErr.Helper = fmt.Sprintf(
|
||||||
|
"Ensure your client release version (%s, different than the API version) matches the server release version",
|
||||||
|
buildinfo.Version())
|
||||||
|
}
|
||||||
|
}
|
||||||
tac.connected <- err
|
tac.connected <- err
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -18,6 +18,8 @@ import (
|
|||||||
|
|
||||||
"cdr.dev/slog"
|
"cdr.dev/slog"
|
||||||
"cdr.dev/slog/sloggers/slogtest"
|
"cdr.dev/slog/sloggers/slogtest"
|
||||||
|
"github.com/coder/coder/v2/apiversion"
|
||||||
|
"github.com/coder/coder/v2/coderd/httpapi"
|
||||||
"github.com/coder/coder/v2/codersdk"
|
"github.com/coder/coder/v2/codersdk"
|
||||||
"github.com/coder/coder/v2/tailnet"
|
"github.com/coder/coder/v2/tailnet"
|
||||||
"github.com/coder/coder/v2/tailnet/proto"
|
"github.com/coder/coder/v2/tailnet/proto"
|
||||||
@ -97,6 +99,41 @@ func TestTailnetAPIConnector_Disconnects(t *testing.T) {
|
|||||||
require.NotNil(t, reqDisc.Disconnect)
|
require.NotNil(t, reqDisc.Disconnect)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTailnetAPIConnector_UplevelVersion(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
ctx := testutil.Context(t, testutil.WaitShort)
|
||||||
|
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
|
||||||
|
agentID := uuid.UUID{0x55}
|
||||||
|
|
||||||
|
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
sVer := apiversion.New(proto.CurrentMajor, proto.CurrentMinor-1)
|
||||||
|
|
||||||
|
// the following matches what Coderd does;
|
||||||
|
// c.f. coderd/workspaceagents.go: workspaceAgentClientCoordinate
|
||||||
|
cVer := r.URL.Query().Get("version")
|
||||||
|
if err := sVer.Validate(cVer); err != nil {
|
||||||
|
httpapi.Write(ctx, w, http.StatusBadRequest, codersdk.Response{
|
||||||
|
Message: AgentAPIMismatchMessage,
|
||||||
|
Validations: []codersdk.ValidationError{
|
||||||
|
{Field: "version", Detail: err.Error()},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
fConn := newFakeTailnetConn()
|
||||||
|
|
||||||
|
uut := runTailnetAPIConnector(ctx, logger, agentID, svr.URL, &websocket.DialOptions{}, fConn)
|
||||||
|
|
||||||
|
err := testutil.RequireRecvCtx(ctx, t, uut.connected)
|
||||||
|
var sdkErr *codersdk.Error
|
||||||
|
require.ErrorAs(t, err, &sdkErr)
|
||||||
|
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
||||||
|
require.Equal(t, AgentAPIMismatchMessage, sdkErr.Message)
|
||||||
|
require.NotEmpty(t, sdkErr.Helper)
|
||||||
|
}
|
||||||
|
|
||||||
type fakeTailnetConn struct{}
|
type fakeTailnetConn struct{}
|
||||||
|
|
||||||
func (*fakeTailnetConn) UpdatePeers([]*proto.CoordinateResponse_PeerUpdate) error {
|
func (*fakeTailnetConn) UpdatePeers([]*proto.CoordinateResponse_PeerUpdate) error {
|
||||||
|
@ -55,6 +55,8 @@ const (
|
|||||||
AgentMinimumListeningPort = 9
|
AgentMinimumListeningPort = 9
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const AgentAPIMismatchMessage = "Unknown or unsupported API version"
|
||||||
|
|
||||||
// AgentIgnoredListeningPorts contains a list of ports to ignore when looking for
|
// AgentIgnoredListeningPorts contains a list of ports to ignore when looking for
|
||||||
// running applications inside a workspace. We want to ignore non-HTTP servers,
|
// running applications inside a workspace. We want to ignore non-HTTP servers,
|
||||||
// so we pre-populate this list with common ports that are not HTTP servers.
|
// so we pre-populate this list with common ports that are not HTTP servers.
|
||||||
|
@ -46,8 +46,9 @@ func TestBlockNonBrowser(t *testing.T) {
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
r := setupWorkspaceAgent(t, client, user, 0)
|
r := setupWorkspaceAgent(t, client, user, 0)
|
||||||
|
ctx := testutil.Context(t, testutil.WaitShort)
|
||||||
//nolint:gocritic // Testing that even the owner gets blocked.
|
//nolint:gocritic // Testing that even the owner gets blocked.
|
||||||
_, err := workspacesdk.New(client).DialAgent(context.Background(), r.sdkAgent.ID, nil)
|
_, err := workspacesdk.New(client).DialAgent(ctx, r.sdkAgent.ID, nil)
|
||||||
var apiErr *codersdk.Error
|
var apiErr *codersdk.Error
|
||||||
require.ErrorAs(t, err, &apiErr)
|
require.ErrorAs(t, err, &apiErr)
|
||||||
require.Equal(t, http.StatusConflict, apiErr.StatusCode())
|
require.Equal(t, http.StatusConflict, apiErr.StatusCode())
|
||||||
@ -65,8 +66,9 @@ func TestBlockNonBrowser(t *testing.T) {
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
r := setupWorkspaceAgent(t, client, user, 0)
|
r := setupWorkspaceAgent(t, client, user, 0)
|
||||||
|
ctx := testutil.Context(t, testutil.WaitShort)
|
||||||
//nolint:gocritic // Testing RBAC is not the point of this test.
|
//nolint:gocritic // Testing RBAC is not the point of this test.
|
||||||
conn, err := workspacesdk.New(client).DialAgent(context.Background(), r.sdkAgent.ID, nil)
|
conn, err := workspacesdk.New(client).DialAgent(ctx, r.sdkAgent.ID, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
_ = conn.Close()
|
_ = conn.Close()
|
||||||
})
|
})
|
||||||
|
Reference in New Issue
Block a user