fix: allow posting licenses that will be valid in future (#14491)

This commit is contained in:
Spike Curtis
2024-09-03 09:09:38 +04:00
committed by GitHub
parent 0785b77d0b
commit 5bd5801286
4 changed files with 106 additions and 24 deletions

View File

@ -174,6 +174,10 @@ type LicenseOptions struct {
// ExpiresAt is the time at which the license will hard expire. // ExpiresAt is the time at which the license will hard expire.
// ExpiresAt should always be greater then GraceAt. // ExpiresAt should always be greater then GraceAt.
ExpiresAt time.Time ExpiresAt time.Time
// NotBefore is the time at which the license becomes valid. If set to the
// zero value, the `nbf` claim on the license is set to 1 minute in the
// past.
NotBefore time.Time
Features license.Features Features license.Features
} }
@ -233,13 +237,16 @@ func GenerateLicense(t *testing.T, options LicenseOptions) string {
if options.GraceAt.IsZero() { if options.GraceAt.IsZero() {
options.GraceAt = time.Now().Add(time.Hour) options.GraceAt = time.Now().Add(time.Hour)
} }
if options.NotBefore.IsZero() {
options.NotBefore = time.Now().Add(-time.Minute)
}
c := &license.Claims{ c := &license.Claims{
RegisteredClaims: jwt.RegisteredClaims{ RegisteredClaims: jwt.RegisteredClaims{
ID: uuid.NewString(), ID: uuid.NewString(),
Issuer: "test@testing.test", Issuer: "test@testing.test",
ExpiresAt: jwt.NewNumericDate(options.ExpiresAt), ExpiresAt: jwt.NewNumericDate(options.ExpiresAt),
NotBefore: jwt.NewNumericDate(time.Now().Add(-time.Minute)), NotBefore: jwt.NewNumericDate(options.NotBefore),
IssuedAt: jwt.NewNumericDate(time.Now().Add(-time.Minute)), IssuedAt: jwt.NewNumericDate(time.Now().Add(-time.Minute)),
}, },
LicenseExpires: jwt.NewNumericDate(options.GraceAt), LicenseExpires: jwt.NewNumericDate(options.GraceAt),

View File

@ -287,6 +287,8 @@ var (
ErrInvalidVersion = xerrors.New("license must be version 3") ErrInvalidVersion = xerrors.New("license must be version 3")
ErrMissingKeyID = xerrors.Errorf("JOSE header must contain %s", HeaderKeyID) ErrMissingKeyID = xerrors.Errorf("JOSE header must contain %s", HeaderKeyID)
ErrMissingLicenseExpires = xerrors.New("license missing license_expires") ErrMissingLicenseExpires = xerrors.New("license missing license_expires")
ErrMissingExp = xerrors.New("exp claim missing or not parsable")
ErrMultipleIssues = xerrors.New("license has multiple issues; contact support")
) )
type Features map[codersdk.FeatureName]int64 type Features map[codersdk.FeatureName]int64
@ -336,7 +338,7 @@ func ParseRaw(l string, keys map[string]ed25519.PublicKey) (jwt.MapClaims, error
return nil, xerrors.New("unable to parse Claims") return nil, xerrors.New("unable to parse Claims")
} }
// ParseClaims validates a database.License record, and if valid, returns the claims. If // ParseClaims validates a raw JWT, and if valid, returns the claims. If
// unparsable or invalid, it returns an error // unparsable or invalid, it returns an error
func ParseClaims(rawJWT string, keys map[string]ed25519.PublicKey) (*Claims, error) { func ParseClaims(rawJWT string, keys map[string]ed25519.PublicKey) (*Claims, error) {
tok, err := jwt.ParseWithClaims( tok, err := jwt.ParseWithClaims(
@ -348,18 +350,53 @@ func ParseClaims(rawJWT string, keys map[string]ed25519.PublicKey) (*Claims, err
if err != nil { if err != nil {
return nil, err return nil, err
} }
if claims, ok := tok.Claims.(*Claims); ok && tok.Valid { return validateClaims(tok)
}
func validateClaims(tok *jwt.Token) (*Claims, error) {
if claims, ok := tok.Claims.(*Claims); ok {
if claims.Version != uint64(CurrentVersion) { if claims.Version != uint64(CurrentVersion) {
return nil, ErrInvalidVersion return nil, ErrInvalidVersion
} }
if claims.LicenseExpires == nil { if claims.LicenseExpires == nil {
return nil, ErrMissingLicenseExpires return nil, ErrMissingLicenseExpires
} }
if claims.ExpiresAt == nil {
return nil, ErrMissingExp
}
return claims, nil return claims, nil
} }
return nil, xerrors.New("unable to parse Claims") return nil, xerrors.New("unable to parse Claims")
} }
// ParseClaimsIgnoreNbf validates a raw JWT, but ignores `nbf` claim. If otherwise valid, it returns
// the claims. If unparsable or invalid, it returns an error. Ignoring the `nbf` (not before) is
// useful to determine if a JWT _will_ become valid at any point now or in the future.
func ParseClaimsIgnoreNbf(rawJWT string, keys map[string]ed25519.PublicKey) (*Claims, error) {
tok, err := jwt.ParseWithClaims(
rawJWT,
&Claims{},
keyFunc(keys),
jwt.WithValidMethods(ValidMethods),
)
var vErr *jwt.ValidationError
if xerrors.As(err, &vErr) {
// zero out the NotValidYet error to check if there were other problems
vErr.Errors = vErr.Errors & (^jwt.ValidationErrorNotValidYet)
if vErr.Errors != 0 {
// There are other errors besides not being valid yet. We _could_ go
// through all the jwt.ValidationError bits and try to work out the
// correct error, but if we get here something very strange is
// going on so let's just return a generic error that says to get in
// touch with our support team.
return nil, ErrMultipleIssues
}
} else if err != nil {
return nil, err
}
return validateClaims(tok)
}
func keyFunc(keys map[string]ed25519.PublicKey) func(*jwt.Token) (interface{}, error) { func keyFunc(keys map[string]ed25519.PublicKey) func(*jwt.Token) (interface{}, error) {
return func(j *jwt.Token) (interface{}, error) { return func(j *jwt.Token) (interface{}, error) {
keyID, ok := j.Header[HeaderKeyID].(string) keyID, ok := j.Header[HeaderKeyID].(string)

View File

@ -86,25 +86,7 @@ func (api *API) postLicense(rw http.ResponseWriter, r *http.Request) {
return return
} }
rawClaims, err := license.ParseRaw(addLicense.License, api.LicenseKeys) claims, err := license.ParseClaimsIgnoreNbf(addLicense.License, api.LicenseKeys)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid license",
Detail: err.Error(),
})
return
}
exp, ok := rawClaims["exp"].(float64)
if !ok {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid license",
Detail: "exp claim missing or not parsable",
})
return
}
expTime := time.Unix(int64(exp), 0)
claims, err := license.ParseClaims(addLicense.License, api.LicenseKeys)
if err != nil { if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid license", Message: "Invalid license",
@ -134,7 +116,7 @@ func (api *API) postLicense(rw http.ResponseWriter, r *http.Request) {
dl, err := api.Database.InsertLicense(ctx, database.InsertLicenseParams{ dl, err := api.Database.InsertLicense(ctx, database.InsertLicenseParams{
UploadedAt: dbtime.Now(), UploadedAt: dbtime.Now(),
JWT: addLicense.License, JWT: addLicense.License,
Exp: expTime, Exp: claims.ExpiresAt.Time,
UUID: id, UUID: id,
}) })
if err != nil { if err != nil {
@ -160,7 +142,15 @@ func (api *API) postLicense(rw http.ResponseWriter, r *http.Request) {
// don't fail the HTTP request, since we did write it successfully to the database // don't fail the HTTP request, since we did write it successfully to the database
} }
httpapi.Write(ctx, rw, http.StatusCreated, convertLicense(dl, rawClaims)) c, err := decodeClaims(dl)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to decode database response",
Detail: err.Error(),
})
return
}
httpapi.Write(ctx, rw, http.StatusCreated, convertLicense(dl, c))
} }
// postRefreshEntitlements forces an `updateEntitlements` call and publishes // postRefreshEntitlements forces an `updateEntitlements` call and publishes

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"net/http" "net/http"
"testing" "testing"
"time"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -82,6 +83,53 @@ func TestPostLicense(t *testing.T) {
t.Error("expected to get error status 400") t.Error("expected to get error status 400")
} }
}) })
// Test a license that isn't yet valid, but will be in the future. We should allow this so that
// operators can upload a license ahead of time.
t.Run("NotYet", func(t *testing.T) {
t.Parallel()
client, _ := coderdenttest.New(t, &coderdenttest.Options{DontAddLicense: true})
respLic := coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
AccountType: license.AccountTypeSalesforce,
AccountID: "testing",
Features: license.Features{
codersdk.FeatureAuditLog: 1,
},
NotBefore: time.Now().Add(time.Hour),
GraceAt: time.Now().Add(2 * time.Hour),
ExpiresAt: time.Now().Add(3 * time.Hour),
})
assert.GreaterOrEqual(t, respLic.ID, int32(0))
// just a couple spot checks for sanity
assert.Equal(t, "testing", respLic.Claims["account_id"])
features, err := respLic.FeaturesClaims()
require.NoError(t, err)
assert.EqualValues(t, 1, features[codersdk.FeatureAuditLog])
})
// Test we still reject a license that isn't valid yet, but has other issues (e.g. expired
// before it starts).
t.Run("NotEver", func(t *testing.T) {
t.Parallel()
client, _ := coderdenttest.New(t, &coderdenttest.Options{DontAddLicense: true})
lic := coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{
AccountType: license.AccountTypeSalesforce,
AccountID: "testing",
Features: license.Features{
codersdk.FeatureAuditLog: 1,
},
NotBefore: time.Now().Add(time.Hour),
GraceAt: time.Now().Add(2 * time.Hour),
ExpiresAt: time.Now().Add(-time.Hour),
})
_, err := client.AddLicense(context.Background(), codersdk.AddLicenseRequest{
License: lic,
})
errResp := &codersdk.Error{}
require.ErrorAs(t, err, &errResp)
require.Equal(t, http.StatusBadRequest, errResp.StatusCode())
require.Contains(t, errResp.Detail, license.ErrMultipleIssues.Error())
})
} }
func TestGetLicense(t *testing.T) { func TestGetLicense(t *testing.T) {