mirror of
https://github.com/hashicorp/vault.git
synced 2025-08-11 09:07:00 +02:00
135 lines
2.9 KiB
Go
135 lines
2.9 KiB
Go
package mongodb
|
|
|
|
import (
|
|
"crypto/tls"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"net/url"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"gopkg.in/mgo.v2"
|
|
)
|
|
|
|
// Unfortunately, mgo doesn't support the ssl parameter in its MongoDB URI parsing logic, so we have to handle that
|
|
// ourselves. See https://github.com/go-mgo/mgo/issues/84
|
|
func parseMongoURI(uri string) (*mgo.DialInfo, error) {
|
|
uinfo, err := extractMongoURL(uri)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
direct := false
|
|
mechanism := ""
|
|
service := ""
|
|
source := ""
|
|
setName := ""
|
|
poolLimit := 0
|
|
ssl := false
|
|
for k, v := range uinfo.options {
|
|
switch k {
|
|
case "authSource":
|
|
source = v
|
|
case "authMechanism":
|
|
mechanism = v
|
|
case "gssapiServiceName":
|
|
service = v
|
|
case "replicaSet":
|
|
setName = v
|
|
case "maxPoolSize":
|
|
poolLimit, err = strconv.Atoi(v)
|
|
if err != nil {
|
|
return nil, errors.New("bad value for maxPoolSize: " + v)
|
|
}
|
|
case "ssl":
|
|
if v == "true" {
|
|
ssl = true
|
|
}
|
|
case "connect":
|
|
if v == "direct" {
|
|
direct = true
|
|
break
|
|
}
|
|
if v == "replicaSet" {
|
|
break
|
|
}
|
|
fallthrough
|
|
default:
|
|
return nil, errors.New("unsupported connection URL option: " + k + "=" + v)
|
|
}
|
|
}
|
|
info := mgo.DialInfo{
|
|
Addrs: uinfo.addrs,
|
|
Direct: direct,
|
|
Database: uinfo.db,
|
|
Username: uinfo.user,
|
|
Password: uinfo.pass,
|
|
Mechanism: mechanism,
|
|
Service: service,
|
|
Source: source,
|
|
PoolLimit: poolLimit,
|
|
ReplicaSetName: setName,
|
|
Timeout: 10 * time.Second,
|
|
}
|
|
if ssl {
|
|
info.DialServer = func(addr *mgo.ServerAddr) (net.Conn, error) {
|
|
return tls.Dial("tcp", addr.String(), &tls.Config{})
|
|
}
|
|
}
|
|
return &info, nil
|
|
}
|
|
|
|
func extractMongoURL(s string) (*urlInfo, error) {
|
|
if strings.HasPrefix(s, "mongodb://") {
|
|
s = s[10:]
|
|
}
|
|
info := &urlInfo{options: make(map[string]string)}
|
|
if c := strings.Index(s, "?"); c != -1 {
|
|
for _, pair := range strings.FieldsFunc(s[c+1:], isOptSep) {
|
|
l := strings.SplitN(pair, "=", 2)
|
|
if len(l) != 2 || l[0] == "" || l[1] == "" {
|
|
return nil, errors.New("connection option must be key=value: " + pair)
|
|
}
|
|
info.options[l[0]] = l[1]
|
|
}
|
|
s = s[:c]
|
|
}
|
|
if c := strings.Index(s, "@"); c != -1 {
|
|
pair := strings.SplitN(s[:c], ":", 2)
|
|
if len(pair) > 2 || pair[0] == "" {
|
|
return nil, errors.New("credentials must be provided as user:pass@host")
|
|
}
|
|
var err error
|
|
info.user, err = url.QueryUnescape(pair[0])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("cannot unescape username in URL: %q", pair[0])
|
|
}
|
|
if len(pair) > 1 {
|
|
info.pass, err = url.QueryUnescape(pair[1])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("cannot unescape password in URL")
|
|
}
|
|
}
|
|
s = s[c+1:]
|
|
}
|
|
if c := strings.Index(s, "/"); c != -1 {
|
|
info.db = s[c+1:]
|
|
s = s[:c]
|
|
}
|
|
info.addrs = strings.Split(s, ",")
|
|
return info, nil
|
|
}
|
|
|
|
func isOptSep(c rune) bool {
|
|
return c == ';' || c == '&'
|
|
}
|
|
|
|
type urlInfo struct {
|
|
addrs []string
|
|
user string
|
|
pass string
|
|
db string
|
|
options map[string]string
|
|
}
|