Skip to content

Commit

Permalink
SNOW-1524204 handle session variables in DSN (#1177)
Browse files Browse the repository at this point in the history
* SNOW-1524204 handle session variables in DSN
  • Loading branch information
sfc-gh-dszmolka authored Jul 17, 2024
1 parent 9730225 commit f6155d8
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 1 deletion.
11 changes: 10 additions & 1 deletion dsn.go
Original file line number Diff line number Diff line change
Expand Up @@ -832,7 +832,8 @@ func parseDSNParams(cfg *Config, params string) (err error) {
if cfg.Params == nil {
cfg.Params = make(map[string]*string)
}
cfg.Params[param[0]] = &value
// handle session variables $variable=value
cfg.Params[urlDecodeIfNeeded(param[0])] = &value
}
}
return
Expand Down Expand Up @@ -957,3 +958,11 @@ func extractAccountName(rawAccount string) string {
}
return strings.ToUpper(rawAccount)
}

func urlDecodeIfNeeded(param string) (decodedParam string) {
unescaped, err := url.QueryUnescape(param)
if err != nil {
return param
}
return unescaped
}
50 changes: 50 additions & 0 deletions dsn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
cr "crypto/rand"
"crypto/rsa"
"crypto/x509"
"database/sql"
"encoding/pem"
"fmt"
"net/url"
Expand Down Expand Up @@ -1817,3 +1818,52 @@ func TestExtractAccountName(t *testing.T) {
})
}
}

func TestUrlDecodeIfNeeded(t *testing.T) {
testcases := map[string]string{
"query_tag": "query_tag",
"%24my_custom_variable": "$my_custom_variable",
}
for param, expected := range testcases {
t.Run(param, func(t *testing.T) {
decodedParam := urlDecodeIfNeeded(param)
assertEqualE(t, decodedParam, expected)
})
}
}

func TestUrlDecodeIfNeededE2E(t *testing.T) {
customVarName := "CUSTOM_VARIABLE"
customVarValue := "test"
myQueryTag := "mytag"
testPort, err := strconv.Atoi(os.Getenv("SNOWFLAKE_TEST_PORT"))
if err != nil {
testPort = 443
}

cfg := &Config{
Account: os.Getenv("SNOWFLAKE_TEST_ACCOUNT"),
Host: os.Getenv("SNOWFLAKE_TEST_HOST"),
Port: testPort,
Protocol: os.Getenv("SNOWFLAKE_TEST_PROTOCOL"),
User: os.Getenv("SNOWFLAKE_TEST_USER"),
Password: os.Getenv("SNOWFLAKE_TEST_PASSWORD"),
Params: map[string]*string{"$" + customVarName: &customVarValue, "query_tag": &myQueryTag},
}
mydsn, err := DSN(cfg)
assertNilE(t, err, "TestUrlDecodeIfNeededE2E failed to create DSN from Config")
db, err := sql.Open("snowflake", mydsn)
assertNilE(t, err, "TestUrlDecodeIfNeededE2E failed to connect.")
defer db.Close()
query := "SHOW VARIABLES;"
rows, err := db.Query(query)
assertNilE(t, err, "TestUrlDecodeIfNeededE2E failed to run SHOW VARIABLES query.")
defer rows.Close()
var v1, v2, v3, v4, v5, v6, v7 any
assertTrueE(t, rows.Next(), "TestUrlDecodeIfNeededE2E query run but no rows were returned.")
err = rows.Scan(&v1, &v2, &v3, &v4, &v5, &v6, &v7)
assertNilE(t, err, "TestUrlDecodeIfNeededE2E failed to get result.")
assertDeepEqualE(t, v4, customVarName, "TestUrlDecodeIfNeededE2E variable name retrieved from the test did not match")
assertDeepEqualE(t, v5, customVarValue, "TestUrlDecodeIfNeededE2E variable value retrieved from the test did not match")
assertNilE(t, rows.Err(), "TestUrlDecodeIfNeededE2E ERROR getting rows.")
}

0 comments on commit f6155d8

Please sign in to comment.