fix: remove unnecessary user lookup in agent API calls (#17934)

# Use workspace.OwnerUsername instead of fetching the owner

This PR optimizes the agent API by using the `workspace.OwnerUsername` field directly instead of making an additional database query to fetch the owner's username. The change removes the need to call `GetUserByID` in the manifest API and workspace agent RPC endpoints.

An issue arose when the agent token was scoped without access to user data (`api_key_scope = "no_user_data"`), causing the agent to fail to fetch the manifest due to an RBAC issue.

Change-Id: I3b6e7581134e2374b364ee059e3b18ece3d98b41
Signed-off-by: Thomas Kosiewski <tk@coder.com>
This commit is contained in:
Thomas Kosiewski
2025-05-20 17:07:50 +02:00
committed by GitHub
parent 1267c9c405
commit 93f17bc73e
6 changed files with 195 additions and 118 deletions

View File

@ -47,7 +47,6 @@ func (a *ManifestAPI) GetManifest(ctx context.Context, _ *agentproto.GetManifest
scripts []database.WorkspaceAgentScript scripts []database.WorkspaceAgentScript
metadata []database.WorkspaceAgentMetadatum metadata []database.WorkspaceAgentMetadatum
workspace database.Workspace workspace database.Workspace
owner database.User
devcontainers []database.WorkspaceAgentDevcontainer devcontainers []database.WorkspaceAgentDevcontainer
) )
@ -76,10 +75,6 @@ func (a *ManifestAPI) GetManifest(ctx context.Context, _ *agentproto.GetManifest
if err != nil { if err != nil {
return xerrors.Errorf("getting workspace by id: %w", err) return xerrors.Errorf("getting workspace by id: %w", err)
} }
owner, err = a.Database.GetUserByID(ctx, workspace.OwnerID)
if err != nil {
return xerrors.Errorf("getting workspace owner by id: %w", err)
}
return err return err
}) })
eg.Go(func() (err error) { eg.Go(func() (err error) {
@ -98,7 +93,7 @@ func (a *ManifestAPI) GetManifest(ctx context.Context, _ *agentproto.GetManifest
AppSlugOrPort: "{{port}}", AppSlugOrPort: "{{port}}",
AgentName: workspaceAgent.Name, AgentName: workspaceAgent.Name,
WorkspaceName: workspace.Name, WorkspaceName: workspace.Name,
Username: owner.Username, Username: workspace.OwnerUsername,
} }
vscodeProxyURI := vscodeProxyURI(appSlug, a.AccessURL, a.AppHostname) vscodeProxyURI := vscodeProxyURI(appSlug, a.AccessURL, a.AppHostname)
@ -115,7 +110,7 @@ func (a *ManifestAPI) GetManifest(ctx context.Context, _ *agentproto.GetManifest
} }
} }
apps, err := dbAppsToProto(dbApps, workspaceAgent, owner.Username, workspace) apps, err := dbAppsToProto(dbApps, workspaceAgent, workspace.OwnerUsername, workspace)
if err != nil { if err != nil {
return nil, xerrors.Errorf("converting workspace apps: %w", err) return nil, xerrors.Errorf("converting workspace apps: %w", err)
} }
@ -128,7 +123,7 @@ func (a *ManifestAPI) GetManifest(ctx context.Context, _ *agentproto.GetManifest
return &agentproto.Manifest{ return &agentproto.Manifest{
AgentId: workspaceAgent.ID[:], AgentId: workspaceAgent.ID[:],
AgentName: workspaceAgent.Name, AgentName: workspaceAgent.Name,
OwnerUsername: owner.Username, OwnerUsername: workspace.OwnerUsername,
WorkspaceId: workspace.ID[:], WorkspaceId: workspace.ID[:],
WorkspaceName: workspace.Name, WorkspaceName: workspace.Name,
GitAuthConfigs: gitAuthConfigs, GitAuthConfigs: gitAuthConfigs,

View File

@ -46,9 +46,10 @@ func TestGetManifest(t *testing.T) {
Username: "cool-user", Username: "cool-user",
} }
workspace = database.Workspace{ workspace = database.Workspace{
ID: uuid.New(), ID: uuid.New(),
OwnerID: owner.ID, OwnerID: owner.ID,
Name: "cool-workspace", OwnerUsername: owner.Username,
Name: "cool-workspace",
} }
agent = database.WorkspaceAgent{ agent = database.WorkspaceAgent{
ID: uuid.New(), ID: uuid.New(),
@ -336,7 +337,6 @@ func TestGetManifest(t *testing.T) {
}).Return(metadata, nil) }).Return(metadata, nil)
mDB.EXPECT().GetWorkspaceAgentDevcontainersByAgentID(gomock.Any(), agent.ID).Return(devcontainers, nil) mDB.EXPECT().GetWorkspaceAgentDevcontainersByAgentID(gomock.Any(), agent.ID).Return(devcontainers, nil)
mDB.EXPECT().GetWorkspaceByID(gomock.Any(), workspace.ID).Return(workspace, nil) mDB.EXPECT().GetWorkspaceByID(gomock.Any(), workspace.ID).Return(workspace, nil)
mDB.EXPECT().GetUserByID(gomock.Any(), workspace.OwnerID).Return(owner, nil)
got, err := api.GetManifest(context.Background(), &agentproto.GetManifestRequest{}) got, err := api.GetManifest(context.Background(), &agentproto.GetManifestRequest{})
require.NoError(t, err) require.NoError(t, err)
@ -404,7 +404,6 @@ func TestGetManifest(t *testing.T) {
}).Return([]database.WorkspaceAgentMetadatum{}, nil) }).Return([]database.WorkspaceAgentMetadatum{}, nil)
mDB.EXPECT().GetWorkspaceAgentDevcontainersByAgentID(gomock.Any(), childAgent.ID).Return([]database.WorkspaceAgentDevcontainer{}, nil) mDB.EXPECT().GetWorkspaceAgentDevcontainersByAgentID(gomock.Any(), childAgent.ID).Return([]database.WorkspaceAgentDevcontainer{}, nil)
mDB.EXPECT().GetWorkspaceByID(gomock.Any(), workspace.ID).Return(workspace, nil) mDB.EXPECT().GetWorkspaceByID(gomock.Any(), workspace.ID).Return(workspace, nil)
mDB.EXPECT().GetUserByID(gomock.Any(), workspace.OwnerID).Return(owner, nil)
got, err := api.GetManifest(context.Background(), &agentproto.GetManifestRequest{}) got, err := api.GetManifest(context.Background(), &agentproto.GetManifestRequest{})
require.NoError(t, err) require.NoError(t, err)
@ -468,7 +467,6 @@ func TestGetManifest(t *testing.T) {
}).Return(metadata, nil) }).Return(metadata, nil)
mDB.EXPECT().GetWorkspaceAgentDevcontainersByAgentID(gomock.Any(), agent.ID).Return(devcontainers, nil) mDB.EXPECT().GetWorkspaceAgentDevcontainersByAgentID(gomock.Any(), agent.ID).Return(devcontainers, nil)
mDB.EXPECT().GetWorkspaceByID(gomock.Any(), workspace.ID).Return(workspace, nil) mDB.EXPECT().GetWorkspaceByID(gomock.Any(), workspace.ID).Return(workspace, nil)
mDB.EXPECT().GetUserByID(gomock.Any(), workspace.OwnerID).Return(owner, nil)
got, err := api.GetManifest(context.Background(), &agentproto.GetManifestRequest{}) got, err := api.GetManifest(context.Background(), &agentproto.GetManifestRequest{})
require.NoError(t, err) require.NoError(t, err)

