mirror of
https://github.com/hashicorp/vault.git
synced 2025-08-15 11:07:00 +02:00
* Added HANA dynamic secret backend * Added acceptance tests for HANA secret backend * Add HANA backend as a logical backend to server * Added documentation to HANA secret backend * Added vendored libraries * Go fmt * Migrate hana credential creation to plugin * Removed deprecated hana logical backend * Migrated documentation for HANA database plugin * Updated HANA DB plugin to use role name in credential generation * Update HANA plugin tests * If env vars are not configured, tests will skip rather than succeed * Fixed some improperly named string variables * Removed unused import * Import SAP hdb driver
987 lines
21 KiB
Go
987 lines
21 KiB
Go
/*
|
|
Copyright 2014 SAP SE
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
*/
|
|
|
|
package protocol
|
|
|
|
import (
|
|
"database/sql/driver"
|
|
"flag"
|
|
"fmt"
|
|
"log"
|
|
"math"
|
|
"net"
|
|
"os"
|
|
"time"
|
|
|
|
"github.com/SAP/go-hdb/internal/bufio"
|
|
"github.com/SAP/go-hdb/internal/unicode"
|
|
"github.com/SAP/go-hdb/internal/unicode/cesu8"
|
|
|
|
"github.com/SAP/go-hdb/driver/sqltrace"
|
|
)
|
|
|
|
const (
|
|
mnSCRAMSHA256 = "SCRAMSHA256"
|
|
mnGSS = "GSS"
|
|
mnSAML = "SAML"
|
|
)
|
|
|
|
var trace bool
|
|
|
|
func init() {
|
|
flag.BoolVar(&trace, "hdb.protocol.trace", false, "enabling hdb protocol trace")
|
|
}
|
|
|
|
var (
|
|
outLogger = log.New(os.Stdout, "hdb.protocol ", log.Ldate|log.Ltime|log.Lshortfile)
|
|
errLogger = log.New(os.Stderr, "hdb.protocol ", log.Ldate|log.Ltime|log.Lshortfile)
|
|
)
|
|
|
|
//padding
|
|
const (
|
|
padding = 8
|
|
)
|
|
|
|
func padBytes(size int) int {
|
|
if r := size % padding; r != 0 {
|
|
return padding - r
|
|
}
|
|
return 0
|
|
}
|
|
|
|
// SessionConn wraps the database tcp connection. It sets timeouts and handles driver ErrBadConn behavior.
|
|
type sessionConn struct {
|
|
addr string
|
|
timeoutDuration time.Duration
|
|
conn net.Conn
|
|
isBad bool // bad connection
|
|
badError error // error cause for session bad state
|
|
inTx bool // in transaction
|
|
}
|
|
|
|
func newSessionConn(addr string, timeout int) (*sessionConn, error) {
|
|
timeoutDuration := time.Duration(timeout) * time.Second
|
|
conn, err := net.DialTimeout("tcp", addr, timeoutDuration)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &sessionConn{
|
|
addr: addr,
|
|
timeoutDuration: timeoutDuration,
|
|
conn: conn,
|
|
}, nil
|
|
}
|
|
|
|
func (c *sessionConn) close() error {
|
|
return c.conn.Close()
|
|
}
|
|
|
|
// Read implements the io.Reader interface.
|
|
func (c *sessionConn) Read(b []byte) (int, error) {
|
|
//set timeout
|
|
if err := c.conn.SetReadDeadline(time.Now().Add(c.timeoutDuration)); err != nil {
|
|
return 0, err
|
|
}
|
|
n, err := c.conn.Read(b)
|
|
if err != nil {
|
|
errLogger.Printf("Connection read error local address %s remote address %s: %s", c.conn.LocalAddr(), c.conn.RemoteAddr(), err)
|
|
c.isBad = true
|
|
c.badError = err
|
|
return n, driver.ErrBadConn
|
|
}
|
|
return n, nil
|
|
}
|
|
|
|
// Write implements the io.Writer interface.
|
|
func (c *sessionConn) Write(b []byte) (int, error) {
|
|
//set timeout
|
|
if err := c.conn.SetWriteDeadline(time.Now().Add(c.timeoutDuration)); err != nil {
|
|
return 0, err
|
|
}
|
|
n, err := c.conn.Write(b)
|
|
if err != nil {
|
|
errLogger.Printf("Connection write error local address %s remote address %s: %s", c.conn.LocalAddr(), c.conn.RemoteAddr(), err)
|
|
c.isBad = true
|
|
c.badError = err
|
|
return n, driver.ErrBadConn
|
|
}
|
|
return n, nil
|
|
}
|
|
|
|
type providePart func(pk partKind) replyPart
|
|
type beforeRead func(p replyPart)
|
|
|
|
// Session represents a HDB session.
|
|
type Session struct {
|
|
prm *SessionPrm
|
|
|
|
conn *sessionConn
|
|
rd *bufio.Reader
|
|
wr *bufio.Writer
|
|
|
|
// reuse header
|
|
mh *messageHeader
|
|
sh *segmentHeader
|
|
ph *partHeader
|
|
|
|
//reuse request / reply parts
|
|
rowsAffected *rowsAffected
|
|
statementID *statementID
|
|
resultMetadata *resultMetadata
|
|
resultsetID *resultsetID
|
|
resultset *resultset
|
|
parameterMetadata *parameterMetadata
|
|
outputParameters *outputParameters
|
|
readLobRequest *readLobRequest
|
|
readLobReply *readLobReply
|
|
|
|
//standard replies
|
|
stmtCtx *statementContext
|
|
txFlags *transactionFlags
|
|
lastError *hdbError
|
|
}
|
|
|
|
// NewSession creates a new database session.
|
|
func NewSession(prm *SessionPrm) (*Session, error) {
|
|
|
|
if trace {
|
|
outLogger.Printf("%s", prm)
|
|
}
|
|
|
|
conn, err := newSessionConn(prm.Host, prm.Timeout)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var rd *bufio.Reader
|
|
var wr *bufio.Writer
|
|
if prm.BufferSize > 0 {
|
|
rd = bufio.NewReaderSize(conn, prm.BufferSize)
|
|
wr = bufio.NewWriterSize(conn, prm.BufferSize)
|
|
} else {
|
|
rd = bufio.NewReader(conn)
|
|
wr = bufio.NewWriter(conn)
|
|
}
|
|
|
|
s := &Session{
|
|
prm: prm,
|
|
conn: conn,
|
|
rd: rd,
|
|
wr: wr,
|
|
mh: new(messageHeader),
|
|
sh: new(segmentHeader),
|
|
ph: new(partHeader),
|
|
rowsAffected: new(rowsAffected),
|
|
statementID: new(statementID),
|
|
resultMetadata: new(resultMetadata),
|
|
resultsetID: new(resultsetID),
|
|
resultset: new(resultset),
|
|
parameterMetadata: new(parameterMetadata),
|
|
outputParameters: new(outputParameters),
|
|
readLobRequest: new(readLobRequest),
|
|
readLobReply: new(readLobReply),
|
|
stmtCtx: newStatementContext(),
|
|
txFlags: newTransactionFlags(),
|
|
lastError: newHdbError(),
|
|
}
|
|
|
|
if err = s.init(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return s, nil
|
|
}
|
|
|
|
// Close closes the session.
|
|
func (s *Session) Close() error {
|
|
return s.conn.close()
|
|
}
|
|
|
|
func (s *Session) sessionID() int64 {
|
|
return s.mh.sessionID
|
|
}
|
|
|
|
// InTx indicates, if the session is in transaction mode.
|
|
func (s *Session) InTx() bool {
|
|
return s.conn.inTx
|
|
}
|
|
|
|
// SetInTx sets session in transaction mode.
|
|
func (s *Session) SetInTx(v bool) {
|
|
s.conn.inTx = v
|
|
}
|
|
|
|
// IsBad indicates, that the session is in bad state.
|
|
func (s *Session) IsBad() bool {
|
|
return s.conn.isBad
|
|
}
|
|
|
|
// BadErr returns the error, that caused the bad session state.
|
|
func (s *Session) BadErr() error {
|
|
return s.conn.badError
|
|
}
|
|
|
|
func (s *Session) init() error {
|
|
|
|
if err := s.initRequest(); err != nil {
|
|
return err
|
|
}
|
|
|
|
// TODO: detect authentication method
|
|
// - actually only basic authetication supported
|
|
|
|
authentication := mnSCRAMSHA256
|
|
|
|
switch authentication {
|
|
default:
|
|
return fmt.Errorf("invalid authentication %s", authentication)
|
|
|
|
case mnSCRAMSHA256:
|
|
if err := s.authenticateScramsha256(); err != nil {
|
|
return err
|
|
}
|
|
case mnGSS:
|
|
panic("not implemented error")
|
|
case mnSAML:
|
|
panic("not implemented error")
|
|
}
|
|
|
|
id := s.sessionID()
|
|
if id <= 0 {
|
|
return fmt.Errorf("invalid session id %d", id)
|
|
}
|
|
|
|
if trace {
|
|
outLogger.Printf("sessionId %d", id)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *Session) authenticateScramsha256() error {
|
|
tr := unicode.Utf8ToCesu8Transformer
|
|
tr.Reset()
|
|
|
|
username := make([]byte, cesu8.StringSize(s.prm.Username))
|
|
if _, _, err := tr.Transform(username, []byte(s.prm.Username), true); err != nil {
|
|
return err // should never happen
|
|
}
|
|
|
|
password := make([]byte, cesu8.StringSize(s.prm.Password))
|
|
if _, _, err := tr.Transform(password, []byte(s.prm.Password), true); err != nil {
|
|
return err //should never happen
|
|
}
|
|
|
|
clientChallenge := clientChallenge()
|
|
|
|
//initial request
|
|
ireq := newScramsha256InitialRequest()
|
|
ireq.username = username
|
|
ireq.clientChallenge = clientChallenge
|
|
|
|
if err := s.writeRequest(mtAuthenticate, false, ireq); err != nil {
|
|
return err
|
|
}
|
|
|
|
irep := newScramsha256InitialReply()
|
|
|
|
f := func(pk partKind) replyPart {
|
|
switch pk {
|
|
case pkAuthentication:
|
|
return irep
|
|
default:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
if err := s.readReply(f, nil); err != nil {
|
|
return err
|
|
}
|
|
|
|
//final request
|
|
freq := newScramsha256FinalRequest()
|
|
freq.username = username
|
|
freq.clientProof = clientProof(irep.salt, irep.serverChallenge, clientChallenge, password)
|
|
|
|
id := newClientID()
|
|
|
|
co := newConnectOptions()
|
|
co.set(coDistributionProtocolVersion, booleanType(false))
|
|
co.set(coSelectForUpdateSupported, booleanType(false))
|
|
co.set(coSplitBatchCommands, booleanType(true))
|
|
co.set(coDataFormatVersion, dfvBaseline)
|
|
co.set(coDataFormatVersion2, dfvBaseline)
|
|
co.set(coCompleteArrayExecution, booleanType(true))
|
|
co.set(coClientLocale, stringType(s.prm.Locale))
|
|
co.set(coClientDistributionMode, cdmOff)
|
|
|
|
if err := s.writeRequest(mtConnect, false, freq, id, co); err != nil {
|
|
return err
|
|
}
|
|
|
|
frep := newScramsha256FinalReply()
|
|
topo := newTopologyOptions()
|
|
|
|
f = func(pk partKind) replyPart {
|
|
switch pk {
|
|
case pkAuthentication:
|
|
return frep
|
|
case pkTopologyInformation:
|
|
return topo
|
|
case pkConnectOptions:
|
|
return co
|
|
default:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
if err := s.readReply(f, nil); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// QueryDirect executes a query without query parameters.
|
|
func (s *Session) QueryDirect(query string) (uint64, *FieldSet, *FieldValues, PartAttributes, error) {
|
|
|
|
if err := s.writeRequest(mtExecuteDirect, false, command(query)); err != nil {
|
|
return 0, nil, nil, nil, err
|
|
}
|
|
|
|
var id uint64
|
|
var fieldSet *FieldSet
|
|
fieldValues := newFieldValues(s)
|
|
|
|
f := func(p replyPart) {
|
|
|
|
switch p := p.(type) {
|
|
|
|
case *resultsetID:
|
|
p.id = &id
|
|
case *resultMetadata:
|
|
fieldSet = newFieldSet(p.numArg)
|
|
p.fieldSet = fieldSet
|
|
case *resultset:
|
|
p.fieldSet = fieldSet
|
|
p.fieldValues = fieldValues
|
|
}
|
|
}
|
|
|
|
if err := s.readReply(nil, f); err != nil {
|
|
return 0, nil, nil, nil, err
|
|
}
|
|
|
|
attrs := s.ph.partAttributes
|
|
|
|
return id, fieldSet, fieldValues, attrs, nil
|
|
}
|
|
|
|
// ExecDirect executes a sql statement without statement parameters.
|
|
func (s *Session) ExecDirect(query string) (driver.Result, error) {
|
|
|
|
if err := s.writeRequest(mtExecuteDirect, !s.conn.inTx, command(query)); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err := s.readReply(nil, nil); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if s.sh.functionCode == fcDDL {
|
|
return driver.ResultNoRows, nil
|
|
}
|
|
return driver.RowsAffected(s.rowsAffected.total()), nil
|
|
}
|
|
|
|
// Prepare prepares a sql statement.
|
|
func (s *Session) Prepare(query string) (QueryType, uint64, *FieldSet, *FieldSet, error) {
|
|
|
|
if err := s.writeRequest(mtPrepare, false, command(query)); err != nil {
|
|
return QtNone, 0, nil, nil, err
|
|
}
|
|
|
|
var id uint64
|
|
var prmFieldSet *FieldSet
|
|
var resultFieldSet *FieldSet
|
|
|
|
f := func(p replyPart) {
|
|
|
|
switch p := p.(type) {
|
|
|
|
case *statementID:
|
|
p.id = &id
|
|
case *parameterMetadata:
|
|
prmFieldSet = newFieldSet(p.numArg)
|
|
p.fieldSet = prmFieldSet
|
|
case *resultMetadata:
|
|
resultFieldSet = newFieldSet(p.numArg)
|
|
p.fieldSet = resultFieldSet
|
|
}
|
|
}
|
|
|
|
if err := s.readReply(nil, f); err != nil {
|
|
return QtNone, 0, nil, nil, err
|
|
}
|
|
|
|
return s.sh.functionCode.queryType(), id, prmFieldSet, resultFieldSet, nil
|
|
}
|
|
|
|
// Exec executes a sql statement.
|
|
func (s *Session) Exec(id uint64, parameterFieldSet *FieldSet, args []driver.Value) (driver.Result, error) {
|
|
|
|
s.statementID.id = &id
|
|
if err := s.writeRequest(mtExecute, !s.conn.inTx, s.statementID, newParameters(parameterFieldSet, args)); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
wlr := newWriteLobReply() //lob streaming
|
|
|
|
f := func(pk partKind) replyPart {
|
|
switch pk {
|
|
case pkWriteLobReply:
|
|
return wlr
|
|
default:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
if err := s.readReply(f, nil); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var result driver.Result
|
|
if s.sh.functionCode == fcDDL {
|
|
result = driver.ResultNoRows
|
|
} else {
|
|
result = driver.RowsAffected(s.rowsAffected.total())
|
|
}
|
|
|
|
if wlr.numArg > 0 {
|
|
if err := s.writeLobStream(parameterFieldSet, nil, args, wlr); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
// DropStatementID releases the hdb statement handle.
|
|
func (s *Session) DropStatementID(id uint64) error {
|
|
|
|
s.statementID.id = &id
|
|
if err := s.writeRequest(mtDropStatementID, false, s.statementID); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := s.readReply(nil, nil); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Call executes a stored procedure.
|
|
func (s *Session) Call(id uint64, prmFieldSet *FieldSet, args []driver.Value) (*FieldValues, []*TableResult, error) {
|
|
|
|
s.statementID.id = &id
|
|
if err := s.writeRequest(mtExecute, false, s.statementID, newParameters(prmFieldSet, args)); err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
wlr := newWriteLobReply() //lob streaming
|
|
|
|
f := func(pk partKind) replyPart {
|
|
switch pk {
|
|
case pkWriteLobReply:
|
|
return wlr
|
|
default:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
prmFieldValues := newFieldValues(s)
|
|
var tableResults []*TableResult
|
|
var tableResult *TableResult
|
|
|
|
g := func(p replyPart) {
|
|
|
|
switch p := p.(type) {
|
|
|
|
case *outputParameters:
|
|
p.fieldSet = prmFieldSet
|
|
p.fieldValues = prmFieldValues
|
|
|
|
// table output parameters: meta, id, result (only first param?)
|
|
case *resultMetadata:
|
|
tableResult = newTableResult(s, p.numArg)
|
|
tableResults = append(tableResults, tableResult)
|
|
p.fieldSet = tableResult.fieldSet
|
|
case *resultsetID:
|
|
p.id = &(tableResult.id)
|
|
case *resultset:
|
|
tableResult.attrs = s.ph.partAttributes
|
|
p.fieldSet = tableResult.fieldSet
|
|
p.fieldValues = tableResult.fieldValues
|
|
}
|
|
}
|
|
|
|
if err := s.readReply(f, g); err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
if wlr.numArg > 0 {
|
|
if err := s.writeLobStream(prmFieldSet, prmFieldValues, args, wlr); err != nil {
|
|
return nil, nil, err
|
|
}
|
|
}
|
|
|
|
return prmFieldValues, tableResults, nil
|
|
}
|
|
|
|
// Query executes a query.
|
|
func (s *Session) Query(stmtID uint64, parameterFieldSet *FieldSet, resultFieldSet *FieldSet, args []driver.Value) (uint64, *FieldValues, PartAttributes, error) {
|
|
|
|
s.statementID.id = &stmtID
|
|
if err := s.writeRequest(mtExecute, false, s.statementID, newParameters(parameterFieldSet, args)); err != nil {
|
|
return 0, nil, nil, err
|
|
}
|
|
|
|
var rsetID uint64
|
|
fieldValues := newFieldValues(s)
|
|
|
|
f := func(p replyPart) {
|
|
|
|
switch p := p.(type) {
|
|
|
|
case *resultsetID:
|
|
p.id = &rsetID
|
|
case *resultset:
|
|
p.fieldSet = resultFieldSet
|
|
p.fieldValues = fieldValues
|
|
}
|
|
}
|
|
|
|
if err := s.readReply(nil, f); err != nil {
|
|
return 0, nil, nil, err
|
|
}
|
|
|
|
attrs := s.ph.partAttributes
|
|
|
|
return rsetID, fieldValues, attrs, nil
|
|
}
|
|
|
|
// FetchNext fetches next chunk in query result set.
|
|
func (s *Session) FetchNext(id uint64, resultFieldSet *FieldSet) (*FieldValues, PartAttributes, error) {
|
|
s.resultsetID.id = &id
|
|
if err := s.writeRequest(mtFetchNext, false, s.resultsetID, fetchsize(s.prm.FetchSize)); err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
fieldValues := newFieldValues(s)
|
|
|
|
f := func(p replyPart) {
|
|
|
|
switch p := p.(type) {
|
|
|
|
case *resultset:
|
|
p.fieldSet = resultFieldSet
|
|
p.fieldValues = fieldValues
|
|
}
|
|
}
|
|
|
|
if err := s.readReply(nil, f); err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
attrs := s.ph.partAttributes
|
|
|
|
return fieldValues, attrs, nil
|
|
}
|
|
|
|
// CloseResultsetID releases the hdb resultset handle.
|
|
func (s *Session) CloseResultsetID(id uint64) error {
|
|
|
|
s.resultsetID.id = &id
|
|
if err := s.writeRequest(mtCloseResultset, false, s.resultsetID); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := s.readReply(nil, nil); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Commit executes a database commit.
|
|
func (s *Session) Commit() error {
|
|
|
|
if err := s.writeRequest(mtCommit, false); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := s.readReply(nil, nil); err != nil {
|
|
return err
|
|
}
|
|
|
|
if trace {
|
|
outLogger.Printf("transaction flags: %s", s.txFlags)
|
|
}
|
|
|
|
s.conn.inTx = false
|
|
return nil
|
|
}
|
|
|
|
// Rollback executes a database rollback.
|
|
func (s *Session) Rollback() error {
|
|
|
|
if err := s.writeRequest(mtRollback, false); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := s.readReply(nil, nil); err != nil {
|
|
return err
|
|
}
|
|
|
|
if trace {
|
|
outLogger.Printf("transaction flags: %s", s.txFlags)
|
|
}
|
|
|
|
s.conn.inTx = false
|
|
return nil
|
|
}
|
|
|
|
// helper
|
|
func readLobStreamDone(writers []lobWriter) bool {
|
|
for _, writer := range writers {
|
|
if !writer.eof() {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
//
|
|
|
|
func (s *Session) readLobStream(writers []lobWriter) error {
|
|
|
|
f := func(pk partKind) replyPart {
|
|
switch pk {
|
|
case pkReadLobReply:
|
|
return s.readLobReply
|
|
default:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
for !readLobStreamDone(writers) {
|
|
|
|
s.readLobRequest.writers = writers
|
|
s.readLobReply.writers = writers
|
|
|
|
if err := s.writeRequest(mtWriteLob, false, s.readLobRequest); err != nil {
|
|
return err
|
|
}
|
|
if err := s.readReply(f, nil); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Session) writeLobStream(prmFieldSet *FieldSet, prmFieldValues *FieldValues, args []driver.Value, reply *writeLobReply) error {
|
|
|
|
num := reply.numArg
|
|
readers := make([]lobReader, num)
|
|
|
|
request := newWriteLobRequest(readers)
|
|
|
|
j := 0
|
|
for i, field := range prmFieldSet.fields {
|
|
|
|
if field.typeCode().isLob() && field.in() {
|
|
|
|
ptr, ok := args[i].(int64)
|
|
if !ok {
|
|
return fmt.Errorf("protocol error: invalid lob driver value type %T", args[i])
|
|
}
|
|
|
|
descr := pointerToLobWriteDescr(ptr)
|
|
if descr.r == nil {
|
|
return fmt.Errorf("protocol error: lob reader %d initial", ptr)
|
|
}
|
|
|
|
if j >= num {
|
|
return fmt.Errorf("protocol error: invalid number of lob parameter ids %d", num)
|
|
}
|
|
|
|
if field.typeCode().isCharBased() {
|
|
readers[j] = newCharLobReader(descr.r, reply.ids[j])
|
|
} else {
|
|
readers[j] = newBinaryLobReader(descr.r, reply.ids[j])
|
|
}
|
|
|
|
j++
|
|
}
|
|
}
|
|
|
|
f := func(pk partKind) replyPart {
|
|
switch pk {
|
|
case pkWriteLobReply:
|
|
return reply
|
|
default:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
g := func(p replyPart) {
|
|
if p, ok := p.(*outputParameters); ok {
|
|
p.fieldSet = prmFieldSet
|
|
p.fieldValues = prmFieldValues
|
|
}
|
|
}
|
|
|
|
for reply.numArg != 0 {
|
|
if err := s.writeRequest(mtReadLob, false, request); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := s.readReply(f, g); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
//
|
|
|
|
func (s *Session) initRequest() error {
|
|
|
|
// init
|
|
s.mh.sessionID = -1
|
|
|
|
// handshake
|
|
req := newInitRequest()
|
|
// TODO: constants
|
|
req.product.major = 4
|
|
req.product.minor = 20
|
|
req.protocol.major = 4
|
|
req.protocol.minor = 1
|
|
req.numOptions = 1
|
|
req.endianess = archEndian
|
|
if err := req.write(s.wr); err != nil {
|
|
return err
|
|
}
|
|
|
|
rep := newInitReply()
|
|
if err := rep.read(s.rd); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Session) writeRequest(messageType messageType, commit bool, requests ...requestPart) error {
|
|
|
|
partSize := make([]int, len(requests))
|
|
|
|
size := int64(segmentHeaderSize + len(requests)*partHeaderSize) //int64 to hold MaxUInt32 in 32bit OS
|
|
|
|
for i, part := range requests {
|
|
s, err := part.size()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
size += int64(s + padBytes(s))
|
|
partSize[i] = s // buffer size (expensive calculation)
|
|
}
|
|
|
|
if size > math.MaxUint32 {
|
|
return fmt.Errorf("message size %d exceeds maximum message header value %d", size, int64(math.MaxUint32)) //int64: without cast overflow error in 32bit OS
|
|
}
|
|
|
|
bufferSize := size
|
|
|
|
s.mh.varPartLength = uint32(size)
|
|
s.mh.varPartSize = uint32(bufferSize)
|
|
s.mh.noOfSegm = 1
|
|
|
|
if err := s.mh.write(s.wr); err != nil {
|
|
return err
|
|
}
|
|
|
|
if size > math.MaxInt32 {
|
|
return fmt.Errorf("message size %d exceeds maximum part header value %d", size, math.MaxInt32)
|
|
}
|
|
|
|
s.sh.messageType = messageType
|
|
s.sh.commit = commit
|
|
s.sh.segmentKind = skRequest
|
|
s.sh.segmentLength = int32(size)
|
|
s.sh.segmentOfs = 0
|
|
s.sh.noOfParts = int16(len(requests))
|
|
s.sh.segmentNo = 1
|
|
|
|
if err := s.sh.write(s.wr); err != nil {
|
|
return err
|
|
}
|
|
|
|
bufferSize -= segmentHeaderSize
|
|
|
|
for i, part := range requests {
|
|
|
|
size := partSize[i]
|
|
pad := padBytes(size)
|
|
|
|
s.ph.partKind = part.kind()
|
|
numArg := part.numArg()
|
|
switch {
|
|
default:
|
|
return fmt.Errorf("maximum number of arguments %d exceeded", numArg)
|
|
case numArg <= math.MaxInt16:
|
|
s.ph.argumentCount = int16(numArg)
|
|
s.ph.bigArgumentCount = 0
|
|
|
|
// TODO: seems not to work: see bulk insert test
|
|
case numArg <= math.MaxInt32:
|
|
s.ph.argumentCount = 0
|
|
s.ph.bigArgumentCount = int32(numArg)
|
|
}
|
|
|
|
s.ph.bufferLength = int32(size)
|
|
s.ph.bufferSize = int32(bufferSize)
|
|
|
|
if err := s.ph.write(s.wr); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := part.write(s.wr); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := s.wr.WriteZeroes(pad); err != nil {
|
|
return err
|
|
}
|
|
|
|
bufferSize -= int64(partHeaderSize + size + pad)
|
|
|
|
}
|
|
|
|
if err := s.wr.Flush(); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *Session) readReply(providePart providePart, beforeRead beforeRead) error {
|
|
|
|
replyError := false
|
|
|
|
if err := s.mh.read(s.rd); err != nil {
|
|
return err
|
|
}
|
|
if s.mh.noOfSegm != 1 {
|
|
return fmt.Errorf("simple message: no of segments %d - expected 1", s.mh.noOfSegm)
|
|
}
|
|
if err := s.sh.read(s.rd); err != nil {
|
|
return err
|
|
}
|
|
|
|
// TODO: protocol error (sps 82)?: message header varPartLength < segment header segmentLength (*1)
|
|
diff := int(s.mh.varPartLength) - int(s.sh.segmentLength)
|
|
if trace && diff != 0 {
|
|
outLogger.Printf("+++++diff %d", diff)
|
|
}
|
|
|
|
noOfParts := int(s.sh.noOfParts)
|
|
|
|
for i := 0; i < noOfParts; i++ {
|
|
|
|
if err := s.ph.read(s.rd); err != nil {
|
|
return err
|
|
}
|
|
|
|
numArg := int(s.ph.argumentCount)
|
|
|
|
var part replyPart
|
|
|
|
if providePart != nil {
|
|
part = providePart(s.ph.partKind)
|
|
} else {
|
|
part = nil
|
|
}
|
|
|
|
if part == nil { // use pre defined parts
|
|
|
|
switch s.ph.partKind {
|
|
|
|
case pkStatementID:
|
|
part = s.statementID
|
|
case pkResultMetadata:
|
|
part = s.resultMetadata
|
|
case pkResultsetID:
|
|
part = s.resultsetID
|
|
case pkResultset:
|
|
part = s.resultset
|
|
case pkParameterMetadata:
|
|
part = s.parameterMetadata
|
|
case pkOutputParameters:
|
|
part = s.outputParameters
|
|
case pkError:
|
|
replyError = true
|
|
part = s.lastError
|
|
case pkStatementContext:
|
|
part = s.stmtCtx
|
|
case pkTransactionFlags:
|
|
part = s.txFlags
|
|
case pkRowsAffected:
|
|
part = s.rowsAffected
|
|
default:
|
|
return fmt.Errorf("read not expected part kind %s", s.ph.partKind)
|
|
}
|
|
}
|
|
|
|
part.setNumArg(numArg)
|
|
|
|
if beforeRead != nil {
|
|
beforeRead(part)
|
|
}
|
|
|
|
if err := part.read(s.rd); err != nil {
|
|
return err
|
|
}
|
|
|
|
// TODO: workaround (see *)
|
|
if i != (noOfParts-1) || (i == (noOfParts-1) && diff == 0) {
|
|
if err := s.rd.Skip(padBytes(int(s.ph.bufferLength))); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
if replyError {
|
|
if s.lastError.IsWarning() {
|
|
sqltrace.Traceln(s.lastError)
|
|
} else {
|
|
return s.lastError
|
|
}
|
|
}
|
|
return nil
|
|
}
|