From 1e94bce2c0b422b578d55d456499fb40b7e72d37 Mon Sep 17 00:00:00 2001 From: adarsh-jaiss Date: Thu, 23 May 2024 01:40:47 +0530 Subject: [PATCH] fixed base64 decoding issue in mysql and postgres Signed-off-by: adarsh-jaiss --- cli/cmd/root.go | 95 ++++++++++++++++++++++++++++------ databases/mysql/mysql.go | 38 +++++--------- databases/postgres/postgres.go | 40 ++++++++------ 3 files changed, 116 insertions(+), 57 deletions(-) diff --git a/cli/cmd/root.go b/cli/cmd/root.go index cb24b70..1c14f27 100644 --- a/cli/cmd/root.go +++ b/cli/cmd/root.go @@ -25,7 +25,13 @@ var ( query string ) -// QueryResult represents the result of a database query. +type QueryResultInterface interface { + GetColumns() []string + GetRows() interface{} + GetTime() float64 + GetError() string +} + type QueryResult struct { Columns []string `json:"columns"` Rows [][]interface{} `json:"rows"` @@ -33,6 +39,47 @@ type QueryResult struct { Error string `json:"error"` } +func (q QueryResult) GetColumns() []string { + return q.Columns +} + +func (q QueryResult) GetRows() interface{} { + return q.Rows +} + +func (q QueryResult) GetTime() float64 { + return q.Time +} + +func (q QueryResult) GetError() string { + return q.Error +} + +type BigQueryResult struct { + Columns []string `json:"columns"` + Rows []map[string]interface{} `json:"rows"` + Time int64 `json:"time"` + Error string `json:"error"` +} + +func (b BigQueryResult) GetColumns() []string { + return b.Columns +} + +func (b BigQueryResult) GetRows() interface{} { + return b.Rows +} + +func (b BigQueryResult) GetTime() float64 { + return float64(b.Time) +} + +func (b BigQueryResult) GetError() string { + return b.Error +} + +// QueryResult represents the result of a database query. + // Command for interacting with databases var shellCmd = &cobra.Command{ Use: "shell", @@ -133,26 +180,44 @@ func queryExecute(query string, db xrayTypes.ISQL) error { return fmt.Errorf("error executing query result: %s", err) } - var result QueryResult - err = json.Unmarshal(b, &result) - if err != nil { - return fmt.Errorf("error parsing query result: %s", err) + var result QueryResultInterface + if dbType == "bigquery" { + result = &BigQueryResult{} + } else { + result = &QueryResult{} } - if len(result.Rows) == 0 { - - return fmt.Errorf("no results found") + err = json.Unmarshal(b, result) + if err != nil { + return fmt.Errorf("error parsing query result: %s", err) } table := tablewriter.NewWriter(os.Stdout) - table.SetHeader(result.Columns) - for _, row := range result.Rows { - stringRow := make([]string, len(row)) - for i, v := range row { - stringRow[i] = fmt.Sprintf("%v", v) + table.SetHeader(result.GetColumns()) // Assert the type of result and call GetColumns() instead of Columns + switch rows := result.GetRows().(type) { + case [][]interface{}: + for _, row := range rows { + stringRow := make([]string, len(row)) + for i, v := range row { + stringRow[i] = fmt.Sprintf("%v", v) + } + + table.Append(stringRow) + } + case []map[string]interface{}: + for _, rowMap := range rows { + var stringRow []string + for _, v := range rowMap { + stringRow = append(stringRow, fmt.Sprintf("%v", v)) + } + table.Append(stringRow) } + default: + return fmt.Errorf("unexpected type of rows: %T", rows) + } - table.Append(stringRow) + if table.NumLines() == 0 { + return fmt.Errorf("no results found") } // Print the table @@ -195,4 +260,4 @@ func parseDbType(s string) xrayTypes.DbType { default: return xrayTypes.MySQL } -} \ No newline at end of file +} diff --git a/databases/mysql/mysql.go b/databases/mysql/mysql.go index 48e71d8..d0a21bc 100644 --- a/databases/mysql/mysql.go +++ b/databases/mysql/mysql.go @@ -7,8 +7,6 @@ import ( "os" "strings" - "encoding/base64" - _ "github.com/go-sql-driver/mysql" "github.com/thesaas-company/xray/config" "github.com/thesaas-company/xray/types" @@ -137,20 +135,21 @@ func (m *MySQL) Execute(query string) ([]byte, error) { return nil, fmt.Errorf("error scanning row: %v", err) } - // Decode base64 data - for i, val := range values { - strVal, ok := val.(string) - if ok && isBase64(strVal) { - // Redecode the value to get the decoded result - decoded, err := base64.StdEncoding.DecodeString(strVal) - if err != nil { - return nil, fmt.Errorf("error decoding base64 data: %v", err) - } - values[i] = string(decoded) + // Convert the values to the appropriate types + stringRow := make([]interface{}, len(values)) + for i, v := range values { + switch value := v.(type) { + case []byte: + stringRow[i] = string(value) + case string: + stringRow[i] = value + default: + stringRow[i] = fmt.Sprintf("%v", value) } } - results = append(results, values) + // Append the modified row to the results + results = append(results, stringRow) } // Check for errors from iterating over rows @@ -171,19 +170,6 @@ func (m *MySQL) Execute(query string) ([]byte, error) { return jsonData, nil } -// isBase64 checks if a string is a valid base64 string. -func isBase64(s string) bool { - if len(s)%4 != 0 { - return false - } - // Try to decode the string - _, err := base64.StdEncoding.DecodeString(s) - // If decoding succeeds, err will be nil, and the function will return true - // If decoding fails, err will not be nil, and the function will return false - // Also we do not have access to decoded value, so we are not using it - return err == nil -} - // Tables retrieves the list of tables in the given database. // It takes the database name as an argument and returns a list of table names. func (m *MySQL) Tables(databaseName string) ([]string, error) { diff --git a/databases/postgres/postgres.go b/databases/postgres/postgres.go index 363ad7d..a360e95 100644 --- a/databases/postgres/postgres.go +++ b/databases/postgres/postgres.go @@ -170,18 +170,29 @@ func (p *Postgres) Execute(query string) ([]byte, error) { } // Decode base64 data - for _, val := range values { - strVal, ok := val.(*string) - if ok && strVal != nil && isBase64(*strVal) { - decoded, err := base64.StdEncoding.DecodeString(*strVal) - if err != nil { - return nil, fmt.Errorf("error decoding base64 data: %v", err) + stringRow := make([]interface{}, len(values)) + for i, val := range values { + switch v := val.(type) { + case []byte: + strVal := string(v) + if isBase64(strVal) { + decoded, err := base64.StdEncoding.DecodeString(strVal) + if err != nil { + return nil, fmt.Errorf("error decoding base64 data: %v", err) + } + stringRow[i] = string(decoded) + } else { + stringRow[i] = strVal } - *strVal = string(decoded) + case string: + stringRow[i] = v + case nil: + stringRow[i] = nil + default: + stringRow[i] = fmt.Sprintf("%v", v) } } - - results = append(results, values) + results = append(results, stringRow) } // Check for errors from iterating over rows @@ -203,15 +214,12 @@ func (p *Postgres) Execute(query string) ([]byte, error) { } func isBase64(s string) bool { - if len(s)%4 != 0 { + decoded, err := base64.StdEncoding.DecodeString(s) + if err != nil { return false } - // Try to decode the string - _, err := base64.StdEncoding.DecodeString(s) - // If decoding succeeds, err will be nil, and the function will return true - // If decoding fails, err will not be nil, and the function will return false - // Also we do not have access to decoded value, so we are not using it - return err == nil + encoded := base64.StdEncoding.EncodeToString(decoded) + return s == encoded } // Tables returns a list of all tables in the given database.