Skip to content

Allow TLS connections in the driver #673

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
May 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 21 additions & 9 deletions client/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,28 @@ func NewClientTLSConfig(caPem, certPem, keyPem []byte, insecureSkipVerify bool,
panic("failed to add ca PEM")
}

cert, err := tls.X509KeyPair(certPem, keyPem)
if err != nil {
panic(err)
}
var config *tls.Config

config := &tls.Config{
Certificates: []tls.Certificate{cert},
RootCAs: pool,
InsecureSkipVerify: insecureSkipVerify,
ServerName: serverName,
// Allow cert and key to be optional
// Send through `make([]byte, 0)` for "nil"
if string(certPem) != "" && string(keyPem) != "" {
cert, err := tls.X509KeyPair(certPem, keyPem)
if err != nil {
panic(err)
}
config = &tls.Config{
RootCAs: pool,
Certificates: []tls.Certificate{cert},
InsecureSkipVerify: insecureSkipVerify,
ServerName: serverName,
}
} else {
config = &tls.Config{
RootCAs: pool,
InsecureSkipVerify: insecureSkipVerify,
ServerName: serverName,
}
}

return config
}
142 changes: 119 additions & 23 deletions driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,51 +3,120 @@
package driver

import (
"crypto/tls"
"database/sql"
sqldriver "database/sql/driver"
"fmt"
"io"
"strings"
"net/url"
"regexp"
"sync"

"github.com/go-mysql-org/go-mysql/client"
"github.com/go-mysql-org/go-mysql/mysql"
"github.com/pingcap/errors"
"github.com/siddontang/go/hack"
)

var customTLSMutex sync.Mutex
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
var customTLSMutex sync.Mutex
var customTLSMutex sync.Mutex

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed this.


// Map of dsn address (makes more sense than full dsn?) to tls Config
var customTLSConfigMap = make(map[string]*tls.Config)

type driver struct {
}

