Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Implement change password during login #141

Merged
merged 8 commits into from
Aug 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .github/workflows/pr-validation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ jobs:
- name: Run tests against Linux SQL
run: |
go version
go get -d
export SQLCMDPASSWORD=$(date +%s|sha256sum|base64|head -c 32)
export SQLCMDUSER=sa
export SQLUSER=sa
Expand All @@ -30,4 +29,4 @@ jobs:
docker run -m 2GB -e ACCEPT_EULA=1 -d --name sqlserver -p:1433:1433 -e SA_PASSWORD=$SQLCMDPASSWORD mcr.microsoft.com/mssql/server:${{ matrix.sqlImage }}
sleep 10
sqlcmd -Q "CREATE DATABASE test"
go test -race -cpu 4 ./...
shueybubbles marked this conversation as resolved.
Show resolved Hide resolved
go test -v ./...
9 changes: 1 addition & 8 deletions .pipelines/TestSql2017.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,7 @@ variables:
steps:
- task: GoTool@0
inputs:
version: '1.19'
- task: Go@0
displayName: 'Go: get sources'
inputs:
command: 'get'
arguments: '-d'
workingDirectory: '$(Build.SourcesDirectory)'

version: '1.20'

- task: Go@0
displayName: 'Go: install gotest.tools/gotestsum'
Expand Down
1 change: 1 addition & 0 deletions appveyor.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ install:
- go get -u github.com/golang-sql/civil
- go get -u github.com/golang-sql/sqlexp
- go get -u golang.org/x/crypto/md4
- go get github.com/stretchr/testify/assert@v1.8.1
- go get -u golang.org/x/text/encoding/unicode

build_script:
Expand Down
133 changes: 85 additions & 48 deletions msdsn/conn_str.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,34 @@ const (
BrowserDAC BrowserMsg = 0x0f
)

const (
Database = "database"
Encrypt = "encrypt"
Password = "password"
ChangePassword = "change password"
UserID = "user id"
Port = "port"
TrustServerCertificate = "trustservercertificate"
Certificate = "certificate"
TLSMin = "tlsmin"
PacketSize = "packet size"
LogParam = "log"
ConnectionTimeout = "connection timeout"
HostNameInCertificate = "hostnameincertificate"
KeepAlive = "keepalive"
ServerSpn = "serverspn"
WorkstationID = "workstation id"
AppName = "app name"
ApplicationIntent = "applicationintent"
FailoverPartner = "failoverpartner"
FailOverPort = "failoverport"
DisableRetry = "disableretry"
Server = "server"
Protocol = "protocol"
DialTimeout = "dial timeout"
Pipe = "pipe"
)

