vault/builtin/logical/mongodb/util.go
2016-06-29 08:33:06 -04:00

135 lines
2.9 KiB
Go

package mongodb
import (
"crypto/tls"
"errors"
"fmt"
"net"
"net/url"
"strconv"
"strings"
"time"
"gopkg.in/mgo.v2"
)
// Unfortunately, mgo doesn't support the ssl parameter in its MongoDB URI parsing logic, so we have to handle that
// ourselves. See https://github.com/go-mgo/mgo/issues/84
func parseMongoURI(uri string) (*mgo.DialInfo, error) {
uinfo, err := extractMongoURL(uri)
if err != nil {
return nil, err
}
direct := false
mechanism := ""
service := ""
source := ""
setName := ""
poolLimit := 0
ssl := false
for k, v := range uinfo.options {
switch k {
case "authSource":
source = v
case "authMechanism":
mechanism = v
case "gssapiServiceName":
service = v
case "replicaSet":
setName = v
case "maxPoolSize":
poolLimit, err = strconv.Atoi(v)
if err != nil {
return nil, errors.New("bad value for maxPoolSize: " + v)
}
case "ssl":
if v == "true" {
ssl = true
}
case "connect":
if v == "direct" {
direct = true
break
}
if v == "replicaSet" {
break
}
fallthrough
default:
return nil, errors.New("unsupported connection URL option: " + k + "=" + v)
}
}
info := mgo.DialInfo{
Addrs: uinfo.addrs,
Direct: direct,
Database: uinfo.db,
Username: uinfo.user,
Password: uinfo.pass,
Mechanism: mechanism,
Service: service,
Source: source,
PoolLimit: poolLimit,
ReplicaSetName: setName,
Timeout: 10 * time.Second,
}
if ssl {
info.DialServer = func(addr *mgo.ServerAddr) (net.Conn, error) {
return tls.Dial("tcp", addr.String(), &tls.Config{})
}
}
return &info, nil
}
func extractMongoURL(s string) (*urlInfo, error) {
if strings.HasPrefix(s, "mongodb://") {
s = s[10:]
}
info := &urlInfo{options: make(map[string]string)}
if c := strings.Index(s, "?"); c != -1 {
for _, pair := range strings.FieldsFunc(s[c+1:], isOptSep) {
l := strings.SplitN(pair, "=", 2)
if len(l) != 2 || l[0] == "" || l[1] == "" {
return nil, errors.New("connection option must be key=value: " + pair)
}
info.options[l[0]] = l[1]
}
s = s[:c]
}
if c := strings.Index(s, "@"); c != -1 {
pair := strings.SplitN(s[:c], ":", 2)
if len(pair) > 2 || pair[0] == "" {
return nil, errors.New("credentials must be provided as user:pass@host")
}
var err error
info.user, err = url.QueryUnescape(pair[0])
if err != nil {
return nil, fmt.Errorf("cannot unescape username in URL: %q", pair[0])
}
if len(pair) > 1 {
info.pass, err = url.QueryUnescape(pair[1])
if err != nil {
return nil, fmt.Errorf("cannot unescape password in URL")
}
}
s = s[c+1:]
}
if c := strings.Index(s, "/"); c != -1 {
info.db = s[c+1:]
s = s[:c]
}
info.addrs = strings.Split(s, ",")
return info, nil
}
func isOptSep(c rune) bool {
return c == ';' || c == '&'
}
type urlInfo struct {
addrs []string
user string
pass string
db string
options map[string]string
}