// Open: DSN user:password@addr[?db]
func (d driver) Open(dsn string) (sqldriver.Conn, error) {
lastIndex := strings.LastIndex(dsn, "@")
seps := []string{dsn[:lastIndex], dsn[lastIndex+1:]}
if len(seps) != 2 {
return nil, errors.Errorf("invalid dsn, must user:password@addr[?db]")
type connInfo struct {
standardDSN bool
addr string
user string
password string
db string
params url.Values
}

// ParseDSN takes a DSN string and splits it up into struct containing addr,
// user, password and db.
// It returns an error if unable to parse.
// The struct also contains a boolean indicating if the DSN is in legacy or
// standard form.
//
// Legacy form uses a `?` is used as the path separator: user:password@addr[?db]
// Standard form uses a `/`: user:password@addr/db?param=value
//
// Optional parameters are supported in the standard DSN form
func parseDSN(dsn string) (connInfo, error) {
var matchErr error
ci := connInfo{}

// If a "/" occurs after "@" and then no more "@" or "/" occur after that
ci.standardDSN, matchErr = regexp.MatchString("@[^@]+/[^@/]+", dsn)
if matchErr != nil {
return ci, errors.Errorf("invalid dsn, must be user:password@addr[/db[?param=X]]")
}

// Add a prefix so we can parse with url.Parse
dsn = "mysql://" + dsn
parsedDSN, parseErr := url.Parse(dsn)
if parseErr != nil {
return ci, errors.Errorf("invalid dsn, must be user:password@addr[/db[?param=X]]")
}

var user string
var password string
var addr string
var db string
ci.addr = parsedDSN.Host
ci.user = parsedDSN.User.Username()
// We ignore the second argument as that is just a flag for existence of a password
// If not set we get empty string anyway
ci.password, _ = parsedDSN.User.Password()

if ss := strings.Split(seps[0], ":"); len(ss) == 2 {
user, password = ss[0], ss[1]
} else if len(ss) == 1 {
user = ss[0]
if ci.standardDSN {
ci.db = parsedDSN.Path[1:]
ci.params = parsedDSN.Query()
} else {
return nil, errors.Errorf("invalid dsn, must user:password@addr[?db]")
ci.db = parsedDSN.RawQuery
// This is the equivalent to a "nil" list of parameters
ci.params = url.Values{}
}

if ss := strings.Split(seps[1], "?"); len(ss) == 2 {
addr, db = ss[0], ss[1]
} else if len(ss) == 1 {
addr = ss[0]
} else {
return nil, errors.Errorf("invalid dsn, must user:password@addr[?db]")
return ci, nil
}

// Open takes a supplied DSN string and opens a connection
// See ParseDSN for more information on the form of the DSN
func (d driver) Open(dsn string) (sqldriver.Conn, error) {
var c *client.Conn

ci, err := parseDSN(dsn)

if err != nil {
return nil, err
}

c, err := client.Connect(addr, user, password, db)
if ci.standardDSN {
if ci.params["ssl"] != nil {
tlsConfigName := ci.params.Get("ssl")
switch tlsConfigName {
case "true":
// This actually does insecureSkipVerify
// But not even sure if it makes sense to handle false? According to
// client_test.go it doesn't - it'd result in an error
c, err = client.Connect(ci.addr, ci.user, ci.password, ci.db, func(c *client.Conn) { c.UseSSL(true) })
case "custom":
// I was too concerned about mimicking what go-sql-driver/mysql does which will
// allow any name for a custom tls profile and maps the query parameter value to
// that TLSConfig variable... there is no need to be that clever.
// Instead of doing that, let's store required custom TLSConfigs in a map that
// uses the DSN address as the key
c, err = client.Connect(ci.addr, ci.user, ci.password, ci.db, func(c *client.Conn) { c.SetTLSConfig(customTLSConfigMap[ci.addr]) })
default:
return nil, errors.Errorf("Supported options are ssl=true or ssl=custom")
}
} else {
c, err = client.Connect(ci.addr, ci.user, ci.password, ci.db)
}
} else {
// No more processing here. Let's only support url parameters with the newer style DSN
c, err = client.Connect(ci.addr, ci.user, ci.password, ci.db)
}
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -229,3 +298,30 @@ func (r *rows) Next(dest []sqldriver.Value) error {
func init() {
sql.Register("mysql", driver{})
}

// SetCustomTLSConfig sets a custom TLSConfig for the address (host:port) of the supplied DSN.
// It requires a full import of the driver (not by side-effects only).
// Example of supplying a custom CA, no client cert, no key, validating the
// certificate, and supplying a serverName for the validation:
//
// driver.SetCustomTLSConfig(CaPem, make([]byte, 0), make([]byte, 0), false, "my.domain.name")
//
func SetCustomTLSConfig(dsn string, caPem []byte, certPem []byte, keyPem []byte, insecureSkipVerify bool, serverName string) error {
// Extract addr from dsn
parsed, err := url.Parse(dsn)
if err != nil {
return errors.Errorf("Unable to parse DSN. Need to extract address to use as key for storing custom TLS config")
}
addr := parsed.Host

// I thought about using serverName instead of addr below, but decided against that as
// having multiple CA certs for one hostname is likely when you have services running on
// different ports.

customTLSMutex.Lock()
// Basic pass-through function so we can just import the driver
customTLSConfigMap[addr] = client.NewClientTLSConfig(caPem, certPem, keyPem, insecureSkipVerify, serverName)
customTLSMutex.Unlock()

return nil
}
27 changes: 27 additions & 0 deletions driver/driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package driver
import (
"flag"
"fmt"
"net/url"
"reflect"
"testing"

"github.com/jmoiron/sqlx"
Expand Down Expand Up @@ -78,3 +80,28 @@ func (s *testDriverSuite) TestTransaction(c *C) {
err = tx.Commit()
c.Assert(err, IsNil)
}

func TestParseDSN(t *testing.T) {
// List of DSNs to test and expected results
// Use different numbered domains to more readily see what has failed - since we
// test in a loop we get the same line number on error
testDSNs := map[string]connInfo{
"user:password@localhost?db": connInfo{standardDSN: false, addr: "localhost", user: "user", password: "password", db: "db", params: url.Values{}},
"[email protected]?db": connInfo{standardDSN: false, addr: "1.domain.com", user: "user", password: "", db: "db", params: url.Values{}},
"user:[email protected]/db": connInfo{standardDSN: true, addr: "2.domain.com", user: "user", password: "password", db: "db", params: url.Values{}},
"user:[email protected]/db?ssl=true": connInfo{standardDSN: true, addr: "3.domain.com", user: "user", password: "password", db: "db", params: url.Values{"ssl": []string{"true"}}},
"user:[email protected]/db?ssl=custom": connInfo{standardDSN: true, addr: "4.domain.com", user: "user", password: "password", db: "db", params: url.Values{"ssl": []string{"custom"}}},
"user:[email protected]/db?unused=param": connInfo{standardDSN: true, addr: "5.domain.com", user: "user", password: "password", db: "db", params: url.Values{"unused": []string{"param"}}},
}

for supplied, expected := range testDSNs {
actual, err := parseDSN(supplied)
if err != nil {
t.Errorf("TestParseDSN failed. Got error: %s", err)
}
// Compare that with expected
if !reflect.DeepEqual(actual, expected) {
t.Errorf("TestParseDSN failed.\nExpected:\n%#v\nGot:\n%#v", expected, actual)
}
}
}