vault/vendor/github.com/SAP/go-hdb/internal/protocol/session.go
Tony Cai f92f4d4972 Added HANA database plugin (#2811)
* Added HANA dynamic secret backend

* Added acceptance tests for HANA secret backend

* Add HANA backend as a logical backend to server

* Added documentation to HANA secret backend

* Added vendored libraries

* Go fmt

* Migrate hana credential creation to plugin

* Removed deprecated hana logical backend

* Migrated documentation for HANA database plugin

* Updated HANA DB plugin to use role name in credential generation

* Update HANA plugin tests

* If env vars are not configured, tests will skip rather than succeed

* Fixed some improperly named string variables

* Removed unused import

* Import SAP hdb driver
2017-07-07 13:11:23 -07:00

987 lines
21 KiB
Go

/*
Copyright 2014 SAP SE
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package protocol
import (
"database/sql/driver"
"flag"
"fmt"
"log"
"math"
"net"
"os"
"time"
"github.com/SAP/go-hdb/internal/bufio"
"github.com/SAP/go-hdb/internal/unicode"
"github.com/SAP/go-hdb/internal/unicode/cesu8"
"github.com/SAP/go-hdb/driver/sqltrace"
)
const (
mnSCRAMSHA256 = "SCRAMSHA256"
mnGSS = "GSS"
mnSAML = "SAML"
)
var trace bool
func init() {
flag.BoolVar(&trace, "hdb.protocol.trace", false, "enabling hdb protocol trace")
}
var (
outLogger = log.New(os.Stdout, "hdb.protocol ", log.Ldate|log.Ltime|log.Lshortfile)
errLogger = log.New(os.Stderr, "hdb.protocol ", log.Ldate|log.Ltime|log.Lshortfile)
)
//padding
const (
padding = 8
)
func padBytes(size int) int {
if r := size % padding; r != 0 {
return padding - r
}
return 0
}
// SessionConn wraps the database tcp connection. It sets timeouts and handles driver ErrBadConn behavior.
type sessionConn struct {
addr string
timeoutDuration time.Duration
conn net.Conn
isBad bool // bad connection
badError error // error cause for session bad state
inTx bool // in transaction
}
func newSessionConn(addr string, timeout int) (*sessionConn, error) {
timeoutDuration := time.Duration(timeout) * time.Second
conn, err := net.DialTimeout("tcp", addr, timeoutDuration)
if err != nil {
return nil, err
}
return &sessionConn{
addr: addr,
timeoutDuration: timeoutDuration,
conn: conn,
}, nil
}
func (c *sessionConn) close() error {
return c.conn.Close()
}
// Read implements the io.Reader interface.
func (c *sessionConn) Read(b []byte) (int, error) {
//set timeout
if err := c.conn.SetReadDeadline(time.Now().Add(c.timeoutDuration)); err != nil {
return 0, err
}
n, err := c.conn.Read(b)
if err != nil {
errLogger.Printf("Connection read error local address %s remote address %s: %s", c.conn.LocalAddr(), c.conn.RemoteAddr(), err)
c.isBad = true
c.badError = err
return n, driver.ErrBadConn
}
return n, nil
}
// Write implements the io.Writer interface.
func (c *sessionConn) Write(b []byte) (int, error) {
//set timeout
if err := c.conn.SetWriteDeadline(time.Now().Add(c.timeoutDuration)); err != nil {
return 0, err
}
n, err := c.conn.Write(b)
if err != nil {
errLogger.Printf("Connection write error local address %s remote address %s: %s", c.conn.LocalAddr(), c.conn.RemoteAddr(), err)
c.isBad = true
c.badError = err
return n, driver.ErrBadConn
}
return n, nil
}
type providePart func(pk partKind) replyPart
type beforeRead func(p replyPart)
// Session represents a HDB session.
type Session struct {
prm *SessionPrm
conn *sessionConn
rd *bufio.Reader
wr *bufio.Writer
// reuse header
mh *messageHeader
sh *segmentHeader
ph *partHeader
//reuse request / reply parts
rowsAffected *rowsAffected
statementID *statementID
resultMetadata *resultMetadata
resultsetID *resultsetID
resultset *resultset
parameterMetadata *parameterMetadata
outputParameters *outputParameters
readLobRequest *readLobRequest
readLobReply *readLobReply
//standard replies
stmtCtx *statementContext
txFlags *transactionFlags
lastError *hdbError
}
// NewSession creates a new database session.
func NewSession(prm *SessionPrm) (*Session, error) {
if trace {
outLogger.Printf("%s", prm)
}
conn, err := newSessionConn(prm.Host, prm.Timeout)
if err != nil {
return nil, err
}
var rd *bufio.Reader
var wr *bufio.Writer
if prm.BufferSize > 0 {
rd = bufio.NewReaderSize(conn, prm.BufferSize)
wr = bufio.NewWriterSize(conn, prm.BufferSize)
} else {
rd = bufio.NewReader(conn)
wr = bufio.NewWriter(conn)
}
s := &Session{
prm: prm,
conn: conn,
rd: rd,
wr: wr,
mh: new(messageHeader),
sh: new(segmentHeader),
ph: new(partHeader),
rowsAffected: new(rowsAffected),
statementID: new(statementID),
resultMetadata: new(resultMetadata),
resultsetID: new(resultsetID),
resultset: new(resultset),
parameterMetadata: new(parameterMetadata),
outputParameters: new(outputParameters),
readLobRequest: new(readLobRequest),
readLobReply: new(readLobReply),
stmtCtx: newStatementContext(),
txFlags: newTransactionFlags(),
lastError: newHdbError(),
}
if err = s.init(); err != nil {
return nil, err
}
return s, nil
}
// Close closes the session.
func (s *Session) Close() error {
return s.conn.close()
}
func (s *Session) sessionID() int64 {
return s.mh.sessionID
}
// InTx indicates, if the session is in transaction mode.
func (s *Session) InTx() bool {
return s.conn.inTx
}
// SetInTx sets session in transaction mode.
func (s *Session) SetInTx(v bool) {
s.conn.inTx = v
}
// IsBad indicates, that the session is in bad state.
func (s *Session) IsBad() bool {
return s.conn.isBad
}
// BadErr returns the error, that caused the bad session state.
func (s *Session) BadErr() error {
return s.conn.badError
}
func (s *Session) init() error {
if err := s.initRequest(); err != nil {
return err
}
// TODO: detect authentication method
// - actually only basic authetication supported
authentication := mnSCRAMSHA256
switch authentication {
default:
return fmt.Errorf("invalid authentication %s", authentication)
case mnSCRAMSHA256:
if err := s.authenticateScramsha256(); err != nil {
return err
}
case mnGSS:
panic("not implemented error")
case mnSAML:
panic("not implemented error")
}
id := s.sessionID()
if id <= 0 {
return fmt.Errorf("invalid session id %d", id)
}
if trace {
outLogger.Printf("sessionId %d", id)
}
return nil
}
func (s *Session) authenticateScramsha256() error {
tr := unicode.Utf8ToCesu8Transformer
tr.Reset()
username := make([]byte, cesu8.StringSize(s.prm.Username))
if _, _, err := tr.Transform(username, []byte(s.prm.Username), true); err != nil {
return err // should never happen
}
password := make([]byte, cesu8.StringSize(s.prm.Password))
if _, _, err := tr.Transform(password, []byte(s.prm.Password), true); err != nil {
return err //should never happen
}
clientChallenge := clientChallenge()
//initial request
ireq := newScramsha256InitialRequest()
ireq.username = username
ireq.clientChallenge = clientChallenge
if err := s.writeRequest(mtAuthenticate, false, ireq); err != nil {
return err
}
irep := newScramsha256InitialReply()
f := func(pk partKind) replyPart {
switch pk {
case pkAuthentication:
return irep
default:
return nil
}
}
if err := s.readReply(f, nil); err != nil {
return err
}
//final request
freq := newScramsha256FinalRequest()
freq.username = username
freq.clientProof = clientProof(irep.salt, irep.serverChallenge, clientChallenge, password)
id := newClientID()
co := newConnectOptions()
co.set(coDistributionProtocolVersion, booleanType(false))
co.set(coSelectForUpdateSupported, booleanType(false))
co.set(coSplitBatchCommands, booleanType(true))
co.set(coDataFormatVersion, dfvBaseline)
co.set(coDataFormatVersion2, dfvBaseline)
co.set(coCompleteArrayExecution, booleanType(true))
co.set(coClientLocale, stringType(s.prm.Locale))
co.set(coClientDistributionMode, cdmOff)
if err := s.writeRequest(mtConnect, false, freq, id, co); err != nil {
return err
}
frep := newScramsha256FinalReply()
topo := newTopologyOptions()
f = func(pk partKind) replyPart {
switch pk {
case pkAuthentication:
return frep
case pkTopologyInformation:
return topo
case pkConnectOptions:
return co
default:
return nil
}
}
if err := s.readReply(f, nil); err != nil {
return err
}
return nil
}
// QueryDirect executes a query without query parameters.
func (s *Session) QueryDirect(query string) (uint64, *FieldSet, *FieldValues, PartAttributes, error) {
if err := s.writeRequest(mtExecuteDirect, false, command(query)); err != nil {
return 0, nil, nil, nil, err
}
var id uint64
var fieldSet *FieldSet
fieldValues := newFieldValues(s)
f := func(p replyPart) {
switch p := p.(type) {
case *resultsetID:
p.id = &id
case *resultMetadata:
fieldSet = newFieldSet(p.numArg)
p.fieldSet = fieldSet
case *resultset:
p.fieldSet = fieldSet
p.fieldValues = fieldValues
}
}
if err := s.readReply(nil, f); err != nil {
return 0, nil, nil, nil, err
}
attrs := s.ph.partAttributes
return id, fieldSet, fieldValues, attrs, nil
}
// ExecDirect executes a sql statement without statement parameters.
func (s *Session) ExecDirect(query string) (driver.Result, error) {
if err := s.writeRequest(mtExecuteDirect, !s.conn.inTx, command(query)); err != nil {
return nil, err
}
if err := s.readReply(nil, nil); err != nil {
return nil, err
}
if s.sh.functionCode == fcDDL {
return driver.ResultNoRows, nil
}
return driver.RowsAffected(s.rowsAffected.total()), nil
}
// Prepare prepares a sql statement.
func (s *Session) Prepare(query string) (QueryType, uint64, *FieldSet, *FieldSet, error) {
if err := s.writeRequest(mtPrepare, false, command(query)); err != nil {
return QtNone, 0, nil, nil, err
}
var id uint64
var prmFieldSet *FieldSet
var resultFieldSet *FieldSet
f := func(p replyPart) {
switch p := p.(type) {
case *statementID:
p.id = &id
case *parameterMetadata:
prmFieldSet = newFieldSet(p.numArg)
p.fieldSet = prmFieldSet
case *resultMetadata:
resultFieldSet = newFieldSet(p.numArg)
p.fieldSet = resultFieldSet
}
}
if err := s.readReply(nil, f); err != nil {
return QtNone, 0, nil, nil, err
}
return s.sh.functionCode.queryType(), id, prmFieldSet, resultFieldSet, nil
}
// Exec executes a sql statement.
func (s *Session) Exec(id uint64, parameterFieldSet *FieldSet, args []driver.Value) (driver.Result, error) {
s.statementID.id = &id
if err := s.writeRequest(mtExecute, !s.conn.inTx, s.statementID, newParameters(parameterFieldSet, args)); err != nil {
return nil, err
}
wlr := newWriteLobReply() //lob streaming
f := func(pk partKind) replyPart {
switch pk {
case pkWriteLobReply:
return wlr
default:
return nil
}
}
if err := s.readReply(f, nil); err != nil {
return nil, err
}
var result driver.Result
if s.sh.functionCode == fcDDL {
result = driver.ResultNoRows
} else {
result = driver.RowsAffected(s.rowsAffected.total())
}
if wlr.numArg > 0 {
if err := s.writeLobStream(parameterFieldSet, nil, args, wlr); err != nil {
return nil, err
}
}
return result, nil
}
// DropStatementID releases the hdb statement handle.
func (s *Session) DropStatementID(id uint64) error {
s.statementID.id = &id
if err := s.writeRequest(mtDropStatementID, false, s.statementID); err != nil {
return err
}
if err := s.readReply(nil, nil); err != nil {
return err
}
return nil
}
// Call executes a stored procedure.
func (s *Session) Call(id uint64, prmFieldSet *FieldSet, args []driver.Value) (*FieldValues, []*TableResult, error) {
s.statementID.id = &id
if err := s.writeRequest(mtExecute, false, s.statementID, newParameters(prmFieldSet, args)); err != nil {
return nil, nil, err
}
wlr := newWriteLobReply() //lob streaming
f := func(pk partKind) replyPart {
switch pk {
case pkWriteLobReply:
return wlr
default:
return nil
}
}
prmFieldValues := newFieldValues(s)
var tableResults []*TableResult
var tableResult *TableResult
g := func(p replyPart) {
switch p := p.(type) {
case *outputParameters:
p.fieldSet = prmFieldSet
p.fieldValues = prmFieldValues
// table output parameters: meta, id, result (only first param?)
case *resultMetadata:
tableResult = newTableResult(s, p.numArg)
tableResults = append(tableResults, tableResult)
p.fieldSet = tableResult.fieldSet
case *resultsetID:
p.id = &(tableResult.id)
case *resultset:
tableResult.attrs = s.ph.partAttributes
p.fieldSet = tableResult.fieldSet
p.fieldValues = tableResult.fieldValues
}
}
if err := s.readReply(f, g); err != nil {
return nil, nil, err
}
if wlr.numArg > 0 {
if err := s.writeLobStream(prmFieldSet, prmFieldValues, args, wlr); err != nil {
return nil, nil, err
}
}
return prmFieldValues, tableResults, nil
}
// Query executes a query.
func (s *Session) Query(stmtID uint64, parameterFieldSet *FieldSet, resultFieldSet *FieldSet, args []driver.Value) (uint64, *FieldValues, PartAttributes, error) {
s.statementID.id = &stmtID
if err := s.writeRequest(mtExecute, false, s.statementID, newParameters(parameterFieldSet, args)); err != nil {
return 0, nil, nil, err
}
var rsetID uint64
fieldValues := newFieldValues(s)
f := func(p replyPart) {
switch p := p.(type) {
case *resultsetID:
p.id = &rsetID
case *resultset:
p.fieldSet = resultFieldSet
p.fieldValues = fieldValues
}
}
if err := s.readReply(nil, f); err != nil {
return 0, nil, nil, err
}
attrs := s.ph.partAttributes
return rsetID, fieldValues, attrs, nil
}
// FetchNext fetches next chunk in query result set.
func (s *Session) FetchNext(id uint64, resultFieldSet *FieldSet) (*FieldValues, PartAttributes, error) {
s.resultsetID.id = &id
if err := s.writeRequest(mtFetchNext, false, s.resultsetID, fetchsize(s.prm.FetchSize)); err != nil {
return nil, nil, err
}
fieldValues := newFieldValues(s)
f := func(p replyPart) {
switch p := p.(type) {
case *resultset:
p.fieldSet = resultFieldSet
p.fieldValues = fieldValues
}
}
if err := s.readReply(nil, f); err != nil {
return nil, nil, err
}
attrs := s.ph.partAttributes
return fieldValues, attrs, nil
}
// CloseResultsetID releases the hdb resultset handle.
func (s *Session) CloseResultsetID(id uint64) error {
s.resultsetID.id = &id
if err := s.writeRequest(mtCloseResultset, false, s.resultsetID); err != nil {
return err
}
if err := s.readReply(nil, nil); err != nil {
return err
}
return nil
}
// Commit executes a database commit.
func (s *Session) Commit() error {
if err := s.writeRequest(mtCommit, false); err != nil {
return err
}
if err := s.readReply(nil, nil); err != nil {
return err
}
if trace {
outLogger.Printf("transaction flags: %s", s.txFlags)
}
s.conn.inTx = false
return nil
}
// Rollback executes a database rollback.
func (s *Session) Rollback() error {
if err := s.writeRequest(mtRollback, false); err != nil {
return err
}
if err := s.readReply(nil, nil); err != nil {
return err
}
if trace {
outLogger.Printf("transaction flags: %s", s.txFlags)
}
s.conn.inTx = false
return nil
}
// helper
func readLobStreamDone(writers []lobWriter) bool {
for _, writer := range writers {
if !writer.eof() {
return false
}
}
return true
}
//
func (s *Session) readLobStream(writers []lobWriter) error {
f := func(pk partKind) replyPart {
switch pk {
case pkReadLobReply:
return s.readLobReply
default:
return nil
}
}
for !readLobStreamDone(writers) {
s.readLobRequest.writers = writers
s.readLobReply.writers = writers
if err := s.writeRequest(mtWriteLob, false, s.readLobRequest); err != nil {
return err
}
if err := s.readReply(f, nil); err != nil {
return err
}
}
return nil
}
func (s *Session) writeLobStream(prmFieldSet *FieldSet, prmFieldValues *FieldValues, args []driver.Value, reply *writeLobReply) error {
num := reply.numArg
readers := make([]lobReader, num)
request := newWriteLobRequest(readers)
j := 0
for i, field := range prmFieldSet.fields {
if field.typeCode().isLob() && field.in() {
ptr, ok := args[i].(int64)
if !ok {
return fmt.Errorf("protocol error: invalid lob driver value type %T", args[i])
}
descr := pointerToLobWriteDescr(ptr)
if descr.r == nil {
return fmt.Errorf("protocol error: lob reader %d initial", ptr)
}
if j >= num {
return fmt.Errorf("protocol error: invalid number of lob parameter ids %d", num)
}
if field.typeCode().isCharBased() {
readers[j] = newCharLobReader(descr.r, reply.ids[j])
} else {
readers[j] = newBinaryLobReader(descr.r, reply.ids[j])
}
j++
}
}
f := func(pk partKind) replyPart {
switch pk {
case pkWriteLobReply:
return reply
default:
return nil
}
}
g := func(p replyPart) {
if p, ok := p.(*outputParameters); ok {
p.fieldSet = prmFieldSet
p.fieldValues = prmFieldValues
}
}
for reply.numArg != 0 {
if err := s.writeRequest(mtReadLob, false, request); err != nil {
return err
}
if err := s.readReply(f, g); err != nil {
return err
}
}
return nil
}
//
func (s *Session) initRequest() error {
// init
s.mh.sessionID = -1
// handshake
req := newInitRequest()
// TODO: constants
req.product.major = 4
req.product.minor = 20
req.protocol.major = 4
req.protocol.minor = 1
req.numOptions = 1
req.endianess = archEndian
if err := req.write(s.wr); err != nil {
return err
}
rep := newInitReply()
if err := rep.read(s.rd); err != nil {
return err
}
return nil
}
func (s *Session) writeRequest(messageType messageType, commit bool, requests ...requestPart) error {
partSize := make([]int, len(requests))
size := int64(segmentHeaderSize + len(requests)*partHeaderSize) //int64 to hold MaxUInt32 in 32bit OS
for i, part := range requests {
s, err := part.size()
if err != nil {
return err
}
size += int64(s + padBytes(s))
partSize[i] = s // buffer size (expensive calculation)
}
if size > math.MaxUint32 {
return fmt.Errorf("message size %d exceeds maximum message header value %d", size, int64(math.MaxUint32)) //int64: without cast overflow error in 32bit OS
}
bufferSize := size
s.mh.varPartLength = uint32(size)
s.mh.varPartSize = uint32(bufferSize)
s.mh.noOfSegm = 1
if err := s.mh.write(s.wr); err != nil {
return err
}
if size > math.MaxInt32 {
return fmt.Errorf("message size %d exceeds maximum part header value %d", size, math.MaxInt32)
}
s.sh.messageType = messageType
s.sh.commit = commit
s.sh.segmentKind = skRequest
s.sh.segmentLength = int32(size)
s.sh.segmentOfs = 0
s.sh.noOfParts = int16(len(requests))
s.sh.segmentNo = 1
if err := s.sh.write(s.wr); err != nil {
return err
}
bufferSize -= segmentHeaderSize
for i, part := range requests {
size := partSize[i]
pad := padBytes(size)
s.ph.partKind = part.kind()
numArg := part.numArg()
switch {
default:
return fmt.Errorf("maximum number of arguments %d exceeded", numArg)
case numArg <= math.MaxInt16:
s.ph.argumentCount = int16(numArg)
s.ph.bigArgumentCount = 0
// TODO: seems not to work: see bulk insert test
case numArg <= math.MaxInt32:
s.ph.argumentCount = 0
s.ph.bigArgumentCount = int32(numArg)
}
s.ph.bufferLength = int32(size)
s.ph.bufferSize = int32(bufferSize)
if err := s.ph.write(s.wr); err != nil {
return err
}
if err := part.write(s.wr); err != nil {
return err
}
if err := s.wr.WriteZeroes(pad); err != nil {
return err
}
bufferSize -= int64(partHeaderSize + size + pad)
}
if err := s.wr.Flush(); err != nil {
return err
}
return nil
}
func (s *Session) readReply(providePart providePart, beforeRead beforeRead) error {
replyError := false
if err := s.mh.read(s.rd); err != nil {
return err
}
if s.mh.noOfSegm != 1 {
return fmt.Errorf("simple message: no of segments %d - expected 1", s.mh.noOfSegm)
}
if err := s.sh.read(s.rd); err != nil {
return err
}
// TODO: protocol error (sps 82)?: message header varPartLength < segment header segmentLength (*1)
diff := int(s.mh.varPartLength) - int(s.sh.segmentLength)
if trace && diff != 0 {
outLogger.Printf("+++++diff %d", diff)
}
noOfParts := int(s.sh.noOfParts)
for i := 0; i < noOfParts; i++ {
if err := s.ph.read(s.rd); err != nil {
return err
}
numArg := int(s.ph.argumentCount)
var part replyPart
if providePart != nil {
part = providePart(s.ph.partKind)
} else {
part = nil
}
if part == nil { // use pre defined parts
switch s.ph.partKind {
case pkStatementID:
part = s.statementID
case pkResultMetadata:
part = s.resultMetadata
case pkResultsetID:
part = s.resultsetID
case pkResultset:
part = s.resultset
case pkParameterMetadata:
part = s.parameterMetadata
case pkOutputParameters:
part = s.outputParameters
case pkError:
replyError = true
part = s.lastError
case pkStatementContext:
part = s.stmtCtx
case pkTransactionFlags:
part = s.txFlags
case pkRowsAffected:
part = s.rowsAffected
default:
return fmt.Errorf("read not expected part kind %s", s.ph.partKind)
}
}
part.setNumArg(numArg)
if beforeRead != nil {
beforeRead(part)
}
if err := part.read(s.rd); err != nil {
return err
}
// TODO: workaround (see *)
if i != (noOfParts-1) || (i == (noOfParts-1) && diff == 0) {
if err := s.rd.Skip(padBytes(int(s.ph.bufferLength))); err != nil {
return err
}
}
}
if replyError {
if s.lastError.IsWarning() {
sqltrace.Traceln(s.lastError)
} else {
return s.lastError
}
}
return nil
}