mirror of
https://github.com/coder/coder.git
synced 2025-07-08 11:39:50 +00:00
fix: use typed wireguard public keys in database structs (#2639)
This commit is contained in:
74
coderd/database/dbtypes/dbtypes.go
Normal file
74
coderd/database/dbtypes/dbtypes.go
Normal file
@ -0,0 +1,74 @@
|
||||
package dbtypes
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
"tailscale.com/types/key"
|
||||
)
|
||||
|
||||
// NodePublic is a wrapper around a key.NodePublic which represents the
|
||||
// Wireguard public key for an agent..
|
||||
type NodePublic key.NodePublic
|
||||
|
||||
func (n NodePublic) String() string {
|
||||
return key.NodePublic(n).String()
|
||||
}
|
||||
|
||||
// This is necessary so NodePublic can be serialized in JSON loggers.
|
||||
func (n NodePublic) MarshalJSON() ([]byte, error) {
|
||||
j, err := key.NodePublic(n).MarshalText()
|
||||
// surround in quotes to make it a JSON string
|
||||
j = append([]byte{'"'}, append(j, '"')...)
|
||||
return j, err
|
||||
}
|
||||
|
||||
// Value is so NodePublic can be inserted into the database.
|
||||
func (n NodePublic) Value() (driver.Value, error) {
|
||||
return key.NodePublic(n).MarshalText()
|
||||
}
|
||||
|
||||
// Scan is so NodePublic can be read from the database.
|
||||
func (n *NodePublic) Scan(value interface{}) error {
|
||||
switch v := value.(type) {
|
||||
case []byte:
|
||||
return (*key.NodePublic)(n).UnmarshalText(v)
|
||||
case string:
|
||||
return (*key.NodePublic)(n).UnmarshalText([]byte(v))
|
||||
default:
|
||||
return xerrors.Errorf("unexpected type: %T", v)
|
||||
}
|
||||
}
|
||||
|
||||
// NodePublic is a wrapper around a key.NodePublic which represents the
|
||||
// Tailscale disco key for an agent.
|
||||
type DiscoPublic key.DiscoPublic
|
||||
|
||||
func (n DiscoPublic) String() string {
|
||||
return key.DiscoPublic(n).String()
|
||||
}
|
||||
|
||||
// This is necessary so DiscoPublic can be serialized in JSON loggers.
|
||||
func (n DiscoPublic) MarshalJSON() ([]byte, error) {
|
||||
j, err := key.DiscoPublic(n).MarshalText()
|
||||
// surround in quotes to make it a JSON string
|
||||
j = append([]byte{'"'}, append(j, '"')...)
|
||||
return j, err
|
||||
}
|
||||
|
||||
// Value is so DiscoPublic can be inserted into the database.
|
||||
func (n DiscoPublic) Value() (driver.Value, error) {
|
||||
return key.DiscoPublic(n).MarshalText()
|
||||
}
|
||||
|
||||
// Scan is so DiscoPublic can be read from the database.
|
||||
func (n *DiscoPublic) Scan(value interface{}) error {
|
||||
switch v := value.(type) {
|
||||
case []byte:
|
||||
return (*key.DiscoPublic)(n).UnmarshalText(v)
|
||||
case string:
|
||||
return (*key.DiscoPublic)(n).UnmarshalText([]byte(v))
|
||||
default:
|
||||
return xerrors.Errorf("unexpected type: %T", v)
|
||||
}
|
||||
}
|
@ -10,6 +10,7 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/coder/coder/coderd/database/dbtypes"
|
||||
"github.com/google/uuid"
|
||||
"github.com/tabbed/pqtype"
|
||||
)
|
||||
@ -521,8 +522,8 @@ type WorkspaceAgent struct {
|
||||
ResourceMetadata pqtype.NullRawMessage `db:"resource_metadata" json:"resource_metadata"`
|
||||
Directory string `db:"directory" json:"directory"`
|
||||
WireguardNodeIPv6 pqtype.Inet `db:"wireguard_node_ipv6" json:"wireguard_node_ipv6"`
|
||||
WireguardNodePublicKey string `db:"wireguard_node_public_key" json:"wireguard_node_public_key"`
|
||||
WireguardDiscoPublicKey string `db:"wireguard_disco_public_key" json:"wireguard_disco_public_key"`
|
||||
WireguardNodePublicKey dbtypes.NodePublic `db:"wireguard_node_public_key" json:"wireguard_node_public_key"`
|
||||
WireguardDiscoPublicKey dbtypes.DiscoPublic `db:"wireguard_disco_public_key" json:"wireguard_disco_public_key"`
|
||||
}
|
||||
|
||||
type WorkspaceApp struct {
|
||||
|
@ -10,6 +10,7 @@ import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/coder/coder/coderd/database/dbtypes"
|
||||
"github.com/google/uuid"
|
||||
"github.com/lib/pq"
|
||||
"github.com/tabbed/pqtype"
|
||||
@ -3105,8 +3106,8 @@ type InsertWorkspaceAgentParams struct {
|
||||
InstanceMetadata pqtype.NullRawMessage `db:"instance_metadata" json:"instance_metadata"`
|
||||
ResourceMetadata pqtype.NullRawMessage `db:"resource_metadata" json:"resource_metadata"`
|
||||
WireguardNodeIPv6 pqtype.Inet `db:"wireguard_node_ipv6" json:"wireguard_node_ipv6"`
|
||||
WireguardNodePublicKey string `db:"wireguard_node_public_key" json:"wireguard_node_public_key"`
|
||||
WireguardDiscoPublicKey string `db:"wireguard_disco_public_key" json:"wireguard_disco_public_key"`
|
||||
WireguardNodePublicKey dbtypes.NodePublic `db:"wireguard_node_public_key" json:"wireguard_node_public_key"`
|
||||
WireguardDiscoPublicKey dbtypes.DiscoPublic `db:"wireguard_disco_public_key" json:"wireguard_disco_public_key"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) InsertWorkspaceAgent(ctx context.Context, arg InsertWorkspaceAgentParams) (WorkspaceAgent, error) {
|
||||
@ -3196,9 +3197,9 @@ WHERE
|
||||
`
|
||||
|
||||
type UpdateWorkspaceAgentKeysByIDParams struct {
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
WireguardNodePublicKey string `db:"wireguard_node_public_key" json:"wireguard_node_public_key"`
|
||||
WireguardDiscoPublicKey string `db:"wireguard_disco_public_key" json:"wireguard_disco_public_key"`
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
WireguardNodePublicKey dbtypes.NodePublic `db:"wireguard_node_public_key" json:"wireguard_node_public_key"`
|
||||
WireguardDiscoPublicKey dbtypes.DiscoPublic `db:"wireguard_disco_public_key" json:"wireguard_disco_public_key"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) UpdateWorkspaceAgentKeysByID(ctx context.Context, arg UpdateWorkspaceAgentKeysByIDParams) error {
|
||||
|
@ -17,10 +17,10 @@ packages:
|
||||
output_db_file_name: db_tmp.go
|
||||
|
||||
overrides:
|
||||
- column: workspaces.wireguard_public_key
|
||||
go_type: tailscale.com/types/key.MachinePublic
|
||||
- column: workspaces.disco_public_key
|
||||
go_type: tailscale.com/types/key.DiscoPublic
|
||||
- column: workspace_agents.wireguard_node_public_key
|
||||
go_type: github.com/coder/coder/coderd/database/dbtypes.NodePublic
|
||||
- column: workspace_agents.wireguard_disco_public_key
|
||||
go_type: github.com/coder/coder/coderd/database/dbtypes.DiscoPublic
|
||||
|
||||
rename:
|
||||
api_key: APIKey
|
||||
|
@ -19,11 +19,11 @@ import (
|
||||
protobuf "google.golang.org/protobuf/proto"
|
||||
"storj.io/drpc/drpcmux"
|
||||
"storj.io/drpc/drpcserver"
|
||||
"tailscale.com/types/key"
|
||||
|
||||
"cdr.dev/slog"
|
||||
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/dbtypes"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/coderd/parameter"
|
||||
"github.com/coder/coder/coderd/rbac"
|
||||
@ -761,8 +761,8 @@ func insertWorkspaceResource(ctx context.Context, db database.Store, jobID uuid.
|
||||
Valid: prAgent.StartupScript != "",
|
||||
},
|
||||
WireguardNodeIPv6: peerwg.UUIDToInet(agentID),
|
||||
WireguardNodePublicKey: key.NodePublic{}.String(),
|
||||
WireguardDiscoPublicKey: key.DiscoPublic{}.String(),
|
||||
WireguardNodePublicKey: dbtypes.NodePublic{},
|
||||
WireguardDiscoPublicKey: dbtypes.DiscoPublic{},
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("insert agent: %w", err)
|
||||
|
@ -22,6 +22,7 @@ import (
|
||||
"cdr.dev/slog"
|
||||
"github.com/coder/coder/agent"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
"github.com/coder/coder/coderd/database/dbtypes"
|
||||
"github.com/coder/coder/coderd/httpapi"
|
||||
"github.com/coder/coder/coderd/httpmw"
|
||||
"github.com/coder/coder/coderd/rbac"
|
||||
@ -488,8 +489,8 @@ func (api *API) postWorkspaceAgentKeys(rw http.ResponseWriter, r *http.Request)
|
||||
|
||||
err := api.Database.UpdateWorkspaceAgentKeysByID(ctx, database.UpdateWorkspaceAgentKeysByIDParams{
|
||||
ID: workspaceAgent.ID,
|
||||
WireguardNodePublicKey: keys.Public.String(),
|
||||
WireguardDiscoPublicKey: keys.Disco.String(),
|
||||
WireguardNodePublicKey: dbtypes.NodePublic(keys.Public),
|
||||
WireguardDiscoPublicKey: dbtypes.DiscoPublic(keys.Disco),
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
|
||||
@ -711,15 +712,8 @@ func convertWorkspaceAgent(dbAgent database.WorkspaceAgent, apps []codersdk.Work
|
||||
Directory: dbAgent.Directory,
|
||||
Apps: apps,
|
||||
IPv6: inetToNetaddr(dbAgent.WireguardNodeIPv6),
|
||||
}
|
||||
|
||||
err := workspaceAgent.WireguardPublicKey.UnmarshalText([]byte(dbAgent.WireguardNodePublicKey))
|
||||
if err != nil {
|
||||
return codersdk.WorkspaceAgent{}, xerrors.Errorf("unmarshal wireguard node public key %q: %w", dbAgent.WireguardNodePublicKey, err)
|
||||
}
|
||||
err = workspaceAgent.DiscoPublicKey.UnmarshalText([]byte(dbAgent.WireguardDiscoPublicKey))
|
||||
if err != nil {
|
||||
return codersdk.WorkspaceAgent{}, xerrors.Errorf("unmarshal disco public key %q: %w", dbAgent.WireguardDiscoPublicKey, err)
|
||||
WireguardPublicKey: key.NodePublic(dbAgent.WireguardNodePublicKey),
|
||||
DiscoPublicKey: key.DiscoPublic(dbAgent.WireguardDiscoPublicKey),
|
||||
}
|
||||
|
||||
if dbAgent.FirstConnectedAt.Valid {
|
||||
|
Reference in New Issue
Block a user