Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

chore: refactor entitlements to be a safe object to use#14406

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Merged
Emyrk merged 7 commits intomainfromstevenmasley/safe_entitlements
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes fromall commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletionscoderd/coderd.go
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -37,6 +37,7 @@ import (
"tailscale.com/util/singleflight"

"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/entitlements"
"github.com/coder/quartz"
"github.com/coder/serpent"

Expand DownExpand Up@@ -157,6 +158,9 @@ type Options struct {
TrialGenerator func(ctx context.Context, body codersdk.LicensorTrialRequest) error
// RefreshEntitlements is used to set correct entitlements after creating first user and generating trial license.
RefreshEntitlements func(ctx context.Context) error
// Entitlements can come from the enterprise caller if enterprise code is
// included.
Entitlements *entitlements.Set
// PostAuthAdditionalHeadersFunc is used to add additional headers to the response
// after a successful authentication.
// This is somewhat janky, but seemingly the only reasonable way to add a header
Expand DownExpand Up@@ -263,6 +267,9 @@ func New(options *Options) *API {
if options == nil {
options = &Options{}
}
if options.Entitlements == nil {
options.Entitlements = entitlements.New()
}
if options.NewTicker == nil {
options.NewTicker = func(duration time.Duration) (tick <-chan time.Time, done func()) {
ticker := time.NewTicker(duration)
Expand DownExpand Up@@ -500,6 +507,7 @@ func New(options *Options) *API {
DocsURL: options.DeploymentValues.DocsURL.String(),
AppearanceFetcher: &api.AppearanceFetcher,
BuildInfo: buildInfo,
Entitlements: options.Entitlements,
})
api.SiteHandler.Experiments.Store(&experiments)

Expand Down
109 changes: 109 additions & 0 deletionscoderd/entitlements/entitlements.go
View file
Open in desktop
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
package entitlements

import (
"encoding/json"
"net/http"
"sync"
"time"

"github.com/coder/coder/v2/codersdk"
)

type Set struct {
Copy link
MemberAuthor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

I implemented the methods as I saw them used. There might be a way to reduce the number of methods on this struct.

entitlementsMu sync.RWMutex
entitlements codersdk.Entitlements
}

func New() *Set {
return &Set{
// Some defaults for an unlicensed instance.
// These will be updated when coderd is initialized.
entitlements: codersdk.Entitlements{
Features: map[codersdk.FeatureName]codersdk.Feature{},
Warnings: nil,
Errors: nil,
HasLicense: false,
Trial: false,
RequireTelemetry: false,
RefreshedAt: time.Time{},
},
}
}

// AllowRefresh returns whether the entitlements are allowed to be refreshed.
// If it returns false, that means it was recently refreshed and the caller should
// wait the returned duration before trying again.
func (l *Set) AllowRefresh(now time.Time) (bool, time.Duration) {
l.entitlementsMu.RLock()
defer l.entitlementsMu.RUnlock()

diff := now.Sub(l.entitlements.RefreshedAt)
if diff < time.Minute {
return false, time.Minute - diff
}

return true, 0
}

func (l *Set) Feature(name codersdk.FeatureName) (codersdk.Feature, bool) {
l.entitlementsMu.RLock()
defer l.entitlementsMu.RUnlock()

f, ok := l.entitlements.Features[name]
return f, ok
}

func (l *Set) Enabled(feature codersdk.FeatureName) bool {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Potential follow-up: we could replace this withf, ok := Features(name); ok && f.Enabled?

Emyrk reacted with thumbs up emoji
Copy link
MemberAuthor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

True. Because before we had access to the whole struct, our usage of it seemed a bit arbitrary at times. Sometimes we grab it and check entitled, most times just enabled.

I'm not trying to fix all our usage right now, but it would be good to audit at some times.

johnstcn reacted with thumbs up emoji
l.entitlementsMu.RLock()
defer l.entitlementsMu.RUnlock()

f, ok := l.entitlements.Features[feature]
if !ok {
return false
}
return f.Enabled
}

// AsJSON is used to return this to the api without exposing the entitlements for
// mutation.
func (l *Set) AsJSON() json.RawMessage {
l.entitlementsMu.RLock()
defer l.entitlementsMu.RUnlock()

b, _ := json.Marshal(l.entitlements)
return b
}

func (l *Set) Replace(entitlements codersdk.Entitlements) {
l.entitlementsMu.Lock()
defer l.entitlementsMu.Unlock()

l.entitlements = entitlements
}

func (l *Set) Update(do func(entitlements *codersdk.Entitlements)) {
l.entitlementsMu.Lock()
defer l.entitlementsMu.Unlock()

do(&l.entitlements)
}

func (l *Set) FeatureChanged(featureName codersdk.FeatureName, newFeature codersdk.Feature) (initial, changed, enabled bool) {
l.entitlementsMu.RLock()
defer l.entitlementsMu.RUnlock()

oldFeature := l.entitlements.Features[featureName]
if oldFeature.Enabled != newFeature.Enabled {
return false, true, newFeature.Enabled
}
return false, false, newFeature.Enabled
}

func (l *Set) WriteEntitlementWarningHeaders(header http.Header) {
l.entitlementsMu.RLock()
defer l.entitlementsMu.RUnlock()

for _, warning := range l.entitlements.Warnings {
header.Add(codersdk.EntitlementsWarningHeader, warning)
}
}
63 changes: 63 additions & 0 deletionscoderd/entitlements/entitlements_test.go
View file
Open in desktop
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
package entitlements_test

import (
"testing"
"time"

"github.com/stretchr/testify/require"

"github.com/coder/coder/v2/coderd/entitlements"
"github.com/coder/coder/v2/codersdk"
)

func TestUpdate(t *testing.T) {
t.Parallel()

set := entitlements.New()
require.False(t, set.Enabled(codersdk.FeatureMultipleOrganizations))

set.Update(func(entitlements *codersdk.Entitlements) {
entitlements.Features[codersdk.FeatureMultipleOrganizations] = codersdk.Feature{
Enabled: true,
Entitlement: codersdk.EntitlementEntitled,
}
})
require.True(t, set.Enabled(codersdk.FeatureMultipleOrganizations))
}

func TestAllowRefresh(t *testing.T) {
t.Parallel()

now := time.Now()
set := entitlements.New()
set.Update(func(entitlements *codersdk.Entitlements) {
entitlements.RefreshedAt = now
})

ok, wait := set.AllowRefresh(now)
require.False(t, ok)
require.InDelta(t, time.Minute.Seconds(), wait.Seconds(), 5)

set.Update(func(entitlements *codersdk.Entitlements) {
entitlements.RefreshedAt = now.Add(time.Minute * -2)
})

ok, wait = set.AllowRefresh(now)
require.True(t, ok)
require.Equal(t, time.Duration(0), wait)
}

func TestReplace(t *testing.T) {
t.Parallel()

set := entitlements.New()
require.False(t, set.Enabled(codersdk.FeatureMultipleOrganizations))
set.Replace(codersdk.Entitlements{
Features: map[codersdk.FeatureName]codersdk.Feature{
codersdk.FeatureMultipleOrganizations: {
Enabled: true,
},
},
})
require.True(t, set.Enabled(codersdk.FeatureMultipleOrganizations))
}
6 changes: 6 additions & 0 deletionscodersdk/deployment.go
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -35,6 +35,12 @@ const (
EntitlementNotEntitled Entitlement = "not_entitled"
)

// Entitled returns if the entitlement can be used. So this is true if it
// is entitled or still in it's grace period.
func (e Entitlement) Entitled() bool {
return e == EntitlementEntitled || e == EntitlementGracePeriod
}

// Weight converts the enum types to a numerical value for easier
// comparisons. Easier than sets of if statements.
func (e Entitlement) Weight() int {
Expand Down
72 changes: 30 additions & 42 deletionsenterprise/coderd/coderd.go
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -15,6 +15,7 @@ import (
"github.com/coder/coder/v2/buildinfo"
"github.com/coder/coder/v2/coderd/appearance"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/entitlements"
agplportsharing "github.com/coder/coder/v2/coderd/portsharing"
"github.com/coder/coder/v2/coderd/rbac/policy"
"github.com/coder/coder/v2/enterprise/coderd/portsharing"
Expand DownExpand Up@@ -103,19 +104,26 @@ func New(ctx context.Context, options *Options) (_ *API, err error) {
}
return nil, xerrors.Errorf("init database encryption: %w", err)
}

entitlementsSet := entitlements.New()
options.Database = cryptDB
api := &API{
ctx: ctx,
cancel: cancelFunc,
Options: options,
ctx: ctx,
cancel: cancelFunc,
Options: options,
entitlements: entitlementsSet,
provisionerDaemonAuth: &provisionerDaemonAuth{
psk: options.ProvisionerDaemonPSK,
authorizer: options.Authorizer,
db: options.Database,
},
licenseMetricsCollector: &license.MetricsCollector{
Entitlements: entitlementsSet,
},
}
// This must happen before coderd initialization!
options.PostAuthAdditionalHeadersFunc = api.writeEntitlementWarningsHeader
options.Options.Entitlements = api.entitlements
api.AGPL = coderd.New(options.Options)
defer func() {
if err != nil {
Expand DownExpand Up@@ -493,7 +501,7 @@ func New(ctx context.Context, options *Options) (_ *API, err error) {
}
api.AGPL.WorkspaceProxiesFetchUpdater.Store(&fetchUpdater)

err = api.PrometheusRegistry.Register(&api.licenseMetricsCollector)
err = api.PrometheusRegistry.Register(api.licenseMetricsCollector)
if err != nil {
return nil, xerrors.Errorf("unable to register license metrics collector")
}
Expand DownExpand Up@@ -553,13 +561,11 @@ type API struct {
// ProxyHealth checks the reachability of all workspace proxies.
ProxyHealth *proxyhealth.ProxyHealth

entitlementsUpdateMu sync.Mutex
entitlementsMu sync.RWMutex
entitlements codersdk.Entitlements
entitlements *entitlements.Set

provisionerDaemonAuth *provisionerDaemonAuth

licenseMetricsCollector license.MetricsCollector
licenseMetricsCollector*license.MetricsCollector
tailnetService *tailnet.ClientService
}

Expand DownExpand Up@@ -588,11 +594,8 @@ func (api *API) writeEntitlementWarningsHeader(a rbac.Subject, header http.Heade
// has no roles. This is a normal user!
return
}
api.entitlementsMu.RLock()
defer api.entitlementsMu.RUnlock()
for _, warning := range api.entitlements.Warnings {
header.Add(codersdk.EntitlementsWarningHeader, warning)
}

api.entitlements.WriteEntitlementWarningHeaders(header)
}

func (api *API) Close() error {
Expand All@@ -614,9 +617,6 @@ func (api *API) Close() error {
}

func (api *API) updateEntitlements(ctx context.Context) error {
api.entitlementsUpdateMu.Lock()
defer api.entitlementsUpdateMu.Unlock()

replicas := api.replicaManager.AllPrimary()
agedReplicas := make([]database.Replica, 0, len(replicas))
for _, replica := range replicas {
Expand All@@ -632,7 +632,7 @@ func (api *API) updateEntitlements(ctx context.Context) error {
agedReplicas = append(agedReplicas, replica)
}

entitlements, err := license.Entitlements(
reloadedEntitlements, err := license.Entitlements(
ctx, api.Database,
len(agedReplicas), len(api.ExternalAuthConfigs), api.LicenseKeys, map[codersdk.FeatureName]bool{
codersdk.FeatureAuditLog: api.AuditLogging,
Expand All@@ -652,29 +652,24 @@ func (api *API) updateEntitlements(ctx context.Context) error {
return err
}

ifentitlements.RequireTelemetry && !api.DeploymentValues.Telemetry.Enable.Value() {
ifreloadedEntitlements.RequireTelemetry && !api.DeploymentValues.Telemetry.Enable.Value() {
// We can't fail because then the user couldn't remove the offending
// license w/o a restart.
//
// We don't simply append to entitlement.Errors since we don't want any
// enterprise features enabled.
api.entitlements.Errors = []string{
"License requires telemetry but telemetry is disabled",
}
api.entitlements.Update(func(entitlements *codersdk.Entitlements) {
entitlements.Errors = []string{
"License requires telemetry but telemetry is disabled",
}
})

api.Logger.Error(ctx, "license requires telemetry enabled")
return nil
}

featureChanged := func(featureName codersdk.FeatureName) (initial, changed, enabled bool) {
if api.entitlements.Features == nil {
return true, false, entitlements.Features[featureName].Enabled
}
oldFeature := api.entitlements.Features[featureName]
newFeature := entitlements.Features[featureName]
if oldFeature.Enabled != newFeature.Enabled {
return false, true, newFeature.Enabled
}
return false, false, newFeature.Enabled
return api.entitlements.FeatureChanged(featureName, reloadedEntitlements.Features[featureName])
}

shouldUpdate := func(initial, changed, enabled bool) bool {
Expand DownExpand Up@@ -831,20 +826,16 @@ func (api *API) updateEntitlements(ctx context.Context) error {
}

// External token encryption is soft-enforced
featureExternalTokenEncryption :=entitlements.Features[codersdk.FeatureExternalTokenEncryption]
featureExternalTokenEncryption :=reloadedEntitlements.Features[codersdk.FeatureExternalTokenEncryption]
featureExternalTokenEncryption.Enabled = len(api.ExternalTokenEncryption) > 0
if featureExternalTokenEncryption.Enabled && featureExternalTokenEncryption.Entitlement != codersdk.EntitlementEntitled {
msg := fmt.Sprintf("%s is enabled (due to setting external token encryption keys) but your license is not entitled to this feature.", codersdk.FeatureExternalTokenEncryption.Humanize())
api.Logger.Warn(ctx, msg)
entitlements.Warnings = append(entitlements.Warnings, msg)
reloadedEntitlements.Warnings = append(reloadedEntitlements.Warnings, msg)
}
entitlements.Features[codersdk.FeatureExternalTokenEncryption] = featureExternalTokenEncryption
reloadedEntitlements.Features[codersdk.FeatureExternalTokenEncryption] = featureExternalTokenEncryption

api.entitlementsMu.Lock()
defer api.entitlementsMu.Unlock()
api.entitlements = entitlements
api.licenseMetricsCollector.Entitlements.Store(&entitlements)
api.AGPL.SiteHandler.Entitlements.Store(&entitlements)
api.entitlements.Replace(reloadedEntitlements)
return nil
}

Expand DownExpand Up@@ -1024,10 +1015,7 @@ func derpMapper(logger slog.Logger, proxyHealth *proxyhealth.ProxyHealth) func(*
// @Router /entitlements [get]
func (api *API) serveEntitlements(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
api.entitlementsMu.RLock()
entitlements := api.entitlements
api.entitlementsMu.RUnlock()
httpapi.Write(ctx, rw, http.StatusOK, entitlements)
httpapi.Write(ctx, rw, http.StatusOK, api.entitlements.AsJSON())
}

func (api *API) runEntitlementsLoop(ctx context.Context) {
Expand Down
Loading

[8]ページ先頭

©2009-2025 Movatter.jp