Skip to content

Commit

Permalink
vtexplain: Fix passing through context for cleanup (#13900)
Browse files Browse the repository at this point in the history
Signed-off-by: Dirkjan Bussink <d.bussink@gmail.com>
  • Loading branch information
dbussink authored Aug 31, 2023
1 parent eb0eb0c commit cc435db
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 63 deletions.
7 changes: 3 additions & 4 deletions go/vt/vtexplain/vtexplain_vtgate.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func (vte *VTExplain) initVtgateExecutor(ctx context.Context, vSchemaStr, ksShar
vte.explainTopo.TopoServer = memorytopo.NewServer(ctx, vtexplainCell)
vte.healthCheck = discovery.NewFakeHealthCheck(nil)

resolver := vte.newFakeResolver(opts, vte.explainTopo, vtexplainCell)
resolver := vte.newFakeResolver(ctx, opts, vte.explainTopo, vtexplainCell)

err := vte.buildTopology(ctx, opts, vSchemaStr, ksShardMapStr, opts.NumShards)
if err != nil {
Expand All @@ -80,10 +80,9 @@ func (vte *VTExplain) initVtgateExecutor(ctx context.Context, vSchemaStr, ksShar
return nil
}

func (vte *VTExplain) newFakeResolver(opts *Options, serv srvtopo.Server, cell string) *vtgate.Resolver {
ctx := context.Background()
func (vte *VTExplain) newFakeResolver(ctx context.Context, opts *Options, serv srvtopo.Server, cell string) *vtgate.Resolver {
gw := vtgate.NewTabletGateway(ctx, vte.healthCheck, serv, cell)
_ = gw.WaitForTablets([]topodatapb.TabletType{topodatapb.TabletType_REPLICA})
_ = gw.WaitForTablets(ctx, []topodatapb.TabletType{topodatapb.TabletType_REPLICA})

txMode := vtgatepb.TransactionMode_MULTI
if opts.ExecutionMode == ModeTwoPC {
Expand Down
20 changes: 9 additions & 11 deletions go/vt/vtgate/legacy_scatter_conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,7 @@ func TestLegacyExecuteFailOnAutocommit(t *testing.T) {
}

func TestScatterConnExecuteMulti(t *testing.T) {
testScatterConnGeneric(t, "TestScatterConnExecuteMultiShard", func(sc *ScatterConn, shards []string) (*sqltypes.Result, error) {
ctx := utils.LeakCheckContext(t)
testScatterConnGeneric(t, "TestScatterConnExecuteMultiShard", func(ctx context.Context, sc *ScatterConn, shards []string) (*sqltypes.Result, error) {
res := srvtopo.NewResolver(newSandboxForCells(ctx, []string{"aa"}), sc.gateway, "aa")
rss, err := res.ResolveDestination(ctx, "TestScatterConnExecuteMultiShard", topodatapb.TabletType_REPLICA, key.DestinationShards(shards))
if err != nil {
Expand All @@ -130,8 +129,7 @@ func TestScatterConnExecuteMulti(t *testing.T) {
}

func TestScatterConnStreamExecuteMulti(t *testing.T) {
testScatterConnGeneric(t, "TestScatterConnStreamExecuteMulti", func(sc *ScatterConn, shards []string) (*sqltypes.Result, error) {
ctx := utils.LeakCheckContext(t)
testScatterConnGeneric(t, "TestScatterConnStreamExecuteMulti", func(ctx context.Context, sc *ScatterConn, shards []string) (*sqltypes.Result, error) {
res := srvtopo.NewResolver(newSandboxForCells(ctx, []string{"aa"}), sc.gateway, "aa")
rss, err := res.ResolveDestination(ctx, "TestScatterConnStreamExecuteMulti", topodatapb.TabletType_REPLICA, key.DestinationShards(shards))
if err != nil {
Expand All @@ -158,15 +156,15 @@ func verifyScatterConnError(t *testing.T, err error, wantErr string, wantCode vt
assert.Equal(t, wantCode, vterrors.Code(err))
}

func testScatterConnGeneric(t *testing.T, name string, f func(sc *ScatterConn, shards []string) (*sqltypes.Result, error)) {
func testScatterConnGeneric(t *testing.T, name string, f func(ctx context.Context, sc *ScatterConn, shards []string) (*sqltypes.Result, error)) {
ctx := utils.LeakCheckContext(t)

hc := discovery.NewFakeHealthCheck(nil)

// no shard
s := createSandbox(name)
sc := newTestScatterConn(ctx, hc, newSandboxForCells(ctx, []string{"aa"}), "aa")
qr, err := f(sc, nil)
qr, err := f(ctx, sc, nil)
require.NoError(t, err)
if qr.RowsAffected != 0 {
t.Errorf("want 0, got %v", qr.RowsAffected)
Expand All @@ -177,7 +175,7 @@ func testScatterConnGeneric(t *testing.T, name string, f func(sc *ScatterConn, s
sc = newTestScatterConn(ctx, hc, newSandboxForCells(ctx, []string{"aa"}), "aa")
sbc := hc.AddTestTablet("aa", "0", 1, name, "0", topodatapb.TabletType_REPLICA, true, 1, nil)
sbc.MustFailCodes[vtrpcpb.Code_INVALID_ARGUMENT] = 1
_, err = f(sc, []string{"0"})
_, err = f(ctx, sc, []string{"0"})
want := fmt.Sprintf("target: %v.0.replica: INVALID_ARGUMENT error", name)
// Verify server error string.
if err == nil || err.Error() != want {
Expand All @@ -196,7 +194,7 @@ func testScatterConnGeneric(t *testing.T, name string, f func(sc *ScatterConn, s
sbc1 := hc.AddTestTablet("aa", "1", 1, name, "1", topodatapb.TabletType_REPLICA, true, 1, nil)
sbc0.MustFailCodes[vtrpcpb.Code_INVALID_ARGUMENT] = 1
sbc1.MustFailCodes[vtrpcpb.Code_INVALID_ARGUMENT] = 1
_, err = f(sc, []string{"0", "1"})
_, err = f(ctx, sc, []string{"0", "1"})
// Verify server errors are consolidated.
want = fmt.Sprintf("target: %v.0.replica: INVALID_ARGUMENT error\ntarget: %v.1.replica: INVALID_ARGUMENT error", name, name)
verifyScatterConnError(t, err, want, vtrpcpb.Code_INVALID_ARGUMENT)
Expand All @@ -216,7 +214,7 @@ func testScatterConnGeneric(t *testing.T, name string, f func(sc *ScatterConn, s
sbc1 = hc.AddTestTablet("aa", "1", 1, name, "1", topodatapb.TabletType_REPLICA, true, 1, nil)
sbc0.MustFailCodes[vtrpcpb.Code_INVALID_ARGUMENT] = 1
sbc1.MustFailCodes[vtrpcpb.Code_RESOURCE_EXHAUSTED] = 1
_, err = f(sc, []string{"0", "1"})
_, err = f(ctx, sc, []string{"0", "1"})
// Verify server errors are consolidated.
want = fmt.Sprintf("target: %v.0.replica: INVALID_ARGUMENT error\ntarget: %v.1.replica: RESOURCE_EXHAUSTED error", name, name)
// We should only surface the higher priority error code
Expand All @@ -234,7 +232,7 @@ func testScatterConnGeneric(t *testing.T, name string, f func(sc *ScatterConn, s
hc.Reset()
sc = newTestScatterConn(ctx, hc, newSandboxForCells(ctx, []string{"aa"}), "aa")
sbc = hc.AddTestTablet("aa", "0", 1, name, "0", topodatapb.TabletType_REPLICA, true, 1, nil)
_, _ = f(sc, []string{"0", "0"})
_, _ = f(ctx, sc, []string{"0", "0"})
// Ensure that we executed only once.
if execCount := sbc.ExecCount.Load(); execCount != 1 {
t.Errorf("want 1, got %v", execCount)
Expand All @@ -246,7 +244,7 @@ func testScatterConnGeneric(t *testing.T, name string, f func(sc *ScatterConn, s
sc = newTestScatterConn(ctx, hc, newSandboxForCells(ctx, []string{"aa"}), "aa")
sbc0 = hc.AddTestTablet("aa", "0", 1, name, "0", topodatapb.TabletType_REPLICA, true, 1, nil)
sbc1 = hc.AddTestTablet("aa", "1", 1, name, "1", topodatapb.TabletType_REPLICA, true, 1, nil)
qr, err = f(sc, []string{"0", "1"})
qr, err = f(ctx, sc, []string{"0", "1"})
if err != nil {
t.Fatalf("want nil, got %v", err)
}
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/tabletgateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,9 @@ func (gw *TabletGateway) RegisterStats() {
}

// WaitForTablets is part of the Gateway interface.
func (gw *TabletGateway) WaitForTablets(tabletTypesToWait []topodatapb.TabletType) (err error) {
func (gw *TabletGateway) WaitForTablets(ctx context.Context, tabletTypesToWait []topodatapb.TabletType) (err error) {
log.Infof("Gateway waiting for serving tablets of types %v ...", tabletTypesToWait)
ctx, cancel := context.WithTimeout(context.Background(), initialTabletTimeout)
ctx, cancel := context.WithTimeout(ctx, initialTabletTimeout)
defer cancel()

defer func() {
Expand Down
61 changes: 32 additions & 29 deletions go/vt/vtgate/tabletgateway_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,49 +37,55 @@ import (
)

func TestTabletGatewayExecute(t *testing.T) {
testTabletGatewayGeneric(t, func(tg *TabletGateway, target *querypb.Target) error {
_, err := tg.Execute(context.Background(), target, "query", nil, 0, 0, nil)
ctx := utils.LeakCheckContext(t)
testTabletGatewayGeneric(t, ctx, func(ctx context.Context, tg *TabletGateway, target *querypb.Target) error {
_, err := tg.Execute(ctx, target, "query", nil, 0, 0, nil)
return err
})
testTabletGatewayTransact(t, func(tg *TabletGateway, target *querypb.Target) error {
_, err := tg.Execute(context.Background(), target, "query", nil, 1, 0, nil)
testTabletGatewayTransact(t, ctx, func(ctx context.Context, tg *TabletGateway, target *querypb.Target) error {
_, err := tg.Execute(ctx, target, "query", nil, 1, 0, nil)
return err
})
}

func TestTabletGatewayExecuteStream(t *testing.T) {
testTabletGatewayGeneric(t, func(tg *TabletGateway, target *querypb.Target) error {
err := tg.StreamExecute(context.Background(), target, "query", nil, 0, 0, nil, func(qr *sqltypes.Result) error {
ctx := utils.LeakCheckContext(t)
testTabletGatewayGeneric(t, ctx, func(ctx context.Context, tg *TabletGateway, target *querypb.Target) error {
err := tg.StreamExecute(ctx, target, "query", nil, 0, 0, nil, func(qr *sqltypes.Result) error {
return nil
})
return err
})
}

func TestTabletGatewayBegin(t *testing.T) {
testTabletGatewayGeneric(t, func(tg *TabletGateway, target *querypb.Target) error {
_, err := tg.Begin(context.Background(), target, nil)
ctx := utils.LeakCheckContext(t)
testTabletGatewayGeneric(t, ctx, func(ctx context.Context, tg *TabletGateway, target *querypb.Target) error {
_, err := tg.Begin(ctx, target, nil)
return err
})
}

func TestTabletGatewayCommit(t *testing.T) {
testTabletGatewayTransact(t, func(tg *TabletGateway, target *querypb.Target) error {
_, err := tg.Commit(context.Background(), target, 1)
ctx := utils.LeakCheckContext(t)
testTabletGatewayTransact(t, ctx, func(ctx context.Context, tg *TabletGateway, target *querypb.Target) error {
_, err := tg.Commit(ctx, target, 1)
return err
})
}

func TestTabletGatewayRollback(t *testing.T) {
testTabletGatewayTransact(t, func(tg *TabletGateway, target *querypb.Target) error {
_, err := tg.Rollback(context.Background(), target, 1)
ctx := utils.LeakCheckContext(t)
testTabletGatewayTransact(t, ctx, func(ctx context.Context, tg *TabletGateway, target *querypb.Target) error {
_, err := tg.Rollback(ctx, target, 1)
return err
})
}

func TestTabletGatewayBeginExecute(t *testing.T) {
testTabletGatewayGeneric(t, func(tg *TabletGateway, target *querypb.Target) error {
_, _, err := tg.BeginExecute(context.Background(), target, nil, "query", nil, 0, nil)
ctx := utils.LeakCheckContext(t)
testTabletGatewayGeneric(t, ctx, func(ctx context.Context, tg *TabletGateway, target *querypb.Target) error {
_, _, err := tg.BeginExecute(ctx, target, nil, "query", nil, 0, nil)
return err
})
}
Expand Down Expand Up @@ -167,14 +173,12 @@ func TestTabletGatewayReplicaTransactionError(t *testing.T) {
defer tg.Close(ctx)

_ = hc.AddTestTablet("cell", host, port, keyspace, shard, tabletType, true, 10, nil)
_, err := tg.Execute(context.Background(), target, "query", nil, 1, 0, nil)
_, err := tg.Execute(ctx, target, "query", nil, 1, 0, nil)
verifyContainsError(t, err, "query service can only be used for non-transactional queries on replicas", vtrpcpb.Code_INTERNAL)
}

func testTabletGatewayGeneric(t *testing.T, f func(tg *TabletGateway, target *querypb.Target) error) {
func testTabletGatewayGeneric(t *testing.T, ctx context.Context, f func(ctx context.Context, tg *TabletGateway, target *querypb.Target) error) {
t.Helper()
ctx := utils.LeakCheckContext(t)

keyspace := "ks"
shard := "0"
tabletType := topodatapb.TabletType_REPLICA
Expand All @@ -192,19 +196,19 @@ func testTabletGatewayGeneric(t *testing.T, f func(tg *TabletGateway, target *qu

// no tablet
want := []string{"target: ks.0.replica", `no healthy tablet available for 'keyspace:"ks" shard:"0" tablet_type:REPLICA`}
err := f(tg, target)
err := f(ctx, tg, target)
verifyShardErrors(t, err, want, vtrpcpb.Code_UNAVAILABLE)

// tablet with error
hc.Reset()
hc.AddTestTablet("cell", host, port, keyspace, shard, tabletType, false, 10, fmt.Errorf("no connection"))
err = f(tg, target)
err = f(ctx, tg, target)
verifyShardErrors(t, err, want, vtrpcpb.Code_UNAVAILABLE)

// tablet without connection
hc.Reset()
_ = hc.AddTestTablet("cell", host, port, keyspace, shard, tabletType, false, 10, nil).Tablet()
err = f(tg, target)
err = f(ctx, tg, target)
verifyShardErrors(t, err, want, vtrpcpb.Code_UNAVAILABLE)

// retry error
Expand All @@ -214,7 +218,7 @@ func testTabletGatewayGeneric(t *testing.T, f func(tg *TabletGateway, target *qu
sc1.MustFailCodes[vtrpcpb.Code_FAILED_PRECONDITION] = 1
sc2.MustFailCodes[vtrpcpb.Code_FAILED_PRECONDITION] = 1

err = f(tg, target)
err = f(ctx, tg, target)
verifyContainsError(t, err, "target: ks.0.replica", vtrpcpb.Code_FAILED_PRECONDITION)

// fatal error
Expand All @@ -223,26 +227,25 @@ func testTabletGatewayGeneric(t *testing.T, f func(tg *TabletGateway, target *qu
sc2 = hc.AddTestTablet("cell", host, port+1, keyspace, shard, tabletType, true, 10, nil)
sc1.MustFailCodes[vtrpcpb.Code_FAILED_PRECONDITION] = 1
sc2.MustFailCodes[vtrpcpb.Code_FAILED_PRECONDITION] = 1
err = f(tg, target)
err = f(ctx, tg, target)
verifyContainsError(t, err, "target: ks.0.replica", vtrpcpb.Code_FAILED_PRECONDITION)

// server error - no retry
hc.Reset()
sc1 = hc.AddTestTablet("cell", host, port, keyspace, shard, tabletType, true, 10, nil)
sc1.MustFailCodes[vtrpcpb.Code_INVALID_ARGUMENT] = 1
err = f(tg, target)
err = f(ctx, tg, target)
assert.Equal(t, vtrpcpb.Code_INVALID_ARGUMENT, vterrors.Code(err))

// no failure
hc.Reset()
hc.AddTestTablet("cell", host, port, keyspace, shard, tabletType, true, 10, nil)
err = f(tg, target)
err = f(ctx, tg, target)
assert.NoError(t, err)
}

func testTabletGatewayTransact(t *testing.T, f func(tg *TabletGateway, target *querypb.Target) error) {
func testTabletGatewayTransact(t *testing.T, ctx context.Context, f func(ctx context.Context, tg *TabletGateway, target *querypb.Target) error) {
t.Helper()
ctx := utils.LeakCheckContext(t)

keyspace := "ks"
shard := "0"
Expand All @@ -267,14 +270,14 @@ func testTabletGatewayTransact(t *testing.T, f func(tg *TabletGateway, target *q
sc1.MustFailCodes[vtrpcpb.Code_FAILED_PRECONDITION] = 1
sc2.MustFailCodes[vtrpcpb.Code_FAILED_PRECONDITION] = 1

err := f(tg, target)
err := f(ctx, tg, target)
verifyContainsError(t, err, "target: ks.0.primary", vtrpcpb.Code_FAILED_PRECONDITION)

// server error - no retry
hc.Reset()
sc1 = hc.AddTestTablet("cell", host, port, keyspace, shard, tabletType, true, 10, nil)
sc1.MustFailCodes[vtrpcpb.Code_INVALID_ARGUMENT] = 1
err = f(tg, target)
err = f(ctx, tg, target)
verifyContainsError(t, err, "target: ks.0.primary", vtrpcpb.Code_INVALID_ARGUMENT)
}

Expand Down
Loading

0 comments on commit cc435db

Please sign in to comment.