mirror of
https://github.com/coder/coder.git
synced 2025-07-09 11:45:56 +00:00
98 lines
2.2 KiB
Go
98 lines
2.2 KiB
Go
package agent
|
|
|
|
import (
|
|
"context"
|
|
"net"
|
|
"strconv"
|
|
|
|
"golang.org/x/xerrors"
|
|
"inet.af/netaddr"
|
|
|
|
"cdr.dev/slog"
|
|
"github.com/coder/coder/peer/peerwg"
|
|
)
|
|
|
|
func (a *agent) startWireguard(ctx context.Context, addrs []netaddr.IPPrefix) error {
|
|
if a.network != nil {
|
|
_ = a.network.Close()
|
|
a.network = nil
|
|
}
|
|
|
|
// We can't create a wireguard network without these.
|
|
if len(addrs) == 0 || a.listenWireguardPeers == nil || a.postKeys == nil {
|
|
return xerrors.New("wireguard is enabled, but no addresses were provided or necessary functions were not provided")
|
|
}
|
|
|
|
wg, err := peerwg.New(a.logger.Named("wireguard"), addrs)
|
|
if err != nil {
|
|
return xerrors.Errorf("create wireguard network: %w", err)
|
|
}
|
|
|
|
// A new keypair is generated on each agent start.
|
|
// This keypair must be sent to Coder to allow for incoming connections.
|
|
err = a.postKeys(ctx, WireguardPublicKeys{
|
|
Public: wg.NodePrivateKey.Public(),
|
|
Disco: wg.DiscoPublicKey,
|
|
})
|
|
if err != nil {
|
|
a.logger.Warn(ctx, "post keys", slog.Error(err))
|
|
}
|
|
|
|
go func() {
|
|
for {
|
|
ch, listenClose, err := a.listenWireguardPeers(ctx, a.logger)
|
|
if err != nil {
|
|
a.logger.Warn(ctx, "listen wireguard peers", slog.Error(err))
|
|
return
|
|
}
|
|
|
|
for {
|
|
peer, ok := <-ch
|
|
if !ok {
|
|
break
|
|
}
|
|
|
|
err := wg.AddPeer(peer)
|
|
a.logger.Info(ctx, "added wireguard peer", slog.F("peer", peer.NodePublicKey.ShortString()), slog.Error(err))
|
|
}
|
|
|
|
listenClose()
|
|
}
|
|
}()
|
|
|
|
a.startWireguardListeners(ctx, wg, []handlerPort{
|
|
{port: 12212, handler: a.sshServer.HandleConn},
|
|
})
|
|
|
|
a.network = wg
|
|
return nil
|
|
}
|
|
|
|
type handlerPort struct {
|
|
handler func(conn net.Conn)
|
|
port uint16
|
|
}
|
|
|
|
func (a *agent) startWireguardListeners(ctx context.Context, network *peerwg.Network, handlers []handlerPort) {
|
|
for _, h := range handlers {
|
|
go func(h handlerPort) {
|
|
a.logger.Debug(ctx, "starting wireguard listener", slog.F("port", h.port))
|
|
|
|
listener, err := network.Listen("tcp", net.JoinHostPort("", strconv.Itoa(int(h.port))))
|
|
if err != nil {
|
|
a.logger.Warn(ctx, "listen wireguard", slog.F("port", h.port), slog.Error(err))
|
|
return
|
|
}
|
|
|
|
for {
|
|
conn, err := listener.Accept()
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
go h.handler(conn)
|
|
}
|
|
}(h)
|
|
}
|
|
}
|