chore: get TUN/DNS working on Windows for CoderVPN (#16310)

This commit is contained in:
Dean Sheather
2025-01-29 18:09:36 +10:00
committed by GitHub
parent a658ccf362
commit 28088165a1
9 changed files with 199 additions and 55 deletions

View File

@ -41,7 +41,10 @@ func (r *RootCmd) vpnDaemonRun() *serpent.Command {
},
Handler: func(inv *serpent.Invocation) error {
ctx := inv.Context()
logger := inv.Logger.AppendSinks(sloghuman.Sink(inv.Stderr)).Leveled(slog.LevelDebug)
sinks := []slog.Sink{
sloghuman.Sink(inv.Stderr),
}
logger := inv.Logger.AppendSinks(sinks...).Leveled(slog.LevelDebug)
if rpcReadHandleInt < 0 || rpcWriteHandleInt < 0 {
return xerrors.Errorf("rpc-read-handle (%v) and rpc-write-handle (%v) must be positive", rpcReadHandleInt, rpcWriteHandleInt)
@ -60,7 +63,11 @@ func (r *RootCmd) vpnDaemonRun() *serpent.Command {
defer pipe.Close()
logger.Info(ctx, "starting tunnel")
tunnel, err := vpn.NewTunnel(ctx, logger, pipe, vpn.NewClient())
tunnel, err := vpn.NewTunnel(ctx, logger, pipe, vpn.NewClient(),
vpn.UseOSNetworkingStack(),
vpn.UseAsLogger(),
vpn.UseCustomLogSinks(sinks...),
)
if err != nil {
return xerrors.Errorf("create new tunnel for client: %w", err)
}

2
go.mod
View File

@ -423,7 +423,7 @@ require (
go.opentelemetry.io/proto/otlp v1.4.0 // indirect
go4.org/mem v0.0.0-20220726221520-4f986261bf13 // indirect
golang.org/x/time v0.9.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 // indirect
golang.zx2c4.com/wireguard/windows v0.5.3 // indirect
google.golang.org/appengine v1.6.8 // indirect

View File

@ -116,6 +116,9 @@ type Options struct {
Router router.Router
// TUNDev is optional, and is passed to the underlying wireguard engine.
TUNDev tun.Device
// WireguardMonitor is optional, and is passed to the underlying wireguard
// engine.
WireguardMonitor *netmon.Monitor
}
// TelemetrySink allows tailnet.Conn to send network telemetry to the Coder
@ -171,13 +174,15 @@ func NewConn(options *Options) (conn *Conn, err error) {
nodeID = tailcfg.NodeID(uid)
}
wireguardMonitor, err := netmon.New(Logger(options.Logger.Named("net.wgmonitor")))
if err != nil {
return nil, xerrors.Errorf("create wireguard link monitor: %w", err)
if options.WireguardMonitor == nil {
options.WireguardMonitor, err = netmon.New(Logger(options.Logger.Named("net.wgmonitor")))
if err != nil {
return nil, xerrors.Errorf("create wireguard link monitor: %w", err)
}
}
defer func() {
if err != nil {
wireguardMonitor.Close()
options.WireguardMonitor.Close()
}
}()
@ -186,7 +191,7 @@ func NewConn(options *Options) (conn *Conn, err error) {
}
sys := new(tsd.System)
wireguardEngine, err := wgengine.NewUserspaceEngine(Logger(options.Logger.Named("net.wgengine")), wgengine.Config{
NetMon: wireguardMonitor,
NetMon: options.WireguardMonitor,
Dialer: dialer,
ListenPort: options.ListenPort,
SetSubsystem: sys.Set,
@ -293,7 +298,7 @@ func NewConn(options *Options) (conn *Conn, err error) {
listeners: map[listenKey]*listener{},
tunDevice: sys.Tun.Get(),
netStack: netStack,
wireguardMonitor: wireguardMonitor,
wireguardMonitor: options.WireguardMonitor,
wireguardRouter: &router.Config{
LocalAddrs: options.Addresses,
},

View File

@ -8,6 +8,7 @@ import (
"golang.org/x/xerrors"
"tailscale.com/net/dns"
"tailscale.com/net/netmon"
"tailscale.com/wgengine/router"
"github.com/google/uuid"
@ -57,12 +58,13 @@ func NewClient() Client {
}
type Options struct {
Headers http.Header
Logger slog.Logger
DNSConfigurator dns.OSConfigurator
Router router.Router
TUNFileDescriptor *int
UpdateHandler tailnet.UpdatesHandler
Headers http.Header
Logger slog.Logger
DNSConfigurator dns.OSConfigurator
Router router.Router
TUNDevice tun.Device
WireguardMonitor *netmon.Monitor
UpdateHandler tailnet.UpdatesHandler
}
func (*client) NewConn(initCtx context.Context, serverURL *url.URL, token string, options *Options) (vpnC Conn, err error) {
@ -74,15 +76,6 @@ func (*client) NewConn(initCtx context.Context, serverURL *url.URL, token string
options.Headers = http.Header{}
}
var dev tun.Device
if options.TUNFileDescriptor != nil {
// No-op on non-Darwin platforms.
dev, err = makeTUN(*options.TUNFileDescriptor)
if err != nil {
return nil, xerrors.Errorf("make TUN: %w", err)
}
}
headers := options.Headers
sdk := codersdk.New(serverURL)
sdk.SetSessionToken(token)
@ -134,7 +127,8 @@ func (*client) NewConn(initCtx context.Context, serverURL *url.URL, token string
BlockEndpoints: connInfo.DisableDirectConnections,
DNSConfigurator: options.DNSConfigurator,
Router: options.Router,
TUNDev: dev,
TUNDev: options.TUNDevice,
WireguardMonitor: options.WireguardMonitor,
})
if err != nil {
return nil, xerrors.Errorf("create tailnet: %w", err)

View File

@ -47,8 +47,7 @@ func OpenTunnel(cReadFD, cWriteFD int32) int32 {
}
_, err = vpn.NewTunnel(ctx, slog.Make(), conn, vpn.NewClient(),
vpn.UseAsDNSConfig(),
vpn.UseAsRouter(),
vpn.UseOSNetworkingStack(),
vpn.UseAsLogger(),
)
if err != nil {

View File

@ -1,10 +1,10 @@
//go:build !darwin
//go:build !darwin && !windows
package vpn
import "github.com/tailscale/wireguard-go/tun"
import "cdr.dev/slog"
// This is a no-op on non-Darwin platforms.
func makeTUN(int) (tun.Device, error) {
return nil, nil
// This is a no-op on every platform except Darwin and Windows.
func GetNetworkingStack(_ *Tunnel, _ *StartRequest, _ slog.Logger) (NetworkStack, error) {
return NetworkStack{}, nil
}

View File

@ -5,26 +5,34 @@ package vpn
import (
"os"
"cdr.dev/slog"
"github.com/tailscale/wireguard-go/tun"
"golang.org/x/sys/unix"
"golang.org/x/xerrors"
)
func makeTUN(tunFD int) (tun.Device, error) {
dupTunFd, err := unix.Dup(tunFD)
func GetNetworkingStack(t *Tunnel, req *StartRequest, _ slog.Logger) (NetworkStack, error) {
tunFd := int(req.GetTunnelFileDescriptor())
dupTunFd, err := unix.Dup(tunFd)
if err != nil {
return nil, xerrors.Errorf("dup tun fd: %w", err)
return NetworkStack{}, xerrors.Errorf("dup tun fd: %w", err)
}
err = unix.SetNonblock(dupTunFd, true)
if err != nil {
unix.Close(dupTunFd)
return nil, xerrors.Errorf("set nonblock: %w", err)
return NetworkStack{}, xerrors.Errorf("set nonblock: %w", err)
}
fileTun, err := tun.CreateTUNFromFile(os.NewFile(uintptr(dupTunFd), "/dev/tun"), 0)
if err != nil {
unix.Close(dupTunFd)
return nil, xerrors.Errorf("create TUN from File: %w", err)
return NetworkStack{}, xerrors.Errorf("create TUN from File: %w", err)
}
return fileTun, nil
return NetworkStack{
WireguardMonitor: nil, // default is fine
TUNDevice: fileTun,
Router: NewRouter(t),
DNSConfigurator: NewDNSConfigurator(t),
}, nil
}

115
vpn/tun_windows.go Normal file
View File

@ -0,0 +1,115 @@
//go:build windows
package vpn
import (
"context"
"errors"
"time"
"github.com/coder/retry"
"github.com/tailscale/wireguard-go/tun"
"golang.org/x/sys/windows"
"golang.org/x/xerrors"
"golang.zx2c4.com/wintun"
"tailscale.com/net/dns"
"tailscale.com/net/netmon"
"tailscale.com/net/tstun"
"tailscale.com/types/logger"
"tailscale.com/util/winutil"
"tailscale.com/wgengine/router"
"cdr.dev/slog"
"github.com/coder/coder/v2/tailnet"
)
const tunName = "Coder"
func GetNetworkingStack(t *Tunnel, _ *StartRequest, logger slog.Logger) (NetworkStack, error) {
tun.WintunTunnelType = tunName
guid, err := windows.GUIDFromString("{0ed1515d-04a4-4c46-abae-11ad07cf0e6d}")
if err != nil {
panic(err)
}
tun.WintunStaticRequestedGUID = &guid
tunDev, tunName, err := tstunNewWithWindowsRetries(tailnet.Logger(logger.Named("net.tun.device")), tunName)
if err != nil {
return NetworkStack{}, xerrors.Errorf("create tun device: %w", err)
}
logger.Info(context.Background(), "tun created", slog.F("name", tunName))
wireguardMonitor, err := netmon.New(tailnet.Logger(logger.Named("net.wgmonitor")))
coderRouter, err := router.New(tailnet.Logger(logger.Named("net.router")), tunDev, wireguardMonitor)
if err != nil {
return NetworkStack{}, xerrors.Errorf("create router: %w", err)
}
dnsConfigurator, err := dns.NewOSConfigurator(tailnet.Logger(logger.Named("net.dns")), tunName)
if err != nil {
return NetworkStack{}, xerrors.Errorf("create dns configurator: %w", err)
}
return NetworkStack{
WireguardMonitor: nil, // default is fine
TUNDevice: tunDev,
Router: coderRouter,
DNSConfigurator: dnsConfigurator,
}, nil
}
// tstunNewOrRetry is a wrapper around tstun.New that retries on Windows for certain
// errors.
//
// This is taken from Tailscale:
// https://github.com/tailscale/tailscale/blob/3abfbf50aebbe3ba57dc749165edb56be6715c0a/cmd/tailscaled/tailscaled_windows.go#L107
func tstunNewWithWindowsRetries(logf logger.Logf, tunName string) (_ tun.Device, devName string, _ error) {
r := retry.New(250*time.Millisecond, 10*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
for r.Wait(ctx) {
dev, devName, err := tstun.New(logf, tunName)
if err == nil {
return dev, devName, err
}
if errors.Is(err, windows.ERROR_DEVICE_NOT_AVAILABLE) || windowsUptime() < 10*time.Minute {
// Wintun is not installing correctly. Dump the state of NetSetupSvc
// (which is a user-mode service that must be active for network devices
// to install) and its dependencies to the log.
winutil.LogSvcState(logf, "NetSetupSvc")
}
}
return nil, "", ctx.Err()
}
var (
kernel32 = windows.NewLazySystemDLL("kernel32.dll")
getTickCount64Proc = kernel32.NewProc("GetTickCount64")
)
func windowsUptime() time.Duration {
r, _, _ := getTickCount64Proc.Call()
return time.Duration(int64(r)) * time.Millisecond
}
// TODO(@dean): implement a way to install/uninstall the wintun driver, most
// likely as a CLI command
//
// This is taken from Tailscale:
// https://github.com/tailscale/tailscale/blob/3abfbf50aebbe3ba57dc749165edb56be6715c0a/cmd/tailscaled/tailscaled_windows.go#L543
func uninstallWinTun(logf logger.Logf) {
dll := windows.NewLazyDLL("wintun.dll")
if err := dll.Load(); err != nil {
logf("Cannot load wintun.dll for uninstall: %v", err)
return
}
logf("Removing wintun driver...")
err := wintun.Uninstall()
logf("Uninstall: %v", err)
}
// TODO(@dean): remove
var _ = uninstallWinTun

View File

@ -15,16 +15,16 @@ import (
"time"
"unicode"
"github.com/google/uuid"
"github.com/tailscale/wireguard-go/tun"
"golang.org/x/xerrors"
"google.golang.org/protobuf/types/known/timestamppb"
"tailscale.com/net/dns"
"tailscale.com/net/netmon"
"tailscale.com/util/dnsname"
"tailscale.com/wgengine/router"
"github.com/google/uuid"
"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/tailnet"
"github.com/coder/quartz"
)
@ -51,9 +51,8 @@ type Tunnel struct {
// option is used, to avoid the tunnel using itself as a sink for it's own
// logs, which could lead to deadlocks.
clientLogger slog.Logger
// router and dnsConfigurator may be nil
router router.Router
dnsConfigurator dns.OSConfigurator
// the following may be nil
networkingStackFn func(*Tunnel, *StartRequest, slog.Logger) (NetworkStack, error)
}
type TunnelOption func(t *Tunnel)
@ -169,21 +168,28 @@ func (t *Tunnel) handleRPC(req *request[*TunnelMessage, *ManagerMessage]) {
}
}
func UseAsRouter() TunnelOption {
type NetworkStack struct {
WireguardMonitor *netmon.Monitor
TUNDevice tun.Device
Router router.Router
DNSConfigurator dns.OSConfigurator
}
func UseOSNetworkingStack() TunnelOption {
return func(t *Tunnel) {
t.router = NewRouter(t)
t.networkingStackFn = GetNetworkingStack
}
}
func UseAsLogger() TunnelOption {
return func(t *Tunnel) {
t.clientLogger = slog.Make(t)
t.clientLogger = t.clientLogger.AppendSinks(t)
}
}
func UseAsDNSConfig() TunnelOption {
func UseCustomLogSinks(sinks ...slog.Sink) TunnelOption {
return func(t *Tunnel) {
t.dnsConfigurator = NewDNSConfigurator(t)
t.clientLogger = t.clientLogger.AppendSinks(sinks...)
}
}
@ -227,18 +233,28 @@ func (t *Tunnel) start(req *StartRequest) error {
for _, h := range req.GetHeaders() {
header.Add(h.GetName(), h.GetValue())
}
var networkingStack NetworkStack
if t.networkingStackFn != nil {
networkingStack, err = t.networkingStackFn(t, req, t.clientLogger)
if err != nil {
return xerrors.Errorf("failed to create networking stack dependencies: %w", err)
}
} else {
t.logger.Debug(t.ctx, "using default networking stack as no custom stack was provided")
}
conn, err := t.client.NewConn(
t.ctx,
svrURL,
apiToken,
&Options{
Headers: header,
Logger: t.clientLogger,
DNSConfigurator: t.dnsConfigurator,
Router: t.router,
TUNFileDescriptor: ptr.Ref(int(req.GetTunnelFileDescriptor())),
UpdateHandler: t,
Headers: header,
Logger: t.clientLogger,
DNSConfigurator: networkingStack.DNSConfigurator,
Router: networkingStack.Router,
TUNDevice: networkingStack.TUNDevice,
WireguardMonitor: networkingStack.WireguardMonitor,
UpdateHandler: t,
},
)
if err != nil {