chore: Remove WebRTC networking (#3881)

* chore: Remove WebRTC networking

* Fix race condition

* Fix WebSocket not closing
This commit is contained in:
Kyle Carberry
2022-09-19 19:46:29 -05:00
committed by GitHub
parent 1186e643ec
commit 714c366d16
44 changed files with 310 additions and 4225 deletions

View File

@ -13,7 +13,6 @@ import (
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/klauspost/compress/zstd"
"github.com/pion/webrtc/v3"
"github.com/prometheus/client_golang/prometheus"
"go.opentelemetry.io/otel/trace"
"golang.org/x/xerrors"
@ -35,7 +34,6 @@ import (
"github.com/coder/coder/coderd/rbac"
"github.com/coder/coder/coderd/telemetry"
"github.com/coder/coder/coderd/tracing"
"github.com/coder/coder/coderd/turnconn"
"github.com/coder/coder/coderd/wsconncache"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/site"
@ -65,17 +63,14 @@ type Options struct {
GithubOAuth2Config *GithubOAuth2Config
OIDCConfig *OIDCConfig
PrometheusRegistry *prometheus.Registry
ICEServers []webrtc.ICEServer
SecureAuthCookie bool
SSHKeygenAlgorithm gitsshkey.Algorithm
Telemetry telemetry.Reporter
TURNServer *turnconn.Server
TracerProvider trace.TracerProvider
AutoImportTemplates []AutoImportTemplate
LicenseHandler http.Handler
FeaturesService features.Service
TailscaleEnable bool
TailnetCoordinator *tailnet.Coordinator
DERPMap *tailcfg.DERPMap
@ -92,6 +87,12 @@ func New(options *Options) *API {
// Multiply the update by two to allow for some lag-time.
options.AgentInactiveDisconnectTimeout = options.AgentConnectionUpdateFrequency * 2
}
if options.AgentStatsRefreshInterval == 0 {
options.AgentStatsRefreshInterval = 10 * time.Minute
}
if options.MetricsCacheRefreshInterval == 0 {
options.MetricsCacheRefreshInterval = time.Hour
}
if options.APIRateLimit == 0 {
options.APIRateLimit = 512
}
@ -149,11 +150,7 @@ func New(options *Options) *API {
},
metricsCache: metricsCache,
}
if options.TailscaleEnable {
api.workspaceAgentCache = wsconncache.New(api.dialWorkspaceAgentTailnet, 0)
} else {
api.workspaceAgentCache = wsconncache.New(api.dialWorkspaceAgent, 0)
}
api.workspaceAgentCache = wsconncache.New(api.dialWorkspaceAgentTailnet, 0)
api.derpServer = derp.NewServer(key.NewNode(), tailnet.Logger(options.Logger))
oauthConfigs := &httpmw.OAuth2Configs{
Github: options.GithubOAuth2Config,
@ -415,14 +412,8 @@ func New(options *Options) *API {
r.Use(httpmw.ExtractWorkspaceAgent(options.Database))
r.Get("/metadata", api.workspaceAgentMetadata)
r.Post("/version", api.postWorkspaceAgentVersion)
r.Get("/listen", api.workspaceAgentListen)
r.Get("/gitsshkey", api.agentGitSSHKey)
r.Get("/turn", api.workspaceAgentTurn)
r.Get("/iceservers", api.workspaceAgentICEServers)
r.Get("/coordinate", api.workspaceAgentCoordinate)
r.Get("/report-stats", api.workspaceAgentReportStats)
})
r.Route("/{workspaceagent}", func(r chi.Router) {
@ -432,11 +423,7 @@ func New(options *Options) *API {
httpmw.ExtractWorkspaceParam(options.Database),
)
r.Get("/", api.workspaceAgent)
r.Get("/dial", api.workspaceAgentDial)
r.Get("/turn", api.userWorkspaceAgentTurn)
r.Get("/pty", api.workspaceAgentPTY)
r.Get("/iceservers", api.workspaceAgentICEServers)
r.Get("/connection", api.workspaceAgentConnection)
r.Get("/coordinate", api.workspaceAgentClientCoordinate)
})

View File

@ -188,18 +188,14 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
"GET:/api/v2/users/oidc/callback": {NoAuthorize: true},
// All workspaceagents endpoints do not use rbac
"POST:/api/v2/workspaceagents/aws-instance-identity": {NoAuthorize: true},
"POST:/api/v2/workspaceagents/azure-instance-identity": {NoAuthorize: true},
"POST:/api/v2/workspaceagents/google-instance-identity": {NoAuthorize: true},
"GET:/api/v2/workspaceagents/me/gitsshkey": {NoAuthorize: true},
"GET:/api/v2/workspaceagents/me/iceservers": {NoAuthorize: true},
"GET:/api/v2/workspaceagents/me/listen": {NoAuthorize: true},
"GET:/api/v2/workspaceagents/me/metadata": {NoAuthorize: true},
"GET:/api/v2/workspaceagents/me/turn": {NoAuthorize: true},
"GET:/api/v2/workspaceagents/me/coordinate": {NoAuthorize: true},
"POST:/api/v2/workspaceagents/me/version": {NoAuthorize: true},
"GET:/api/v2/workspaceagents/me/report-stats": {NoAuthorize: true},
"GET:/api/v2/workspaceagents/{workspaceagent}/iceservers": {NoAuthorize: true},
"POST:/api/v2/workspaceagents/aws-instance-identity": {NoAuthorize: true},
"POST:/api/v2/workspaceagents/azure-instance-identity": {NoAuthorize: true},
"POST:/api/v2/workspaceagents/google-instance-identity": {NoAuthorize: true},
"GET:/api/v2/workspaceagents/me/gitsshkey": {NoAuthorize: true},
"GET:/api/v2/workspaceagents/me/metadata": {NoAuthorize: true},
"GET:/api/v2/workspaceagents/me/coordinate": {NoAuthorize: true},
"POST:/api/v2/workspaceagents/me/version": {NoAuthorize: true},
"GET:/api/v2/workspaceagents/me/report-stats": {NoAuthorize: true},
// These endpoints have more assertions. This is good, add more endpoints to assert if you can!
"GET:/api/v2/organizations/{organization}": {AssertObject: rbac.ResourceOrganization.InOrg(a.Admin.OrganizationID)},
@ -256,14 +252,6 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
AssertAction: rbac.ActionRead,
AssertObject: workspaceRBACObj,
},
"GET:/api/v2/workspaceagents/{workspaceagent}/dial": {
AssertAction: rbac.ActionCreate,
AssertObject: workspaceExecObj,
},
"GET:/api/v2/workspaceagents/{workspaceagent}/turn": {
AssertAction: rbac.ActionCreate,
AssertObject: workspaceExecObj,
},
"GET:/api/v2/workspaceagents/{workspaceagent}/pty": {
AssertAction: rbac.ActionCreate,
AssertObject: workspaceExecObj,

View File

@ -54,7 +54,6 @@ import (
"github.com/coder/coder/coderd/gitsshkey"
"github.com/coder/coder/coderd/rbac"
"github.com/coder/coder/coderd/telemetry"
"github.com/coder/coder/coderd/turnconn"
"github.com/coder/coder/coderd/util/ptr"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/cryptorand"
@ -202,12 +201,6 @@ func newWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c
options.SSHKeygenAlgorithm = gitsshkey.AlgorithmEd25519
}
turnServer, err := turnconn.New(nil)
require.NoError(t, err)
t.Cleanup(func() {
_ = turnServer.Close()
})
features := coderd.DisabledImplementations
if options.Auditor != nil {
features.Auditor = options.Auditor
@ -231,7 +224,6 @@ func newWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c
OIDCConfig: options.OIDCConfig,
GoogleTokenValidator: options.GoogleTokenValidator,
SSHKeygenAlgorithm: options.SSHKeygenAlgorithm,
TURNServer: turnServer,
APIRateLimit: options.APIRateLimit,
Authorizer: options.Authorizer,
Telemetry: telemetry.NewNoop(),

View File

@ -604,7 +604,6 @@ func TestTemplateDAUs(t *testing.T) {
agentCloser := agent.New(agent.Options{
Logger: slogtest.Make(t, nil),
StatsReporter: agentClient.AgentReportStats,
WebRTCDialer: agentClient.ListenWorkspaceAgent,
FetchMetadata: agentClient.WorkspaceAgentMetadata,
CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet,
})

View File

@ -1,203 +0,0 @@
package turnconn
import (
"io"
"net"
"sync"
"github.com/pion/logging"
"github.com/pion/turn/v2"
"github.com/pion/webrtc/v3"
"golang.org/x/net/proxy"
"golang.org/x/xerrors"
)
var (
// reservedAddress is a magic address that's used exclusively
// for proxying via Coder. We don't proxy all TURN connections,
// because that'd exclude the possibility of a customer using
// their own TURN server.
reservedAddress = "127.0.0.1:12345"
credential = "coder"
localhost = &net.TCPAddr{
IP: net.IPv4(127, 0, 0, 1),
}
// Proxy is a an ICE Server that uses a special hostname
// to indicate traffic should be proxied.
Proxy = webrtc.ICEServer{
URLs: []string{"turns:" + reservedAddress},
Username: "coder",
Credential: credential,
}
)
// New constructs a new TURN server binding to the relay address provided.
// The relay address is used to broadcast the location of an accepted connection.
func New(relayAddress *turn.RelayAddressGeneratorStatic) (*Server, error) {
if relayAddress == nil {
relayAddress = &turn.RelayAddressGeneratorStatic{
RelayAddress: localhost.IP,
Address: "127.0.0.1",
}
}
logger := logging.NewDefaultLoggerFactory()
logger.DefaultLogLevel = logging.LogLevelDisabled
server := &Server{
conns: make(chan net.Conn, 1),
closed: make(chan struct{}),
}
server.listener = &listener{
srv: server,
}
var err error
server.turn, err = turn.NewServer(turn.ServerConfig{
AuthHandler: func(username, realm string, srcAddr net.Addr) (key []byte, ok bool) {
// TURN connections require credentials. It's not important
// for our use-case, because our listener is entirely in-memory.
return turn.GenerateAuthKey(Proxy.Username, "", credential), true
},
ListenerConfigs: []turn.ListenerConfig{{
Listener: server.listener,
RelayAddressGenerator: relayAddress,
}},
LoggerFactory: logger,
})
if err != nil {
return nil, xerrors.Errorf("create server: %w", err)
}
return server, nil
}
// Server accepts and connects TURN allocations.
//
// This is a thin wrapper around pion/turn that pipes
// connections directly to the in-memory handler.
type Server struct {
listener *listener
turn *turn.Server
closeMutex sync.Mutex
closed chan (struct{})
conns chan (net.Conn)
}
// Accept consumes a new connection into the TURN server.
// A unique remote address must exist per-connection.
// pion/turn indexes allocations based on the address.
func (s *Server) Accept(nc net.Conn, remoteAddress, localAddress *net.TCPAddr) *Conn {
if localAddress == nil {
localAddress = localhost
}
conn := &Conn{
Conn: nc,
remoteAddress: remoteAddress,
localAddress: localAddress,
closed: make(chan struct{}),
}
s.conns <- conn
return conn
}
// Close ends the TURN server.
func (s *Server) Close() error {
s.closeMutex.Lock()
defer s.closeMutex.Unlock()
if s.isClosed() {
return nil
}
err := s.turn.Close()
close(s.conns)
close(s.closed)
return err
}
func (s *Server) isClosed() bool {
select {
case <-s.closed:
return true
default:
return false
}
}
// listener implements net.Listener for the TURN
// server to consume.
type listener struct {
srv *Server
}
func (l *listener) Accept() (net.Conn, error) {
conn, ok := <-l.srv.conns
if !ok {
return nil, io.EOF
}
return conn, nil
}
func (*listener) Close() error {
return nil
}
func (*listener) Addr() net.Addr {
return nil
}
type Conn struct {
net.Conn
closed chan struct{}
localAddress *net.TCPAddr
remoteAddress *net.TCPAddr
}
func (c *Conn) LocalAddr() net.Addr {
return c.localAddress
}
func (c *Conn) RemoteAddr() net.Addr {
return c.remoteAddress
}
// Closed returns a channel which is closed when
// the connection is.
func (c *Conn) Closed() <-chan struct{} {
return c.closed
}
func (c *Conn) Close() error {
err := c.Conn.Close()
select {
case <-c.closed:
default:
close(c.closed)
}
return err
}
type dialer func(network, addr string) (c net.Conn, err error)
func (d dialer) Dial(network, addr string) (c net.Conn, err error) {
return d(network, addr)
}
// ProxyDialer accepts a proxy function that's called when the connection
// address matches the reserved host in the "Proxy" ICE server.
//
// This should be passed to WebRTC connections as an ICE dialer.
func ProxyDialer(proxyFunc func() (c net.Conn, err error)) proxy.Dialer {
return dialer(func(network, addr string) (net.Conn, error) {
if addr != reservedAddress {
return proxy.Direct.Dial(network, addr)
}
netConn, err := proxyFunc()
if err != nil {
return nil, err
}
return &Conn{
localAddress: localhost,
closed: make(chan struct{}),
Conn: netConn,
}, nil
})
}

View File

@ -1,107 +0,0 @@
package turnconn_test
import (
"net"
"sync"
"testing"
"github.com/pion/webrtc/v3"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/coderd/turnconn"
"github.com/coder/coder/peer"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
func TestTURNConn(t *testing.T) {
t.Parallel()
turnServer, err := turnconn.New(nil)
require.NoError(t, err)
defer turnServer.Close()
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
clientDialer, clientTURN := net.Pipe()
turnServer.Accept(clientTURN, &net.TCPAddr{
IP: net.IPv4(127, 0, 0, 1),
Port: 16000,
}, nil)
require.NoError(t, err)
clientSettings := webrtc.SettingEngine{}
clientSettings.SetNetworkTypes([]webrtc.NetworkType{webrtc.NetworkTypeTCP4, webrtc.NetworkTypeTCP6})
clientSettings.SetRelayAcceptanceMinWait(0)
clientSettings.SetICEProxyDialer(turnconn.ProxyDialer(func() (net.Conn, error) {
return clientDialer, nil
}))
client, err := peer.Client([]webrtc.ICEServer{turnconn.Proxy}, &peer.ConnOptions{
SettingEngine: clientSettings,
Logger: logger.Named("client"),
})
require.NoError(t, err)
defer func() {
_ = client.Close()
}()
serverDialer, serverTURN := net.Pipe()
turnServer.Accept(serverTURN, &net.TCPAddr{
IP: net.IPv4(127, 0, 0, 1),
Port: 16001,
}, nil)
require.NoError(t, err)
serverSettings := webrtc.SettingEngine{}
serverSettings.SetNetworkTypes([]webrtc.NetworkType{webrtc.NetworkTypeTCP4, webrtc.NetworkTypeTCP6})
serverSettings.SetRelayAcceptanceMinWait(0)
serverSettings.SetICEProxyDialer(turnconn.ProxyDialer(func() (net.Conn, error) {
return serverDialer, nil
}))
server, err := peer.Server([]webrtc.ICEServer{turnconn.Proxy}, &peer.ConnOptions{
SettingEngine: serverSettings,
Logger: logger.Named("server"),
})
require.NoError(t, err)
defer func() {
_ = server.Close()
}()
exchange(t, client, server)
_, err = client.Ping()
require.NoError(t, err)
}
func exchange(t *testing.T, client, server *peer.Conn) {
var wg sync.WaitGroup
wg.Add(2)
t.Cleanup(wg.Wait)
go func() {
defer wg.Done()
for {
select {
case c := <-server.LocalCandidate():
client.AddRemoteCandidate(c)
case c := <-server.LocalSessionDescription():
client.SetRemoteSessionDescription(c)
case <-server.Closed():
return
}
}
}()
go func() {
defer wg.Done()
for {
select {
case c := <-client.LocalCandidate():
server.AddRemoteCandidate(c)
case c := <-client.LocalSessionDescription():
server.SetRemoteSessionDescription(c)
case <-client.Closed():
return
}
}
}()
}

View File

@ -15,7 +15,6 @@ import (
"time"
"github.com/google/uuid"
"github.com/hashicorp/yamux"
"go.opentelemetry.io/otel/trace"
"golang.org/x/mod/semver"
"golang.org/x/xerrors"
@ -30,12 +29,7 @@ import (
"github.com/coder/coder/coderd/httpmw"
"github.com/coder/coder/coderd/rbac"
"github.com/coder/coder/coderd/tracing"
"github.com/coder/coder/coderd/turnconn"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/peer"
"github.com/coder/coder/peerbroker"
"github.com/coder/coder/peerbroker/proto"
"github.com/coder/coder/provisionersdk"
"github.com/coder/coder/tailnet"
)
@ -66,67 +60,6 @@ func (api *API) workspaceAgent(rw http.ResponseWriter, r *http.Request) {
httpapi.Write(rw, http.StatusOK, apiAgent)
}
func (api *API) workspaceAgentDial(rw http.ResponseWriter, r *http.Request) {
api.websocketWaitMutex.Lock()
api.websocketWaitGroup.Add(1)
api.websocketWaitMutex.Unlock()
defer api.websocketWaitGroup.Done()
workspaceAgent := httpmw.WorkspaceAgentParam(r)
workspace := httpmw.WorkspaceParam(r)
if !api.Authorize(r, rbac.ActionCreate, workspace.ExecutionRBAC()) {
httpapi.ResourceNotFound(rw)
return
}
apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, workspaceAgent, nil, api.AgentInactiveDisconnectTimeout)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error reading workspace agent.",
Detail: err.Error(),
})
return
}
if apiAgent.Status != codersdk.WorkspaceAgentConnected {
httpapi.Write(rw, http.StatusPreconditionFailed, codersdk.Response{
Message: fmt.Sprintf("Agent isn't connected! Status: %s.", apiAgent.Status),
})
return
}
conn, err := websocket.Accept(rw, r, nil)
if err != nil {
httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{
Message: "Failed to accept websocket.",
Detail: err.Error(),
})
return
}
ctx, wsNetConn := websocketNetConn(r.Context(), conn, websocket.MessageBinary)
defer wsNetConn.Close() // Also closes conn.
config := yamux.DefaultConfig()
config.LogOutput = io.Discard
session, err := yamux.Server(wsNetConn, config)
if err != nil {
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
return
}
// end span so we don't get long lived trace data
tracing.EndHTTPSpan(r, http.StatusOK, trace.SpanFromContext(ctx))
err = peerbroker.ProxyListen(ctx, session, peerbroker.ProxyOptions{
ChannelID: workspaceAgent.ID.String(),
Logger: api.Logger.Named("peerbroker-proxy-dial"),
Pubsub: api.Pubsub,
})
if err != nil {
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("serve: %s", err))
return
}
}
func (api *API) workspaceAgentMetadata(rw http.ResponseWriter, r *http.Request) {
workspaceAgent := httpmw.WorkspaceAgent(r)
apiAgent, err := convertWorkspaceAgent(api.DERPMap, api.TailnetCoordinator, workspaceAgent, nil, api.AgentInactiveDisconnectTimeout)
@ -186,231 +119,6 @@ func (api *API) postWorkspaceAgentVersion(rw http.ResponseWriter, r *http.Reques
httpapi.Write(rw, http.StatusOK, nil)
}
func (api *API) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
api.websocketWaitMutex.Lock()
api.websocketWaitGroup.Add(1)
api.websocketWaitMutex.Unlock()
defer api.websocketWaitGroup.Done()
workspaceAgent := httpmw.WorkspaceAgent(r)
resource, err := api.Database.GetWorkspaceResourceByID(r.Context(), workspaceAgent.ResourceID)
if err != nil {
httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{
Message: "Failed to accept websocket.",
Detail: err.Error(),
})
return
}
build, err := api.Database.GetWorkspaceBuildByJobID(r.Context(), resource.JobID)
if err != nil {
httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{
Message: "Internal error fetching workspace build job.",
Detail: err.Error(),
})
return
}
// Ensure the resource is still valid!
// We only accept agents for resources on the latest build.
ensureLatestBuild := func() error {
latestBuild, err := api.Database.GetLatestWorkspaceBuildByWorkspaceID(r.Context(), build.WorkspaceID)
if err != nil {
return err
}
if build.ID != latestBuild.ID {
return xerrors.New("build is outdated")
}
return nil
}
err = ensureLatestBuild()
if err != nil {
api.Logger.Debug(r.Context(), "agent tried to connect from non-latest built",
slog.F("resource", resource),
slog.F("agent", workspaceAgent),
)
httpapi.Write(rw, http.StatusForbidden, codersdk.Response{
Message: "Agent trying to connect from non-latest build.",
Detail: err.Error(),
})
return
}
conn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{
CompressionMode: websocket.CompressionDisabled,
})
if err != nil {
httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{
Message: "Failed to accept websocket.",
Detail: err.Error(),
})
return
}
ctx, wsNetConn := websocketNetConn(r.Context(), conn, websocket.MessageBinary)
defer wsNetConn.Close() // Also closes conn.
config := yamux.DefaultConfig()
config.LogOutput = io.Discard
session, err := yamux.Server(wsNetConn, config)
if err != nil {
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
return
}
closer, err := peerbroker.ProxyDial(proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(session)), peerbroker.ProxyOptions{
ChannelID: workspaceAgent.ID.String(),
Pubsub: api.Pubsub,
Logger: api.Logger.Named("peerbroker-proxy-listen"),
})
if err != nil {
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
return
}
defer closer.Close()
firstConnectedAt := workspaceAgent.FirstConnectedAt
if !firstConnectedAt.Valid {
firstConnectedAt = sql.NullTime{
Time: database.Now(),
Valid: true,
}
}
lastConnectedAt := sql.NullTime{
Time: database.Now(),
Valid: true,
}
disconnectedAt := workspaceAgent.DisconnectedAt
updateConnectionTimes := func() error {
err = api.Database.UpdateWorkspaceAgentConnectionByID(ctx, database.UpdateWorkspaceAgentConnectionByIDParams{
ID: workspaceAgent.ID,
FirstConnectedAt: firstConnectedAt,
LastConnectedAt: lastConnectedAt,
DisconnectedAt: disconnectedAt,
UpdatedAt: database.Now(),
})
if err != nil {
return err
}
return nil
}
defer func() {
disconnectedAt = sql.NullTime{
Time: database.Now(),
Valid: true,
}
_ = updateConnectionTimes()
}()
err = updateConnectionTimes()
if err != nil {
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
return
}
// end span so we don't get long lived trace data
tracing.EndHTTPSpan(r, http.StatusOK, trace.SpanFromContext(ctx))
api.Logger.Info(ctx, "accepting agent", slog.F("resource", resource), slog.F("agent", workspaceAgent))
ticker := time.NewTicker(api.AgentConnectionUpdateFrequency)
defer ticker.Stop()
for {
select {
case <-session.CloseChan():
return
case <-ticker.C:
lastConnectedAt = sql.NullTime{
Time: database.Now(),
Valid: true,
}
err = updateConnectionTimes()
if err != nil {
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
return
}
err = ensureLatestBuild()
if err != nil {
// Disconnect agents that are no longer valid.
_ = conn.Close(websocket.StatusGoingAway, "")
return
}
}
}
}
func (api *API) workspaceAgentICEServers(rw http.ResponseWriter, _ *http.Request) {
httpapi.Write(rw, http.StatusOK, api.ICEServers)
}
// userWorkspaceAgentTurn is a user connecting to a remote workspace agent
// through turn.
func (api *API) userWorkspaceAgentTurn(rw http.ResponseWriter, r *http.Request) {
workspace := httpmw.WorkspaceParam(r)
if !api.Authorize(r, rbac.ActionCreate, workspace.ExecutionRBAC()) {
httpapi.ResourceNotFound(rw)
return
}
// Passed authorization
api.workspaceAgentTurn(rw, r)
}
// workspaceAgentTurn proxies a WebSocket connection to the TURN server.
func (api *API) workspaceAgentTurn(rw http.ResponseWriter, r *http.Request) {
api.websocketWaitMutex.Lock()
api.websocketWaitGroup.Add(1)
api.websocketWaitMutex.Unlock()
defer api.websocketWaitGroup.Done()
localAddress, _ := r.Context().Value(http.LocalAddrContextKey).(*net.TCPAddr)
remoteAddress := &net.TCPAddr{
IP: net.ParseIP(r.RemoteAddr),
}
// By default requests have the remote address and port.
host, port, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid remote address.",
Detail: err.Error(),
})
return
}
remoteAddress.IP = net.ParseIP(host)
remoteAddress.Port, err = strconv.Atoi(port)
if err != nil {
httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{
Message: fmt.Sprintf("Port for remote address %q must be an integer.", r.RemoteAddr),
Detail: err.Error(),
})
return
}
wsConn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{
CompressionMode: websocket.CompressionDisabled,
})
if err != nil {
httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{
Message: "Failed to accept websocket.",
Detail: err.Error(),
})
return
}
ctx, wsNetConn := websocketNetConn(r.Context(), wsConn, websocket.MessageBinary)
defer wsNetConn.Close() // Also closes conn.
// end span so we don't get long lived trace data
tracing.EndHTTPSpan(r, http.StatusOK, trace.SpanFromContext(ctx))
api.Logger.Debug(ctx, "accepting turn connection", slog.F("remote-address", r.RemoteAddr), slog.F("local-address", localAddress))
select {
case <-api.TURNServer.Accept(wsNetConn, remoteAddress, localAddress).Closed():
case <-ctx.Done():
}
api.Logger.Debug(ctx, "completed turn connection", slog.F("remote-address", r.RemoteAddr), slog.F("local-address", localAddress))
}
// workspaceAgentPTY spawns a PTY and pipes it over a WebSocket.
// This is used for the web terminal.
func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
@ -492,75 +200,7 @@ func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
_, _ = io.Copy(ptNetConn, wsNetConn)
}
// dialWorkspaceAgent connects to a workspace agent by ID. Only rely on
// r.Context() for cancellation if it's use is safe or r.Hijack() has
// not been performed.
func (api *API) dialWorkspaceAgent(r *http.Request, agentID uuid.UUID) (agent.Conn, error) {
client, server := provisionersdk.TransportPipe()
ctx, cancelFunc := context.WithCancel(context.Background())
go func() {
_ = peerbroker.ProxyListen(ctx, server, peerbroker.ProxyOptions{
ChannelID: agentID.String(),
Logger: api.Logger.Named("peerbroker-proxy-dial"),
Pubsub: api.Pubsub,
})
_ = client.Close()
_ = server.Close()
}()
peerClient := proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(client))
stream, err := peerClient.NegotiateConnection(ctx)
if err != nil {
cancelFunc()
return nil, xerrors.Errorf("negotiate: %w", err)
}
options := &peer.ConnOptions{
Logger: api.Logger.Named("agent-dialer"),
}
options.SettingEngine.SetSrflxAcceptanceMinWait(0)
options.SettingEngine.SetRelayAcceptanceMinWait(0)
// Use the ProxyDialer for the TURN server.
// This is required for connections where P2P is not enabled.
options.SettingEngine.SetICEProxyDialer(turnconn.ProxyDialer(func() (c net.Conn, err error) {
clientPipe, serverPipe := net.Pipe()
go func() {
<-ctx.Done()
_ = clientPipe.Close()
_ = serverPipe.Close()
}()
localAddress, _ := r.Context().Value(http.LocalAddrContextKey).(*net.TCPAddr)
remoteAddress := &net.TCPAddr{
IP: net.ParseIP(r.RemoteAddr),
}
// By default requests have the remote address and port.
host, port, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return nil, xerrors.Errorf("split remote address: %w", err)
}
remoteAddress.IP = net.ParseIP(host)
remoteAddress.Port, err = strconv.Atoi(port)
if err != nil {
return nil, xerrors.Errorf("convert remote port: %w", err)
}
api.TURNServer.Accept(clientPipe, remoteAddress, localAddress)
return serverPipe, nil
}))
peerConn, err := peerbroker.Dial(stream, append(api.ICEServers, turnconn.Proxy), options)
if err != nil {
cancelFunc()
return nil, xerrors.Errorf("dial: %w", err)
}
go func() {
<-peerConn.Closed()
cancelFunc()
}()
return &agent.WebRTCConn{
Negotiator: peerClient,
Conn: peerConn,
}, nil
}
func (api *API) dialWorkspaceAgentTailnet(r *http.Request, agentID uuid.UUID) (agent.Conn, error) {
func (api *API) dialWorkspaceAgentTailnet(r *http.Request, agentID uuid.UUID) (*agent.Conn, error) {
clientConn, serverConn := net.Pipe()
go func() {
<-r.Context().Done()
@ -587,7 +227,7 @@ func (api *API) dialWorkspaceAgentTailnet(r *http.Request, agentID uuid.UUID) (a
_ = conn.Close()
}
}()
return &agent.TailnetConn{
return &agent.Conn{
Conn: conn,
}, nil
}
@ -609,6 +249,48 @@ func (api *API) workspaceAgentCoordinate(rw http.ResponseWriter, r *http.Request
api.websocketWaitMutex.Unlock()
defer api.websocketWaitGroup.Done()
workspaceAgent := httpmw.WorkspaceAgent(r)
resource, err := api.Database.GetWorkspaceResourceByID(r.Context(), workspaceAgent.ResourceID)
if err != nil {
httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{
Message: "Failed to accept websocket.",
Detail: err.Error(),
})
return
}
build, err := api.Database.GetWorkspaceBuildByJobID(r.Context(), resource.JobID)
if err != nil {
httpapi.Write(rw, http.StatusBadRequest, codersdk.Response{
Message: "Internal error fetching workspace build job.",
Detail: err.Error(),
})
return
}
// Ensure the resource is still valid!
// We only accept agents for resources on the latest build.
ensureLatestBuild := func() error {
latestBuild, err := api.Database.GetLatestWorkspaceBuildByWorkspaceID(r.Context(), build.WorkspaceID)
if err != nil {
return err
}
if build.ID != latestBuild.ID {
return xerrors.New("build is outdated")
}
return nil
}
err = ensureLatestBuild()
if err != nil {
api.Logger.Debug(r.Context(), "agent tried to connect from non-latest built",
slog.F("resource", resource),
slog.F("agent", workspaceAgent),
)
httpapi.Write(rw, http.StatusForbidden, codersdk.Response{
Message: "Agent trying to connect from non-latest build.",
Detail: err.Error(),
})
return
}
conn, err := websocket.Accept(rw, r, nil)
if err != nil {
@ -618,12 +300,88 @@ func (api *API) workspaceAgentCoordinate(rw http.ResponseWriter, r *http.Request
})
return
}
defer conn.Close(websocket.StatusNormalClosure, "")
err = api.TailnetCoordinator.ServeAgent(websocket.NetConn(r.Context(), conn, websocket.MessageBinary), workspaceAgent.ID)
ctx, wsNetConn := websocketNetConn(r.Context(), conn, websocket.MessageBinary)
defer wsNetConn.Close()
firstConnectedAt := workspaceAgent.FirstConnectedAt
if !firstConnectedAt.Valid {
firstConnectedAt = sql.NullTime{
Time: database.Now(),
Valid: true,
}
}
lastConnectedAt := sql.NullTime{
Time: database.Now(),
Valid: true,
}
disconnectedAt := workspaceAgent.DisconnectedAt
updateConnectionTimes := func() error {
err = api.Database.UpdateWorkspaceAgentConnectionByID(ctx, database.UpdateWorkspaceAgentConnectionByIDParams{
ID: workspaceAgent.ID,
FirstConnectedAt: firstConnectedAt,
LastConnectedAt: lastConnectedAt,
DisconnectedAt: disconnectedAt,
UpdatedAt: database.Now(),
})
if err != nil {
return err
}
return nil
}
defer func() {
disconnectedAt = sql.NullTime{
Time: database.Now(),
Valid: true,
}
_ = updateConnectionTimes()
}()
err = updateConnectionTimes()
if err != nil {
_ = conn.Close(websocket.StatusInternalError, err.Error())
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
return
}
// end span so we don't get long lived trace data
tracing.EndHTTPSpan(r, http.StatusOK, trace.SpanFromContext(ctx))
api.Logger.Info(ctx, "accepting agent", slog.F("resource", resource), slog.F("agent", workspaceAgent))
defer conn.Close(websocket.StatusNormalClosure, "")
closeChan := make(chan struct{})
go func() {
defer close(closeChan)
err := api.TailnetCoordinator.ServeAgent(wsNetConn, workspaceAgent.ID)
if err != nil {
_ = conn.Close(websocket.StatusInternalError, err.Error())
return
}
}()
ticker := time.NewTicker(api.AgentConnectionUpdateFrequency)
defer ticker.Stop()
for {
select {
case <-closeChan:
return
case <-ticker.C:
}
lastConnectedAt = sql.NullTime{
Time: database.Now(),
Valid: true,
}
err = updateConnectionTimes()
if err != nil {
_ = conn.Close(websocket.StatusAbnormalClosure, err.Error())
return
}
err := ensureLatestBuild()
if err != nil {
// Disconnect agents that are no longer valid.
_ = conn.Close(websocket.StatusGoingAway, "")
return
}
}
}
// workspaceAgentClientCoordinate accepts a WebSocket that reads node network updates.

