merge changes from master

This commit is contained in:
Vishal Nayak 2015-06-29 22:01:43 -04:00
commit 9ba1d26f4e
275 changed files with 50329 additions and 3121 deletions

106
Godeps/Godeps.json generated
View File

@ -7,74 +7,79 @@
"Deps": [
{
"ImportPath": "github.com/armon/go-metrics",
"Rev": "a54701ebec11868993bc198c3f315353e9de2ed6"
"Rev": "b2d95e5291cdbc26997d1301a5e467ecbb240e25"
},
{
"ImportPath": "github.com/armon/go-radix",
"Rev": "0bab926c3433cfd6490c6d3c504a7b471362390c"
"Rev": "fbd82e84e2b13651f3abc5ffd26b65ba71bc8f93"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/aws",
"Comment": "v0.6.0",
"Rev": "ea83c25c44525da47e8044bbd21e4045758ea39b"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/internal/apierr",
"Comment": "v0.6.0",
"Rev": "ea83c25c44525da47e8044bbd21e4045758ea39b"
"Comment": "v0.6.4-5-g127313c",
"Rev": "127313c1b41e534a0456a68b6b3a16712dacb35d"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/internal/endpoints",
"Comment": "v0.6.0",
"Rev": "ea83c25c44525da47e8044bbd21e4045758ea39b"
"Comment": "v0.6.4-5-g127313c",
"Rev": "127313c1b41e534a0456a68b6b3a16712dacb35d"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/internal/protocol/query",
"Comment": "v0.6.0",
"Rev": "ea83c25c44525da47e8044bbd21e4045758ea39b"
"Comment": "v0.6.4-5-g127313c",
"Rev": "127313c1b41e534a0456a68b6b3a16712dacb35d"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/internal/protocol/rest",
"Comment": "v0.6.0",
"Rev": "ea83c25c44525da47e8044bbd21e4045758ea39b"
"Comment": "v0.6.4-5-g127313c",
"Rev": "127313c1b41e534a0456a68b6b3a16712dacb35d"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/internal/protocol/restxml",
"Comment": "v0.6.0",
"Rev": "ea83c25c44525da47e8044bbd21e4045758ea39b"
"Comment": "v0.6.4-5-g127313c",
"Rev": "127313c1b41e534a0456a68b6b3a16712dacb35d"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/internal/protocol/xml/xmlutil",
"Comment": "v0.6.0",
"Rev": "ea83c25c44525da47e8044bbd21e4045758ea39b"
"Comment": "v0.6.4-5-g127313c",
"Rev": "127313c1b41e534a0456a68b6b3a16712dacb35d"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/internal/signer/v4",
"Comment": "v0.6.0",
"Rev": "ea83c25c44525da47e8044bbd21e4045758ea39b"
"Comment": "v0.6.4-5-g127313c",
"Rev": "127313c1b41e534a0456a68b6b3a16712dacb35d"
},
{
"ImportPath": "github.com/aws/aws-sdk-go/service/s3",
"Comment": "v0.6.0",
"Rev": "ea83c25c44525da47e8044bbd21e4045758ea39b"
"Comment": "v0.6.4-5-g127313c",
"Rev": "127313c1b41e534a0456a68b6b3a16712dacb35d"
},
{
"ImportPath": "github.com/coreos/go-etcd/etcd",
"Comment": "v2.0.0-7-g73a8ef7",
"Rev": "73a8ef737e8ea002281a28b4cb92a1de121ad4c6"
"Comment": "v2.0.0-18-gc904d70",
"Rev": "c904d7032a70da6551c43929f199244f6a45f4c1"
},
{
"ImportPath": "github.com/fatih/structs",
"Rev": "a9f7daa9c2729e97450c2da2feda19130a367d8f"
},
{
"ImportPath": "github.com/go-ldap/ldap",
"Comment": "v1-14-g406aa05",
"Rev": "406aa05eb8272fb8aa201e410afa6f9fdcb2bf68"
},
{
"ImportPath": "github.com/go-ldap/ldap",
"Comment": "v1-14-g406aa05",
"Rev": "406aa05eb8272fb8aa201e410afa6f9fdcb2bf68"
},
{
"ImportPath": "github.com/go-sql-driver/mysql",
"Comment": "v1.2-88-ga197e5d",
"Rev": "a197e5d40516f2e9f74dcee085a5f2d4604e94df"
"Comment": "v1.2-112-gfb72997",
"Rev": "fb7299726d2e68745a8805b14f2ff44b5c2cfa84"
},
{
"ImportPath": "github.com/google/go-github/github",
"Rev": "0aaa85be4f3087c6dd815a69e291775d4e83f9ea"
"Rev": "fccd5bb66f985db0a0d150342ca0a9529a23488a"
},
{
"ImportPath": "github.com/google/go-querystring/query",
@ -102,8 +107,8 @@
},
{
"ImportPath": "github.com/hashicorp/consul/api",
"Comment": "v0.5.0-253-g7062ecc",
"Rev": "7062ecc50fef9307e532c4a188da7ce1dd759dde"
"Comment": "v0.5.2-123-gaddb614",
"Rev": "addb6145096bbce6f9dde807a78cad2a4cea3a68"
},
{
"ImportPath": "github.com/hashicorp/errwrap",
@ -111,7 +116,7 @@
},
{
"ImportPath": "github.com/hashicorp/go-multierror",
"Rev": "fcdddc395df1ddf4247c69bd436e84cfa0733f7e"
"Rev": "56912fb08d85084aa318edcf2bba735b97cf35c5"
},
{
"ImportPath": "github.com/hashicorp/go-syslog",
@ -119,28 +124,28 @@
},
{
"ImportPath": "github.com/hashicorp/golang-lru",
"Rev": "d85392d6bc30546d352f52f2632814cde4201d44"
"Rev": "995efda3e073b6946b175ed93901d729ad47466a"
},
{
"ImportPath": "github.com/hashicorp/hcl",
"Rev": "513e04c400ee2e81e97f5e011c08fb42c6f69b84"
"Rev": "54864211433d45cb780682431585b3e573b49e4a"
},
{
"ImportPath": "github.com/hashicorp/logutils",
"Rev": "367a65d59043b4f846d179341d138f01f988c186"
"Rev": "0dc08b1671f34c4250ce212759ebd880f743d883"
},
{
"ImportPath": "github.com/kardianos/osext",
"Rev": "8fef92e41e22a70e700a96b29f066cda30ea24ef"
"Rev": "6e7f843663477789fac7c02def0d0909e969b4e5"
},
{
"ImportPath": "github.com/lib/pq",
"Comment": "go1.0-cutoff-40-g8910d1c",
"Rev": "8910d1c3a4bda5c97c50bc38543953f1f1e1f8bb"
"Comment": "go1.0-cutoff-51-ga8d8d01",
"Rev": "a8d8d01c4f91602f876bf5aa210274e8203a6b45"
},
{
"ImportPath": "github.com/mitchellh/cli",
"Rev": "6cc8bc522243675a2882b81662b0b0d2e04b99c9"
"Rev": "8102d0ed5ea2709ade1243798785888175f6e415"
},
{
"ImportPath": "github.com/mitchellh/copystructure",
@ -152,11 +157,11 @@
},
{
"ImportPath": "github.com/mitchellh/mapstructure",
"Rev": "442e588f213303bec7936deba67901f8fc8f18b1"
"Rev": "2caf8efc93669b6c43e0441cdc6aed17546c96f3"
},
{
"ImportPath": "github.com/mitchellh/reflectwalk",
"Rev": "242be0c275dedfba00a616563e6db75ab8f279ec"
"Rev": "eecf4c70c626c7cfbb95c90195bc34d386c74ac6"
},
{
"ImportPath": "github.com/ryanuber/columnize",
@ -165,15 +170,11 @@
},
{
"ImportPath": "github.com/samuel/go-zookeeper/zk",
"Rev": "d0e0d8e11f318e000a8cc434616d69e329edc374"
"Rev": "c86eba8e7e95efab81f6c0455332e49d39aed12f"
},
{
"ImportPath": "github.com/vanackere/asn1-ber",
"Rev": "295c7b21db5d9525ad959e3382610f3aff029663"
},
{
"ImportPath": "github.com/vanackere/ldap",
"Rev": "e29b797d1abde6567ccb4ab56236e033cabf845a"
"ImportPath": "github.com/ugorji/go/codec",
"Rev": "821cda7e48749cacf7cad2c6ed01e96457ca7e9d"
},
{
"ImportPath": "github.com/vaughan0/go-ini",
@ -181,15 +182,20 @@
},
{
"ImportPath": "golang.org/x/crypto/ssh/terminal",
"Rev": "59435533c88bd0b1254c738244da1fe96b59d05d"
"Rev": "cc04154d65fb9296747569b107cfd05380b1ea3e"
},
{
"ImportPath": "golang.org/x/net/context",
"Rev": "a8c61998a557a37435f719980da368469c10bfed"
"Rev": "d9558e5c97f85372afee28cf2b6059d7d3818919"
},
{
"ImportPath": "golang.org/x/oauth2",
"Rev": "ec6d5d770f531108a6464462b2201b74fcd09314"
"Rev": "b5adcc2dcdf009d0391547edc6ecbaff889f5bb9"
},
{
"ImportPath": "gopkg.in/asn1-ber.v1",
"Comment": "v1",
"Rev": "9eae18c3681ae3d3c677ac2b80a8fe57de45fc09"
}
]
}

View File

