Skip to content

Commit

Permalink
feat(restore_test): introduce testHelper
Browse files Browse the repository at this point in the history
TestHelper is a reimplementation of restoreTestHelper which wasn't well-designed. It removes a lot of boilerplate from test setup and also allows for more flexibility.
  • Loading branch information
Michal-Leszczynski committed Jun 4, 2024
1 parent 8140136 commit 498eecb
Showing 1 changed file with 218 additions and 1 deletion.
219 changes: 218 additions & 1 deletion pkg/service/restore/helper_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,30 @@
package restore_test

import (
"context"
"encoding/json"
"fmt"
"strings"
"testing"

"github.com/gocql/gocql"
"github.com/pkg/errors"
"github.com/scylladb/go-log"
"github.com/scylladb/gocqlx/v2"
"github.com/scylladb/gocqlx/v2/qb"
"github.com/scylladb/scylla-manager/v3/pkg/util/uuid"
"go.uber.org/zap/zapcore"

"github.com/scylladb/scylla-manager/v3/pkg/metrics"
"github.com/scylladb/scylla-manager/v3/pkg/scyllaclient"
"github.com/scylladb/scylla-manager/v3/pkg/service/backup"
"github.com/scylladb/scylla-manager/v3/pkg/service/repair"
. "github.com/scylladb/scylla-manager/v3/pkg/service/restore"
. "github.com/scylladb/scylla-manager/v3/pkg/testutils"
. "github.com/scylladb/scylla-manager/v3/pkg/testutils/db"
. "github.com/scylladb/scylla-manager/v3/pkg/testutils/testhelper"
"github.com/scylladb/scylla-manager/v3/pkg/util/inexlist/ksfilter"
"github.com/scylladb/scylla-manager/v3/pkg/util/query"
"github.com/scylladb/scylla-manager/v3/pkg/util/uuid"
)

