fix: use typed wireguard public keys in database structs (#2639)

This commit is contained in:
Colin Adler
2022-06-24 15:45:28 -05:00
committed by GitHub
parent 115730341e
commit 26e85b0bbc
6 changed files with 95 additions and 25 deletions

View File

@ -0,0 +1,74 @@
package dbtypes
import (
"database/sql/driver"
"golang.org/x/xerrors"
"tailscale.com/types/key"
)
// NodePublic is a wrapper around a key.NodePublic which represents the
// Wireguard public key for an agent..
type NodePublic key.NodePublic
func (n NodePublic) String() string {
return key.NodePublic(n).String()
}
// This is necessary so NodePublic can be serialized in JSON loggers.
func (n NodePublic) MarshalJSON() ([]byte, error) {
j, err := key.NodePublic(n).MarshalText()
// surround in quotes to make it a JSON string
j = append([]byte{'"'}, append(j, '"')...)
return j, err
}
// Value is so NodePublic can be inserted into the database.
func (n NodePublic) Value() (driver.Value, error) {
return key.NodePublic(n).MarshalText()
}
// Scan is so NodePublic can be read from the database.
func (n *NodePublic) Scan(value interface{}) error {
switch v := value.(type) {
case []byte:
return (*key.NodePublic)(n).UnmarshalText(v)
case string:
return (*key.NodePublic)(n).UnmarshalText([]byte(v))
default:
return xerrors.Errorf("unexpected type: %T", v)
}
}
// NodePublic is a wrapper around a key.NodePublic which represents the
// Tailscale disco key for an agent.
type DiscoPublic key.DiscoPublic
func (n DiscoPublic) String() string {
return key.DiscoPublic(n).String()
}
// This is necessary so DiscoPublic can be serialized in JSON loggers.
func (n DiscoPublic) MarshalJSON() ([]byte, error) {
j, err := key.DiscoPublic(n).MarshalText()
// surround in quotes to make it a JSON string
j = append([]byte{'"'}, append(j, '"')...)
return j, err
}
// Value is so DiscoPublic can be inserted into the database.
func (n DiscoPublic) Value() (driver.Value, error) {
return key.DiscoPublic(n).MarshalText()
}
// Scan is so DiscoPublic can be read from the database.
func (n *DiscoPublic) Scan(value interface{}) error {
switch v := value.(type) {
case []byte:
return (*key.DiscoPublic)(n).UnmarshalText(v)
case string:
return (*key.DiscoPublic)(n).UnmarshalText([]byte(v))
default:
return xerrors.Errorf("unexpected type: %T", v)
}
}

View File

@ -10,6 +10,7 @@ import (
"fmt"
"time"
"github.com/coder/coder/coderd/database/dbtypes"
"github.com/google/uuid"
"github.com/tabbed/pqtype"
)
@ -521,8 +522,8 @@ type WorkspaceAgent struct {
ResourceMetadata pqtype.NullRawMessage `db:"resource_metadata" json:"resource_metadata"`
Directory string `db:"directory" json:"directory"`
WireguardNodeIPv6 pqtype.Inet `db:"wireguard_node_ipv6" json:"wireguard_node_ipv6"`
WireguardNodePublicKey string `db:"wireguard_node_public_key" json:"wireguard_node_public_key"`
WireguardDiscoPublicKey string `db:"wireguard_disco_public_key" json:"wireguard_disco_public_key"`
WireguardNodePublicKey dbtypes.NodePublic `db:"wireguard_node_public_key" json:"wireguard_node_public_key"`
WireguardDiscoPublicKey dbtypes.DiscoPublic `db:"wireguard_disco_public_key" json:"wireguard_disco_public_key"`
}
type WorkspaceApp struct {

View File

@ -10,6 +10,7 @@ import (
"encoding/json"
"time"
"github.com/coder/coder/coderd/database/dbtypes"
"github.com/google/uuid"
"github.com/lib/pq"
"github.com/tabbed/pqtype"
@ -3105,8 +3106,8 @@ type InsertWorkspaceAgentParams struct {
InstanceMetadata pqtype.NullRawMessage `db:"instance_metadata" json:"instance_metadata"`
ResourceMetadata pqtype.NullRawMessage `db:"resource_metadata" json:"resource_metadata"`
WireguardNodeIPv6 pqtype.Inet `db:"wireguard_node_ipv6" json:"wireguard_node_ipv6"`
WireguardNodePublicKey string `db:"wireguard_node_public_key" json:"wireguard_node_public_key"`
WireguardDiscoPublicKey string `db:"wireguard_disco_public_key" json:"wireguard_disco_public_key"`
WireguardNodePublicKey dbtypes.NodePublic `db:"wireguard_node_public_key" json:"wireguard_node_public_key"`
WireguardDiscoPublicKey dbtypes.DiscoPublic `db:"wireguard_disco_public_key" json:"wireguard_disco_public_key"`
}
func (q *sqlQuerier) InsertWorkspaceAgent(ctx context.Context, arg InsertWorkspaceAgentParams) (WorkspaceAgent, error) {
@ -3196,9 +3197,9 @@ WHERE
`
type UpdateWorkspaceAgentKeysByIDParams struct {
ID uuid.UUID `db:"id" json:"id"`
WireguardNodePublicKey string `db:"wireguard_node_public_key" json:"wireguard_node_public_key"`
WireguardDiscoPublicKey string `db:"wireguard_disco_public_key" json:"wireguard_disco_public_key"`
ID uuid.UUID `db:"id" json:"id"`
WireguardNodePublicKey dbtypes.NodePublic `db:"wireguard_node_public_key" json:"wireguard_node_public_key"`
WireguardDiscoPublicKey dbtypes.DiscoPublic `db:"wireguard_disco_public_key" json:"wireguard_disco_public_key"`
}
func (q *sqlQuerier) UpdateWorkspaceAgentKeysByID(ctx context.Context, arg UpdateWorkspaceAgentKeysByIDParams) error {

View File

@ -17,10 +17,10 @@ packages:
output_db_file_name: db_tmp.go
overrides:
- column: workspaces.wireguard_public_key
go_type: tailscale.com/types/key.MachinePublic
- column: workspaces.disco_public_key
go_type: tailscale.com/types/key.DiscoPublic
- column: workspace_agents.wireguard_node_public_key
go_type: github.com/coder/coder/coderd/database/dbtypes.NodePublic
- column: workspace_agents.wireguard_disco_public_key
go_type: github.com/coder/coder/coderd/database/dbtypes.DiscoPublic
rename:
api_key: APIKey

View File

@ -19,11 +19,11 @@ import (
protobuf "google.golang.org/protobuf/proto"
"storj.io/drpc/drpcmux"
"storj.io/drpc/drpcserver"
"tailscale.com/types/key"
"cdr.dev/slog"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/dbtypes"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/coderd/parameter"
"github.com/coder/coder/coderd/rbac"
@ -761,8 +761,8 @@ func insertWorkspaceResource(ctx context.Context, db database.Store, jobID uuid.
Valid: prAgent.StartupScript != "",
},
WireguardNodeIPv6: peerwg.UUIDToInet(agentID),
WireguardNodePublicKey: key.NodePublic{}.String(),
WireguardDiscoPublicKey: key.DiscoPublic{}.String(),
WireguardNodePublicKey: dbtypes.NodePublic{},
WireguardDiscoPublicKey: dbtypes.DiscoPublic{},
})
if err != nil {
return xerrors.Errorf("insert agent: %w", err)

View File

@ -22,6 +22,7 @@ import (
"cdr.dev/slog"
"github.com/coder/coder/agent"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/dbtypes"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/coderd/httpmw"
"github.com/coder/coder/coderd/rbac"
@ -488,8 +489,8 @@ func (api *API) postWorkspaceAgentKeys(rw http.ResponseWriter, r *http.Request)
err := api.Database.UpdateWorkspaceAgentKeysByID(ctx, database.UpdateWorkspaceAgentKeysByIDParams{
ID: workspaceAgent.ID,
WireguardNodePublicKey: keys.Public.String(),
WireguardDiscoPublicKey: keys.Disco.String(),
WireguardNodePublicKey: dbtypes.NodePublic(keys.Public),
WireguardDiscoPublicKey: dbtypes.DiscoPublic(keys.Disco),
})
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
@ -711,15 +712,8 @@ func convertWorkspaceAgent(dbAgent database.WorkspaceAgent, apps []codersdk.Work
Directory: dbAgent.Directory,
Apps: apps,
IPv6: inetToNetaddr(dbAgent.WireguardNodeIPv6),
}
err := workspaceAgent.WireguardPublicKey.UnmarshalText([]byte(dbAgent.WireguardNodePublicKey))
if err != nil {
return codersdk.WorkspaceAgent{}, xerrors.Errorf("unmarshal wireguard node public key %q: %w", dbAgent.WireguardNodePublicKey, err)
}
err = workspaceAgent.DiscoPublicKey.UnmarshalText([]byte(dbAgent.WireguardDiscoPublicKey))
if err != nil {
return codersdk.WorkspaceAgent{}, xerrors.Errorf("unmarshal disco public key %q: %w", dbAgent.WireguardDiscoPublicKey, err)
WireguardPublicKey: key.NodePublic(dbAgent.WireguardNodePublicKey),
DiscoPublicKey: key.DiscoPublic(dbAgent.WireguardDiscoPublicKey),
}
if dbAgent.FirstConnectedAt.Valid {