@ -65,11 +65,12 @@ func NewIntervalMetrics(intv time.Time) *IntervalMetrics {
// AggregateSample is used to hold aggregate metrics
// about a sample
type AggregateSample struct {
Count int // The count of emitted pairs
Sum float64 // The sum of values
SumSq float64 // The sum of squared values
Min float64 // Minimum value
Max float64 // Maximum value
Count int // The count of emitted pairs
Sum float64 // The sum of values
SumSq float64 // The sum of squared values
Min float64 // Minimum value
Max float64 // Maximum value
LastUpdated time.Time // When value was last updated
}
// Computes a Stddev of the values
@ -101,16 +102,17 @@ func (a *AggregateSample) Ingest(v float64) {
if v > a.Max || a.Count == 1 {
a.Max = v
}
a.LastUpdated = time.Now()
}
func (a *AggregateSample) String() string {
if a.Count == 0 {
return "Count: 0"
} else if a.Stddev() == 0 {
return fmt.Sprintf("Count: %d Sum: %0.3f", a.Count, a.Sum)
return fmt.Sprintf("Count: %d Sum: %0.3f LastUpdated: %s", a.Count, a.Sum, a.LastUpdated)
} else {
return fmt.Sprintf("Count: %d Min: %0.3f Mean: %0.3f Max: %0.3f Stddev: %0.3f Sum: %0.3f",
a.Count, a.Min, a.Mean(), a.Max, a.Stddev(), a.Sum)
return fmt.Sprintf("Count: %d Min: %0.3f Mean: %0.3f Max: %0.3f Stddev: %0.3f Sum: %0.3f LastUpdated: %s",
a.Count, a.Min, a.Mean(), a.Max, a.Stddev(), a.Sum, a.LastUpdated)
}
}

View File

@ -63,6 +63,15 @@ func TestInmemSink(t *testing.T) {
t.Fatalf("bad val: %v", agg)
}
if agg.LastUpdated.IsZero() {
t.Fatalf("agg.LastUpdated is not set: %v", agg)
}
diff := time.Now().Sub(agg.LastUpdated).Seconds()
if diff > 1 {
t.Fatalf("time diff too great: %f", diff)
}
if agg = intvM.Samples["foo.bar"]; agg == nil {
t.Fatalf("missing sample")
}

View File

@ -10,6 +10,8 @@ As a radix tree, it provides the following:
* Minimum / Maximum value lookups
* Ordered iteration
For an immutable variant, see [go-immutable-radix](https://github.com/hashicorp/go-immutable-radix).
Documentation
=============

View File

@ -42,6 +42,17 @@ type Error interface {
OrigErr() error
}
// New returns an Error object described by the code, message, and origErr.
//
// If origErr satisfies the Error interface it will not be wrapped within a new
// Error object and will instead be returned.
func New(code, message string, origErr error) Error {
if e, ok := origErr.(Error); ok && e != nil {
return e
}
return newBaseError(code, message, origErr)
}
// A RequestFailure is an interface to extract request failure information from
// an Error such as the request ID of the failed request returned by a service.
// RequestFailures may not always have a requestID value if the request failed
@ -86,3 +97,9 @@ type RequestFailure interface {
// to a connection error.
RequestID() string
}
// NewRequestFailure returns a new request error wrapper for the given Error
// provided.
func NewRequestFailure(err Error, statusCode int, reqID string) RequestFailure {
return newRequestError(err, statusCode, reqID)
}

View File

@ -1,14 +1,28 @@
// Package apierr represents API error types.
package apierr
package awserr
import "fmt"
// A BaseError wraps the code and message which defines an error. It also
// SprintError returns a string of the formatted error code.
//
// Both extra and origErr are optional. If they are included their lines
// will be added, but if they are not included their lines will be ignored.
func SprintError(code, message, extra string, origErr error) string {
msg := fmt.Sprintf("%s: %s", code, message)
if extra != "" {
msg = fmt.Sprintf("%s\n\t%s", msg, extra)
}
if origErr != nil {
msg = fmt.Sprintf("%s\ncaused by: %s", msg, origErr.Error())
}
return msg
}
// A baseError wraps the code and message which defines an error. It also
// can be used to wrap an original error object.
//
// Should be used as the root for errors satisfying the awserr.Error. Also
// for any error which does not fit into a specific error wrapper type.
type BaseError struct {
type baseError struct {
// Classification of error
code string
@ -20,7 +34,7 @@ type BaseError struct {
origErr error
}
// New returns an error object for the code, message, and err.
// newBaseError returns an error object for the code, message, and err.
//
// code is a short no whitespace phrase depicting the classification of
// the error that is being created.
@ -28,8 +42,8 @@ type BaseError struct {
// message is the free flow string containing detailed information about the error.
//
// origErr is the error object which will be nested under the new error to be returned.
func New(code, message string, origErr error) *BaseError {
return &BaseError{
func newBaseError(code, message string, origErr error) *baseError {
return &baseError{
code: code,
message: message,
origErr: origErr,
@ -41,75 +55,56 @@ func New(code, message string, origErr error) *BaseError {
// See ErrorWithExtra for formatting.
//
// Satisfies the error interface.
func (b *BaseError) Error() string {
return b.ErrorWithExtra("")
func (b baseError) Error() string {
return SprintError(b.code, b.message, "", b.origErr)
}
// String returns the string representation of the error.
// Alias for Error to satisfy the stringer interface.
func (b *BaseError) String() string {
func (b baseError) String() string {
return b.Error()
}
// Code returns the short phrase depicting the classification of the error.
func (b *BaseError) Code() string {
func (b baseError) Code() string {
return b.code
}
// Message returns the error details message.
func (b *BaseError) Message() string {
func (b baseError) Message() string {
return b.message
}
// OrigErr returns the original error if one was set. Nil is returned if no error
// was set.
func (b *BaseError) OrigErr() error {
func (b baseError) OrigErr() error {
return b.origErr
}
// ErrorWithExtra is a helper method to add an extra string to the stratified
// error message. The extra message will be added on the next line below the
// error message like the following:
//
// <error code>: <error message>
// <extra message>
//
// If there is a original error the error will be included on a new line.
//
// <error code>: <error message>
// <extra message>
// caused by: <original error>
func (b *BaseError) ErrorWithExtra(extra string) string {
msg := fmt.Sprintf("%s: %s", b.code, b.message)
if extra != "" {
msg = fmt.Sprintf("%s\n\t%s", msg, extra)
}
if b.origErr != nil {
msg = fmt.Sprintf("%s\ncaused by: %s", msg, b.origErr.Error())
}
return msg
}
// So that the Error interface type can be included as an anonymous field
// in the requestError struct and not conflict with the error.Error() method.
type awsError Error
// A RequestError wraps a request or service error.
// A requestError wraps a request or service error.
//
// Composed of BaseError for code, message, and original error.
type RequestError struct {
*BaseError
// Composed of baseError for code, message, and original error.
type requestError struct {
awsError
statusCode int
requestID string
}
// NewRequestError returns a wrapped error with additional information for request
// newRequestError returns a wrapped error with additional information for request
// status code, and service requestID.
//
// Should be used to wrap all request which involve service requests. Even if
// the request failed without a service response, but had an HTTP status code
// that may be meaningful.
//
// Also wraps original errors via the BaseError.
func NewRequestError(base *BaseError, statusCode int, requestID string) *RequestError {
return &RequestError{
BaseError: base,
// Also wraps original errors via the baseError.
func newRequestError(err Error, statusCode int, requestID string) *requestError {
return &requestError{
awsError: err,
statusCode: statusCode,
requestID: requestID,
}
@ -117,23 +112,24 @@ func NewRequestError(base *BaseError, statusCode int, requestID string) *Request
// Error returns the string representation of the error.
// Satisfies the error interface.
func (r *RequestError) Error() string {
return r.ErrorWithExtra(fmt.Sprintf("status code: %d, request id: [%s]",
r.statusCode, r.requestID))
func (r requestError) Error() string {
extra := fmt.Sprintf("status code: %d, request id: [%s]",
r.statusCode, r.requestID)
return SprintError(r.Code(), r.Message(), extra, r.OrigErr())
}
// String returns the string representation of the error.
// Alias for Error to satisfy the stringer interface.
func (r *RequestError) String() string {
func (r requestError) String() string {
return r.Error()
}
// StatusCode returns the wrapped status code for the error
func (r *RequestError) StatusCode() int {
func (r requestError) StatusCode() int {
return r.statusCode
}
// RequestID returns the wrapped requestID
func (r *RequestError) RequestID() string {
func (r requestError) RequestID() string {
return r.requestID
}

View File

@ -16,8 +16,8 @@ type Struct struct {
}
var data = Struct{
A: []Struct{Struct{C: "value1"}, Struct{C: "value2"}, Struct{C: "value3"}},
z: []Struct{Struct{C: "value1"}, Struct{C: "value2"}, Struct{C: "value3"}},
A: []Struct{{C: "value1"}, {C: "value2"}, {C: "value3"}},
z: []Struct{{C: "value1"}, {C: "value2"}, {C: "value3"}},
B: &Struct{B: &Struct{C: "terminal"}, D: &Struct{C: "terminal2"}},
C: "initial",
}

View File

@ -68,6 +68,7 @@ func (c Config) Copy() Config {
dst.DisableSSL = c.DisableSSL
dst.ManualSend = c.ManualSend
dst.HTTPClient = c.HTTPClient
dst.LogHTTPBody = c.LogHTTPBody
dst.LogLevel = c.LogLevel
dst.Logger = c.Logger
dst.MaxRetries = c.MaxRetries
@ -90,79 +91,79 @@ func (c Config) Merge(newcfg *Config) *Config {
cfg := Config{}
if newcfg != nil && newcfg.Credentials != nil {
if newcfg.Credentials != nil {
cfg.Credentials = newcfg.Credentials
} else {
cfg.Credentials = c.Credentials
}
if newcfg != nil && newcfg.Endpoint != "" {
if newcfg.Endpoint != "" {
cfg.Endpoint = newcfg.Endpoint
} else {
cfg.Endpoint = c.Endpoint
}
if newcfg != nil && newcfg.Region != "" {
if newcfg.Region != "" {
cfg.Region = newcfg.Region
} else {
cfg.Region = c.Region
}
if newcfg != nil && newcfg.DisableSSL {
if newcfg.DisableSSL {
cfg.DisableSSL = newcfg.DisableSSL
} else {
cfg.DisableSSL = c.DisableSSL
}
if newcfg != nil && newcfg.ManualSend {
if newcfg.ManualSend {
cfg.ManualSend = newcfg.ManualSend
} else {
cfg.ManualSend = c.ManualSend
}
if newcfg != nil && newcfg.HTTPClient != nil {
if newcfg.HTTPClient != nil {
cfg.HTTPClient = newcfg.HTTPClient
} else {
cfg.HTTPClient = c.HTTPClient
}
if newcfg != nil && newcfg.LogHTTPBody {
if newcfg.LogHTTPBody {
cfg.LogHTTPBody = newcfg.LogHTTPBody
} else {
cfg.LogHTTPBody = c.LogHTTPBody
}
if newcfg != nil && newcfg.LogLevel != 0 {
if newcfg.LogLevel != 0 {
cfg.LogLevel = newcfg.LogLevel
} else {
cfg.LogLevel = c.LogLevel
}
if newcfg != nil && newcfg.Logger != nil {
if newcfg.Logger != nil {
cfg.Logger = newcfg.Logger
} else {
cfg.Logger = c.Logger
}
if newcfg != nil && newcfg.MaxRetries != DefaultRetries {
if newcfg.MaxRetries != DefaultRetries {
cfg.MaxRetries = newcfg.MaxRetries
} else {
cfg.MaxRetries = c.MaxRetries
}
if newcfg != nil && newcfg.DisableParamValidation {
if newcfg.DisableParamValidation {
cfg.DisableParamValidation = newcfg.DisableParamValidation
} else {
cfg.DisableParamValidation = c.DisableParamValidation
}
if newcfg != nil && newcfg.DisableComputeChecksums {
if newcfg.DisableComputeChecksums {
cfg.DisableComputeChecksums = newcfg.DisableComputeChecksums
} else {
cfg.DisableComputeChecksums = c.DisableComputeChecksums
}
if newcfg != nil && newcfg.S3ForcePathStyle {
if newcfg.S3ForcePathStyle {
cfg.S3ForcePathStyle = newcfg.S3ForcePathStyle
} else {
cfg.S3ForcePathStyle = c.S3ForcePathStyle

View File

@ -0,0 +1,92 @@
package aws
import (
"net/http"
"os"
"reflect"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws/credentials"
)
var testCredentials = credentials.NewChainCredentials([]credentials.Provider{
&credentials.EnvProvider{},
&credentials.SharedCredentialsProvider{
Filename: "TestFilename",
Profile: "TestProfile"},
&credentials.EC2RoleProvider{ExpiryWindow: 5 * time.Minute},
})
var copyTestConfig = Config{
Credentials: testCredentials,
Endpoint: "CopyTestEndpoint",
Region: "COPY_TEST_AWS_REGION",
DisableSSL: true,
ManualSend: true,
HTTPClient: http.DefaultClient,
LogHTTPBody: true,
LogLevel: 2,
Logger: os.Stdout,
MaxRetries: DefaultRetries,
DisableParamValidation: true,
DisableComputeChecksums: true,
S3ForcePathStyle: true,
}
func TestCopy(t *testing.T) {
want := copyTestConfig
got := copyTestConfig.Copy()
if !reflect.DeepEqual(got, want) {
t.Errorf("Copy() = %+v", got)
t.Errorf(" want %+v", want)
}
}
func TestCopyReturnsNewInstance(t *testing.T) {
want := copyTestConfig
got := copyTestConfig.Copy()
if &got == &want {
t.Errorf("Copy() = %p; want different instance as source %p", &got, &want)
}
}
var mergeTestZeroValueConfig = Config{MaxRetries: DefaultRetries}
var mergeTestConfig = Config{
Credentials: testCredentials,
Endpoint: "MergeTestEndpoint",
Region: "MERGE_TEST_AWS_REGION",
DisableSSL: true,
ManualSend: true,
HTTPClient: http.DefaultClient,
LogHTTPBody: true,
LogLevel: 2,
Logger: os.Stdout,
MaxRetries: 10,
DisableParamValidation: true,
DisableComputeChecksums: true,
S3ForcePathStyle: true,
}
var mergeTests = []struct {
cfg *Config
in *Config
want *Config
}{
{&Config{}, nil, &Config{}},
{&Config{}, &mergeTestZeroValueConfig, &Config{}},
{&Config{}, &mergeTestConfig, &mergeTestConfig},
}
func TestMerge(t *testing.T) {
for _, tt := range mergeTests {
got := tt.cfg.Merge(tt.in)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Config %+v", tt.cfg)
t.Errorf(" Merge(%+v)", tt.in)
t.Errorf(" got %+v", got)
t.Errorf(" want %+v", tt.want)
}
}
}

View File

@ -1,13 +1,13 @@
package credentials
import (
"github.com/aws/aws-sdk-go/internal/apierr"
"github.com/aws/aws-sdk-go/aws/awserr"
)
var (
// ErrNoValidProvidersFoundInChain Is returned when there are no valid
// providers in the ChainProvider.
ErrNoValidProvidersFoundInChain = apierr.New("NoCredentialProviders", "no valid providers in chain", nil)
ErrNoValidProvidersFoundInChain = awserr.New("NoCredentialProviders", "no valid providers in chain", nil)
)
// A ChainProvider will search for a provider which returns credentials
@ -36,7 +36,9 @@ var (
// &EnvProvider{},
// &EC2RoleProvider{},
// })
// creds.Retrieve()
//
// // Usage of ChainCredentials with aws.Config
// svc := ec2.New(&aws.Config{Credentials: creds})
//
type ChainProvider struct {
Providers []Provider

View File

@ -3,15 +3,15 @@ package credentials
import (
"testing"
"github.com/aws/aws-sdk-go/internal/apierr"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/stretchr/testify/assert"
)
func TestChainProviderGet(t *testing.T) {
p := &ChainProvider{
Providers: []Provider{
&stubProvider{err: apierr.New("FirstError", "first provider error", nil)},
&stubProvider{err: apierr.New("SecondError", "second provider error", nil)},
&stubProvider{err: awserr.New("FirstError", "first provider error", nil)},
&stubProvider{err: awserr.New("SecondError", "second provider error", nil)},
&stubProvider{
creds: Value{
AccessKeyID: "AKID",
@ -62,8 +62,8 @@ func TestChainProviderWithNoProvider(t *testing.T) {
func TestChainProviderWithNoValidProvider(t *testing.T) {
p := &ChainProvider{
Providers: []Provider{
&stubProvider{err: apierr.New("FirstError", "first provider error", nil)},
&stubProvider{err: apierr.New("SecondError", "second provider error", nil)},
&stubProvider{err: awserr.New("FirstError", "first provider error", nil)},
&stubProvider{err: awserr.New("SecondError", "second provider error", nil)},
},
}

View File

@ -93,6 +93,50 @@ type Provider interface {
IsExpired() bool
}
// A Expiry provides shared expiration logic to be used by credentials
// providers to implement expiry functionality.
//
// The best method to use this struct is as an anonymous field within the
// provider's struct.
//
// Example:
// type EC2RoleProvider struct {
// Expiry
// ...
// }
type Expiry struct {
// The date/time when to expire on
expiration time.Time
// If set will be used by IsExpired to determine the current time.
// Defaults to time.Now if CurrentTime is not set. Available for testing
// to be able to mock out the current time.
CurrentTime func() time.Time
}
// SetExpiration sets the expiration IsExpired will check when called.
//
// If window is greater than 0 the expiration time will be reduced by the
// window value.
//
// Using a window is helpful to trigger credentials to expire sooner than
// the expiration time given to ensure no requests are made with expired
// tokens.
func (e *Expiry) SetExpiration(expiration time.Time, window time.Duration) {
e.expiration = expiration
if window > 0 {
e.expiration = e.expiration.Add(-window)
}
}
// IsExpired returns if the credentials are expired.
func (e *Expiry) IsExpired() bool {
if e.CurrentTime == nil {
e.CurrentTime = time.Now
}
return e.expiration.Before(e.CurrentTime())
}
// A Credentials provides synchronous safe retrieval of AWS credentials Value.
// Credentials will cache the credentials value until they expire. Once the value
// expires the next Get will attempt to retrieve valid credentials.
@ -173,6 +217,3 @@ func (c *Credentials) IsExpired() bool {
func (c *Credentials) isExpired() bool {
return c.forceRefresh || c.provider.IsExpired()
}
// Provide a stub-able time.Now for unit tests so expiry can be tested.
var currentTime = time.Now

View File

@ -4,7 +4,6 @@ import (
"testing"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/internal/apierr"
"github.com/stretchr/testify/assert"
)
@ -40,7 +39,7 @@ func TestCredentialsGet(t *testing.T) {
}
func TestCredentialsGetWithError(t *testing.T) {
c := NewCredentials(&stubProvider{err: apierr.New("provider error", "", nil), expired: true})
c := NewCredentials(&stubProvider{err: awserr.New("provider error", "", nil), expired: true})
_, err := c.Get()
assert.Equal(t, "provider error", err.(awserr.Error).Code(), "Expected provider error")

View File

@ -7,7 +7,7 @@ import (
"net/http"
"time"
"github.com/aws/aws-sdk-go/internal/apierr"
"github.com/aws/aws-sdk-go/aws/awserr"
)
const metadataCredentialsEndpoint = "http://169.254.169.254/latest/meta-data/iam/security-credentials/"
@ -33,6 +33,8 @@ const metadataCredentialsEndpoint = "http://169.254.169.254/latest/meta-data/iam
// }
//
type EC2RoleProvider struct {
Expiry
// Endpoint must be fully quantified URL
Endpoint string
@ -49,9 +51,6 @@ type EC2RoleProvider struct {
//
// If ExpiryWindow is 0 or less it will be ignored.
ExpiryWindow time.Duration
// The date/time at which the credentials expire.
expiresOn time.Time
}
// NewEC2RoleCredentials returns a pointer to a new Credentials object
@ -91,7 +90,7 @@ func (m *EC2RoleProvider) Retrieve() (Value, error) {
}
if len(credsList) == 0 {
return Value{}, apierr.New("EmptyEC2RoleList", "empty EC2 Role list", nil)
return Value{}, awserr.New("EmptyEC2RoleList", "empty EC2 Role list", nil)
}
credsName := credsList[0]
@ -100,11 +99,7 @@ func (m *EC2RoleProvider) Retrieve() (Value, error) {
return Value{}, err
}
m.expiresOn = roleCreds.Expiration
if m.ExpiryWindow > 0 {
// Offset based on expiry window if set.
m.expiresOn = m.expiresOn.Add(-m.ExpiryWindow)
}
m.SetExpiration(roleCreds.Expiration, m.ExpiryWindow)
return Value{
AccessKeyID: roleCreds.AccessKeyID,
@ -113,11 +108,6 @@ func (m *EC2RoleProvider) Retrieve() (Value, error) {
}, nil
}
// IsExpired returns if the credentials are expired.
func (m *EC2RoleProvider) IsExpired() bool {
return m.expiresOn.Before(currentTime())
}
// A ec2RoleCredRespBody provides the shape for deserializing credential
// request responses.
type ec2RoleCredRespBody struct {
@ -132,7 +122,7 @@ type ec2RoleCredRespBody struct {
func requestCredList(client *http.Client, endpoint string) ([]string, error) {
resp, err := client.Get(endpoint)
if err != nil {
return nil, apierr.New("ListEC2Role", "failed to list EC2 Roles", err)
return nil, awserr.New("ListEC2Role", "failed to list EC2 Roles", err)
}
defer resp.Body.Close()
@ -143,7 +133,7 @@ func requestCredList(client *http.Client, endpoint string) ([]string, error) {
}
if err := s.Err(); err != nil {
return nil, apierr.New("ReadEC2Role", "failed to read list of EC2 Roles", err)
return nil, awserr.New("ReadEC2Role", "failed to read list of EC2 Roles", err)
}
return credsList, nil
@ -156,7 +146,7 @@ func requestCredList(client *http.Client, endpoint string) ([]string, error) {
func requestCred(client *http.Client, endpoint, credsName string) (*ec2RoleCredRespBody, error) {
resp, err := client.Get(endpoint + credsName)
if err != nil {
return nil, apierr.New("GetEC2RoleCredentials",
return nil, awserr.New("GetEC2RoleCredentials",
fmt.Sprintf("failed to get %s EC2 Role credentials", credsName),
err)
}
@ -164,7 +154,7 @@ func requestCred(client *http.Client, endpoint, credsName string) (*ec2RoleCredR
respCreds := &ec2RoleCredRespBody{}
if err := json.NewDecoder(resp.Body).Decode(respCreds); err != nil {
return nil, apierr.New("DecodeEC2RoleCredentials",
return nil, awserr.New("DecodeEC2RoleCredentials",
fmt.Sprintf("failed to decode %s EC2 Role credentials", credsName),
err)
}

View File

@ -45,10 +45,7 @@ func TestEC2RoleProviderIsExpired(t *testing.T) {
defer server.Close()
p := &EC2RoleProvider{Client: http.DefaultClient, Endpoint: server.URL}
defer func() {
currentTime = time.Now
}()
currentTime = func() time.Time {
p.CurrentTime = func() time.Time {
return time.Date(2014, 12, 15, 21, 26, 0, 0, time.UTC)
}
@ -59,7 +56,7 @@ func TestEC2RoleProviderIsExpired(t *testing.T) {
assert.False(t, p.IsExpired(), "Expect creds to not be expired after retrieve.")
currentTime = func() time.Time {
p.CurrentTime = func() time.Time {
return time.Date(3014, 12, 15, 21, 26, 0, 0, time.UTC)
}
@ -71,10 +68,7 @@ func TestEC2RoleProviderExpiryWindowIsExpired(t *testing.T) {
defer server.Close()
p := &EC2RoleProvider{Client: http.DefaultClient, Endpoint: server.URL, ExpiryWindow: time.Hour * 1}
defer func() {
currentTime = time.Now
}()
currentTime = func() time.Time {
p.CurrentTime = func() time.Time {
return time.Date(2014, 12, 15, 0, 51, 37, 0, time.UTC)
}
@ -85,7 +79,7 @@ func TestEC2RoleProviderExpiryWindowIsExpired(t *testing.T) {
assert.False(t, p.IsExpired(), "Expect creds to not be expired after retrieve.")
currentTime = func() time.Time {
p.CurrentTime = func() time.Time {
return time.Date(2014, 12, 16, 0, 55, 37, 0, time.UTC)
}

View File

@ -3,16 +3,16 @@ package credentials
import (
"os"
"github.com/aws/aws-sdk-go/internal/apierr"
"github.com/aws/aws-sdk-go/aws/awserr"
)
var (
// ErrAccessKeyIDNotFound is returned when the AWS Access Key ID can't be
// found in the process's environment.
ErrAccessKeyIDNotFound = apierr.New("EnvAccessKeyNotFound", "AWS_ACCESS_KEY_ID or AWS_ACCESS_KEY not found in environment", nil)
ErrAccessKeyIDNotFound = awserr.New("EnvAccessKeyNotFound", "AWS_ACCESS_KEY_ID or AWS_ACCESS_KEY not found in environment", nil)
// ErrSecretAccessKeyNotFound is returned when the AWS Secret Access Key
// can't be found in the process's environment.
ErrSecretAccessKeyNotFound = apierr.New("EnvSecretNotFound", "AWS_SECRET_ACCESS_KEY or AWS_SECRET_KEY not found in environment", nil)
ErrSecretAccessKeyNotFound = awserr.New("EnvSecretNotFound", "AWS_SECRET_ACCESS_KEY or AWS_SECRET_KEY not found in environment", nil)
)
// A EnvProvider retrieves credentials from the environment variables of the

View File

@ -7,12 +7,12 @@ import (
"github.com/vaughan0/go-ini"
"github.com/aws/aws-sdk-go/internal/apierr"
"github.com/aws/aws-sdk-go/aws/awserr"
)
var (
// ErrSharedCredentialsHomeNotFound is emitted when the user directory cannot be found.
ErrSharedCredentialsHomeNotFound = apierr.New("UserHomeNotFound", "user home directory not found.", nil)
ErrSharedCredentialsHomeNotFound = awserr.New("UserHomeNotFound", "user home directory not found.", nil)
)
// A SharedCredentialsProvider retrieves credentials from the current user's home
@ -72,20 +72,20 @@ func (p *SharedCredentialsProvider) IsExpired() bool {
func loadProfile(filename, profile string) (Value, error) {
config, err := ini.LoadFile(filename)
if err != nil {
return Value{}, apierr.New("SharedCredsLoad", "failed to load shared credentials file", err)
return Value{}, awserr.New("SharedCredsLoad", "failed to load shared credentials file", err)
}
iniProfile := config.Section(profile)
id, ok := iniProfile["aws_access_key_id"]
if !ok {
return Value{}, apierr.New("SharedCredsAccessKey",
return Value{}, awserr.New("SharedCredsAccessKey",
fmt.Sprintf("shared credentials %s in %s did not contain aws_access_key_id", profile, filename),
nil)
}
secret, ok := iniProfile["aws_secret_access_key"]
if !ok {
return Value{}, apierr.New("SharedCredsSecret",
return Value{}, awserr.New("SharedCredsSecret",
fmt.Sprintf("shared credentials %s in %s did not contain aws_secret_access_key", profile, filename),
nil)
}

View File

@ -1,12 +1,12 @@
package credentials
import (
"github.com/aws/aws-sdk-go/internal/apierr"
"github.com/aws/aws-sdk-go/aws/awserr"
)
var (
// ErrStaticCredentialsEmpty is emitted when static credentials are empty.
ErrStaticCredentialsEmpty = apierr.New("EmptyStaticCreds", "static credentials are empty", nil)
ErrStaticCredentialsEmpty = awserr.New("EmptyStaticCreds", "static credentials are empty", nil)
)
// A StaticProvider is a set of credentials which are set pragmatically,

View File

@ -0,0 +1,120 @@
// Package stscreds are credential Providers to retrieve STS AWS credentials.
//
// STS provides multiple ways to retrieve credentials which can be used when making
// future AWS service API operation calls.
package stscreds
import (
"fmt"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/service/sts"
"time"
)
// AssumeRoler represents the minimal subset of the STS client API used by this provider.
type AssumeRoler interface {
AssumeRole(input *sts.AssumeRoleInput) (*sts.AssumeRoleOutput, error)
}
// AssumeRoleProvider retrieves temporary credentials from the STS service, and
// keeps track of their expiration time. This provider must be used explicitly,
// as it is not included in the credentials chain.
//
// Example how to configure a service to use this provider:
//
// config := &aws.Config{
// Credentials: stscreds.NewCredentials(nil, "arn-of-the-role-to-assume", 10*time.Second),
// })
// // Use config for creating your AWS service.
//
// Example how to obtain customised credentials:
//
// provider := &stscreds.Provider{
// // Extend the duration to 1 hour.
// Duration: time.Hour,
// // Custom role name.
// RoleSessionName: "custom-session-name",
// }
// creds := credentials.NewCredentials(provider)
//
type AssumeRoleProvider struct {
credentials.Expiry
// Custom STS client. If not set the default STS client will be used.
Client AssumeRoler
// Role to be assumed.
RoleARN string
// Session name, if you wish to reuse the credentials elsewhere.
RoleSessionName string
// Expiry duration of the STS credentials. Defaults to 15 minutes if not set.
Duration time.Duration
// ExpiryWindow will allow the credentials to trigger refreshing prior to
// the credentials actually expiring. This is beneficial so race conditions
// with expiring credentials do not cause request to fail unexpectedly
// due to ExpiredTokenException exceptions.
//
// So a ExpiryWindow of 10s would cause calls to IsExpired() to return true
// 10 seconds before the credentials are actually expired.
//
// If ExpiryWindow is 0 or less it will be ignored.
ExpiryWindow time.Duration
}
// NewCredentials returns a pointer to a new Credentials object wrapping the
// AssumeRoleProvider. The credentials will expire every 15 minutes and the
// role will be named after a nanosecond timestamp of this operation.
//
// The sts and roleARN parameters are used for building the "AssumeRole" call.
// Pass nil as sts to use the default client.
//
// Window is the expiry window that will be subtracted from the expiry returned
// by the role credential request. This is done so that the credentials will
// expire sooner than their actual lifespan.
func NewCredentials(client AssumeRoler, roleARN string, window time.Duration) *credentials.Credentials {
return credentials.NewCredentials(&AssumeRoleProvider{
Client: client,
RoleARN: roleARN,
ExpiryWindow: window,
})
}
// Retrieve generates a new set of temporary credentials using STS.
func (p *AssumeRoleProvider) Retrieve() (credentials.Value, error) {
// Apply defaults where parameters are not set.
if p.Client == nil {
p.Client = sts.New(nil)
}
if p.RoleSessionName == "" {
// Try to work out a role name that will hopefully end up unique.
p.RoleSessionName = fmt.Sprintf("%d", time.Now().UTC().UnixNano())
}
if p.Duration == 0 {
// Expire as often as AWS permits.
p.Duration = 15 * time.Minute
}
roleOutput, err := p.Client.AssumeRole(&sts.AssumeRoleInput{
DurationSeconds: aws.Long(int64(p.Duration / time.Second)),
RoleARN: aws.String(p.RoleARN),
RoleSessionName: aws.String(p.RoleSessionName),
})
if err != nil {
return credentials.Value{}, err
}
// We will proactively generate new credentials before they expire.
p.SetExpiration(*roleOutput.Credentials.Expiration, p.ExpiryWindow)
return credentials.Value{
AccessKeyID: *roleOutput.Credentials.AccessKeyID,
SecretAccessKey: *roleOutput.Credentials.SecretAccessKey,
SessionToken: *roleOutput.Credentials.SessionToken,
}, nil
}

View File

@ -0,0 +1,58 @@
package stscreds
import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/stretchr/testify/assert"
"testing"
"time"
)
type stubSTS struct {
}
func (s *stubSTS) AssumeRole(input *sts.AssumeRoleInput) (*sts.AssumeRoleOutput, error) {
expiry := time.Now().Add(60 * time.Minute)
return &sts.AssumeRoleOutput{
Credentials: &sts.Credentials{
// Just reflect the role arn to the provider.
AccessKeyID: input.RoleARN,
SecretAccessKey: aws.String("assumedSecretAccessKey"),
SessionToken: aws.String("assumedSessionToken"),
Expiration: &expiry,
},
}, nil
}
func TestAssumeRoleProvider(t *testing.T) {
stub := &stubSTS{}
p := &AssumeRoleProvider{
Client: stub,
RoleARN: "roleARN",
}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "roleARN", creds.AccessKeyID, "Expect access key ID to be reflected role ARN")
assert.Equal(t, "assumedSecretAccessKey", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "assumedSessionToken", creds.SessionToken, "Expect session token to match")
}
func BenchmarkAssumeRoleProvider(b *testing.B) {
stub := &stubSTS{}
p := &AssumeRoleProvider{
Client: stub,
RoleARN: "roleARN",
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_, err := p.Retrieve()
if err != nil {
b.Fatal(err)
}
}
})
}

View File

@ -12,7 +12,6 @@ import (
"time"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/internal/apierr"
)
var sleepDelay = func(delay time.Duration) {
@ -80,8 +79,17 @@ func SendHandler(r *Request) {
return
}
}
if r.HTTPRequest == nil {
// Add a dummy request response object to ensure the HTTPResponse
// value is consistent.
r.HTTPResponse = &http.Response{
StatusCode: int(0),
Status: http.StatusText(int(0)),
Body: ioutil.NopCloser(bytes.NewReader([]byte{})),
}
}
// Catch all other request errors.
r.Error = apierr.New("RequestError", "send request failed", err)
r.Error = awserr.New("RequestError", "send request failed", err)
r.Retryable.Set(true) // network errors are retryable
}
}
@ -90,7 +98,7 @@ func SendHandler(r *Request) {
func ValidateResponseHandler(r *Request) {
if r.HTTPResponse.StatusCode == 0 || r.HTTPResponse.StatusCode >= 300 {
// this may be replaced by an UnmarshalError handler
r.Error = apierr.New("UnknownError", "unknown error", nil)
r.Error = awserr.New("UnknownError", "unknown error", nil)
}
}
@ -114,8 +122,6 @@ func AfterRetryHandler(r *Request) {
if err, ok := r.Error.(awserr.Error); ok {
if isCodeExpiredCreds(err.Code()) {
r.Config.Credentials.Expire()
// The credentials will need to be resigned with new credentials
r.signed = false
}
}
}
@ -128,11 +134,11 @@ func AfterRetryHandler(r *Request) {
var (
// ErrMissingRegion is an error that is returned if region configuration is
// not found.
ErrMissingRegion error = apierr.New("MissingRegion", "could not find region configuration", nil)
ErrMissingRegion error = awserr.New("MissingRegion", "could not find region configuration", nil)
// ErrMissingEndpoint is an error that is returned if an endpoint cannot be
// resolved for a service.
ErrMissingEndpoint error = apierr.New("MissingEndpoint", "'Endpoint' configuration is required for this service", nil)
ErrMissingEndpoint error = awserr.New("MissingEndpoint", "'Endpoint' configuration is required for this service", nil)
)
// ValidateEndpointHandler is a request handler to validate a request had the

View File

@ -5,8 +5,8 @@ import (
"os"
"testing"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/internal/apierr"
"github.com/stretchr/testify/assert"
)
@ -56,11 +56,11 @@ func TestAfterRetryRefreshCreds(t *testing.T) {
svc.Handlers.Clear()
svc.Handlers.ValidateResponse.PushBack(func(r *Request) {
r.Error = apierr.New("UnknownError", "", nil)
r.Error = awserr.New("UnknownError", "", nil)
r.HTTPResponse = &http.Response{StatusCode: 400}
})
svc.Handlers.UnmarshalError.PushBack(func(r *Request) {
r.Error = apierr.New("ExpiredTokenException", "", nil)
r.Error = awserr.New("ExpiredTokenException", "", nil)
})
svc.Handlers.AfterRetry.PushBack(func(r *Request) {
AfterRetryHandler(r)

View File

@ -5,7 +5,7 @@ import (
"reflect"
"strings"
"github.com/aws/aws-sdk-go/internal/apierr"
"github.com/aws/aws-sdk-go/aws/awserr"
)
// ValidateParameters is a request handler to validate the input parameters.
@ -18,7 +18,7 @@ func ValidateParameters(r *Request) {
if count := len(v.errors); count > 0 {
format := "%d validation errors:\n- %s"
msg := fmt.Sprintf(format, count, strings.Join(v.errors, "\n- "))
r.Error = apierr.New("InvalidParameter", msg, nil)
r.Error = awserr.New("InvalidParameter", msg, nil)
}
}
}

View File

@ -41,8 +41,8 @@ func TestNoErrors(t *testing.T) {
input := &StructShape{
RequiredList: []*ConditionalStructShape{},
RequiredMap: map[string]*ConditionalStructShape{
"key1": &ConditionalStructShape{Name: aws.String("Name")},
"key2": &ConditionalStructShape{Name: aws.String("Name")},
"key1": {Name: aws.String("Name")},
"key2": {Name: aws.String("Name")},
},
RequiredBool: aws.Boolean(true),
OptionalStruct: &ConditionalStructShape{Name: aws.String("Name")},
@ -65,10 +65,10 @@ func TestMissingRequiredParameters(t *testing.T) {
func TestNestedMissingRequiredParameters(t *testing.T) {
input := &StructShape{
RequiredList: []*ConditionalStructShape{&ConditionalStructShape{}},
RequiredList: []*ConditionalStructShape{{}},
RequiredMap: map[string]*ConditionalStructShape{
"key1": &ConditionalStructShape{Name: aws.String("Name")},
"key2": &ConditionalStructShape{},
"key1": {Name: aws.String("Name")},
"key2": {},
},
RequiredBool: aws.Boolean(true),
OptionalStruct: &ConditionalStructShape{},

View File

@ -32,8 +32,7 @@ type Request struct {
Retryable SettableBool
RetryDelay time.Duration
built bool
signed bool
built bool
}
// An Operation is the service API operation to be made.
@ -164,17 +163,12 @@ func (r *Request) Build() error {
// Send will build the request prior to signing. All Sign Handlers will
// be executed in the order they were set.
func (r *Request) Sign() error {
if r.signed {
return r.Error
}
r.Build()
if r.Error != nil {
return r.Error
}
r.Handlers.Sign.Run(r)
r.signed = r.Error != nil
return r.Error
}
@ -203,6 +197,7 @@ func (r *Request) Send() error {
if r.Error != nil {
return r.Error
}
continue
}
r.Handlers.UnmarshalMeta.Run(r)

View File

@ -19,9 +19,9 @@ func TestPagination(t *testing.T) {
reqNum := 0
resps := []*dynamodb.ListTablesOutput{
&dynamodb.ListTablesOutput{TableNames: []*string{aws.String("Table1"), aws.String("Table2")}, LastEvaluatedTableName: aws.String("Table2")},
&dynamodb.ListTablesOutput{TableNames: []*string{aws.String("Table3"), aws.String("Table4")}, LastEvaluatedTableName: aws.String("Table4")},
&dynamodb.ListTablesOutput{TableNames: []*string{aws.String("Table5")}},
{TableNames: []*string{aws.String("Table1"), aws.String("Table2")}, LastEvaluatedTableName: aws.String("Table2")},
{TableNames: []*string{aws.String("Table3"), aws.String("Table4")}, LastEvaluatedTableName: aws.String("Table4")},
{TableNames: []*string{aws.String("Table5")}},
}
db.Handlers.Send.Clear() // mock sending
@ -71,9 +71,9 @@ func TestPaginationEachPage(t *testing.T) {
reqNum := 0
resps := []*dynamodb.ListTablesOutput{
&dynamodb.ListTablesOutput{TableNames: []*string{aws.String("Table1"), aws.String("Table2")}, LastEvaluatedTableName: aws.String("Table2")},
&dynamodb.ListTablesOutput{TableNames: []*string{aws.String("Table3"), aws.String("Table4")}, LastEvaluatedTableName: aws.String("Table4")},
&dynamodb.ListTablesOutput{TableNames: []*string{aws.String("Table5")}},
{TableNames: []*string{aws.String("Table1"), aws.String("Table2")}, LastEvaluatedTableName: aws.String("Table2")},
{TableNames: []*string{aws.String("Table3"), aws.String("Table4")}, LastEvaluatedTableName: aws.String("Table4")},
{TableNames: []*string{aws.String("Table5")}},
}
db.Handlers.Send.Clear() // mock sending
@ -124,9 +124,9 @@ func TestPaginationEarlyExit(t *testing.T) {
reqNum := 0
resps := []*dynamodb.ListTablesOutput{
&dynamodb.ListTablesOutput{TableNames: []*string{aws.String("Table1"), aws.String("Table2")}, LastEvaluatedTableName: aws.String("Table2")},
&dynamodb.ListTablesOutput{TableNames: []*string{aws.String("Table3"), aws.String("Table4")}, LastEvaluatedTableName: aws.String("Table4")},
&dynamodb.ListTablesOutput{TableNames: []*string{aws.String("Table5")}},
{TableNames: []*string{aws.String("Table1"), aws.String("Table2")}, LastEvaluatedTableName: aws.String("Table2")},
{TableNames: []*string{aws.String("Table3"), aws.String("Table4")}, LastEvaluatedTableName: aws.String("Table4")},
{TableNames: []*string{aws.String("Table5")}},
}
db.Handlers.Send.Clear() // mock sending
@ -189,10 +189,10 @@ func TestPaginationTruncation(t *testing.T) {
reqNum := &count
resps := []*s3.ListObjectsOutput{
&s3.ListObjectsOutput{IsTruncated: aws.Boolean(true), Contents: []*s3.Object{&s3.Object{Key: aws.String("Key1")}}},
&s3.ListObjectsOutput{IsTruncated: aws.Boolean(true), Contents: []*s3.Object{&s3.Object{Key: aws.String("Key2")}}},
&s3.ListObjectsOutput{IsTruncated: aws.Boolean(false), Contents: []*s3.Object{&s3.Object{Key: aws.String("Key3")}}},
&s3.ListObjectsOutput{IsTruncated: aws.Boolean(true), Contents: []*s3.Object{&s3.Object{Key: aws.String("Key4")}}},
{IsTruncated: aws.Boolean(true), Contents: []*s3.Object{{Key: aws.String("Key1")}}},
{IsTruncated: aws.Boolean(true), Contents: []*s3.Object{{Key: aws.String("Key2")}}},
{IsTruncated: aws.Boolean(false), Contents: []*s3.Object{{Key: aws.String("Key3")}}},
{IsTruncated: aws.Boolean(true), Contents: []*s3.Object{{Key: aws.String("Key4")}}},
}
client.Handlers.Send.Clear() // mock sending
@ -232,20 +232,20 @@ func TestPaginationTruncation(t *testing.T) {
// Benchmarks
var benchResps = []*dynamodb.ListTablesOutput{
&dynamodb.ListTablesOutput{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
&dynamodb.ListTablesOutput{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
&dynamodb.ListTablesOutput{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
&dynamodb.ListTablesOutput{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
&dynamodb.ListTablesOutput{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
&dynamodb.ListTablesOutput{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
&dynamodb.ListTablesOutput{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
&dynamodb.ListTablesOutput{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
&dynamodb.ListTablesOutput{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
&dynamodb.ListTablesOutput{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
&dynamodb.ListTablesOutput{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
&dynamodb.ListTablesOutput{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
&dynamodb.ListTablesOutput{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
&dynamodb.ListTablesOutput{TableNames: []*string{aws.String("TABLE")}},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE"), aws.String("NXT")}, LastEvaluatedTableName: aws.String("NXT")},
{TableNames: []*string{aws.String("TABLE")}},
}
var benchDb = func() *dynamodb.DynamoDB {

View File

@ -13,7 +13,6 @@ import (
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/internal/apierr"
"github.com/stretchr/testify/assert"
)
@ -36,12 +35,12 @@ func unmarshal(req *Request) {
func unmarshalError(req *Request) {
bodyBytes, err := ioutil.ReadAll(req.HTTPResponse.Body)
if err != nil {
req.Error = apierr.New("UnmarshaleError", req.HTTPResponse.Status, err)
req.Error = awserr.New("UnmarshaleError", req.HTTPResponse.Status, err)
return
}
if len(bodyBytes) == 0 {
req.Error = apierr.NewRequestError(
apierr.New("UnmarshaleError", req.HTTPResponse.Status, fmt.Errorf("empty body")),
req.Error = awserr.NewRequestFailure(
awserr.New("UnmarshaleError", req.HTTPResponse.Status, fmt.Errorf("empty body")),
req.HTTPResponse.StatusCode,
"",
)
@ -49,11 +48,11 @@ func unmarshalError(req *Request) {
}
var jsonErr jsonErrorResponse
if err := json.Unmarshal(bodyBytes, &jsonErr); err != nil {
req.Error = apierr.New("UnmarshaleError", "JSON unmarshal", err)
req.Error = awserr.New("UnmarshaleError", "JSON unmarshal", err)
return
}
req.Error = apierr.NewRequestError(
apierr.New(jsonErr.Code, jsonErr.Message, nil),
req.Error = awserr.NewRequestFailure(
awserr.New(jsonErr.Code, jsonErr.Message, nil),
req.HTTPResponse.StatusCode,
"",
)
@ -68,9 +67,9 @@ type jsonErrorResponse struct {
func TestRequestRecoverRetry5xx(t *testing.T) {
reqNum := 0
reqs := []http.Response{
http.Response{StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)},
http.Response{StatusCode: 501, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)},
http.Response{StatusCode: 200, Body: body(`{"data":"valid"}`)},
{StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)},
{StatusCode: 501, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)},
{StatusCode: 200, Body: body(`{"data":"valid"}`)},
}
s := NewService(&Config{MaxRetries: 10})
@ -94,9 +93,9 @@ func TestRequestRecoverRetry5xx(t *testing.T) {
func TestRequestRecoverRetry4xxRetryable(t *testing.T) {
reqNum := 0
reqs := []http.Response{
http.Response{StatusCode: 400, Body: body(`{"__type":"Throttling","message":"Rate exceeded."}`)},
http.Response{StatusCode: 429, Body: body(`{"__type":"ProvisionedThroughputExceededException","message":"Rate exceeded."}`)},
http.Response{StatusCode: 200, Body: body(`{"data":"valid"}`)},
{StatusCode: 400, Body: body(`{"__type":"Throttling","message":"Rate exceeded."}`)},
{StatusCode: 429, Body: body(`{"__type":"ProvisionedThroughputExceededException","message":"Rate exceeded."}`)},
{StatusCode: 200, Body: body(`{"data":"valid"}`)},
}
s := NewService(&Config{MaxRetries: 10})
@ -148,10 +147,10 @@ func TestRequestExhaustRetries(t *testing.T) {
reqNum := 0
reqs := []http.Response{
http.Response{StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)},
http.Response{StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)},
http.Response{StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)},
http.Response{StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)},
{StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)},
{StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)},
{StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)},
{StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)},
}
s := NewService(&Config{MaxRetries: -1})
@ -181,8 +180,8 @@ func TestRequestExhaustRetries(t *testing.T) {
func TestRequestRecoverExpiredCreds(t *testing.T) {
reqNum := 0
reqs := []http.Response{
http.Response{StatusCode: 400, Body: body(`{"__type":"ExpiredTokenException","message":"expired token"}`)},
http.Response{StatusCode: 200, Body: body(`{"data":"valid"}`)},
{StatusCode: 400, Body: body(`{"__type":"ExpiredTokenException","message":"expired token"}`)},
{StatusCode: 200, Body: body(`{"data":"valid"}`)},
}
s := NewService(&Config{MaxRetries: 10, Credentials: credentials.NewStaticCredentials("AKID", "SECRET", "")})

View File

@ -136,17 +136,18 @@ func retryRules(r *Request) time.Duration {
// retryableCodes is a collection of service response codes which are retry-able
// without any further action.
var retryableCodes = map[string]struct{}{
"ProvisionedThroughputExceededException": struct{}{},
"Throttling": struct{}{},
"RequestError": {},
"ProvisionedThroughputExceededException": {},
"Throttling": {},
}
// credsExpiredCodes is a collection of error codes which signify the credentials
// need to be refreshed. Expired tokens require refreshing of credentials, and
// resigning before the request can be retried.
var credsExpiredCodes = map[string]struct{}{
"ExpiredToken": struct{}{},
"ExpiredTokenException": struct{}{},
"RequestExpired": struct{}{}, // EC2 Only
"ExpiredToken": {},
"ExpiredTokenException": {},
"RequestExpired": {}, // EC2 Only
}
func isCodeRetryable(code string) bool {

View File

@ -5,4 +5,4 @@ package aws
const SDKName = "aws-sdk-go"
// SDKVersion is the version of this SDK
const SDKVersion = "0.6.0"
const SDKVersion = "0.6.4"

View File

@ -1,6 +1,8 @@
// Package endpoints validates regional endpoints for services.
package endpoints
//go:generate go run ../model/cli/gen-endpoints/main.go endpoints.json endpoints_map.go
//go:generate gofmt -s -w endpoints_map.go
import "strings"

View File

@ -15,74 +15,74 @@ type endpointEntry struct {
var endpointsMap = endpointStruct{
Version: 2,
Endpoints: map[string]endpointEntry{
"*/*": endpointEntry{
"*/*": {
Endpoint: "{service}.{region}.amazonaws.com",
},
"*/cloudfront": endpointEntry{
"*/cloudfront": {
Endpoint: "cloudfront.amazonaws.com",
SigningRegion: "us-east-1",
},
"*/cloudsearchdomain": endpointEntry{
"*/cloudsearchdomain": {
Endpoint: "",
SigningRegion: "us-east-1",
},
"*/iam": endpointEntry{
"*/iam": {
Endpoint: "iam.amazonaws.com",
SigningRegion: "us-east-1",
},
"*/importexport": endpointEntry{
"*/importexport": {
Endpoint: "importexport.amazonaws.com",
SigningRegion: "us-east-1",
},
"*/route53": endpointEntry{
"*/route53": {
Endpoint: "route53.amazonaws.com",
SigningRegion: "us-east-1",
},
"*/sts": endpointEntry{
"*/sts": {
Endpoint: "sts.amazonaws.com",
SigningRegion: "us-east-1",
},
"ap-northeast-1/s3": endpointEntry{
"ap-northeast-1/s3": {
Endpoint: "s3-{region}.amazonaws.com",
},
"ap-southeast-1/s3": endpointEntry{
"ap-southeast-1/s3": {
Endpoint: "s3-{region}.amazonaws.com",
},
"ap-southeast-2/s3": endpointEntry{
"ap-southeast-2/s3": {
Endpoint: "s3-{region}.amazonaws.com",
},
"cn-north-1/*": endpointEntry{
"cn-north-1/*": {
Endpoint: "{service}.{region}.amazonaws.com.cn",
},
"eu-central-1/s3": endpointEntry{
"eu-central-1/s3": {
Endpoint: "{service}.{region}.amazonaws.com",
},
"eu-west-1/s3": endpointEntry{
"eu-west-1/s3": {
Endpoint: "s3-{region}.amazonaws.com",
},
"sa-east-1/s3": endpointEntry{
"sa-east-1/s3": {
Endpoint: "s3-{region}.amazonaws.com",
},
"us-east-1/s3": endpointEntry{
"us-east-1/s3": {
Endpoint: "s3.amazonaws.com",
},
"us-east-1/sdb": endpointEntry{
"us-east-1/sdb": {
Endpoint: "sdb.amazonaws.com",
SigningRegion: "us-east-1",
},
"us-gov-west-1/iam": endpointEntry{
"us-gov-west-1/iam": {
Endpoint: "iam.us-gov.amazonaws.com",
},
"us-gov-west-1/s3": endpointEntry{
"us-gov-west-1/s3": {
Endpoint: "s3-{region}.amazonaws.com",
},
"us-gov-west-1/sts": endpointEntry{
"us-gov-west-1/sts": {
Endpoint: "sts.us-gov-west-1.amazonaws.com",
},
"us-west-1/s3": endpointEntry{
"us-west-1/s3": {
Endpoint: "s3-{region}.amazonaws.com",
},
"us-west-2/s3": endpointEntry{
"us-west-2/s3": {
Endpoint: "s3-{region}.amazonaws.com",
},
},

View File

@ -1,3 +1,4 @@
// Package query provides serialisation of AWS query requests, and responses.
package query
//go:generate go run ../../fixtures/protocol/generate.go ../../fixtures/protocol/input/query.json build_test.go
@ -6,7 +7,7 @@ import (
"net/url"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/internal/apierr"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/internal/protocol/query/queryutil"
)
@ -17,7 +18,7 @@ func Build(r *aws.Request) {
"Version": {r.Service.APIVersion},
}
if err := queryutil.Parse(body, r.Params, false); err != nil {
r.Error = apierr.New("Marshal", "failed encoding Query request", err)
r.Error = awserr.New("SerializationError", "failed encoding Query request", err)
return
}

View File

@ -1,10 +1,6 @@
package query_test
import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/internal/protocol/query"
"github.com/aws/aws-sdk-go/internal/signer/v4"
"bytes"
"encoding/json"
"encoding/xml"
@ -15,7 +11,10 @@ import (
"testing"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/internal/protocol/query"
"github.com/aws/aws-sdk-go/internal/protocol/xml/xmlutil"
"github.com/aws/aws-sdk-go/internal/signer/v4"
"github.com/aws/aws-sdk-go/internal/util"
"github.com/stretchr/testify/assert"
)
@ -63,20 +62,19 @@ func (c *InputService1ProtocolTest) newRequest(op *aws.Operation, params, data i
return req
}
const opInputService1TestCaseOperation1 = "OperationName"
// InputService1TestCaseOperation1Request generates a request for the InputService1TestCaseOperation1 operation.
func (c *InputService1ProtocolTest) InputService1TestCaseOperation1Request(input *InputService1TestShapeInputShape) (req *aws.Request, output *InputService1TestShapeInputService1TestCaseOperation1Output) {
if opInputService1TestCaseOperation1 == nil {
opInputService1TestCaseOperation1 = &aws.Operation{
Name: "OperationName",
}
op := &aws.Operation{
Name: opInputService1TestCaseOperation1,
}
if input == nil {
input = &InputService1TestShapeInputShape{}
}
req = c.newRequest(opInputService1TestCaseOperation1, input, output)
req = c.newRequest(op, input, output)
output = &InputService1TestShapeInputService1TestCaseOperation1Output{}
req.Data = output
return
@ -88,8 +86,6 @@ func (c *InputService1ProtocolTest) InputService1TestCaseOperation1(input *Input
return out, err
}
var opInputService1TestCaseOperation1 *aws.Operation
type InputService1TestShapeInputService1TestCaseOperation1Output struct {
metadataInputService1TestShapeInputService1TestCaseOperation1Output `json:"-" xml:"-"`
}
@ -142,20 +138,19 @@ func (c *InputService2ProtocolTest) newRequest(op *aws.Operation, params, data i
return req
}
const opInputService2TestCaseOperation1 = "OperationName"
// InputService2TestCaseOperation1Request generates a request for the InputService2TestCaseOperation1 operation.
func (c *InputService2ProtocolTest) InputService2TestCaseOperation1Request(input *InputService2TestShapeInputShape) (req *aws.Request, output *InputService2TestShapeInputService2TestCaseOperation1Output) {
if opInputService2TestCaseOperation1 == nil {
opInputService2TestCaseOperation1 = &aws.Operation{
Name: "OperationName",
}
op := &aws.Operation{
Name: opInputService2TestCaseOperation1,
}
if input == nil {
input = &InputService2TestShapeInputShape{}
}
req = c.newRequest(opInputService2TestCaseOperation1, input, output)
req = c.newRequest(op, input, output)
output = &InputService2TestShapeInputService2TestCaseOperation1Output{}
req.Data = output
return
@ -167,8 +162,6 @@ func (c *InputService2ProtocolTest) InputService2TestCaseOperation1(input *Input
return out, err
}
var opInputService2TestCaseOperation1 *aws.Operation
type InputService2TestShapeInputService2TestCaseOperation1Output struct {
metadataInputService2TestShapeInputService2TestCaseOperation1Output `json:"-" xml:"-"`
}
@ -229,20 +222,19 @@ func (c *InputService3ProtocolTest) newRequest(op *aws.Operation, params, data i
return req
}
const opInputService3TestCaseOperation1 = "OperationName"
// InputService3TestCaseOperation1Request generates a request for the InputService3TestCaseOperation1 operation.
func (c *InputService3ProtocolTest) InputService3TestCaseOperation1Request(input *InputService3TestShapeInputShape) (req *aws.Request, output *InputService3TestShapeInputService3TestCaseOperation1Output) {
if opInputService3TestCaseOperation1 == nil {
opInputService3TestCaseOperation1 = &aws.Operation{
Name: "OperationName",
}
op := &aws.Operation{
Name: opInputService3TestCaseOperation1,
}
if input == nil {
input = &InputService3TestShapeInputShape{}
}
req = c.newRequest(opInputService3TestCaseOperation1, input, output)
req = c.newRequest(op, input, output)
output = &InputService3TestShapeInputService3TestCaseOperation1Output{}
req.Data = output
return
@ -254,22 +246,19 @@ func (c *InputService3ProtocolTest) InputService3TestCaseOperation1(input *Input
return out, err
}
var opInputService3TestCaseOperation1 *aws.Operation
const opInputService3TestCaseOperation2 = "OperationName"
// InputService3TestCaseOperation2Request generates a request for the InputService3TestCaseOperation2 operation.
func (c *InputService3ProtocolTest) InputService3TestCaseOperation2Request(input *InputService3TestShapeInputShape) (req *aws.Request, output *InputService3TestShapeInputService3TestCaseOperation2Output) {
if opInputService3TestCaseOperation2 == nil {
opInputService3TestCaseOperation2 = &aws.Operation{
Name: "OperationName",
}
op := &aws.Operation{
Name: opInputService3TestCaseOperation2,
}
if input == nil {
input = &InputService3TestShapeInputShape{}
}
req = c.newRequest(opInputService3TestCaseOperation2, input, output)
req = c.newRequest(op, input, output)
output = &InputService3TestShapeInputService3TestCaseOperation2Output{}
req.Data = output
return
@ -281,8 +270,6 @@ func (c *InputService3ProtocolTest) InputService3TestCaseOperation2(input *Input
return out, err
}
var opInputService3TestCaseOperation2 *aws.Operation
type InputService3TestShapeInputService3TestCaseOperation1Output struct {
metadataInputService3TestShapeInputService3TestCaseOperation1Output `json:"-" xml:"-"`
}
@ -341,20 +328,19 @@ func (c *InputService4ProtocolTest) newRequest(op *aws.Operation, params, data i
return req
}
const opInputService4TestCaseOperation1 = "OperationName"
// InputService4TestCaseOperation1Request generates a request for the InputService4TestCaseOperation1 operation.
func (c *InputService4ProtocolTest) InputService4TestCaseOperation1Request(input *InputService4TestShapeInputShape) (req *aws.Request, output *InputService4TestShapeInputService4TestCaseOperation1Output) {
if opInputService4TestCaseOperation1 == nil {
opInputService4TestCaseOperation1 = &aws.Operation{
Name: "OperationName",
}
op := &aws.Operation{
Name: opInputService4TestCaseOperation1,
}
if input == nil {
input = &InputService4TestShapeInputShape{}
}
req = c.newRequest(opInputService4TestCaseOperation1, input, output)
req = c.newRequest(op, input, output)
output = &InputService4TestShapeInputService4TestCaseOperation1Output{}
req.Data = output
return
@ -366,22 +352,19 @@ func (c *InputService4ProtocolTest) InputService4TestCaseOperation1(input *Input
return out, err
}
var opInputService4TestCaseOperation1 *aws.Operation
const opInputService4TestCaseOperation2 = "OperationName"
// InputService4TestCaseOperation2Request generates a request for the InputService4TestCaseOperation2 operation.
func (c *InputService4ProtocolTest) InputService4TestCaseOperation2Request(input *InputService4TestShapeInputShape) (req *aws.Request, output *InputService4TestShapeInputService4TestCaseOperation2Output) {
if opInputService4TestCaseOperation2 == nil {
opInputService4TestCaseOperation2 = &aws.Operation{
Name: "OperationName",
}
op := &aws.Operation{
Name: opInputService4TestCaseOperation2,
}
if input == nil {
input = &InputService4TestShapeInputShape{}
}
req = c.newRequest(opInputService4TestCaseOperation2, input, output)
req = c.newRequest(op, input, output)
output = &InputService4TestShapeInputService4TestCaseOperation2Output{}
req.Data = output
return
@ -393,8 +376,6 @@ func (c *InputService4ProtocolTest) InputService4TestCaseOperation2(input *Input
return out, err
}
var opInputService4TestCaseOperation2 *aws.Operation
type InputService4TestShapeInputService4TestCaseOperation1Output struct {
metadataInputService4TestShapeInputService4TestCaseOperation1Output `json:"-" xml:"-"`
}
@ -455,20 +436,19 @@ func (c *InputService5ProtocolTest) newRequest(op *aws.Operation, params, data i
return req
}
const opInputService5TestCaseOperation1 = "OperationName"
// InputService5TestCaseOperation1Request generates a request for the InputService5TestCaseOperation1 operation.
func (c *InputService5ProtocolTest) InputService5TestCaseOperation1Request(input *InputService5TestShapeInputShape) (req *aws.Request, output *InputService5TestShapeInputService5TestCaseOperation1Output) {
if opInputService5TestCaseOperation1 == nil {
opInputService5TestCaseOperation1 = &aws.Operation{
Name: "OperationName",
}
op := &aws.Operation{
Name: opInputService5TestCaseOperation1,
}
if input == nil {
input = &InputService5TestShapeInputShape{}
}
req = c.newRequest(opInputService5TestCaseOperation1, input, output)
req = c.newRequest(op, input, output)
output = &InputService5TestShapeInputService5TestCaseOperation1Output{}
req.Data = output
return
@ -480,22 +460,19 @@ func (c *InputService5ProtocolTest) InputService5TestCaseOperation1(input *Input
return out, err
}
var opInputService5TestCaseOperation1 *aws.Operation
const opInputService5TestCaseOperation2 = "OperationName"
// InputService5TestCaseOperation2Request generates a request for the InputService5TestCaseOperation2 operation.
func (c *InputService5ProtocolTest) InputService5TestCaseOperation2Request(input *InputService5TestShapeInputShape) (req *aws.Request, output *InputService5TestShapeInputService5TestCaseOperation2Output) {
if opInputService5TestCaseOperation2 == nil {
opInputService5TestCaseOperation2 = &aws.Operation{
Name: "OperationName",
}
op := &aws.Operation{
Name: opInputService5TestCaseOperation2,
}
if input == nil {
input = &InputService5TestShapeInputShape{}
}
req = c.newRequest(opInputService5TestCaseOperation2, input, output)
req = c.newRequest(op, input, output)
output = &InputService5TestShapeInputService5TestCaseOperation2Output{}
req.Data = output
return
@ -507,8 +484,6 @@ func (c *InputService5ProtocolTest) InputService5TestCaseOperation2(input *Input
return out, err
}
var opInputService5TestCaseOperation2 *aws.Operation
type InputService5TestShapeInputService5TestCaseOperation1Output struct {
metadataInputService5TestShapeInputService5TestCaseOperation1Output `json:"-" xml:"-"`
}
@ -567,20 +542,19 @@ func (c *InputService6ProtocolTest) newRequest(op *aws.Operation, params, data i
return req
}
const opInputService6TestCaseOperation1 = "OperationName"
// InputService6TestCaseOperation1Request generates a request for the InputService6TestCaseOperation1 operation.
func (c *InputService6ProtocolTest) InputService6TestCaseOperation1Request(input *InputService6TestShapeInputShape) (req *aws.Request, output *InputService6TestShapeInputService6TestCaseOperation1Output) {
if opInputService6TestCaseOperation1 == nil {
opInputService6TestCaseOperation1 = &aws.Operation{
Name: "OperationName",
}
op := &aws.Operation{
Name: opInputService6TestCaseOperation1,
}
if input == nil {
input = &InputService6TestShapeInputShape{}
}
req = c.newRequest(opInputService6TestCaseOperation1, input, output)
req = c.newRequest(op, input, output)
output = &InputService6TestShapeInputService6TestCaseOperation1Output{}
req.Data = output
return
@ -592,8 +566,6 @@ func (c *InputService6ProtocolTest) InputService6TestCaseOperation1(input *Input
return out, err
}
var opInputService6TestCaseOperation1 *aws.Operation
type InputService6TestShapeInputService6TestCaseOperation1Output struct {
metadataInputService6TestShapeInputService6TestCaseOperation1Output `json:"-" xml:"-"`
}
@ -644,20 +616,19 @@ func (c *InputService7ProtocolTest) newRequest(op *aws.Operation, params, data i
return req
}
const opInputService7TestCaseOperation1 = "OperationName"
// InputService7TestCaseOperation1Request generates a request for the InputService7TestCaseOperation1 operation.
func (c *InputService7ProtocolTest) InputService7TestCaseOperation1Request(input *InputService7TestShapeInputShape) (req *aws.Request, output *InputService7TestShapeInputService7TestCaseOperation1Output) {
if opInputService7TestCaseOperation1 == nil {
opInputService7TestCaseOperation1 = &aws.Operation{
Name: "OperationName",
}
op := &aws.Operation{
Name: opInputService7TestCaseOperation1,
}
if input == nil {
input = &InputService7TestShapeInputShape{}
}
req = c.newRequest(opInputService7TestCaseOperation1, input, output)
req = c.newRequest(op, input, output)
output = &InputService7TestShapeInputService7TestCaseOperation1Output{}
req.Data = output
return
@ -669,8 +640,6 @@ func (c *InputService7ProtocolTest) InputService7TestCaseOperation1(input *Input
return out, err
}
var opInputService7TestCaseOperation1 *aws.Operation
type InputService7TestShapeInputService7TestCaseOperation1Output struct {
metadataInputService7TestShapeInputService7TestCaseOperation1Output `json:"-" xml:"-"`
}
@ -721,20 +690,19 @@ func (c *InputService8ProtocolTest) newRequest(op *aws.Operation, params, data i
return req
}
const opInputService8TestCaseOperation1 = "OperationName"
// InputService8TestCaseOperation1Request generates a request for the InputService8TestCaseOperation1 operation.
func (c *InputService8ProtocolTest) InputService8TestCaseOperation1Request(input *InputService8TestShapeInputShape) (req *aws.Request, output *InputService8TestShapeInputService8TestCaseOperation1Output) {
if opInputService8TestCaseOperation1 == nil {
opInputService8TestCaseOperation1 = &aws.Operation{
Name: "OperationName",
}
op := &aws.Operation{
Name: opInputService8TestCaseOperation1,
}
if input == nil {
input = &InputService8TestShapeInputShape{}
}
req = c.newRequest(opInputService8TestCaseOperation1, input, output)
req = c.newRequest(op, input, output)
output = &InputService8TestShapeInputService8TestCaseOperation1Output{}
req.Data = output
return
@ -746,8 +714,6 @@ func (c *InputService8ProtocolTest) InputService8TestCaseOperation1(input *Input
return out, err
}
var opInputService8TestCaseOperation1 *aws.Operation
type InputService8TestShapeInputService8TestCaseOperation1Output struct {
metadataInputService8TestShapeInputService8TestCaseOperation1Output `json:"-" xml:"-"`
}
@ -798,20 +764,19 @@ func (c *InputService9ProtocolTest) newRequest(op *aws.Operation, params, data i
return req
}
const opInputService9TestCaseOperation1 = "OperationName"
// InputService9TestCaseOperation1Request generates a request for the InputService9TestCaseOperation1 operation.
func (c *InputService9ProtocolTest) InputService9TestCaseOperation1Request(input *InputService9TestShapeInputShape) (req *aws.Request, output *InputService9TestShapeInputService9TestCaseOperation1Output) {
if opInputService9TestCaseOperation1 == nil {
opInputService9TestCaseOperation1 = &aws.Operation{
Name: "OperationName",
}
op := &aws.Operation{
Name: opInputService9TestCaseOperation1,
}
if input == nil {
input = &InputService9TestShapeInputShape{}
}
req = c.newRequest(opInputService9TestCaseOperation1, input, output)
req = c.newRequest(op, input, output)
output = &InputService9TestShapeInputService9TestCaseOperation1Output{}
req.Data = output
return
@ -823,22 +788,19 @@ func (c *InputService9ProtocolTest) InputService9TestCaseOperation1(input *Input
return out, err
}
var opInputService9TestCaseOperation1 *aws.Operation
const opInputService9TestCaseOperation2 = "OperationName"
// InputService9TestCaseOperation2Request generates a request for the InputService9TestCaseOperation2 operation.
func (c *InputService9ProtocolTest) InputService9TestCaseOperation2Request(input *InputService9TestShapeInputShape) (req *aws.Request, output *InputService9TestShapeInputService9TestCaseOperation2Output) {
if opInputService9TestCaseOperation2 == nil {
opInputService9TestCaseOperation2 = &aws.Operation{
Name: "OperationName",
}
op := &aws.Operation{
Name: opInputService9TestCaseOperation2,
}
if input == nil {
input = &InputService9TestShapeInputShape{}
}
req = c.newRequest(opInputService9TestCaseOperation2, input, output)
req = c.newRequest(op, input, output)
output = &InputService9TestShapeInputService9TestCaseOperation2Output{}
req.Data = output
return
@ -850,22 +812,19 @@ func (c *InputService9ProtocolTest) InputService9TestCaseOperation2(input *Input
return out, err
}
var opInputService9TestCaseOperation2 *aws.Operation
const opInputService9TestCaseOperation3 = "OperationName"
// InputService9TestCaseOperation3Request generates a request for the InputService9TestCaseOperation3 operation.
func (c *InputService9ProtocolTest) InputService9TestCaseOperation3Request(input *InputService9TestShapeInputShape) (req *aws.Request, output *InputService9TestShapeInputService9TestCaseOperation3Output) {
if opInputService9TestCaseOperation3 == nil {
opInputService9TestCaseOperation3 = &aws.Operation{
Name: "OperationName",
}
op := &aws.Operation{
Name: opInputService9TestCaseOperation3,
}
if input == nil {
input = &InputService9TestShapeInputShape{}
}
req = c.newRequest(opInputService9TestCaseOperation3, input, output)
req = c.newRequest(op, input, output)
output = &InputService9TestShapeInputService9TestCaseOperation3Output{}
req.Data = output
return
@ -877,22 +836,19 @@ func (c *InputService9ProtocolTest) InputService9TestCaseOperation3(input *Input
return out, err
}
var opInputService9TestCaseOperation3 *aws.Operation
const opInputService9TestCaseOperation4 = "OperationName"
// InputService9TestCaseOperation4Request generates a request for the InputService9TestCaseOperation4 operation.
func (c *InputService9ProtocolTest) InputService9TestCaseOperation4Request(input *InputService9TestShapeInputShape) (req *aws.Request, output *InputService9TestShapeInputService9TestCaseOperation4Output) {
if opInputService9TestCaseOperation4 == nil {
opInputService9TestCaseOperation4 = &aws.Operation{
Name: "OperationName",
}
op := &aws.Operation{
Name: opInputService9TestCaseOperation4,
}
if input == nil {
input = &InputService9TestShapeInputShape{}
}
req = c.newRequest(opInputService9TestCaseOperation4, input, output)
req = c.newRequest(op, input, output)
output = &InputService9TestShapeInputService9TestCaseOperation4Output{}
req.Data = output
return
@ -904,22 +860,19 @@ func (c *InputService9ProtocolTest) InputService9TestCaseOperation4(input *Input
return out, err
}
var opInputService9TestCaseOperation4 *aws.Operation
const opInputService9TestCaseOperation5 = "OperationName"
// InputService9TestCaseOperation5Request generates a request for the InputService9TestCaseOperation5 operation.
func (c *InputService9ProtocolTest) InputService9TestCaseOperation5Request(input *InputService9TestShapeInputShape) (req *aws.Request, output *InputService9TestShapeInputService9TestCaseOperation5Output) {
if opInputService9TestCaseOperation5 == nil {
opInputService9TestCaseOperation5 = &aws.Operation{
Name: "OperationName",
}
op := &aws.Operation{
Name: opInputService9TestCaseOperation5,
}
if input == nil {
input = &InputService9TestShapeInputShape{}
}
req = c.newRequest(opInputService9TestCaseOperation5, input, output)
req = c.newRequest(op, input, output)
output = &InputService9TestShapeInputService9TestCaseOperation5Output{}
req.Data = output
return
@ -931,22 +884,19 @@ func (c *InputService9ProtocolTest) InputService9TestCaseOperation5(input *Input
return out, err
}
var opInputService9TestCaseOperation5 *aws.Operation
const opInputService9TestCaseOperation6 = "OperationName"
// InputService9TestCaseOperation6Request generates a request for the InputService9TestCaseOperation6 operation.
func (c *InputService9ProtocolTest) InputService9TestCaseOperation6Request(input *InputService9TestShapeInputShape) (req *aws.Request, output *InputService9TestShapeInputService9TestCaseOperation6Output) {
if opInputService9TestCaseOperation6 == nil {
opInputService9TestCaseOperation6 = &aws.Operation{
Name: "OperationName",
}
op := &aws.Operation{
Name: opInputService9TestCaseOperation6,
}
if input == nil {
input = &InputService9TestShapeInputShape{}
}
req = c.newRequest(opInputService9TestCaseOperation6, input, output)
req = c.newRequest(op, input, output)
output = &InputService9TestShapeInputService9TestCaseOperation6Output{}
req.Data = output
return
@ -958,8 +908,6 @@ func (c *InputService9ProtocolTest) InputService9TestCaseOperation6(input *Input
return out, err
}
var opInputService9TestCaseOperation6 *aws.Operation
type InputService9TestShapeInputService9TestCaseOperation1Output struct {
metadataInputService9TestShapeInputService9TestCaseOperation1Output `json:"-" xml:"-"`
}
@ -1442,10 +1390,10 @@ func TestInputService9ProtocolTestRecursiveShapesCase4(t *testing.T) {
input := &InputService9TestShapeInputShape{
RecursiveStruct: &InputService9TestShapeRecursiveStructType{
RecursiveList: []*InputService9TestShapeRecursiveStructType{
&InputService9TestShapeRecursiveStructType{
{
NoRecurse: aws.String("foo"),
},
&InputService9TestShapeRecursiveStructType{
{
NoRecurse: aws.String("bar"),
},
},
@ -1477,10 +1425,10 @@ func TestInputService9ProtocolTestRecursiveShapesCase5(t *testing.T) {
input := &InputService9TestShapeInputShape{
RecursiveStruct: &InputService9TestShapeRecursiveStructType{
RecursiveList: []*InputService9TestShapeRecursiveStructType{
&InputService9TestShapeRecursiveStructType{
{
NoRecurse: aws.String("foo"),
},
&InputService9TestShapeRecursiveStructType{
{
RecursiveStruct: &InputService9TestShapeRecursiveStructType{
NoRecurse: aws.String("bar"),
},
@ -1514,10 +1462,10 @@ func TestInputService9ProtocolTestRecursiveShapesCase6(t *testing.T) {
input := &InputService9TestShapeInputShape{
RecursiveStruct: &InputService9TestShapeRecursiveStructType{
RecursiveMap: map[string]*InputService9TestShapeRecursiveStructType{
"bar": &InputService9TestShapeRecursiveStructType{
"bar": {
NoRecurse: aws.String("bar"),
},
"foo": &InputService9TestShapeRecursiveStructType{
"foo": {
NoRecurse: aws.String("foo"),
},
},

View File

@ -6,7 +6,7 @@ import (
"encoding/xml"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/internal/apierr"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/internal/protocol/xml/xmlutil"
)
@ -17,7 +17,7 @@ func Unmarshal(r *aws.Request) {
decoder := xml.NewDecoder(r.HTTPResponse.Body)
err := xmlutil.UnmarshalXML(r.Data, decoder, r.Operation.Name+"Result")
if err != nil {
r.Error = apierr.New("Unmarshal", "failed decoding Query response", err)
r.Error = awserr.New("SerializationError", "failed decoding Query response", err)
return
}
}

View File

@ -5,7 +5,7 @@ import (
"io"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/internal/apierr"
"github.com/aws/aws-sdk-go/aws/awserr"
)
type xmlErrorResponse struct {
@ -22,10 +22,10 @@ func UnmarshalError(r *aws.Request) {
resp := &xmlErrorResponse{}
err := xml.NewDecoder(r.HTTPResponse.Body).Decode(resp)
if err != nil && err != io.EOF {
r.Error = apierr.New("Unmarshal", "failed to decode query XML error response", err)
r.Error = awserr.New("SerializationError", "failed to decode query XML error response", err)
} else {
r.Error = apierr.NewRequestError(
apierr.New(resp.Code, resp.Message, nil),
r.Error = awserr.NewRequestFailure(
awserr.New(resp.Code, resp.Message, nil),
r.HTTPResponse.StatusCode,
resp.RequestID,
)

View File

@ -1,10 +1,6 @@
package query_test
import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/internal/protocol/query"
"github.com/aws/aws-sdk-go/internal/signer/v4"
"bytes"
"encoding/json"
"encoding/xml"
@ -15,7 +11,10 @@ import (
"testing"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/internal/protocol/query"
"github.com/aws/aws-sdk-go/internal/protocol/xml/xmlutil"
"github.com/aws/aws-sdk-go/internal/signer/v4"
"github.com/aws/aws-sdk-go/internal/util"
"github.com/stretchr/testify/assert"
)
@ -63,20 +62,19 @@ func (c *OutputService1ProtocolTest) newRequest(op *aws.Operation, params, data
return req
}
const opOutputService1TestCaseOperation1 = "OperationName"
// OutputService1TestCaseOperation1Request generates a request for the OutputService1TestCaseOperation1 operation.
func (c *OutputService1ProtocolTest) OutputService1TestCaseOperation1Request(input *OutputService1TestShapeOutputService1TestCaseOperation1Input) (req *aws.Request, output *OutputService1TestShapeOutputShape) {
if opOutputService1TestCaseOperation1 == nil {
opOutputService1TestCaseOperation1 = &aws.Operation{
Name: "OperationName",
}
op := &aws.Operation{
Name: opOutputService1TestCaseOperation1,
}
if input == nil {
input = &OutputService1TestShapeOutputService1TestCaseOperation1Input{}
}
req = c.newRequest(opOutputService1TestCaseOperation1, input, output)
req = c.newRequest(op, input, output)
output = &OutputService1TestShapeOutputShape{}
req.Data = output
return
@ -88,8 +86,6 @@ func (c *OutputService1ProtocolTest) OutputService1TestCaseOperation1(input *Out
return out, err
}
var opOutputService1TestCaseOperation1 *aws.Operation
type OutputService1TestShapeOutputService1TestCaseOperation1Input struct {
metadataOutputService1TestShapeOutputService1TestCaseOperation1Input `json:"-" xml:"-"`
}
@ -156,20 +152,19 @@ func (c *OutputService2ProtocolTest) newRequest(op *aws.Operation, params, data
return req
}
const opOutputService2TestCaseOperation1 = "OperationName"
// OutputService2TestCaseOperation1Request generates a request for the OutputService2TestCaseOperation1 operation.
func (c *OutputService2ProtocolTest) OutputService2TestCaseOperation1Request(input *OutputService2TestShapeOutputService2TestCaseOperation1Input) (req *aws.Request, output *OutputService2TestShapeOutputShape) {
if opOutputService2TestCaseOperation1 == nil {
opOutputService2TestCaseOperation1 = &aws.Operation{
Name: "OperationName",
}
op := &aws.Operation{
Name: opOutputService2TestCaseOperation1,
}
if input == nil {
input = &OutputService2TestShapeOutputService2TestCaseOperation1Input{}
}
req = c.newRequest(opOutputService2TestCaseOperation1, input, output)
req = c.newRequest(op, input, output)
output = &OutputService2TestShapeOutputShape{}
req.Data = output
return
@ -181,8 +176,6 @@ func (c *OutputService2ProtocolTest) OutputService2TestCaseOperation1(input *Out
return out, err
}
var opOutputService2TestCaseOperation1 *aws.Operation
type OutputService2TestShapeOutputService2TestCaseOperation1Input struct {
metadataOutputService2TestShapeOutputService2TestCaseOperation1Input `json:"-" xml:"-"`
}
@ -235,20 +228,19 @@ func (c *OutputService3ProtocolTest) newRequest(op *aws.Operation, params, data
return req
}
const opOutputService3TestCaseOperation1 = "OperationName"
// OutputService3TestCaseOperation1Request generates a request for the OutputService3TestCaseOperation1 operation.
func (c *OutputService3ProtocolTest) OutputService3TestCaseOperation1Request(input *OutputService3TestShapeOutputService3TestCaseOperation1Input) (req *aws.Request, output *OutputService3TestShapeOutputShape) {
if opOutputService3TestCaseOperation1 == nil {
opOutputService3TestCaseOperation1 = &aws.Operation{
Name: "OperationName",
}
op := &aws.Operation{
Name: opOutputService3TestCaseOperation1,
}
if input == nil {
input = &OutputService3TestShapeOutputService3TestCaseOperation1Input{}
}
req = c.newRequest(opOutputService3TestCaseOperation1, input, output)
req = c.newRequest(op, input, output)
output = &OutputService3TestShapeOutputShape{}
req.Data = output
return
@ -260,8 +252,6 @@ func (c *OutputService3ProtocolTest) OutputService3TestCaseOperation1(input *Out
return out, err
}
var opOutputService3TestCaseOperation1 *aws.Operation
type OutputService3TestShapeOutputService3TestCaseOperation1Input struct {
metadataOutputService3TestShapeOutputService3TestCaseOperation1Input `json:"-" xml:"-"`
}
@ -312,20 +302,19 @@ func (c *OutputService4ProtocolTest) newRequest(op *aws.Operation, params, data
return req
}
const opOutputService4TestCaseOperation1 = "OperationName"
// OutputService4TestCaseOperation1Request generates a request for the OutputService4TestCaseOperation1 operation.
func (c *OutputService4ProtocolTest) OutputService4TestCaseOperation1Request(input *OutputService4TestShapeOutputService4TestCaseOperation1Input) (req *aws.Request, output *OutputService4TestShapeOutputShape) {
if opOutputService4TestCaseOperation1 == nil {
opOutputService4TestCaseOperation1 = &aws.Operation{
Name: "OperationName",
}
op := &aws.Operation{
Name: opOutputService4TestCaseOperation1,
}
if input == nil {
input = &OutputService4TestShapeOutputService4TestCaseOperation1Input{}
}
req = c.newRequest(opOutputService4TestCaseOperation1, input, output)
req = c.newRequest(op, input, output)
output = &OutputService4TestShapeOutputShape{}
req.Data = output
return
@ -337,8 +326,6 @@ func (c *OutputService4ProtocolTest) OutputService4TestCaseOperation1(input *Out
return out, err
}
var opOutputService4TestCaseOperation1 *aws.Operation
type OutputService4TestShapeOutputService4TestCaseOperation1Input struct {
metadataOutputService4TestShapeOutputService4TestCaseOperation1Input `json:"-" xml:"-"`
}
@ -389,20 +376,19 @@ func (c *OutputService5ProtocolTest) newRequest(op *aws.Operation, params, data
return req
}
const opOutputService5TestCaseOperation1 = "OperationName"
// OutputService5TestCaseOperation1Request generates a request for the OutputService5TestCaseOperation1 operation.
func (c *OutputService5ProtocolTest) OutputService5TestCaseOperation1Request(input *OutputService5TestShapeOutputService5TestCaseOperation1Input) (req *aws.Request, output *OutputService5TestShapeOutputShape) {
if opOutputService5TestCaseOperation1 == nil {
opOutputService5TestCaseOperation1 = &aws.Operation{
Name: "OperationName",
}
op := &aws.Operation{
Name: opOutputService5TestCaseOperation1,
}
if input == nil {
input = &OutputService5TestShapeOutputService5TestCaseOperation1Input{}
}
req = c.newRequest(opOutputService5TestCaseOperation1, input, output)
req = c.newRequest(op, input, output)
output = &OutputService5TestShapeOutputShape{}
req.Data = output
return
@ -414,8 +400,6 @@ func (c *OutputService5ProtocolTest) OutputService5TestCaseOperation1(input *Out
return out, err
}
var opOutputService5TestCaseOperation1 *aws.Operation
type OutputService5TestShapeOutputService5TestCaseOperation1Input struct {
metadataOutputService5TestShapeOutputService5TestCaseOperation1Input `json:"-" xml:"-"`
}
@ -466,20 +450,19 @@ func (c *OutputService6ProtocolTest) newRequest(op *aws.Operation, params, data
return req
}
const opOutputService6TestCaseOperation1 = "OperationName"
// OutputService6TestCaseOperation1Request generates a request for the OutputService6TestCaseOperation1 operation.
func (c *OutputService6ProtocolTest) OutputService6TestCaseOperation1Request(input *OutputService6TestShapeOutputService6TestCaseOperation1Input) (req *aws.Request, output *OutputService6TestShapeOutputShape) {
if opOutputService6TestCaseOperation1 == nil {
opOutputService6TestCaseOperation1 = &aws.Operation{
Name: "OperationName",
}
op := &aws.Operation{
Name: opOutputService6TestCaseOperation1,
}
if input == nil {
input = &OutputService6TestShapeOutputService6TestCaseOperation1Input{}
}
req = c.newRequest(opOutputService6TestCaseOperation1, input, output)
req = c.newRequest(op, input, output)
output = &OutputService6TestShapeOutputShape{}
req.Data = output
return
@ -491,8 +474,6 @@ func (c *OutputService6ProtocolTest) OutputService6TestCaseOperation1(input *Out
return out, err
}
var opOutputService6TestCaseOperation1 *aws.Operation
type OutputService6TestShapeOutputService6TestCaseOperation1Input struct {
metadataOutputService6TestShapeOutputService6TestCaseOperation1Input `json:"-" xml:"-"`
}
@ -543,20 +524,19 @@ func (c *OutputService7ProtocolTest) newRequest(op *aws.Operation, params, data
return req
}
const opOutputService7TestCaseOperation1 = "OperationName"
// OutputService7TestCaseOperation1Request generates a request for the OutputService7TestCaseOperation1 operation.
func (c *OutputService7ProtocolTest) OutputService7TestCaseOperation1Request(input *OutputService7TestShapeOutputService7TestCaseOperation1Input) (req *aws.Request, output *OutputService7TestShapeOutputShape) {
if opOutputService7TestCaseOperation1 == nil {
opOutputService7TestCaseOperation1 = &aws.Operation{
Name: "OperationName",
}
op := &aws.Operation{
Name: opOutputService7TestCaseOperation1,
}
if input == nil {
input = &OutputService7TestShapeOutputService7TestCaseOperation1Input{}
}
req = c.newRequest(opOutputService7TestCaseOperation1, input, output)
req = c.newRequest(op, input, output)
output = &OutputService7TestShapeOutputShape{}
req.Data = output
return
@ -568,8 +548,6 @@ func (c *OutputService7ProtocolTest) OutputService7TestCaseOperation1(input *Out
return out, err
}
var opOutputService7TestCaseOperation1 *aws.Operation
type OutputService7TestShapeOutputService7TestCaseOperation1Input struct {
metadataOutputService7TestShapeOutputService7TestCaseOperation1Input `json:"-" xml:"-"`
}
@ -620,20 +598,19 @@ func (c *OutputService8ProtocolTest) newRequest(op *aws.Operation, params, data
return req
}
const opOutputService8TestCaseOperation1 = "OperationName"
// OutputService8TestCaseOperation1Request generates a request for the OutputService8TestCaseOperation1 operation.
func (c *OutputService8ProtocolTest) OutputService8TestCaseOperation1Request(input *OutputService8TestShapeOutputService8TestCaseOperation1Input) (req *aws.Request, output *OutputService8TestShapeOutputShape) {
if opOutputService8TestCaseOperation1 == nil {
opOutputService8TestCaseOperation1 = &aws.Operation{
Name: "OperationName",
}
op := &aws.Operation{
Name: opOutputService8TestCaseOperation1,
}
if input == nil {
input = &OutputService8TestShapeOutputService8TestCaseOperation1Input{}
}
req = c.newRequest(opOutputService8TestCaseOperation1, input, output)
req = c.newRequest(op, input, output)
output = &OutputService8TestShapeOutputShape{}
req.Data = output
return
@ -645,8 +622,6 @@ func (c *OutputService8ProtocolTest) OutputService8TestCaseOperation1(input *Out
return out, err
}
var opOutputService8TestCaseOperation1 *aws.Operation
type OutputService8TestShapeOutputService8TestCaseOperation1Input struct {
metadataOutputService8TestShapeOutputService8TestCaseOperation1Input `json:"-" xml:"-"`
}
@ -711,20 +686,19 @@ func (c *OutputService9ProtocolTest) newRequest(op *aws.Operation, params, data
return req
}
const opOutputService9TestCaseOperation1 = "OperationName"
// OutputService9TestCaseOperation1Request generates a request for the OutputService9TestCaseOperation1 operation.
func (c *OutputService9ProtocolTest) OutputService9TestCaseOperation1Request(input *OutputService9TestShapeOutputService9TestCaseOperation1Input) (req *aws.Request, output *OutputService9TestShapeOutputShape) {
if opOutputService9TestCaseOperation1 == nil {
opOutputService9TestCaseOperation1 = &aws.Operation{
Name: "OperationName",
}
op := &aws.Operation{
Name: opOutputService9TestCaseOperation1,
}
if input == nil {
input = &OutputService9TestShapeOutputService9TestCaseOperation1Input{}
}
req = c.newRequest(opOutputService9TestCaseOperation1, input, output)
req = c.newRequest(op, input, output)
output = &OutputService9TestShapeOutputShape{}
req.Data = output
return
@ -736,8 +710,6 @@ func (c *OutputService9ProtocolTest) OutputService9TestCaseOperation1(input *Out
return out, err
}
var opOutputService9TestCaseOperation1 *aws.Operation
type OutputService9TestShapeOutputService9TestCaseOperation1Input struct {
metadataOutputService9TestShapeOutputService9TestCaseOperation1Input `json:"-" xml:"-"`
}
@ -802,20 +774,19 @@ func (c *OutputService10ProtocolTest) newRequest(op *aws.Operation, params, data
return req
}
const opOutputService10TestCaseOperation1 = "OperationName"
// OutputService10TestCaseOperation1Request generates a request for the OutputService10TestCaseOperation1 operation.
func (c *OutputService10ProtocolTest) OutputService10TestCaseOperation1Request(input *OutputService10TestShapeOutputService10TestCaseOperation1Input) (req *aws.Request, output *OutputService10TestShapeOutputShape) {
if opOutputService10TestCaseOperation1 == nil {
opOutputService10TestCaseOperation1 = &aws.Operation{
Name: "OperationName",
}
op := &aws.Operation{
Name: opOutputService10TestCaseOperation1,
}
if input == nil {
input = &OutputService10TestShapeOutputService10TestCaseOperation1Input{}
}
req = c.newRequest(opOutputService10TestCaseOperation1, input, output)
req = c.newRequest(op, input, output)
output = &OutputService10TestShapeOutputShape{}
req.Data = output
return
@ -827,8 +798,6 @@ func (c *OutputService10ProtocolTest) OutputService10TestCaseOperation1(input *O
return out, err
}
var opOutputService10TestCaseOperation1 *aws.Operation
type OutputService10TestShapeOutputService10TestCaseOperation1Input struct {
metadataOutputService10TestShapeOutputService10TestCaseOperation1Input `json:"-" xml:"-"`
}
@ -879,20 +848,19 @@ func (c *OutputService11ProtocolTest) newRequest(op *aws.Operation, params, data
return req
}
const opOutputService11TestCaseOperation1 = "OperationName"
// OutputService11TestCaseOperation1Request generates a request for the OutputService11TestCaseOperation1 operation.
func (c *OutputService11ProtocolTest) OutputService11TestCaseOperation1Request(input *OutputService11TestShapeOutputService11TestCaseOperation1Input) (req *aws.Request, output *OutputService11TestShapeOutputShape) {
if opOutputService11TestCaseOperation1 == nil {
opOutputService11TestCaseOperation1 = &aws.Operation{
Name: "OperationName",
}
op := &aws.Operation{
Name: opOutputService11TestCaseOperation1,
}
if input == nil {
input = &OutputService11TestShapeOutputService11TestCaseOperation1Input{}
}
req = c.newRequest(opOutputService11TestCaseOperation1, input, output)
req = c.newRequest(op, input, output)
output = &OutputService11TestShapeOutputShape{}
req.Data = output
return
@ -904,8 +872,6 @@ func (c *OutputService11ProtocolTest) OutputService11TestCaseOperation1(input *O
return out, err
}
var opOutputService11TestCaseOperation1 *aws.Operation
type OutputService11TestShapeOutputService11TestCaseOperation1Input struct {
metadataOutputService11TestShapeOutputService11TestCaseOperation1Input `json:"-" xml:"-"`
}
@ -966,20 +932,19 @@ func (c *OutputService12ProtocolTest) newRequest(op *aws.Operation, params, data
return req
}
const opOutputService12TestCaseOperation1 = "OperationName"
// OutputService12TestCaseOperation1Request generates a request for the OutputService12TestCaseOperation1 operation.
func (c *OutputService12ProtocolTest) OutputService12TestCaseOperation1Request(input *OutputService12TestShapeOutputService12TestCaseOperation1Input) (req *aws.Request, output *OutputService12TestShapeOutputShape) {
if opOutputService12TestCaseOperation1 == nil {
opOutputService12TestCaseOperation1 = &aws.Operation{
Name: "OperationName",
}
op := &aws.Operation{
Name: opOutputService12TestCaseOperation1,
}
if input == nil {
input = &OutputService12TestShapeOutputService12TestCaseOperation1Input{}
}
req = c.newRequest(opOutputService12TestCaseOperation1, input, output)
req = c.newRequest(op, input, output)
output = &OutputService12TestShapeOutputShape{}
req.Data = output
return
@ -991,8 +956,6 @@ func (c *OutputService12ProtocolTest) OutputService12TestCaseOperation1(input *O
return out, err
}
var opOutputService12TestCaseOperation1 *aws.Operation
type OutputService12TestShapeOutputService12TestCaseOperation1Input struct {
metadataOutputService12TestShapeOutputService12TestCaseOperation1Input `json:"-" xml:"-"`
}
@ -1043,20 +1006,19 @@ func (c *OutputService13ProtocolTest) newRequest(op *aws.Operation, params, data
return req
}
const opOutputService13TestCaseOperation1 = "OperationName"
// OutputService13TestCaseOperation1Request generates a request for the OutputService13TestCaseOperation1 operation.
func (c *OutputService13ProtocolTest) OutputService13TestCaseOperation1Request(input *OutputService13TestShapeOutputService13TestCaseOperation1Input) (req *aws.Request, output *OutputService13TestShapeOutputShape) {
if opOutputService13TestCaseOperation1 == nil {
opOutputService13TestCaseOperation1 = &aws.Operation{
Name: "OperationName",
}
op := &aws.Operation{
Name: opOutputService13TestCaseOperation1,
}
if input == nil {
input = &OutputService13TestShapeOutputService13TestCaseOperation1Input{}
}
req = c.newRequest(opOutputService13TestCaseOperation1, input, output)
req = c.newRequest(op, input, output)
output = &OutputService13TestShapeOutputShape{}
req.Data = output
return
@ -1068,8 +1030,6 @@ func (c *OutputService13ProtocolTest) OutputService13TestCaseOperation1(input *O
return out, err
}
var opOutputService13TestCaseOperation1 *aws.Operation
type OutputService13TestShapeOutputService13TestCaseOperation1Input struct {
metadataOutputService13TestShapeOutputService13TestCaseOperation1Input `json:"-" xml:"-"`
}
@ -1120,20 +1080,19 @@ func (c *OutputService14ProtocolTest) newRequest(op *aws.Operation, params, data
return req
}
const opOutputService14TestCaseOperation1 = "OperationName"
// OutputService14TestCaseOperation1Request generates a request for the OutputService14TestCaseOperation1 operation.
func (c *OutputService14ProtocolTest) OutputService14TestCaseOperation1Request(input *OutputService14TestShapeOutputService14TestCaseOperation1Input) (req *aws.Request, output *OutputService14TestShapeOutputShape) {
if opOutputService14TestCaseOperation1 == nil {
opOutputService14TestCaseOperation1 = &aws.Operation{
Name: "OperationName",
}
op := &aws.Operation{
Name: opOutputService14TestCaseOperation1,
}
if input == nil {
input = &OutputService14TestShapeOutputService14TestCaseOperation1Input{}
}
req = c.newRequest(opOutputService14TestCaseOperation1, input, output)
req = c.newRequest(op, input, output)
output = &OutputService14TestShapeOutputShape{}
req.Data = output
return
@ -1145,8 +1104,6 @@ func (c *OutputService14ProtocolTest) OutputService14TestCaseOperation1(input *O
return out, err
}
var opOutputService14TestCaseOperation1 *aws.Operation
type OutputService14TestShapeOutputService14TestCaseOperation1Input struct {
metadataOutputService14TestShapeOutputService14TestCaseOperation1Input `json:"-" xml:"-"`
}

View File

@ -1,3 +1,4 @@
// Package rest provides RESTful serialisation of AWS requests and responses.
package rest
import (
@ -13,7 +14,7 @@ import (
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/internal/apierr"
"github.com/aws/aws-sdk-go/aws/awserr"
)
// RFC822 returns an RFC822 formatted timestamp for AWS protocols
@ -101,7 +102,7 @@ func buildBody(r *aws.Request, v reflect.Value) {
case string:
r.SetStringBody(reader)
default:
r.Error = apierr.New("Marshal",
r.Error = awserr.New("SerializationError",
"failed to encode REST request",
fmt.Errorf("unknown payload type %s", payload.Type()))
}
@ -114,7 +115,7 @@ func buildBody(r *aws.Request, v reflect.Value) {
func buildHeader(r *aws.Request, v reflect.Value, name string) {
str, err := convertType(v)
if err != nil {
r.Error = apierr.New("Marshal", "failed to encode REST request", err)
r.Error = awserr.New("SerializationError", "failed to encode REST request", err)
} else if str != nil {
r.HTTPRequest.Header.Add(name, *str)
}
@ -124,7 +125,7 @@ func buildHeaderMap(r *aws.Request, v reflect.Value, prefix string) {
for _, key := range v.MapKeys() {
str, err := convertType(v.MapIndex(key))
if err != nil {
r.Error = apierr.New("Marshal", "failed to encode REST request", err)
r.Error = awserr.New("SerializationError", "failed to encode REST request", err)
} else if str != nil {
r.HTTPRequest.Header.Add(prefix+key.String(), *str)
}
@ -134,7 +135,7 @@ func buildHeaderMap(r *aws.Request, v reflect.Value, prefix string) {
func buildURI(r *aws.Request, v reflect.Value, name string) {
value, err := convertType(v)
if err != nil {
r.Error = apierr.New("Marshal", "failed to encode REST request", err)
r.Error = awserr.New("SerializationError", "failed to encode REST request", err)
} else if value != nil {
uri := r.HTTPRequest.URL.Path
uri = strings.Replace(uri, "{"+name+"}", EscapePath(*value, true), -1)
@ -146,7 +147,7 @@ func buildURI(r *aws.Request, v reflect.Value, name string) {
func buildQueryString(r *aws.Request, v reflect.Value, name string, query url.Values) {
str, err := convertType(v)
if err != nil {
r.Error = apierr.New("Marshal", "failed to encode REST request", err)
r.Error = awserr.New("SerializationError", "failed to encode REST request", err)
} else if str != nil {
query.Set(name, *str)
}

View File

@ -11,7 +11,7 @@ import (
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/internal/apierr"
"github.com/aws/aws-sdk-go/aws/awserr"
)
// Unmarshal unmarshals the REST component of a response in a REST service.
@ -34,14 +34,14 @@ func unmarshalBody(r *aws.Request, v reflect.Value) {
case []byte:
b, err := ioutil.ReadAll(r.HTTPResponse.Body)
if err != nil {
r.Error = apierr.New("Unmarshal", "failed to decode REST response", err)
r.Error = awserr.New("SerializationError", "failed to decode REST response", err)
} else {
payload.Set(reflect.ValueOf(b))
}
case *string:
b, err := ioutil.ReadAll(r.HTTPResponse.Body)
if err != nil {
r.Error = apierr.New("Unmarshal", "failed to decode REST response", err)
r.Error = awserr.New("SerializationError", "failed to decode REST response", err)
} else {
str := string(b)
payload.Set(reflect.ValueOf(&str))
@ -53,7 +53,7 @@ func unmarshalBody(r *aws.Request, v reflect.Value) {
case "aws.ReadSeekCloser", "io.ReadCloser":
payload.Set(reflect.ValueOf(r.HTTPResponse.Body))
default:
r.Error = apierr.New("Unmarshal",
r.Error = awserr.New("SerializationError",
"failed to decode REST response",
fmt.Errorf("unknown payload type %s", payload.Type()))
}
@ -83,14 +83,14 @@ func unmarshalLocationElements(r *aws.Request, v reflect.Value) {
case "header":
err := unmarshalHeader(m, r.HTTPResponse.Header.Get(name))
if err != nil {
r.Error = apierr.New("Unmarshal", "failed to decode REST response", err)
r.Error = awserr.New("SerializationError", "failed to decode REST response", err)
break
}
case "headers":
prefix := field.Tag.Get("locationName")
err := unmarshalHeaderMap(m, r.HTTPResponse.Header, prefix)
if err != nil {
r.Error = apierr.New("Unmarshal", "failed to decode REST response", err)
r.Error = awserr.New("SerializationError", "failed to decode REST response", err)
break
}
}

File diff suppressed because it is too large Load Diff

View File

@ -1,3 +1,5 @@
// Package restxml provides RESTful XML serialisation of AWS
// requests and responses.
package restxml
//go:generate go run ../../fixtures/protocol/generate.go ../../fixtures/protocol/input/rest-xml.json build_test.go
@ -8,7 +10,7 @@ import (
"encoding/xml"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/internal/apierr"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/internal/protocol/query"
"github.com/aws/aws-sdk-go/internal/protocol/rest"
"github.com/aws/aws-sdk-go/internal/protocol/xml/xmlutil"
@ -22,7 +24,7 @@ func Build(r *aws.Request) {
var buf bytes.Buffer
err := xmlutil.BuildXML(r.Params, xml.NewEncoder(&buf))
if err != nil {
r.Error = apierr.New("Marshal", "failed to enode rest XML request", err)
r.Error = awserr.New("SerializationError", "failed to enode rest XML request", err)
return
}
r.SetBufferBody(buf.Bytes())
@ -36,7 +38,7 @@ func Unmarshal(r *aws.Request) {
decoder := xml.NewDecoder(r.HTTPResponse.Body)
err := xmlutil.UnmarshalXML(r.Data, decoder, "")
if err != nil {
r.Error = apierr.New("Unmarshal", "failed to decode REST XML response", err)
r.Error = awserr.New("SerializationError", "failed to decode REST XML response", err)
return
}
}

View File

@ -1,10 +1,6 @@
package restxml_test
import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/internal/protocol/restxml"
"github.com/aws/aws-sdk-go/internal/signer/v4"
"bytes"
"encoding/json"
"encoding/xml"
@ -15,7 +11,10 @@ import (
"testing"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/internal/protocol/restxml"
"github.com/aws/aws-sdk-go/internal/protocol/xml/xmlutil"
"github.com/aws/aws-sdk-go/internal/signer/v4"
"github.com/aws/aws-sdk-go/internal/util"
"github.com/stretchr/testify/assert"
)
@ -63,20 +62,19 @@ func (c *OutputService1ProtocolTest) newRequest(op *aws.Operation, params, data
return req
}
const opOutputService1TestCaseOperation1 = "OperationName"
// OutputService1TestCaseOperation1Request generates a request for the OutputService1TestCaseOperation1 operation.
func (c *OutputService1ProtocolTest) OutputService1TestCaseOperation1Request(input *OutputService1TestShapeOutputService1TestCaseOperation1Input) (req *aws.Request, output *OutputService1TestShapeOutputShape) {
if opOutputService1TestCaseOperation1 == nil {
opOutputService1TestCaseOperation1 = &aws.Operation{
Name: "OperationName",
}
op := &aws.Operation{
Name: opOutputService1TestCaseOperation1,
}
if input == nil {
input = &OutputService1TestShapeOutputService1TestCaseOperation1Input{}
}
req = c.newRequest(opOutputService1TestCaseOperation1, input, output)
req = c.newRequest(op, input, output)
output = &OutputService1TestShapeOutputShape{}
req.Data = output
return
@ -88,22 +86,19 @@ func (c *OutputService1ProtocolTest) OutputService1TestCaseOperation1(input *Out
return out, err
}
var opOutputService1TestCaseOperation1 *aws.Operation
const opOutputService1TestCaseOperation2 = "OperationName"
// OutputService1TestCaseOperation2Request generates a request for the OutputService1TestCaseOperation2 operation.
func (c *OutputService1ProtocolTest) OutputService1TestCaseOperation2Request(input *OutputService1TestShapeOutputService1TestCaseOperation2Input) (req *aws.Request, output *OutputService1TestShapeOutputShape) {
if opOutputService1TestCaseOperation2 == nil {
opOutputService1TestCaseOperation2 = &aws.Operation{
Name: "OperationName",
}
op := &aws.Operation{
Name: opOutputService1TestCaseOperation2,
}
if input == nil {
input = &OutputService1TestShapeOutputService1TestCaseOperation2Input{}
}
req = c.newRequest(opOutputService1TestCaseOperation2, input, output)
req = c.newRequest(op, input, output)
output = &OutputService1TestShapeOutputShape{}
req.Data = output
return
@ -115,8 +110,6 @@ func (c *OutputService1ProtocolTest) OutputService1TestCaseOperation2(input *Out
return out, err
}
var opOutputService1TestCaseOperation2 *aws.Operation
type OutputService1TestShapeOutputService1TestCaseOperation1Input struct {
metadataOutputService1TestShapeOutputService1TestCaseOperation1Input `json:"-" xml:"-"`
}
@ -195,20 +188,19 @@ func (c *OutputService2ProtocolTest) newRequest(op *aws.Operation, params, data
return req
}
const opOutputService2TestCaseOperation1 = "OperationName"
// OutputService2TestCaseOperation1Request generates a request for the OutputService2TestCaseOperation1 operation.
func (c *OutputService2ProtocolTest) OutputService2TestCaseOperation1Request(input *OutputService2TestShapeOutputService2TestCaseOperation1Input) (req *aws.Request, output *OutputService2TestShapeOutputShape) {
if opOutputService2TestCaseOperation1 == nil {
opOutputService2TestCaseOperation1 = &aws.Operation{
Name: "OperationName",
}
op := &aws.Operation{
Name: opOutputService2TestCaseOperation1,
}
if input == nil {
input = &OutputService2TestShapeOutputService2TestCaseOperation1Input{}
}
req = c.newRequest(opOutputService2TestCaseOperation1, input, output)
req = c.newRequest(op, input, output)
output = &OutputService2TestShapeOutputShape{}
req.Data = output
return
@ -220,8 +212,6 @@ func (c *OutputService2ProtocolTest) OutputService2TestCaseOperation1(input *Out
return out, err
}
var opOutputService2TestCaseOperation1 *aws.Operation
type OutputService2TestShapeOutputService2TestCaseOperation1Input struct {
metadataOutputService2TestShapeOutputService2TestCaseOperation1Input `json:"-" xml:"-"`
}
@ -272,20 +262,19 @@ func (c *OutputService3ProtocolTest) newRequest(op *aws.Operation, params, data
return req
}
const opOutputService3TestCaseOperation1 = "OperationName"
// OutputService3TestCaseOperation1Request generates a request for the OutputService3TestCaseOperation1 operation.
func (c *OutputService3ProtocolTest) OutputService3TestCaseOperation1Request(input *OutputService3TestShapeOutputService3TestCaseOperation1Input) (req *aws.Request, output *OutputService3TestShapeOutputShape) {
if opOutputService3TestCaseOperation1 == nil {
opOutputService3TestCaseOperation1 = &aws.Operation{
Name: "OperationName",
}
op := &aws.Operation{
Name: opOutputService3TestCaseOperation1,
}
if input == nil {
input = &OutputService3TestShapeOutputService3TestCaseOperation1Input{}
}
req = c.newRequest(opOutputService3TestCaseOperation1, input, output)
req = c.newRequest(op, input, output)
output = &OutputService3TestShapeOutputShape{}
req.Data = output
return
@ -297,8 +286,6 @@ func (c *OutputService3ProtocolTest) OutputService3TestCaseOperation1(input *Out
return out, err
}
var opOutputService3TestCaseOperation1 *aws.Operation
type OutputService3TestShapeOutputService3TestCaseOperation1Input struct {
metadataOutputService3TestShapeOutputService3TestCaseOperation1Input `json:"-" xml:"-"`
}
@ -349,20 +336,19 @@ func (c *OutputService4ProtocolTest) newRequest(op *aws.Operation, params, data
return req
}
const opOutputService4TestCaseOperation1 = "OperationName"
// OutputService4TestCaseOperation1Request generates a request for the OutputService4TestCaseOperation1 operation.
func (c *OutputService4ProtocolTest) OutputService4TestCaseOperation1Request(input *OutputService4TestShapeOutputService4TestCaseOperation1Input) (req *aws.Request, output *OutputService4TestShapeOutputShape) {
if opOutputService4TestCaseOperation1 == nil {
opOutputService4TestCaseOperation1 = &aws.Operation{
Name: "OperationName",
}
op := &aws.Operation{
Name: opOutputService4TestCaseOperation1,
}
if input == nil {
input = &OutputService4TestShapeOutputService4TestCaseOperation1Input{}
}
req = c.newRequest(opOutputService4TestCaseOperation1, input, output)
req = c.newRequest(op, input, output)
output = &OutputService4TestShapeOutputShape{}
req.Data = output
return
@ -374,8 +360,6 @@ func (c *OutputService4ProtocolTest) OutputService4TestCaseOperation1(input *Out
return out, err
}
var opOutputService4TestCaseOperation1 *aws.Operation
type OutputService4TestShapeOutputService4TestCaseOperation1Input struct {
metadataOutputService4TestShapeOutputService4TestCaseOperation1Input `json:"-" xml:"-"`
}
@ -426,20 +410,19 @@ func (c *OutputService5ProtocolTest) newRequest(op *aws.Operation, params, data
return req
}
const opOutputService5TestCaseOperation1 = "OperationName"
// OutputService5TestCaseOperation1Request generates a request for the OutputService5TestCaseOperation1 operation.
func (c *OutputService5ProtocolTest) OutputService5TestCaseOperation1Request(input *OutputService5TestShapeOutputService5TestCaseOperation1Input) (req *aws.Request, output *OutputService5TestShapeOutputShape) {
if opOutputService5TestCaseOperation1 == nil {
opOutputService5TestCaseOperation1 = &aws.Operation{
Name: "OperationName",
}
op := &aws.Operation{
Name: opOutputService5TestCaseOperation1,
}
if input == nil {
input = &OutputService5TestShapeOutputService5TestCaseOperation1Input{}
}
req = c.newRequest(opOutputService5TestCaseOperation1, input, output)
req = c.newRequest(op, input, output)
output = &OutputService5TestShapeOutputShape{}
req.Data = output
return
@ -451,8 +434,6 @@ func (c *OutputService5ProtocolTest) OutputService5TestCaseOperation1(input *Out
return out, err
}
var opOutputService5TestCaseOperation1 *aws.Operation
type OutputService5TestShapeOutputService5TestCaseOperation1Input struct {
metadataOutputService5TestShapeOutputService5TestCaseOperation1Input `json:"-" xml:"-"`
}
@ -503,20 +484,19 @@ func (c *OutputService6ProtocolTest) newRequest(op *aws.Operation, params, data
return req
}
const opOutputService6TestCaseOperation1 = "OperationName"
// OutputService6TestCaseOperation1Request generates a request for the OutputService6TestCaseOperation1 operation.
func (c *OutputService6ProtocolTest) OutputService6TestCaseOperation1Request(input *OutputService6TestShapeOutputService6TestCaseOperation1Input) (req *aws.Request, output *OutputService6TestShapeOutputShape) {
if opOutputService6TestCaseOperation1 == nil {
opOutputService6TestCaseOperation1 = &aws.Operation{
Name: "OperationName",
}
op := &aws.Operation{
Name: opOutputService6TestCaseOperation1,
}
if input == nil {
input = &OutputService6TestShapeOutputService6TestCaseOperation1Input{}
}
req = c.newRequest(opOutputService6TestCaseOperation1, input, output)
req = c.newRequest(op, input, output)
output = &OutputService6TestShapeOutputShape{}
req.Data = output
return
@ -528,8 +508,6 @@ func (c *OutputService6ProtocolTest) OutputService6TestCaseOperation1(input *Out
return out, err
}
var opOutputService6TestCaseOperation1 *aws.Operation
type OutputService6TestShapeOutputService6TestCaseOperation1Input struct {
metadataOutputService6TestShapeOutputService6TestCaseOperation1Input `json:"-" xml:"-"`
}
@ -590,20 +568,19 @@ func (c *OutputService7ProtocolTest) newRequest(op *aws.Operation, params, data
return req
}
const opOutputService7TestCaseOperation1 = "OperationName"
// OutputService7TestCaseOperation1Request generates a request for the OutputService7TestCaseOperation1 operation.
func (c *OutputService7ProtocolTest) OutputService7TestCaseOperation1Request(input *OutputService7TestShapeOutputService7TestCaseOperation1Input) (req *aws.Request, output *OutputService7TestShapeOutputShape) {
if opOutputService7TestCaseOperation1 == nil {
opOutputService7TestCaseOperation1 = &aws.Operation{
Name: "OperationName",
}
op := &aws.Operation{
Name: opOutputService7TestCaseOperation1,
}
if input == nil {
input = &OutputService7TestShapeOutputService7TestCaseOperation1Input{}
}
req = c.newRequest(opOutputService7TestCaseOperation1, input, output)
req = c.newRequest(op, input, output)
output = &OutputService7TestShapeOutputShape{}
req.Data = output
return
@ -615,8 +592,6 @@ func (c *OutputService7ProtocolTest) OutputService7TestCaseOperation1(input *Out
return out, err
}
var opOutputService7TestCaseOperation1 *aws.Operation
type OutputService7TestShapeOutputService7TestCaseOperation1Input struct {
metadataOutputService7TestShapeOutputService7TestCaseOperation1Input `json:"-" xml:"-"`
}
@ -667,20 +642,19 @@ func (c *OutputService8ProtocolTest) newRequest(op *aws.Operation, params, data
return req
}
const opOutputService8TestCaseOperation1 = "OperationName"
// OutputService8TestCaseOperation1Request generates a request for the OutputService8TestCaseOperation1 operation.
func (c *OutputService8ProtocolTest) OutputService8TestCaseOperation1Request(input *OutputService8TestShapeOutputService8TestCaseOperation1Input) (req *aws.Request, output *OutputService8TestShapeOutputShape) {
if opOutputService8TestCaseOperation1 == nil {
opOutputService8TestCaseOperation1 = &aws.Operation{
Name: "OperationName",
}
op := &aws.Operation{
Name: opOutputService8TestCaseOperation1,
}
if input == nil {
input = &OutputService8TestShapeOutputService8TestCaseOperation1Input{}
}
req = c.newRequest(opOutputService8TestCaseOperation1, input, output)
req = c.newRequest(op, input, output)
output = &OutputService8TestShapeOutputShape{}
req.Data = output
return
@ -692,8 +666,6 @@ func (c *OutputService8ProtocolTest) OutputService8TestCaseOperation1(input *Out
return out, err
}
var opOutputService8TestCaseOperation1 *aws.Operation
type OutputService8TestShapeOutputService8TestCaseOperation1Input struct {
metadataOutputService8TestShapeOutputService8TestCaseOperation1Input `json:"-" xml:"-"`
}
@ -744,20 +716,19 @@ func (c *OutputService9ProtocolTest) newRequest(op *aws.Operation, params, data
return req
}
const opOutputService9TestCaseOperation1 = "OperationName"
// OutputService9TestCaseOperation1Request generates a request for the OutputService9TestCaseOperation1 operation.
func (c *OutputService9ProtocolTest) OutputService9TestCaseOperation1Request(input *OutputService9TestShapeOutputService9TestCaseOperation1Input) (req *aws.Request, output *OutputService9TestShapeOutputShape) {
if opOutputService9TestCaseOperation1 == nil {
opOutputService9TestCaseOperation1 = &aws.Operation{
Name: "OperationName",
}
op := &aws.Operation{
Name: opOutputService9TestCaseOperation1,
}
if input == nil {
input = &OutputService9TestShapeOutputService9TestCaseOperation1Input{}
}
req = c.newRequest(opOutputService9TestCaseOperation1, input, output)
req = c.newRequest(op, input, output)
output = &OutputService9TestShapeOutputShape{}
req.Data = output
return
@ -769,8 +740,6 @@ func (c *OutputService9ProtocolTest) OutputService9TestCaseOperation1(input *Out
return out, err
}
var opOutputService9TestCaseOperation1 *aws.Operation
type OutputService9TestShapeOutputService9TestCaseOperation1Input struct {
metadataOutputService9TestShapeOutputService9TestCaseOperation1Input `json:"-" xml:"-"`
}
@ -833,20 +802,19 @@ func (c *OutputService10ProtocolTest) newRequest(op *aws.Operation, params, data
return req
}
const opOutputService10TestCaseOperation1 = "OperationName"
// OutputService10TestCaseOperation1Request generates a request for the OutputService10TestCaseOperation1 operation.
func (c *OutputService10ProtocolTest) OutputService10TestCaseOperation1Request(input *OutputService10TestShapeOutputService10TestCaseOperation1Input) (req *aws.Request, output *OutputService10TestShapeOutputShape) {
if opOutputService10TestCaseOperation1 == nil {
opOutputService10TestCaseOperation1 = &aws.Operation{
Name: "OperationName",
}
op := &aws.Operation{
Name: opOutputService10TestCaseOperation1,
}
if input == nil {
input = &OutputService10TestShapeOutputService10TestCaseOperation1Input{}
}
req = c.newRequest(opOutputService10TestCaseOperation1, input, output)
req = c.newRequest(op, input, output)
output = &OutputService10TestShapeOutputShape{}
req.Data = output
return
@ -858,8 +826,6 @@ func (c *OutputService10ProtocolTest) OutputService10TestCaseOperation1(input *O
return out, err
}
var opOutputService10TestCaseOperation1 *aws.Operation
type OutputService10TestShapeOutputService10TestCaseOperation1Input struct {
metadataOutputService10TestShapeOutputService10TestCaseOperation1Input `json:"-" xml:"-"`
}
@ -910,20 +876,19 @@ func (c *OutputService11ProtocolTest) newRequest(op *aws.Operation, params, data
return req
}
const opOutputService11TestCaseOperation1 = "OperationName"
// OutputService11TestCaseOperation1Request generates a request for the OutputService11TestCaseOperation1 operation.
func (c *OutputService11ProtocolTest) OutputService11TestCaseOperation1Request(input *OutputService11TestShapeOutputService11TestCaseOperation1Input) (req *aws.Request, output *OutputService11TestShapeOutputShape) {
if opOutputService11TestCaseOperation1 == nil {
opOutputService11TestCaseOperation1 = &aws.Operation{
Name: "OperationName",
}
op := &aws.Operation{
Name: opOutputService11TestCaseOperation1,
}
if input == nil {
input = &OutputService11TestShapeOutputService11TestCaseOperation1Input{}
}
req = c.newRequest(opOutputService11TestCaseOperation1, input, output)
req = c.newRequest(op, input, output)
output = &OutputService11TestShapeOutputShape{}
req.Data = output
return
@ -935,8 +900,6 @@ func (c *OutputService11ProtocolTest) OutputService11TestCaseOperation1(input *O
return out, err
}
var opOutputService11TestCaseOperation1 *aws.Operation
type OutputService11TestShapeOutputService11TestCaseOperation1Input struct {
metadataOutputService11TestShapeOutputService11TestCaseOperation1Input `json:"-" xml:"-"`
}
@ -1003,20 +966,19 @@ func (c *OutputService12ProtocolTest) newRequest(op *aws.Operation, params, data
return req
}
const opOutputService12TestCaseOperation1 = "OperationName"
// OutputService12TestCaseOperation1Request generates a request for the OutputService12TestCaseOperation1 operation.
func (c *OutputService12ProtocolTest) OutputService12TestCaseOperation1Request(input *OutputService12TestShapeOutputService12TestCaseOperation1Input) (req *aws.Request, output *OutputService12TestShapeOutputShape) {
if opOutputService12TestCaseOperation1 == nil {
opOutputService12TestCaseOperation1 = &aws.Operation{
Name: "OperationName",
}
op := &aws.Operation{
Name: opOutputService12TestCaseOperation1,
}
if input == nil {
input = &OutputService12TestShapeOutputService12TestCaseOperation1Input{}
}
req = c.newRequest(opOutputService12TestCaseOperation1, input, output)
req = c.newRequest(op, input, output)
output = &OutputService12TestShapeOutputShape{}
req.Data = output
return
@ -1028,8 +990,6 @@ func (c *OutputService12ProtocolTest) OutputService12TestCaseOperation1(input *O
return out, err
}
var opOutputService12TestCaseOperation1 *aws.Operation
type OutputService12TestShapeOutputService12TestCaseOperation1Input struct {
metadataOutputService12TestShapeOutputService12TestCaseOperation1Input `json:"-" xml:"-"`
}

View File

@ -1,3 +1,4 @@
// Package xmlutil provides XML serialisation of AWS requests and responses.
package xmlutil
import (

View File

@ -114,7 +114,7 @@ func parseStruct(r reflect.Value, node *XMLNode, tag reflect.StructTag) error {
for _, a := range node.Attr {
if name == a.Name.Local {
// turn this into a text node for de-serializing
elems = []*XMLNode{&XMLNode{Text: a.Value}}
elems = []*XMLNode{{Text: a.Value}}
}
}
}

View File

@ -1,3 +1,4 @@
// Package v4 implements signing for AWS V4 signer
package v4
import (
@ -33,18 +34,17 @@ var ignoredHeaders = map[string]bool{
}
type signer struct {
Request *http.Request
Time time.Time
ExpireTime time.Duration
ServiceName string
Region string
AccessKeyID string
SecretAccessKey string
SessionToken string
Query url.Values
Body io.ReadSeeker
Debug uint
Logger io.Writer
Request *http.Request
Time time.Time
ExpireTime time.Duration
ServiceName string
Region string
CredValues credentials.Value
Credentials *credentials.Credentials
Query url.Values
Body io.ReadSeeker
Debug uint
Logger io.Writer
isPresign bool
formattedTime string
@ -70,11 +70,6 @@ func Sign(req *aws.Request) {
if req.Service.Config.Credentials == credentials.AnonymousCredentials {
return
}
creds, err := req.Service.Config.Credentials.Get()
if err != nil {
req.Error = err
return
}
region := req.Service.SigningRegion
if region == "" {
@ -87,56 +82,84 @@ func Sign(req *aws.Request) {
}
s := signer{
Request: req.HTTPRequest,
Time: req.Time,
ExpireTime: req.ExpireTime,
Query: req.HTTPRequest.URL.Query(),
Body: req.Body,
ServiceName: name,
Region: region,
AccessKeyID: creds.AccessKeyID,
SecretAccessKey: creds.SecretAccessKey,
SessionToken: creds.SessionToken,
Debug: req.Service.Config.LogLevel,
Logger: req.Service.Config.Logger,
Request: req.HTTPRequest,
Time: req.Time,
ExpireTime: req.ExpireTime,
Query: req.HTTPRequest.URL.Query(),
Body: req.Body,
ServiceName: name,
Region: region,
Credentials: req.Service.Config.Credentials,
Debug: req.Service.Config.LogLevel,
Logger: req.Service.Config.Logger,
}
s.sign()
return
req.Error = s.sign()
}
func (v4 *signer) sign() {
func (v4 *signer) sign() error {
if v4.ExpireTime != 0 {
v4.isPresign = true
}
if v4.isRequestSigned() {
if !v4.Credentials.IsExpired() {
// If the request is already signed, and the credentials have not
// expired yet ignore the signing request.
return nil
}
// The credentials have expired for this request. The current signing
// is invalid, and needs to be request because the request will fail.
if v4.isPresign {
v4.removePresign()
// Update the request's query string to ensure the values stays in
// sync in the case retrieving the new credentials fails.
v4.Request.URL.RawQuery = v4.Query.Encode()
}
}
var err error
v4.CredValues, err = v4.Credentials.Get()
if err != nil {
return err
}
if v4.isPresign {
v4.Query.Set("X-Amz-Algorithm", authHeaderPrefix)
if v4.SessionToken != "" {
v4.Query.Set("X-Amz-Security-Token", v4.SessionToken)
if v4.CredValues.SessionToken != "" {
v4.Query.Set("X-Amz-Security-Token", v4.CredValues.SessionToken)
} else {
v4.Query.Del("X-Amz-Security-Token")
}
} else if v4.SessionToken != "" {
v4.Request.Header.Set("X-Amz-Security-Token", v4.SessionToken)
} else if v4.CredValues.SessionToken != "" {
v4.Request.Header.Set("X-Amz-Security-Token", v4.CredValues.SessionToken)
}
v4.build()
if v4.Debug > 0 {
out := v4.Logger
fmt.Fprintf(out, "---[ CANONICAL STRING ]-----------------------------\n")
fmt.Fprintln(out, v4.canonicalString)
fmt.Fprintf(out, "---[ STRING TO SIGN ]--------------------------------\n")
fmt.Fprintln(out, v4.stringToSign)
if v4.isPresign {
fmt.Fprintf(out, "---[ SIGNED URL ]--------------------------------\n")
fmt.Fprintln(out, v4.Request.URL)
}
fmt.Fprintf(out, "-----------------------------------------------------\n")
v4.logSigningInfo()
}
return nil
}
func (v4 *signer) logSigningInfo() {
out := v4.Logger
fmt.Fprintf(out, "---[ CANONICAL STRING ]-----------------------------\n")
fmt.Fprintln(out, v4.canonicalString)
fmt.Fprintf(out, "---[ STRING TO SIGN ]--------------------------------\n")
fmt.Fprintln(out, v4.stringToSign)
if v4.isPresign {
fmt.Fprintf(out, "---[ SIGNED URL ]--------------------------------\n")
fmt.Fprintln(out, v4.Request.URL)
}
fmt.Fprintf(out, "-----------------------------------------------------\n")
}
func (v4 *signer) build() {
v4.buildTime() // no depends
v4.buildCredentialString() // no depends
if v4.isPresign {
@ -151,7 +174,7 @@ func (v4 *signer) build() {
v4.Request.URL.RawQuery += "&X-Amz-Signature=" + v4.signature
} else {
parts := []string{
authHeaderPrefix + " Credential=" + v4.AccessKeyID + "/" + v4.credentialString,
authHeaderPrefix + " Credential=" + v4.CredValues.AccessKeyID + "/" + v4.credentialString,
"SignedHeaders=" + v4.signedHeaders,
"Signature=" + v4.signature,
}
@ -181,7 +204,7 @@ func (v4 *signer) buildCredentialString() {
}, "/")
if v4.isPresign {
v4.Query.Set("X-Amz-Credential", v4.AccessKeyID+"/"+v4.credentialString)
v4.Query.Set("X-Amz-Credential", v4.CredValues.AccessKeyID+"/"+v4.credentialString)
}
}
@ -268,7 +291,7 @@ func (v4 *signer) buildStringToSign() {
}
func (v4 *signer) buildSignature() {
secret := v4.SecretAccessKey
secret := v4.CredValues.SecretAccessKey
date := makeHmac([]byte("AWS4"+secret), []byte(v4.formattedShortTime))
region := makeHmac(date, []byte(v4.Region))
service := makeHmac(region, []byte(v4.ServiceName))
@ -292,6 +315,29 @@ func (v4 *signer) bodyDigest() string {
return hash
}
// isRequestSigned returns if the request is currently signed or presigned
func (v4 *signer) isRequestSigned() bool {
if v4.isPresign && v4.Query.Get("X-Amz-Signature") != "" {
return true
}
if v4.Request.Header.Get("Authorization") != "" {
return true
}
return false
}
// unsign removes signing flags for both signed and presigned requests.
func (v4 *signer) removePresign() {
v4.Query.Del("X-Amz-Algorithm")
v4.Query.Del("X-Amz-Signature")
v4.Query.Del("X-Amz-Security-Token")
v4.Query.Del("X-Amz-Date")
v4.Query.Del("X-Amz-Expires")
v4.Query.Del("X-Amz-Credential")
v4.Query.Del("X-Amz-SignedHeaders")
}
func makeHmac(key []byte, data []byte) []byte {
hash := hmac.New(sha256.New, key)
hash.Write(data)
@ -305,21 +351,10 @@ func makeSha256(data []byte) []byte {
}
func makeSha256Reader(reader io.ReadSeeker) []byte {
packet := make([]byte, 4096)
hash := sha256.New()
start, _ := reader.Seek(0, 1)
defer reader.Seek(start, 0)
for {
n, err := reader.Read(packet)
if n > 0 {
hash.Write(packet[0:n])
}
if err == io.EOF || n == 0 {
break
}
}
io.Copy(hash, reader)
return hash.Sum(nil)
}

View File

@ -22,16 +22,14 @@ func buildSigner(serviceName string, region string, signTime time.Time, expireTi
req.Header.Add("X-Amz-Meta-Other-Header", "some-value=!@#$%^&* (+)")
return signer{
Request: req,
Time: signTime,
ExpireTime: expireTime,
Query: req.URL.Query(),
Body: reader,
ServiceName: serviceName,
Region: region,
AccessKeyID: "AKID",
SecretAccessKey: "SECRET",
SessionToken: "SESSION",
Request: req,
Time: signTime,
ExpireTime: expireTime,
Query: req.URL.Query(),
Body: reader,
ServiceName: serviceName,
Region: region,
Credentials: credentials.NewStaticCredentials("AKID", "SECRET", "SESSION"),
}
}
@ -141,6 +139,97 @@ func TestAnonymousCredentials(t *testing.T) {
assert.Empty(t, hQ.Get("X-Amz-Date"))
}
func TestIgnoreResignRequestWithValidCreds(t *testing.T) {
r := aws.NewRequest(
aws.NewService(&aws.Config{
Credentials: credentials.NewStaticCredentials("AKID", "SECRET", "SESSION"),
Region: "us-west-2",
}),
&aws.Operation{
Name: "BatchGetItem",
HTTPMethod: "POST",
HTTPPath: "/",
},
nil,
nil,
)
Sign(r)
sig := r.HTTPRequest.Header.Get("Authorization")
Sign(r)
assert.Equal(t, sig, r.HTTPRequest.Header.Get("Authorization"))
}
func TestIgnorePreResignRequestWithValidCreds(t *testing.T) {
r := aws.NewRequest(
aws.NewService(&aws.Config{
Credentials: credentials.NewStaticCredentials("AKID", "SECRET", "SESSION"),
Region: "us-west-2",
}),
&aws.Operation{
Name: "BatchGetItem",
HTTPMethod: "POST",
HTTPPath: "/",
},
nil,
nil,
)
r.ExpireTime = time.Minute * 10
Sign(r)
sig := r.HTTPRequest.Header.Get("X-Amz-Signature")
Sign(r)
assert.Equal(t, sig, r.HTTPRequest.Header.Get("X-Amz-Signature"))
}
func TestResignRequestExpiredCreds(t *testing.T) {
creds := credentials.NewStaticCredentials("AKID", "SECRET", "SESSION")
r := aws.NewRequest(
aws.NewService(&aws.Config{Credentials: creds}),
&aws.Operation{
Name: "BatchGetItem",
HTTPMethod: "POST",
HTTPPath: "/",
},
nil,
nil,
)
Sign(r)
querySig := r.HTTPRequest.Header.Get("Authorization")
creds.Expire()
Sign(r)
assert.NotEqual(t, querySig, r.HTTPRequest.Header.Get("Authorization"))
}
func TestPreResignRequestExpiredCreds(t *testing.T) {
provider := &credentials.StaticProvider{credentials.Value{"AKID", "SECRET", "SESSION"}}
creds := credentials.NewCredentials(provider)
r := aws.NewRequest(
aws.NewService(&aws.Config{Credentials: creds}),
&aws.Operation{
Name: "BatchGetItem",
HTTPMethod: "POST",
HTTPPath: "/",
},
nil,
nil,
)
r.ExpireTime = time.Minute * 10
Sign(r)
querySig := r.HTTPRequest.URL.Query().Get("X-Amz-Signature")
creds.Expire()
r.Time = time.Now().Add(time.Hour * 48)
Sign(r)
assert.NotEqual(t, querySig, r.HTTPRequest.URL.Query().Get("X-Amz-Signature"))
}
func BenchmarkPresignRequest(b *testing.B) {
signer := buildSigner("dynamodb", "us-east-1", time.Now(), 300*time.Second, "{}")
for i := 0; i < b.N; i++ {

File diff suppressed because it is too large Load Diff

View File

@ -5,8 +5,8 @@ import (
"regexp"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/aws/aws-sdk-go/internal/apierr"
)
var reBucketLocation = regexp.MustCompile(`>([^<>]+)<\/Location`)
@ -16,7 +16,7 @@ func buildGetBucketLocation(r *aws.Request) {
out := r.Data.(*GetBucketLocationOutput)
b, err := ioutil.ReadAll(r.HTTPResponse.Body)
if err != nil {
r.Error = apierr.New("Unmarshal", "failed reading response body", err)
r.Error = awserr.New("SerializationError", "failed reading response body", err)
return
}

View File

@ -6,7 +6,7 @@ import (
"io"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/internal/apierr"
"github.com/aws/aws-sdk-go/aws/awserr"
)
// contentMD5 computes and sets the HTTP Content-MD5 header for requests that
@ -19,12 +19,12 @@ func contentMD5(r *aws.Request) {
// body.
_, err := io.Copy(h, r.Body)
if err != nil {
r.Error = apierr.New("ContentMD5", "failed to read body", err)
r.Error = awserr.New("ContentMD5", "failed to read body", err)
return
}
_, err = r.Body.Seek(0, 0)
if err != nil {
r.Error = apierr.New("ContentMD5", "failed to seek body", err)
r.Error = awserr.New("ContentMD5", "failed to seek body", err)
return
}

View File

@ -17,7 +17,7 @@ func init() {
}
initRequest = func(r *aws.Request) {
switch r.Operation {
switch r.Operation.Name {
case opPutBucketCORS, opPutBucketLifecycle, opPutBucketPolicy, opPutBucketTagging, opDeleteObjects:
// These S3 operations require Content-MD5 to be set
r.Handlers.Build.PushBack(contentMD5)

View File

@ -30,7 +30,7 @@ func TestMD5InPutBucketCORS(t *testing.T) {
Bucket: aws.String("bucketname"),
CORSConfiguration: &s3.CORSConfiguration{
CORSRules: []*s3.CORSRule{
&s3.CORSRule{AllowedMethods: []*string{aws.String("GET")}},
{AllowedMethods: []*string{aws.String("GET")}},
},
},
})
@ -43,7 +43,7 @@ func TestMD5InPutBucketLifecycle(t *testing.T) {
Bucket: aws.String("bucketname"),
LifecycleConfiguration: &s3.LifecycleConfiguration{
Rules: []*s3.LifecycleRule{
&s3.LifecycleRule{
{
ID: aws.String("ID"),
Prefix: aws.String("Prefix"),
Status: aws.String("Enabled"),
@ -69,7 +69,7 @@ func TestMD5InPutBucketTagging(t *testing.T) {
Bucket: aws.String("bucketname"),
Tagging: &s3.Tagging{
TagSet: []*s3.Tag{
&s3.Tag{Key: aws.String("KEY"), Value: aws.String("VALUE")},
{Key: aws.String("KEY"), Value: aws.String("VALUE")},
},
},
})
@ -82,7 +82,7 @@ func TestMD5InDeleteObjects(t *testing.T) {
Bucket: aws.String("bucketname"),
Delete: &s3.Delete{
Objects: []*s3.ObjectIdentifier{
&s3.ObjectIdentifier{Key: aws.String("key")},
{Key: aws.String("key")},
},
},
})

View File

@ -36,7 +36,7 @@ func ExampleS3_AbortMultipartUpload() {
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, The SDK should alwsy return an
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
@ -55,7 +55,7 @@ func ExampleS3_CompleteMultipartUpload() {
UploadID: aws.String("MultipartUploadId"), // Required
MultipartUpload: &s3.CompletedMultipartUpload{
Parts: []*s3.CompletedPart{
&s3.CompletedPart{ // Required
{ // Required
ETag: aws.String("ETag"),
PartNumber: aws.Long(1),
},
@ -75,7 +75,7 @@ func ExampleS3_CompleteMultipartUpload() {
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, The SDK should alwsy return an
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
@ -135,7 +135,7 @@ func ExampleS3_CopyObject() {
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, The SDK should alwsy return an
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
@ -171,7 +171,7 @@ func ExampleS3_CreateBucket() {
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, The SDK should alwsy return an
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
@ -222,7 +222,7 @@ func ExampleS3_CreateMultipartUpload() {
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, The SDK should alwsy return an
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
@ -249,7 +249,7 @@ func ExampleS3_DeleteBucket() {
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, The SDK should alwsy return an
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
@ -276,7 +276,7 @@ func ExampleS3_DeleteBucketCORS() {
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, The SDK should alwsy return an
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
@ -303,7 +303,7 @@ func ExampleS3_DeleteBucketLifecycle() {
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, The SDK should alwsy return an
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
@ -330,7 +330,7 @@ func ExampleS3_DeleteBucketPolicy() {
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, The SDK should alwsy return an
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
@ -357,7 +357,7 @@ func ExampleS3_DeleteBucketReplication() {
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, The SDK should alwsy return an
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
@ -384,7 +384,7 @@ func ExampleS3_DeleteBucketTagging() {
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, The SDK should alwsy return an
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
@ -411,7 +411,7 @@ func ExampleS3_DeleteBucketWebsite() {
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, The SDK should alwsy return an
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
@ -442,7 +442,7 @@ func ExampleS3_DeleteObject() {
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, The SDK should alwsy return an
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
@ -459,7 +459,7 @@ func ExampleS3_DeleteObjects() {
Bucket: aws.String("BucketName"), // Required
Delete: &s3.Delete{ // Required
Objects: []*s3.ObjectIdentifier{ // Required
&s3.ObjectIdentifier{ // Required
{ // Required
Key: aws.String("ObjectKey"), // Required
VersionID: aws.String("ObjectVersionId"),
},
@ -481,7 +481,7 @@ func ExampleS3_DeleteObjects() {
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, The SDK should alwsy return an
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
@ -508,7 +508,7 @@ func ExampleS3_GetBucketACL() {
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, The SDK should alwsy return an
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
@ -535,7 +535,7 @@ func ExampleS3_GetBucketCORS() {
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, The SDK should alwsy return an
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
@ -562,7 +562,7 @@ func ExampleS3_GetBucketLifecycle() {
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, The SDK should alwsy return an
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
@ -589,7 +589,7 @@ func ExampleS3_GetBucketLocation() {
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, The SDK should alwsy return an
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
@ -616,7 +616,7 @@ func ExampleS3_GetBucketLogging() {
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, The SDK should alwsy return an
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
@ -643,7 +643,7 @@ func ExampleS3_GetBucketNotification() {
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, The SDK should alwsy return an
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
@ -670,7 +670,7 @@ func ExampleS3_GetBucketNotificationConfiguration() {
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, The SDK should alwsy return an
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
@ -697,7 +697,7 @@ func ExampleS3_GetBucketPolicy() {
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, The SDK should alwsy return an
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
@ -724,7 +724,7 @@ func ExampleS3_GetBucketReplication() {
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, The SDK should alwsy return an
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
@ -751,7 +751,7 @@ func ExampleS3_GetBucketRequestPayment() {
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, The SDK should alwsy return an
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
@ -778,7 +778,7 @@ func ExampleS3_GetBucketTagging() {
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, The SDK should alwsy return an
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
@ -805,7 +805,7 @@ func ExampleS3_GetBucketVersioning() {
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, The SDK should alwsy return an
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
@ -832,7 +832,7 @@ func ExampleS3_GetBucketWebsite() {
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, The SDK should alwsy return an
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
@ -876,7 +876,7 @@ func ExampleS3_GetObject() {
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, The SDK should alwsy return an
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
@ -906,7 +906,7 @@ func ExampleS3_GetObjectACL() {
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, The SDK should alwsy return an
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
@ -935,7 +935,7 @@ func ExampleS3_GetObjectTorrent() {
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, The SDK should alwsy return an
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
@ -962,7 +962,7 @@ func ExampleS3_HeadBucket() {
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, The SDK should alwsy return an
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
@ -1000,7 +1000,7 @@ func ExampleS3_HeadObject() {
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, The SDK should alwsy return an
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
@ -1025,7 +1025,7 @@ func ExampleS3_ListBuckets() {
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, The SDK should alwsy return an
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
@ -1058,7 +1058,7 @@ func ExampleS3_ListMultipartUploads() {
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, The SDK should alwsy return an
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
@ -1091,7 +1091,7 @@ func ExampleS3_ListObjectVersions() {
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, The SDK should alwsy return an
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
@ -1123,7 +1123,7 @@ func ExampleS3_ListObjects() {
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, The SDK should alwsy return an
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
@ -1155,7 +1155,7 @@ func ExampleS3_ListParts() {
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, The SDK should alwsy return an
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
@ -1173,7 +1173,7 @@ func ExampleS3_PutBucketACL() {
ACL: aws.String("BucketCannedACL"),
AccessControlPolicy: &s3.AccessControlPolicy{
Grants: []*s3.Grant{
&s3.Grant{ // Required
{ // Required
Grantee: &s3.Grantee{
Type: aws.String("Type"), // Required
DisplayName: aws.String("DisplayName"),
@ -1207,7 +1207,7 @@ func ExampleS3_PutBucketACL() {
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, The SDK should alwsy return an
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
@ -1224,7 +1224,7 @@ func ExampleS3_PutBucketCORS() {
Bucket: aws.String("BucketName"), // Required
CORSConfiguration: &s3.CORSConfiguration{
CORSRules: []*s3.CORSRule{
&s3.CORSRule{ // Required
{ // Required
AllowedHeaders: []*string{
aws.String("AllowedHeader"), // Required
// More values...
@ -1258,7 +1258,7 @@ func ExampleS3_PutBucketCORS() {
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, The SDK should alwsy return an
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
@ -1275,7 +1275,7 @@ func ExampleS3_PutBucketLifecycle() {
Bucket: aws.String("BucketName"), // Required
LifecycleConfiguration: &s3.LifecycleConfiguration{
Rules: []*s3.LifecycleRule{ // Required
&s3.LifecycleRule{ // Required
{ // Required
Prefix: aws.String("Prefix"), // Required
Status: aws.String("ExpirationStatus"), // Required
Expiration: &s3.LifecycleExpiration{
@ -1311,7 +1311,7 @@ func ExampleS3_PutBucketLifecycle() {
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, The SDK should alwsy return an
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
@ -1330,7 +1330,7 @@ func ExampleS3_PutBucketLogging() {
LoggingEnabled: &s3.LoggingEnabled{
TargetBucket: aws.String("TargetBucket"),
TargetGrants: []*s3.TargetGrant{
&s3.TargetGrant{ // Required
{ // Required
Grantee: &s3.Grantee{
Type: aws.String("Type"), // Required
DisplayName: aws.String("DisplayName"),
@ -1357,7 +1357,7 @@ func ExampleS3_PutBucketLogging() {
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, The SDK should alwsy return an
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
@ -1414,7 +1414,7 @@ func ExampleS3_PutBucketNotification() {
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, The SDK should alwsy return an
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
@ -1431,7 +1431,7 @@ func ExampleS3_PutBucketNotificationConfiguration() {
Bucket: aws.String("BucketName"), // Required
NotificationConfiguration: &s3.NotificationConfiguration{ // Required
LambdaFunctionConfigurations: []*s3.LambdaFunctionConfiguration{
&s3.LambdaFunctionConfiguration{ // Required
{ // Required
Events: []*string{ // Required
aws.String("Event"), // Required
// More values...
@ -1442,7 +1442,7 @@ func ExampleS3_PutBucketNotificationConfiguration() {
// More values...
},
QueueConfigurations: []*s3.QueueConfiguration{
&s3.QueueConfiguration{ // Required
{ // Required
Events: []*string{ // Required
aws.String("Event"), // Required
// More values...
@ -1453,7 +1453,7 @@ func ExampleS3_PutBucketNotificationConfiguration() {
// More values...
},
TopicConfigurations: []*s3.TopicConfiguration{
&s3.TopicConfiguration{ // Required
{ // Required
Events: []*string{ // Required
aws.String("Event"), // Required
// More values...
@ -1476,7 +1476,7 @@ func ExampleS3_PutBucketNotificationConfiguration() {
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, The SDK should alwsy return an
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
@ -1504,7 +1504,7 @@ func ExampleS3_PutBucketPolicy() {
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, The SDK should alwsy return an
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
@ -1522,7 +1522,7 @@ func ExampleS3_PutBucketReplication() {
ReplicationConfiguration: &s3.ReplicationConfiguration{ // Required
Role: aws.String("Role"), // Required
Rules: []*s3.ReplicationRule{ // Required
&s3.ReplicationRule{ // Required
{ // Required
Destination: &s3.Destination{ // Required
Bucket: aws.String("BucketName"), // Required
},
@ -1545,7 +1545,7 @@ func ExampleS3_PutBucketReplication() {
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, The SDK should alwsy return an
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
@ -1575,7 +1575,7 @@ func ExampleS3_PutBucketRequestPayment() {
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, The SDK should alwsy return an
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
@ -1592,7 +1592,7 @@ func ExampleS3_PutBucketTagging() {
Bucket: aws.String("BucketName"), // Required
Tagging: &s3.Tagging{ // Required
TagSet: []*s3.Tag{ // Required
&s3.Tag{ // Required
{ // Required
Key: aws.String("ObjectKey"), // Required
Value: aws.String("Value"), // Required
},
@ -1611,7 +1611,7 @@ func ExampleS3_PutBucketTagging() {
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, The SDK should alwsy return an
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
@ -1643,7 +1643,7 @@ func ExampleS3_PutBucketVersioning() {
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, The SDK should alwsy return an
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
@ -1670,7 +1670,7 @@ func ExampleS3_PutBucketWebsite() {
Protocol: aws.String("Protocol"),
},
RoutingRules: []*s3.RoutingRule{
&s3.RoutingRule{ // Required
{ // Required
Redirect: &s3.Redirect{ // Required
HTTPRedirectCode: aws.String("HttpRedirectCode"),
HostName: aws.String("HostName"),
@ -1698,7 +1698,7 @@ func ExampleS3_PutBucketWebsite() {
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, The SDK should alwsy return an
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
@ -1751,7 +1751,7 @@ func ExampleS3_PutObject() {
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, The SDK should alwsy return an
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
@ -1770,7 +1770,7 @@ func ExampleS3_PutObjectACL() {
ACL: aws.String("ObjectCannedACL"),
AccessControlPolicy: &s3.AccessControlPolicy{
Grants: []*s3.Grant{
&s3.Grant{ // Required
{ // Required
Grantee: &s3.Grantee{
Type: aws.String("Type"), // Required
DisplayName: aws.String("DisplayName"),
@ -1805,7 +1805,7 @@ func ExampleS3_PutObjectACL() {
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, The SDK should alwsy return an
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
@ -1838,7 +1838,7 @@ func ExampleS3_RestoreObject() {
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, The SDK should alwsy return an
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
@ -1874,7 +1874,7 @@ func ExampleS3_UploadPart() {
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, The SDK should alwsy return an
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}
@ -1917,7 +1917,7 @@ func ExampleS3_UploadPartCopy() {
fmt.Println(reqErr.Code(), reqErr.Message(), reqErr.StatusCode(), reqErr.RequestID())
}
} else {
// This case should never be hit, The SDK should alwsy return an
// This case should never be hit, the SDK should always return an
// error which satisfies the awserr.Error interface.
fmt.Println(err.Error())
}

View File

@ -129,12 +129,7 @@ func (d *downloader) download() (n int64, err error) {
}
// Queue the next range of bytes to read.
ch <- dlchunk{
dlchunkcounter: &dlchunkcounter{},
w: d.w,
start: d.pos,
size: d.opts.PartSize,
}
ch <- dlchunk{w: d.w, start: d.pos, size: d.opts.PartSize}
d.pos += d.opts.PartSize
}
@ -175,7 +170,7 @@ func (d *downloader) downloadPart(ch chan dlchunk) {
} else {
d.setTotalBytes(resp) // Set total if not yet set.
n, err := io.Copy(chunk, resp.Body)
n, err := io.Copy(&chunk, resp.Body)
resp.Body.Close()
if err != nil {
@ -242,21 +237,15 @@ func (d *downloader) seterr(e error) {
// io.WriterAt, effectively making it an io.SectionWriter (which does not
// exist).
type dlchunk struct {
*dlchunkcounter
w io.WriterAt
start int64
size int64
}
// dlchunkcounter keeps track of the current position the dlchunk struct is
// writing to.
type dlchunkcounter struct {
cur int64
cur int64
}
// Write wraps io.WriterAt for the dlchunk, writing from the dlchunk's start
// position to its end (or EOF).
func (c dlchunk) Write(p []byte) (n int, err error) {
func (c *dlchunk) Write(p []byte) (n int, err error) {
if c.cur >= c.size {
return 0, io.EOF
}

View File

@ -10,7 +10,6 @@ import (
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/aws/aws-sdk-go/internal/apierr"
"github.com/aws/aws-sdk-go/service/s3"
)
@ -60,13 +59,17 @@ type MultiUploadFailure interface {
UploadID() string
}
// So that the Error interface type can be included as an anonymous field
// in the multiUploadError struct and not conflict with the error.Error() method.
type awsError awserr.Error
// A multiUploadError wraps the upload ID of a failed s3 multipart upload.
// Composed of BaseError for code, message, and original error
//
// Should be used for an error that occurred failing a S3 multipart upload,
// and a upload ID is available. If an uploadID is not available a more relevant
type multiUploadError struct {
*apierr.BaseError
awsError
// ID for multipart upload which failed.
uploadID string
@ -77,18 +80,19 @@ type multiUploadError struct {
// See apierr.BaseError ErrorWithExtra for output format
//
// Satisfies the error interface.
func (m *multiUploadError) Error() string {
return m.ErrorWithExtra(fmt.Sprintf("upload id: %s", m.uploadID))
func (m multiUploadError) Error() string {
extra := fmt.Sprintf("upload id: %s", m.uploadID)
return awserr.SprintError(m.Code(), m.Message(), extra, m.OrigErr())
}
// String returns the string representation of the error.
// Alias for Error to satisfy the stringer interface.
func (m *multiUploadError) String() string {
func (m multiUploadError) String() string {
return m.Error()
}
// UploadID returns the id of the S3 upload which failed.
func (m *multiUploadError) UploadID() string {
func (m multiUploadError) UploadID() string {
return m.uploadID
}
@ -258,7 +262,7 @@ func (u *uploader) upload() (*UploadOutput, error) {
if u.opts.PartSize < MinUploadPartSize {
msg := fmt.Sprintf("part size must be at least %d bytes", MinUploadPartSize)
return nil, apierr.New("ConfigError", msg, nil)
return nil, awserr.New("ConfigError", msg, nil)
}
// Do one read to determine if we have more than one part
@ -266,7 +270,7 @@ func (u *uploader) upload() (*UploadOutput, error) {
if err == io.EOF || err == io.ErrUnexpectedEOF { // single part
return u.singlePart(buf)
} else if err != nil {
return nil, apierr.New("ReadRequestBody", "read upload data failed", err)
return nil, awserr.New("ReadRequestBody", "read upload data failed", err)
}
mu := multiuploader{uploader: u}
@ -418,7 +422,7 @@ func (u *multiuploader) upload(firstBuf io.ReadSeeker) (*UploadOutput, error) {
if num > int64(MaxUploadParts) {
msg := fmt.Sprintf("exceeded total allowed parts (%d). "+
"Adjust PartSize to fit in this limit", MaxUploadParts)
u.seterr(apierr.New("TotalPartsExceeded", msg, nil))
u.seterr(awserr.New("TotalPartsExceeded", msg, nil))
break
}
@ -432,7 +436,10 @@ func (u *multiuploader) upload(firstBuf io.ReadSeeker) (*UploadOutput, error) {
ch <- chunk{buf: buf, num: num}
if err != nil && err != io.ErrUnexpectedEOF {
u.seterr(apierr.New("ReadRequestBody", "read multipart upload data failed", err))
u.seterr(awserr.New(
"ReadRequestBody",
"read multipart upload data failed",
err))
break
}
}
@ -443,16 +450,12 @@ func (u *multiuploader) upload(firstBuf io.ReadSeeker) (*UploadOutput, error) {
complete := u.complete()
if err := u.geterr(); err != nil {
var berr *apierr.BaseError
switch t := err.(type) {
case *apierr.BaseError:
berr = t
default:
berr = apierr.New("MultipartUpload", "upload multipart failed", err)
}
return nil, &multiUploadError{
BaseError: berr,
uploadID: u.uploadID,
awsError: awserr.New(
"MultipartUpload",
"upload multipart failed",
err),
uploadID: u.uploadID,
}
}
return &UploadOutput{

View File

@ -23,11 +23,22 @@ var _ = unit.Imported
var buf12MB = make([]byte, 1024*1024*12)
var buf2MB = make([]byte, 1024*1024*2)
var emptyList = []string{}
func val(i interface{}, s string) interface{} {
return awsutil.ValuesAtPath(i, s)[0]
}
func loggingSvc() (*s3.S3, *[]string, *[]interface{}) {
func contains(src []string, s string) bool {
for _, v := range src {
if s == v {
return true
}
}
return false
}
func loggingSvc(ignoreOps []string) (*s3.S3, *[]string, *[]interface{}) {
var m sync.Mutex
partNum := 0
names := []string{}
@ -41,8 +52,10 @@ func loggingSvc() (*s3.S3, *[]string, *[]interface{}) {
m.Lock()
defer m.Unlock()
names = append(names, r.Operation.Name)
params = append(params, r.Params)
if !contains(ignoreOps, r.Operation.Name) {
names = append(names, r.Operation.Name)
params = append(params, r.Params)
}
r.HTTPResponse = &http.Response{
StatusCode: 200,
@ -70,7 +83,7 @@ func buflen(i interface{}) int {
}
func TestUploadOrderMulti(t *testing.T) {
s, ops, args := loggingSvc()
s, ops, args := loggingSvc(emptyList)
mgr := s3manager.NewUploader(&s3manager.UploadOptions{S3: s})
resp, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
@ -107,7 +120,7 @@ func TestUploadOrderMulti(t *testing.T) {
}
func TestUploadOrderMultiDifferentPartSize(t *testing.T) {
s, ops, args := loggingSvc()
s, ops, args := loggingSvc(emptyList)
mgr := s3manager.NewUploader(&s3manager.UploadOptions{
S3: s,
PartSize: 1024 * 1024 * 7,
@ -131,7 +144,7 @@ func TestUploadIncreasePartSize(t *testing.T) {
s3manager.MaxUploadParts = 2
defer func() { s3manager.MaxUploadParts = 10000 }()
s, ops, args := loggingSvc()
s, ops, args := loggingSvc(emptyList)
opts := &s3manager.UploadOptions{S3: s, Concurrency: 1}
mgr := s3manager.NewUploader(opts)
_, err := mgr.Upload(&s3manager.UploadInput{
@ -167,7 +180,7 @@ func TestUploadFailIfPartSizeTooSmall(t *testing.T) {
}
func TestUploadOrderSingle(t *testing.T) {
s, ops, args := loggingSvc()
s, ops, args := loggingSvc(emptyList)
mgr := s3manager.NewUploader(&s3manager.UploadOptions{S3: s})
resp, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
@ -186,7 +199,7 @@ func TestUploadOrderSingle(t *testing.T) {
}
func TestUploadOrderSingleFailure(t *testing.T) {
s, ops, _ := loggingSvc()
s, ops, _ := loggingSvc(emptyList)
s.Handlers.Send.PushBack(func(r *aws.Request) {
r.HTTPResponse.StatusCode = 400
})
@ -203,7 +216,7 @@ func TestUploadOrderSingleFailure(t *testing.T) {
}
func TestUploadOrderZero(t *testing.T) {
s, ops, args := loggingSvc()
s, ops, args := loggingSvc(emptyList)
mgr := s3manager.NewUploader(&s3manager.UploadOptions{S3: s})
resp, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
@ -219,7 +232,7 @@ func TestUploadOrderZero(t *testing.T) {
}
func TestUploadOrderMultiFailure(t *testing.T) {
s, ops, _ := loggingSvc()
s, ops, _ := loggingSvc(emptyList)
s.Handlers.Send.PushBack(func(r *aws.Request) {
switch t := r.Data.(type) {
case *s3.UploadPartOutput:
@ -241,7 +254,7 @@ func TestUploadOrderMultiFailure(t *testing.T) {
}
func TestUploadOrderMultiFailureOnComplete(t *testing.T) {
s, ops, _ := loggingSvc()
s, ops, _ := loggingSvc(emptyList)
s.Handlers.Send.PushBack(func(r *aws.Request) {
switch r.Data.(type) {
case *s3.CompleteMultipartUploadOutput:
@ -249,7 +262,7 @@ func TestUploadOrderMultiFailureOnComplete(t *testing.T) {
}
})
mgr := s3manager.NewUploader(&s3manager.UploadOptions{S3: s})
mgr := s3manager.NewUploader(&s3manager.UploadOptions{S3: s, Concurrency: 1})
_, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
@ -262,7 +275,7 @@ func TestUploadOrderMultiFailureOnComplete(t *testing.T) {
}
func TestUploadOrderMultiFailureOnCreate(t *testing.T) {
s, ops, _ := loggingSvc()
s, ops, _ := loggingSvc(emptyList)
s.Handlers.Send.PushBack(func(r *aws.Request) {
switch r.Data.(type) {
case *s3.CreateMultipartUploadOutput:
@ -282,7 +295,7 @@ func TestUploadOrderMultiFailureOnCreate(t *testing.T) {
}
func TestUploadOrderMultiFailureLeaveParts(t *testing.T) {
s, ops, _ := loggingSvc()
s, ops, _ := loggingSvc(emptyList)
s.Handlers.Send.PushBack(func(r *aws.Request) {
switch data := r.Data.(type) {
case *s3.UploadPartOutput:
@ -307,26 +320,26 @@ func TestUploadOrderMultiFailureLeaveParts(t *testing.T) {
assert.Equal(t, []string{"CreateMultipartUpload", "UploadPart", "UploadPart"}, *ops)
}
var failreaderCount = 0
type failreader struct {
times int
failCount int
}
type failreader struct{ times int }
func (f failreader) Read(b []byte) (int, error) {
failreaderCount++
if failreaderCount >= f.times {
func (f *failreader) Read(b []byte) (int, error) {
f.failCount++
if f.failCount >= f.times {
return 0, fmt.Errorf("random failure")
}
return len(b), nil
}
func TestUploadOrderReadFail1(t *testing.T) {
failreaderCount = 0
s, ops, _ := loggingSvc()
s, ops, _ := loggingSvc(emptyList)
mgr := s3manager.NewUploader(&s3manager.UploadOptions{S3: s})
_, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: failreader{1},
Body: &failreader{times: 1},
})
assert.Equal(t, "ReadRequestBody", err.(awserr.Error).Code())
@ -335,13 +348,12 @@ func TestUploadOrderReadFail1(t *testing.T) {
}
func TestUploadOrderReadFail2(t *testing.T) {
failreaderCount = 0
s, ops, _ := loggingSvc()
mgr := s3manager.NewUploader(&s3manager.UploadOptions{S3: s})
s, ops, _ := loggingSvc([]string{"UploadPart"})
mgr := s3manager.NewUploader(&s3manager.UploadOptions{S3: s, Concurrency: 1})
_, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: failreader{2},
Body: &failreader{times: 2},
})
assert.Equal(t, "ReadRequestBody", err.(awserr.Error).Code())
@ -349,16 +361,12 @@ func TestUploadOrderReadFail2(t *testing.T) {
assert.Equal(t, []string{"CreateMultipartUpload", "AbortMultipartUpload"}, *ops)
}
type sizedReaderImpl struct {
type sizedReader struct {
size int
cur int
}
type sizedReader struct {
*sizedReaderImpl
}
func (s sizedReader) Read(p []byte) (n int, err error) {
func (s *sizedReader) Read(p []byte) (n int, err error) {
if s.cur >= s.size {
return 0, io.EOF
}
@ -373,12 +381,12 @@ func (s sizedReader) Read(p []byte) (n int, err error) {
}
func TestUploadOrderMultiBufferedReader(t *testing.T) {
s, ops, args := loggingSvc()
s, ops, args := loggingSvc(emptyList)
mgr := s3manager.NewUploader(&s3manager.UploadOptions{S3: s})
_, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: sizedReader{&sizedReaderImpl{size: 1024 * 1024 * 12}},
Body: &sizedReader{size: 1024 * 1024 * 12},
})
assert.NoError(t, err)
@ -397,17 +405,17 @@ func TestUploadOrderMultiBufferedReader(t *testing.T) {
func TestUploadOrderMultiBufferedReaderExceedTotalParts(t *testing.T) {
s3manager.MaxUploadParts = 2
defer func() { s3manager.MaxUploadParts = 10000 }()
s, ops, _ := loggingSvc()
s, ops, _ := loggingSvc([]string{"UploadPart"})
mgr := s3manager.NewUploader(&s3manager.UploadOptions{S3: s, Concurrency: 1})
resp, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: sizedReader{&sizedReaderImpl{size: 1024 * 1024 * 12}},
Body: &sizedReader{size: 1024 * 1024 * 12},
})
assert.Error(t, err)
assert.Nil(t, resp)
assert.Equal(t, []string{"CreateMultipartUpload", "UploadPart", "UploadPart", "AbortMultipartUpload"}, *ops)
assert.Equal(t, []string{"CreateMultipartUpload", "AbortMultipartUpload"}, *ops)
aerr := err.(awserr.Error)
assert.Equal(t, "TotalPartsExceeded", aerr.Code())
@ -415,12 +423,12 @@ func TestUploadOrderMultiBufferedReaderExceedTotalParts(t *testing.T) {
}
func TestUploadOrderSingleBufferedReader(t *testing.T) {
s, ops, _ := loggingSvc()
s, ops, _ := loggingSvc(emptyList)
mgr := s3manager.NewUploader(&s3manager.UploadOptions{S3: s})
resp, err := mgr.Upload(&s3manager.UploadInput{
Bucket: aws.String("Bucket"),
Key: aws.String("Key"),
Body: sizedReader{&sizedReaderImpl{size: 1024 * 1024 * 2}},
Body: &sizedReader{size: 1024 * 1024 * 2},
})
assert.NoError(t, err)

View File

@ -5,11 +5,11 @@ import (
"encoding/base64"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/aws/aws-sdk-go/internal/apierr"
)
var errSSERequiresSSL = apierr.New("ConfigError", "cannot send SSE keys over HTTP.", nil)
var errSSERequiresSSL = awserr.New("ConfigError", "cannot send SSE keys over HTTP.", nil)
func validateSSERequiresSSL(r *aws.Request) {
if r.HTTPRequest.URL.Scheme != "https" {

View File

@ -6,7 +6,7 @@ import (
"strings"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/internal/apierr"
"github.com/aws/aws-sdk-go/aws/awserr"
)
type xmlErrorResponse struct {
@ -20,8 +20,8 @@ func unmarshalError(r *aws.Request) {
if r.HTTPResponse.ContentLength == int64(0) {
// No body, use status code to generate an awserr.Error
r.Error = apierr.NewRequestError(
apierr.New(strings.Replace(r.HTTPResponse.Status, " ", "", -1), r.HTTPResponse.Status, nil),
r.Error = awserr.NewRequestFailure(
awserr.New(strings.Replace(r.HTTPResponse.Status, " ", "", -1), r.HTTPResponse.Status, nil),
r.HTTPResponse.StatusCode,
"",
)
@ -31,10 +31,10 @@ func unmarshalError(r *aws.Request) {
resp := &xmlErrorResponse{}
err := xml.NewDecoder(r.HTTPResponse.Body).Decode(resp)
if err != nil && err != io.EOF {
r.Error = apierr.New("Unmarshal", "failed to decode S3 XML error response", nil)
r.Error = awserr.New("SerializationError", "failed to decode S3 XML error response", nil)
} else {
r.Error = apierr.NewRequestError(
apierr.New(resp.Code, resp.Message, nil),
r.Error = awserr.NewRequestFailure(
awserr.New(resp.Code, resp.Message, nil),
r.HTTPResponse.StatusCode,
"",
)

View File

@ -192,7 +192,7 @@ func (c *Client) Close() {
// initHTTPClient initializes a HTTP client for etcd client
func (c *Client) initHTTPClient() {
c.transport = &http.Transport{
Dial: c.dial,
Dial: c.DefaultDial,
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
@ -218,7 +218,7 @@ func (c *Client) initHTTPSClient(cert, key string) error {
tr := &http.Transport{
TLSClientConfig: tlsConfig,
Dial: c.dial,
Dial: c.DefaultDial,
}
c.httpClient = &http.Client{Transport: tr}
@ -306,12 +306,16 @@ func (c *Client) GetCluster() []string {
}
// SyncCluster updates the cluster information using the internal machine list.
// If no members are found, the intenral machine list is left untouched.
func (c *Client) SyncCluster() bool {
return c.internalSyncCluster(c.cluster.Machines)
}
// internalSyncCluster syncs cluster information using the given machine list.
func (c *Client) internalSyncCluster(machines []string) bool {
// comma-separated list of machines in the cluster.
members := ""
for _, machine := range machines {
httpPath := c.createHttpPath(machine, path.Join(version, "members"))
resp, err := c.httpClient.Get(httpPath)
@ -333,8 +337,7 @@ func (c *Client) internalSyncCluster(machines []string) bool {
// try another machine in the cluster
continue
}
// update Machines List
c.cluster.updateFromStr(string(b))
members = string(b)
} else {
b, err := ioutil.ReadAll(resp.Body)
resp.Body.Close()
@ -354,10 +357,16 @@ func (c *Client) internalSyncCluster(machines []string) bool {
urls = append(urls, m.ClientURLs...)
}
// update Machines List
c.cluster.updateFromStr(strings.Join(urls, ","))
members = strings.Join(urls, ",")
}
// We should never do an empty cluster update.
if members == "" {
continue
}
// update Machines List
c.cluster.updateFromStr(members)
logger.Debug("sync.machines ", c.cluster.Machines)
c.saveConfig()
return true
@ -382,9 +391,9 @@ func (c *Client) createHttpPath(serverName string, _path string) string {
return u.String()
}
// dial attempts to open a TCP connection to the provided address, explicitly
// DefaultDial attempts to open a TCP connection to the provided address, explicitly
// enabling keep-alives with a one-second interval.
func (c *Client) dial(network, addr string) (net.Conn, error) {
func (c *Client) DefaultDial(network, addr string) (net.Conn, error) {
conn, err := net.DialTimeout(network, addr, c.config.DialTimeout)
if err != nil {
return nil, err

View File

@ -3,12 +3,14 @@ package etcd
import (
"math/rand"
"strings"
"sync"
)
type Cluster struct {
Leader string `json:"leader"`
Machines []string `json:"machines"`
picked int
mu sync.RWMutex
}
func NewCluster(machines []string) *Cluster {
@ -25,10 +27,22 @@ func NewCluster(machines []string) *Cluster {
}
}
func (cl *Cluster) failure() { cl.picked = rand.Intn(len(cl.Machines)) }
func (cl *Cluster) pick() string { return cl.Machines[cl.picked] }
func (cl *Cluster) failure() {
cl.mu.Lock()
defer cl.mu.Unlock()
cl.picked = rand.Intn(len(cl.Machines))
}
func (cl *Cluster) pick() string {
cl.mu.Lock()
defer cl.mu.Unlock()
return cl.Machines[cl.picked]
}
func (cl *Cluster) updateFromStr(machines string) {
cl.mu.Lock()
defer cl.mu.Unlock()
cl.Machines = strings.Split(machines, ",")
for i := range cl.Machines {
cl.Machines[i] = strings.TrimSpace(cl.Machines[i])

File diff suppressed because it is too large Load Diff

View File

@ -1,10 +1,13 @@
package etcd
//go:generate codecgen -o response.generated.go response.go
import (
"encoding/json"
"net/http"
"strconv"
"time"
"github.com/ugorji/go/codec"
)
const (
@ -28,6 +31,7 @@ var (
http.StatusNotFound: true,
http.StatusPreconditionFailed: true,
http.StatusForbidden: true,
http.StatusUnauthorized: true,
}
)
@ -39,7 +43,7 @@ func (rr *RawResponse) Unmarshal() (*Response, error) {
resp := new(Response)
err := json.Unmarshal(rr.Body, resp)
err := codec.NewDecoderBytes(rr.Body, new(codec.JsonHandle)).Decode(resp)
if err != nil {
return nil, err

View File

@ -0,0 +1,75 @@
package etcd
import (
"net/http"
"reflect"
"strings"
"testing"
"github.com/ugorji/go/codec"
)
func createTestNode(size int) *Node {
return &Node{
Key: strings.Repeat("a", 30),
Value: strings.Repeat("a", size),
TTL: 123456789,
ModifiedIndex: 123456,
CreatedIndex: 123456,
}
}
func createTestNodeWithChildren(children, size int) *Node {
node := createTestNode(size)
for i := 0; i < children; i++ {
node.Nodes = append(node.Nodes, createTestNode(size))
}
return node
}
func createTestResponse(children, size int) *Response {
return &Response{
Action: "aaaaa",
Node: createTestNodeWithChildren(children, size),
PrevNode: nil,
EtcdIndex: 123456,
RaftIndex: 123456,
RaftTerm: 123456,
}
}
func benchmarkResponseUnmarshalling(b *testing.B, children, size int) {
response := createTestResponse(children, size)
rr := RawResponse{http.StatusOK, make([]byte, 0), http.Header{}}
codec.NewEncoderBytes(&rr.Body, new(codec.JsonHandle)).Encode(response)
b.ResetTimer()
newResponse := new(Response)
var err error
for i := 0; i < b.N; i++ {
if newResponse, err = rr.Unmarshal(); err != nil {
b.Errorf("Error: %v", err)
}
}
if !reflect.DeepEqual(response.Node, newResponse.Node) {
b.Errorf("Unexpected difference in a parsed response: \n%+v\n%+v", response, newResponse)
}
}
func BenchmarkSmallResponseUnmarshal(b *testing.B) {
benchmarkResponseUnmarshalling(b, 30, 20)
}
func BenchmarkManySmallResponseUnmarshal(b *testing.B) {
benchmarkResponseUnmarshalling(b, 3000, 20)
}
func BenchmarkMediumResponseUnmarshal(b *testing.B) {
benchmarkResponseUnmarshalling(b, 300, 200)
}
func BenchmarkLargeResponseUnmarshal(b *testing.B) {
benchmarkResponseUnmarshalling(b, 3000, 2000)
}

View File

@ -0,0 +1,12 @@
language: go
go:
- 1.2
- 1.3
- tip
install:
- go get gopkg.in/asn1-ber.v1
- go get gopkg.in/ldap.v1
- go get code.google.com/p/go.tools/cmd/cover || go get golang.org/x/tools/cmd/cover
- go build -v ./...
script:
- go test -v -cover ./...

View File

@ -0,0 +1,48 @@
[![GoDoc](https://godoc.org/gopkg.in/ldap.v1?status.svg)](https://godoc.org/gopkg.in/ldap.v1) [![Build Status](https://travis-ci.org/go-ldap/ldap.svg)](https://travis-ci.org/go-ldap/ldap)
# Basic LDAP v3 functionality for the GO programming language.
## Required Librarys:
- gopkg.in/asn1-ber.v1
## Working:
- Connecting to LDAP server
- Binding to LDAP server
- Searching for entries
- Compiling string filters to LDAP filters
- Paging Search Results
- Modify Requests / Responses
## Examples:
- search
- modify
## Tests Implemented:
- Filter Compile / Decompile
## TODO:
- Add Requests / Responses
- Delete Requests / Responses
- Modify DN Requests / Responses
- Compare Requests / Responses
- Implement Tests / Benchmarks
---
This feature is disabled at the moment, because in some cases the "Search Request Done" packet will be handled before the last "Search Request Entry":
- Mulitple internal goroutines to handle network traffic
Makes library goroutine safe
Can perform multiple search requests at the same time and return
the results to the proper goroutine. All requests are blocking requests,
so the goroutine does not need special handling
---
The Go gopher was designed by Renee French. (http://reneefrench.blogspot.com/)
The design is licensed under the Creative Commons 3.0 Attributions license.
Read this article for more details: http://blog.golang.org/gopher

135
Godeps/_workspace/src/github.com/go-ldap/ldap/bind.go generated vendored Normal file
View File

@ -0,0 +1,135 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ldap
import (
"errors"
"gopkg.in/asn1-ber.v1"
)
type SimpleBindRequest struct {
Username string
Password string
Controls []Control
}
type SimpleBindResult struct {
Controls []Control
}
func NewSimpleBindRequest(username string, password string, controls []Control) *SimpleBindRequest {
return &SimpleBindRequest{
Username: username,
Password: password,
Controls: controls,
}
}
func (bindRequest *SimpleBindRequest) encode() *ber.Packet {
request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request")
request.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version"))
request.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, bindRequest.Username, "User Name"))
request.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, bindRequest.Password, "Password"))
request.AppendChild(encodeControls(bindRequest.Controls))
return request
}
func (l *Conn) SimpleBind(simpleBindRequest *SimpleBindRequest) (*SimpleBindResult, error) {
messageID := l.nextMessageID()
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "MessageID"))
encodedBindRequest := simpleBindRequest.encode()
packet.AppendChild(encodedBindRequest)
if l.Debug {
ber.PrintPacket(packet)
}
channel, err := l.sendMessage(packet)
if err != nil {
return nil, err
}
if channel == nil {
return nil, NewError(ErrorNetwork, errors.New("ldap: could not send message"))
}
defer l.finishMessage(messageID)
packet = <-channel
if packet == nil {
return nil, NewError(ErrorNetwork, errors.New("ldap: could not retrieve response"))
}
if l.Debug {
if err := addLDAPDescriptions(packet); err != nil {
return nil, err
}
ber.PrintPacket(packet)
}
result := &SimpleBindResult{
Controls: make([]Control, 0),
}
if len(packet.Children) == 3 {
for _, child := range packet.Children[2].Children {
result.Controls = append(result.Controls, DecodeControl(child))
}
}
resultCode, resultDescription := getLDAPResultCode(packet)
if resultCode != 0 {
return result, NewError(resultCode, errors.New(resultDescription))
}
return result, nil
}
func (l *Conn) Bind(username, password string) error {
messageID := l.nextMessageID()
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "MessageID"))
bindRequest := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request")
bindRequest.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version"))
bindRequest.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, username, "User Name"))
bindRequest.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, password, "Password"))
packet.AppendChild(bindRequest)
if l.Debug {
ber.PrintPacket(packet)
}
channel, err := l.sendMessage(packet)
if err != nil {
return err
}
if channel == nil {
return NewError(ErrorNetwork, errors.New("ldap: could not send message"))
}
defer l.finishMessage(messageID)
packet = <-channel
if packet == nil {
return NewError(ErrorNetwork, errors.New("ldap: could not retrieve response"))
}
if l.Debug {
if err := addLDAPDescriptions(packet); err != nil {
return err
}
ber.PrintPacket(packet)
}
resultCode, resultDescription := getLDAPResultCode(packet)
if resultCode != 0 {
return NewError(resultCode, errors.New(resultDescription))
}
return nil
}

View File

@ -0,0 +1,85 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//
// File contains Compare functionality
//
// https://tools.ietf.org/html/rfc4511
//
// CompareRequest ::= [APPLICATION 14] SEQUENCE {
// entry LDAPDN,
// ava AttributeValueAssertion }
//
// AttributeValueAssertion ::= SEQUENCE {
// attributeDesc AttributeDescription,
// assertionValue AssertionValue }
//
// AttributeDescription ::= LDAPString
// -- Constrained to <attributedescription>
// -- [RFC4512]
//
// AttributeValue ::= OCTET STRING
//
package ldap
import (
"errors"
"fmt"
"gopkg.in/asn1-ber.v1"
)
// Compare checks to see if the attribute of the dn matches value. Returns true if it does otherwise
// false with any error that occurs if any.
func (l *Conn) Compare(dn, attribute, value string) (bool, error) {
messageID := l.nextMessageID()
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "MessageID"))
request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationCompareRequest, nil, "Compare Request")
request.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, dn, "DN"))
ava := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "AttributeValueAssertion")
ava.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, attribute, "AttributeDesc"))
ava.AppendChild(ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagOctetString, value, "AssertionValue"))
request.AppendChild(ava)
packet.AppendChild(request)
l.Debug.PrintPacket(packet)
channel, err := l.sendMessage(packet)
if err != nil {
return false, err
}
if channel == nil {
return false, NewError(ErrorNetwork, errors.New("ldap: could not send message"))
}
defer l.finishMessage(messageID)
l.Debug.Printf("%d: waiting for response", messageID)
packet = <-channel
l.Debug.Printf("%d: got response %p", messageID, packet)
if packet == nil {
return false, NewError(ErrorNetwork, errors.New("ldap: could not retrieve message"))
}
if l.Debug {
if err := addLDAPDescriptions(packet); err != nil {
return false, err
}
ber.PrintPacket(packet)
}
if packet.Children[1].Tag == ApplicationCompareResponse {
resultCode, resultDescription := getLDAPResultCode(packet)
if resultCode == LDAPResultCompareTrue {
return true, nil
} else if resultCode == LDAPResultCompareFalse {
return false, nil
} else {
return false, NewError(resultCode, errors.New(resultDescription))
}
}
return false, fmt.Errorf("Unexpected Response: %d", packet.Children[1].Tag)
}

View File

@ -7,11 +7,12 @@ package ldap
import (
"crypto/tls"
"errors"
"fmt"
"gopkg.in/asn1-ber.v1"
"log"
"net"
"sync"
"github.com/vanackere/asn1-ber"
"time"
)
const (
@ -23,71 +24,93 @@ const (
type messagePacket struct {
Op int
MessageID uint64
MessageID int64
Packet *ber.Packet
Channel chan *ber.Packet
}
type sendMessageFlags uint
const (
startTLS sendMessageFlags = 1 << iota
)
// Conn represents an LDAP Connection
type Conn struct {
conn net.Conn
isTLS bool
Debug debugging
chanConfirm chan bool
chanResults map[uint64]chan *ber.Packet
chanMessage chan *messagePacket
chanMessageID chan uint64
wgSender sync.WaitGroup
chanDone chan struct{}
once sync.Once
conn net.Conn
isTLS bool
isClosing bool
isStartingTLS bool
Debug debugging
chanConfirm chan bool
chanResults map[int64]chan *ber.Packet
chanMessage chan *messagePacket
chanMessageID chan int64
wgSender sync.WaitGroup
wgClose sync.WaitGroup
once sync.Once
outstandingRequests uint
messageMutex sync.Mutex
}
// DefaultTimeout is a package-level variable that sets the timeout value
// used for the Dial and DialTLS methods.
//
// WARNING: since this is a package-level variable, setting this value from
// multiple places will probably result in undesired behaviour.
var DefaultTimeout = 60 * time.Second
// Dial connects to the given address on the given network using net.Dial
// and then returns a new Conn for the connection.
func Dial(network, addr string) (*Conn, error) {
c, err := net.Dial(network, addr)
c, err := net.DialTimeout(network, addr, DefaultTimeout)
if err != nil {
return nil, NewError(ErrorNetwork, err)
}
conn := NewConn(c)
conn.start()
conn := NewConn(c, false)
conn.Start()
return conn, nil
}
// DialTLS connects to the given address on the given network using tls.Dial
// and then returns a new Conn for the connection.
func DialTLS(network, addr string, config *tls.Config) (*Conn, error) {
c, err := tls.Dial(network, addr, config)
dc, err := net.DialTimeout(network, addr, DefaultTimeout)
if err != nil {
return nil, NewError(ErrorNetwork, err)
}
conn := NewConn(c)
conn.isTLS = true
conn.start()
c := tls.Client(dc, config)
err = c.Handshake()
if err != nil {
return nil, NewError(ErrorNetwork, err)
}
conn := NewConn(c, true)
conn.Start()
return conn, nil
}
// NewConn returns a new Conn using conn for network I/O.
func NewConn(conn net.Conn) *Conn {
func NewConn(conn net.Conn, isTLS bool) *Conn {
return &Conn{
conn: conn,
chanConfirm: make(chan bool),
chanMessageID: make(chan uint64),
chanMessageID: make(chan int64),
chanMessage: make(chan *messagePacket, 10),
chanResults: map[uint64]chan *ber.Packet{},
chanDone: make(chan struct{}),
chanResults: map[int64]chan *ber.Packet{},
isTLS: isTLS,
}
}
func (l *Conn) start() {
func (l *Conn) Start() {
go l.reader()
go l.processMessages()
l.wgClose.Add(1)
}
// Close closes the connection.
func (l *Conn) Close() {
l.once.Do(func() {
close(l.chanDone)
l.isClosing = true
l.wgSender.Wait()
l.Debug.Printf("Sending quit message and waiting for confirmation")
@ -99,12 +122,14 @@ func (l *Conn) Close() {
if err := l.conn.Close(); err != nil {
log.Print(err)
}
l.wgClose.Done()
})
<-l.chanDone
l.wgClose.Wait()
}
// Returns the next available messageID
func (l *Conn) nextMessageID() uint64 {
func (l *Conn) nextMessageID() int64 {
if l.chanMessageID != nil {
if messageID, ok := <-l.chanMessageID; ok {
return messageID
@ -122,24 +147,28 @@ func (l *Conn) StartTLS(config *tls.Config) error {
}
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, int64(messageID), "MessageID"))
packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "MessageID"))
request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationExtendedRequest, nil, "Start TLS")
request.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, "1.3.6.1.4.1.1466.20037", "TLS Extended Command"))
packet.AppendChild(request)
l.Debug.PrintPacket(packet)
_, err := l.conn.Write(packet.Bytes())
channel, err := l.sendMessageWithFlags(packet, startTLS)
if err != nil {
return NewError(ErrorNetwork, err)
return err
}
if channel == nil {
return NewError(ErrorNetwork, errors.New("ldap: could not send message"))
}
packet, err = ber.ReadPacket(l.conn)
if err != nil {
return NewError(ErrorNetwork, err)
}
l.Debug.Printf("%d: waiting for response", messageID)
packet = <-channel
l.Debug.Printf("%d: got response %p", messageID, packet)
l.finishMessage(messageID)
if l.Debug {
if err := addLDAPDescriptions(packet); err != nil {
l.Close()
return err
}
ber.PrintPacket(packet)
@ -147,30 +176,50 @@ func (l *Conn) StartTLS(config *tls.Config) error {
if packet.Children[1].Children[0].Value.(int64) == 0 {
conn := tls.Client(l.conn, config)
if err := conn.Handshake(); err != nil {
l.Close()
return NewError(ErrorNetwork, fmt.Errorf("TLS handshake failed (%v)", err))
}
l.isTLS = true
l.conn = conn
}
go l.reader()
return nil
}
func (l *Conn) closing() bool {
select {
case <-l.chanDone:
return true
default:
return false
}
func (l *Conn) sendMessage(packet *ber.Packet) (chan *ber.Packet, error) {
return l.sendMessageWithFlags(packet, 0)
}
func (l *Conn) sendMessage(packet *ber.Packet) (chan *ber.Packet, error) {
if l.closing() {
func (l *Conn) sendMessageWithFlags(packet *ber.Packet, flags sendMessageFlags) (chan *ber.Packet, error) {
if l.isClosing {
return nil, NewError(ErrorNetwork, errors.New("ldap: connection closed"))
}
l.messageMutex.Lock()
l.Debug.Printf("flags&startTLS = %d", flags&startTLS)
if l.isStartingTLS {
l.messageMutex.Unlock()
return nil, NewError(ErrorNetwork, errors.New("ldap: connection is in startls phase."))
}
if flags&startTLS != 0 {
if l.outstandingRequests != 0 {
l.messageMutex.Unlock()
return nil, NewError(ErrorNetwork, errors.New("ldap: cannot StartTLS with outstanding requests"))
} else {
l.isStartingTLS = true
}
}
l.outstandingRequests++
l.messageMutex.Unlock()
out := make(chan *ber.Packet)
message := &messagePacket{
Op: MessageRequest,
MessageID: uint64(packet.Children[0].Value.(int64)),
MessageID: packet.Children[0].Value.(int64),
Packet: packet,
Channel: out,
}
@ -178,10 +227,18 @@ func (l *Conn) sendMessage(packet *ber.Packet) (chan *ber.Packet, error) {
return out, nil
}
func (l *Conn) finishMessage(messageID uint64) {
if l.closing() {
func (l *Conn) finishMessage(messageID int64) {
if l.isClosing {
return
}
l.messageMutex.Lock()
l.outstandingRequests--
if l.isStartingTLS {
l.isStartingTLS = false
}
l.messageMutex.Unlock()
message := &messagePacket{
Op: MessageFinish,
MessageID: messageID,
@ -190,18 +247,20 @@ func (l *Conn) finishMessage(messageID uint64) {
}
func (l *Conn) sendProcessMessage(message *messagePacket) bool {
l.wgSender.Add(1)
defer l.wgSender.Done()
if l.closing() {
if l.isClosing {
return false
}
l.wgSender.Add(1)
l.chanMessage <- message
l.wgSender.Done()
return true
}
func (l *Conn) processMessages() {
defer func() {
if err := recover(); err != nil {
log.Printf("ldap: recovered panic in processMessages: %v", err)
}
for messageID, channel := range l.chanResults {
l.Debug.Printf("Closing channel for MessageID %d", messageID)
close(channel)
@ -212,7 +271,7 @@ func (l *Conn) processMessages() {
close(l.chanConfirm)
}()
var messageID uint64 = 1
var messageID int64 = 1
for {
select {
case l.chanMessageID <- messageID:
@ -257,20 +316,42 @@ func (l *Conn) processMessages() {
}
func (l *Conn) reader() {
cleanstop := false
defer func() {
l.Close()
if err := recover(); err != nil {
log.Printf("ldap: recovered panic in reader: %v", err)
}
if !cleanstop {
l.Close()
}
}()
for {
if cleanstop {
l.Debug.Printf("reader clean stopping (without closing the connection)")
return
}
packet, err := ber.ReadPacket(l.conn)
if err != nil {
l.Debug.Printf("reader: %s", err.Error())
// A read error is expected here if we are closing the connection...
if !l.isClosing {
l.Debug.Printf("reader error: %s", err.Error())
}
return
}
addLDAPDescriptions(packet)
if len(packet.Children) == 0 {
l.Debug.Printf("Received bad ldap packet")
continue
}
l.messageMutex.Lock()
if l.isStartingTLS {
cleanstop = true
}
l.messageMutex.Unlock()
message := &messagePacket{
Op: MessageResponse,
MessageID: uint64(packet.Children[0].Value.(int64)),
MessageID: packet.Children[0].Value.(int64),
Packet: packet,
}
if !l.sendProcessMessage(message) {

View File

@ -6,16 +6,21 @@ package ldap
import (
"fmt"
"strconv"
"github.com/vanackere/asn1-ber"
"gopkg.in/asn1-ber.v1"
)
const (
ControlTypePaging = "1.2.840.113556.1.4.319"
ControlTypePaging = "1.2.840.113556.1.4.319"
ControlTypeBeheraPasswordPolicy = "1.3.6.1.4.1.42.2.27.8.5.1"
ControlTypeVChuPasswordMustChange = "2.16.840.1.113730.3.4.4"
ControlTypeVChuPasswordWarning = "2.16.840.1.113730.3.4.5"
)
var ControlTypeMap = map[string]string{
ControlTypePaging: "Paging",
ControlTypePaging: "Paging",
ControlTypeBeheraPasswordPolicy: "Password Policy - Behera Draft",
}
type Control interface {
@ -40,7 +45,7 @@ func (c *ControlString) Encode() *ber.Packet {
if c.Criticality {
packet.AppendChild(ber.NewBoolean(ber.ClassUniversal, ber.TypePrimitive, ber.TagBoolean, c.Criticality, "Criticality"))
}
packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, c.ControlValue, "Control Value"))
packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, string(c.ControlValue), "Control Value"))
return packet
}
@ -63,7 +68,7 @@ func (c *ControlPaging) Encode() *ber.Packet {
p2 := ber.Encode(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, nil, "Control Value (Paging)")
seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Search Control Value")
seq.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, int64(c.PagingSize), "Paging Size"))
seq.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, uint64(c.PagingSize), "Paging Size"))
cookie := ber.Encode(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, nil, "Cookie")
cookie.Value = c.Cookie
cookie.Data.Write(c.Cookie)
@ -88,6 +93,78 @@ func (c *ControlPaging) SetCookie(cookie []byte) {
c.Cookie = cookie
}
type ControlBeheraPasswordPolicy struct {
Expire int64
Grace int64
Error int8
ErrorString string
}
func (c *ControlBeheraPasswordPolicy) GetControlType() string {
return ControlTypeBeheraPasswordPolicy
}
func (c *ControlBeheraPasswordPolicy) Encode() *ber.Packet {
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Control")
packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, ControlTypeBeheraPasswordPolicy, "Control Type ("+ControlTypeMap[ControlTypeBeheraPasswordPolicy]+")"))
return packet
}
func (c *ControlBeheraPasswordPolicy) String() string {
return fmt.Sprintf(
"Control Type: %s (%q) Criticality: %t Expire: %d Grace: %d Error: %d, ErrorString: %s",
ControlTypeMap[ControlTypeBeheraPasswordPolicy],
ControlTypeBeheraPasswordPolicy,
false,
c.Expire,
c.Grace,
c.Error,
c.ErrorString)
}
type ControlVChuPasswordMustChange struct {
MustChange bool
}
func (c *ControlVChuPasswordMustChange) GetControlType() string {
return ControlTypeVChuPasswordMustChange
}
func (c *ControlVChuPasswordMustChange) Encode() *ber.Packet {
return nil
}
func (c *ControlVChuPasswordMustChange) String() string {
return fmt.Sprintf(
"Control Type: %s (%q) Criticality: %t MustChange: %b",
ControlTypeMap[ControlTypeVChuPasswordMustChange],
ControlTypeVChuPasswordMustChange,
false,
c.MustChange)
}
type ControlVChuPasswordWarning struct {
Expire int64
}
func (c *ControlVChuPasswordWarning) GetControlType() string {
return ControlTypeVChuPasswordWarning
}
func (c *ControlVChuPasswordWarning) Encode() *ber.Packet {
return nil
}
func (c *ControlVChuPasswordWarning) String() string {
return fmt.Sprintf(
"Control Type: %s (%q) Criticality: %t Expire: %b",
ControlTypeMap[ControlTypeVChuPasswordWarning],
ControlTypeVChuPasswordWarning,
false,
c.Expire)
}
func FindControl(controls []Control, controlType string) Control {
for _, c := range controls {
if c.GetControlType() == controlType {
@ -127,6 +204,64 @@ func DecodeControl(packet *ber.Packet) Control {
c.PagingSize = uint32(value.Children[0].Value.(int64))
c.Cookie = value.Children[1].Data.Bytes()
value.Children[1].Value = c.Cookie
return c
case ControlTypeBeheraPasswordPolicy:
value.Description += " (Password Policy - Behera)"
c := NewControlBeheraPasswordPolicy()
if value.Value != nil {
valueChildren := ber.DecodePacket(value.Data.Bytes())
value.Data.Truncate(0)
value.Value = nil
value.AppendChild(valueChildren)
}
sequence := value.Children[0]
for _, child := range sequence.Children {
if child.Tag == 0 {
//Warning
child := child.Children[0]
packet := ber.DecodePacket(child.Data.Bytes())
val, ok := packet.Value.(int64)
if ok {
if child.Tag == 0 {
//timeBeforeExpiration
c.Expire = val
child.Value = c.Expire
} else if child.Tag == 1 {
//graceAuthNsRemaining
c.Grace = val
child.Value = c.Grace
}
}
} else if child.Tag == 1 {
// Error
packet := ber.DecodePacket(child.Data.Bytes())
val, ok := packet.Value.(int8)
if !ok {
// what to do?
val = -1
}
c.Error = val
child.Value = c.Error
c.ErrorString = BeheraPasswordPolicyErrorMap[c.Error]
}
}
return c
case ControlTypeVChuPasswordMustChange:
c := &ControlVChuPasswordMustChange{MustChange: true}
return c
case ControlTypeVChuPasswordWarning:
c := &ControlVChuPasswordWarning{Expire: -1}
expireStr := ber.DecodeString(value.Data.Bytes())
expire, err := strconv.ParseInt(expireStr, 10, 64)
if err != nil {
return nil
}
c.Expire = expire
value.Value = c.Expire
return c
}
c := new(ControlString)
@ -148,6 +283,14 @@ func NewControlPaging(pagingSize uint32) *ControlPaging {
return &ControlPaging{PagingSize: pagingSize}
}
func NewControlBeheraPasswordPolicy() *ControlBeheraPasswordPolicy {
return &ControlBeheraPasswordPolicy{
Expire: -1,
Grace: -1,
Error: -1,
}
}
func encodeControls(controls []Control) *ber.Packet {
packet := ber.Encode(ber.ClassContext, ber.TypeConstructed, 0, nil, "Controls")
for _, control := range controls {

View File

@ -3,7 +3,7 @@ package ldap
import (
"log"
"github.com/vanackere/asn1-ber"
"gopkg.in/asn1-ber.v1"
)
// debbuging type

155
Godeps/_workspace/src/github.com/go-ldap/ldap/dn.go generated vendored Normal file
View File

@ -0,0 +1,155 @@
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//
// File contains DN parsing functionallity
//
// https://tools.ietf.org/html/rfc4514
//
// distinguishedName = [ relativeDistinguishedName
// *( COMMA relativeDistinguishedName ) ]
// relativeDistinguishedName = attributeTypeAndValue
// *( PLUS attributeTypeAndValue )
// attributeTypeAndValue = attributeType EQUALS attributeValue
// attributeType = descr / numericoid
// attributeValue = string / hexstring
//
// ; The following characters are to be escaped when they appear
// ; in the value to be encoded: ESC, one of <escaped>, leading
// ; SHARP or SPACE, trailing SPACE, and NULL.
// string = [ ( leadchar / pair ) [ *( stringchar / pair )
// ( trailchar / pair ) ] ]
//
// leadchar = LUTF1 / UTFMB
// LUTF1 = %x01-1F / %x21 / %x24-2A / %x2D-3A /
// %x3D / %x3F-5B / %x5D-7F
//
// trailchar = TUTF1 / UTFMB
// TUTF1 = %x01-1F / %x21 / %x23-2A / %x2D-3A /
// %x3D / %x3F-5B / %x5D-7F
//
// stringchar = SUTF1 / UTFMB
// SUTF1 = %x01-21 / %x23-2A / %x2D-3A /
// %x3D / %x3F-5B / %x5D-7F
//
// pair = ESC ( ESC / special / hexpair )
// special = escaped / SPACE / SHARP / EQUALS
// escaped = DQUOTE / PLUS / COMMA / SEMI / LANGLE / RANGLE
// hexstring = SHARP 1*hexpair
// hexpair = HEX HEX
//
// where the productions <descr>, <numericoid>, <COMMA>, <DQUOTE>,
// <EQUALS>, <ESC>, <HEX>, <LANGLE>, <NULL>, <PLUS>, <RANGLE>, <SEMI>,
// <SPACE>, <SHARP>, and <UTFMB> are defined in [RFC4512].
//
package ldap
import (
"bytes"
"errors"
"fmt"
"strings"
enchex "encoding/hex"
ber "gopkg.in/asn1-ber.v1"
)
type AttributeTypeAndValue struct {
Type string
Value string
}
type RelativeDN struct {
Attributes []*AttributeTypeAndValue
}
type DN struct {
RDNs []*RelativeDN
}
func ParseDN(str string) (*DN, error) {
dn := new(DN)
dn.RDNs = make([]*RelativeDN, 0)
rdn := new (RelativeDN)
rdn.Attributes = make([]*AttributeTypeAndValue, 0)
buffer := bytes.Buffer{}
attribute := new(AttributeTypeAndValue)
escaping := false
for i := 0; i < len(str); i++ {
char := str[i]
if escaping {
escaping = false
switch char {
case ' ', '"', '#', '+', ',', ';', '<', '=', '>', '\\':
buffer.WriteByte(char)
continue
}
// Not a special character, assume hex encoded octet
if len(str) == i+1 {
return nil, errors.New("Got corrupted escaped character")
}
dst := []byte{0}
n, err := enchex.Decode([]byte(dst), []byte(str[i:i+2]))
if err != nil {
return nil, errors.New(
fmt.Sprintf("Failed to decode escaped character: %s", err))
} else if n != 1 {
return nil, errors.New(
fmt.Sprintf("Expected 1 byte when un-escaping, got %d", n))
}
buffer.WriteByte(dst[0])
i++
} else if char == '\\' {
escaping = true
} else if char == '=' {
attribute.Type = buffer.String()
buffer.Reset()
// Special case: If the first character in the value is # the
// following data is BER encoded so we can just fast forward
// and decode.
if len(str) > i+1 && str[i+1] == '#' {
i += 2
index := strings.IndexAny(str[i:], ",+")
data := str
if index > 0 {
data = str[i:i+index]
} else {
data = str[i:]
}
raw_ber, err := enchex.DecodeString(data)
if err != nil {
return nil, errors.New(
fmt.Sprintf("Failed to decode BER encoding: %s", err))
}
packet := ber.DecodePacket(raw_ber)
buffer.WriteString(packet.Data.String())
i += len(data)-1
}
} else if char == ',' || char == '+' {
// We're done with this RDN or value, push it
attribute.Value = buffer.String()
rdn.Attributes = append(rdn.Attributes, attribute)
attribute = new(AttributeTypeAndValue)
if char == ',' {
dn.RDNs = append(dn.RDNs, rdn)
rdn = new(RelativeDN)
rdn.Attributes = make([]*AttributeTypeAndValue, 0)
}
buffer.Reset()
} else {
buffer.WriteByte(char)
}
}
if buffer.Len() > 0 {
if len(attribute.Type) == 0 {
return nil, errors.New("DN ended with incomplete type, value pair")
}
attribute.Value = buffer.String()
rdn.Attributes = append(rdn.Attributes, attribute)
dn.RDNs = append(dn.RDNs, rdn)
}
return dn, nil
}

View File

@ -0,0 +1,70 @@
package ldap
import (
"reflect"
"testing"
)
func TestSuccessfulDNParsing(t *testing.T) {
testcases := map[string]DN {
"": DN{[]*RelativeDN{}},
"cn=Jim\\2C \\22Hasse Hö\\22 Hansson!,dc=dummy,dc=com": DN{[]*RelativeDN{
&RelativeDN{[]*AttributeTypeAndValue{&AttributeTypeAndValue{"cn", "Jim, \"Hasse Hö\" Hansson!"},}},
&RelativeDN{[]*AttributeTypeAndValue{&AttributeTypeAndValue{"dc", "dummy"},}},
&RelativeDN{[]*AttributeTypeAndValue{&AttributeTypeAndValue{"dc", "com"}, }},}},
"UID=jsmith,DC=example,DC=net": DN{[]*RelativeDN{
&RelativeDN{[]*AttributeTypeAndValue{&AttributeTypeAndValue{"UID", "jsmith"},}},
&RelativeDN{[]*AttributeTypeAndValue{&AttributeTypeAndValue{"DC", "example"},}},
&RelativeDN{[]*AttributeTypeAndValue{&AttributeTypeAndValue{"DC", "net"}, }},}},
"OU=Sales+CN=J. Smith,DC=example,DC=net": DN{[]*RelativeDN{
&RelativeDN{[]*AttributeTypeAndValue{
&AttributeTypeAndValue{"OU", "Sales"},
&AttributeTypeAndValue{"CN", "J. Smith"},}},
&RelativeDN{[]*AttributeTypeAndValue{&AttributeTypeAndValue{"DC", "example"},}},
&RelativeDN{[]*AttributeTypeAndValue{&AttributeTypeAndValue{"DC", "net"}, }},}},
"1.3.6.1.4.1.1466.0=#04024869": DN{[]*RelativeDN{
&RelativeDN{[]*AttributeTypeAndValue{&AttributeTypeAndValue{"1.3.6.1.4.1.1466.0", "Hi"},}},}},
"1.3.6.1.4.1.1466.0=#04024869,DC=net": DN{[]*RelativeDN{
&RelativeDN{[]*AttributeTypeAndValue{&AttributeTypeAndValue{"1.3.6.1.4.1.1466.0", "Hi"},}},
&RelativeDN{[]*AttributeTypeAndValue{&AttributeTypeAndValue{"DC", "net"}, }},}},
"CN=Lu\\C4\\8Di\\C4\\87": DN{[]*RelativeDN{
&RelativeDN{[]*AttributeTypeAndValue{&AttributeTypeAndValue{"CN", "Lučić"},}},}},
}
for test, answer := range testcases {
dn, err := ParseDN(test)
if err != nil {
t.Errorf(err.Error())
continue
}
if !reflect.DeepEqual(dn, &answer) {
t.Errorf("Parsed DN %s is not equal to the expected structure", test)
for _, rdn := range dn.RDNs {
for _, attribs := range rdn.Attributes {
t.Logf("#%v\n", attribs)
}
}
}
}
}
func TestErrorDNParsing(t *testing.T) {
testcases := map[string]string {
"*": "DN ended with incomplete type, value pair",
"cn=Jim\\0Test": "Failed to decode escaped character: encoding/hex: invalid byte: U+0054 'T'",
"cn=Jim\\0": "Got corrupted escaped character",
"DC=example,=net": "DN ended with incomplete type, value pair",
"1=#0402486": "Failed to decode BER encoding: encoding/hex: odd length hex string",
}
for test, answer := range testcases {
_, err := ParseDN(test)
if err == nil {
t.Errorf("Expected %s to fail parsing but succeeded\n", test)
} else if err.Error() != answer {
t.Errorf("Unexpected error on %s:\n%s\nvs.\n%s\n", test, answer, err.Error())
}
}
}

View File

@ -0,0 +1,60 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"fmt"
"log"
"github.com/go-ldap/ldap"
)
var (
ldapServer string = "localhost"
ldapPort uint16 = 389
baseDN string = "dc=enterprise,dc=org"
user string = "cn=kirkj,ou=crew,dc=enterprise,dc=org"
passwd string = "*"
)
func main() {
l, err := ldap.Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort))
if err != nil {
log.Fatalf("ERROR: %s\n", err.Error())
}
defer l.Close()
l.Debug = true
controls := []ldap.Control{}
controls = append(controls, ldap.NewControlBeheraPasswordPolicy())
bindRequest := ldap.NewSimpleBindRequest(user, passwd, controls)
r, err := l.SimpleBind(bindRequest)
ppolicyControl := ldap.FindControl(r.Controls, ldap.ControlTypeBeheraPasswordPolicy)
var ppolicy *ldap.ControlBeheraPasswordPolicy
if ppolicyControl != nil {
ppolicy = ppolicyControl.(*ldap.ControlBeheraPasswordPolicy)
} else {
log.Printf("ppolicyControl response not avaliable.\n")
}
if err != nil {
errStr := "ERROR: Cannot bind: " + err.Error()
if ppolicy != nil && ppolicy.Error >= 0 {
errStr += ":" + ppolicy.ErrorString
}
log.Print(errStr)
} else {
logStr := "Login Ok"
if ppolicy != nil {
if ppolicy.Expire >= 0 {
logStr += fmt.Sprintf(". Password expires in %d seconds\n", ppolicy.Expire)
} else if ppolicy.Grace >= 0 {
logStr += fmt.Sprintf(". Password expired, %d grace logins remain\n", ppolicy.Grace)
}
}
log.Print(logStr)
}
}

View File

@ -0,0 +1,39 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"fmt"
"log"
"github.com/go-ldap/ldap"
)
var (
ldapServer string = "localhost"
ldapPort uint16 = 389
user string = "*"
passwd string = "*"
dn string = "uid=*,cn=*,dc=*,dc=*"
attribute string = "uid"
value string = "username"
)
func main() {
l, err := ldap.Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort))
if err != nil {
log.Fatalf("ERROR: %s\n", err.Error())
}
defer l.Close()
// l.Debug = true
err = l.Bind(user, passwd)
if err != nil {
log.Printf("ERROR: Cannot bind: %s\n", err.Error())
return
}
fmt.Println(l.Compare(dn, attribute, value))
}

View File

@ -0,0 +1,125 @@
package main
import (
"fmt"
"log"
"github.com/go-ldap/ldap"
)
// Example password policy. For this test pwdMinAge is 0 or subsequent password
// changes will fail.
//
// dn: cn=default,ou=policies,dc=enterprise,dc=org
// objectClass: pwdPolicy
// objectClass: person
// objectClass: top
// cn: default
// pwdAllowUserChange: TRUE
// pwdAttribute: userPassword
// pwdCheckQuality: 2
// pwdExpireWarning: 300
// pwdFailureCountInterval: 30
// pwdGraceAuthNLimit: 5
// pwdInHistory: 0
// pwdLockout: TRUE
// pwdLockoutDuration: 0
// pwdMaxAge: 300
// pwdMaxFailure: 0
// pwdMinAge: 0
// pwdMinLength: 5
// pwdMustChange: TRUE
// pwdSafeModify: TRUE
// sn: dummy value
var (
ldapServer string = "localhost"
ldapPort uint16 = 389
baseDN string = "dc=enterprise,dc=org"
adminUser string = "cn=admin,dc=enterprise,dc=org"
adminPassword string = "*"
user string = "cn=kirkj,ou=crew,dc=enterprise,dc=org"
oldPassword string = "*"
password1 string = "password123"
password2 string = "password1234"
)
const (
debug = false
)
func login(user string, password string) (*ldap.Conn, error) {
l, err := ldap.Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort))
if err != nil {
return nil, err
}
l.Debug = debug
bindRequest := ldap.NewSimpleBindRequest(user, password, nil)
_, err = l.SimpleBind(bindRequest)
if err != nil {
return nil, err
}
return l, nil
}
func main() {
// Login as the admin and change the password of an user (without providing the old password)
log.Printf("Logging in as the admin and changing the password of user (without providing the old password")
l, err := login(adminUser, adminPassword)
if err != nil {
log.Fatalf("ERROR: %s\n", err.Error())
}
passwordModifyRequest := ldap.NewPasswordModifyRequest(user, "", password1)
_, err = l.PasswordModify(passwordModifyRequest)
if err != nil {
l.Close()
log.Fatalf("ERROR: Cannot change password: %s\n", err)
}
log.Printf("Done")
l.Close()
// Login as the user and change the password without providing a new password.
log.Printf("Logging in as the user and changing the password without providing a new one")
l, err = login(user, password1)
if err != nil {
log.Fatalf("ERROR: %s\n", err.Error())
}
passwordModifyRequest = ldap.NewPasswordModifyRequest("", password1, "")
passwordModifyResponse, err := l.PasswordModify(passwordModifyRequest)
if err != nil {
l.Close()
log.Fatalf("ERROR: Cannot change password: %s\n", err)
}
generatedPassword := passwordModifyResponse.GeneratedPassword
log.Printf("Done. Generated password: %s\n", generatedPassword)
l.Close()
// Login as the user with the generated password and change it to another one
log.Printf("Logging in as the user and changing the password")
l, err = login(user, generatedPassword)
if err != nil {
log.Fatalf("ERROR: %s\n", err.Error())
}
passwordModifyRequest = ldap.NewPasswordModifyRequest("", generatedPassword, password2)
_, err = l.PasswordModify(passwordModifyRequest)
if err != nil {
l.Close()
log.Fatalf("ERROR: Cannot change password: %s\n", err)
}
log.Printf("Done")
l.Close()
}

View File

@ -1,5 +1,3 @@
// +build ignore
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
@ -7,10 +5,11 @@
package main
import (
"crypto/tls"
"fmt"
"log"
"github.com/vanackere/ldap"
"gopkg.in/ldap.v1"
)
var (
@ -22,12 +21,12 @@ var (
)
func main() {
l, err := ldap.DialTLS("tcp", fmt.Sprintf("%s:%d", LdapServer, LdapPort), nil)
l, err := ldap.Dial("tcp", fmt.Sprintf("%s:%d", LdapServer, LdapPort))
if err != nil {
log.Fatalf("ERROR: %s\n", err.Error())
log.Fatalf("ERROR: %s\n", err)
}
defer l.Close()
// l.Debug = true
l.Debug = true
search := ldap.NewSearchRequest(
BaseDN,
@ -36,10 +35,24 @@ func main() {
Attributes,
nil)
// First search without tls.
sr, err := l.Search(search)
if err != nil {
log.Fatalf("ERROR: %s\n", err.Error())
return
log.Printf("ERROR: %s\n", err)
}
log.Printf("Search: %s -> num of entries = %d\n", search.Filter, len(sr.Entries))
sr.PrettyPrint(0)
// Then startTLS
err = l.StartTLS(&tls.Config{InsecureSkipVerify: true})
if err != nil {
log.Fatalf("ERROR: %s\n", err)
}
sr, err = l.Search(search)
if err != nil {
log.Printf("ERROR: %s\n", err)
}
log.Printf("Search: %s -> num of entries = %d\n", search.Filter, len(sr.Entries))

View File

@ -0,0 +1,63 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"fmt"
"log"
"gopkg.in/ldap.v1"
)
var (
ldapServer string = "localhost"
ldapPort uint16 = 389
baseDN string = "dc=enterprise,dc=org"
user string = "uid=kirkj,ou=crew,dc=enterprise,dc=org"
passwd string = "*"
)
func main() {
l, err := ldap.Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort))
if err != nil {
log.Fatalf("ERROR: %s\n", err.Error())
}
defer l.Close()
l.Debug = true
bindRequest := ldap.NewSimpleBindRequest(user, passwd, nil)
r, err := l.SimpleBind(bindRequest)
passwordMustChangeControl := ldap.FindControl(r.Controls, ldap.ControlTypeVChuPasswordMustChange)
var passwordMustChange *ldap.ControlVChuPasswordMustChange
if passwordMustChangeControl != nil {
passwordMustChange = passwordMustChangeControl.(*ldap.ControlVChuPasswordMustChange)
}
if passwordMustChange != nil && passwordMustChange.MustChange {
log.Printf("Password Must be changed.\n")
}
passwordWarningControl := ldap.FindControl(r.Controls, ldap.ControlTypeVChuPasswordWarning)
var passwordWarning *ldap.ControlVChuPasswordWarning
if passwordWarningControl != nil {
passwordWarning = passwordWarningControl.(*ldap.ControlVChuPasswordWarning)
} else {
log.Printf("ppolicyControl response not available.\n")
}
if err != nil {
log.Print("ERROR: Cannot bind: " + err.Error())
} else {
logStr := "Login Ok"
if passwordWarning != nil {
if passwordWarning.Expire >= 0 {
logStr += fmt.Sprintf(". Password expires in %d seconds\n", passwordWarning.Expire)
}
}
log.Print(logStr)
}
}

View File

@ -7,24 +7,25 @@ package ldap
import (
"errors"
"fmt"
"strings"
"github.com/vanackere/asn1-ber"
"gopkg.in/asn1-ber.v1"
)
const (
FilterAnd ber.Tag = 0
FilterOr ber.Tag = 1
FilterNot ber.Tag = 2
FilterEqualityMatch ber.Tag = 3
FilterSubstrings ber.Tag = 4
FilterGreaterOrEqual ber.Tag = 5
FilterLessOrEqual ber.Tag = 6
FilterPresent ber.Tag = 7
FilterApproxMatch ber.Tag = 8
FilterExtensibleMatch ber.Tag = 9
FilterAnd = 0
FilterOr = 1
FilterNot = 2
FilterEqualityMatch = 3
FilterSubstrings = 4
FilterGreaterOrEqual = 5
FilterLessOrEqual = 6
FilterPresent = 7
FilterApproxMatch = 8
FilterExtensibleMatch = 9
)
var filterMap = map[ber.Tag]string{
var FilterMap = map[uint64]string{
FilterAnd: "And",
FilterOr: "Or",
FilterNot: "Not",
@ -43,6 +44,12 @@ const (
FilterSubstringsFinal = 2
)
var FilterSubstringsMap = map[uint64]string{
FilterSubstringsInitial: "Substrings Initial",
FilterSubstringsAny: "Substrings Any",
FilterSubstringsFinal: "Substrings Final",
}
func CompileFilter(filter string) (*ber.Packet, error) {
if len(filter) == 0 || filter[0] != '(' {
return nil, NewError(ErrorFilterCompile, errors.New("ldap: filter does not start with an '('"))
@ -97,13 +104,14 @@ func DecompileFilter(packet *ber.Packet) (ret string, err error) {
case FilterSubstrings:
ret += ber.DecodeString(packet.Children[0].Data.Bytes())
ret += "="
switch packet.Children[1].Children[0].Tag {
case FilterSubstringsInitial:
ret += ber.DecodeString(packet.Children[1].Children[0].Data.Bytes()) + "*"
case FilterSubstringsAny:
ret += "*" + ber.DecodeString(packet.Children[1].Children[0].Data.Bytes()) + "*"
case FilterSubstringsFinal:
ret += "*" + ber.DecodeString(packet.Children[1].Children[0].Data.Bytes())
for i, child := range packet.Children[1].Children {
if i == 0 && child.Tag != FilterSubstringsInitial {
ret += "*"
}
ret += ber.DecodeString(child.Data.Bytes())
if child.Tag != FilterSubstringsFinal {
ret += "*"
}
}
case FilterEqualityMatch:
ret += ber.DecodeString(packet.Children[0].Data.Bytes())
@ -163,15 +171,15 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) {
newPos++
return packet, newPos, err
case '&':
packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterAnd, nil, filterMap[FilterAnd])
packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterAnd, nil, FilterMap[FilterAnd])
newPos, err = compileFilterSet(filter, pos+1, packet)
return packet, newPos, err
case '|':
packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterOr, nil, filterMap[FilterOr])
packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterOr, nil, FilterMap[FilterOr])
newPos, err = compileFilterSet(filter, pos+1, packet)
return packet, newPos, err
case '!':
packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterNot, nil, filterMap[FilterNot])
packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterNot, nil, FilterMap[FilterNot])
var child *ber.Packet
child, newPos, err = compileFilter(filter, pos+1)
packet.AppendChild(child)
@ -184,15 +192,15 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) {
case packet != nil:
condition += fmt.Sprintf("%c", filter[newPos])
case filter[newPos] == '=':
packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterEqualityMatch, nil, filterMap[FilterEqualityMatch])
packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterEqualityMatch, nil, FilterMap[FilterEqualityMatch])
case filter[newPos] == '>' && filter[newPos+1] == '=':
packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterGreaterOrEqual, nil, filterMap[FilterGreaterOrEqual])
packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterGreaterOrEqual, nil, FilterMap[FilterGreaterOrEqual])
newPos++
case filter[newPos] == '<' && filter[newPos+1] == '=':
packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterLessOrEqual, nil, filterMap[FilterLessOrEqual])
packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterLessOrEqual, nil, FilterMap[FilterLessOrEqual])
newPos++
case filter[newPos] == '~' && filter[newPos+1] == '=':
packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterApproxMatch, nil, filterMap[FilterLessOrEqual])
packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterApproxMatch, nil, FilterMap[FilterLessOrEqual])
newPos++
case packet == nil:
attribute += fmt.Sprintf("%c", filter[newPos])
@ -207,40 +215,37 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) {
err = NewError(ErrorFilterCompile, errors.New("ldap: error parsing filter"))
return packet, newPos, err
}
// Handle FilterEqualityMatch as a separate case (is primitive, not constructed like the other filters)
if packet.Tag == FilterEqualityMatch && condition == "*" {
packet.TagType = ber.TypePrimitive
packet.Tag = FilterPresent
packet.Description = filterMap[packet.Tag]
packet.Data.WriteString(attribute)
return packet, newPos + 1, nil
}
packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, attribute, "Attribute"))
switch {
case packet.Tag == FilterEqualityMatch && condition[0] == '*' && condition[len(condition)-1] == '*':
// Any
case packet.Tag == FilterEqualityMatch && condition == "*":
packet = ber.NewString(ber.ClassContext, ber.TypePrimitive, FilterPresent, attribute, FilterMap[FilterPresent])
case packet.Tag == FilterEqualityMatch && strings.Contains(condition, "*"):
packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, attribute, "Attribute"))
packet.Tag = FilterSubstrings
packet.Description = filterMap[packet.Tag]
packet.Description = FilterMap[uint64(packet.Tag)]
seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Substrings")
seq.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, FilterSubstringsAny, condition[1:len(condition)-1], "Any Substring"))
packet.AppendChild(seq)
case packet.Tag == FilterEqualityMatch && condition[0] == '*':
// Final
packet.Tag = FilterSubstrings
packet.Description = filterMap[packet.Tag]
seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Substrings")
seq.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, FilterSubstringsFinal, condition[1:], "Final Substring"))
packet.AppendChild(seq)
case packet.Tag == FilterEqualityMatch && condition[len(condition)-1] == '*':
// Initial
packet.Tag = FilterSubstrings
packet.Description = filterMap[packet.Tag]
seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Substrings")
seq.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, FilterSubstringsInitial, condition[:len(condition)-1], "Initial Substring"))
parts := strings.Split(condition, "*")
for i, part := range parts {
if part == "" {
continue
}
var tag ber.Tag
switch i {
case 0:
tag = FilterSubstringsInitial
case len(parts) - 1:
tag = FilterSubstringsFinal
default:
tag = FilterSubstringsAny
}
seq.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, tag, part, FilterSubstringsMap[uint64(tag)]))
}
packet.AppendChild(seq)
default:
packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, attribute, "Attribute"))
packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, condition, "Condition"))
}
newPos++
return packet, newPos, err
}

View File

@ -1,15 +1,14 @@
package ldap
import (
"reflect"
"testing"
"github.com/vanackere/asn1-ber"
"gopkg.in/asn1-ber.v1"
)
type compileTest struct {
filterStr string
filterType ber.Tag
filterType int
}
var testFilters = []compileTest{
@ -20,6 +19,10 @@ var testFilters = []compileTest{
compileTest{filterStr: "(sn=Mill*)", filterType: FilterSubstrings},
compileTest{filterStr: "(sn=*Mill)", filterType: FilterSubstrings},
compileTest{filterStr: "(sn=*Mill*)", filterType: FilterSubstrings},
compileTest{filterStr: "(sn=*i*le*)", filterType: FilterSubstrings},
compileTest{filterStr: "(sn=Mi*l*r)", filterType: FilterSubstrings},
compileTest{filterStr: "(sn=Mi*le*)", filterType: FilterSubstrings},
compileTest{filterStr: "(sn=*i*ler)", filterType: FilterSubstrings},
compileTest{filterStr: "(sn>=Miller)", filterType: FilterGreaterOrEqual},
compileTest{filterStr: "(sn<=Miller)", filterType: FilterLessOrEqual},
compileTest{filterStr: "(sn=*)", filterType: FilterPresent},
@ -33,8 +36,8 @@ func TestFilter(t *testing.T) {
filter, err := CompileFilter(i.filterStr)
if err != nil {
t.Errorf("Problem compiling %s - %s", i.filterStr, err.Error())
} else if filter.Tag != i.filterType {
t.Errorf("%q Expected %q got %q", i.filterStr, filterMap[i.filterType], filterMap[filter.Tag])
} else if filter.Tag != ber.Tag(i.filterType) {
t.Errorf("%q Expected %q got %q", i.filterStr, FilterMap[uint64(i.filterType)], FilterMap[uint64(filter.Tag)])
} else {
o, err := DecompileFilter(filter)
if err != nil {
@ -46,40 +49,6 @@ func TestFilter(t *testing.T) {
}
}
type binTestFilter struct {
bin []byte
str string
}
var binTestFilters = []binTestFilter{
{bin: []byte{0x87, 0x06, 0x6d, 0x65, 0x6d, 0x62, 0x65, 0x72}, str: "(member=*)"},
}
func TestFiltersDecode(t *testing.T) {
for i, test := range binTestFilters {
p := ber.DecodePacket(test.bin)
if filter, err := DecompileFilter(p); err != nil {
t.Errorf("binTestFilters[%d], DecompileFilter returned : %s", i, err)
} else if filter != test.str {
t.Errorf("binTestFilters[%d], %q expected, got %q", i, test.str, filter)
}
}
}
func TestFiltersEncode(t *testing.T) {
for i, test := range binTestFilters {
p, err := CompileFilter(test.str)
if err != nil {
t.Errorf("binTestFilters[%d], CompileFilter returned : %s", i, err)
continue
}
b := p.Bytes()
if !reflect.DeepEqual(b, test.bin) {
t.Errorf("binTestFilters[%d], %q expected for CompileFilter(%q), got %q", i, test.bin, test.str, b)
}
}
}
func BenchmarkFilterCompile(b *testing.B) {
b.StopTimer()
filters := make([]string, len(testFilters))

View File

@ -10,34 +10,34 @@ import (
"io/ioutil"
"os"
"github.com/vanackere/asn1-ber"
ber "gopkg.in/asn1-ber.v1"
)
// LDAP Application Codes
const (
ApplicationBindRequest ber.Tag = 0
ApplicationBindResponse ber.Tag = 1
ApplicationUnbindRequest ber.Tag = 2
ApplicationSearchRequest ber.Tag = 3
ApplicationSearchResultEntry ber.Tag = 4
ApplicationSearchResultDone ber.Tag = 5
ApplicationModifyRequest ber.Tag = 6
ApplicationModifyResponse ber.Tag = 7
ApplicationAddRequest ber.Tag = 8
ApplicationAddResponse ber.Tag = 9
ApplicationDelRequest ber.Tag = 10
ApplicationDelResponse ber.Tag = 11
ApplicationModifyDNRequest ber.Tag = 12
ApplicationModifyDNResponse ber.Tag = 13
ApplicationCompareRequest ber.Tag = 14
ApplicationCompareResponse ber.Tag = 15
ApplicationAbandonRequest ber.Tag = 16
ApplicationSearchResultReference ber.Tag = 19
ApplicationExtendedRequest ber.Tag = 23
ApplicationExtendedResponse ber.Tag = 24
ApplicationBindRequest = 0
ApplicationBindResponse = 1
ApplicationUnbindRequest = 2
ApplicationSearchRequest = 3
ApplicationSearchResultEntry = 4
ApplicationSearchResultDone = 5
ApplicationModifyRequest = 6
ApplicationModifyResponse = 7
ApplicationAddRequest = 8
ApplicationAddResponse = 9
ApplicationDelRequest = 10
ApplicationDelResponse = 11
ApplicationModifyDNRequest = 12
ApplicationModifyDNResponse = 13
ApplicationCompareRequest = 14
ApplicationCompareResponse = 15
ApplicationAbandonRequest = 16
ApplicationSearchResultReference = 19
ApplicationExtendedRequest = 23
ApplicationExtendedResponse = 24
)
var ApplicationMap = map[ber.Tag]string{
var ApplicationMap = map[uint8]string{
ApplicationBindRequest: "Bind Request",
ApplicationBindResponse: "Bind Response",
ApplicationUnbindRequest: "Unbind Request",
@ -102,10 +102,12 @@ const (
LDAPResultAffectsMultipleDSAs = 71
LDAPResultOther = 80
ErrorNetwork = 200
ErrorFilterCompile = 201
ErrorFilterDecompile = 202
ErrorDebugging = 203
ErrorNetwork = 200
ErrorFilterCompile = 201
ErrorFilterDecompile = 202
ErrorDebugging = 203
ErrorUnexpectedMessage = 204
ErrorUnexpectedResponse = 205
)
var LDAPResultCodeMap = map[uint8]string{
@ -150,6 +152,31 @@ var LDAPResultCodeMap = map[uint8]string{
LDAPResultOther: "Other",
}
// Ldap Behera Password Policy Draft 10 (https://tools.ietf.org/html/draft-behera-ldap-password-policy-10)
const (
BeheraPasswordExpired = 0
BeheraAccountLocked = 1
BeheraChangeAfterReset = 2
BeheraPasswordModNotAllowed = 3
BeheraMustSupplyOldPassword = 4
BeheraInsufficientPasswordQuality = 5
BeheraPasswordTooShort = 6
BeheraPasswordTooYoung = 7
BeheraPasswordInHistory = 8
)
var BeheraPasswordPolicyErrorMap = map[int8]string{
BeheraPasswordExpired: "Password expired",
BeheraAccountLocked: "Account locked",
BeheraChangeAfterReset: "Password must be changed",
BeheraPasswordModNotAllowed: "Policy prevents password modification",
BeheraMustSupplyOldPassword: "Policy requires old password in order to change password",
BeheraInsufficientPasswordQuality: "Password fails quality checks",
BeheraPasswordTooShort: "Password is too short for policy",
BeheraPasswordTooYoung: "Password has been changed too recently",
BeheraPasswordInHistory: "New password is in list of old passwords",
}
// Adds descriptions to an LDAP Response packet for debugging
func addLDAPDescriptions(packet *ber.Packet) (err error) {
defer func() {
@ -160,7 +187,7 @@ func addLDAPDescriptions(packet *ber.Packet) (err error) {
packet.Description = "LDAP Response"
packet.Children[0].Description = "Message ID"
application := packet.Children[1].Tag
application := uint8(packet.Children[1].Tag)
packet.Children[1].Description = ApplicationMap[application]
switch application {
@ -239,6 +266,44 @@ func addControlDescriptions(packet *ber.Packet) {
value.Children[0].Description = "Real Search Control Value"
value.Children[0].Children[0].Description = "Paging Size"
value.Children[0].Children[1].Description = "Cookie"
case ControlTypeBeheraPasswordPolicy:
value.Description += " (Password Policy - Behera Draft)"
if value.Value != nil {
valueChildren := ber.DecodePacket(value.Data.Bytes())
value.Data.Truncate(0)
value.Value = nil
value.AppendChild(valueChildren)
}
sequence := value.Children[0]
for _, child := range sequence.Children {
if child.Tag == 0 {
//Warning
child := child.Children[0]
packet := ber.DecodePacket(child.Data.Bytes())
val, ok := packet.Value.(int64)
if ok {
if child.Tag == 0 {
//timeBeforeExpiration
value.Description += " (TimeBeforeExpiration)"
child.Value = val
} else if child.Tag == 1 {
//graceAuthNsRemaining
value.Description += " (GraceAuthNsRemaining)"
child.Value = val
}
}
} else if child.Tag == 1 {
// Error
packet := ber.DecodePacket(child.Data.Bytes())
val, ok := packet.Value.(int8)
if !ok {
val = -1
}
child.Description = "Error"
child.Value = val
}
}
}
}
}
@ -246,7 +311,7 @@ func addControlDescriptions(packet *ber.Packet) {
func addRequestDescriptions(packet *ber.Packet) {
packet.Description = "LDAP Request"
packet.Children[0].Description = "Message ID"
packet.Children[1].Description = ApplicationMap[packet.Children[1].Tag]
packet.Children[1].Description = ApplicationMap[uint8(packet.Children[1].Tag)]
if len(packet.Children) == 3 {
addControlDescriptions(packet.Children[2])
}
@ -294,10 +359,45 @@ func NewError(resultCode uint8, err error) error {
func getLDAPResultCode(packet *ber.Packet) (code uint8, description string) {
if len(packet.Children) >= 2 {
response := packet.Children[1]
if response.ClassType == ber.ClassApplication && response.TagType == ber.TypeConstructed && len(response.Children) == 3 {
if response.ClassType == ber.ClassApplication && response.TagType == ber.TypeConstructed && len(response.Children) >= 3 {
return uint8(response.Children[0].Value.(int64)), response.Children[2].Value.(string)
}
}
return ErrorNetwork, "Invalid packet format"
}
var hex = "0123456789abcdef"
func mustEscape(c byte) bool {
return c > 0x7f || c == '(' || c == ')' || c == '\\' || c == '*' || c == 0
}
// EscapeFilter escapes from the provided LDAP filter string the special
// characters in the set `()*\` and those out of the range 0 < c < 0x80,
// as defined in RFC4515.
func EscapeFilter(filter string) string {
escape := 0
for i := 0; i < len(filter); i++ {
if mustEscape(filter[i]) {
escape++
}
}
if escape == 0 {
return filter
}
buf := make([]byte, len(filter)+escape*2)
for i, j := 0, 0; i < len(filter); i++ {
c := filter[i]
if mustEscape(c) {
buf[j+0] = '\\'
buf[j+1] = hex[c>>4]
buf[j+2] = hex[c&0xf]
j += 3
} else {
buf[j] = c
j++
}
}
return string(buf)
}

View File

@ -0,0 +1,247 @@
package ldap
import (
"crypto/tls"
"fmt"
"testing"
)
var ldapServer = "ldap.itd.umich.edu"
var ldapPort = uint16(389)
var ldapTLSPort = uint16(636)
var baseDN = "dc=umich,dc=edu"
var filter = []string{
"(cn=cis-fac)",
"(&(owner=*)(cn=cis-fac))",
"(&(objectclass=rfc822mailgroup)(cn=*Computer*))",
"(&(objectclass=rfc822mailgroup)(cn=*Mathematics*))"}
var attributes = []string{
"cn",
"description"}
func TestDial(t *testing.T) {
fmt.Printf("TestDial: starting...\n")
l, err := Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort))
if err != nil {
t.Errorf(err.Error())
return
}
defer l.Close()
fmt.Printf("TestDial: finished...\n")
}
func TestDialTLS(t *testing.T) {
fmt.Printf("TestDialTLS: starting...\n")
l, err := DialTLS("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapTLSPort), &tls.Config{InsecureSkipVerify: true})
if err != nil {
t.Errorf(err.Error())
return
}
defer l.Close()
fmt.Printf("TestDialTLS: finished...\n")
}
func TestStartTLS(t *testing.T) {
fmt.Printf("TestStartTLS: starting...\n")
l, err := Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort))
if err != nil {
t.Errorf(err.Error())
return
}
err = l.StartTLS(&tls.Config{InsecureSkipVerify: true})
if err != nil {
t.Errorf(err.Error())
return
}
fmt.Printf("TestStartTLS: finished...\n")
}
func TestSearch(t *testing.T) {
fmt.Printf("TestSearch: starting...\n")
l, err := Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort))
if err != nil {
t.Errorf(err.Error())
return
}
defer l.Close()
searchRequest := NewSearchRequest(
baseDN,
ScopeWholeSubtree, DerefAlways, 0, 0, false,
filter[0],
attributes,
nil)
sr, err := l.Search(searchRequest)
if err != nil {
t.Errorf(err.Error())
return
}
fmt.Printf("TestSearch: %s -> num of entries = %d\n", searchRequest.Filter, len(sr.Entries))
}
func TestSearchStartTLS(t *testing.T) {
fmt.Printf("TestSearchStartTLS: starting...\n")
l, err := Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort))
if err != nil {
t.Errorf(err.Error())
return
}
defer l.Close()
searchRequest := NewSearchRequest(
baseDN,
ScopeWholeSubtree, DerefAlways, 0, 0, false,
filter[0],
attributes,
nil)
sr, err := l.Search(searchRequest)
if err != nil {
t.Errorf(err.Error())
return
}
fmt.Printf("TestSearchStartTLS: %s -> num of entries = %d\n", searchRequest.Filter, len(sr.Entries))
fmt.Printf("TestSearchStartTLS: upgrading with startTLS\n")
err = l.StartTLS(&tls.Config{InsecureSkipVerify: true})
if err != nil {
t.Errorf(err.Error())
return
}
sr, err = l.Search(searchRequest)
if err != nil {
t.Errorf(err.Error())
return
}
fmt.Printf("TestSearchStartTLS: %s -> num of entries = %d\n", searchRequest.Filter, len(sr.Entries))
}
func TestSearchWithPaging(t *testing.T) {
fmt.Printf("TestSearchWithPaging: starting...\n")
l, err := Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort))
if err != nil {
t.Errorf(err.Error())
return
}
defer l.Close()
err = l.Bind("", "")
if err != nil {
t.Errorf(err.Error())
return
}
searchRequest := NewSearchRequest(
baseDN,
ScopeWholeSubtree, DerefAlways, 0, 0, false,
filter[2],
attributes,
nil)
sr, err := l.SearchWithPaging(searchRequest, 5)
if err != nil {
t.Errorf(err.Error())
return
}
fmt.Printf("TestSearchWithPaging: %s -> num of entries = %d\n", searchRequest.Filter, len(sr.Entries))
}
func searchGoroutine(t *testing.T, l *Conn, results chan *SearchResult, i int) {
searchRequest := NewSearchRequest(
baseDN,
ScopeWholeSubtree, DerefAlways, 0, 0, false,
filter[i],
attributes,
nil)
sr, err := l.Search(searchRequest)
if err != nil {
t.Errorf(err.Error())
results <- nil
return
}
results <- sr
}
func testMultiGoroutineSearch(t *testing.T, TLS bool, startTLS bool) {
fmt.Printf("TestMultiGoroutineSearch: starting...\n")
var l *Conn
var err error
if TLS {
l, err = DialTLS("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapTLSPort), &tls.Config{InsecureSkipVerify: true})
if err != nil {
t.Errorf(err.Error())
return
}
defer l.Close()
} else {
l, err = Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort))
if err != nil {
t.Errorf(err.Error())
return
}
if startTLS {
fmt.Printf("TestMultiGoroutineSearch: using StartTLS...\n")
err := l.StartTLS(&tls.Config{InsecureSkipVerify: true})
if err != nil {
t.Errorf(err.Error())
return
}
}
}
results := make([]chan *SearchResult, len(filter))
for i := range filter {
results[i] = make(chan *SearchResult)
go searchGoroutine(t, l, results[i], i)
}
for i := range filter {
sr := <-results[i]
if sr == nil {
t.Errorf("Did not receive results from goroutine for %q", filter[i])
} else {
fmt.Printf("TestMultiGoroutineSearch(%d): %s -> num of entries = %d\n", i, filter[i], len(sr.Entries))
}
}
}
func TestMultiGoroutineSearch(t *testing.T) {
testMultiGoroutineSearch(t, false, false)
testMultiGoroutineSearch(t, true, true)
testMultiGoroutineSearch(t, false, true)
}
func TestEscapeFilter(t *testing.T) {
if got, want := EscapeFilter("a\x00b(c)d*e\\f"), `a\00b\28c\29d\2ae\5cf`; got != want {
t.Errorf("Got %s, expected %s", want, got)
}
if got, want := EscapeFilter("Lučić"), `Lu\c4\8di\c4\87`; got != want {
t.Errorf("Got %s, expected %s", want, got)
}
}
func TestCompare(t *testing.T) {
fmt.Printf("TestCompare: starting...\n")
l, err := Dial("tcp", fmt.Sprintf("%s:%d", ldapServer, ldapPort))
if err != nil {
t.Fatal(err.Error())
}
defer l.Close()
dn := "cn=math mich,ou=User Groups,ou=Groups,dc=umich,dc=edu"
attribute := "cn"
value := "math mich"
sr, err := l.Compare(dn, attribute, value)
if err != nil {
t.Errorf(err.Error())
return
}
fmt.Printf("TestCompare: -> num of entries = %d\n", sr)
}

View File

@ -33,7 +33,7 @@ import (
"errors"
"log"
"github.com/vanackere/asn1-ber"
"gopkg.in/asn1-ber.v1"
)
const (
@ -83,19 +83,19 @@ func (m ModifyRequest) encode() *ber.Packet {
changes := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Changes")
for _, attribute := range m.addAttributes {
change := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Change")
change.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, int64(AddAttribute), "Operation"))
change.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, uint64(AddAttribute), "Operation"))
change.AppendChild(attribute.encode())
changes.AppendChild(change)
}
for _, attribute := range m.deleteAttributes {
change := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Change")
change.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, int64(DeleteAttribute), "Operation"))
change.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, uint64(DeleteAttribute), "Operation"))
change.AppendChild(attribute.encode())
changes.AppendChild(change)
}
for _, attribute := range m.replaceAttributes {
change := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Change")
change.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, int64(ReplaceAttribute), "Operation"))
change.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, uint64(ReplaceAttribute), "Operation"))
change.AppendChild(attribute.encode())
changes.AppendChild(change)
}
@ -114,7 +114,7 @@ func NewModifyRequest(
func (l *Conn) Modify(modifyRequest *ModifyRequest) error {
messageID := l.nextMessageID()
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, int64(messageID), "MessageID"))
packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "MessageID"))
packet.AppendChild(modifyRequest.encode())
l.Debug.PrintPacket(packet)

View File

@ -0,0 +1,137 @@
// This file contains the password modify extended operation as specified in rfc 3062
//
// https://tools.ietf.org/html/rfc3062
//
package ldap
import (
"errors"
"fmt"
"gopkg.in/asn1-ber.v1"
)
const (
passwordModifyOID = "1.3.6.1.4.1.4203.1.11.1"
)
type PasswordModifyRequest struct {
UserIdentity string
OldPassword string
NewPassword string
}
type PasswordModifyResult struct {
GeneratedPassword string
}
func (r *PasswordModifyRequest) encode() (*ber.Packet, error) {
request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationExtendedRequest, nil, "Password Modify Extended Operation")
request.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, passwordModifyOID, "Extended Request Name: Password Modify OID"))
extendedRequestValue := ber.Encode(ber.ClassContext, ber.TypePrimitive, 1, nil, "Extended Request Value: Password Modify Request")
passwordModifyRequestValue := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Password Modify Request")
if r.UserIdentity != "" {
passwordModifyRequestValue.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, r.UserIdentity, "User Identity"))
}
if r.OldPassword != "" {
passwordModifyRequestValue.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 1, r.OldPassword, "Old Password"))
}
if r.NewPassword != "" {
passwordModifyRequestValue.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 2, r.NewPassword, "New Password"))
}
extendedRequestValue.AppendChild(passwordModifyRequestValue)
request.AppendChild(extendedRequestValue)
return request, nil
}
// Create a new PasswordModifyRequest
//
// According to the RFC 3602:
// userIdentity is a string representing the user associated with the request.
// This string may or may not be an LDAPDN (RFC 2253).
// If userIdentity is empty then the operation will act on the user associated
// with the session.
//
// oldPassword is the current user's password, it can be empty or it can be
// needed depending on the session user access rights (usually an administrator
// can change a user's password without knowing the current one) and the
// password policy (see pwdSafeModify password policy's attribute)
//
// newPassword is the desired user's password. If empty the server can return
// an error or generate a new password that will be available in the
// PasswordModifyResult.GeneratedPassword
//
func NewPasswordModifyRequest(userIdentity string, oldPassword string, newPassword string) *PasswordModifyRequest {
return &PasswordModifyRequest{
UserIdentity: userIdentity,
OldPassword: oldPassword,
NewPassword: newPassword,
}
}
func (l *Conn) PasswordModify(passwordModifyRequest *PasswordModifyRequest) (*PasswordModifyResult, error) {
messageID := l.nextMessageID()
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "MessageID"))
encodedPasswordModifyRequest, err := passwordModifyRequest.encode()
if err != nil {
return nil, err
}
packet.AppendChild(encodedPasswordModifyRequest)
l.Debug.PrintPacket(packet)
channel, err := l.sendMessage(packet)
if err != nil {
return nil, err
}
if channel == nil {
return nil, NewError(ErrorNetwork, errors.New("ldap: could not send message"))
}
defer l.finishMessage(messageID)
result := &PasswordModifyResult{}
l.Debug.Printf("%d: waiting for response", messageID)
packet = <-channel
l.Debug.Printf("%d: got response %p", messageID, packet)
if packet == nil {
return nil, NewError(ErrorNetwork, errors.New("ldap: could not retrieve message"))
}
if l.Debug {
if err := addLDAPDescriptions(packet); err != nil {
return nil, err
}
ber.PrintPacket(packet)
}
if packet.Children[1].Tag == ApplicationExtendedResponse {
resultCode, resultDescription := getLDAPResultCode(packet)
if resultCode != 0 {
return nil, NewError(resultCode, errors.New(resultDescription))
}
} else {
return nil, NewError(ErrorUnexpectedResponse, fmt.Errorf("Unexpected Response: %d", packet.Children[1].Tag))
}
extendedResponse := packet.Children[1]
for _, child := range extendedResponse.Children {
if child.Tag == 11 {
passwordModifyReponseValue := ber.DecodePacket(child.Data.Bytes())
if len(passwordModifyReponseValue.Children) == 1 {
if passwordModifyReponseValue.Children[0].Tag == 0 {
result.GeneratedPassword = ber.DecodeString(passwordModifyReponseValue.Children[0].Data.Bytes())
}
}
}
}
return result, nil
}

View File

@ -64,7 +64,7 @@ import (
"fmt"
"strings"
"github.com/vanackere/asn1-ber"
"gopkg.in/asn1-ber.v1"
)
const (
@ -107,6 +107,15 @@ func (e *Entry) GetAttributeValues(attribute string) []string {
return []string{}
}
func (e *Entry) GetRawAttributeValues(attribute string) [][]byte {
for _, attr := range e.Attributes {
if attr.Name == attribute {
return attr.ByteValues
}
}
return [][]byte{}
}
func (e *Entry) GetAttributeValue(attribute string) string {
values := e.GetAttributeValues(attribute)
if len(values) == 0 {
@ -115,6 +124,14 @@ func (e *Entry) GetAttributeValue(attribute string) string {
return values[0]
}
func (e *Entry) GetRawAttributeValue(attribute string) []byte {
values := e.GetRawAttributeValues(attribute)
if len(values) == 0 {
return []byte{}
}
return values[0]
}
func (e *Entry) Print() {
fmt.Printf("DN: %s\n", e.DN)
for _, attr := range e.Attributes {
@ -130,8 +147,9 @@ func (e *Entry) PrettyPrint(indent int) {
}
type EntryAttribute struct {
Name string
Values []string
Name string
Values []string
ByteValues [][]byte
}
func (e *EntryAttribute) Print() {
@ -175,10 +193,10 @@ type SearchRequest struct {
func (s *SearchRequest) encode() (*ber.Packet, error) {
request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationSearchRequest, nil, "Search Request")
request.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, s.BaseDN, "Base DN"))
request.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, int64(s.Scope), "Scope"))
request.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, int64(s.DerefAliases), "Deref Aliases"))
request.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, int64(s.SizeLimit), "Size Limit"))
request.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, int64(s.TimeLimit), "Time Limit"))
request.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, uint64(s.Scope), "Scope"))
request.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, uint64(s.DerefAliases), "Deref Aliases"))
request.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, uint64(s.SizeLimit), "Size Limit"))
request.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, uint64(s.TimeLimit), "Time Limit"))
request.AppendChild(ber.NewBoolean(ber.ClassUniversal, ber.TypePrimitive, ber.TagBoolean, s.TypesOnly, "Types Only"))
// compile and encode filter
filterPacket, err := CompileFilter(s.Filter)
@ -273,7 +291,7 @@ func (l *Conn) SearchWithPaging(searchRequest *SearchRequest, pagingSize uint32)
func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) {
messageID := l.nextMessageID()
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, int64(messageID), "MessageID"))
packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "MessageID"))
// encode search request
encodedSearchRequest, err := searchRequest.encode()
if err != nil {
@ -326,6 +344,7 @@ func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) {
attr.Name = child.Children[0].Value.(string)
for _, value := range child.Children[1].Children {
attr.Values = append(attr.Values, value.Value.(string))
attr.ByteValues = append(attr.ByteValues, value.ByteValue)
}
entry.Attributes = append(entry.Attributes, attr)
}

View File

@ -23,6 +23,7 @@ Henri Yandell <flamefew at gmail.com>
INADA Naoki <songofacandy at gmail.com>
James Harr <james.harr at gmail.com>
Jian Zhen <zhenjl at gmail.com>
Joshua Prunier <joshua.prunier at gmail.com>
Julien Schmidt <go-sql-driver at julienschmidt.com>
Kamil Dziedzic <kamil at klecza.pl>
Leonardo YongUk Kim <dalinaum at gmail.com>
@ -31,6 +32,8 @@ Luke Scott <luke at webconnex.com>
Michael Woolnough <michael.woolnough at gmail.com>
Nicola Peduzzi <thenikso at gmail.com>
Runrioter Wung <runrioter at gmail.com>
Soroush Pour <me at soroushjp.com>
Stan Putrya <root.vagner at gmail.com>
Xiaobing Jiang <s7v7nislands at gmail.com>
Xiuming Chen <cc at cxm.cc>

View File

@ -123,6 +123,16 @@ Default: false
`allowAllFiles=true` disables the file Whitelist for `LOAD DATA LOCAL INFILE` and allows *all* files.
[*Might be insecure!*](http://dev.mysql.com/doc/refman/5.7/en/load-data-local.html)
##### `allowCleartextPasswords`
```
Type: bool
Valid Values: true, false
Default: false
```
`allowCleartextPasswords=true` allows using the [cleartext client side plugin](http://dev.mysql.com/doc/en/cleartext-authentication-plugin.html) if required by an account, such as one defined with the [PAM authentication plugin](http://dev.mysql.com/doc/en/pam-authentication-plugin.html). Sending passwords in clear text may be a security problem in some configurations. To avoid problems if there is any possibility that the password would be intercepted, clients should connect to MySQL Server using a method that protects the password. Possibilities include [TLS / SSL](#tls), IPsec, or a private network.
##### `allowOldPasswords`
```
@ -205,6 +215,8 @@ Default: UTC
Sets the location for time.Time values (when using `parseTime=true`). *"Local"* sets the system's location. See [time.LoadLocation](http://golang.org/pkg/time/#LoadLocation) for details.
Note that this sets the location for time.Time values but does not change MySQL's [time_zone setting](https://dev.mysql.com/doc/refman/5.5/en/time-zone-support.html). For that see the [time_zone system variable](#system-variables), which can also be set as a DSN parameter.
Please keep in mind, that param values must be [url.QueryEscape](http://golang.org/pkg/net/url/#QueryEscape)'ed. Alternatively you can manually replace the `/` with `%2F`. For example `US/Pacific` would be `loc=US%2FPacific`.
@ -257,7 +269,7 @@ Default: false
All other parameters are interpreted as system variables:
* `autocommit`: `"SET autocommit=<value>"`
* `time_zone`: `"SET time_zone=<value>"`
* [`time_zone`](https://dev.mysql.com/doc/refman/5.5/en/time-zone-support.html): `"SET time_zone=<value>"`
* [`tx_isolation`](https://dev.mysql.com/doc/refman/5.5/en/server-system-variables.html#sysvar_tx_isolation): `"SET tx_isolation=<value>"`
* `param`: `"SET <param>=<value>"`

View File

@ -34,21 +34,22 @@ type mysqlConn struct {
}
type config struct {
user string
passwd string
net string
addr string
dbname string
params map[string]string
loc *time.Location
tls *tls.Config
timeout time.Duration
collation uint8
allowAllFiles bool
allowOldPasswords bool
clientFoundRows bool
columnsWithAlias bool
interpolateParams bool
user string
passwd string
net string
addr string
dbname string
params map[string]string
loc *time.Location
tls *tls.Config
timeout time.Duration
collation uint8
allowAllFiles bool
allowOldPasswords bool
allowCleartextPasswords bool
clientFoundRows bool
columnsWithAlias bool
interpolateParams bool
}
// Handles parameters set in DSN after the connection is established

View File

@ -24,6 +24,7 @@ const (
iERR byte = 0xff
)
// https://dev.mysql.com/doc/internals/en/capability-flags.html#packet-Protocol::CapabilityFlags
type clientFlag uint32
const (
@ -45,6 +46,13 @@ const (
clientSecureConn
clientMultiStatements
clientMultiResults
clientPSMultiResults
clientPluginAuth
clientConnectAttrs
clientPluginAuthLenEncClientData
clientCanHandleExpiredPasswords
clientSessionTrack
clientDeprecateEOF
)
const (
@ -78,6 +86,7 @@ const (
comStmtFetch
)
// https://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnType
const (
fieldTypeDecimal byte = iota
fieldTypeTiny
@ -132,7 +141,6 @@ const (
)
// http://dev.mysql.com/doc/internals/en/status-flags.html
type statusFlag uint16
const (

View File

@ -107,6 +107,15 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
mc.Close()
return nil, err
}
} else if mc.cfg != nil && mc.cfg.allowCleartextPasswords && err == ErrCleartextPassword {
if err = mc.writeClearAuthPacket(); err != nil {
mc.Close()
return nil, err
}
if err = mc.readResultOK(); err != nil {
mc.Close()
return nil, err
}
} else {
mc.Close()
return nil, err

View File

@ -780,6 +780,49 @@ func TestNULL(t *testing.T) {
})
}
func TestUint64(t *testing.T) {
const (
u0 = uint64(0)
uall = ^u0
uhigh = uall >> 1
utop = ^uhigh
s0 = int64(0)
sall = ^s0
shigh = int64(uhigh)
stop = ^shigh
)
runTests(t, dsn, func(dbt *DBTest) {
stmt, err := dbt.db.Prepare(`SELECT ?, ?, ? ,?, ?, ?, ?, ?`)
if err != nil {
dbt.Fatal(err)
}
defer stmt.Close()
row := stmt.QueryRow(
u0, uhigh, utop, uall,
s0, shigh, stop, sall,
)
var ua, ub, uc, ud uint64
var sa, sb, sc, sd int64
err = row.Scan(&ua, &ub, &uc, &ud, &sa, &sb, &sc, &sd)
if err != nil {
dbt.Fatal(err)
}
switch {
case ua != u0,
ub != uhigh,
uc != utop,
ud != uall,
sa != s0,
sb != shigh,
sc != stop,
sd != sall:
dbt.Fatal("Unexpected result value")
}
})
}
func TestLongData(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
var maxAllowedPacketSize int

View File

@ -19,15 +19,17 @@ import (
// Various errors the driver might return. Can change between driver versions.
var (
ErrInvalidConn = errors.New("Invalid Connection")
ErrMalformPkt = errors.New("Malformed Packet")
ErrNoTLS = errors.New("TLS encryption requested but server does not support TLS")
ErrOldPassword = errors.New("This server only supports the insecure old password authentication. If you still want to use it, please add 'allowOldPasswords=1' to your DSN. See also https://github.com/go-sql-driver/mysql/wiki/old_passwords")
ErrOldProtocol = errors.New("MySQL-Server does not support required Protocol 41+")
ErrPktSync = errors.New("Commands out of sync. You can't run this command now")
ErrPktSyncMul = errors.New("Commands out of sync. Did you run multiple statements at once?")
ErrPktTooLarge = errors.New("Packet for query is too large. You can change this value on the server by adjusting the 'max_allowed_packet' variable.")
ErrBusyBuffer = errors.New("Busy buffer")
ErrInvalidConn = errors.New("Invalid Connection")
ErrMalformPkt = errors.New("Malformed Packet")
ErrNoTLS = errors.New("TLS encryption requested but server does not support TLS")
ErrOldPassword = errors.New("This user requires old password authentication. If you still want to use it, please add 'allowOldPasswords=1' to your DSN. See also https://github.com/go-sql-driver/mysql/wiki/old_passwords")
ErrCleartextPassword = errors.New("This user requires clear text authentication. If you still want to use it, please add 'allowCleartextPasswords=1' to your DSN.")
ErrUnknownPlugin = errors.New("The authentication plugin is not supported.")
ErrOldProtocol = errors.New("MySQL-Server does not support required Protocol 41+")
ErrPktSync = errors.New("Commands out of sync. You can't run this command now")
ErrPktSyncMul = errors.New("Commands out of sync. Did you run multiple statements at once?")
ErrPktTooLarge = errors.New("Packet for query is too large. You can change this value on the server by adjusting the 'max_allowed_packet' variable.")
ErrBusyBuffer = errors.New("Busy buffer")
)
var errLog Logger = log.New(os.Stderr, "[MySQL] ", log.Ldate|log.Ltime|log.Lshortfile)

View File

@ -196,7 +196,11 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) {
// return
//}
//return ErrMalformPkt
return cipher, nil
// make a memory safe copy of the cipher slice
var b [20]byte
copy(b[:], cipher)
return b[:], nil
}
// make a memory safe copy of the cipher slice
@ -214,6 +218,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
clientLongPassword |
clientTransactions |
clientLocalFiles |
clientPluginAuth |
mc.flags&clientLongFlag
if mc.cfg.clientFoundRows {
@ -228,7 +233,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
// User Password
scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.passwd))
pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.user) + 1 + 1 + len(scrambleBuff)
pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.user) + 1 + 1 + len(scrambleBuff) + 21 + 1
// To specify a db name
if n := len(mc.cfg.dbname); n > 0 {
@ -294,8 +299,13 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
if len(mc.cfg.dbname) > 0 {
pos += copy(data[pos:], mc.cfg.dbname)
data[pos] = 0x00
pos++
}
// Assume native client during response
pos += copy(data[pos:], "mysql_native_password")
data[pos] = 0x00
// Send Auth packet
return mc.writePacket(data)
}
@ -306,7 +316,7 @@ func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error {
// User password
scrambleBuff := scrambleOldPassword(cipher, []byte(mc.cfg.passwd))
// Calculate the packet lenght and add a tailing 0
// Calculate the packet length and add a tailing 0
pktLen := len(scrambleBuff) + 1
data := mc.buf.takeSmallBuffer(4 + pktLen)
if data == nil {
@ -322,6 +332,25 @@ func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error {
return mc.writePacket(data)
}
// Client clear text authentication packet
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
func (mc *mysqlConn) writeClearAuthPacket() error {
// Calculate the packet length and add a tailing 0
pktLen := len(mc.cfg.passwd) + 1
data := mc.buf.takeSmallBuffer(4 + pktLen)
if data == nil {
// can not take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
return driver.ErrBadConn
}
// Add the clear password [null terminated string]
copy(data[4:], mc.cfg.passwd)
data[4+pktLen-1] = 0x00
return mc.writePacket(data)
}
/******************************************************************************
* Command Packets *
******************************************************************************/
@ -405,8 +434,20 @@ func (mc *mysqlConn) readResultOK() error {
return mc.handleOkPacket(data)
case iEOF:
// someone is using old_passwords
return ErrOldPassword
if len(data) > 1 {
plugin := string(data[1:bytes.IndexByte(data, 0x00)])
if plugin == "mysql_old_password" {
// using old_passwords
return ErrOldPassword
} else if plugin == "mysql_clear_password" {
// using clear text password
return ErrCleartextPassword
} else {
return ErrUnknownPlugin
}
} else {
return ErrOldPassword
}
default: // Error otherwise
return mc.handleErrorPacket(data)

View File

@ -10,6 +10,9 @@ package mysql
import (
"database/sql/driver"
"fmt"
"reflect"
"strconv"
)
type mysqlStmt struct {
@ -34,6 +37,10 @@ func (stmt *mysqlStmt) NumInput() int {
return stmt.paramCount
}
func (stmt *mysqlStmt) ColumnConverter(idx int) driver.ValueConverter {
return converter{}
}
func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
if stmt.mc.netConn == nil {
errLog.Print(ErrInvalidConn)
@ -110,3 +117,34 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
return rows, err
}
type converter struct{}
func (c converter) ConvertValue(v interface{}) (driver.Value, error) {
if driver.IsValue(v) {
return v, nil
}
rv := reflect.ValueOf(v)
switch rv.Kind() {
case reflect.Ptr:
// indirect pointers
if rv.IsNil() {
return nil, nil
}
return c.ConvertValue(rv.Elem().Interface())
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return rv.Int(), nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32:
return int64(rv.Uint()), nil
case reflect.Uint64:
u64 := rv.Uint()
if u64 >= 1<<63 {
return strconv.FormatUint(u64, 10), nil
}
return int64(u64), nil
case reflect.Float32, reflect.Float64:
return rv.Float(), nil
}
return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind())
}

View File

@ -80,8 +80,6 @@ func parseDSN(dsn string) (cfg *config, err error) {
collation: defaultCollation,
}
// TODO: use strings.IndexByte when we can depend on Go 1.2
// [user[:password]@][net[(addr)]]/dbname[?param1=value1&paramN=valueN]
// Find the last '/' (since the password or the net addr might contain a '/')
foundSlash := false
@ -201,6 +199,14 @@ func parseDSNParams(cfg *config, params string) (err error) {
return fmt.Errorf("Invalid Bool value: %s", value)
}
// Use cleartext authentication mode (MySQL 5.5.10+)
case "allowCleartextPasswords":
var isBool bool
cfg.allowCleartextPasswords, isBool = readBool(value)
if !isBool {
return fmt.Errorf("Invalid Bool value: %s", value)
}
// Use old authentication mode (pre MySQL 4.1)
case "allowOldPasswords":
var isBool bool
@ -771,6 +777,10 @@ func skipLengthEncodedString(b []byte) (int, error) {
// returns the number read, whether the value is NULL and the number of bytes read
func readLengthEncodedInteger(b []byte) (uint64, bool, int) {
// See issue #349
if len(b) == 0 {
return 0, true, 1
}
switch b[0] {
// 251: NULL

View File

@ -22,19 +22,19 @@ var testDSNs = []struct {
out string
loc *time.Location
}{
{"username:password@protocol(address)/dbname?param=value", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC},
{"username:password@protocol(address)/dbname?param=value&columnsWithAlias=true", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:true interpolateParams:false}", time.UTC},
{"user@unix(/path/to/socket)/dbname?charset=utf8", "&{user:user passwd: net:unix addr:/path/to/socket dbname:dbname params:map[charset:utf8] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC},
{"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC},
{"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8mb4,utf8] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC},
{"user:password@/dbname?loc=UTC&timeout=30s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci", "&{user:user passwd:password net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p tls:<nil> timeout:30000000000 collation:224 allowAllFiles:true allowOldPasswords:true clientFoundRows:true columnsWithAlias:false interpolateParams:false}", time.UTC},
{"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{user:user passwd:p@ss(word) net:tcp addr:[de:ad:be:ef::ca:fe]:80 dbname:dbname params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.Local},
{"/dbname", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC},
{"@/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC},
{"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC},
{"", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC},
{"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC},
{"unix/?arg=%2Fsome%2Fpath.ext", "&{user: passwd: net:unix addr:/tmp/mysql.sock dbname: params:map[arg:/some/path.ext] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC},
{"username:password@protocol(address)/dbname?param=value", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false allowCleartextPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC},
{"username:password@protocol(address)/dbname?param=value&columnsWithAlias=true", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false allowCleartextPasswords:false clientFoundRows:false columnsWithAlias:true interpolateParams:false}", time.UTC},
{"user@unix(/path/to/socket)/dbname?charset=utf8", "&{user:user passwd: net:unix addr:/path/to/socket dbname:dbname params:map[charset:utf8] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false allowCleartextPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC},
{"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false allowCleartextPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC},
{"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8mb4,utf8] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false allowCleartextPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC},
{"user:password@/dbname?loc=UTC&timeout=30s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci", "&{user:user passwd:password net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p tls:<nil> timeout:30000000000 collation:224 allowAllFiles:true allowOldPasswords:true allowCleartextPasswords:false clientFoundRows:true columnsWithAlias:false interpolateParams:false}", time.UTC},
{"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{user:user passwd:p@ss(word) net:tcp addr:[de:ad:be:ef::ca:fe]:80 dbname:dbname params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false allowCleartextPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.Local},
{"/dbname", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false allowCleartextPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC},
{"@/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false allowCleartextPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC},
{"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false allowCleartextPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC},
{"", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false allowCleartextPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC},
{"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false allowCleartextPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC},
{"unix/?arg=%2Fsome%2Fpath.ext", "&{user: passwd: net:unix addr:/tmp/mysql.sock dbname: params:map[arg:/some/path.ext] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false allowCleartextPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC},
}
func TestDSNParser(t *testing.T) {

View File

@ -7,6 +7,12 @@ package github
import "fmt"
// StarredRepository is returned by ListStarred.
type StarredRepository struct {
StarredAt *Timestamp `json:"starred_at,omitempty"`
Repository *Repository `json:"repo,omitempty"`
}
// ListStargazers lists people who have starred the specified repo.
//
// GitHub API Docs: https://developer.github.com/v3/activity/starring/#list-stargazers
@ -49,7 +55,7 @@ type ActivityListStarredOptions struct {
// will list the starred repositories for the authenticated user.
//
// GitHub API docs: http://developer.github.com/v3/activity/starring/#list-repositories-being-starred
func (s *ActivityService) ListStarred(user string, opt *ActivityListStarredOptions) ([]Repository, *Response, error) {
func (s *ActivityService) ListStarred(user string, opt *ActivityListStarredOptions) ([]StarredRepository, *Response, error) {
var u string
if user != "" {
u = fmt.Sprintf("users/%v/starred", user)
@ -66,7 +72,10 @@ func (s *ActivityService) ListStarred(user string, opt *ActivityListStarredOptio
return nil, nil, err
}
repos := new([]Repository)
// TODO: remove custom Accept header when this API fully launches
req.Header.Set("Accept", mediaTypeStarringPreview)
repos := new([]StarredRepository)
resp, err := s.client.Do(req, repos)
if err != nil {
return nil, resp, err

View File

@ -10,6 +10,7 @@ import (
"net/http"
"reflect"
"testing"
"time"
)
func TestActivityService_ListStargazers(t *testing.T) {
@ -42,7 +43,8 @@ func TestActivityService_ListStarred_authenticatedUser(t *testing.T) {
mux.HandleFunc("/user/starred", func(w http.ResponseWriter, r *http.Request) {
testMethod(t, r, "GET")
fmt.Fprint(w, `[{"id":1}]`)
testHeader(t, r, "Accept", mediaTypeStarringPreview)
fmt.Fprint(w, `[{"starred_at":"2002-02-10T15:30:00Z","repo":{"id":1}}]`)
})
repos, _, err := client.Activity.ListStarred("", nil)
@ -50,7 +52,7 @@ func TestActivityService_ListStarred_authenticatedUser(t *testing.T) {
t.Errorf("Activity.ListStarred returned error: %v", err)
}
want := []Repository{{ID: Int(1)}}
want := []StarredRepository{{StarredAt: &Timestamp{time.Date(2002, time.February, 10, 15, 30, 0, 0, time.UTC)}, Repository: &Repository{ID: Int(1)}}}
if !reflect.DeepEqual(repos, want) {
t.Errorf("Activity.ListStarred returned %+v, want %+v", repos, want)
}
@ -62,12 +64,13 @@ func TestActivityService_ListStarred_specifiedUser(t *testing.T) {
mux.HandleFunc("/users/u/starred", func(w http.ResponseWriter, r *http.Request) {
testMethod(t, r, "GET")
testHeader(t, r, "Accept", mediaTypeStarringPreview)
testFormValues(t, r, values{
"sort": "created",
"direction": "asc",
"page": "2",
})
fmt.Fprint(w, `[{"id":2}]`)
fmt.Fprint(w, `[{"starred_at":"2002-02-10T15:30:00Z","repo":{"id":2}}]`)
})
opt := &ActivityListStarredOptions{"created", "asc", ListOptions{Page: 2}}
@ -76,7 +79,7 @@ func TestActivityService_ListStarred_specifiedUser(t *testing.T) {
t.Errorf("Activity.ListStarred returned error: %v", err)
}
want := []Repository{{ID: Int(2)}}
want := []StarredRepository{{StarredAt: &Timestamp{time.Date(2002, time.February, 10, 15, 30, 0, 0, time.UTC)}, Repository: &Repository{ID: Int(2)}}}
if !reflect.DeepEqual(repos, want) {
t.Errorf("Activity.ListStarred returned %+v, want %+v", repos, want)
}

View File

@ -35,21 +35,10 @@ use it with the oauth2 library using:
import "golang.org/x/oauth2"
// tokenSource is an oauth2.TokenSource which returns a static access token
type tokenSource struct {
token *oauth2.Token
}
// Token implements the oauth2.TokenSource interface
func (t *tokenSource) Token() (*oauth2.Token, error){
return t.token, nil
}
func main() {
ts := &tokenSource{
ts := oauth2.StaticTokenSource(
&oauth2.Token{AccessToken: "... your access token ..."},
}
)
tc := oauth2.NewClient(oauth2.NoContext, ts)
client := github.NewClient(tc)

Some files were not shown because too many files have changed in this diff Show More