Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] support column encryption #202

Merged
merged 3 commits into from
Jul 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cmd/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import (
"github.com/cectc/dbpack/pkg/executor"
"github.com/cectc/dbpack/pkg/filter"
_ "github.com/cectc/dbpack/pkg/filter/audit_log"
_ "github.com/cectc/dbpack/pkg/filter/crypto"
_ "github.com/cectc/dbpack/pkg/filter/dt"
_ "github.com/cectc/dbpack/pkg/filter/metrics"
dbpackHttp "github.com/cectc/dbpack/pkg/http"
Expand Down
9 changes: 9 additions & 0 deletions docker/conf/config_rws.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ executors:
weight: r0w10
- name: employees-slave
weight: r10w0
filters:
- cryptoFilter

data_source_cluster:
- name: employees-master
Expand Down Expand Up @@ -46,6 +48,13 @@ filters:
appid: svc
lock_retry_interval: 50ms
lock_retry_times: 30
- name: cryptoFilter
kind: CryptoFilter
conf:
column_crypto_list:
- table: departments
columns: ["dept_name"]
aeskey: 123456789abcdefg

distributed_transaction:
appid: svc
Expand Down
9 changes: 9 additions & 0 deletions docker/conf/config_sdb.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ executors:
mode: sdb
config:
data_source_ref: employees
filters:
- cryptoFilter

data_source_cluster:
- name: employees
Expand Down Expand Up @@ -50,6 +52,13 @@ filters:
# determines if the rotated log files should be compressed using gzip
compress: true
record_before: true
- name: cryptoFilter
kind: CryptoFilter
conf:
column_crypto_list:
- table: departments
columns: ["dept_name"]
aeskey: 123456789abcdefg

distributed_transaction:
appid: svc
Expand Down
2 changes: 1 addition & 1 deletion docker/scripts/init.sql
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ CREATE TABLE employees (
CREATE TABLE departments (
`id` bigint NOT NULL AUTO_INCREMENT,
dept_no CHAR(4) NOT NULL,
dept_name VARCHAR(40) NOT NULL,
dept_name VARCHAR(100) NOT NULL,
PRIMARY KEY (`id`),
UNIQUE KEY (dept_name)
);
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ require (
github.com/opentracing/opentracing-go v1.1.0 // indirect
github.com/pingcap/failpoint v0.0.0-20210316064728-7acb0f0a3dfd // indirect
github.com/pingcap/kvproto v0.0.0-20210806074406-317f69fb54b4 // indirect
github.com/pingcap/parser v0.0.0-20210831085004-b5390aa83f65 // indirect
github.com/pingcap/parser v0.0.0-20210831085004-b5390aa83f65
github.com/pingcap/tipb v0.0.0-20210708040514-0f154bb0dc0f // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/client_model v0.2.0 // indirect
Expand Down
69 changes: 69 additions & 0 deletions pkg/executor/misc.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@
package executor

import (
"io"
"strings"

"github.com/cectc/dbpack/pkg/mysql"
"github.com/cectc/dbpack/pkg/proto"
"github.com/cectc/dbpack/third_party/parser/ast"
driver "github.com/cectc/dbpack/third_party/types/parser_driver"
)
Expand All @@ -43,3 +46,69 @@ func shouldStartTransaction(stmt *ast.SetStmt) (shouldStartTransaction bool) {
}
return
}

func decodeTextResult(result proto.Result) (proto.Result, error) {
if result != nil {
if mysqlResult, ok := result.(*mysql.Result); ok {
if mysqlResult.Rows != nil {
var rows []proto.Row
for {
row, err := mysqlResult.Rows.Next()
if err != nil {
if err == io.EOF {
break
}
return nil, err
}
textRow := &mysql.TextRow{Row: row}
_, err = textRow.Decode()
if err != nil {
return nil, err
}
rows = append(rows, textRow)
}
decodedRow := &mysql.DecodedResult{
Fields: mysqlResult.Fields,
AffectedRows: mysqlResult.AffectedRows,
InsertId: mysqlResult.InsertId,
Rows: rows,
}
return decodedRow, nil
}
}
}
return result, nil
}

func decodeBinaryResult(result proto.Result) (proto.Result, error) {
if result != nil {
if mysqlResult, ok := result.(*mysql.Result); ok {
if mysqlResult.Rows != nil {
var rows []proto.Row
for {
row, err := mysqlResult.Rows.Next()
if err != nil {
if err == io.EOF {
break
}
return nil, err
}
binaryRow := &mysql.BinaryRow{Row: row}
_, err = binaryRow.Decode()
if err != nil {
return nil, err
}
rows = append(rows, binaryRow)
}
decodedRow := &mysql.DecodedResult{
Fields: mysqlResult.Fields,
AffectedRows: mysqlResult.AffectedRows,
InsertId: mysqlResult.InsertId,
Rows: rows,
}
return decodedRow, nil
}
}
}
return result, nil
}
83 changes: 56 additions & 27 deletions pkg/executor/read_write_splitting.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,13 @@ import (
"github.com/cectc/dbpack/pkg/filter"
"github.com/cectc/dbpack/pkg/lb"
"github.com/cectc/dbpack/pkg/log"
"github.com/cectc/dbpack/pkg/misc"
"github.com/cectc/dbpack/pkg/mysql"
"github.com/cectc/dbpack/pkg/proto"
"github.com/cectc/dbpack/pkg/resource"
"github.com/cectc/dbpack/pkg/tracing"
"github.com/cectc/dbpack/third_party/parser/ast"
"github.com/cectc/dbpack/third_party/parser/model"
)

const (
hintUseDB = "UseDB"
"github.com/cectc/dbpack/third_party/parser/format"
)

type ReadWriteSplittingExecutor struct {
Expand Down Expand Up @@ -152,19 +149,44 @@ func (executor *ReadWriteSplittingExecutor) ExecuteFieldList(ctx context.Context
return nil, errors.New("unimplemented COM_FIELD_LIST in read write splitting mode")
}

func (executor *ReadWriteSplittingExecutor) ExecutorComQuery(ctx context.Context, sql string) (proto.Result, uint16, error) {
var (
db *DataSourceBrief
tx proto.Tx
result proto.Result
err error
)

func (executor *ReadWriteSplittingExecutor) ExecutorComQuery(
ctx context.Context, _ string) (result proto.Result, warns uint16, err error) {
spanCtx, span := tracing.GetTraceSpan(ctx, tracing.RWSComQuery)
defer span.End()

if err = executor.doPreFilter(spanCtx); err != nil {
return nil, 0, err
}
defer func() {
if err == nil {
result, err = decodeTextResult(result)
if err != nil {
span.RecordError(err)
return
}
err = executor.doPostFilter(spanCtx, result)
} else {
span.RecordError(err)
}
}()

var (
db *DataSourceBrief
tx proto.Tx
sb strings.Builder
)

connectionID := proto.ConnectionID(spanCtx)
queryStmt := proto.QueryStmt(spanCtx)
if err := queryStmt.Restore(format.NewRestoreCtx(format.RestoreStringSingleQuotes|
format.RestoreKeyWordUppercase|
format.RestoreStringWithoutDefaultCharset, &sb)); err != nil {
return nil, 0, err
}
sql := sb.String()
spanCtx = proto.WithSqlText(spanCtx, sql)

log.Debugf("connectionID: %d, query: %s", connectionID, sql)
switch stmt := queryStmt.(type) {
case *ast.SetStmt:
if shouldStartTransaction(stmt) {
Expand Down Expand Up @@ -246,7 +268,7 @@ func (executor *ReadWriteSplittingExecutor) ExecutorComQuery(ctx context.Context
return tx.Query(spanCtx, sql)
}
withSlaveCtx := proto.WithSlave(spanCtx)
if has, dsName := hasUseDBHint(stmt.TableHints); has {
if has, dsName := misc.HasUseDBHint(stmt.TableHints); has {
protoDB := resource.GetDBManager().GetDB(dsName)
if protoDB == nil {
log.Debugf("data source %d not found", dsName)
Expand All @@ -271,11 +293,29 @@ func (executor *ReadWriteSplittingExecutor) ExecutorComQuery(ctx context.Context
}
}

func (executor *ReadWriteSplittingExecutor) ExecutorComStmtExecute(ctx context.Context, stmt *proto.Stmt) (proto.Result, uint16, error) {
func (executor *ReadWriteSplittingExecutor) ExecutorComStmtExecute(
ctx context.Context, stmt *proto.Stmt) (result proto.Result, warns uint16, err error) {
spanCtx, span := tracing.GetTraceSpan(ctx, tracing.RWSComStmtExecute)
defer span.End()

if err = executor.doPreFilter(spanCtx); err != nil {
return nil, 0, err
}
defer func() {
if err == nil {
result, err = decodeBinaryResult(result)
if err != nil {
span.RecordError(err)
return
}
err = executor.doPostFilter(spanCtx, result)
} else {
span.RecordError(err)
}
}()

connectionID := proto.ConnectionID(spanCtx)
log.Debugf("connectionID: %d, prepare: %s", connectionID, stmt.SqlText)
txi, ok := executor.localTransactionMap.Load(connectionID)
if ok {
// in local transaction
Expand All @@ -288,7 +328,7 @@ func (executor *ReadWriteSplittingExecutor) ExecutorComStmtExecute(ctx context.C
return db.DB.ExecuteStmt(proto.WithMaster(spanCtx), stmt)
case *ast.SelectStmt:
var db *DataSourceBrief
if has, dsName := hasUseDBHint(st.TableHints); has {
if has, dsName := misc.HasUseDBHint(st.TableHints); has {
protoDB := resource.GetDBManager().GetDB(dsName)
if protoDB == nil {
log.Debugf("data source %d not found", dsName)
Expand Down Expand Up @@ -338,14 +378,3 @@ func (executor *ReadWriteSplittingExecutor) doPostFilter(ctx context.Context, re
}
return nil
}

func hasUseDBHint(hints []*ast.TableOptimizerHint) (bool, string) {
for _, hint := range hints {
if strings.EqualFold(hint.HintName.String(), hintUseDB) {
hintData := hint.HintData.(model.CIStr)
ds := hintData.String()
return true, ds
}
}
return false, ""
}
13 changes: 11 additions & 2 deletions pkg/executor/sharding.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,11 +240,20 @@ func (executor *ShardingExecutor) ExecutorComQuery(ctx context.Context, sql stri
return plan.Execute(spanCtx)
}

func (executor *ShardingExecutor) ExecutorComStmtExecute(ctx context.Context, stmt *proto.Stmt) (proto.Result, uint16, error) {
func (executor *ShardingExecutor) ExecutorComStmtExecute(
ctx context.Context, stmt *proto.Stmt) (result proto.Result, warns uint16, err error) {
if err = executor.doPreFilter(ctx); err != nil {
return nil, 0, err
}
defer func() {
if err == nil {
err = executor.doPostFilter(ctx, result)
}
}()

var (
args []interface{}
plan proto.Plan
err error
)

spanCtx, span := tracing.GetTraceSpan(ctx, tracing.SHDComStmtExecute)
Expand Down
Loading