View File

@ -439,25 +439,55 @@ func TestWorkspaceAgentConnectRPC(t *testing.T) {
t.Run("Connect", func(t *testing.T) { t.Run("Connect", func(t *testing.T) {
t.Parallel() t.Parallel()
client, db := coderdtest.NewWithDatabase(t, nil) for _, tc := range []struct {
user := coderdtest.CreateFirstUser(t, client) name string
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ apiKeyScope rbac.ScopeName
OrganizationID: user.OrganizationID, }{
OwnerID: user.UserID, {
}).WithAgent().Do() name: "empty (backwards compat)",
_ = agenttest.New(t, client.URL, r.AgentToken) apiKeyScope: "",
resources := coderdtest.AwaitWorkspaceAgents(t, client, r.Workspace.ID) },
{
name: "all",
apiKeyScope: rbac.ScopeAll,
},
{
name: "no_user_data",
apiKeyScope: rbac.ScopeNoUserData,
},
{
name: "application_connect",
apiKeyScope: rbac.ScopeApplicationConnect,
},
} {
t.Run(tc.name, func(t *testing.T) {
client, db := coderdtest.NewWithDatabase(t, nil)
user := coderdtest.CreateFirstUser(t, client)
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
}).WithAgent(func(agents []*proto.Agent) []*proto.Agent {
for _, agent := range agents {
agent.ApiKeyScope = string(tc.apiKeyScope)
}
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) return agents
defer cancel() }).Do()
_ = agenttest.New(t, client.URL, r.AgentToken)
resources := coderdtest.NewWorkspaceAgentWaiter(t, client, r.Workspace.ID).AgentNames([]string{}).Wait()
conn, err := workspacesdk.New(client). ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
DialAgent(ctx, resources[0].Agents[0].ID, nil) defer cancel()
require.NoError(t, err)
defer func() { conn, err := workspacesdk.New(client).
_ = conn.Close() DialAgent(ctx, resources[0].Agents[0].ID, nil)
}() require.NoError(t, err)
conn.AwaitReachable(ctx) defer func() {
_ = conn.Close()
}()
conn.AwaitReachable(ctx)
})
}
}) })
t.Run("FailNonLatestBuild", func(t *testing.T) { t.Run("FailNonLatestBuild", func(t *testing.T) {

View File

@ -76,17 +76,8 @@ func (api *API) workspaceAgentRPC(rw http.ResponseWriter, r *http.Request) {
return return
} }
owner, err := api.Database.GetUserByID(ctx, workspace.OwnerID)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Internal error fetching user.",
Detail: err.Error(),
})
return
}
logger = logger.With( logger = logger.With(
slog.F("owner", owner.Username), slog.F("owner", workspace.OwnerUsername),
slog.F("workspace_name", workspace.Name), slog.F("workspace_name", workspace.Name),
slog.F("agent_name", workspaceAgent.Name), slog.F("agent_name", workspaceAgent.Name),
) )
@ -170,7 +161,7 @@ func (api *API) workspaceAgentRPC(rw http.ResponseWriter, r *http.Request) {
}) })
streamID := tailnet.StreamID{ streamID := tailnet.StreamID{
Name: fmt.Sprintf("%s-%s-%s", owner.Username, workspace.Name, workspaceAgent.Name), Name: fmt.Sprintf("%s-%s-%s", workspace.OwnerUsername, workspace.Name, workspaceAgent.Name),
ID: workspaceAgent.ID, ID: workspaceAgent.ID,
Auth: tailnet.AgentCoordinateeAuth{ID: workspaceAgent.ID}, Auth: tailnet.AgentCoordinateeAuth{ID: workspaceAgent.ID},
} }