View File

@ -10,7 +10,6 @@ import (
"time"
"github.com/google/uuid"
"github.com/pion/webrtc/v3"
"github.com/stretchr/testify/require"
"cdr.dev/slog"
@ -18,7 +17,6 @@ import (
"github.com/coder/coder/agent"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/peer"
"github.com/coder/coder/provisioner/echo"
"github.com/coder/coder/provisionersdk/proto"
"github.com/coder/coder/testutil"
@ -112,7 +110,6 @@ func TestWorkspaceAgentListen(t *testing.T) {
agentCloser := agent.New(agent.Options{
FetchMetadata: agentClient.WorkspaceAgentMetadata,
CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet,
WebRTCDialer: agentClient.ListenWorkspaceAgent,
Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelDebug),
})
defer func() {
@ -123,13 +120,15 @@ func TestWorkspaceAgentListen(t *testing.T) {
defer cancel()
resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID)
conn, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, nil)
conn, err := client.DialWorkspaceAgentTailnet(ctx, slog.Logger{}, resources[0].Agents[0].ID)
require.NoError(t, err)
defer func() {
_ = conn.Close()
}()
_, err = conn.Ping()
require.NoError(t, err)
require.Eventually(t, func() bool {
_, err := conn.Ping()
return err == nil
}, testutil.WaitMedium, testutil.IntervalFast)
})
t.Run("FailNonLatestBuild", func(t *testing.T) {
@ -202,75 +201,12 @@ func TestWorkspaceAgentListen(t *testing.T) {
agentClient := codersdk.New(client.URL)
agentClient.SessionToken = authToken
_, err = agentClient.ListenWorkspaceAgent(ctx, slogtest.Make(t, nil))
_, err = agentClient.ListenWorkspaceAgentTailnet(ctx)
require.Error(t, err)
require.ErrorContains(t, err, "build is outdated")
})
}
func TestWorkspaceAgentTURN(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, &coderdtest.Options{
IncludeProvisionerDaemon: true,
})
user := coderdtest.CreateFirstUser(t, client)
authToken := uuid.NewString()
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
Parse: echo.ParseComplete,
ProvisionDryRun: echo.ProvisionComplete,
Provision: []*proto.Provision_Response{{
Type: &proto.Provision_Response_Complete{
Complete: &proto.Provision_Complete{
Resources: []*proto.Resource{{
Name: "example",
Type: "aws_instance",
Agents: []*proto.Agent{{
Id: uuid.NewString(),
Auth: &proto.Agent_Token{
Token: authToken,
},
}},
}},
},
},
}},
})
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
agentClient := codersdk.New(client.URL)
agentClient.SessionToken = authToken
agentCloser := agent.New(agent.Options{
FetchMetadata: agentClient.WorkspaceAgentMetadata,
CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet,
WebRTCDialer: agentClient.ListenWorkspaceAgent,
Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelDebug),
})
defer func() {
_ = agentCloser.Close()
}()
resources := coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
opts := &peer.ConnOptions{
Logger: slogtest.Make(t, nil).Named("client"),
}
// Force a TURN connection!
opts.SettingEngine.SetNetworkTypes([]webrtc.NetworkType{webrtc.NetworkTypeTCP4})
conn, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, opts)
require.NoError(t, err)
defer func() {
_ = conn.Close()
}()
_, err = conn.Ping()
require.NoError(t, err)
}
func TestWorkspaceAgentTailnet(t *testing.T) {
t.Parallel()
client, daemonCloser := coderdtest.NewWithProvisionerCloser(t, nil)
@ -306,7 +242,6 @@ func TestWorkspaceAgentTailnet(t *testing.T) {
agentClient.SessionToken = authToken
agentCloser := agent.New(agent.Options{
FetchMetadata: agentClient.WorkspaceAgentMetadata,
WebRTCDialer: agentClient.ListenWorkspaceAgent,
CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet,
Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelDebug),
})
@ -373,7 +308,6 @@ func TestWorkspaceAgentPTY(t *testing.T) {
agentCloser := agent.New(agent.Options{
FetchMetadata: agentClient.WorkspaceAgentMetadata,
CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet,
WebRTCDialer: agentClient.ListenWorkspaceAgent,
Logger: slogtest.Make(t, nil).Named("agent").Leveled(slog.LevelDebug),
})
defer func() {

View File

@ -103,7 +103,6 @@ func setupProxyTest(t *testing.T) (*codersdk.Client, uuid.UUID, codersdk.Workspa
agentCloser := agent.New(agent.Options{
FetchMetadata: agentClient.WorkspaceAgentMetadata,
CoordinatorDialer: agentClient.ListenWorkspaceAgentTailnet,
WebRTCDialer: agentClient.ListenWorkspaceAgent,
Logger: slogtest.Make(t, nil).Named("agent"),
})
t.Cleanup(func() {

View File

@ -32,11 +32,11 @@ func New(dialer Dialer, inactiveTimeout time.Duration) *Cache {
}
// Dialer creates a new agent connection by ID.
type Dialer func(r *http.Request, id uuid.UUID) (agent.Conn, error)
type Dialer func(r *http.Request, id uuid.UUID) (*agent.Conn, error)
// Conn wraps an agent connection with a reusable HTTP transport.
type Conn struct {
agent.Conn
*agent.Conn
locks atomic.Uint64
timeoutMutex sync.Mutex

View File

@ -35,7 +35,7 @@ func TestCache(t *testing.T) {
t.Parallel()
t.Run("Same", func(t *testing.T) {
t.Parallel()
cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (agent.Conn, error) {
cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (*agent.Conn, error) {
return setupAgent(t, agent.Metadata{}, 0), nil
}, 0)
defer func() {
@ -50,7 +50,7 @@ func TestCache(t *testing.T) {
t.Run("Expire", func(t *testing.T) {
t.Parallel()
called := atomic.NewInt32(0)
cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (agent.Conn, error) {
cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (*agent.Conn, error) {
called.Add(1)
return setupAgent(t, agent.Metadata{}, 0), nil
}, time.Microsecond)
@ -69,7 +69,7 @@ func TestCache(t *testing.T) {
})
t.Run("NoExpireWhenLocked", func(t *testing.T) {
t.Parallel()
cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (agent.Conn, error) {
cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (*agent.Conn, error) {
return setupAgent(t, agent.Metadata{}, 0), nil
}, time.Microsecond)
defer func() {
@ -102,7 +102,7 @@ func TestCache(t *testing.T) {
}()
go server.Serve(random)
cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (agent.Conn, error) {
cache := wsconncache.New(func(r *http.Request, id uuid.UUID) (*agent.Conn, error) {
return setupAgent(t, agent.Metadata{}, 0), nil
}, time.Microsecond)
defer func() {
@ -139,7 +139,7 @@ func TestCache(t *testing.T) {
})
}
func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration) agent.Conn {
func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration) *agent.Conn {
metadata.DERPMap = tailnettest.RunDERPAndSTUN(t)
coordinator := tailnet.NewCoordinator()
@ -180,7 +180,7 @@ func setupAgent(t *testing.T, metadata agent.Metadata, ptyTimeout time.Duration)
return conn.UpdateNodes(node)
})
conn.SetNodeCallback(sendNode)
return &agent.TailnetConn{
return &agent.Conn{
Conn: conn,
}
}