diff --git a/audit/format.go b/audit/format.go index 29ea807951..1a97078ae7 100644 --- a/audit/format.go +++ b/audit/format.go @@ -110,6 +110,7 @@ func (f *AuditFormatter) FormatRequest(ctx context.Context, w io.Writer, config Request: &AuditRequest{ ID: req.ID, + ClientID: req.ClientID, ClientToken: req.ClientToken, ClientTokenAccessor: req.ClientTokenAccessor, Operation: req.Operation, @@ -336,6 +337,7 @@ type AuditResponseEntry struct { type AuditRequest struct { ID string `json:"id,omitempty"` + ClientID string `json:"client_id,omitempty"` ReplicationCluster string `json:"replication_cluster,omitempty"` Operation logical.Operation `json:"operation,omitempty"` MountType string `json:"mount_type,omitempty"` diff --git a/sdk/logical/request.go b/sdk/logical/request.go index 5809531474..829c155fd0 100644 --- a/sdk/logical/request.go +++ b/sdk/logical/request.go @@ -219,7 +219,7 @@ type Request struct { // entity, it will be the same as the EntityID . If the token has no entity, // this will be the sha256(sorted policies + namespace) associated with the // client token. - ClientID string + ClientID string `json:"client_id" structs:"client_id" mapstructure:"client_id" sentinel:""` } // Clone returns a deep copy of the request by using copystructure diff --git a/vault/activity_log.go b/vault/activity_log.go index 026ad35002..a326caea12 100644 --- a/vault/activity_log.go +++ b/vault/activity_log.go @@ -1583,31 +1583,33 @@ func (a *ActivityLog) loadConfigOrDefault(ctx context.Context) (activityConfig, return config, nil } -// HandleTokenUsage adds the TokenEntry to the current fragment of the activity log. +// HandleTokenUsage adds the TokenEntry to the current fragment of the activity log +// and returns the corresponding Client ID. // This currently occurs on token usage only. -func (a *ActivityLog) HandleTokenUsage(entry *logical.TokenEntry) { +func (a *ActivityLog) HandleTokenUsage(entry *logical.TokenEntry) string { // First, check if a is enabled, so as to avoid the cost of creating an ID for // tokens without entities in the case where it not. a.fragmentLock.RLock() if !a.enabled { a.fragmentLock.RUnlock() - return + return "" } a.fragmentLock.RUnlock() // Do not count wrapping tokens in client count if IsWrappingToken(entry) { - return + return "" } // Do not count root tokens in client count. if entry.IsRoot() { - return + return "" } // Parse an entry's client ID and add it to the activity log clientID, isTWE := a.CreateClientID(entry) a.AddClientToFragment(clientID, entry.NamespaceID, entry.CreationTime, isTWE) + return clientID } // CreateClientID returns the client ID, and a boolean which is false if the clientID @@ -1649,7 +1651,7 @@ func (a *ActivityLog) CreateClientID(entry *logical.TokenEntry) (string, bool) { // Step 5: Hash the sum hashed := sha256.Sum256([]byte(clientIDInput)) - return base64.URLEncoding.EncodeToString(hashed[:]), true + return base64.StdEncoding.EncodeToString(hashed[:]), true } func (a *ActivityLog) namespaceToLabel(ctx context.Context, nsID string) string { diff --git a/vault/activity_log_test.go b/vault/activity_log_test.go index a7fbe9d540..b0755c4214 100644 --- a/vault/activity_log_test.go +++ b/vault/activity_log_test.go @@ -1534,7 +1534,7 @@ func TestCreateClientID(t *testing.T) { string(sortedPoliciesTWEDelimiter) + "foo" + string(clientIDTWEDelimiter) + "namespaceFoo" hashed := sha256.Sum256([]byte(expectedIDPlaintext)) - expectedID := base64.URLEncoding.EncodeToString(hashed[:]) + expectedID := base64.StdEncoding.EncodeToString(hashed[:]) if expectedID != id { t.Fatalf("wrong ID: expected %s, found %s", expectedID, id) } @@ -1559,7 +1559,7 @@ func TestCreateClientID(t *testing.T) { string(sortedPoliciesTWEDelimiter) + "foo" + string(clientIDTWEDelimiter) hashed = sha256.Sum256([]byte(expectedIDPlaintext)) - expectedID = base64.URLEncoding.EncodeToString(hashed[:]) + expectedID = base64.StdEncoding.EncodeToString(hashed[:]) if expectedID != id { t.Fatalf("wrong ID: expected %s, found %s", expectedID, id) } @@ -1573,7 +1573,7 @@ func TestCreateClientID(t *testing.T) { expectedIDPlaintext = "namespaceFoo" hashed = sha256.Sum256([]byte(expectedIDPlaintext)) - expectedID = base64.URLEncoding.EncodeToString(hashed[:]) + expectedID = base64.StdEncoding.EncodeToString(hashed[:]) if expectedID != id { t.Fatalf("wrong ID: expected %s, found %s", expectedID, id) } diff --git a/vault/request_handling.go b/vault/request_handling.go index 42940c8908..c83bac2dbd 100644 --- a/vault/request_handling.go +++ b/vault/request_handling.go @@ -398,9 +398,7 @@ func (c *Core) checkToken(ctx context.Context, req *logical.Request, unauth bool // If it is an authenticated ( i.e with vault token ) request, increment client count if !unauth && c.activityLog != nil { - clientID, _ := c.activityLog.CreateClientID(req.TokenEntry()) - req.ClientID = clientID - c.activityLog.HandleTokenUsage(te) + req.ClientID = c.activityLog.HandleTokenUsage(te) } return auth, te, nil }