Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
zjg555543 committed Dec 13, 2024
1 parent ec2c903 commit 4b1cad2
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 55 deletions.
2 changes: 1 addition & 1 deletion cmd/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -844,7 +844,7 @@ func runSqliteServiceIfNeeded(
}

server := sqldb.CreateSqliteService(cfg.Sqlite, dbPath)
log.Info(fmt.Sprintf("Starting sqlite service on %s:%d\n%v", cfg.Sqlite.Host, cfg.Sqlite.Port, allDBPath))
log.Info(fmt.Sprintf("Starting sqlite service on %s:%d,max:%v,\n%v", cfg.Sqlite.Host, cfg.Sqlite.Port, cfg.Sqlite.MaxRequestsPerIPAndSecond, allDBPath))
go func() {
if err := server.Start(); err != nil {
log.Fatal(err)
Expand Down
2 changes: 2 additions & 0 deletions db/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,7 @@ type Config struct {
// check net/http.server.WriteTimeout
WriteTimeout types.Duration `mapstructure:"WriteTimeout"`

MaxRequestsPerIPAndSecond float64 `mapstructure:"MaxRequestsPerIPAndSecond"`

AuthMethodList string `mapstructure:"AuthMethodList"`
}
153 changes: 99 additions & 54 deletions db/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@ import (
"strings"
"time"

"github.com/0xPolygon/cdk-data-availability/rpc"
jRPC "github.com/0xPolygon/cdk-rpc/rpc"
"github.com/0xPolygon/cdk-rpc/rpc"
"github.com/0xPolygon/cdk/db/types"
"github.com/0xPolygon/cdk/log"
"go.opentelemetry.io/otel"
Expand Down Expand Up @@ -52,7 +51,7 @@ type SqliteEndpoints struct {
func CreateSqliteService(
cfg Config,
dbMaps map[string]string,
) *jRPC.Server {
) *rpc.Server {
logger := log.WithFields("module", NAME)

meter := otel.Meter(meterName)
Expand All @@ -62,7 +61,6 @@ func CreateSqliteService(
methodList = append(methodList, s)
}
log.Info(fmt.Sprintf("Sqlite service dbMaps: %v", dbMaps))
time.Sleep(10 * time.Second)
sqlDBs := make(map[string]*dbSql.DB)
for k, dbPath := range dbMaps {
log.Info(fmt.Sprintf("Sqlite service: %s, %s", k, dbPath))
Expand All @@ -72,8 +70,9 @@ func CreateSqliteService(
}
sqlDBs[k] = db
}
log.Info(fmt.Sprintf("Sqlite service sqlDBs sucess %v", sqlDBs))

services := []jRPC.Service{
services := []rpc.Service{
{
Name: NAME,
Service: &SqliteEndpoints{
Expand All @@ -88,12 +87,13 @@ func CreateSqliteService(
},
}

return jRPC.NewServer(jRPC.Config{
Host: cfg.Host,
Port: cfg.Port,
ReadTimeout: cfg.ReadTimeout,
WriteTimeout: cfg.WriteTimeout,
}, services, jRPC.WithLogger(logger.GetSugaredLogger()))
return rpc.NewServer(rpc.Config{
Host: cfg.Host,
Port: cfg.Port,
ReadTimeout: cfg.ReadTimeout,
WriteTimeout: cfg.WriteTimeout,
MaxRequestsPerIPAndSecond: cfg.MaxRequestsPerIPAndSecond,
}, services, rpc.WithLogger(logger.GetSugaredLogger()))
}

type A struct {
Expand All @@ -104,13 +104,16 @@ func (b *SqliteEndpoints) Select(
db string,
sql string,
) (interface{}, rpc.Error) {
log.Info(fmt.Sprintf("zjg----1"))
err, dbCon := b.checkAndGetDB(db, sql, METHOD_SELECT)
if err != nil {
return zeroHex, rpc.NewRPCError(rpc.DefaultErrorCode, fmt.Sprintf("invalid sql: %s", err.Error()))
log.Info(fmt.Sprintf("zjg----1-2---error"))
return zeroHex, rpc.NewRPCError(rpc.DefaultErrorCode, fmt.Sprintf("check params invalid: %s", err.Error()))
}

log.Info(fmt.Sprintf("zjg----2"))
ctx, cancel := context.WithTimeout(context.Background(), b.readTimeout)
defer cancel()
log.Info(fmt.Sprintf("zjg----2--1"))
rows, err := dbCon.QueryContext(ctx, sql)
if err != nil {
if errors.Is(err, dbSql.ErrNoRows) {
Expand All @@ -119,7 +122,13 @@ func (b *SqliteEndpoints) Select(
}
return nil, rpc.NewRPCError(rpc.DefaultErrorCode, fmt.Sprintf("failed to query: %s", err.Error()))
}
log.Info(fmt.Sprintf("zjg----3"))

err, result := getResults(rows)
if err != nil {
return nil, rpc.NewRPCError(rpc.DefaultErrorCode, fmt.Sprintf("failed to get results: %s", err.Error()))
}
log.Info(fmt.Sprintf("zjg----14"))

return result, nil
}
Expand Down Expand Up @@ -147,13 +156,7 @@ func (b *SqliteEndpoints) Delete(
db string,
sql string,
) (interface{}, rpc.Error) {
ctx, cancel := context.WithTimeout(context.Background(), b.readTimeout)
defer cancel()
c, merr := b.meter.Int64Counter("claim_proof")
if merr != nil {
b.logger.Warnf("failed to create claim_proof counter: %s", merr)
}
c.Add(ctx, 1)
log.Info(fmt.Sprintf("Sqlite service Delete: %s, %s", db, sql))

return types.SqliteData{
ProofLocalExitRoot: "ProofLocalExitRoot",
Expand All @@ -162,85 +165,127 @@ func (b *SqliteEndpoints) Delete(
}, nil
}

func (b *SqliteEndpoints) GetDbList() (interface{}, rpc.Error) {
ctx, cancel := context.WithTimeout(context.Background(), b.readTimeout)
defer cancel()
c, merr := b.meter.Int64Counter("claim_proof")
if merr != nil {
b.logger.Warnf("failed to create claim_proof counter: %s", merr)
func (b *SqliteEndpoints) GetDbs() (interface{}, rpc.Error) {
//var dbList []string
dbList := make(map[string][]string)
for k, dbPath := range b.dbMaps {
log.Info(fmt.Sprintf("Sqlite service: %s, %s", k, dbPath))

sql := "SELECT name FROM sqlite_master WHERE type = 'table' ORDER BY name;"
err, dbCon := b.checkAndGetDB(k, sql, METHOD_SELECT)
if err != nil {
log.Info(fmt.Sprintf("zjg----1-2---error"))
return zeroHex, rpc.NewRPCError(rpc.DefaultErrorCode, fmt.Sprintf("%s", err.Error()))
}
log.Info(fmt.Sprintf("zjg----2"))
ctx, cancel := context.WithTimeout(context.Background(), b.readTimeout)
defer cancel()
log.Info(fmt.Sprintf("zjg----2--1"))
rows, err := dbCon.QueryContext(ctx, sql)
if err != nil {
if errors.Is(err, dbSql.ErrNoRows) {
return nil, rpc.NewRPCError(
rpc.DefaultErrorCode, fmt.Sprintf("No rows"), ErrNotFound)
}
return nil, rpc.NewRPCError(rpc.DefaultErrorCode, fmt.Sprintf("failed to query: %s", err.Error()))
}
log.Info(fmt.Sprintf("zjg----3"))

err, result := getTables(rows)
if err != nil {
return nil, rpc.NewRPCError(rpc.DefaultErrorCode, fmt.Sprintf("failed to get results: %s", err.Error()))
}

dbList[k] = result
}
c.Add(ctx, 1)

return types.SqliteData{
ProofLocalExitRoot: "ProofLocalExitRoot",
ProofRollupExitRoot: "ProofRollupExitRoot",
L1InfoTreeLeaf: "L1InfoTreeLeaf",
}, nil
return dbList, nil
}

func (b *SqliteEndpoints) checkAndGetDB(db string, sql string, method string) (error, *dbSql.DB) {
log.Info("zjg, chenckAndGetDB: -----------1")
if len(sql) <= LIMIT_SQL_LEN {
log.Info("zjg, chenckAndGetDB: -----------1--1")
return fmt.Errorf("sql length is too short"), nil
}
log.Info("zjg, chenckAndGetDB: -----------2")

sqlMethod := strings.ToLower(sql[:6])
if sqlMethod != method {
return fmt.Errorf("sql method is not valid"), nil
}
log.Info("zjg, chenckAndGetDB: -----------3")

found := false
for _, str := range b.authMethods {
log.Info("zjg, chenckAndGetDB: -----------4")

if str == method {
found = true
break
}
}
if !found {
log.Info("zjg, chenckAndGetDB: -----------5")
return fmt.Errorf("sql method is not authorized"), nil
}
log.Info("zjg, chenckAndGetDB: -----------6")

dbCon, ok := b.sqlDBs[db]
log.Info("zjg, chenckAndGetDB: -----------7")
if !ok {
return fmt.Errorf("sql db is not valid"), nil
}

log.Info("zjg, chenckAndGetDB: -----------8")
return nil, dbCon
}

func getResults(rows *dbSql.Rows) (error, []A) {
var result []A

// 获取列名
columns, err := rows.Columns()
if err != nil {
log.Fatalf("Failed to get columns: %v", err)
}
for rows.Next() {
var a A
a.Fields = make(map[string]interface{})
err := rows.Scan()
if err != nil {
log.Info("Error scanning row:", err)
return err, nil
}

columns, err := rows.Columns()
if err != nil {
fmt.Println("Error getting columns:", err)
return err, nil
}
record := A{Fields: make(map[string]interface{})}

// 创建值和指针的切片
values := make([]interface{}, len(columns))
valuePtrs := make([]interface{}, len(columns))
for i := range values {
values[i] = &values[i]
valuePtrs[i] = &values[i]
}

err = rows.Scan(values...)
if err != nil {
fmt.Println("Error scanning row:", err)
return err, nil
// 扫描结果
if err := rows.Scan(valuePtrs...); err != nil {
log.Fatalf("Failed to scan row: %v", err)
}

for i, col := range columns {
a.Fields[col] = values[i]
// 将每列数据存入 map
for i, colName := range columns {
record.Fields[colName] = values[i]
}

result = append(result, a)
// 输出结果或处理
fmt.Printf("Record: %+v\n", record.Fields)
result = append(result, record)
}

return nil, result
}

func getTables(rows *dbSql.Rows) (error, []string) {
var result []string

for rows.Next() {
var tableName string
if err := rows.Scan(&tableName); err != nil {
return err, nil
}
log.Info(fmt.Sprintf("Table name: %s", tableName))
result = append(result, tableName)
}

return nil, result
Expand Down

0 comments on commit 4b1cad2

Please sign in to comment.