Skip to content

Commit

Permalink
vault-12244 (hashicorp#19591)
Browse files Browse the repository at this point in the history
* vault-12244

* CL
  • Loading branch information
hghaf099 authored Mar 17, 2023
1 parent 40c4684 commit f15715f
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 7 deletions.
3 changes: 3 additions & 0 deletions changelog/19591.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:improvement
core: validate name identifiers in mssql physical storage backend prior use
```
30 changes: 25 additions & 5 deletions physical/mssql/mssql.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"context"
"database/sql"
"fmt"
"regexp"
"sort"
"strconv"
"strings"
Expand All @@ -21,6 +22,7 @@ import (

// Verify MSSQLBackend satisfies the correct interfaces
var _ physical.Backend = (*MSSQLBackend)(nil)
var identifierRegex = regexp.MustCompile(`^[\p{L}_][\p{L}\p{Nd}@#$_]*$`)

type MSSQLBackend struct {
dbTable string
Expand All @@ -30,6 +32,13 @@ type MSSQLBackend struct {
permitPool *physical.PermitPool
}

func isInvalidIdentifier(name string) bool {
if !identifierRegex.MatchString(name) {
return true
}
return false
}

func NewMSSQLBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) {
username, ok := conf["username"]
if !ok {
Expand Down Expand Up @@ -71,11 +80,19 @@ func NewMSSQLBackend(conf map[string]string, logger log.Logger) (physical.Backen
database = "Vault"
}

if isInvalidIdentifier(database) {
return nil, fmt.Errorf("invalid database name")
}

table, ok := conf["table"]
if !ok {
table = "Vault"
}

if isInvalidIdentifier(table) {
return nil, fmt.Errorf("invalid table name")
}

appname, ok := conf["appname"]
if !ok {
appname = "Vault"
Expand All @@ -96,6 +113,10 @@ func NewMSSQLBackend(conf map[string]string, logger log.Logger) (physical.Backen
schema = "dbo"
}

if isInvalidIdentifier(schema) {
return nil, fmt.Errorf("invalid schema name")
}

connectionString := fmt.Sprintf("server=%s;app name=%s;connection timeout=%s;log=%s", server, appname, connectionTimeout, logLevel)
if username != "" {
connectionString += ";user id=" + username
Expand All @@ -116,18 +137,17 @@ func NewMSSQLBackend(conf map[string]string, logger log.Logger) (physical.Backen

db.SetMaxOpenConns(maxParInt)

if _, err := db.Exec("IF NOT EXISTS(SELECT * FROM sys.databases WHERE name = '" + database + "') CREATE DATABASE " + database); err != nil {
if _, err := db.Exec("IF NOT EXISTS(SELECT * FROM sys.databases WHERE name = ?) CREATE DATABASE "+database, database); err != nil {
return nil, fmt.Errorf("failed to create mssql database: %w", err)
}

dbTable := database + "." + schema + "." + table
createQuery := "IF NOT EXISTS(SELECT 1 FROM " + database + ".INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE='BASE TABLE' AND TABLE_NAME='" + table + "' AND TABLE_SCHEMA='" + schema +
"') CREATE TABLE " + dbTable + " (Path VARCHAR(512) PRIMARY KEY, Value VARBINARY(MAX))"
createQuery := "IF NOT EXISTS(SELECT 1 FROM " + database + ".INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE='BASE TABLE' AND TABLE_NAME=? AND TABLE_SCHEMA=?) CREATE TABLE " + dbTable + " (Path VARCHAR(512) PRIMARY KEY, Value VARBINARY(MAX))"

if schema != "dbo" {

var num int
err = db.QueryRow("SELECT 1 FROM " + database + ".sys.schemas WHERE name = '" + schema + "'").Scan(&num)
err = db.QueryRow("SELECT 1 FROM "+database+".sys.schemas WHERE name = ?", schema).Scan(&num)

switch {
case err == sql.ErrNoRows:
Expand All @@ -140,7 +160,7 @@ func NewMSSQLBackend(conf map[string]string, logger log.Logger) (physical.Backen
}
}

if _, err := db.Exec(createQuery); err != nil {
if _, err := db.Exec(createQuery, table, schema); err != nil {
return nil, fmt.Errorf("failed to create mssql table: %w", err)
}

Expand Down
44 changes: 42 additions & 2 deletions physical/mssql/mssql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,53 @@ import (
"os"
"testing"

_ "github.com/denisenkom/go-mssqldb"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/sdk/helper/logging"
"github.com/hashicorp/vault/sdk/physical"

_ "github.com/denisenkom/go-mssqldb"
)

// TestInvalidIdentifier checks validity of an identifier
func TestInvalidIdentifier(t *testing.T) {
testcases := map[string]bool{
"name": true,
"_name": true,
"Name": true,
"#name": false,
"?Name": false,
"9name": false,
"@name": false,
"$name": false,
" name": false,
"n ame": false,
"n4444444": true,
"_4321098765": true,
"_##$$@@__": true,
"_123name#@": true,
"name!": false,
"name%": false,
"name^": false,
"name&": false,
"name*": false,
"name(": false,
"name)": false,
"nåame": true,
"åname": true,
"name'": false,
"nam`e": false,
"пример": true,
"_#Āā@#$_ĂĄąćĈĉĊċ": true,
"ÛÜÝÞßàáâ": true,
"豈更滑a23$#@": true,
}

for i, expected := range testcases {
if !isInvalidIdentifier(i) != expected {
t.Fatalf("unexpected identifier %s: expected validity %v", i, expected)
}
}
}

func TestMSSQLBackend(t *testing.T) {
server := os.Getenv("MSSQL_SERVER")
if server == "" {
Expand Down

0 comments on commit f15715f

Please sign in to comment.