View File

@ -13,6 +13,7 @@ import (
"github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbfake" "github.com/coder/coder/v2/coderd/database/dbfake"
"github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/codersdk/agentsdk" "github.com/coder/coder/v2/codersdk/agentsdk"
"github.com/coder/coder/v2/provisionersdk/proto" "github.com/coder/coder/v2/provisionersdk/proto"
"github.com/coder/coder/v2/testutil" "github.com/coder/coder/v2/testutil"
@ -22,89 +23,150 @@ import (
func TestWorkspaceAgentReportStats(t *testing.T) { func TestWorkspaceAgentReportStats(t *testing.T) {
t.Parallel() t.Parallel()
tickCh := make(chan time.Time) for _, tc := range []struct {
flushCh := make(chan int, 1) name string
client, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{ apiKeyScope rbac.ScopeName
WorkspaceUsageTrackerFlush: flushCh, }{
WorkspaceUsageTrackerTick: tickCh, {
}) name: "empty (backwards compat)",
user := coderdtest.CreateFirstUser(t, client) apiKeyScope: "",
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
LastUsedAt: dbtime.Now().Add(-time.Minute),
}).WithAgent().Do()
ac := agentsdk.New(client.URL)
ac.SetSessionToken(r.AgentToken)
conn, err := ac.ConnectRPC(context.Background())
require.NoError(t, err)
defer func() {
_ = conn.Close()
}()
agentAPI := agentproto.NewDRPCAgentClient(conn)
_, err = agentAPI.UpdateStats(context.Background(), &agentproto.UpdateStatsRequest{
Stats: &agentproto.Stats{
ConnectionsByProto: map[string]int64{"TCP": 1},
ConnectionCount: 1,
RxPackets: 1,
RxBytes: 1,
TxPackets: 1,
TxBytes: 1,
SessionCountVscode: 1,
SessionCountJetbrains: 0,
SessionCountReconnectingPty: 0,
SessionCountSsh: 0,
ConnectionMedianLatencyMs: 10,
}, },
}) {
require.NoError(t, err) name: "all",
apiKeyScope: rbac.ScopeAll,
},
{
name: "no_user_data",
apiKeyScope: rbac.ScopeNoUserData,
},
{
name: "application_connect",
apiKeyScope: rbac.ScopeApplicationConnect,
},
} {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
tickCh <- dbtime.Now() tickCh := make(chan time.Time)
count := <-flushCh flushCh := make(chan int, 1)
require.Equal(t, 1, count, "expected one flush with one id") client, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{
WorkspaceUsageTrackerFlush: flushCh,
WorkspaceUsageTrackerTick: tickCh,
})
user := coderdtest.CreateFirstUser(t, client)
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
OrganizationID: user.OrganizationID,
OwnerID: user.UserID,
LastUsedAt: dbtime.Now().Add(-time.Minute),
}).WithAgent(
func(agent []*proto.Agent) []*proto.Agent {
for _, a := range agent {
a.ApiKeyScope = string(tc.apiKeyScope)
}
newWorkspace, err := client.Workspace(context.Background(), r.Workspace.ID) return agent
require.NoError(t, err) },
).Do()
assert.True(t, ac := agentsdk.New(client.URL)
newWorkspace.LastUsedAt.After(r.Workspace.LastUsedAt), ac.SetSessionToken(r.AgentToken)
"%s is not after %s", newWorkspace.LastUsedAt, r.Workspace.LastUsedAt, conn, err := ac.ConnectRPC(context.Background())
) require.NoError(t, err)
defer func() {
_ = conn.Close()
}()
agentAPI := agentproto.NewDRPCAgentClient(conn)
_, err = agentAPI.UpdateStats(context.Background(), &agentproto.UpdateStatsRequest{
Stats: &agentproto.Stats{
ConnectionsByProto: map[string]int64{"TCP": 1},
ConnectionCount: 1,
RxPackets: 1,
RxBytes: 1,
TxPackets: 1,
TxBytes: 1,
SessionCountVscode: 1,
SessionCountJetbrains: 0,
SessionCountReconnectingPty: 0,
SessionCountSsh: 0,
ConnectionMedianLatencyMs: 10,
},
})
require.NoError(t, err)
tickCh <- dbtime.Now()
count := <-flushCh
require.Equal(t, 1, count, "expected one flush with one id")
newWorkspace, err := client.Workspace(context.Background(), r.Workspace.ID)
require.NoError(t, err)
assert.True(t,
newWorkspace.LastUsedAt.After(r.Workspace.LastUsedAt),
"%s is not after %s", newWorkspace.LastUsedAt, r.Workspace.LastUsedAt,
)
})
}
} }
func TestAgentAPI_LargeManifest(t *testing.T) { func TestAgentAPI_LargeManifest(t *testing.T) {
t.Parallel() t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, store := coderdtest.NewWithDatabase(t, nil) for _, tc := range []struct {
adminUser := coderdtest.CreateFirstUser(t, client) name string
n := 512000 apiKeyScope rbac.ScopeName
longScript := make([]byte, n) }{
for i := range longScript { {
longScript[i] = 'q' name: "empty (backwards compat)",
apiKeyScope: "",
},
{
name: "all",
apiKeyScope: rbac.ScopeAll,
},
{
name: "no_user_data",
apiKeyScope: rbac.ScopeNoUserData,
},
{
name: "application_connect",
apiKeyScope: rbac.ScopeApplicationConnect,
},
} {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, store := coderdtest.NewWithDatabase(t, nil)
adminUser := coderdtest.CreateFirstUser(t, client)
n := 512000
longScript := make([]byte, n)
for i := range longScript {
longScript[i] = 'q'
}
r := dbfake.WorkspaceBuild(t, store, database.WorkspaceTable{
OrganizationID: adminUser.OrganizationID,
OwnerID: adminUser.UserID,
}).WithAgent(func(agents []*proto.Agent) []*proto.Agent {
agents[0].Scripts = []*proto.Script{
{
Script: string(longScript),
},
}
agents[0].ApiKeyScope = string(tc.apiKeyScope)
return agents
}).Do()
ac := agentsdk.New(client.URL)
ac.SetSessionToken(r.AgentToken)
conn, err := ac.ConnectRPC(ctx)
defer func() {
_ = conn.Close()
}()
require.NoError(t, err)
agentAPI := agentproto.NewDRPCAgentClient(conn)
manifest, err := agentAPI.GetManifest(ctx, &agentproto.GetManifestRequest{})
require.NoError(t, err)
require.Len(t, manifest.Scripts, 1)
require.Len(t, manifest.Scripts[0].Script, n)
})
} }
r := dbfake.WorkspaceBuild(t, store, database.WorkspaceTable{
OrganizationID: adminUser.OrganizationID,
OwnerID: adminUser.UserID,
}).WithAgent(func(agents []*proto.Agent) []*proto.Agent {
agents[0].Scripts = []*proto.Script{
{
Script: string(longScript),
},
}
return agents
}).Do()
ac := agentsdk.New(client.URL)
ac.SetSessionToken(r.AgentToken)
conn, err := ac.ConnectRPC(ctx)
defer func() {
_ = conn.Close()
}()
require.NoError(t, err)
agentAPI := agentproto.NewDRPCAgentClient(conn)
manifest, err := agentAPI.GetManifest(ctx, &agentproto.GetManifestRequest{})
require.NoError(t, err)
require.Len(t, manifest.Scripts, 1)
require.Len(t, manifest.Scripts[0].Script, n)
} }

View File

@ -141,6 +141,7 @@
kubectl kubectl
kubectx kubectx
kubernetes-helm kubernetes-helm
lazydocker
lazygit lazygit
less less
mockgen mockgen