From 2654a9313240f68c90054cb9d097872122a587c3 Mon Sep 17 00:00:00 2001 From: Kyle Carberry Date: Thu, 20 Jan 2022 10:00:13 -0600 Subject: [PATCH] chore: Fix golangci-lint configuration and patch errors (#34) * chore: Fix golangci-lint configuration and patch errors Due to misconfiguration of a linting rules directory, our linter has not been working properly. This change fixes the configuration issue, and all remaining linting errors. * Fix race in peer logging * Fix race and return * Lock on bufferred amount low * Fix mutex lock --- .golangci.yml | 38 +++++++---------- cmd/coder/main.go | 2 +- coderd/coderd.go | 2 +- coderd/coderdtest/coderdtest.go | 6 +-- coderd/userpassword/userpassword.go | 6 +-- coderd/users.go | 26 +++++++----- coderd/users_test.go | 3 +- codersdk/client.go | 10 ++--- codersdk/users.go | 2 +- codersdk/users_test.go | 3 +- cryptorand/numbers.go | 31 +++++++------- cryptorand/numbers_test.go | 5 ++- cryptorand/strings.go | 4 +- cryptorand/strings_test.go | 3 +- database/databasefake/databasefake.go | 17 ++++---- database/db.go | 12 +++--- database/dump/main.go | 5 +-- database/migrate.go | 3 +- database/migrate_test.go | 3 +- database/postgres/postgres.go | 8 ++-- database/pubsub.go | 10 ++--- database/pubsub_memory.go | 2 +- database/pubsub_memory_test.go | 6 +-- database/pubsub_test.go | 6 +-- httpapi/httpapi.go | 17 ++++---- httpmw/apikey.go | 12 +++--- httpmw/apikey_test.go | 56 ++++++++++++++++++------- peer/channel.go | 24 ++++++----- peer/conn.go | 53 +++++++++++------------ peer/conn_test.go | 56 ++++++++++++------------- peer/netconn.go | 10 ++--- peerbroker/listen.go | 1 - provisioner/terraform/parse.go | 2 +- provisioner/terraform/parse_test.go | 10 ++--- provisioner/terraform/provision.go | 4 +- provisioner/terraform/provision_test.go | 20 ++++----- site/embed.go | 53 ++++++++++------------- site/embed_test.go | 7 +++- 38 files changed, 283 insertions(+), 255 deletions(-) diff --git a/.golangci.yml b/.golangci.yml index 140df27953..859e160899 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -100,10 +100,6 @@ linters-settings: # - whyNoLint # - wrapperFunc # - yodaStyleExpr - settings: - ruleguard: - failOn: all - rules: "${configDir}/lib/go/lintrules/*.go" goimports: local-prefixes: coder.com,cdr.dev,go.coder.com,github.com/cdr,github.com/coder @@ -113,24 +109,6 @@ linters-settings: importas: no-unaliased: true - alias: - - pkg: k8s.io/api/(\w+)/(v[\w\d]+) - alias: ${1}${2} - - - pkg: k8s.io/apimachinery/pkg/apis/meta/(v[\w\d]+) - alias: meta${1} - - - pkg: k8s.io/client-go/kubernetes/typed/(\w+)/(v[\w\d]+) - alias: ${1}${2}client - - - pkg: k8s.io/metrics/pkg/apis/metrics/(v[\w\d]+) - alias: metrics${1} - - - pkg: github.com/docker/docker/api/types - alias: dockertypes - - - pkg: github.com/docker/docker/client - alias: dockerclient misspell: locale: US @@ -195,6 +173,20 @@ linters-settings: - name: var-declaration - name: var-naming - name: waitgroup-by-value + varnamelen: + ignore-names: + - err + - rw + - r + - i + - db + # Optional list of variable declarations that should be ignored completely. (defaults to empty list) + # Entries must be in the form of " " or " *" for + # variables, or "const " for constants. + ignore-decls: + - rw http.ResponseWriter + - r *http.Request + - t testing.T issues: # Rules listed here: https://github.com/securego/gosec#available-rules @@ -222,7 +214,6 @@ linters: - asciicheck - bidichk - bodyclose - - contextcheck - deadcode - dogsled - errcheck @@ -239,7 +230,6 @@ linters: - govet - importas - ineffassign - # - ireturn - makezero - misspell - nilnil diff --git a/cmd/coder/main.go b/cmd/coder/main.go index b1b14d0c79..741337dc03 100644 --- a/cmd/coder/main.go +++ b/cmd/coder/main.go @@ -3,5 +3,5 @@ package main import "fmt" func main() { - fmt.Println("Hello World!") + _, _ = fmt.Println("Hello World!") } diff --git a/coderd/coderd.go b/coderd/coderd.go index c24d20158a..16a69a918d 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -39,7 +39,7 @@ func New(options *Options) http.Handler { httpmw.ExtractAPIKey(options.Database, nil), httpmw.ExtractUser(options.Database), ) - r.Get("/user", users.getAuthenticatedUser) + r.Get("/user", users.authenticatedUser) }) }) r.NotFound(site.Handler().ServeHTTP) diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index 7b1954ccc6..d3e880133e 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -32,11 +32,11 @@ func New(t *testing.T) Server { Database: db, }) srv := httptest.NewServer(handler) - u, err := url.Parse(srv.URL) + serverURL, err := url.Parse(srv.URL) require.NoError(t, err) t.Cleanup(srv.Close) - client := codersdk.New(u) + client := codersdk.New(serverURL) _, err = client.CreateInitialUser(context.Background(), coderd.CreateUserRequest{ Email: "testuser@coder.com", Username: "testuser", @@ -54,6 +54,6 @@ func New(t *testing.T) Server { return Server{ Client: client, - URL: u, + URL: serverURL, } } diff --git a/coderd/userpassword/userpassword.go b/coderd/userpassword/userpassword.go index 67ce873125..28e75869ec 100644 --- a/coderd/userpassword/userpassword.go +++ b/coderd/userpassword/userpassword.go @@ -35,14 +35,14 @@ func Compare(hashed string, password string) (bool, error) { if len(parts[0]) != 0 { return false, xerrors.Errorf("hash prefix is invalid") } - if string(parts[1]) != hashScheme { + if parts[1] != hashScheme { return false, xerrors.Errorf("hash isn't %q scheme: %q", hashScheme, parts[1]) } - iter, err := strconv.Atoi(string(parts[2])) + iter, err := strconv.Atoi(parts[2]) if err != nil { return false, xerrors.Errorf("parse iter from hash: %w", err) } - salt, err := base64.RawStdEncoding.DecodeString(string(parts[3])) + salt, err := base64.RawStdEncoding.DecodeString(parts[3]) if err != nil { return false, xerrors.Errorf("decode salt: %w", err) } diff --git a/coderd/users.go b/coderd/users.go index 891981be09..b5f4004d70 100644 --- a/coderd/users.go +++ b/coderd/users.go @@ -70,7 +70,7 @@ func (users *users) createInitialUser(rw http.ResponseWriter, r *http.Request) { }) return } - user, err := users.Database.GetUserByEmailOrUsername(r.Context(), database.GetUserByEmailOrUsernameParams{ + _, err = users.Database.GetUserByEmailOrUsername(r.Context(), database.GetUserByEmailOrUsernameParams{ Email: createUser.Email, Username: createUser.Username, }) @@ -91,7 +91,7 @@ func (users *users) createInitialUser(rw http.ResponseWriter, r *http.Request) { return } - user, err = users.Database.InsertUser(context.Background(), database.InsertUserParams{ + user, err := users.Database.InsertUser(context.Background(), database.InsertUserParams{ ID: uuid.NewString(), Email: createUser.Email, HashedPassword: []byte(hashedPassword), @@ -111,7 +111,7 @@ func (users *users) createInitialUser(rw http.ResponseWriter, r *http.Request) { } // Returns the currently authenticated user. -func (users *users) getAuthenticatedUser(rw http.ResponseWriter, r *http.Request) { +func (*users) authenticatedUser(rw http.ResponseWriter, r *http.Request) { user := httpmw.User(r) render.JSON(rw, r, User{ @@ -158,11 +158,17 @@ func (users *users) loginWithPassword(rw http.ResponseWriter, r *http.Request) { return } - id, secret, err := generateAPIKeyIDSecret() - hashed := sha256.Sum256([]byte(secret)) + keyID, keySecret, err := generateAPIKeyIDSecret() + if err != nil { + httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ + Message: fmt.Sprintf("generate api key parts: %s", err.Error()), + }) + return + } + hashed := sha256.Sum256([]byte(keySecret)) _, err = users.Database.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{ - ID: id, + ID: keyID, UserID: user.ID, ExpiresAt: database.Now().Add(24 * time.Hour), CreatedAt: database.Now(), @@ -178,7 +184,7 @@ func (users *users) loginWithPassword(rw http.ResponseWriter, r *http.Request) { } // This format is consumed by the APIKey middleware. - sessionToken := fmt.Sprintf("%s-%s", id, secret) + sessionToken := fmt.Sprintf("%s-%s", keyID, keySecret) http.SetCookie(rw, &http.Cookie{ Name: httpmw.AuthCookie, Value: sessionToken, @@ -194,14 +200,14 @@ func (users *users) loginWithPassword(rw http.ResponseWriter, r *http.Request) { } // Generates a new ID and secret for an API key. -func generateAPIKeyIDSecret() (string, string, error) { +func generateAPIKeyIDSecret() (id string, secret string, err error) { // Length of an API Key ID. - id, err := cryptorand.String(10) + id, err = cryptorand.String(10) if err != nil { return "", "", err } // Length of an API Key secret. - secret, err := cryptorand.String(22) + secret, err = cryptorand.String(22) if err != nil { return "", "", err } diff --git a/coderd/users_test.go b/coderd/users_test.go index 5f5a5bef48..0aa8e4f023 100644 --- a/coderd/users_test.go +++ b/coderd/users_test.go @@ -4,9 +4,10 @@ import ( "context" "testing" + "github.com/stretchr/testify/require" + "github.com/coder/coder/coderd" "github.com/coder/coder/coderd/coderdtest" - "github.com/stretchr/testify/require" ) func TestUsers(t *testing.T) { diff --git a/codersdk/client.go b/codersdk/client.go index 4644346b42..b4931a91e8 100644 --- a/codersdk/client.go +++ b/codersdk/client.go @@ -18,9 +18,9 @@ import ( ) // New creates a Coder client for the provided URL. -func New(url *url.URL) *Client { +func New(serverURL *url.URL) *Client { return &Client{ - url: url, + url: serverURL, httpClient: &http.Client{}, } } @@ -50,7 +50,7 @@ func (c *Client) SetSessionToken(token string) error { // request performs an HTTP request with the body provided. // The caller is responsible for closing the response body. func (c *Client) request(ctx context.Context, method, path string, body interface{}) (*http.Response, error) { - url, err := c.url.Parse(path) + serverURL, err := c.url.Parse(path) if err != nil { return nil, xerrors.Errorf("parse url: %w", err) } @@ -65,7 +65,7 @@ func (c *Client) request(ctx context.Context, method, path string, body interfac } } - req, err := http.NewRequestWithContext(ctx, method, url.String(), &buf) + req, err := http.NewRequestWithContext(ctx, method, serverURL.String(), &buf) if err != nil { return nil, xerrors.Errorf("create request: %w", err) } @@ -81,7 +81,7 @@ func (c *Client) request(ctx context.Context, method, path string, body interfac } // readBodyAsError reads the response as an httpapi.Message, and -// wraps it in a codersdk.Error type for easy marshalling. +// wraps it in a codersdk.Error type for easy marshaling. func readBodyAsError(res *http.Response) error { var m httpapi.Response err := json.NewDecoder(res.Body).Decode(&m) diff --git a/codersdk/users.go b/codersdk/users.go index 7f2da5f80d..e7caa36b1b 100644 --- a/codersdk/users.go +++ b/codersdk/users.go @@ -26,7 +26,7 @@ func (c *Client) CreateInitialUser(ctx context.Context, req coderd.CreateUserReq // User returns a user for the ID provided. // If the ID string is empty, the current user will be returned. -func (c *Client) User(ctx context.Context, id string) (coderd.User, error) { +func (c *Client) User(ctx context.Context, _ string) (coderd.User, error) { res, err := c.request(ctx, http.MethodGet, "/api/v2/user", nil) if err != nil { return coderd.User{}, err diff --git a/codersdk/users_test.go b/codersdk/users_test.go index 8810b62db8..a6740ff6a2 100644 --- a/codersdk/users_test.go +++ b/codersdk/users_test.go @@ -5,10 +5,11 @@ import ( "net/http" "testing" + "github.com/stretchr/testify/require" + "github.com/coder/coder/coderd" "github.com/coder/coder/coderd/coderdtest" "github.com/coder/coder/codersdk" - "github.com/stretchr/testify/require" ) func TestUsers(t *testing.T) { diff --git a/cryptorand/numbers.go b/cryptorand/numbers.go index e685aa41d0..c782f03e71 100644 --- a/cryptorand/numbers.go +++ b/cryptorand/numbers.go @@ -73,42 +73,43 @@ func Int() (int, error) { return int(i), nil } -// Int63n returns a non-negative random integer in [0,n) as a int64. -func Int63n(n int64) (int64, error) { - if n <= 0 { +// Int63n returns a non-negative random integer in [0,max) as a int64. +func Int63n(max int64) (int64, error) { + if max <= 0 { panic("invalid argument to Int63n") } - max := int64((1 << 63) - 1 - (1<<63)%uint64(n)) + trueMax := int64((1 << 63) - 1 - (1<<63)%uint64(max)) i, err := Int63() if err != nil { return 0, err } - for i > max { + for i > trueMax { i, err = Int63() if err != nil { return 0, err } } - return i % n, nil + return i % max, nil } -// Int31n returns a non-negative integer in [0,n) as a int32. -func Int31n(n int32) (int32, error) { +// Int31n returns a non-negative integer in [0,max) as a int32. +func Int31n(max int32) (int32, error) { i, err := Uint32() if err != nil { return 0, err } - return UnbiasedModulo32(i, n) + return UnbiasedModulo32(i, max) } // UnbiasedModulo32 uniformly modulos v by n over a sufficiently large data // set, regenerating v if necessary. n must be > 0. All input bits in v must be // fully random, you cannot cast a random uint8/uint16 for input into this // function. +//nolint:varnamelen func UnbiasedModulo32(v uint32, n int32) (int32, error) { prod := uint64(v) * uint64(n) low := uint32(prod) @@ -127,14 +128,14 @@ func UnbiasedModulo32(v uint32, n int32) (int32, error) { return int32(prod >> 32), nil } -// Intn returns a non-negative integer in [0,n) as a int. -func Intn(n int) (int, error) { - if n <= 0 { +// Intn returns a non-negative integer in [0,max) as a int. +func Intn(max int) (int, error) { + if max <= 0 { panic("n must be a positive nonzero number") } - if n <= 1<<31-1 { - i, err := Int31n(int32(n)) + if max <= 1<<31-1 { + i, err := Int31n(int32(max)) if err != nil { return 0, err } @@ -142,7 +143,7 @@ func Intn(n int) (int, error) { return int(i), nil } - i, err := Int63n(int64(n)) + i, err := Int63n(int64(max)) if err != nil { return 0, err } diff --git a/cryptorand/numbers_test.go b/cryptorand/numbers_test.go index 612ec5b7f0..b1602df404 100644 --- a/cryptorand/numbers_test.go +++ b/cryptorand/numbers_test.go @@ -5,8 +5,9 @@ import ( "encoding/binary" "testing" - "github.com/coder/coder/cryptorand" "github.com/stretchr/testify/require" + + "github.com/coder/coder/cryptorand" ) func TestInt63(t *testing.T) { @@ -144,7 +145,7 @@ func TestBool(t *testing.T) { const iterations = 10000 trueCount := 0 - for i := 0; i < iterations; i += 1 { + for i := 0; i < iterations; i++ { v, err := cryptorand.Bool() require.NoError(t, err, "unexpected error from Bool") if v { diff --git a/cryptorand/strings.go b/cryptorand/strings.go index 897a91d105..fd16ba0af7 100644 --- a/cryptorand/strings.go +++ b/cryptorand/strings.go @@ -53,7 +53,7 @@ func StringCharset(charSetStr string, size int) (string, error) { buf.Grow(size) for i := 0; i < size; i++ { - c, err := UnbiasedModulo32( + count, err := UnbiasedModulo32( binary.BigEndian.Uint32(ibuf[i*4:(i+1)*4]), int32(len(charSet)), ) @@ -61,7 +61,7 @@ func StringCharset(charSetStr string, size int) (string, error) { return "", err } - _, _ = buf.WriteRune(charSet[c]) + _, _ = buf.WriteRune(charSet[count]) } return buf.String(), nil diff --git a/cryptorand/strings_test.go b/cryptorand/strings_test.go index df81a85320..3f6025e0f9 100644 --- a/cryptorand/strings_test.go +++ b/cryptorand/strings_test.go @@ -8,8 +8,9 @@ import ( "testing" "unicode/utf8" - "github.com/coder/coder/cryptorand" "github.com/stretchr/testify/require" + + "github.com/coder/coder/cryptorand" ) func TestString(t *testing.T) { diff --git a/database/databasefake/databasefake.go b/database/databasefake/databasefake.go index ec9ad6e7e1..96b39fc99c 100644 --- a/database/databasefake/databasefake.go +++ b/database/databasefake/databasefake.go @@ -22,11 +22,11 @@ type fakeQuerier struct { } // InTx doesn't rollback data properly for in-memory yet. -func (q *fakeQuerier) InTx(ctx context.Context, fn func(database.Store) error) error { +func (q *fakeQuerier) InTx(fn func(database.Store) error) error { return fn(q) } -func (q *fakeQuerier) GetAPIKeyByID(ctx context.Context, id string) (database.APIKey, error) { +func (q *fakeQuerier) GetAPIKeyByID(_ context.Context, id string) (database.APIKey, error) { for _, apiKey := range q.apiKeys { if apiKey.ID == id { return apiKey, nil @@ -35,7 +35,7 @@ func (q *fakeQuerier) GetAPIKeyByID(ctx context.Context, id string) (database.AP return database.APIKey{}, sql.ErrNoRows } -func (q *fakeQuerier) GetUserByEmailOrUsername(ctx context.Context, arg database.GetUserByEmailOrUsernameParams) (database.User, error) { +func (q *fakeQuerier) GetUserByEmailOrUsername(_ context.Context, arg database.GetUserByEmailOrUsernameParams) (database.User, error) { for _, user := range q.users { if user.Email == arg.Email || user.Username == arg.Username { return user, nil @@ -44,7 +44,7 @@ func (q *fakeQuerier) GetUserByEmailOrUsername(ctx context.Context, arg database return database.User{}, sql.ErrNoRows } -func (q *fakeQuerier) GetUserByID(ctx context.Context, id string) (database.User, error) { +func (q *fakeQuerier) GetUserByID(_ context.Context, id string) (database.User, error) { for _, user := range q.users { if user.ID == id { return user, nil @@ -53,11 +53,12 @@ func (q *fakeQuerier) GetUserByID(ctx context.Context, id string) (database.User return database.User{}, sql.ErrNoRows } -func (q *fakeQuerier) GetUserCount(ctx context.Context) (int64, error) { +func (q *fakeQuerier) GetUserCount(_ context.Context) (int64, error) { return int64(len(q.users)), nil } -func (q *fakeQuerier) InsertAPIKey(ctx context.Context, arg database.InsertAPIKeyParams) (database.APIKey, error) { +func (q *fakeQuerier) InsertAPIKey(_ context.Context, arg database.InsertAPIKeyParams) (database.APIKey, error) { + //nolint:gosimple key := database.APIKey{ ID: arg.ID, HashedSecret: arg.HashedSecret, @@ -79,7 +80,7 @@ func (q *fakeQuerier) InsertAPIKey(ctx context.Context, arg database.InsertAPIKe return key, nil } -func (q *fakeQuerier) InsertUser(ctx context.Context, arg database.InsertUserParams) (database.User, error) { +func (q *fakeQuerier) InsertUser(_ context.Context, arg database.InsertUserParams) (database.User, error) { user := database.User{ ID: arg.ID, Email: arg.Email, @@ -94,7 +95,7 @@ func (q *fakeQuerier) InsertUser(ctx context.Context, arg database.InsertUserPar return user, nil } -func (q *fakeQuerier) UpdateAPIKeyByID(ctx context.Context, arg database.UpdateAPIKeyByIDParams) error { +func (q *fakeQuerier) UpdateAPIKeyByID(_ context.Context, arg database.UpdateAPIKeyByIDParams) error { for index, apiKey := range q.apiKeys { if apiKey.ID != arg.ID { continue diff --git a/database/db.go b/database/db.go index 8fce646dce..a4399ee953 100644 --- a/database/db.go +++ b/database/db.go @@ -21,7 +21,7 @@ import ( type Store interface { querier - InTx(context.Context, func(Store) error) error + InTx(func(Store) error) error } // DBTX represents a database connection or transaction. @@ -46,16 +46,16 @@ type sqlQuerier struct { } // InTx performs database operations inside a transaction. -func (q *sqlQuerier) InTx(ctx context.Context, fn func(Store) error) error { +func (q *sqlQuerier) InTx(function func(Store) error) error { if q.sdb == nil { return nil } - tx, err := q.sdb.Begin() + transaction, err := q.sdb.Begin() if err != nil { return xerrors.Errorf("begin transaction: %w", err) } defer func() { - rerr := tx.Rollback() + rerr := transaction.Rollback() if rerr == nil || errors.Is(rerr, sql.ErrTxDone) { // no need to do anything, tx committed successfully return @@ -63,11 +63,11 @@ func (q *sqlQuerier) InTx(ctx context.Context, fn func(Store) error) error { // couldn't roll back for some reason, extend returned error err = xerrors.Errorf("defer (%s): %w", rerr.Error(), err) }() - err = fn(&sqlQuerier{db: tx}) + err = function(&sqlQuerier{db: transaction}) if err != nil { return xerrors.Errorf("execute transaction: %w", err) } - err = tx.Commit() + err = transaction.Commit() if err != nil { return xerrors.Errorf("commit transaction: %w", err) } diff --git a/database/dump/main.go b/database/dump/main.go index 91dc0e2f0c..00e7e15271 100644 --- a/database/dump/main.go +++ b/database/dump/main.go @@ -2,7 +2,6 @@ package main import ( "bytes" - "context" "database/sql" "fmt" "io/ioutil" @@ -25,7 +24,7 @@ func main() { if err != nil { panic(err) } - err = database.Migrate(context.Background(), db) + err = database.Migrate(db) if err != nil { panic(err) } @@ -82,7 +81,7 @@ func main() { if !ok { panic("couldn't get caller path") } - err = ioutil.WriteFile(filepath.Join(mainPath, "..", "..", "dump.sql"), []byte(dump), 0644) + err = ioutil.WriteFile(filepath.Join(mainPath, "..", "..", "dump.sql"), []byte(dump), 0600) if err != nil { panic(err) } diff --git a/database/migrate.go b/database/migrate.go index f3dff135bd..3722652f48 100644 --- a/database/migrate.go +++ b/database/migrate.go @@ -1,7 +1,6 @@ package database import ( - "context" "database/sql" "embed" "errors" @@ -16,7 +15,7 @@ import ( var migrations embed.FS // Migrate runs SQL migrations to ensure the database schema is up-to-date. -func Migrate(ctx context.Context, db *sql.DB) error { +func Migrate(db *sql.DB) error { sourceDriver, err := iofs.New(migrations, "migrations") if err != nil { return xerrors.Errorf("create iofs: %w", err) diff --git a/database/migrate_test.go b/database/migrate_test.go index 270292805e..d16671198b 100644 --- a/database/migrate_test.go +++ b/database/migrate_test.go @@ -3,7 +3,6 @@ package database_test import ( - "context" "database/sql" "testing" @@ -27,6 +26,6 @@ func TestMigrate(t *testing.T) { db, err := sql.Open("postgres", connection) require.NoError(t, err) defer db.Close() - err = database.Migrate(context.Background(), db) + err = database.Migrate(db) require.NoError(t, err) } diff --git a/database/postgres/postgres.go b/database/postgres/postgres.go index 9ca57125aa..3b99297998 100644 --- a/database/postgres/postgres.go +++ b/database/postgres/postgres.go @@ -3,7 +3,6 @@ package postgres import ( "database/sql" "fmt" - "log" "time" "github.com/ory/dockertest/v3" @@ -32,13 +31,16 @@ func Open() (string, func(), error) { config.RestartPolicy = docker.RestartPolicy{Name: "no"} }) if err != nil { - log.Fatalf("Could not start resource: %s", err) + return "", nil, xerrors.Errorf("could not start resource: %w", err) } hostAndPort := resource.GetHostPort("5432/tcp") dbURL := fmt.Sprintf("postgres://postgres:postgres@%s/postgres?sslmode=disable", hostAndPort) // Docker should hard-kill the container after 120 seconds. - resource.Expire(120) + err = resource.Expire(120) + if err != nil { + return "", nil, xerrors.Errorf("could not expire resource: %w", err) + } pool.MaxWait = 120 * time.Second err = pool.Retry(func() error { diff --git a/database/pubsub.go b/database/pubsub.go index 656e923246..c9cae23653 100644 --- a/database/pubsub.go +++ b/database/pubsub.go @@ -122,7 +122,7 @@ func (p *pgPubsub) listenReceive(ctx context.Context, notif *pq.Notification) { } // NewPubsub creates a new Pubsub implementation using a PostgreSQL connection. -func NewPubsub(ctx context.Context, db *sql.DB, connectURL string) (Pubsub, error) { +func NewPubsub(ctx context.Context, database *sql.DB, connectURL string) (Pubsub, error) { // Creates a new listener using pq. errCh := make(chan error) listener := pq.NewListener(connectURL, time.Second*10, time.Minute, func(event pq.ListenerEventType, err error) { @@ -144,12 +144,12 @@ func NewPubsub(ctx context.Context, db *sql.DB, connectURL string) (Pubsub, erro case <-ctx.Done(): return nil, ctx.Err() } - pg := &pgPubsub{ - db: db, + pgPubsub := &pgPubsub{ + db: database, pgListener: listener, listeners: make(map[string]map[string]Listener), } - go pg.listen(ctx) + go pgPubsub.listen(ctx) - return pg, nil + return pgPubsub, nil } diff --git a/database/pubsub_memory.go b/database/pubsub_memory.go index 92244f8bbc..148d2f57b1 100644 --- a/database/pubsub_memory.go +++ b/database/pubsub_memory.go @@ -52,7 +52,7 @@ func (m *memoryPubsub) Publish(event string, message []byte) error { return nil } -func (m *memoryPubsub) Close() error { +func (*memoryPubsub) Close() error { return nil } diff --git a/database/pubsub_memory_test.go b/database/pubsub_memory_test.go index bf9f200558..2a930fb3e0 100644 --- a/database/pubsub_memory_test.go +++ b/database/pubsub_memory_test.go @@ -17,9 +17,9 @@ func TestPubsubMemory(t *testing.T) { pubsub := database.NewPubsubInMemory() event := "test" data := "testing" - ch := make(chan []byte) + messageChannel := make(chan []byte) cancelFunc, err := pubsub.Subscribe(event, func(ctx context.Context, message []byte) { - ch <- message + messageChannel <- message }) require.NoError(t, err) defer cancelFunc() @@ -27,7 +27,7 @@ func TestPubsubMemory(t *testing.T) { err = pubsub.Publish(event, []byte(data)) require.NoError(t, err) }() - message := <-ch + message := <-messageChannel assert.Equal(t, string(message), data) }) } diff --git a/database/pubsub_test.go b/database/pubsub_test.go index c563a1bcca..0a1eba426f 100644 --- a/database/pubsub_test.go +++ b/database/pubsub_test.go @@ -32,9 +32,9 @@ func TestPubsub(t *testing.T) { defer pubsub.Close() event := "test" data := "testing" - ch := make(chan []byte) + messageChannel := make(chan []byte) cancelFunc, err = pubsub.Subscribe(event, func(ctx context.Context, message []byte) { - ch <- message + messageChannel <- message }) require.NoError(t, err) defer cancelFunc() @@ -42,7 +42,7 @@ func TestPubsub(t *testing.T) { err = pubsub.Publish(event, []byte(data)) require.NoError(t, err) }() - message := <-ch + message := <-messageChannel assert.Equal(t, string(message), data) }) } diff --git a/httpapi/httpapi.go b/httpapi/httpapi.go index 1613340b8a..52fabee369 100644 --- a/httpapi/httpapi.go +++ b/httpapi/httpapi.go @@ -31,7 +31,7 @@ func init() { } return name }) - validate.RegisterValidation("username", func(fl validator.FieldLevel) bool { + err := validate.RegisterValidation("username", func(fl validator.FieldLevel) bool { f := fl.Field().Interface() str, ok := f.(string) if !ok { @@ -45,6 +45,9 @@ func init() { } return usernameRegex.MatchString(str) }) + if err != nil { + panic(err) + } } // Response represents a generic HTTP response. @@ -60,20 +63,20 @@ type Error struct { } // Write outputs a standardized format to an HTTP response body. -func Write(w http.ResponseWriter, status int, response Response) { +func Write(rw http.ResponseWriter, status int, response Response) { buf := &bytes.Buffer{} enc := json.NewEncoder(buf) enc.SetEscapeHTML(true) err := enc.Encode(response) if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + http.Error(rw, err.Error(), http.StatusInternalServerError) return } - w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.WriteHeader(status) - _, err = w.Write(buf.Bytes()) + rw.Header().Set("Content-Type", "application/json; charset=utf-8") + rw.WriteHeader(status) + _, err = rw.Write(buf.Bytes()) if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + http.Error(rw, err.Error(), http.StatusInternalServerError) return } } diff --git a/httpmw/apikey.go b/httpmw/apikey.go index 09bb4652fa..45a93c2f91 100644 --- a/httpmw/apikey.go +++ b/httpmw/apikey.go @@ -58,22 +58,22 @@ func ExtractAPIKey(db database.Store, oauthConfig OAuth2Config) func(http.Handle }) return } - id := parts[0] - secret := parts[1] + keyID := parts[0] + keySecret := parts[1] // Ensuring key lengths are valid. - if len(id) != 10 { + if len(keyID) != 10 { httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{ Message: fmt.Sprintf("invalid %q cookie api key id", AuthCookie), }) return } - if len(secret) != 22 { + if len(keySecret) != 22 { httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{ Message: fmt.Sprintf("invalid %q cookie api key secret", AuthCookie), }) return } - key, err := db.GetAPIKeyByID(r.Context(), id) + key, err := db.GetAPIKeyByID(r.Context(), keyID) if err != nil { if errors.Is(err, sql.ErrNoRows) { httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{ @@ -86,7 +86,7 @@ func ExtractAPIKey(db database.Store, oauthConfig OAuth2Config) func(http.Handle }) return } - hashed := sha256.Sum256([]byte(secret)) + hashed := sha256.Sum256([]byte(keySecret)) // Checking to see if the secret is valid. if subtle.ConstantTimeCompare(key.HashedSecret, hashed[:]) != 1 { diff --git a/httpmw/apikey_test.go b/httpmw/apikey_test.go index 7a7ff25a5a..6bdebcde6e 100644 --- a/httpmw/apikey_test.go +++ b/httpmw/apikey_test.go @@ -19,9 +19,9 @@ import ( "github.com/coder/coder/httpmw" ) -func randomAPIKeyParts() (string, string) { - id, _ := cryptorand.String(10) - secret, _ := cryptorand.String(22) +func randomAPIKeyParts() (id string, secret string) { + id, _ = cryptorand.String(10) + secret, _ = cryptorand.String(22) return id, secret } @@ -41,7 +41,9 @@ func TestAPIKey(t *testing.T) { rw = httptest.NewRecorder() ) httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r) - require.Equal(t, http.StatusUnauthorized, rw.Result().StatusCode) + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusUnauthorized, res.StatusCode) }) t.Run("InvalidFormat", func(t *testing.T) { @@ -56,7 +58,9 @@ func TestAPIKey(t *testing.T) { }) httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r) - require.Equal(t, http.StatusUnauthorized, rw.Result().StatusCode) + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusUnauthorized, res.StatusCode) }) t.Run("InvalidIDLength", func(t *testing.T) { @@ -71,7 +75,9 @@ func TestAPIKey(t *testing.T) { }) httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r) - require.Equal(t, http.StatusUnauthorized, rw.Result().StatusCode) + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusUnauthorized, res.StatusCode) }) t.Run("InvalidSecretLength", func(t *testing.T) { @@ -86,7 +92,9 @@ func TestAPIKey(t *testing.T) { }) httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r) - require.Equal(t, http.StatusUnauthorized, rw.Result().StatusCode) + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusUnauthorized, res.StatusCode) }) t.Run("NotFound", func(t *testing.T) { @@ -102,7 +110,9 @@ func TestAPIKey(t *testing.T) { }) httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r) - require.Equal(t, http.StatusUnauthorized, rw.Result().StatusCode) + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusUnauthorized, res.StatusCode) }) t.Run("InvalidSecret", func(t *testing.T) { @@ -125,7 +135,9 @@ func TestAPIKey(t *testing.T) { }) require.NoError(t, err) httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r) - require.Equal(t, http.StatusUnauthorized, rw.Result().StatusCode) + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusUnauthorized, res.StatusCode) }) t.Run("Expired", func(t *testing.T) { @@ -147,7 +159,9 @@ func TestAPIKey(t *testing.T) { }) require.NoError(t, err) httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r) - require.Equal(t, http.StatusUnauthorized, rw.Result().StatusCode) + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusUnauthorized, res.StatusCode) }) t.Run("Valid", func(t *testing.T) { @@ -177,7 +191,9 @@ func TestAPIKey(t *testing.T) { Message: "it worked!", }) })).ServeHTTP(rw, r) - require.Equal(t, http.StatusOK, rw.Result().StatusCode) + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusOK, res.StatusCode) gotAPIKey, err := db.GetAPIKeyByID(r.Context(), id) require.NoError(t, err) @@ -207,7 +223,9 @@ func TestAPIKey(t *testing.T) { }) require.NoError(t, err) httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r) - require.Equal(t, http.StatusOK, rw.Result().StatusCode) + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusOK, res.StatusCode) gotAPIKey, err := db.GetAPIKeyByID(r.Context(), id) require.NoError(t, err) @@ -237,7 +255,9 @@ func TestAPIKey(t *testing.T) { }) require.NoError(t, err) httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r) - require.Equal(t, http.StatusOK, rw.Result().StatusCode) + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusOK, res.StatusCode) gotAPIKey, err := db.GetAPIKeyByID(r.Context(), id) require.NoError(t, err) @@ -268,7 +288,9 @@ func TestAPIKey(t *testing.T) { }) require.NoError(t, err) httpmw.ExtractAPIKey(db, nil)(successHandler).ServeHTTP(rw, r) - require.Equal(t, http.StatusOK, rw.Result().StatusCode) + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusOK, res.StatusCode) gotAPIKey, err := db.GetAPIKeyByID(r.Context(), id) require.NoError(t, err) @@ -310,7 +332,9 @@ func TestAPIKey(t *testing.T) { }, }, })(successHandler).ServeHTTP(rw, r) - require.Equal(t, http.StatusOK, rw.Result().StatusCode) + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusOK, res.StatusCode) gotAPIKey, err := db.GetAPIKeyByID(r.Context(), id) require.NoError(t, err) @@ -325,7 +349,7 @@ type oauth2Config struct { tokenSource *oauth2TokenSource } -func (o *oauth2Config) TokenSource(ctx context.Context, token *oauth2.Token) oauth2.TokenSource { +func (o *oauth2Config) TokenSource(_ context.Context, _ *oauth2.Token) oauth2.TokenSource { return o.tokenSource } diff --git a/peer/channel.go b/peer/channel.go index f61f35590c..2d5f62c006 100644 --- a/peer/channel.go +++ b/peer/channel.go @@ -28,7 +28,7 @@ const ( // the channel on open. The datachannel should not be manually // mutated after being passed to this function. func newChannel(conn *Conn, dc *webrtc.DataChannel, opts *ChannelOpts) *Channel { - c := &Channel{ + channel := &Channel{ opts: opts, conn: conn, dc: dc, @@ -37,8 +37,8 @@ func newChannel(conn *Conn, dc *webrtc.DataChannel, opts *ChannelOpts) *Channel closed: make(chan struct{}), sendMore: make(chan struct{}, 1), } - c.init() - return c + channel.init() + return channel } type ChannelOpts struct { @@ -109,6 +109,8 @@ func (c *Channel) init() { return } select { + case <-c.closed: + return case c.sendMore <- struct{}{}: default: } @@ -167,7 +169,7 @@ func (c *Channel) init() { // Read blocks until data is received. // // This will block until the underlying DataChannel has been opened. -func (c *Channel) Read(b []byte) (n int, err error) { +func (c *Channel) Read(bytes []byte) (int, error) { if c.isClosed() { return 0, c.closeError } @@ -178,7 +180,7 @@ func (c *Channel) Read(b []byte) (n int, err error) { } } - n, err = c.rwc.Read(b) + bytesRead, err := c.rwc.Read(bytes) if err != nil { if c.isClosed() { return 0, c.closeError @@ -189,9 +191,9 @@ func (c *Channel) Read(b []byte) (n int, err error) { if xerrors.Is(err, io.EOF) { err = c.closeWithError(ErrClosed) } - return + return bytesRead, err } - return + return bytesRead, err } // Write sends data to the underlying DataChannel. @@ -202,8 +204,8 @@ func (c *Channel) Read(b []byte) (n int, err error) { // // If the Channel is setup to close on disconnect, any buffered // data will be lost. -func (c *Channel) Write(b []byte) (n int, err error) { - if len(b) > maxMessageLength { +func (c *Channel) Write(bytes []byte) (n int, err error) { + if len(bytes) > maxMessageLength { return 0, xerrors.Errorf("outbound packet larger than maximum message size: %d", maxMessageLength) } @@ -220,7 +222,7 @@ func (c *Channel) Write(b []byte) (n int, err error) { } } - if c.dc.BufferedAmount()+uint64(len(b)) >= maxBufferedAmount { + if c.dc.BufferedAmount()+uint64(len(bytes)) >= maxBufferedAmount { <-c.sendMore } // TODO (@kyle): There's an obvious race-condition here. @@ -230,7 +232,7 @@ func (c *Channel) Write(b []byte) (n int, err error) { // See: https://github.com/pion/sctp/issues/181 time.Sleep(time.Microsecond) - return c.rwc.Write(b) + return c.rwc.Write(bytes) } // Close gracefully closes the DataChannel. diff --git a/peer/conn.go b/peer/conn.go index f8f0e437bc..721e6f0f45 100644 --- a/peer/conn.go +++ b/peer/conn.go @@ -42,6 +42,7 @@ func Server(servers []webrtc.ICEServer, opts *ConnOpts) (*Conn, error) { } // newWithClientOrServer constructs a new connection with the client option. +// nolint:revive func newWithClientOrServer(servers []webrtc.ICEServer, client bool, opts *ConnOpts) (*Conn, error) { if opts == nil { opts = &ConnOpts{} @@ -60,7 +61,7 @@ func newWithClientOrServer(servers []webrtc.ICEServer, client bool, opts *ConnOp if err != nil { return nil, xerrors.Errorf("create peer connection: %w", err) } - c := &Conn{ + conn := &Conn{ pingChannelID: 1, pingEchoChannelID: 2, opts: opts, @@ -77,13 +78,13 @@ func newWithClientOrServer(servers []webrtc.ICEServer, client bool, opts *ConnOp if client { // If we're the client, we want to flip the echo and // ping channel IDs so pings don't accidentally hit each other. - c.pingChannelID, c.pingEchoChannelID = c.pingEchoChannelID, c.pingChannelID + conn.pingChannelID, conn.pingEchoChannelID = conn.pingEchoChannelID, conn.pingChannelID } - err = c.init() + err = conn.init() if err != nil { return nil, xerrors.Errorf("init: %w", err) } - return c, nil + return conn, nil } type ConnOpts struct { @@ -142,6 +143,10 @@ func (c *Conn) init() error { } }) c.rtc.OnConnectionStateChange(func(pcs webrtc.PeerConnectionState) { + // Close must be locked here otherwise log output can appear + // after the connection has been closed. + c.closeMutex.Lock() + defer c.closeMutex.Unlock() if c.isClosed() { return } @@ -211,12 +216,12 @@ func (c *Conn) pingEchoChannel() (*Channel, error) { if c.isClosed() { return } - _ = c.closeWithError(xerrors.Errorf("read ping echo channel: %w", err)) + _ = c.CloseWithError(xerrors.Errorf("read ping echo channel: %w", err)) return } _, err = c.pingEchoChan.Write(data[:bytesRead]) if err != nil { - _ = c.closeWithError(xerrors.Errorf("write ping echo channel: %w", err)) + _ = c.CloseWithError(xerrors.Errorf("write ping echo channel: %w", err)) return } } @@ -237,12 +242,12 @@ func (c *Conn) negotiate() { if c.offerrer { offer, err := c.rtc.CreateOffer(&webrtc.OfferOptions{}) if err != nil { - _ = c.closeWithError(xerrors.Errorf("create offer: %w", err)) + _ = c.CloseWithError(xerrors.Errorf("create offer: %w", err)) return } err = c.rtc.SetLocalDescription(offer) if err != nil { - _ = c.closeWithError(xerrors.Errorf("set local description: %w", err)) + _ = c.CloseWithError(xerrors.Errorf("set local description: %w", err)) return } select { @@ -261,19 +266,19 @@ func (c *Conn) negotiate() { err := c.rtc.SetRemoteDescription(remoteDescription) if err != nil { - _ = c.closeWithError(xerrors.Errorf("set remote description (closed %v): %w", c.isClosed(), err)) + _ = c.CloseWithError(xerrors.Errorf("set remote description (closed %v): %w", c.isClosed(), err)) return } if !c.offerrer { answer, err := c.rtc.CreateAnswer(&webrtc.AnswerOptions{}) if err != nil { - _ = c.closeWithError(xerrors.Errorf("create answer: %w", err)) + _ = c.CloseWithError(xerrors.Errorf("create answer: %w", err)) return } err = c.rtc.SetLocalDescription(answer) if err != nil { - _ = c.closeWithError(xerrors.Errorf("set local description: %w", err)) + _ = c.CloseWithError(xerrors.Errorf("set local description: %w", err)) return } if c.isClosed() { @@ -296,20 +301,20 @@ func (c *Conn) proxyICECandidates() func() { queue = []webrtc.ICECandidateInit{} flushed = false ) - c.rtc.OnICECandidate(func(i *webrtc.ICECandidate) { - if i == nil { + c.rtc.OnICECandidate(func(iceCandidate *webrtc.ICECandidate) { + if iceCandidate == nil { return } mut.Lock() defer mut.Unlock() if !flushed { - queue = append(queue, i.ToJSON()) + queue = append(queue, iceCandidate.ToJSON()) return } select { case <-c.closed: return - case c.localCandidateChannel <- i.ToJSON(): + case c.localCandidateChannel <- iceCandidate.ToJSON(): } }) return func() { @@ -353,7 +358,7 @@ func (c *Conn) SetConfiguration(configuration webrtc.Configuration) error { } // SetRemoteSessionDescription sets the remote description for the WebRTC connection. -func (c *Conn) SetRemoteSessionDescription(s webrtc.SessionDescription) { +func (c *Conn) SetRemoteSessionDescription(sessionDescription webrtc.SessionDescription) { if c.isClosed() { return } @@ -361,7 +366,7 @@ func (c *Conn) SetRemoteSessionDescription(s webrtc.SessionDescription) { defer c.closeMutex.Unlock() select { case <-c.closed: - case c.remoteSessionDescriptionChannel <- s: + case c.remoteSessionDescriptionChannel <- sessionDescription: } } @@ -407,7 +412,7 @@ func (c *Conn) dialChannel(ctx context.Context, label string, opts *ChannelOpts) return nil, xerrors.Errorf("closed: %w", c.closeError) } - dc, err := c.rtc.CreateDataChannel(label, &webrtc.DataChannelInit{ + dataChannel, err := c.rtc.CreateDataChannel(label, &webrtc.DataChannelInit{ ID: id, Negotiated: &opts.Negotiated, Ordered: &ordered, @@ -416,7 +421,7 @@ func (c *Conn) dialChannel(ctx context.Context, label string, opts *ChannelOpts) if err != nil { return nil, xerrors.Errorf("create data channel: %w", err) } - return newChannel(c, dc, opts), nil + return newChannel(c, dataChannel, opts), nil } // Ping returns the duration it took to round-trip data. @@ -461,12 +466,7 @@ func (c *Conn) Closed() <-chan struct{} { // Close closes the connection and frees all associated resources. func (c *Conn) Close() error { - return c.closeWithError(nil) -} - -// CloseWithError closes the connection; subsequent reads/writes will return the error err. -func (c *Conn) CloseWithError(err error) error { - return c.closeWithError(err) + return c.CloseWithError(nil) } func (c *Conn) isClosed() bool { @@ -478,7 +478,8 @@ func (c *Conn) isClosed() bool { } } -func (c *Conn) closeWithError(err error) error { +// CloseWithError closes the connection; subsequent reads/writes will return the error err. +func (c *Conn) CloseWithError(err error) error { c.closeMutex.Lock() defer c.closeMutex.Unlock() diff --git a/peer/conn_test.go b/peer/conn_test.go index c6912a24c0..7e64bf7747 100644 --- a/peer/conn_test.go +++ b/peer/conn_test.go @@ -67,13 +67,13 @@ func TestConn(t *testing.T) { _, err := server.Ping() require.NoError(t, err) // Create a channel that closes on disconnect. - ch, err := server.Dial(context.Background(), "wow", nil) + channel, err := server.Dial(context.Background(), "wow", nil) assert.NoError(t, err) err = wan.Stop() require.NoError(t, err) // Once the connection is marked as disconnected, this // channel will be closed. - _, err = ch.Read(make([]byte, 4)) + _, err = channel.Read(make([]byte, 4)) assert.ErrorIs(t, err, peer.ErrClosed) err = wan.Start() require.NoError(t, err) @@ -154,26 +154,26 @@ func TestConn(t *testing.T) { _, _ = io.Copy(nc2, nc1) }() go func() { - s := http.Server{ + server := http.Server{ Handler: http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { rw.WriteHeader(200) }), } - defer s.Close() - _ = s.Serve(srv) + defer server.Close() + _ = server.Serve(srv) }() - dt := http.DefaultTransport.(*http.Transport).Clone() + defaultTransport := http.DefaultTransport.(*http.Transport).Clone() var cch *peer.Channel - dt.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { - cch, err = client.Dial(context.Background(), "hello", &peer.ChannelOpts{}) + defaultTransport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + cch, err = client.Dial(ctx, "hello", &peer.ChannelOpts{}) if err != nil { return nil, err } return cch.NetConn(), nil } c := http.Client{ - Transport: dt, + Transport: defaultTransport, } req, err := http.NewRequestWithContext(context.Background(), "GET", "http://localhost/", nil) require.NoError(t, err) @@ -183,7 +183,7 @@ func TestConn(t *testing.T) { require.Equal(t, resp.StatusCode, 200) // Triggers any connections to close. // This test below ensures the DataChannel actually closes. - dt.CloseIdleConnections() + defaultTransport.CloseIdleConnections() err = cch.Close() require.ErrorIs(t, err, peer.ErrClosed) }) @@ -226,13 +226,13 @@ func TestConn(t *testing.T) { } func createPair(t *testing.T) (client *peer.Conn, server *peer.Conn, wan *vnet.Router) { - lf := logging.NewDefaultLoggerFactory() - lf.DefaultLogLevel = logging.LogLevelDisabled + loggingFactory := logging.NewDefaultLoggerFactory() + loggingFactory.DefaultLogLevel = logging.LogLevelDisabled vnetMutex.Lock() defer vnetMutex.Unlock() wan, err := vnet.NewRouter(&vnet.RouterConfig{ CIDR: "1.2.3.0/24", - LoggerFactory: lf, + LoggerFactory: loggingFactory, }) require.NoError(t, err) c1Net := vnet.NewNet(&vnet.NetConfig{ @@ -250,25 +250,25 @@ func createPair(t *testing.T) (client *peer.Conn, server *peer.Conn, wan *vnet.R c1SettingEngine.SetVNet(c1Net) c1SettingEngine.SetPrflxAcceptanceMinWait(0) c1SettingEngine.SetICETimeouts(disconnectedTimeout, failedTimeout, keepAliveInterval) - c1, err := peer.Client([]webrtc.ICEServer{}, &peer.ConnOpts{ + channel1, err := peer.Client([]webrtc.ICEServer{}, &peer.ConnOpts{ SettingEngine: c1SettingEngine, Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug), }) require.NoError(t, err) t.Cleanup(func() { - c1.Close() + channel1.Close() }) c2SettingEngine := webrtc.SettingEngine{} c2SettingEngine.SetVNet(c2Net) c2SettingEngine.SetPrflxAcceptanceMinWait(0) c2SettingEngine.SetICETimeouts(disconnectedTimeout, failedTimeout, keepAliveInterval) - c2, err := peer.Server([]webrtc.ICEServer{}, &peer.ConnOpts{ + channel2, err := peer.Server([]webrtc.ICEServer{}, &peer.ConnOpts{ SettingEngine: c2SettingEngine, Logger: slogtest.Make(t, nil).Named("server").Leveled(slog.LevelDebug), }) require.NoError(t, err) t.Cleanup(func() { - c2.Close() + channel2.Close() }) err = wan.Start() @@ -280,11 +280,11 @@ func createPair(t *testing.T) (client *peer.Conn, server *peer.Conn, wan *vnet.R go func() { for { select { - case c := <-c2.LocalCandidate(): - _ = c1.AddRemoteCandidate(c) - case c := <-c2.LocalSessionDescription(): - c1.SetRemoteSessionDescription(c) - case <-c2.Closed(): + case c := <-channel2.LocalCandidate(): + _ = channel1.AddRemoteCandidate(c) + case c := <-channel2.LocalSessionDescription(): + channel1.SetRemoteSessionDescription(c) + case <-channel2.Closed(): return } } @@ -293,15 +293,15 @@ func createPair(t *testing.T) (client *peer.Conn, server *peer.Conn, wan *vnet.R go func() { for { select { - case c := <-c1.LocalCandidate(): - _ = c2.AddRemoteCandidate(c) - case c := <-c1.LocalSessionDescription(): - c2.SetRemoteSessionDescription(c) - case <-c1.Closed(): + case c := <-channel1.LocalCandidate(): + _ = channel2.AddRemoteCandidate(c) + case c := <-channel1.LocalSessionDescription(): + channel2.SetRemoteSessionDescription(c) + case <-channel1.Closed(): return } } }() - return c1, c2, wan + return channel1, channel2, wan } diff --git a/peer/netconn.go b/peer/netconn.go index 67c3fb55e5..e564c0ecc2 100644 --- a/peer/netconn.go +++ b/peer/netconn.go @@ -10,11 +10,11 @@ type peerAddr struct{} // Statically checks if we properly implement net.Addr. var _ net.Addr = &peerAddr{} -func (a *peerAddr) Network() string { +func (*peerAddr) Network() string { return "peer" } -func (a *peerAddr) String() string { +func (*peerAddr) String() string { return "peer/unknown-addr" } @@ -46,14 +46,14 @@ func (c *fakeNetConn) RemoteAddr() net.Addr { return c.addr } -func (c *fakeNetConn) SetDeadline(_ time.Time) error { +func (*fakeNetConn) SetDeadline(_ time.Time) error { return nil } -func (c *fakeNetConn) SetReadDeadline(_ time.Time) error { +func (*fakeNetConn) SetReadDeadline(_ time.Time) error { return nil } -func (c *fakeNetConn) SetWriteDeadline(_ time.Time) error { +func (*fakeNetConn) SetWriteDeadline(_ time.Time) error { return nil } diff --git a/peerbroker/listen.go b/peerbroker/listen.go index c63d283d0f..03f7f00b59 100644 --- a/peerbroker/listen.go +++ b/peerbroker/listen.go @@ -161,7 +161,6 @@ func (b *peerBrokerService) NegotiateConnection(stream proto.DRPCPeerBroker_Nego Type: webrtc.SDPType(clientToServerMessage.GetOffer().SdpType), SDP: clientToServerMessage.GetOffer().Sdp, }) - break case clientToServerMessage.GetServers() != nil: // Convert protobuf ICE servers to the WebRTC type. iceServers := make([]webrtc.ICEServer, 0, len(clientToServerMessage.GetServers().Servers)) diff --git a/provisioner/terraform/parse.go b/provisioner/terraform/parse.go index a6162a5055..639d1039f4 100644 --- a/provisioner/terraform/parse.go +++ b/provisioner/terraform/parse.go @@ -12,7 +12,7 @@ import ( ) // Parse extracts Terraform variables from source-code. -func (t *terraform) Parse(ctx context.Context, request *proto.Parse_Request) (*proto.Parse_Response, error) { +func (*terraform) Parse(_ context.Context, request *proto.Parse_Request) (*proto.Parse_Response, error) { module, diags := tfconfig.LoadModule(request.Directory) if diags.HasErrors() { return nil, xerrors.Errorf("load module: %w", diags.Err()) diff --git a/provisioner/terraform/parse_test.go b/provisioner/terraform/parse_test.go index 07841f2559..94af39103a 100644 --- a/provisioner/terraform/parse_test.go +++ b/provisioner/terraform/parse_test.go @@ -37,7 +37,7 @@ func TestParse(t *testing.T) { }() api := proto.NewDRPCProvisionerClient(drpcconn.New(client)) - for _, tc := range []struct { + for _, testCase := range []struct { Name string Files map[string]string Response *proto.Parse_Response @@ -83,13 +83,13 @@ func TestParse(t *testing.T) { }}, }, }} { - tc := tc - t.Run(tc.Name, func(t *testing.T) { + testCase := testCase + t.Run(testCase.Name, func(t *testing.T) { t.Parallel() // Write all files to the temporary test directory. directory := t.TempDir() - for path, content := range tc.Files { + for path, content := range testCase.Files { err := os.WriteFile(filepath.Join(directory, path), []byte(content), 0600) require.NoError(t, err) } @@ -100,7 +100,7 @@ func TestParse(t *testing.T) { require.NoError(t, err) // Ensure the want and got are equivalent! - want, err := json.Marshal(tc.Response) + want, err := json.Marshal(testCase.Response) require.NoError(t, err) got, err := json.Marshal(response) require.NoError(t, err) diff --git a/provisioner/terraform/provision.go b/provisioner/terraform/provision.go index c12197d5ef..34f57ff649 100644 --- a/provisioner/terraform/provision.go +++ b/provisioner/terraform/provision.go @@ -15,9 +15,9 @@ import ( // Provision executes `terraform apply`. func (t *terraform) Provision(ctx context.Context, request *proto.Provision_Request) (*proto.Provision_Response, error) { statefilePath := filepath.Join(request.Directory, "terraform.tfstate") - err := os.WriteFile(statefilePath, request.State, 0644) + err := os.WriteFile(statefilePath, request.State, 0600) if err != nil { - return nil, xerrors.Errorf("write statefile %q: %w", err) + return nil, xerrors.Errorf("write statefile %q: %w", statefilePath, err) } terraform, err := tfexec.NewTerraform(request.Directory, t.binaryPath) diff --git a/provisioner/terraform/provision_test.go b/provisioner/terraform/provision_test.go index 10f2fa9892..7d193033f1 100644 --- a/provisioner/terraform/provision_test.go +++ b/provisioner/terraform/provision_test.go @@ -48,7 +48,7 @@ func TestProvision(t *testing.T) { }() api := proto.NewDRPCProvisionerClient(drpcconn.New(client)) - for _, tc := range []struct { + for _, testCase := range []struct { Name string Files map[string]string Request *proto.Provision_Request @@ -93,25 +93,25 @@ func TestProvision(t *testing.T) { }, Error: true, }} { - tc := tc - t.Run(tc.Name, func(t *testing.T) { + testCase := testCase + t.Run(testCase.Name, func(t *testing.T) { t.Parallel() directory := t.TempDir() - for path, content := range tc.Files { - err := os.WriteFile(filepath.Join(directory, path), []byte(content), 0644) + for path, content := range testCase.Files { + err := os.WriteFile(filepath.Join(directory, path), []byte(content), 0600) require.NoError(t, err) } request := &proto.Provision_Request{ Directory: directory, } - if tc.Request != nil { - request.ParameterValues = tc.Request.ParameterValues - request.State = tc.Request.State + if testCase.Request != nil { + request.ParameterValues = testCase.Request.ParameterValues + request.State = testCase.Request.State } response, err := api.Provision(ctx, request) - if tc.Error { + if testCase.Error { require.Error(t, err) return } @@ -121,7 +121,7 @@ func TestProvision(t *testing.T) { resourcesGot, err := json.Marshal(response.Resources) require.NoError(t, err) - resourcesWant, err := json.Marshal(tc.Response.Resources) + resourcesWant, err := json.Marshal(testCase.Response.Resources) require.NoError(t, err) require.Equal(t, string(resourcesWant), string(resourcesGot)) diff --git a/site/embed.go b/site/embed.go index 3cab301c0d..69326b2240 100644 --- a/site/embed.go +++ b/site/embed.go @@ -28,7 +28,7 @@ var site embed.FS // Handler returns an HTTP handler for serving the static site. func Handler() http.Handler { - f, err := fs.Sub(site, "out") + filesystem, err := fs.Sub(site, "out") if err != nil { // This can't happen... Go would throw a compilation error. panic(err) @@ -36,15 +36,15 @@ func Handler() http.Handler { // html files are handled by a text/template. Non-html files // are served by the default file server. - files, err := htmlFiles(f) + files, err := htmlFiles(filesystem) if err != nil { panic(xerrors.Errorf("Failed to return handler for static files. Html files failed to load: %w", err)) } return secureHeaders(&handler{ - fs: f, + fs: filesystem, htmlFiles: files, - h: http.FileServer(http.FS(f)), // All other non-html static files + h: http.FileServer(http.FS(filesystem)), // All other non-html static files }) } @@ -61,15 +61,15 @@ type handler struct { } // filePath returns the filepath of the requested file. -func (h *handler) filePath(p string) string { +func (*handler) filePath(p string) string { if !strings.HasPrefix(p, "/") { p = "/" + p } return strings.TrimPrefix(path.Clean(p), "/") } -func (h *handler) exists(path string) bool { - f, err := h.fs.Open(path) +func (h *handler) exists(filePath string) bool { + f, err := h.fs.Open(filePath) if err == nil { _ = f.Close() } @@ -89,7 +89,7 @@ type csrfState struct { Token string } -func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { +func (h *handler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { // reqFile is the static file requested reqFile := h.filePath(r.URL.Path) state := htmlState{ @@ -100,13 +100,13 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } // First check if it's a file we have in our templates - if h.serveHtml(w, r, reqFile, state) { + if h.serveHTML(rw, r, reqFile, state) { return } // If the original file path exists we serve it. if h.exists(reqFile) { - h.h.ServeHTTP(w, r) + h.h.ServeHTTP(rw, r) return } @@ -117,28 +117,28 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { reqFile = h.filePath(r.URL.Path) // All html files should be served by the htmlFile templates - if h.serveHtml(w, r, reqFile, state) { + if h.serveHTML(rw, r, reqFile, state) { return } // If we don't have the file... we should redirect to `/` // for our single-page-app. r.URL.Path = "/" - if h.serveHtml(w, r, "", state) { + if h.serveHTML(rw, r, "", state) { return } // This will send a correct 404 - h.h.ServeHTTP(w, r) + h.h.ServeHTTP(rw, r) } -func (h *handler) serveHtml(w http.ResponseWriter, r *http.Request, reqPath string, state htmlState) bool { +func (h *handler) serveHTML(rw http.ResponseWriter, r *http.Request, reqPath string, state htmlState) bool { if data, err := h.htmlFiles.renderWithState(reqPath, state); err == nil { if reqPath == "" { // Pass "index.html" to the ServeContent so the ServeContent sets the right content headers. reqPath = "index.html" } - http.ServeContent(w, r, reqPath, time.Time{}, bytes.NewReader(data)) + http.ServeContent(rw, r, reqPath, time.Time{}, bytes.NewReader(data)) return true } return false @@ -150,12 +150,12 @@ type htmlTemplates struct { // renderWithState will render the file using the given nonce if the file exists // as a template. If it does not, it will return an error. -func (t *htmlTemplates) renderWithState(path string, state htmlState) ([]byte, error) { +func (t *htmlTemplates) renderWithState(filePath string, state htmlState) ([]byte, error) { var buf bytes.Buffer - if path == "" { - path = "index.html" + if filePath == "" { + filePath = "index.html" } - err := t.tpls.ExecuteTemplate(&buf, path, state) + err := t.tpls.ExecuteTemplate(&buf, filePath, state) if err != nil { return nil, err } @@ -168,13 +168,6 @@ func (t *htmlTemplates) renderWithState(path string, state htmlState) ([]byte, e // All directives are semi-colon separated as a single string for the csp header. type cspDirectives map[cspFetchDirective][]string -func (s cspDirectives) append(d cspFetchDirective, values ...string) { - if _, ok := s[d]; !ok { - s[d] = make([]string, 0) - } - s[d] = append(s[d], values...) -} - // cspFetchDirective is the list of all constant fetch directives that // can be used/appended to. type cspFetchDirective string @@ -234,7 +227,7 @@ func secureHeaders(next http.Handler) http.Handler { var csp strings.Builder for src, vals := range cspSrcs { - fmt.Fprintf(&csp, "%s %s; ", src, strings.Join(vals, " ")) + _, _ = fmt.Fprintf(&csp, "%s %s; ", src, strings.Join(vals, " ")) } // Permissions-Policy can be used to disabled various browser features that we do not use. @@ -280,16 +273,16 @@ func htmlFiles(files fs.FS) (*htmlTemplates, error) { root := template.New("") rootPath := "." - err := fs.WalkDir(files, rootPath, func(path string, d fs.DirEntry, err error) error { + err := fs.WalkDir(files, rootPath, func(path string, dirEntry fs.DirEntry, err error) error { if err != nil { return err } - if d.IsDir() { + if dirEntry.IsDir() { return nil } - if filepath.Ext(d.Name()) != ".html" { + if filepath.Ext(dirEntry.Name()) != ".html" { return nil } diff --git a/site/embed_test.go b/site/embed_test.go index 031f014685..90f1530b44 100644 --- a/site/embed_test.go +++ b/site/embed_test.go @@ -1,7 +1,9 @@ package site import ( + "context" "io" + "net/http" "net/http/httptest" "testing" @@ -13,8 +15,11 @@ func TestIndexPageRenders(t *testing.T) { srv := httptest.NewServer(Handler()) - resp, err := srv.Client().Get(srv.URL) + req, err := http.NewRequestWithContext(context.Background(), "GET", srv.URL, nil) + require.NoError(t, err) + resp, err := http.DefaultClient.Do(req) require.NoError(t, err, "get index") + defer resp.Body.Close() data, _ := io.ReadAll(resp.Body) require.NotEmpty(t, data, "index should have contents") }