package promoauth import ( "context" "fmt" "net/http" "time" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" "golang.org/x/oauth2" ) type Oauth2Source string const ( SourceValidateToken Oauth2Source = "ValidateToken" SourceExchange Oauth2Source = "Exchange" SourceTokenSource Oauth2Source = "TokenSource" SourceAppInstallations Oauth2Source = "AppInstallations" SourceAuthorizeDevice Oauth2Source = "AuthorizeDevice" SourceGitAPIAuthUser Oauth2Source = "GitAPIAuthUser" SourceGitAPIListEmails Oauth2Source = "GitAPIListEmails" SourceGitAPIOrgMemberships Oauth2Source = "GitAPIOrgMemberships" SourceGitAPITeamMemberships Oauth2Source = "GitAPITeamMemberships" ) // OAuth2Config exposes a subset of *oauth2.Config functions for easier testing. // *oauth2.Config should be used instead of implementing this in production. type OAuth2Config interface { AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource } // InstrumentedOAuth2Config extends OAuth2Config with a `Do` method that allows // external oauth related calls to be instrumented. This is to support // "ValidateToken" which is not an oauth2 specified method. // These calls still count against the api rate limit, and should be instrumented. type InstrumentedOAuth2Config interface { OAuth2Config // Do is provided as a convenience method to make a request with the oauth2 client. // It mirrors `http.Client.Do`. Do(ctx context.Context, source Oauth2Source, req *http.Request) (*http.Response, error) } var _ OAuth2Config = (*Config)(nil) // Factory allows us to have 1 set of metrics for all oauth2 providers. // Primarily to avoid any prometheus errors registering duplicate metrics. type Factory struct { metrics *metrics // optional replace now func Now func() time.Time } // metrics is the reusable metrics for all oauth2 providers. type metrics struct { externalRequestCount *prometheus.CounterVec // if the oauth supports it, rate limit metrics. // rateLimit is the defined limit per interval rateLimit *prometheus.GaugeVec // TODO: remove deprecated metrics in the future release rateLimitDeprecated *prometheus.GaugeVec rateLimitRemaining *prometheus.GaugeVec rateLimitUsed *prometheus.GaugeVec // rateLimitReset is unix time of the next interval (when the rate limit resets). rateLimitReset *prometheus.GaugeVec // rateLimitResetIn is the time in seconds until the rate limit resets. // This is included because it is sometimes more helpful to know the limit // will reset in 600seconds, rather than at 1704000000 unix time. rateLimitResetIn *prometheus.GaugeVec } func NewFactory(registry prometheus.Registerer) *Factory { factory := promauto.With(registry) return &Factory{ metrics: &metrics{ externalRequestCount: factory.NewCounterVec(prometheus.CounterOpts{ Namespace: "coderd", Subsystem: "oauth2", Name: "external_requests_total", Help: "The total number of api calls made to external oauth2 providers. 'status_code' will be 0 if the request failed with no response.", }, []string{ "name", "source", "status_code", }), rateLimit: factory.NewGaugeVec(prometheus.GaugeOpts{ Namespace: "coderd", Subsystem: "oauth2", Name: "external_requests_rate_limit", Help: "The total number of allowed requests per interval.", }, []string{ "name", // Resource allows different rate limits for the same oauth2 provider. // Some IDPs have different buckets for different rate limits. "resource", }), // TODO: deprecated: remove in the future // See: https://github.com/coder/coder/issues/12999 // Deprecation reason: gauge metrics should avoid suffix `_total`` rateLimitDeprecated: factory.NewGaugeVec(prometheus.GaugeOpts{ Namespace: "coderd", Subsystem: "oauth2", Name: "external_requests_rate_limit_total", Help: "DEPRECATED: use coderd_oauth2_external_requests_rate_limit instead", }, []string{ "name", "resource", }), rateLimitRemaining: factory.NewGaugeVec(prometheus.GaugeOpts{ Namespace: "coderd", Subsystem: "oauth2", Name: "external_requests_rate_limit_remaining", Help: "The remaining number of allowed requests in this interval.", }, []string{ "name", "resource", }), rateLimitUsed: factory.NewGaugeVec(prometheus.GaugeOpts{ Namespace: "coderd", Subsystem: "oauth2", Name: "external_requests_rate_limit_used", Help: "The number of requests made in this interval.", }, []string{ "name", "resource", }), rateLimitReset: factory.NewGaugeVec(prometheus.GaugeOpts{ Namespace: "coderd", Subsystem: "oauth2", Name: "external_requests_rate_limit_next_reset_unix", Help: "Unix timestamp for when the next interval starts", }, []string{ "name", "resource", }), rateLimitResetIn: factory.NewGaugeVec(prometheus.GaugeOpts{ Namespace: "coderd", Subsystem: "oauth2", Name: "external_requests_rate_limit_reset_in_seconds", Help: "Seconds until the next interval", }, []string{ "name", "resource", }), }, } } func (f *Factory) New(name string, under OAuth2Config) *Config { return &Config{ name: name, underlying: under, metrics: f.metrics, } } // NewGithub returns a new instrumented oauth2 config for github. It tracks // rate limits as well as just the external request counts. // //nolint:bodyclose func (f *Factory) NewGithub(name string, under OAuth2Config) *Config { cfg := f.New(name, under) cfg.interceptors = append(cfg.interceptors, func(resp *http.Response, err error) { limits, ok := githubRateLimits(resp, err) if !ok { return } labels := prometheus.Labels{ "name": cfg.name, "resource": limits.Resource, } // Default to -1 for "do not know" resetIn := float64(-1) if !limits.Reset.IsZero() { now := time.Now() if f.Now != nil { now = f.Now() } resetIn = limits.Reset.Sub(now).Seconds() if resetIn < 0 { // If it just reset, just make it 0. resetIn = 0 } } // TODO: remove this metric in v3 f.metrics.rateLimitDeprecated.With(labels).Set(float64(limits.Limit)) f.metrics.rateLimit.With(labels).Set(float64(limits.Limit)) f.metrics.rateLimitRemaining.With(labels).Set(float64(limits.Remaining)) f.metrics.rateLimitUsed.With(labels).Set(float64(limits.Used)) f.metrics.rateLimitReset.With(labels).Set(float64(limits.Reset.Unix())) f.metrics.rateLimitResetIn.With(labels).Set(resetIn) }) return cfg } type Config struct { // Name is a human friendly name to identify the oauth2 provider. This should be // deterministic from restart to restart, as it is going to be used as a label in // prometheus metrics. name string underlying OAuth2Config metrics *metrics // interceptors are called after every request made by the oauth2 client. interceptors []func(resp *http.Response, err error) } func (c *Config) Do(ctx context.Context, source Oauth2Source, req *http.Request) (*http.Response, error) { cli := c.oauthHTTPClient(ctx, source) return cli.Do(req) } func (c *Config) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string { // No external requests are made when constructing the auth code url. return c.underlying.AuthCodeURL(state, opts...) } func (c *Config) Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) { return c.underlying.Exchange(c.wrapClient(ctx, SourceExchange), code, opts...) } func (c *Config) TokenSource(ctx context.Context, token *oauth2.Token) oauth2.TokenSource { return c.underlying.TokenSource(c.wrapClient(ctx, SourceTokenSource), token) } // InstrumentHTTPClient will always return a new http client. The new client will // match the one passed in, but will have an instrumented round tripper. func (c *Config) InstrumentHTTPClient(hc *http.Client, source Oauth2Source) *http.Client { return &http.Client{ // The new tripper will instrument every request made by the oauth2 client. Transport: newInstrumentedTripper(c, source, hc.Transport), CheckRedirect: hc.CheckRedirect, Jar: hc.Jar, Timeout: hc.Timeout, } } // wrapClient is the only way we can accurately instrument the oauth2 client. // This is because method calls to the 'OAuth2Config' interface are not 1:1 with // network requests. // // For example, the 'TokenSource' method will return a token // source that will make a network request when the 'Token' method is called on // it if the token is expired. func (c *Config) wrapClient(ctx context.Context, source Oauth2Source) context.Context { return context.WithValue(ctx, oauth2.HTTPClient, c.oauthHTTPClient(ctx, source)) } // oauthHTTPClient returns an http client that will instrument every request made. func (c *Config) oauthHTTPClient(ctx context.Context, source Oauth2Source) *http.Client { cli := &http.Client{} // Check if the context has a http client already. if hc, ok := ctx.Value(oauth2.HTTPClient).(*http.Client); ok { cli = hc } cli = c.InstrumentHTTPClient(cli, source) return cli } type instrumentedTripper struct { c *Config source Oauth2Source underlying http.RoundTripper } // newInstrumentedTripper intercepts a http request, and increments the // externalRequestCount metric. func newInstrumentedTripper(c *Config, source Oauth2Source, under http.RoundTripper) *instrumentedTripper { if under == nil { under = http.DefaultTransport } return &instrumentedTripper{ c: c, source: source, underlying: under, } } func (i *instrumentedTripper) RoundTrip(r *http.Request) (*http.Response, error) { resp, err := i.underlying.RoundTrip(r) var statusCode int if resp != nil { statusCode = resp.StatusCode } i.c.metrics.externalRequestCount.With(prometheus.Labels{ "name": i.c.name, "source": string(i.source), "status_code": fmt.Sprintf("%d", statusCode), }).Inc() // Handle any extra interceptors. for _, interceptor := range i.c.interceptors { interceptor(resp, err) } return resp, err }