type table struct {
Expand All @@ -37,6 +48,212 @@ func randomizedName(name string) string {
return name + strings.Replace(fmt.Sprint(uuid.NewTime()), "-", "", -1)
}

type clusterHelper struct {
*CommonTestHelper
rootSession gocqlx.Session
}

func newCluster(t *testing.T, hosts []string) clusterHelper {
logger := log.NewDevelopmentWithLevel(zapcore.InfoLevel)
hrt := NewHackableRoundTripper(scyllaclient.DefaultTransport())
clientCfg := scyllaclient.TestConfig(hosts, AgentAuthToken())
client := newTestClient(t, hrt, logger.Named("client"), &clientCfg)

for _, h := range hosts {
if err := client.RcloneResetStats(context.Background(), h); err != nil {
t.Fatal("Reset rclone stats", h, err)
}
}

return clusterHelper{
CommonTestHelper: &CommonTestHelper{
Session: CreateScyllaManagerDBSession(t),
Hrt: hrt,
Client: client,
ClusterID: uuid.NewTime(),
TaskID: uuid.NewTime(),
RunID: uuid.NewTime(),
T: t,
},
rootSession: CreateSessionAndDropAllKeyspaces(t, client),
}
}

type testHelper struct {
srcCluster clusterHelper
srcBackupSvc *backup.Service

dstCluster clusterHelper
dstRestoreSvc *Service
dstUser string
dstPass string
}

func newTestHelper(t *testing.T, srcHosts, dstHosts []string) *testHelper {
srcCluster := newCluster(t, srcHosts)
dstCluster := newCluster(t, dstHosts)

user := randomizedName("helper_user_")
pass := randomizedName("helper_pass_")

dropNonSuperUsers(t, dstCluster.rootSession)
createUser(t, dstCluster.rootSession, user, pass)

return &testHelper{
srcCluster: srcCluster,
srcBackupSvc: newBackupSvc(t, srcCluster.Session, srcCluster.Client),
dstCluster: dstCluster,
dstRestoreSvc: newRestoreSvc(t, dstCluster.Session, dstCluster.Client, user, pass),
dstUser: user,
dstPass: pass,
}
}

func newBackupSvc(t *testing.T, mgrSession gocqlx.Session, client *scyllaclient.Client) *backup.Service {
svc, err := backup.NewService(
mgrSession,
backup.DefaultConfig(),
metrics.NewBackupMetrics(),
func(_ context.Context, id uuid.UUID) (string, error) {
return "test_cluster", nil
},
func(context.Context, uuid.UUID) (*scyllaclient.Client, error) {
return client, nil
},
func(ctx context.Context, clusterID uuid.UUID) (gocqlx.Session, error) {
return CreateSession(t, client), nil
},
log.NewDevelopmentWithLevel(zapcore.ErrorLevel).Named("backup"),
)
if err != nil {
t.Fatal(err)
}
return svc
}

func newRestoreSvc(t *testing.T, mgrSession gocqlx.Session, client *scyllaclient.Client, user, pass string) *Service {
repairSvc, err := repair.NewService(
mgrSession,
repair.DefaultConfig(),
metrics.NewRepairMetrics(),
func(context.Context, uuid.UUID) (*scyllaclient.Client, error) {
return client, nil
},
func(ctx context.Context, clusterID uuid.UUID) (gocqlx.Session, error) {
return CreateSession(t, client), nil
},
log.NewDevelopmentWithLevel(zapcore.ErrorLevel).Named("repair"),
)
if err != nil {
t.Fatal(err)
}

svc, err := NewService(
repairSvc,
mgrSession,
defaultTestConfig(),
metrics.NewRestoreMetrics(),
func(context.Context, uuid.UUID) (*scyllaclient.Client, error) {
return client, nil
},
func(ctx context.Context, clusterID uuid.UUID) (gocqlx.Session, error) {
return CreateManagedClusterSession(t, false, client, user, pass), nil
},
log.NewDevelopmentWithLevel(zapcore.InfoLevel).Named("restore"),
)
if err != nil {
t.Fatal(err)
}

return svc
}

func (h *testHelper) runBackup(t *testing.T, props map[string]any) string {
Printf("Run backup with properties: %v", props)
ctx := context.Background()
h.srcCluster.TaskID = uuid.NewTime()
h.srcCluster.RunID = uuid.NewTime()

rawProps, err := json.Marshal(props)
if err != nil {
t.Fatal(errors.Wrap(err, "marshal properties"))
}

target, err := h.srcBackupSvc.GetTarget(ctx, h.srcCluster.ClusterID, rawProps)
if err != nil {
t.Fatal(errors.Wrap(err, "generate target"))
}

err = h.srcBackupSvc.Backup(ctx, h.srcCluster.ClusterID, h.srcCluster.TaskID, h.srcCluster.RunID, target)
if err != nil {
t.Fatal(errors.Wrap(err, "run backup"))
}

pr, err := h.srcBackupSvc.GetProgress(ctx, h.srcCluster.ClusterID, h.srcCluster.TaskID, h.srcCluster.RunID)
if err != nil {
t.Fatal(errors.Wrap(err, "get progress"))
}

return pr.SnapshotTag
}

func (h *testHelper) runRestore(t *testing.T, props map[string]any) {
Printf("Run restore with properties: %v", props)
ctx := context.Background()
h.dstCluster.TaskID = uuid.NewTime()
h.dstCluster.RunID = uuid.NewTime()

rawProps, err := json.Marshal(props)
if err != nil {
t.Fatal(errors.Wrap(err, "marshal properties"))
}

err = h.dstRestoreSvc.Restore(ctx, h.dstCluster.ClusterID, h.dstCluster.TaskID, h.dstCluster.RunID, rawProps)
if err != nil {
t.Fatal(errors.Wrap(err, "run restore"))
}
}

func (h *testHelper) getRestoreProgress(t *testing.T) Progress {
pr, err := h.dstRestoreSvc.GetProgress(context.Background(), h.dstCluster.ClusterID, h.dstCluster.TaskID, h.dstCluster.RunID)
if err != nil {
t.Fatal(errors.Wrap(err, "get progress"))
}
return pr
}

func (h *testHelper) validateIdenticalTables(t *testing.T, tables []table) {
pr := h.getRestoreProgress(t)
validateCompleteProgress(t, pr, tables)

views, err := query.GetAllViews(h.srcCluster.rootSession)
if err != nil {
t.Fatal(errors.Wrap(err, "get all views"))
}

Print("Validate tombstone_gc mode")
for _, tab := range tables {
// Don't validate views tombstone_gc
if views.Has(tab.ks + "." + tab.tab) {
continue
}
srcMode := tombstoneGCMode(t, h.srcCluster.rootSession, tab.ks, tab.tab)
dstMode := tombstoneGCMode(t, h.dstCluster.rootSession, tab.ks, tab.tab)
if srcMode != dstMode {
t.Fatalf("Expected %s tombstone_gc mode, got: %s", srcMode, dstMode)
}
}

Print("Validate row count")
for _, tab := range tables {
dstCnt := rowCount(t, h.dstCluster.rootSession, tab.ks, tab.tab)
srcCnt := rowCount(t, h.srcCluster.rootSession, tab.ks, tab.tab)
if dstCnt != srcCnt {
t.Fatalf("srcCount != dstCount")
}
}
}

func tombstoneGCMode(t *testing.T, s gocqlx.Session, keyspace, table string) string {
var ext map[string]string
q := qb.Select("system_schema.tables").
Expand Down

0 comments on commit 498eecb

Please sign in to comment.