type Config struct {
Port uint64
Host string
Expand Down Expand Up @@ -88,6 +116,8 @@ type Config struct {
ProtocolParameters map[string]interface{}
// BrowserMsg is the message identifier to fetch instance data from SQL browser
BrowserMessage BrowserMsg
// ChangePassword is used to set the login's password during login. Ignored for non-SQL authentication.
ChangePassword string
//ColumnEncryption is true if the application needs to decrypt or encrypt Always Encrypted values
ColumnEncryption bool
}
Expand Down Expand Up @@ -130,7 +160,7 @@ func parseTLS(params map[string]string, host string) (Encryption, *tls.Config, e
trustServerCert := false

var encryption Encryption = EncryptionOff
encrypt, ok := params["encrypt"]
encrypt, ok := params[Encrypt]
if ok {
if strings.EqualFold(encrypt, "DISABLE") {
encryption = EncryptionDisabled
Expand All @@ -147,7 +177,7 @@ func parseTLS(params map[string]string, host string) (Encryption, *tls.Config, e
} else {
trustServerCert = true
}
trust, ok := params["trustservercertificate"]
trust, ok := params[TrustServerCertificate]
if ok {
var err error
trustServerCert, err = strconv.ParseBool(trust)
Expand All @@ -156,9 +186,9 @@ func parseTLS(params map[string]string, host string) (Encryption, *tls.Config, e
return encryption, nil, fmt.Errorf(f, trust, err.Error())
}
}
certificate := params["certificate"]
certificate := params[Certificate]
if encryption != EncryptionDisabled {
tlsMin := params["tlsmin"]
tlsMin := params[TLSMin]
tlsConfig, err := SetupTLS(certificate, trustServerCert, host, tlsMin)
if err != nil {
return encryption, nil, fmt.Errorf("failed to setup TLS: %w", err)
Expand Down Expand Up @@ -194,7 +224,7 @@ func Parse(dsn string) (Config, error) {

p.Parameters = params

strlog, ok := params["log"]
strlog, ok := params[LogParam]
if ok {
flags, err := strconv.ParseUint(strlog, 10, 64)
if err != nil {
Expand All @@ -203,12 +233,12 @@ func Parse(dsn string) (Config, error) {
p.LogFlags = Log(flags)
}

p.Database = params["database"]
p.User = params["user id"]
p.Password = params["password"]

p.Database = params[Database]
p.User = params[UserID]
p.Password = params[Password]
p.ChangePassword = params[ChangePassword]
p.Port = 0
strport, ok := params["port"]
strport, ok := params[Port]
if ok {
var err error
p.Port, err = strconv.ParseUint(strport, 10, 16)
Expand All @@ -219,7 +249,7 @@ func Parse(dsn string) (Config, error) {
}

// https://docs.microsoft.com/en-us/sql/database-engine/configure-windows/configure-the-network-packet-size-server-configuration-option\
strpsize, ok := params["packet size"]
strpsize, ok := params[PacketSize]
if ok {
var err error
psize, err := strconv.ParseUint(strpsize, 0, 16)
Expand All @@ -244,7 +274,7 @@ func Parse(dsn string) (Config, error) {
//
// Do not set a connection timeout. Use Context to manage such things.
// Default to zero, but still allow it to be set.
if strconntimeout, ok := params["connection timeout"]; ok {
if strconntimeout, ok := params[ConnectionTimeout]; ok {
timeout, err := strconv.ParseUint(strconntimeout, 10, 64)
if err != nil {
f := "invalid connection timeout '%v': %v"
Expand All @@ -256,7 +286,7 @@ func Parse(dsn string) (Config, error) {
// default keep alive should be 30 seconds according to spec:
// https://msdn.microsoft.com/en-us/library/dd341108.aspx
p.KeepAlive = 30 * time.Second
if keepAlive, ok := params["keepalive"]; ok {
if keepAlive, ok := params[KeepAlive]; ok {
timeout, err := strconv.ParseUint(keepAlive, 10, 64)
if err != nil {
f := "invalid keepAlive value '%s': %s"
Expand All @@ -265,12 +295,12 @@ func Parse(dsn string) (Config, error) {
p.KeepAlive = time.Duration(timeout) * time.Second
}

serverSPN, ok := params["serverspn"]
serverSPN, ok := params[ServerSpn]
if ok {
p.ServerSPN = serverSPN
} // If not set by the app, ServerSPN will be set by the successful dialer.

workstation, ok := params["workstation id"]
workstation, ok := params[WorkstationID]
if ok {
p.Workstation = workstation
} else {
Expand All @@ -280,13 +310,13 @@ func Parse(dsn string) (Config, error) {
}
}

appname, ok := params["app name"]
appname, ok := params[AppName]
if !ok {
appname = "go-mssqldb"
}
p.AppName = appname

appintent, ok := params["applicationintent"]
appintent, ok := params[ApplicationIntent]
if ok {
if appintent == "ReadOnly" {
if p.Database == "" {
Expand All @@ -296,12 +326,12 @@ func Parse(dsn string) (Config, error) {
}
}

failOverPartner, ok := params["failoverpartner"]
failOverPartner, ok := params[FailoverPartner]
if ok {
p.FailOverPartner = failOverPartner
}

failOverPort, ok := params["failoverport"]
failOverPort, ok := params[FailOverPort]
if ok {
var err error
p.FailOverPort, err = strconv.ParseUint(failOverPort, 0, 16)
Expand All @@ -311,7 +341,7 @@ func Parse(dsn string) (Config, error) {
}
}

disableRetry, ok := params["disableretry"]
disableRetry, ok := params[DisableRetry]
if ok {
var err error
p.DisableRetry, err = strconv.ParseBool(disableRetry)
Expand All @@ -323,8 +353,8 @@ func Parse(dsn string) (Config, error) {
p.DisableRetry = disableRetryDefault
}

server := params["server"]
protocol, ok := params["protocol"]
server := params[Server]
protocol, ok := params[Protocol]

for _, parser := range ProtocolParsers {
if (!ok && !parser.Hidden()) || parser.Protocol() == protocol {
Expand All @@ -350,7 +380,7 @@ func Parse(dsn string) (Config, error) {
f = 1
}
p.DialTimeout = time.Duration(15*f) * time.Second
if strdialtimeout, ok := params["dial timeout"]; ok {
if strdialtimeout, ok := params[DialTimeout]; ok {
timeout, err := strconv.ParseUint(strdialtimeout, 10, 64)
if err != nil {
f := "invalid dial timeout '%v': %v"
Expand All @@ -360,7 +390,7 @@ func Parse(dsn string) (Config, error) {
p.DialTimeout = time.Duration(timeout) * time.Second
}

hostInCertificate, ok := params["hostnameincertificate"]
hostInCertificate, ok := params[HostNameInCertificate]
if ok {
p.HostInCertificateProvided = true
} else {
Expand Down Expand Up @@ -394,10 +424,10 @@ func Parse(dsn string) (Config, error) {
func (p Config) URL() *url.URL {
q := url.Values{}
if p.Database != "" {
q.Add("database", p.Database)
q.Add(Database, p.Database)
}
if p.LogFlags != 0 {
q.Add("log", strconv.FormatUint(uint64(p.LogFlags), 10))
q.Add(LogParam, strconv.FormatUint(uint64(p.LogFlags), 10))
}
host := p.Host
protocol := ""
Expand All @@ -412,20 +442,20 @@ func (p Config) URL() *url.URL {
if p.Port > 0 {
host = fmt.Sprintf("%s:%d", host, p.Port)
}
q.Add("disableRetry", fmt.Sprintf("%t", p.DisableRetry))
protocolParam, ok := p.Parameters["protocol"]
q.Add(DisableRetry, fmt.Sprintf("%t", p.DisableRetry))
protocolParam, ok := p.Parameters[Protocol]
if ok {
if protocol != "" && protocolParam != protocol {
panic("Mismatched protocol parameters!")
}
protocol = protocolParam
}
if protocol != "" {
q.Add("protocol", protocol)
q.Add(Protocol, protocol)
}
pipe, ok := p.Parameters["pipe"]
pipe, ok := p.Parameters[Pipe]
if ok {
q.Add("pipe", pipe)
q.Add(Pipe, pipe)
}
res := url.URL{
Scheme: "sqlserver",
Expand All @@ -435,7 +465,14 @@ func (p Config) URL() *url.URL {
if p.Instance != "" {
res.Path = p.Instance
}
q.Add("dial timeout", strconv.FormatFloat(float64(p.DialTimeout.Seconds()), 'f', 0, 64))
q.Add(DialTimeout, strconv.FormatFloat(float64(p.DialTimeout.Seconds()), 'f', 0, 64))

switch p.Encryption {
case EncryptionDisabled:
q.Add(Encrypt, "DISABLE")
case EncryptionRequired:
q.Add(Encrypt, "true")
}
if p.ColumnEncryption {
q.Add("columnencryption", "true")
}
Expand All @@ -448,14 +485,14 @@ func (p Config) URL() *url.URL {

// ADO connection string keywords at https://github.com/dotnet/SqlClient/blob/main/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Common/DbConnectionStringCommon.cs
var adoSynonyms = map[string]string{
"application name": "app name",
"data source": "server",
"address": "server",
"network address": "server",
"addr": "server",
"user": "user id",
"uid": "user id",
"initial catalog": "database",
"application name": AppName,
"data source": Server,
"address": Server,
"network address": Server,
"addr": Server,
"user": UserID,
"uid": UserID,
"initial catalog": Database,
"column encryption setting": "columnencryption",
}

Expand All @@ -480,18 +517,18 @@ func splitConnectionString(dsn string) (res map[string]string) {
name = synonym
}
// "server" in ADO can include a protocol and a port.
if name == "server" {
if name == Server {
for _, parser := range ProtocolParsers {
prot := parser.Protocol() + ":"
if strings.HasPrefix(value, prot) {
res["protocol"] = parser.Protocol()
res[Protocol] = parser.Protocol()
}
value = strings.TrimPrefix(value, prot)
}
serverParts := strings.Split(value, ",")
if len(serverParts) == 2 && len(serverParts[1]) > 0 {
value = serverParts[0]
res["port"] = serverParts[1]
res[Port] = serverParts[1]
}
}
res[name] = value
Expand All @@ -513,10 +550,10 @@ func splitConnectionStringURL(dsn string) (map[string]string, error) {
}

if u.User != nil {
res["user id"] = u.User.Username()
res[UserID] = u.User.Username()
p, exists := u.User.Password()
if exists {
res["password"] = p
res[Password] = p
}
}

Expand All @@ -526,13 +563,13 @@ func splitConnectionStringURL(dsn string) (map[string]string, error) {
}

if len(u.Path) > 0 {
res["server"] = host + "\\" + u.Path[1:]
res[Server] = host + "\\" + u.Path[1:]
} else {
res["server"] = host
res[Server] = host
}

if len(port) > 0 {
res["port"] = port
res[Port] = port
}

query := u.Query()
Expand Down
Loading