Skip to content

Commit

Permalink
Add support for ssl dial string (globalsign#184)
Browse files Browse the repository at this point in the history
* Add support for ssl dial string

* Ensure we dont override user settings

* update examples

* update ssl value parsing

* PingSsl test

* skip test requiring system certificates
  • Loading branch information
tbruyelle authored and max-konin committed Jul 5, 2018
1 parent e16f0d0 commit bc5f0ce
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 1 deletion.
43 changes: 42 additions & 1 deletion example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,4 +133,45 @@ func ExampleSession_concurrency() {
wg.Wait()

session.Close()
}
}

func ExampleDial_usingSSL() {
// To connect via TLS/SSL (enforced for MongoDB Atlas for example) requires
// to set the ssl query param to true.
url := "mongodb://localhost:40003?ssl=true"

session, err := Dial(url)
if err != nil {
panic(err)
}

// Use session as normal
session.Close()
}

func ExampleDial_tlsConfig() {
// You can define a custom tlsConfig, this one enables TLS, like if you have
// ssl=true in the connection string.
url := "mongodb://localhost:40003"

tlsConfig := &tls.Config{
// This can be configured to use a private root CA - see the Credential
// x509 Authentication example.
//
// Please don't set InsecureSkipVerify to true - it makes using TLS
// pointless and is never the right answer!
}

dialInfo, err := ParseURL(url)
dialInfo.DialServer = func(addr *ServerAddr) (net.Conn, error) {
return tls.Dial("tcp", addr.String(), tlsConfig)
}

session, err := DialWithInfo(dialInfo)
if err != nil {
panic(err)
}

// Use session as normal
session.Close()
}
19 changes: 19 additions & 0 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ package mgo

import (
"crypto/md5"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
Expand Down Expand Up @@ -274,6 +275,12 @@ const (
// The identifier of this client application. This parameter is used to
// annotate logs / profiler output and cannot exceed 128 bytes.
//
// ssl=<true|false>
//
// true: Initiate the connection with TLS/SSL.
// false: Initiate the connection without TLS/SSL.
// The default value is false.
//
// Relevant documentation:
//
// http://docs.mongodb.org/manual/reference/connection-string/
Expand Down Expand Up @@ -311,6 +318,7 @@ func ParseURL(url string) (*DialInfo, error) {
if err != nil {
return nil, err
}
ssl := false
direct := false
mechanism := ""
service := ""
Expand All @@ -323,6 +331,10 @@ func ParseURL(url string) (*DialInfo, error) {
var readPreferenceTagSets []bson.D
for _, opt := range uinfo.options {
switch opt.key {
case "ssl":
if v, err := strconv.ParseBool(opt.value); err == nil && v {
ssl = true
}
case "authSource":
source = opt.value
case "authMechanism":
Expand Down Expand Up @@ -412,6 +424,13 @@ func ParseURL(url string) (*DialInfo, error) {
ReplicaSetName: setName,
SSL: ssl,
}
if ssl && info.DialServer == nil {
// Set DialServer only if nil, we don't want to override user's settings.
info.DialServer = func(addr *ServerAddr) (net.Conn, error) {
conn, err := tls.Dial("tcp", addr.String(), &tls.Config{})
return conn, err
}
}
return &info, nil
}

Expand Down
28 changes: 28 additions & 0 deletions session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,15 @@ func (s *S) TestPing(c *C) {
c.Assert(stats.ReceivedOps, Equals, 1)
}

func (s *S) TestPingSsl(c *C) {
c.Skip("this test requires the usage of the system provided certificates")
session, err := mgo.Dial("localhost:40001?ssl=true")
c.Assert(err, IsNil)
defer session.Close()

c.Assert(session.Ping(), IsNil)
}

func (s *S) TestDialIPAddress(c *C) {
session, err := mgo.Dial("127.0.0.1:40001")
c.Assert(err, IsNil)
Expand Down Expand Up @@ -133,6 +142,25 @@ func (s *S) TestURLParsing(c *C) {
}
}

func (s *S) TestURLSsl(c *C) {
type test struct {
url string
nilDialServer bool
}

tests := []test{
{"localhost:40001", true},
{"localhost:40001?ssl=false", true},
{"localhost:40001?ssl=true", false},
}

for _, test := range tests {
info, err := mgo.ParseURL(test.url)
c.Assert(err, IsNil)
c.Assert(info.DialServer == nil, Equals, test.nilDialServer)
}
}

func (s *S) TestURLReadPreference(c *C) {
type test struct {
url string
Expand Down

0 comments on commit bc5f0ce

Please sign in to comment.