package coderdtest import ( "bytes" "context" "crypto" "crypto/rand" "crypto/rsa" "crypto/sha256" "crypto/x509" "crypto/x509/pkix" "database/sql" "encoding/base64" "encoding/json" "encoding/pem" "io" "math/big" "net" "net/http" "net/http/httptest" "net/url" "os" "strings" "testing" "time" "cloud.google.com/go/compute/metadata" "github.com/fullsailor/pkcs7" "github.com/golang-jwt/jwt" "github.com/google/uuid" "github.com/moby/moby/pkg/namesgenerator" "github.com/stretchr/testify/require" "google.golang.org/api/idtoken" "google.golang.org/api/option" "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/coderd" "github.com/coder/coder/coderd/autobuild/executor" "github.com/coder/coder/coderd/awsidentity" "github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/database/databasefake" "github.com/coder/coder/coderd/database/postgres" "github.com/coder/coder/coderd/gitsshkey" "github.com/coder/coder/coderd/turnconn" "github.com/coder/coder/codersdk" "github.com/coder/coder/cryptorand" "github.com/coder/coder/provisioner/echo" "github.com/coder/coder/provisionerd" "github.com/coder/coder/provisionersdk" "github.com/coder/coder/provisionersdk/proto" ) type Options struct { AWSCertificates awsidentity.Certificates AzureCertificates x509.VerifyOptions GithubOAuth2Config *coderd.GithubOAuth2Config GoogleTokenValidator *idtoken.Validator SSHKeygenAlgorithm gitsshkey.Algorithm APIRateLimit int LifecycleTicker <-chan time.Time } // New constructs an in-memory coderd instance and returns // the connected client. func New(t *testing.T, options *Options) *codersdk.Client { if options == nil { options = &Options{} } if options.GoogleTokenValidator == nil { ctx, cancelFunc := context.WithCancel(context.Background()) t.Cleanup(cancelFunc) var err error options.GoogleTokenValidator, err = idtoken.NewValidator(ctx, option.WithoutAuthentication()) require.NoError(t, err) } if options.LifecycleTicker == nil { ticker := make(chan time.Time) options.LifecycleTicker = ticker t.Cleanup(func() { close(ticker) }) } // This can be hotswapped for a live database instance. db := databasefake.New() pubsub := database.NewPubsubInMemory() if os.Getenv("DB") != "" { connectionURL, closePg, err := postgres.Open() require.NoError(t, err) t.Cleanup(closePg) sqlDB, err := sql.Open("postgres", connectionURL) require.NoError(t, err) t.Cleanup(func() { _ = sqlDB.Close() }) err = database.MigrateUp(sqlDB) require.NoError(t, err) db = database.New(sqlDB) pubsub, err = database.NewPubsub(context.Background(), sqlDB, connectionURL) require.NoError(t, err) t.Cleanup(func() { _ = pubsub.Close() }) } ctx, cancelFunc := context.WithCancel(context.Background()) lifecycleExecutor := executor.New( ctx, db, slogtest.Make(t, nil).Named("autobuild.executor").Leveled(slog.LevelDebug), options.LifecycleTicker, ) lifecycleExecutor.Run() srv := httptest.NewUnstartedServer(nil) srv.Config.BaseContext = func(_ net.Listener) context.Context { return ctx } srv.Start() serverURL, err := url.Parse(srv.URL) require.NoError(t, err) var closeWait func() // match default with cli default if options.SSHKeygenAlgorithm == "" { options.SSHKeygenAlgorithm = gitsshkey.AlgorithmEd25519 } turnServer, err := turnconn.New(nil) require.NoError(t, err) // We set the handler after server creation for the access URL. srv.Config.Handler, closeWait = coderd.New(&coderd.Options{ AgentConnectionUpdateFrequency: 150 * time.Millisecond, AccessURL: serverURL, Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug), Database: db, Pubsub: pubsub, AWSCertificates: options.AWSCertificates, AzureCertificates: options.AzureCertificates, GithubOAuth2Config: options.GithubOAuth2Config, GoogleTokenValidator: options.GoogleTokenValidator, SSHKeygenAlgorithm: options.SSHKeygenAlgorithm, TURNServer: turnServer, APIRateLimit: options.APIRateLimit, }) t.Cleanup(func() { cancelFunc() _ = turnServer.Close() srv.Close() closeWait() }) return codersdk.New(serverURL) } // NewProvisionerDaemon launches a provisionerd instance configured to work // well with coderd testing. It registers the "echo" provisioner for // quick testing. func NewProvisionerDaemon(t *testing.T, client *codersdk.Client) io.Closer { echoClient, echoServer := provisionersdk.TransportPipe() ctx, cancelFunc := context.WithCancel(context.Background()) t.Cleanup(func() { _ = echoClient.Close() _ = echoServer.Close() cancelFunc() }) go func() { err := echo.Serve(ctx, &provisionersdk.ServeOptions{ Listener: echoServer, }) require.NoError(t, err) }() closer := provisionerd.New(client.ListenProvisionerDaemon, &provisionerd.Options{ Logger: slogtest.Make(t, nil).Named("provisionerd").Leveled(slog.LevelDebug), PollInterval: 50 * time.Millisecond, UpdateInterval: 250 * time.Millisecond, ForceCancelInterval: 250 * time.Millisecond, Provisioners: provisionerd.Provisioners{ string(database.ProvisionerTypeEcho): proto.NewDRPCProvisionerClient(provisionersdk.Conn(echoClient)), }, WorkDirectory: t.TempDir(), }) t.Cleanup(func() { _ = closer.Close() }) return closer } var FirstUserParams = codersdk.CreateFirstUserRequest{ Email: "testuser@coder.com", Username: "testuser", Password: "testpass", OrganizationName: "testorg", } // CreateFirstUser creates a user with preset credentials and authenticates // with the passed in codersdk client. func CreateFirstUser(t *testing.T, client *codersdk.Client) codersdk.CreateFirstUserResponse { resp, err := client.CreateFirstUser(context.Background(), FirstUserParams) require.NoError(t, err) login, err := client.LoginWithPassword(context.Background(), codersdk.LoginWithPasswordRequest{ Email: FirstUserParams.Email, Password: FirstUserParams.Password, }) require.NoError(t, err) client.SessionToken = login.SessionToken return resp } // CreateAnotherUser creates and authenticates a new user. func CreateAnotherUser(t *testing.T, client *codersdk.Client, organizationID uuid.UUID) *codersdk.Client { req := codersdk.CreateUserRequest{ Email: namesgenerator.GetRandomName(1) + "@coder.com", Username: randomUsername(), Password: "testpass", OrganizationID: organizationID, } _, err := client.CreateUser(context.Background(), req) require.NoError(t, err) login, err := client.LoginWithPassword(context.Background(), codersdk.LoginWithPasswordRequest{ Email: req.Email, Password: req.Password, }) require.NoError(t, err) other := codersdk.New(client.URL) other.SessionToken = login.SessionToken return other } // CreateTemplateVersion creates a template import provisioner job // with the responses provided. It uses the "echo" provisioner for compatibility // with testing. func CreateTemplateVersion(t *testing.T, client *codersdk.Client, organizationID uuid.UUID, res *echo.Responses) codersdk.TemplateVersion { data, err := echo.Tar(res) require.NoError(t, err) file, err := client.Upload(context.Background(), codersdk.ContentTypeTar, data) require.NoError(t, err) templateVersion, err := client.CreateTemplateVersion(context.Background(), organizationID, codersdk.CreateTemplateVersionRequest{ StorageSource: file.Hash, StorageMethod: database.ProvisionerStorageMethodFile, Provisioner: database.ProvisionerTypeEcho, }) require.NoError(t, err) return templateVersion } // CreateTemplate creates a template with the "echo" provisioner for // compatibility with testing. The name assigned is randomly generated. func CreateTemplate(t *testing.T, client *codersdk.Client, organization uuid.UUID, version uuid.UUID) codersdk.Template { template, err := client.CreateTemplate(context.Background(), organization, codersdk.CreateTemplateRequest{ Name: randomUsername(), VersionID: version, }) require.NoError(t, err) return template } // UpdateTemplateVersion creates a new template version with the "echo" provisioner // and associates it with the given templateID. func UpdateTemplateVersion(t *testing.T, client *codersdk.Client, organizationID uuid.UUID, res *echo.Responses, templateID uuid.UUID) codersdk.TemplateVersion { data, err := echo.Tar(res) require.NoError(t, err) file, err := client.Upload(context.Background(), codersdk.ContentTypeTar, data) require.NoError(t, err) templateVersion, err := client.CreateTemplateVersion(context.Background(), organizationID, codersdk.CreateTemplateVersionRequest{ TemplateID: templateID, StorageSource: file.Hash, StorageMethod: database.ProvisionerStorageMethodFile, Provisioner: database.ProvisionerTypeEcho, }) require.NoError(t, err) return templateVersion } // AwaitTemplateImportJob awaits for an import job to reach completed status. func AwaitTemplateVersionJob(t *testing.T, client *codersdk.Client, version uuid.UUID) codersdk.TemplateVersion { var templateVersion codersdk.TemplateVersion require.Eventually(t, func() bool { var err error templateVersion, err = client.TemplateVersion(context.Background(), version) require.NoError(t, err) return templateVersion.Job.CompletedAt != nil }, 5*time.Second, 25*time.Millisecond) return templateVersion } // AwaitWorkspaceBuildJob waits for a workspace provision job to reach completed status. func AwaitWorkspaceBuildJob(t *testing.T, client *codersdk.Client, build uuid.UUID) codersdk.WorkspaceBuild { var workspaceBuild codersdk.WorkspaceBuild require.Eventually(t, func() bool { var err error workspaceBuild, err = client.WorkspaceBuild(context.Background(), build) require.NoError(t, err) return workspaceBuild.Job.CompletedAt != nil }, 5*time.Second, 25*time.Millisecond) return workspaceBuild } // AwaitWorkspaceAgents waits for all resources with agents to be connected. func AwaitWorkspaceAgents(t *testing.T, client *codersdk.Client, build uuid.UUID) []codersdk.WorkspaceResource { var resources []codersdk.WorkspaceResource require.Eventually(t, func() bool { var err error resources, err = client.WorkspaceResourcesByBuild(context.Background(), build) require.NoError(t, err) for _, resource := range resources { for _, agent := range resource.Agents { if agent.Status != codersdk.WorkspaceAgentConnected { return false } } } return true }, 5*time.Second, 25*time.Millisecond) return resources } // CreateWorkspace creates a workspace for the user and template provided. // A random name is generated for it. func CreateWorkspace(t *testing.T, client *codersdk.Client, organization uuid.UUID, templateID uuid.UUID) codersdk.Workspace { workspace, err := client.CreateWorkspace(context.Background(), organization, codersdk.CreateWorkspaceRequest{ TemplateID: templateID, Name: randomUsername(), }) require.NoError(t, err) return workspace } // NewGoogleInstanceIdentity returns a metadata client and ID token validator for faking // instance authentication for Google Cloud. // nolint:revive func NewGoogleInstanceIdentity(t *testing.T, instanceID string, expired bool) (*idtoken.Validator, *metadata.Client) { keyID, err := cryptorand.String(12) require.NoError(t, err) claims := jwt.MapClaims{ "google": map[string]interface{}{ "compute_engine": map[string]string{ "instance_id": instanceID, }, }, } if !expired { claims["exp"] = time.Now().AddDate(1, 0, 0).Unix() } token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) token.Header["kid"] = keyID privateKey, err := rsa.GenerateKey(rand.Reader, 2048) require.NoError(t, err) signedKey, err := token.SignedString(privateKey) require.NoError(t, err) // Taken from: https://github.com/googleapis/google-api-go-client/blob/4bb729045d611fa77bdbeb971f6a1204ba23161d/idtoken/validate.go#L57-L75 type jwk struct { Kid string `json:"kid"` N string `json:"n"` E string `json:"e"` } type certResponse struct { Keys []jwk `json:"keys"` } validator, err := idtoken.NewValidator(context.Background(), option.WithHTTPClient(&http.Client{ Transport: roundTripper(func(r *http.Request) (*http.Response, error) { data, err := json.Marshal(certResponse{ Keys: []jwk{{ Kid: keyID, N: base64.RawURLEncoding.EncodeToString(privateKey.N.Bytes()), E: base64.RawURLEncoding.EncodeToString(new(big.Int).SetInt64(int64(privateKey.E)).Bytes()), }}, }) require.NoError(t, err) return &http.Response{ StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(data)), Header: make(http.Header), }, nil }), })) require.NoError(t, err) return validator, metadata.NewClient(&http.Client{ Transport: roundTripper(func(r *http.Request) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader([]byte(signedKey))), Header: make(http.Header), }, nil }), }) } // NewAWSInstanceIdentity returns a metadata client and ID token validator for faking // instance authentication for AWS. func NewAWSInstanceIdentity(t *testing.T, instanceID string) (awsidentity.Certificates, *http.Client) { privateKey, err := rsa.GenerateKey(rand.Reader, 2048) require.NoError(t, err) document := []byte(`{"instanceId":"` + instanceID + `"}`) hashedDocument := sha256.Sum256(document) signatureRaw, err := rsa.SignPKCS1v15(rand.Reader, privateKey, crypto.SHA256, hashedDocument[:]) require.NoError(t, err) signature := make([]byte, base64.StdEncoding.EncodedLen(len(signatureRaw))) base64.StdEncoding.Encode(signature, signatureRaw) certificate, err := x509.CreateCertificate(rand.Reader, &x509.Certificate{ SerialNumber: big.NewInt(2022), }, &x509.Certificate{}, &privateKey.PublicKey, privateKey) require.NoError(t, err) certificatePEM := bytes.Buffer{} err = pem.Encode(&certificatePEM, &pem.Block{ Type: "CERTIFICATE", Bytes: certificate, }) require.NoError(t, err) return awsidentity.Certificates{ awsidentity.Other: certificatePEM.String(), }, &http.Client{ Transport: roundTripper(func(r *http.Request) (*http.Response, error) { // Only handle metadata server requests. if r.URL.Host != "169.254.169.254" { return http.DefaultTransport.RoundTrip(r) } switch r.URL.Path { case "/latest/api/token": return &http.Response{ StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader([]byte("faketoken"))), Header: make(http.Header), }, nil case "/latest/dynamic/instance-identity/signature": return &http.Response{ StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(signature)), Header: make(http.Header), }, nil case "/latest/dynamic/instance-identity/document": return &http.Response{ StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(document)), Header: make(http.Header), }, nil default: panic("unhandled route: " + r.URL.Path) } }), } } // NewAzureInstanceIdentity returns a metadata client and ID token validator for faking // instance authentication for Azure. func NewAzureInstanceIdentity(t *testing.T, instanceID string) (x509.VerifyOptions, *http.Client) { privateKey, err := rsa.GenerateKey(rand.Reader, 2048) require.NoError(t, err) rawCertificate, err := x509.CreateCertificate(rand.Reader, &x509.Certificate{ SerialNumber: big.NewInt(2022), NotAfter: time.Now().AddDate(1, 0, 0), Subject: pkix.Name{ CommonName: "metadata.azure.com", }, }, &x509.Certificate{}, &privateKey.PublicKey, privateKey) require.NoError(t, err) certificate, err := x509.ParseCertificate(rawCertificate) require.NoError(t, err) signed, err := pkcs7.NewSignedData([]byte(`{"vmId":"` + instanceID + `"}`)) require.NoError(t, err) err = signed.AddSigner(certificate, privateKey, pkcs7.SignerInfoConfig{}) require.NoError(t, err) signatureRaw, err := signed.Finish() require.NoError(t, err) signature := make([]byte, base64.StdEncoding.EncodedLen(len(signatureRaw))) base64.StdEncoding.Encode(signature, signatureRaw) payload, err := json.Marshal(codersdk.AzureInstanceIdentityToken{ Signature: string(signature), Encoding: "pkcs7", }) require.NoError(t, err) certPool := x509.NewCertPool() certPool.AddCert(certificate) return x509.VerifyOptions{ Intermediates: certPool, Roots: certPool, }, &http.Client{ Transport: roundTripper(func(r *http.Request) (*http.Response, error) { // Only handle metadata server requests. if r.URL.Host != "169.254.169.254" { return http.DefaultTransport.RoundTrip(r) } switch r.URL.Path { case "/metadata/attested/document": return &http.Response{ StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(payload)), Header: make(http.Header), }, nil default: panic("unhandled route: " + r.URL.Path) } }), } } func randomUsername() string { return strings.ReplaceAll(namesgenerator.GetRandomName(0), "_", "-") } // Used to easily create an HTTP transport! type roundTripper func(req *http.Request) (*http.Response, error) func (r roundTripper) RoundTrip(req *http.Request) (*http.Response, error) { return r(req) }