Files
coder/cli/server_test.go
Hugo Dutka a68d11506c chore: track disabled telemetry (#16347)
Addresses https://github.com/coder/nexus/issues/116.

## Core Concept

Send one final telemetry report after the user disables telemetry with
the message that the telemetry was disabled. No other information about
the deployment is sent in this report.

This final report is submitted only if the deployment ever had telemetry
on.

## Changes

1. Refactored how our telemetry is initialized.
2. Introduced the `TelemetryEnabled` telemetry item, which allows to
decide whether a final report should be sent.
3. Added the `RecordTelemetryStatus` telemetry method, which decides
whether a final report should be sent and updates the telemetry item.
4. Added tests to ensure the implementation is correct.
2025-02-03 14:50:55 +01:00

2143 lines
63 KiB
Go

package cli_test
import (
"bufio"
"bytes"
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/json"
"encoding/pem"
"fmt"
"io"
"math/big"
"net"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"reflect"
"runtime"
"strconv"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/go-chi/chi/v5"
"github.com/spf13/pflag"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"gopkg.in/yaml.v3"
"tailscale.com/derp/derphttp"
"tailscale.com/types/key"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/buildinfo"
"github.com/coder/coder/v2/cli"
"github.com/coder/coder/v2/cli/clitest"
"github.com/coder/coder/v2/cli/config"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/database/migrations"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/telemetry"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/cryptorand"
"github.com/coder/coder/v2/pty/ptytest"
"github.com/coder/coder/v2/tailnet/tailnettest"
"github.com/coder/coder/v2/testutil"
)
func TestReadExternalAuthProvidersFromEnv(t *testing.T) {
t.Parallel()
t.Run("Valid", func(t *testing.T) {
t.Parallel()
providers, err := cli.ReadExternalAuthProvidersFromEnv([]string{
"CODER_EXTERNAL_AUTH_0_ID=1",
"CODER_EXTERNAL_AUTH_0_TYPE=gitlab",
"CODER_EXTERNAL_AUTH_1_ID=2",
"CODER_EXTERNAL_AUTH_1_CLIENT_ID=sid",
"CODER_EXTERNAL_AUTH_1_CLIENT_SECRET=hunter12",
"CODER_EXTERNAL_AUTH_1_TOKEN_URL=google.com",
"CODER_EXTERNAL_AUTH_1_VALIDATE_URL=bing.com",
"CODER_EXTERNAL_AUTH_1_SCOPES=repo:read repo:write",
"CODER_EXTERNAL_AUTH_1_NO_REFRESH=true",
"CODER_EXTERNAL_AUTH_1_DISPLAY_NAME=Google",
"CODER_EXTERNAL_AUTH_1_DISPLAY_ICON=/icon/google.svg",
})
require.NoError(t, err)
require.Len(t, providers, 2)
// Validate the first provider.
assert.Equal(t, "1", providers[0].ID)
assert.Equal(t, "gitlab", providers[0].Type)
// Validate the second provider.
assert.Equal(t, "2", providers[1].ID)
assert.Equal(t, "sid", providers[1].ClientID)
assert.Equal(t, "hunter12", providers[1].ClientSecret)
assert.Equal(t, "google.com", providers[1].TokenURL)
assert.Equal(t, "bing.com", providers[1].ValidateURL)
assert.Equal(t, []string{"repo:read", "repo:write"}, providers[1].Scopes)
assert.Equal(t, true, providers[1].NoRefresh)
assert.Equal(t, "Google", providers[1].DisplayName)
assert.Equal(t, "/icon/google.svg", providers[1].DisplayIcon)
})
}
// TestReadGitAuthProvidersFromEnv ensures that the deprecated `CODER_GITAUTH_`
// environment variables are still supported.
func TestReadGitAuthProvidersFromEnv(t *testing.T) {
t.Parallel()
t.Run("Empty", func(t *testing.T) {
t.Parallel()
providers, err := cli.ReadExternalAuthProvidersFromEnv([]string{
"HOME=/home/frodo",
})
require.NoError(t, err)
require.Empty(t, providers)
})
t.Run("InvalidKey", func(t *testing.T) {
t.Parallel()
providers, err := cli.ReadExternalAuthProvidersFromEnv([]string{
"CODER_GITAUTH_XXX=invalid",
})
require.Error(t, err, "providers: %+v", providers)
require.Empty(t, providers)
})
t.Run("SkipKey", func(t *testing.T) {
t.Parallel()
providers, err := cli.ReadExternalAuthProvidersFromEnv([]string{
"CODER_GITAUTH_0_ID=invalid",
"CODER_GITAUTH_2_ID=invalid",
})
require.Error(t, err, "%+v", providers)
require.Empty(t, providers)
})
t.Run("Valid", func(t *testing.T) {
t.Parallel()
providers, err := cli.ReadExternalAuthProvidersFromEnv([]string{
"CODER_GITAUTH_0_ID=1",
"CODER_GITAUTH_0_TYPE=gitlab",
"CODER_GITAUTH_1_ID=2",
"CODER_GITAUTH_1_CLIENT_ID=sid",
"CODER_GITAUTH_1_CLIENT_SECRET=hunter12",
"CODER_GITAUTH_1_TOKEN_URL=google.com",
"CODER_GITAUTH_1_VALIDATE_URL=bing.com",
"CODER_GITAUTH_1_SCOPES=repo:read repo:write",
"CODER_GITAUTH_1_NO_REFRESH=true",
})
require.NoError(t, err)
require.Len(t, providers, 2)
// Validate the first provider.
assert.Equal(t, "1", providers[0].ID)
assert.Equal(t, "gitlab", providers[0].Type)
// Validate the second provider.
assert.Equal(t, "2", providers[1].ID)
assert.Equal(t, "sid", providers[1].ClientID)
assert.Equal(t, "hunter12", providers[1].ClientSecret)
assert.Equal(t, "google.com", providers[1].TokenURL)
assert.Equal(t, "bing.com", providers[1].ValidateURL)
assert.Equal(t, []string{"repo:read", "repo:write"}, providers[1].Scopes)
assert.Equal(t, true, providers[1].NoRefresh)
})
}
func TestServer(t *testing.T) {
t.Parallel()
t.Run("BuiltinPostgres", func(t *testing.T) {
t.Parallel()
if testing.Short() {
t.SkipNow()
}
inv, cfg := clitest.New(t,
"server",
"--http-address", ":0",
"--access-url", "http://example.com",
"--cache-dir", t.TempDir(),
)
const superDuperLong = testutil.WaitSuperLong * 3
ctx := testutil.Context(t, superDuperLong)
clitest.Start(t, inv.WithContext(ctx))
//nolint:gocritic // Embedded postgres take a while to fire up.
require.Eventually(t, func() bool {
rawURL, err := cfg.URL().Read()
return err == nil && rawURL != ""
}, superDuperLong, testutil.IntervalFast, "failed to get access URL")
})
t.Run("EphemeralDeployment", func(t *testing.T) {
t.Parallel()
if testing.Short() {
t.SkipNow()
}
inv, _ := clitest.New(t,
"server",
"--http-address", ":0",
"--access-url", "http://example.com",
"--ephemeral",
)
pty := ptytest.New(t).Attach(inv)
// Embedded postgres takes a while to fire up.
const superDuperLong = testutil.WaitSuperLong * 3
ctx, cancelFunc := context.WithCancel(testutil.Context(t, superDuperLong))
errCh := make(chan error, 1)
go func() {
errCh <- inv.WithContext(ctx).Run()
}()
pty.ExpectMatch("Using an ephemeral deployment directory")
rootDirLine := pty.ReadLine(ctx)
rootDir := strings.TrimPrefix(rootDirLine, "Using an ephemeral deployment directory")
rootDir = strings.TrimSpace(rootDir)
rootDir = strings.TrimPrefix(rootDir, "(")
rootDir = strings.TrimSuffix(rootDir, ")")
require.NotEmpty(t, rootDir)
require.DirExists(t, rootDir)
pty.ExpectMatchContext(ctx, "View the Web UI")
cancelFunc()
<-errCh
require.NoDirExists(t, rootDir)
})
t.Run("BuiltinPostgresURL", func(t *testing.T) {
t.Parallel()
root, _ := clitest.New(t, "server", "postgres-builtin-url")
pty := ptytest.New(t)
root.Stdout = pty.Output()
err := root.Run()
require.NoError(t, err)
pty.ExpectMatch("psql")
})
t.Run("BuiltinPostgresURLRaw", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
root, _ := clitest.New(t, "server", "postgres-builtin-url", "--raw-url")
pty := ptytest.New(t)
root.Stdout = pty.Output()
err := root.WithContext(ctx).Run()
require.NoError(t, err)
got := pty.ReadLine(ctx)
if !strings.HasPrefix(got, "postgres://") {
t.Fatalf("expected postgres URL to start with \"postgres://\", got %q", got)
}
})
// Validate that a warning is printed that it may not be externally
// reachable.
t.Run("LocalAccessURL", func(t *testing.T) {
t.Parallel()
inv, cfg := clitest.New(t,
"server",
"--in-memory",
"--http-address", ":0",
"--access-url", "http://localhost:3000/",
"--cache-dir", t.TempDir(),
)
pty := ptytest.New(t).Attach(inv)
clitest.Start(t, inv)
// Just wait for startup
_ = waitAccessURL(t, cfg)
pty.ExpectMatch("this may cause unexpected problems when creating workspaces")
pty.ExpectMatch("View the Web UI:")
pty.ExpectMatch("http://localhost:3000/")
})
// Validate that an https scheme is prepended to a remote access URL
// and that a warning is printed for a host that cannot be resolved.
t.Run("RemoteAccessURL", func(t *testing.T) {
t.Parallel()
inv, cfg := clitest.New(t,
"server",
"--in-memory",
"--http-address", ":0",
"--access-url", "https://foobarbaz.mydomain",
"--cache-dir", t.TempDir(),
)
pty := ptytest.New(t).Attach(inv)
clitest.Start(t, inv)
// Just wait for startup
_ = waitAccessURL(t, cfg)
pty.ExpectMatch("this may cause unexpected problems when creating workspaces")
pty.ExpectMatch("View the Web UI:")
pty.ExpectMatch("https://foobarbaz.mydomain")
})
t.Run("NoWarningWithRemoteAccessURL", func(t *testing.T) {
t.Parallel()
inv, cfg := clitest.New(t,
"server",
"--in-memory",
"--http-address", ":0",
"--access-url", "https://google.com",
"--cache-dir", t.TempDir(),
)
pty := ptytest.New(t).Attach(inv)
clitest.Start(t, inv)
// Just wait for startup
_ = waitAccessURL(t, cfg)
pty.ExpectMatch("View the Web UI:")
pty.ExpectMatch("https://google.com")
})
t.Run("NoSchemeAccessURL", func(t *testing.T) {
t.Parallel()
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
root, _ := clitest.New(t,
"server",
"--in-memory",
"--http-address", ":0",
"--access-url", "google.com",
"--cache-dir", t.TempDir(),
)
err := root.WithContext(ctx).Run()
require.Error(t, err)
})
t.Run("TLSBadVersion", func(t *testing.T) {
t.Parallel()
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
root, _ := clitest.New(t,
"server",
"--in-memory",
"--http-address", "",
"--access-url", "http://example.com",
"--tls-enable",
"--tls-address", ":0",
"--tls-min-version", "tls9",
"--cache-dir", t.TempDir(),
)
err := root.WithContext(ctx).Run()
require.Error(t, err)
})
t.Run("TLSBadClientAuth", func(t *testing.T) {
t.Parallel()
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
root, _ := clitest.New(t,
"server",
"--in-memory",
"--http-address", "",
"--access-url", "http://example.com",
"--tls-enable",
"--tls-address", ":0",
"--tls-client-auth", "something",
"--cache-dir", t.TempDir(),
)
err := root.WithContext(ctx).Run()
require.Error(t, err)
})
t.Run("TLSInvalid", func(t *testing.T) {
t.Parallel()
cert1Path, key1Path := generateTLSCertificate(t)
cert2Path, key2Path := generateTLSCertificate(t)
cases := []struct {
name string
args []string
errContains string
}{
{
name: "NoCert",
args: []string{"--tls-enable", "--tls-key-file", key1Path},
errContains: "--tls-cert-file and --tls-key-file must be used the same amount of times",
},
{
name: "NoKey",
args: []string{"--tls-enable", "--tls-cert-file", cert1Path},
errContains: "--tls-cert-file and --tls-key-file must be used the same amount of times",
},
{
name: "MismatchedCount",
args: []string{"--tls-enable", "--tls-cert-file", cert1Path, "--tls-key-file", key1Path, "--tls-cert-file", cert2Path},
errContains: "--tls-cert-file and --tls-key-file must be used the same amount of times",
},
{
name: "MismatchedCertAndKey",
args: []string{"--tls-enable", "--tls-cert-file", cert1Path, "--tls-key-file", key2Path},
errContains: "load TLS key pair",
},
}
for _, c := range cases {
c := c
t.Run(c.name, func(t *testing.T) {
t.Parallel()
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
args := []string{
"server",
"--in-memory",
"--http-address", ":0",
"--access-url", "http://example.com",
"--cache-dir", t.TempDir(),
}
args = append(args, c.args...)
root, _ := clitest.New(t, args...)
err := root.WithContext(ctx).Run()
require.Error(t, err)
t.Logf("args: %v", args)
require.ErrorContains(t, err, c.errContains)
})
}
})
t.Run("TLSValid", func(t *testing.T) {
t.Parallel()
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
certPath, keyPath := generateTLSCertificate(t)
root, cfg := clitest.New(t,
"server",
"--in-memory",
"--http-address", "",
"--access-url", "https://example.com",
"--tls-enable",
"--tls-address", ":0",
"--tls-cert-file", certPath,
"--tls-key-file", keyPath,
"--cache-dir", t.TempDir(),
)
clitest.Start(t, root.WithContext(ctx))
// Verify HTTPS
accessURL := waitAccessURL(t, cfg)
require.Equal(t, "https", accessURL.Scheme)
client := codersdk.New(accessURL)
client.HTTPClient = &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
//nolint:gosec
InsecureSkipVerify: true,
},
},
}
defer client.HTTPClient.CloseIdleConnections()
_, err := client.HasFirstUser(ctx)
require.NoError(t, err)
})
t.Run("TLSValidMultiple", func(t *testing.T) {
t.Parallel()
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
cert1Path, key1Path := generateTLSCertificate(t, "alpaca.com")
cert2Path, key2Path := generateTLSCertificate(t, "*.llama.com")
root, cfg := clitest.New(t,
"server",
"--in-memory",
"--http-address", "",
"--access-url", "https://example.com",
"--tls-enable",
"--tls-address", ":0",
"--tls-cert-file", cert1Path,
"--tls-key-file", key1Path,
"--tls-cert-file", cert2Path,
"--tls-key-file", key2Path,
"--cache-dir", t.TempDir(),
)
pty := ptytest.New(t)
root.Stdout = pty.Output()
clitest.Start(t, root.WithContext(ctx))
accessURL := waitAccessURL(t, cfg)
require.Equal(t, "https", accessURL.Scheme)
originalHost := accessURL.Host
var (
expectAddr string
dials int64
)
client := codersdk.New(accessURL)
client.HTTPClient = &http.Client{
Transport: &http.Transport{
DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
atomic.AddInt64(&dials, 1)
assert.Equal(t, expectAddr, addr)
host, _, err := net.SplitHostPort(addr)
require.NoError(t, err)
// Always connect to the accessURL ip:port regardless of
// hostname.
conn, err := tls.Dial(network, originalHost, &tls.Config{
MinVersion: tls.VersionTLS12,
//nolint:gosec
InsecureSkipVerify: true,
ServerName: host,
})
if err != nil {
return nil, err
}
// We can't call conn.VerifyHostname because it requires
// that the certificates are valid, so we call
// VerifyHostname on the first certificate instead.
require.Len(t, conn.ConnectionState().PeerCertificates, 1)
err = conn.ConnectionState().PeerCertificates[0].VerifyHostname(host)
assert.NoError(t, err, "invalid cert common name")
return conn, nil
},
},
}
defer client.HTTPClient.CloseIdleConnections()
// Use the first certificate and hostname.
client.URL.Host = "alpaca.com:443"
expectAddr = "alpaca.com:443"
_, err := client.HasFirstUser(ctx)
require.NoError(t, err)
require.EqualValues(t, 1, atomic.LoadInt64(&dials))
// Use the second certificate (wildcard) and hostname.
client.URL.Host = "hi.llama.com:443"
expectAddr = "hi.llama.com:443"
_, err = client.HasFirstUser(ctx)
require.NoError(t, err)
require.EqualValues(t, 2, atomic.LoadInt64(&dials))
})
t.Run("TLSAndHTTP", func(t *testing.T) {
t.Parallel()
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
certPath, keyPath := generateTLSCertificate(t)
inv, _ := clitest.New(t,
"server",
"--in-memory",
"--http-address", ":0",
"--access-url", "https://example.com",
"--tls-enable",
"--tls-redirect-http-to-https=false",
"--tls-address", ":0",
"--tls-cert-file", certPath,
"--tls-key-file", keyPath,
"--cache-dir", t.TempDir(),
)
pty := ptytest.New(t).Attach(inv)
clitest.Start(t, inv)
// We can't use waitAccessURL as it will only return the HTTP URL.
const httpLinePrefix = "Started HTTP listener at"
pty.ExpectMatch(httpLinePrefix)
httpLine := pty.ReadLine(ctx)
httpAddr := strings.TrimSpace(strings.TrimPrefix(httpLine, httpLinePrefix))
require.NotEmpty(t, httpAddr)
const tlsLinePrefix = "Started TLS/HTTPS listener at "
pty.ExpectMatch(tlsLinePrefix)
tlsLine := pty.ReadLine(ctx)
tlsAddr := strings.TrimSpace(strings.TrimPrefix(tlsLine, tlsLinePrefix))
require.NotEmpty(t, tlsAddr)
// Verify HTTP
httpURL, err := url.Parse(httpAddr)
require.NoError(t, err)
client := codersdk.New(httpURL)
client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
}
_, err = client.HasFirstUser(ctx)
require.NoError(t, err)
// Verify TLS
tlsURL, err := url.Parse(tlsAddr)
require.NoError(t, err)
client = codersdk.New(tlsURL)
client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
}
client.HTTPClient = &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
//nolint:gosec
InsecureSkipVerify: true,
},
},
}
defer client.HTTPClient.CloseIdleConnections()
_, err = client.HasFirstUser(ctx)
require.NoError(t, err)
})
t.Run("TLSRedirect", func(t *testing.T) {
t.Parallel()
cases := []struct {
name string
httpListener bool
tlsListener bool
redirect bool
accessURL string
requestURL string
// Empty string means no redirect.
expectRedirect string
}{
{
name: "OK",
httpListener: true,
tlsListener: true,
redirect: true,
accessURL: "https://example.com",
expectRedirect: "https://example.com",
},
{
name: "NoRedirect",
httpListener: true,
tlsListener: true,
accessURL: "https://example.com",
expectRedirect: "",
},
{
name: "NoRedirectWithWildcard",
tlsListener: true,
accessURL: "https://example.com",
requestURL: "https://dev.example.com",
expectRedirect: "",
redirect: true,
},
{
name: "NoTLSListener",
httpListener: true,
tlsListener: false,
accessURL: "https://example.com",
expectRedirect: "",
},
{
name: "NoHTTPListener",
httpListener: false,
tlsListener: true,
accessURL: "https://example.com",
expectRedirect: "",
},
}
for _, c := range cases {
c := c
t.Run(c.name, func(t *testing.T) {
t.Parallel()
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
if c.requestURL == "" {
c.requestURL = c.accessURL
}
httpListenAddr := ""
if c.httpListener {
httpListenAddr = ":0"
}
certPath, keyPath := generateTLSCertificate(t)
flags := []string{
"server",
"--in-memory",
"--cache-dir", t.TempDir(),
"--http-address", httpListenAddr,
}
if c.tlsListener {
flags = append(flags,
"--tls-enable",
"--tls-address", ":0",
"--tls-cert-file", certPath,
"--tls-key-file", keyPath,
"--wildcard-access-url", "*.example.com",
)
}
if c.accessURL != "" {
flags = append(flags, "--access-url", c.accessURL)
}
if c.redirect {
flags = append(flags, "--redirect-to-access-url")
}
inv, _ := clitest.New(t, flags...)
pty := ptytest.New(t)
pty.Attach(inv)
clitest.Start(t, inv)
var (
httpAddr string
tlsAddr string
)
// We can't use waitAccessURL as it will only return the HTTP URL.
if c.httpListener {
const httpLinePrefix = "Started HTTP listener at"
pty.ExpectMatch(httpLinePrefix)
httpLine := pty.ReadLine(ctx)
httpAddr = strings.TrimSpace(strings.TrimPrefix(httpLine, httpLinePrefix))
require.NotEmpty(t, httpAddr)
}
if c.tlsListener {
const tlsLinePrefix = "Started TLS/HTTPS listener at"
pty.ExpectMatch(tlsLinePrefix)
tlsLine := pty.ReadLine(ctx)
tlsAddr = strings.TrimSpace(strings.TrimPrefix(tlsLine, tlsLinePrefix))
require.NotEmpty(t, tlsAddr)
}
// Verify HTTP redirects (or not)
if c.httpListener {
httpURL, err := url.Parse(httpAddr)
require.NoError(t, err)
client := codersdk.New(httpURL)
client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
}
resp, err := client.Request(ctx, http.MethodGet, "/api/v2/buildinfo", nil)
require.NoError(t, err)
defer resp.Body.Close()
if c.expectRedirect == "" {
require.Equal(t, http.StatusOK, resp.StatusCode)
} else {
require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
require.Equal(t, c.expectRedirect, resp.Header.Get("Location"))
}
// We should never readirect /healthz
respHealthz, err := client.Request(ctx, http.MethodGet, "/healthz", nil)
require.NoError(t, err)
defer respHealthz.Body.Close()
require.Equal(t, http.StatusOK, respHealthz.StatusCode, "/healthz should never redirect")
// We should never redirect DERP
respDERP, err := client.Request(ctx, http.MethodGet, "/derp", nil)
require.NoError(t, err)
defer respDERP.Body.Close()
require.Equal(t, http.StatusUpgradeRequired, respDERP.StatusCode, "/derp should never redirect")
}
// Verify TLS
if c.tlsListener {
accessURLParsed, err := url.Parse(c.requestURL)
require.NoError(t, err)
client := &http.Client{
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
Transport: &http.Transport{
DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return tls.Dial(network, strings.TrimPrefix(tlsAddr, "https://"), &tls.Config{
// nolint:gosec
InsecureSkipVerify: true,
})
},
},
}
defer client.CloseIdleConnections()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, accessURLParsed.String(), nil)
require.NoError(t, err)
resp, err := client.Do(req)
// We don't care much about the response, just that TLS
// worked.
require.NoError(t, err)
defer resp.Body.Close()
}
})
}
})
t.Run("CanListenUnspecifiedv4", func(t *testing.T) {
t.Parallel()
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
root, _ := clitest.New(t,
"server",
"--in-memory",
"--http-address", "0.0.0.0:0",
"--access-url", "http://example.com",
)
pty := ptytest.New(t)
root.Stdout = pty.Output()
root.Stderr = pty.Output()
serverStop := make(chan error, 1)
go func() {
err := root.WithContext(ctx).Run()
if err != nil {
t.Error(err)
}
close(serverStop)
}()
pty.ExpectMatch("Started HTTP listener")
pty.ExpectMatch("http://0.0.0.0:")
cancelFunc()
<-serverStop
})
t.Run("CanListenUnspecifiedv6", func(t *testing.T) {
t.Parallel()
inv, _ := clitest.New(t,
"server",
"--in-memory",
"--http-address", "[::]:0",
"--access-url", "http://example.com",
)
pty := ptytest.New(t).Attach(inv)
clitest.Start(t, inv)
pty.ExpectMatch("Started HTTP listener at")
pty.ExpectMatch("http://[::]:")
})
t.Run("NoAddress", func(t *testing.T) {
t.Parallel()
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
inv, _ := clitest.New(t,
"server",
"--in-memory",
"--http-address", ":80",
"--tls-enable=false",
"--tls-address", "",
)
err := inv.WithContext(ctx).Run()
require.Error(t, err)
require.ErrorContains(t, err, "tls-address")
})
t.Run("NoTLSAddress", func(t *testing.T) {
t.Parallel()
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
inv, _ := clitest.New(t,
"server",
"--in-memory",
"--tls-enable=true",
"--tls-address", "",
)
err := inv.WithContext(ctx).Run()
require.Error(t, err)
require.ErrorContains(t, err, "must not be empty")
})
// DeprecatedAddress is a test for the deprecated --address flag. If
// specified, --http-address and --tls-address are both ignored, a warning
// is printed, and the server will either be HTTP-only or TLS-only depending
// on if --tls-enable is set.
t.Run("DeprecatedAddress", func(t *testing.T) {
t.Parallel()
t.Run("HTTP", func(t *testing.T) {
t.Parallel()
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
inv, cfg := clitest.New(t,
"server",
"--in-memory",
"--address", ":0",
"--access-url", "http://example.com",
"--cache-dir", t.TempDir(),
)
pty := ptytest.New(t)
inv.Stdout = pty.Output()
inv.Stderr = pty.Output()
clitest.Start(t, inv.WithContext(ctx))
pty.ExpectMatch("is deprecated")
accessURL := waitAccessURL(t, cfg)
require.Equal(t, "http", accessURL.Scheme)
client := codersdk.New(accessURL)
_, err := client.HasFirstUser(ctx)
require.NoError(t, err)
})
t.Run("TLS", func(t *testing.T) {
t.Parallel()
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
certPath, keyPath := generateTLSCertificate(t)
root, cfg := clitest.New(t,
"server",
"--in-memory",
"--address", ":0",
"--access-url", "https://example.com",
"--tls-enable",
"--tls-cert-file", certPath,
"--tls-key-file", keyPath,
"--cache-dir", t.TempDir(),
)
pty := ptytest.New(t)
root.Stdout = pty.Output()
root.Stderr = pty.Output()
clitest.Start(t, root.WithContext(ctx))
pty.ExpectMatch("is deprecated")
accessURL := waitAccessURL(t, cfg)
require.Equal(t, "https", accessURL.Scheme)
client := codersdk.New(accessURL)
client.HTTPClient = &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
//nolint:gosec
InsecureSkipVerify: true,
},
},
}
defer client.HTTPClient.CloseIdleConnections()
_, err := client.HasFirstUser(ctx)
require.NoError(t, err)
})
})
t.Run("TracerNoLeak", func(t *testing.T) {
t.Parallel()
inv, _ := clitest.New(t,
"server",
"--in-memory",
"--http-address", ":0",
"--access-url", "http://example.com",
"--trace=true",
"--cache-dir", t.TempDir(),
)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
clitest.Start(t, inv.WithContext(ctx))
cancel()
require.Error(t, goleak.Find())
})
t.Run("Telemetry", func(t *testing.T) {
t.Parallel()
telemetryServerURL, deployment, snapshot := mockTelemetryServer(t)
inv, cfg := clitest.New(t,
"server",
"--in-memory",
"--http-address", ":0",
"--access-url", "http://example.com",
"--telemetry",
"--telemetry-url", telemetryServerURL.String(),
"--cache-dir", t.TempDir(),
)
clitest.Start(t, inv)
<-deployment
<-snapshot
accessURL := waitAccessURL(t, cfg)
ctx := testutil.Context(t, testutil.WaitMedium)
client := codersdk.New(accessURL)
body, err := client.Request(ctx, http.MethodGet, "/", nil)
require.NoError(t, err)
require.NoError(t, body.Body.Close())
require.Eventually(t, func() bool {
snap := <-snapshot
htmlFirstServedFound := false
for _, item := range snap.TelemetryItems {
if item.Key == string(telemetry.TelemetryItemKeyHTMLFirstServedAt) {
htmlFirstServedFound = true
}
}
return htmlFirstServedFound
}, testutil.WaitMedium, testutil.IntervalFast, "no html_first_served telemetry item")
})
t.Run("Prometheus", func(t *testing.T) {
t.Parallel()
t.Run("DBMetricsDisabled", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
randPort := testutil.RandomPort(t)
inv, cfg := clitest.New(t,
"server",
"--in-memory",
"--http-address", ":0",
"--access-url", "http://example.com",
"--provisioner-daemons", "1",
"--prometheus-enable",
"--prometheus-address", ":"+strconv.Itoa(randPort),
// "--prometheus-collect-db-metrics", // disabled by default
"--cache-dir", t.TempDir(),
)
clitest.Start(t, inv)
_ = waitAccessURL(t, cfg)
var res *http.Response
require.Eventually(t, func() bool {
req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("http://127.0.0.1:%d", randPort), nil)
assert.NoError(t, err)
// nolint:bodyclose
res, err = http.DefaultClient.Do(req)
if err != nil {
return false
}
defer res.Body.Close()
scanner := bufio.NewScanner(res.Body)
hasActiveUsers := false
for scanner.Scan() {
// This metric is manually registered to be tracked in the server. That's
// why we test it's tracked here.
if strings.HasPrefix(scanner.Text(), "coderd_api_active_users_duration_hour") {
hasActiveUsers = true
continue
}
if strings.HasPrefix(scanner.Text(), "coderd_db_query_latencies_seconds") {
t.Fatal("db metrics should not be tracked when --prometheus-collect-db-metrics is not enabled")
}
t.Logf("scanned %s", scanner.Text())
}
if scanner.Err() != nil {
t.Logf("scanner err: %s", scanner.Err().Error())
return false
}
return hasActiveUsers
}, testutil.WaitShort, testutil.IntervalFast, "didn't find coderd_api_active_users_duration_hour in time")
})
t.Run("DBMetricsEnabled", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
randPort := testutil.RandomPort(t)
inv, cfg := clitest.New(t,
"server",
"--in-memory",
"--http-address", ":0",
"--access-url", "http://example.com",
"--provisioner-daemons", "1",
"--prometheus-enable",
"--prometheus-address", ":"+strconv.Itoa(randPort),
"--prometheus-collect-db-metrics",
"--cache-dir", t.TempDir(),
)
clitest.Start(t, inv)
_ = waitAccessURL(t, cfg)
var res *http.Response
require.Eventually(t, func() bool {
req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("http://127.0.0.1:%d", randPort), nil)
assert.NoError(t, err)
// nolint:bodyclose
res, err = http.DefaultClient.Do(req)
if err != nil {
return false
}
defer res.Body.Close()
scanner := bufio.NewScanner(res.Body)
hasDBMetrics := false
for scanner.Scan() {
if strings.HasPrefix(scanner.Text(), "coderd_db_query_latencies_seconds") {
hasDBMetrics = true
}
t.Logf("scanned %s", scanner.Text())
}
if scanner.Err() != nil {
t.Logf("scanner err: %s", scanner.Err().Error())
return false
}
return hasDBMetrics
}, testutil.WaitShort, testutil.IntervalFast, "didn't find coderd_db_query_latencies_seconds in time")
})
})
t.Run("GitHubOAuth", func(t *testing.T) {
t.Parallel()
fakeRedirect := "https://fake-url.com"
inv, cfg := clitest.New(t,
"server",
"--in-memory",
"--http-address", ":0",
"--access-url", "http://example.com",
"--oauth2-github-allow-everyone",
"--oauth2-github-client-id", "fake",
"--oauth2-github-client-secret", "fake",
"--oauth2-github-enterprise-base-url", fakeRedirect,
)
clitest.Start(t, inv)
accessURL := waitAccessURL(t, cfg)
client := codersdk.New(accessURL)
client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
}
githubURL, err := accessURL.Parse("/api/v2/users/oauth2/github")
require.NoError(t, err)
req, err := http.NewRequestWithContext(inv.Context(), http.MethodGet, githubURL.String(), nil)
require.NoError(t, err)
res, err := client.HTTPClient.Do(req)
require.NoError(t, err)
defer res.Body.Close()
fakeURL, err := res.Location()
require.NoError(t, err)
require.True(t, strings.HasPrefix(fakeURL.String(), fakeRedirect), fakeURL.String())
})
t.Run("OIDC", func(t *testing.T) {
t.Parallel()
t.Run("Defaults", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
// Startup a fake server that just responds to .well-known/openid-configuration
// This is just needed to get Coder to start up.
oidcServer := httptest.NewServer(nil)
fakeWellKnownHandler := func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
payload := fmt.Sprintf("{\"issuer\": %q}", oidcServer.URL)
_, _ = w.Write([]byte(payload))
}
oidcServer.Config.Handler = http.HandlerFunc(fakeWellKnownHandler)
t.Cleanup(oidcServer.Close)
inv, cfg := clitest.New(t,
"server",
"--in-memory",
"--http-address", ":0",
"--access-url", "http://example.com",
"--oidc-client-id", "fake",
"--oidc-client-secret", "fake",
"--oidc-issuer-url", oidcServer.URL,
// Leaving the rest of the flags as defaults.
)
// Ensure that the server starts up without error.
clitest.Start(t, inv)
accessURL := waitAccessURL(t, cfg)
client := codersdk.New(accessURL)
randPassword, err := cryptorand.String(24)
require.NoError(t, err)
_, err = client.CreateFirstUser(ctx, codersdk.CreateFirstUserRequest{
Email: "admin@coder.com",
Password: randPassword,
Username: "admin",
Trial: true,
})
require.NoError(t, err)
loginResp, err := client.LoginWithPassword(ctx, codersdk.LoginWithPasswordRequest{
Email: "admin@coder.com",
Password: randPassword,
})
require.NoError(t, err)
client.SetSessionToken(loginResp.SessionToken)
deploymentConfig, err := client.DeploymentConfig(ctx)
require.NoError(t, err)
// Ensure that the OIDC provider is configured correctly.
require.Equal(t, "fake", deploymentConfig.Values.OIDC.ClientID.Value())
// The client secret is not returned from the API.
require.Empty(t, deploymentConfig.Values.OIDC.ClientSecret.Value())
require.Equal(t, oidcServer.URL, deploymentConfig.Values.OIDC.IssuerURL.Value())
// These are the default values returned from the API. See codersdk/deployment.go for the default values.
require.True(t, deploymentConfig.Values.OIDC.AllowSignups.Value())
require.Empty(t, deploymentConfig.Values.OIDC.EmailDomain.Value())
require.Equal(t, []string{"openid", "profile", "email"}, deploymentConfig.Values.OIDC.Scopes.Value())
require.False(t, deploymentConfig.Values.OIDC.IgnoreEmailVerified.Value())
require.Equal(t, "preferred_username", deploymentConfig.Values.OIDC.UsernameField.Value())
require.Equal(t, "email", deploymentConfig.Values.OIDC.EmailField.Value())
require.Equal(t, map[string]string{"access_type": "offline"}, deploymentConfig.Values.OIDC.AuthURLParams.Value)
require.False(t, deploymentConfig.Values.OIDC.IgnoreUserInfo.Value())
require.Empty(t, deploymentConfig.Values.OIDC.GroupField.Value())
require.Empty(t, deploymentConfig.Values.OIDC.GroupMapping.Value)
require.Empty(t, deploymentConfig.Values.OIDC.UserRoleField.Value())
require.Empty(t, deploymentConfig.Values.OIDC.UserRoleMapping.Value)
require.Equal(t, "OpenID Connect", deploymentConfig.Values.OIDC.SignInText.Value())
require.Empty(t, deploymentConfig.Values.OIDC.IconURL.Value())
})
t.Run("Overrides", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium)
defer cancel()
// Startup a fake server that just responds to .well-known/openid-configuration
// This is just needed to get Coder to start up.
oidcServer := httptest.NewServer(nil)
fakeWellKnownHandler := func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
payload := fmt.Sprintf("{\"issuer\": %q}", oidcServer.URL)
_, _ = w.Write([]byte(payload))
}
oidcServer.Config.Handler = http.HandlerFunc(fakeWellKnownHandler)
t.Cleanup(oidcServer.Close)
inv, cfg := clitest.New(t,
"server",
"--in-memory",
"--http-address", ":0",
"--access-url", "http://example.com",
"--oidc-client-id", "fake",
"--oidc-client-secret", "fake",
"--oidc-issuer-url", oidcServer.URL,
// The following values have defaults that we want to override.
"--oidc-allow-signups=false",
"--oidc-email-domain", "example.com",
"--oidc-scopes", "360noscope",
"--oidc-ignore-email-verified",
"--oidc-username-field", "not_preferred_username",
"--oidc-email-field", "not_email",
"--oidc-auth-url-params", `{"prompt":"consent"}`,
"--oidc-ignore-userinfo",
"--oidc-group-field", "serious_business_unit",
"--oidc-group-mapping", `{"serious_business_unit": "serious_business_unit"}`,
"--oidc-sign-in-text", "Sign In With Coder",
"--oidc-icon-url", "https://example.com/icon.png",
)
// Ensure that the server starts up without error.
clitest.Start(t, inv)
accessURL := waitAccessURL(t, cfg)
client := codersdk.New(accessURL)
randPassword, err := cryptorand.String(24)
require.NoError(t, err)
_, err = client.CreateFirstUser(ctx, codersdk.CreateFirstUserRequest{
Email: "admin@coder.com",
Password: randPassword,
Username: "admin",
Trial: true,
})
require.NoError(t, err)
loginResp, err := client.LoginWithPassword(ctx, codersdk.LoginWithPasswordRequest{
Email: "admin@coder.com",
Password: randPassword,
})
require.NoError(t, err)
client.SetSessionToken(loginResp.SessionToken)
deploymentConfig, err := client.DeploymentConfig(ctx)
require.NoError(t, err)
// Ensure that the OIDC provider is configured correctly.
require.Equal(t, "fake", deploymentConfig.Values.OIDC.ClientID.Value())
// The client secret is not returned from the API.
require.Empty(t, deploymentConfig.Values.OIDC.ClientSecret.Value())
require.Equal(t, oidcServer.URL, deploymentConfig.Values.OIDC.IssuerURL.Value())
// These are values that we want to make sure were overridden.
require.False(t, deploymentConfig.Values.OIDC.AllowSignups.Value())
require.Equal(t, []string{"example.com"}, deploymentConfig.Values.OIDC.EmailDomain.Value())
require.Equal(t, []string{"360noscope"}, deploymentConfig.Values.OIDC.Scopes.Value())
require.True(t, deploymentConfig.Values.OIDC.IgnoreEmailVerified.Value())
require.Equal(t, "not_preferred_username", deploymentConfig.Values.OIDC.UsernameField.Value())
require.Equal(t, "not_email", deploymentConfig.Values.OIDC.EmailField.Value())
require.True(t, deploymentConfig.Values.OIDC.IgnoreUserInfo.Value())
require.Equal(t, map[string]string{"prompt": "consent"}, deploymentConfig.Values.OIDC.AuthURLParams.Value)
require.Equal(t, "serious_business_unit", deploymentConfig.Values.OIDC.GroupField.Value())
require.Equal(t, map[string]string{"serious_business_unit": "serious_business_unit"}, deploymentConfig.Values.OIDC.GroupMapping.Value)
require.Equal(t, "Sign In With Coder", deploymentConfig.Values.OIDC.SignInText.Value())
require.Equal(t, "https://example.com/icon.png", deploymentConfig.Values.OIDC.IconURL.Value().String())
// Verify the option values
for _, opt := range deploymentConfig.Options {
switch opt.Flag {
case "access-url":
require.Equal(t, "http://example.com", opt.Value.String())
case "oidc-icon-url":
require.Equal(t, "https://example.com/icon.png", opt.Value.String())
case "oidc-sign-in-text":
require.Equal(t, "Sign In With Coder", opt.Value.String())
case "redirect-to-access-url":
require.Equal(t, "false", opt.Value.String())
case "derp-server-region-id":
require.Equal(t, "999", opt.Value.String())
}
}
})
})
t.Run("RateLimit", func(t *testing.T) {
t.Parallel()
t.Run("Default", func(t *testing.T) {
t.Parallel()
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
root, cfg := clitest.New(t,
"server",
"--in-memory",
"--http-address", ":0",
"--access-url", "http://example.com",
)
serverErr := make(chan error, 1)
go func() {
serverErr <- root.WithContext(ctx).Run()
}()
accessURL := waitAccessURL(t, cfg)
client := codersdk.New(accessURL)
resp, err := client.Request(ctx, http.MethodGet, "/api/v2/buildinfo", nil)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)
require.Equal(t, "512", resp.Header.Get("X-Ratelimit-Limit"))
cancelFunc()
<-serverErr
})
t.Run("Changed", func(t *testing.T) {
t.Parallel()
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
val := "100"
root, cfg := clitest.New(t,
"server",
"--in-memory",
"--http-address", ":0",
"--access-url", "http://example.com",
"--api-rate-limit", val,
)
serverErr := make(chan error, 1)
go func() {
serverErr <- root.WithContext(ctx).Run()
}()
accessURL := waitAccessURL(t, cfg)
client := codersdk.New(accessURL)
resp, err := client.Request(ctx, http.MethodGet, "/api/v2/buildinfo", nil)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)
require.Equal(t, val, resp.Header.Get("X-Ratelimit-Limit"))
cancelFunc()
<-serverErr
})
t.Run("Disabled", func(t *testing.T) {
t.Parallel()
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
root, cfg := clitest.New(t,
"server",
"--in-memory",
"--http-address", ":0",
"--access-url", "http://example.com",
"--api-rate-limit", "-1",
)
serverErr := make(chan error, 1)
go func() {
serverErr <- root.WithContext(ctx).Run()
}()
accessURL := waitAccessURL(t, cfg)
client := codersdk.New(accessURL)
resp, err := client.Request(ctx, http.MethodGet, "/api/v2/buildinfo", nil)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)
require.Equal(t, "", resp.Header.Get("X-Ratelimit-Limit"))
cancelFunc()
<-serverErr
})
})
t.Run("Logging", func(t *testing.T) {
t.Parallel()
t.Run("CreatesFile", func(t *testing.T) {
t.Parallel()
fiName := testutil.TempFile(t, "", "coder-logging-test-*")
root, _ := clitest.New(t,
"server",
"--log-filter=.*",
"--in-memory",
"--http-address", ":0",
"--access-url", "http://example.com",
"--provisioner-daemons=3",
"--provisioner-types=echo",
"--log-human", fiName,
)
clitest.Start(t, root)
loggingWaitFile(t, fiName, testutil.WaitLong)
})
t.Run("Human", func(t *testing.T) {
t.Parallel()
fi := testutil.TempFile(t, "", "coder-logging-test-*")
root, _ := clitest.New(t,
"server",
"--log-filter=.*",
"--in-memory",
"--http-address", ":0",
"--access-url", "http://example.com",
"--provisioner-daemons=3",
"--provisioner-types=echo",
"--log-human", fi,
)
clitest.Start(t, root)
loggingWaitFile(t, fi, testutil.WaitShort)
})
t.Run("JSON", func(t *testing.T) {
t.Parallel()
fi := testutil.TempFile(t, "", "coder-logging-test-*")
root, _ := clitest.New(t,
"server",
"--log-filter=.*",
"--in-memory",
"--http-address", ":0",
"--access-url", "http://example.com",
"--provisioner-daemons=3",
"--provisioner-types=echo",
"--log-json", fi,
)
clitest.Start(t, root)
loggingWaitFile(t, fi, testutil.WaitShort)
})
})
t.Run("YAML", func(t *testing.T) {
t.Parallel()
t.Run("WriteThenReadConfig", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
args := []string{
"server",
"--in-memory",
"--http-address", ":0",
"--access-url", "http://example.com",
"--log-human", filepath.Join(t.TempDir(), "coder-logging-test-human"),
// We use ecdsa here because it's the fastest alternative algorithm.
"--ssh-keygen-algorithm", "ecdsa",
"--cache-dir", t.TempDir(),
}
// First, we get the base config as set via flags (like users before
// migrating).
inv, cfg := clitest.New(t,
args...,
)
ptytest.New(t).Attach(inv)
inv = inv.WithContext(ctx)
w := clitest.StartWithWaiter(t, inv)
gotURL := waitAccessURL(t, cfg)
client := codersdk.New(gotURL)
_ = coderdtest.CreateFirstUser(t, client)
wantConfig, err := client.DeploymentConfig(ctx)
require.NoError(t, err)
cancel()
w.RequireSuccess()
// Next, we instruct the same server to display the YAML config
// and then save it.
inv = inv.WithContext(testutil.Context(t, testutil.WaitMedium))
inv.Args = append(args, "--write-config")
fi, err := os.OpenFile(testutil.TempFile(t, "", "coder-config-test-*"), os.O_WRONLY|os.O_CREATE, 0o600)
require.NoError(t, err)
defer fi.Close()
var conf bytes.Buffer
inv.Stdout = io.MultiWriter(fi, &conf)
t.Logf("%+v", inv.Args)
err = inv.Run()
require.NoError(t, err)
// Reset the context.
ctx = testutil.Context(t, testutil.WaitMedium)
// Finally, we restart the server with just the config and no flags
// and ensure that the live configuration is equivalent.
inv, cfg = clitest.New(t, "server", "--config="+fi.Name())
w = clitest.StartWithWaiter(t, inv)
client = codersdk.New(waitAccessURL(t, cfg))
_ = coderdtest.CreateFirstUser(t, client)
gotConfig, err := client.DeploymentConfig(ctx)
require.NoError(t, err, "config:\n%s\nargs: %+v", conf.String(), inv.Args)
gotConfig.Options.ByName("Config Path").Value.Set("")
// We check the options individually for better error messages.
for i := range wantConfig.Options {
// ValueSource is not going to be correct on the `want`, so just
// match that field.
wantConfig.Options[i].ValueSource = gotConfig.Options[i].ValueSource
// If there is a wrapped value with a validator, unwrap it.
// The underlying doesn't compare well since it compares go pointers,
// and not the actual value.
if validator, isValidator := wantConfig.Options[i].Value.(interface{ Underlying() pflag.Value }); isValidator {
wantConfig.Options[i].Value = validator.Underlying()
}
if validator, isValidator := gotConfig.Options[i].Value.(interface{ Underlying() pflag.Value }); isValidator {
gotConfig.Options[i].Value = validator.Underlying()
}
assert.Equal(
t, wantConfig.Options[i],
gotConfig.Options[i],
"option %q",
wantConfig.Options[i].Name,
)
}
w.Cancel()
w.RequireSuccess()
})
})
}
//nolint:tparallel,paralleltest // This test sets environment variables.
func TestServer_Logging_NoParallel(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.Copy(io.Discard, r.Body)
_ = r.Body.Close()
w.WriteHeader(http.StatusOK)
}))
t.Cleanup(func() { server.Close() })
// Speed up stackdriver test by using custom host. This is like
// saying we're running on GCE, so extra checks are skipped.
//
// Note, that the server isn't actually hit by the test, unsure why
// but kept just in case.
//
// From cloud.google.com/go/compute/metadata/metadata.go (used by coder/slog):
//
// metadataHostEnv is the environment variable specifying the
// GCE metadata hostname. If empty, the default value of
// metadataIP ("169.254.169.254") is used instead.
// This is variable name is not defined by any spec, as far as
// I know; it was made up for the Go package.
t.Setenv("GCE_METADATA_HOST", server.URL)
t.Run("Stackdriver", func(t *testing.T) {
ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancelFunc()
fi := testutil.TempFile(t, "", "coder-logging-test-*")
inv, _ := clitest.New(t,
"server",
"--log-filter=.*",
"--in-memory",
"--http-address", ":0",
"--access-url", "http://example.com",
"--provisioner-daemons=3",
"--provisioner-types=echo",
"--log-stackdriver", fi,
)
// Attach pty so we get debug output from the command if this test
// fails.
pty := ptytest.New(t).Attach(inv)
clitest.Start(t, inv.WithContext(ctx))
// Wait for server to listen on HTTP, this is a good
// starting point for expecting logs.
_ = pty.ExpectMatchContext(ctx, "Started HTTP listener at")
loggingWaitFile(t, fi, testutil.WaitSuperLong)
})
t.Run("Multiple", func(t *testing.T) {
ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.WaitSuperLong)
defer cancelFunc()
fi1 := testutil.TempFile(t, "", "coder-logging-test-*")
fi2 := testutil.TempFile(t, "", "coder-logging-test-*")
fi3 := testutil.TempFile(t, "", "coder-logging-test-*")
// NOTE(mafredri): This test might end up downloading Terraform
// which can take a long time and end up failing the test.
// This is why we wait extra long below for server to listen on
// HTTP.
inv, _ := clitest.New(t,
"server",
"--log-filter=.*",
"--in-memory",
"--http-address", ":0",
"--access-url", "http://example.com",
"--provisioner-daemons=3",
"--provisioner-types=echo",
"--log-human", fi1,
"--log-json", fi2,
"--log-stackdriver", fi3,
)
// Attach pty so we get debug output from the command if this test
// fails.
pty := ptytest.New(t).Attach(inv)
clitest.Start(t, inv)
// Wait for server to listen on HTTP, this is a good
// starting point for expecting logs.
_ = pty.ExpectMatchContext(ctx, "Started HTTP listener at")
loggingWaitFile(t, fi1, testutil.WaitSuperLong)
loggingWaitFile(t, fi2, testutil.WaitSuperLong)
loggingWaitFile(t, fi3, testutil.WaitSuperLong)
})
}
func loggingWaitFile(t *testing.T, fiName string, dur time.Duration) {
var lastStat os.FileInfo
require.Eventually(t, func() bool {
var err error
lastStat, err = os.Stat(fiName)
if err != nil {
if !os.IsNotExist(err) {
t.Fatalf("unexpected error: %v", err)
}
return false
}
return lastStat.Size() > 0
},
dur, //nolint:gocritic
testutil.IntervalFast,
"file at %s should exist, last stat: %+v",
fiName, lastStat,
)
}
func TestServer_Production(t *testing.T) {
t.Parallel()
if runtime.GOOS != "linux" || testing.Short() {
// Skip on non-Linux because it spawns a PostgreSQL instance.
t.SkipNow()
}
connectionURL, err := dbtestutil.Open(t)
require.NoError(t, err)
// Postgres + race detector + CI = slow.
ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.WaitSuperLong*3)
defer cancelFunc()
inv, cfg := clitest.New(t,
"server",
"--http-address", ":0",
"--access-url", "http://example.com",
"--postgres-url", connectionURL,
"--cache-dir", t.TempDir(),
)
clitest.Start(t, inv.WithContext(ctx))
accessURL := waitAccessURL(t, cfg)
client := codersdk.New(accessURL)
_, err = client.CreateFirstUser(ctx, coderdtest.FirstUserParams)
require.NoError(t, err)
}
//nolint:tparallel,paralleltest // This test sets environment variables.
func TestServer_TelemetryDisable(t *testing.T) {
// Set the default telemetry to true (normally disabled in tests).
t.Setenv("CODER_TEST_TELEMETRY_DEFAULT_ENABLE", "true")
//nolint:paralleltest // No need to reinitialise the variable tt (Go version).
for _, tt := range []struct {
key string
val string
want bool
}{
{"", "", true},
{"CODER_TELEMETRY_ENABLE", "true", true},
{"CODER_TELEMETRY_ENABLE", "false", false},
{"CODER_TELEMETRY", "true", true},
{"CODER_TELEMETRY", "false", false},
} {
t.Run(fmt.Sprintf("%s=%s", tt.key, tt.val), func(t *testing.T) {
t.Parallel()
var b bytes.Buffer
inv, _ := clitest.New(t, "server", "--write-config")
inv.Stdout = &b
inv.Environ.Set(tt.key, tt.val)
clitest.Run(t, inv)
var dv codersdk.DeploymentValues
err := yaml.Unmarshal(b.Bytes(), &dv)
require.NoError(t, err)
assert.Equal(t, tt.want, dv.Telemetry.Enable.Value())
})
}
}
//nolint:tparallel,paralleltest // This test cannot be run in parallel due to signal handling.
func TestServer_InterruptShutdown(t *testing.T) {
t.Skip("This test issues an interrupt signal which will propagate to the test runner.")
if runtime.GOOS == "windows" {
// Sending interrupt signal isn't supported on Windows!
t.SkipNow()
}
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
root, cfg := clitest.New(t,
"server",
"--in-memory",
"--http-address", ":0",
"--access-url", "http://example.com",
"--provisioner-daemons", "1",
"--cache-dir", t.TempDir(),
)
serverErr := make(chan error, 1)
go func() {
serverErr <- root.WithContext(ctx).Run()
}()
_ = waitAccessURL(t, cfg)
currentProcess, err := os.FindProcess(os.Getpid())
require.NoError(t, err)
err = currentProcess.Signal(os.Interrupt)
require.NoError(t, err)
// We cannot send more signals here, because it's possible Coder
// has already exited, which could cause the test to fail due to interrupt.
err = <-serverErr
require.NoError(t, err)
}
func TestServer_GracefulShutdown(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
// Sending interrupt signal isn't supported on Windows!
t.SkipNow()
}
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
root, cfg := clitest.New(t,
"server",
"--in-memory",
"--http-address", ":0",
"--access-url", "http://example.com",
"--provisioner-daemons", "1",
"--cache-dir", t.TempDir(),
)
var stopFunc context.CancelFunc
root = root.WithTestSignalNotifyContext(t, func(parent context.Context, signals ...os.Signal) (context.Context, context.CancelFunc) {
if !reflect.DeepEqual(cli.StopSignalsNoInterrupt, signals) {
return context.WithCancel(ctx)
}
var ctx context.Context
ctx, stopFunc = context.WithCancel(parent)
return ctx, stopFunc
})
serverErr := make(chan error, 1)
pty := ptytest.New(t).Attach(root)
go func() {
serverErr <- root.WithContext(ctx).Run()
}()
_ = waitAccessURL(t, cfg)
// It's fair to assume `stopFunc` isn't nil here, because the server
// has started and access URL is propagated.
stopFunc()
pty.ExpectMatch("waiting for provisioner jobs to complete")
err := <-serverErr
require.NoError(t, err)
}
func BenchmarkServerHelp(b *testing.B) {
// server --help is a good proxy for measuring the
// constant overhead of each command.
b.ReportAllocs()
for i := 0; i < b.N; i++ {
inv, _ := clitest.New(b, "server", "--help")
inv.Stdout = io.Discard
inv.Stderr = io.Discard
err := inv.Run()
require.NoError(b, err)
}
}
func generateTLSCertificate(t testing.TB, commonName ...string) (certPath, keyPath string) {
dir := t.TempDir()
commonNameStr := "localhost"
if len(commonName) > 0 {
commonNameStr = commonName[0]
}
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
template := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{"Acme Co"},
CommonName: commonNameStr,
},
DNSNames: []string{commonNameStr},
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour * 24 * 180),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
}
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
require.NoError(t, err)
certFile, err := os.CreateTemp(dir, "")
require.NoError(t, err)
defer certFile.Close()
_, err = certFile.Write(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}))
require.NoError(t, err)
privateKeyBytes, err := x509.MarshalPKCS8PrivateKey(privateKey)
require.NoError(t, err)
keyFile, err := os.CreateTemp(dir, "")
require.NoError(t, err)
defer keyFile.Close()
err = pem.Encode(keyFile, &pem.Block{Type: "PRIVATE KEY", Bytes: privateKeyBytes})
require.NoError(t, err)
return certFile.Name(), keyFile.Name()
}
func waitAccessURL(t *testing.T, cfg config.Root) *url.URL {
t.Helper()
var err error
var rawURL string
require.Eventually(t, func() bool {
rawURL, err = cfg.URL().Read()
return err == nil && rawURL != ""
}, testutil.WaitLong, testutil.IntervalFast, "failed to get access URL")
accessURL, err := url.Parse(rawURL)
require.NoError(t, err, "failed to parse access URL")
return accessURL
}
func TestServerYAMLConfig(t *testing.T) {
t.Parallel()
var deployValues codersdk.DeploymentValues
opts := deployValues.Options()
err := opts.SetDefaults()
require.NoError(t, err)
n, err := opts.MarshalYAML()
require.NoError(t, err)
// Sanity-check that we can read the config back in.
err = opts.UnmarshalYAML(n.(*yaml.Node))
require.NoError(t, err)
var wantBuf bytes.Buffer
enc := yaml.NewEncoder(&wantBuf)
enc.SetIndent(2)
err = enc.Encode(n)
require.NoError(t, err)
clitest.TestGoldenFile(t, "server-config.yaml", wantBuf.Bytes(), nil)
}
func TestConnectToPostgres(t *testing.T) {
t.Parallel()
if !dbtestutil.WillUsePostgres() {
t.Skip("this test does not make sense without postgres")
}
t.Run("Migrate", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
t.Cleanup(cancel)
log := testutil.Logger(t)
dbURL, err := dbtestutil.Open(t)
require.NoError(t, err)
sqlDB, err := cli.ConnectToPostgres(ctx, log, "postgres", dbURL, migrations.Up)
require.NoError(t, err)
t.Cleanup(func() {
_ = sqlDB.Close()
})
require.NoError(t, sqlDB.PingContext(ctx))
})
t.Run("NoMigrate", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
t.Cleanup(cancel)
log := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
dbURL, err := dbtestutil.Open(t)
require.NoError(t, err)
okDB, err := cli.ConnectToPostgres(ctx, log, "postgres", dbURL, nil)
require.NoError(t, err)
defer okDB.Close()
// Set the migration number forward
_, err = okDB.Exec(`UPDATE schema_migrations SET version = version + 1`)
require.NoError(t, err)
_, err = cli.ConnectToPostgres(ctx, log, "postgres", dbURL, nil)
require.Error(t, err)
require.ErrorContains(t, err, "database needs migration")
require.NoError(t, okDB.PingContext(ctx))
})
}
func TestServer_InvalidDERP(t *testing.T) {
t.Parallel()
// Try to start a server with the built-in DERP server disabled and no
// external DERP map.
inv, _ := clitest.New(t,
"server",
"--in-memory",
"--http-address", ":0",
"--access-url", "http://example.com",
"--derp-server-enable=false",
"--derp-server-stun-addresses", "disable",
"--block-direct-connections",
)
err := inv.Run()
require.Error(t, err)
require.ErrorContains(t, err, "A valid DERP map is required for networking to work")
}
func TestServer_DisabledDERP(t *testing.T) {
t.Parallel()
derpMap, _ := tailnettest.RunDERPAndSTUN(t)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
httpapi.Write(context.Background(), w, http.StatusOK, derpMap)
}))
t.Cleanup(srv.Close)
ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancelFunc()
// Try to start a server with the built-in DERP server disabled and an
// external DERP map.
inv, cfg := clitest.New(t,
"server",
"--in-memory",
"--http-address", ":0",
"--access-url", "http://example.com",
"--derp-server-enable=false",
"--derp-config-url", srv.URL,
)
clitest.Start(t, inv.WithContext(ctx))
accessURL := waitAccessURL(t, cfg)
derpURL, err := accessURL.Parse("/derp")
require.NoError(t, err)
c, err := derphttp.NewClient(key.NewNode(), derpURL.String(), func(format string, args ...any) {})
require.NoError(t, err)
// DERP should fail to connect
err = c.Connect(ctx)
require.Error(t, err)
}
type runServerOpts struct {
waitForSnapshot bool
telemetryDisabled bool
waitForTelemetryDisabledCheck bool
}
func TestServer_TelemetryDisabled_FinalReport(t *testing.T) {
t.Parallel()
if !dbtestutil.WillUsePostgres() {
t.Skip("this test requires postgres")
}
telemetryServerURL, deployment, snapshot := mockTelemetryServer(t)
dbConnURL, err := dbtestutil.Open(t)
require.NoError(t, err)
cacheDir := t.TempDir()
runServer := func(t *testing.T, opts runServerOpts) (chan error, context.CancelFunc) {
ctx, cancelFunc := context.WithCancel(context.Background())
inv, _ := clitest.New(t,
"server",
"--postgres-url", dbConnURL,
"--http-address", ":0",
"--access-url", "http://example.com",
"--telemetry="+strconv.FormatBool(!opts.telemetryDisabled),
"--telemetry-url", telemetryServerURL.String(),
"--cache-dir", cacheDir,
"--log-filter", ".*",
)
finished := make(chan bool, 2)
errChan := make(chan error, 1)
pty := ptytest.New(t).Attach(inv)
go func() {
errChan <- inv.WithContext(ctx).Run()
finished <- true
}()
go func() {
defer func() {
finished <- true
}()
if opts.waitForSnapshot {
pty.ExpectMatchContext(testutil.Context(t, testutil.WaitLong), "submitted snapshot")
}
if opts.waitForTelemetryDisabledCheck {
pty.ExpectMatchContext(testutil.Context(t, testutil.WaitLong), "finished telemetry status check")
}
}()
<-finished
return errChan, cancelFunc
}
waitForShutdown := func(t *testing.T, errChan chan error) error {
t.Helper()
select {
case err := <-errChan:
return err
case <-time.After(testutil.WaitMedium):
t.Fatalf("timed out waiting for server to shutdown")
}
return nil
}
errChan, cancelFunc := runServer(t, runServerOpts{telemetryDisabled: true, waitForTelemetryDisabledCheck: true})
cancelFunc()
require.NoError(t, waitForShutdown(t, errChan))
// Since telemetry was disabled, we expect no deployments or snapshots.
require.Empty(t, deployment)
require.Empty(t, snapshot)
errChan, cancelFunc = runServer(t, runServerOpts{waitForSnapshot: true})
cancelFunc()
require.NoError(t, waitForShutdown(t, errChan))
// we expect to see a deployment and a snapshot twice:
// 1. the first pair is sent when the server starts
// 2. the second pair is sent when the server shuts down
for i := 0; i < 2; i++ {
select {
case <-snapshot:
case <-time.After(testutil.WaitShort / 2):
t.Fatalf("timed out waiting for snapshot")
}
select {
case <-deployment:
case <-time.After(testutil.WaitShort / 2):
t.Fatalf("timed out waiting for deployment")
}
}
errChan, cancelFunc = runServer(t, runServerOpts{telemetryDisabled: true, waitForTelemetryDisabledCheck: true})
cancelFunc()
require.NoError(t, waitForShutdown(t, errChan))
// Since telemetry is disabled, we expect no deployment. We expect a snapshot
// with the telemetry disabled item.
require.Empty(t, deployment)
select {
case ss := <-snapshot:
require.Len(t, ss.TelemetryItems, 1)
require.Equal(t, string(telemetry.TelemetryItemKeyTelemetryEnabled), ss.TelemetryItems[0].Key)
require.Equal(t, "false", ss.TelemetryItems[0].Value)
case <-time.After(testutil.WaitShort / 2):
t.Fatalf("timed out waiting for snapshot")
}
errChan, cancelFunc = runServer(t, runServerOpts{telemetryDisabled: true, waitForTelemetryDisabledCheck: true})
cancelFunc()
require.NoError(t, waitForShutdown(t, errChan))
// Since telemetry is disabled and we've already sent a snapshot, we expect no
// new deployments or snapshots.
require.Empty(t, deployment)
require.Empty(t, snapshot)
}
func mockTelemetryServer(t *testing.T) (*url.URL, chan *telemetry.Deployment, chan *telemetry.Snapshot) {
t.Helper()
deployment := make(chan *telemetry.Deployment, 64)
snapshot := make(chan *telemetry.Snapshot, 64)
r := chi.NewRouter()
r.Post("/deployment", func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, buildinfo.Version(), r.Header.Get(telemetry.VersionHeader))
dd := &telemetry.Deployment{}
err := json.NewDecoder(r.Body).Decode(dd)
require.NoError(t, err)
deployment <- dd
// Ensure the header is sent only after deployment is sent
w.WriteHeader(http.StatusAccepted)
})
r.Post("/snapshot", func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, buildinfo.Version(), r.Header.Get(telemetry.VersionHeader))
ss := &telemetry.Snapshot{}
err := json.NewDecoder(r.Body).Decode(ss)
require.NoError(t, err)
snapshot <- ss
// Ensure the header is sent only after snapshot is sent
w.WriteHeader(http.StatusAccepted)
})
server := httptest.NewServer(r)
t.Cleanup(server.Close)
serverURL, err := url.Parse(server.URL)
require.NoError(t, err)
return serverURL, deployment, snapshot
}