diff --git a/pkg/server/license/BUILD.bazel b/pkg/server/license/BUILD.bazel index 2eeee199046b..08875c777365 100644 --- a/pkg/server/license/BUILD.bazel +++ b/pkg/server/license/BUILD.bazel @@ -16,6 +16,7 @@ go_library( "//pkg/sql/pgwire/pgerror", "//pkg/util/envutil", "//pkg/util/log", + "//pkg/util/syncutil", "//pkg/util/timeutil", "@com_github_cockroachdb_errors//:errors", ], diff --git a/pkg/server/license/enforcer.go b/pkg/server/license/enforcer.go index 4cf0fc59347c..8c51c09931b3 100644 --- a/pkg/server/license/enforcer.go +++ b/pkg/server/license/enforcer.go @@ -24,6 +24,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" "github.com/cockroachdb/cockroach/pkg/util/envutil" "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/cockroach/pkg/util/syncutil" "github.com/cockroachdb/cockroach/pkg/util/timeutil" "github.com/cockroachdb/errors" ) @@ -35,8 +36,11 @@ const ( // Enforcer is responsible for enforcing license policies. type Enforcer struct { - // TestingKnobs are used to control the behavior of the enforcer for testing. - TestingKnobs *TestingKnobs + mu struct { + syncutil.Mutex + // testingKnobs are used to control the behavior of the enforcer for testing. + testingKnobs *TestingKnobs + } // telemetryStatusReporter is an interface for getting the timestamp of the // last successful ping to the telemetry server. For some licenses, sending @@ -132,6 +136,19 @@ func (e *Enforcer) SetTelemetryStatusReporter(reporter TelemetryStatusReporter) e.telemetryStatusReporter = reporter } +// SetTesting Knobs will set the pointer to the testing knobs. +func (e *Enforcer) SetTestingKnobs(k *TestingKnobs) { + e.mu.Lock() + defer e.mu.Unlock() + e.mu.testingKnobs = k +} + +func (e *Enforcer) GetTestingKnobs() *TestingKnobs { + e.mu.Lock() + defer e.mu.Unlock() + return e.mu.testingKnobs +} + // Start will load the necessary metadata for the enforcer. It reads from the // KV license metadata and will populate any missing data as needed. The DB // passed in must have access to the system tenant. @@ -146,7 +163,7 @@ func (e *Enforcer) Start( // Writing the grace period initialization timestamp is currently opt-in. See // the EnableGracePeriodInitTSWrite comment for details. - if e.TestingKnobs != nil && e.TestingKnobs.EnableGracePeriodInitTSWrite { + if tk := e.GetTestingKnobs(); tk != nil && tk.EnableGracePeriodInitTSWrite { return e.maybeWriteClusterInitGracePeriodTS(ctx, db, initialStart) } return nil @@ -313,8 +330,8 @@ func (e *Enforcer) Disable(ctx context.Context) { // getStartTime returns the time when the enforcer was created. This accounts // for testing knobs that may override the time. func (e *Enforcer) getStartTime() time.Time { - if e.TestingKnobs != nil && e.TestingKnobs.OverrideStartTime != nil { - return *e.TestingKnobs.OverrideStartTime + if tk := e.GetTestingKnobs(); tk != nil && tk.OverrideStartTime != nil { + return *tk.OverrideStartTime } return e.startTime } @@ -322,8 +339,8 @@ func (e *Enforcer) getStartTime() time.Time { // getThrottleCheckTS returns the time to use when checking if we should // throttle the new transaction. func (e *Enforcer) getThrottleCheckTS() time.Time { - if e.TestingKnobs != nil && e.TestingKnobs.OverrideThrottleCheckTime != nil { - return *e.TestingKnobs.OverrideThrottleCheckTime + if tk := e.GetTestingKnobs(); tk != nil && tk.OverrideThrottleCheckTime != nil { + return *tk.OverrideThrottleCheckTime } return timeutil.Now() } diff --git a/pkg/server/license/enforcer_test.go b/pkg/server/license/enforcer_test.go index 475a21f33e04..caaf36383a7d 100644 --- a/pkg/server/license/enforcer_test.go +++ b/pkg/server/license/enforcer_test.go @@ -63,10 +63,10 @@ func TestGracePeriodInitTSCache(t *testing.T) { enforcer := &license.Enforcer{} ts2 := ts1.Add(1) ts2End := ts2.Add(7 * 24 * time.Hour) // Calculate the end of the grace period - enforcer.TestingKnobs = &license.TestingKnobs{ + enforcer.SetTestingKnobs(&license.TestingKnobs{ EnableGracePeriodInitTSWrite: true, OverrideStartTime: &ts2, - } + }) // Ensure request for the grace period init ts1 before start just returns the start // time used when the enforcer was created. require.Equal(t, ts2End, enforcer.GetClusterInitGracePeriodEndTS()) @@ -143,12 +143,11 @@ func TestThrottle(t *testing.T) { {OverTxnThreshold, license.LicTypeEvaluation, t0, t0, t15d, t46d, "License expired"}, } { t.Run(fmt.Sprintf("test %d", i), func(t *testing.T) { - e := license.Enforcer{ - TestingKnobs: &license.TestingKnobs{ - OverrideStartTime: &tc.gracePeriodInit, - OverrideThrottleCheckTime: &tc.checkTs, - }, - } + e := license.Enforcer{} + e.SetTestingKnobs(&license.TestingKnobs{ + OverrideStartTime: &tc.gracePeriodInit, + OverrideThrottleCheckTime: &tc.checkTs, + }) e.SetTelemetryStatusReporter(&mockTelemetryStatusReporter{ lastPingTime: tc.lastTelemetryPingTime, }) diff --git a/pkg/server/server_sql.go b/pkg/server/server_sql.go index 7197d2a219db..585e8e3ef9c3 100644 --- a/pkg/server/server_sql.go +++ b/pkg/server/server_sql.go @@ -1753,7 +1753,7 @@ func (s *SQLServer) startLicenseEnforcer( // is shared to provide access to the values cached from the KV read. if s.execCfg.Codec.ForSystemTenant() { if knobs.Server != nil { - s.execCfg.LicenseEnforcer.TestingKnobs = &knobs.Server.(*TestingKnobs).LicenseTestingKnobs + s.execCfg.LicenseEnforcer.SetTestingKnobs(&knobs.Server.(*TestingKnobs).LicenseTestingKnobs) } // TODO(spilchen): we need to tell the license enforcer about the // diagnostics reporter. This will be handled in CRDB-39991