fix: Allow custom Git OAuth URLs (#4758)

Fixes an issue reported in Discord where custom endpoints
weren't working.
This commit is contained in:
Kyle Carberry
2022-10-27 10:38:05 -07:00
committed by GitHub
parent 3e15ee3ba0
commit b34a67e6cb
8 changed files with 63 additions and 13 deletions

View File

@ -242,7 +242,11 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (*t
if err != nil {
return nil, xerrors.Errorf("listen on the ssh port: %w", err)
}
a.closeMutex.Lock()
a.connCloseWait.Add(1)
a.closeMutex.Unlock()
go func() {
defer a.connCloseWait.Done()
for {
conn, err := sshListener.Accept()
if err != nil {
@ -256,7 +260,11 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (*t
if err != nil {
return nil, xerrors.Errorf("listen for reconnecting pty: %w", err)
}
a.closeMutex.Lock()
a.connCloseWait.Add(1)
a.closeMutex.Unlock()
go func() {
defer a.connCloseWait.Done()
for {
conn, err := reconnectingPTYListener.Accept()
if err != nil {
@ -290,7 +298,11 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (*t
if err != nil {
return nil, xerrors.Errorf("listen for speedtest: %w", err)
}
a.closeMutex.Lock()
a.connCloseWait.Add(1)
a.closeMutex.Unlock()
go func() {
defer a.connCloseWait.Done()
for {
conn, err := speedtestListener.Accept()
if err != nil {
@ -311,7 +323,11 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (*t
if err != nil {
return nil, xerrors.Errorf("listen for statistics: %w", err)
}
a.closeMutex.Lock()
a.connCloseWait.Add(1)
a.closeMutex.Unlock()
go func() {
defer a.connCloseWait.Done()
defer statisticsListener.Close()
server := &http.Server{
Handler: a.statisticsHandler(),

View File

@ -462,7 +462,8 @@ func readSliceFromViper[T any](vip *viper.Viper, key string, value any) []T {
if prop == "-" {
prop = fve.Tag.Get("yaml")
}
value := vip.Get(fmt.Sprintf("%s.%d.%s", key, entry, prop))
configKey := fmt.Sprintf("%s.%d.%s", key, entry, prop)
value := vip.Get(configKey)
if value == nil {
continue
}
@ -470,6 +471,11 @@ func readSliceFromViper[T any](vip *viper.Viper, key string, value any) []T {
newType := reflect.Indirect(reflect.New(elementType))
instance = &newType
}
switch instance.Field(i).Type().String() {
case "[]string":
value = vip.GetStringSlice(configKey)
default:
}
instance.Field(i).Set(reflect.ValueOf(value))
}
if instance == nil {

View File

@ -158,6 +158,7 @@ func TestConfig(t *testing.T) {
"CODER_GITAUTH_0_AUTH_URL": "https://auth.com",
"CODER_GITAUTH_0_TOKEN_URL": "https://token.com",
"CODER_GITAUTH_0_REGEX": "github.com",
"CODER_GITAUTH_0_SCOPES": "read write",
"CODER_GITAUTH_1_ID": "another",
"CODER_GITAUTH_1_TYPE": "gitlab",
@ -177,6 +178,7 @@ func TestConfig(t *testing.T) {
AuthURL: "https://auth.com",
TokenURL: "https://token.com",
Regex: "github.com",
Scopes: []string{"read", "write"},
}, {
ID: "another",
Type: "gitlab",

View File

@ -86,6 +86,16 @@ func ConvertConfig(entries []codersdk.GitAuthConfig, accessURL *url.URL) ([]*Con
Scopes: scope[typ],
}
if entry.AuthURL != "" {
oauth2Config.Endpoint.AuthURL = entry.AuthURL
}
if entry.TokenURL != "" {
oauth2Config.Endpoint.TokenURL = entry.TokenURL
}
if entry.Scopes != nil && len(entry.Scopes) > 0 {
oauth2Config.Scopes = entry.Scopes
}
var oauthConfig httpmw.OAuth2Config = oauth2Config
// Azure DevOps uses JWT token authentication!
if typ == codersdk.GitProviderAzureDevops {

View File

@ -75,4 +75,18 @@ func TestConvertYAML(t *testing.T) {
require.Equal(t, tc.Output, output)
})
}
t.Run("CustomScopesAndEndpoint", func(t *testing.T) {
t.Parallel()
config, err := gitauth.ConvertConfig([]codersdk.GitAuthConfig{{
Type: codersdk.GitProviderGitLab,
ClientID: "id",
ClientSecret: "secret",
AuthURL: "https://auth.com",
TokenURL: "https://token.com",
Scopes: []string{"read"},
}}, &url.URL{})
require.NoError(t, err)
require.Equal(t, "https://auth.com?client_id=id&redirect_uri=%2Fgitauth%2Fgitlab%2Fcallback&response_type=code&scope=read", config[0].AuthCodeURL(""))
})
}

View File

@ -34,7 +34,7 @@ func Test_readBodyAsError(t *testing.T) {
longResponse += "a"
}
unexpectedJSON := marshalJSON(map[string]any{
unexpectedJSON := marshal(map[string]any{
"hello": "world",
"foo": "bar",
})
@ -49,7 +49,7 @@ func Test_readBodyAsError(t *testing.T) {
{
name: "JSONWithRequest",
req: httptest.NewRequest(http.MethodGet, exampleURL, nil),
res: newResponse(http.StatusNotFound, jsonCT, marshalJSON(simpleResponse)),
res: newResponse(http.StatusNotFound, jsonCT, marshal(simpleResponse)),
assert: func(t *testing.T, err error) {
sdkErr := assertSDKError(t, err)
@ -72,7 +72,7 @@ func Test_readBodyAsError(t *testing.T) {
{
name: "JSONWithoutRequest",
req: nil,
res: newResponse(http.StatusNotFound, jsonCT, marshalJSON(simpleResponse)),
res: newResponse(http.StatusNotFound, jsonCT, marshal(simpleResponse)),
assert: func(t *testing.T, err error) {
sdkErr := assertSDKError(t, err)
@ -86,7 +86,7 @@ func Test_readBodyAsError(t *testing.T) {
{
name: "UnauthorizedHelper",
req: nil,
res: newResponse(http.StatusUnauthorized, jsonCT, marshalJSON(simpleResponse)),
res: newResponse(http.StatusUnauthorized, jsonCT, marshal(simpleResponse)),
assert: func(t *testing.T, err error) {
sdkErr := assertSDKError(t, err)
@ -190,7 +190,7 @@ func newResponse(status int, contentType string, body interface{}) *http.Respons
}
}
func marshalJSON(res any) string {
func marshal(res any) string {
b, err := json.Marshal(res)
if err != nil {
panic(err)

View File

@ -115,6 +115,7 @@ type GitAuthConfig struct {
AuthURL string `json:"auth_url"`
TokenURL string `json:"token_url"`
Regex string `json:"regex"`
Scopes []string `json:"scopes"`
}
type Flaggable interface {

View File

@ -354,6 +354,7 @@ export interface GitAuthConfig {
readonly auth_url: string
readonly token_url: string
readonly regex: string
readonly scopes: string[]
}
// From codersdk/gitsshkey.go