diff --git a/agent/agent.go b/agent/agent.go index 66e009a207..e9ba7672cd 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -242,7 +242,11 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (*t if err != nil { return nil, xerrors.Errorf("listen on the ssh port: %w", err) } + a.closeMutex.Lock() + a.connCloseWait.Add(1) + a.closeMutex.Unlock() go func() { + defer a.connCloseWait.Done() for { conn, err := sshListener.Accept() if err != nil { @@ -256,7 +260,11 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (*t if err != nil { return nil, xerrors.Errorf("listen for reconnecting pty: %w", err) } + a.closeMutex.Lock() + a.connCloseWait.Add(1) + a.closeMutex.Unlock() go func() { + defer a.connCloseWait.Done() for { conn, err := reconnectingPTYListener.Accept() if err != nil { @@ -290,7 +298,11 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (*t if err != nil { return nil, xerrors.Errorf("listen for speedtest: %w", err) } + a.closeMutex.Lock() + a.connCloseWait.Add(1) + a.closeMutex.Unlock() go func() { + defer a.connCloseWait.Done() for { conn, err := speedtestListener.Accept() if err != nil { @@ -311,7 +323,11 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (*t if err != nil { return nil, xerrors.Errorf("listen for statistics: %w", err) } + a.closeMutex.Lock() + a.connCloseWait.Add(1) + a.closeMutex.Unlock() go func() { + defer a.connCloseWait.Done() defer statisticsListener.Close() server := &http.Server{ Handler: a.statisticsHandler(), diff --git a/cli/deployment/config.go b/cli/deployment/config.go index e96a64746f..602f755742 100644 --- a/cli/deployment/config.go +++ b/cli/deployment/config.go @@ -462,7 +462,8 @@ func readSliceFromViper[T any](vip *viper.Viper, key string, value any) []T { if prop == "-" { prop = fve.Tag.Get("yaml") } - value := vip.Get(fmt.Sprintf("%s.%d.%s", key, entry, prop)) + configKey := fmt.Sprintf("%s.%d.%s", key, entry, prop) + value := vip.Get(configKey) if value == nil { continue } @@ -470,6 +471,11 @@ func readSliceFromViper[T any](vip *viper.Viper, key string, value any) []T { newType := reflect.Indirect(reflect.New(elementType)) instance = &newType } + switch instance.Field(i).Type().String() { + case "[]string": + value = vip.GetStringSlice(configKey) + default: + } instance.Field(i).Set(reflect.ValueOf(value)) } if instance == nil { diff --git a/cli/deployment/config_test.go b/cli/deployment/config_test.go index 73a138d193..4fd2cc5978 100644 --- a/cli/deployment/config_test.go +++ b/cli/deployment/config_test.go @@ -158,6 +158,7 @@ func TestConfig(t *testing.T) { "CODER_GITAUTH_0_AUTH_URL": "https://auth.com", "CODER_GITAUTH_0_TOKEN_URL": "https://token.com", "CODER_GITAUTH_0_REGEX": "github.com", + "CODER_GITAUTH_0_SCOPES": "read write", "CODER_GITAUTH_1_ID": "another", "CODER_GITAUTH_1_TYPE": "gitlab", @@ -177,6 +178,7 @@ func TestConfig(t *testing.T) { AuthURL: "https://auth.com", TokenURL: "https://token.com", Regex: "github.com", + Scopes: []string{"read", "write"}, }, { ID: "another", Type: "gitlab", diff --git a/coderd/gitauth/config.go b/coderd/gitauth/config.go index 8216446390..e4c6fdd97e 100644 --- a/coderd/gitauth/config.go +++ b/coderd/gitauth/config.go @@ -86,6 +86,16 @@ func ConvertConfig(entries []codersdk.GitAuthConfig, accessURL *url.URL) ([]*Con Scopes: scope[typ], } + if entry.AuthURL != "" { + oauth2Config.Endpoint.AuthURL = entry.AuthURL + } + if entry.TokenURL != "" { + oauth2Config.Endpoint.TokenURL = entry.TokenURL + } + if entry.Scopes != nil && len(entry.Scopes) > 0 { + oauth2Config.Scopes = entry.Scopes + } + var oauthConfig httpmw.OAuth2Config = oauth2Config // Azure DevOps uses JWT token authentication! if typ == codersdk.GitProviderAzureDevops { diff --git a/coderd/gitauth/config_test.go b/coderd/gitauth/config_test.go index 07ad7f1ac7..0ae1d579d4 100644 --- a/coderd/gitauth/config_test.go +++ b/coderd/gitauth/config_test.go @@ -75,4 +75,18 @@ func TestConvertYAML(t *testing.T) { require.Equal(t, tc.Output, output) }) } + + t.Run("CustomScopesAndEndpoint", func(t *testing.T) { + t.Parallel() + config, err := gitauth.ConvertConfig([]codersdk.GitAuthConfig{{ + Type: codersdk.GitProviderGitLab, + ClientID: "id", + ClientSecret: "secret", + AuthURL: "https://auth.com", + TokenURL: "https://token.com", + Scopes: []string{"read"}, + }}, &url.URL{}) + require.NoError(t, err) + require.Equal(t, "https://auth.com?client_id=id&redirect_uri=%2Fgitauth%2Fgitlab%2Fcallback&response_type=code&scope=read", config[0].AuthCodeURL("")) + }) } diff --git a/codersdk/client_internal_test.go b/codersdk/client_internal_test.go index f4cb49297c..dbb96340f1 100644 --- a/codersdk/client_internal_test.go +++ b/codersdk/client_internal_test.go @@ -34,7 +34,7 @@ func Test_readBodyAsError(t *testing.T) { longResponse += "a" } - unexpectedJSON := marshalJSON(map[string]any{ + unexpectedJSON := marshal(map[string]any{ "hello": "world", "foo": "bar", }) @@ -49,7 +49,7 @@ func Test_readBodyAsError(t *testing.T) { { name: "JSONWithRequest", req: httptest.NewRequest(http.MethodGet, exampleURL, nil), - res: newResponse(http.StatusNotFound, jsonCT, marshalJSON(simpleResponse)), + res: newResponse(http.StatusNotFound, jsonCT, marshal(simpleResponse)), assert: func(t *testing.T, err error) { sdkErr := assertSDKError(t, err) @@ -72,7 +72,7 @@ func Test_readBodyAsError(t *testing.T) { { name: "JSONWithoutRequest", req: nil, - res: newResponse(http.StatusNotFound, jsonCT, marshalJSON(simpleResponse)), + res: newResponse(http.StatusNotFound, jsonCT, marshal(simpleResponse)), assert: func(t *testing.T, err error) { sdkErr := assertSDKError(t, err) @@ -86,7 +86,7 @@ func Test_readBodyAsError(t *testing.T) { { name: "UnauthorizedHelper", req: nil, - res: newResponse(http.StatusUnauthorized, jsonCT, marshalJSON(simpleResponse)), + res: newResponse(http.StatusUnauthorized, jsonCT, marshal(simpleResponse)), assert: func(t *testing.T, err error) { sdkErr := assertSDKError(t, err) @@ -190,7 +190,7 @@ func newResponse(status int, contentType string, body interface{}) *http.Respons } } -func marshalJSON(res any) string { +func marshal(res any) string { b, err := json.Marshal(res) if err != nil { panic(err) diff --git a/codersdk/deploymentconfig.go b/codersdk/deploymentconfig.go index d9702f5f19..35df5733d6 100644 --- a/codersdk/deploymentconfig.go +++ b/codersdk/deploymentconfig.go @@ -108,13 +108,14 @@ type TLSConfig struct { } type GitAuthConfig struct { - ID string `json:"id"` - Type string `json:"type"` - ClientID string `json:"client_id"` - ClientSecret string `json:"-" yaml:"client_secret"` - AuthURL string `json:"auth_url"` - TokenURL string `json:"token_url"` - Regex string `json:"regex"` + ID string `json:"id"` + Type string `json:"type"` + ClientID string `json:"client_id"` + ClientSecret string `json:"-" yaml:"client_secret"` + AuthURL string `json:"auth_url"` + TokenURL string `json:"token_url"` + Regex string `json:"regex"` + Scopes []string `json:"scopes"` } type Flaggable interface { diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index c006c13928..f7fcecbc0f 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -354,6 +354,7 @@ export interface GitAuthConfig { readonly auth_url: string readonly token_url: string readonly regex: string + readonly scopes: string[] } // From codersdk/gitsshkey.go