Skip to content

Commit

Permalink
feat: snowflake key pair (#4781)
Browse files Browse the repository at this point in the history
  • Loading branch information
achettyiitr authored Jun 11, 2024
1 parent b81ee4a commit c64447e
Show file tree
Hide file tree
Showing 12 changed files with 572 additions and 180 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ require (
github.com/viney-shih/go-lock v1.1.2
github.com/xitongsys/parquet-go v1.5.1
github.com/xitongsys/parquet-go-source v0.0.0-20240122235623-d6294584ab18
github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d
go.etcd.io/etcd/api/v3 v3.5.14
go.etcd.io/etcd/client/v3 v3.5.14
go.uber.org/atomic v1.11.0
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -1256,6 +1256,8 @@ github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 h1:gEOO8jv9F4OT7lGC
github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1/go.mod h1:Ohn+xnUBiLI6FVj/9LpzZWtj1/D6lUovWYBkxHVV3aM=
github.com/xtgo/uuid v0.0.0-20140804021211-a0b114877d4c h1:3lbZUMbMiGUW/LMkfsEABsc5zNT9+b1CvsJx47JzJ8g=
github.com/xtgo/uuid v0.0.0-20140804021211-a0b114877d4c/go.mod h1:UrdRz5enIKZ63MEE3IF9l2/ebyx59GyGgPi+tICQdmM=
github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d h1:splanxYIlg+5LfHAM6xpdFEAYOk8iySO56hMFq6uLyA=
github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d/go.mod h1:rHwXgn7JulP+udvsHwJoVG1YGAP6VLg4y9I5dyZdqmA=
github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
Expand Down
3 changes: 2 additions & 1 deletion warehouse/integrations/bigquery/bigquery_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/rudderlabs/rudder-go-kit/filemanager"
"github.com/rudderlabs/rudder-go-kit/logger"
kithelper "github.com/rudderlabs/rudder-go-kit/testhelper"

backendconfig "github.com/rudderlabs/rudder-server/backend-config"
"github.com/rudderlabs/rudder-server/runner"
"github.com/rudderlabs/rudder-server/testhelper/health"
Expand All @@ -42,7 +43,7 @@ func TestIntegration(t *testing.T) {
if os.Getenv("SLOW") != "1" {
t.Skip("Skipping tests. Add 'SLOW=1' env var to run test.")
}
if !bqHelper.IsBQTestCredentialsAvailable() {
if _, exists := os.LookupEnv(bqHelper.TestKey); !exists {
t.Skipf("Skipping %s as %s is not set", t.Name(), bqHelper.TestKey)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package middleware_test

import (
"context"
"os"
"testing"
"time"

Expand All @@ -11,13 +12,14 @@ import (
"github.com/stretchr/testify/require"

"github.com/rudderlabs/rudder-go-kit/logger/mock_logger"

"github.com/rudderlabs/rudder-server/warehouse/integrations/bigquery"
"github.com/rudderlabs/rudder-server/warehouse/integrations/bigquery/middleware"
"github.com/rudderlabs/rudder-server/warehouse/logfield"
)

func TestQueryWrapper(t *testing.T) {
if !bqHelper.IsBQTestCredentialsAvailable() {
if _, exists := os.LookupEnv(bqHelper.TestKey); !exists {
t.Skipf("Skipping %s as %s is not set", t.Name(), bqHelper.TestKey)
}

Expand Down
5 changes: 0 additions & 5 deletions warehouse/integrations/bigquery/testhelper/setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,6 @@ func GetBQTestCredentials() (*TestCredentials, error) {
return &credentials, nil
}

func IsBQTestCredentialsAvailable() bool {
_, err := GetBQTestCredentials()
return err == nil
}

// RetrieveRecordsFromWarehouse retrieves records from the warehouse based on the given query.
// It returns a slice of slices, where each inner slice represents a record's values.
func RetrieveRecordsFromWarehouse(
Expand Down
7 changes: 1 addition & 6 deletions warehouse/integrations/deltalake/deltalake_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,16 +68,11 @@ func deltaLakeTestCredentials() (*testCredentials, error) {
return &credentials, nil
}

func testCredentialsAvailable() bool {
_, err := deltaLakeTestCredentials()
return err == nil
}

func TestIntegration(t *testing.T) {
if os.Getenv("SLOW") != "1" {
t.Skip("Skipping tests. Add 'SLOW=1' env var to run test.")
}
if !testCredentialsAvailable() {
if _, exists := os.LookupEnv(testKey); !exists {
t.Skipf("Skipping %s as %s is not set", t.Name(), testKey)
}

Expand Down
5 changes: 3 additions & 2 deletions warehouse/integrations/manager/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/rudderlabs/rudder-go-kit/config"
"github.com/rudderlabs/rudder-go-kit/logger"
"github.com/rudderlabs/rudder-go-kit/stats"

"github.com/rudderlabs/rudder-server/utils/misc"
"github.com/rudderlabs/rudder-server/warehouse/client"
azuresynapse "github.com/rudderlabs/rudder-server/warehouse/integrations/azure-synapse"
Expand Down Expand Up @@ -65,7 +66,7 @@ func New(destType string, conf *config.Config, logger logger.Logger, stats stats
case warehouseutils.BQ:
return bigquery.New(conf, logger), nil
case warehouseutils.SNOWFLAKE:
return snowflake.New(conf, logger, stats)
return snowflake.New(conf, logger, stats), nil
case warehouseutils.POSTGRES:
return postgres.New(conf, logger, stats), nil
case warehouseutils.CLICKHOUSE:
Expand All @@ -90,7 +91,7 @@ func NewWarehouseOperations(destType string, conf *config.Config, logger logger.
case warehouseutils.BQ:
return bigquery.New(conf, logger), nil
case warehouseutils.SNOWFLAKE:
return snowflake.New(conf, logger, stats)
return snowflake.New(conf, logger, stats), nil
case warehouseutils.POSTGRES:
return postgres.New(conf, logger, stats), nil
case warehouseutils.CLICKHOUSE:
Expand Down
7 changes: 1 addition & 6 deletions warehouse/integrations/redshift/redshift_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,16 +70,11 @@ func rsTestCredentials() (*testCredentials, error) {
return &credentials, nil
}

func testCredentialsAvailable() bool {
_, err := rsTestCredentials()
return err == nil
}

func TestIntegration(t *testing.T) {
if os.Getenv("SLOW") != "1" {
t.Skip("Skipping tests. Add 'SLOW=1' env var to run test.")
}
if !testCredentialsAvailable() {
if _, exists := os.LookupEnv(testKey); !exists {
t.Skipf("Skipping %s as %s is not set", t.Name(), testKey)
}

Expand Down
104 changes: 72 additions & 32 deletions warehouse/integrations/snowflake/snowflake.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ package snowflake
import (
"bytes"
"context"
"crypto/rsa"
"database/sql"
"encoding/csv"
"encoding/pem"
"errors"
"fmt"
"regexp"
Expand All @@ -16,6 +18,7 @@ import (

"github.com/samber/lo"
snowflake "github.com/snowflakedb/gosnowflake"
"github.com/youmark/pkcs8"

"github.com/rudderlabs/rudder-go-kit/config"
"github.com/rudderlabs/rudder-go-kit/logger"
Expand All @@ -37,14 +40,17 @@ const (

// String constants for snowflake destination config
const (
storageIntegration = "storageIntegration"
account = "account"
warehouse = "warehouse"
database = "database"
user = "user"
role = "role"
password = "password"
application = "Rudderstack_Warehouse"
storageIntegration = "storageIntegration"
account = "account"
warehouse = "warehouse"
database = "database"
user = "user"
role = "role"
password = "password"
useKeyPairAuth = "useKeyPairAuth"
privateKey = "privateKey"
privateKeyPassphrase = "privateKeyPassphrase"
application = "Rudderstack_Warehouse"
)

var primaryKeyMap = map[string]string{
Expand Down Expand Up @@ -127,14 +133,17 @@ var errorsMappings = []model.JobError{
}

type credentials struct {
account string
warehouse string
database string
user string
role string
password string
schemaName string
timeout time.Duration
account string
warehouse string
database string
user string
role string
password string
schemaName string
useKeyPairAuth bool
privateKey string
privateKeyPassphrase string
timeout time.Duration
}

type tableLoadResp struct {
Expand Down Expand Up @@ -179,11 +188,11 @@ type Snowflake struct {
}
}

func New(conf *config.Config, log logger.Logger, stat stats.Stats) (*Snowflake, error) {
sf := &Snowflake{}

sf.logger = log.Child("integrations").Child("snowflake")
sf.stats = stat
func New(conf *config.Config, log logger.Logger, stat stats.Stats) *Snowflake {
sf := &Snowflake{
logger: log.Child("integrations").Child("snowflake"),
stats: stat,
}

sf.config.allowMerge = conf.GetBool("Warehouse.snowflake.allowMerge", true)
sf.config.enableDeleteByJobs = conf.GetBool("Warehouse.snowflake.enableDeleteByJobs", false)
Expand All @@ -199,7 +208,7 @@ func New(conf *config.Config, log logger.Logger, stat stats.Stats) (*Snowflake,
return strings.ToUpper(item)
})

return sf, nil
return sf
}

func ColumnsWithDataTypes(columns model.TableSchema, prefix string) string {
Expand Down Expand Up @@ -1058,16 +1067,26 @@ func (sf *Snowflake) connect(ctx context.Context, opts optionalCreds) (*sqlmw.DB
Account: cred.account,
User: cred.user,
Role: cred.role,
Password: cred.password,
Database: cred.database,
Schema: cred.schemaName,
Warehouse: cred.warehouse,
Application: application,
}

if cred.timeout > 0 {
urlConfig.LoginTimeout = cred.timeout
}
if cred.useKeyPairAuth {
rsaPrivateKey, err := ParsePrivateKey(cred.privateKey, cred.privateKeyPassphrase)
if err != nil {
return nil, fmt.Errorf("parsing private key: %w", err)
}

urlConfig.PrivateKey = rsaPrivateKey
urlConfig.Authenticator = snowflake.AuthTypeJwt
} else {
urlConfig.Password = cred.password
urlConfig.Authenticator = snowflake.AuthTypeSnowflake
}

var err error
dsn, err := snowflake.DSN(&urlConfig)
Expand Down Expand Up @@ -1108,6 +1127,24 @@ func (sf *Snowflake) connect(ctx context.Context, opts optionalCreds) (*sqlmw.DB
return middleware, nil
}

func ParsePrivateKey(privateKey, passPhrase string) (*rsa.PrivateKey, error) {
block, _ := pem.Decode([]byte(whutils.FormatPemContent(privateKey)))
if block == nil {
return nil, errors.New("decoding private key failed")
}

var opts [][]byte
if len(passPhrase) > 0 {
opts = append(opts, []byte(passPhrase))
}

rsaPrivateKey, err := pkcs8.ParsePKCS8PrivateKeyRSA(block.Bytes, opts...)
if err != nil {
return nil, fmt.Errorf("parsing private key: %w", err)
}
return rsaPrivateKey, nil
}

func (sf *Snowflake) CreateSchema(ctx context.Context) (err error) {
var schemaExists bool
schemaIdentifier := sf.schemaIdentifier()
Expand Down Expand Up @@ -1336,14 +1373,17 @@ func (sf *Snowflake) IsEmpty(ctx context.Context, warehouse model.Warehouse) (em

func (sf *Snowflake) getConnectionCredentials(opts optionalCreds) credentials {
return credentials{
account: whutils.GetConfigValue(account, sf.Warehouse),
warehouse: whutils.GetConfigValue(warehouse, sf.Warehouse),
database: whutils.GetConfigValue(database, sf.Warehouse),
user: whutils.GetConfigValue(user, sf.Warehouse),
role: whutils.GetConfigValue(role, sf.Warehouse),
password: whutils.GetConfigValue(password, sf.Warehouse),
schemaName: opts.schemaName,
timeout: sf.connectTimeout,
account: whutils.GetConfigValue(account, sf.Warehouse),
warehouse: whutils.GetConfigValue(warehouse, sf.Warehouse),
database: whutils.GetConfigValue(database, sf.Warehouse),
user: whutils.GetConfigValue(user, sf.Warehouse),
role: whutils.GetConfigValue(role, sf.Warehouse),
password: whutils.GetConfigValue(password, sf.Warehouse),
useKeyPairAuth: whutils.ReadAsBool(useKeyPairAuth, sf.Warehouse.Destination.Config),
privateKey: whutils.GetConfigValue(privateKey, sf.Warehouse),
privateKeyPassphrase: whutils.GetConfigValue(privateKeyPassphrase, sf.Warehouse),
schemaName: opts.schemaName,
timeout: sf.connectTimeout,
}
}

Expand Down
Loading

0 comments on commit c64447e

Please sign in to comment.