diff --git a/AUTHORS b/AUTHORS new file mode 100644 index 0000000..15167cd --- /dev/null +++ b/AUTHORS @@ -0,0 +1,3 @@ +# This source code refers to The Go Authors for copyright purposes. +# The master list of authors is in the main Go distribution, +# visible at http://tip.golang.org/AUTHORS. diff --git a/CONTRIBUTORS b/CONTRIBUTORS new file mode 100644 index 0000000..778a94d --- /dev/null +++ b/CONTRIBUTORS @@ -0,0 +1,37 @@ +# This is the official list of people who can contribute +# (and typically have contributed) code to the Go repository. +# The AUTHORS file lists the copyright holders; this file +# lists people. For example, Google employees are listed here +# but not in AUTHORS, because Google holds the copyright. +# +# The submission process automatically checks to make sure +# that people submitting code are listed in this file (by email address). +# +# Names should be added to this file only after verifying that +# the individual or the individual's organization has agreed to +# the appropriate Contributor License Agreement, found here: +# +# http://code.google.com/legal/individual-cla-v1.0.html +# http://code.google.com/legal/corporate-cla-v1.0.html +# +# The agreement for individuals can be filled out on the web. +# +# When adding J Random Contributor's name to this file, +# either J's name or J's organization's name should be +# added to the AUTHORS file, depending on whether the +# individual or corporate CLA was used. + +# Names should be added to this file like so: +# Name +# +# An entry with two email addresses specifies that the +# first address should be used in the submit logs and +# that the second address should be recognized as the +# same person when interacting with Rietveld. + +# Please keep the list sorted. + +Alex Brainman +Andrew Gerrand +Chris Hines +Luke Mauldin diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..7448756 --- /dev/null +++ b/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2012 The Go Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..bbf81b2 --- /dev/null +++ b/Makefile @@ -0,0 +1,44 @@ + +DB_NAME=test +PASSWORD=Passw0rd + +help: + echo "use start or stop target" + +# Microsoft SQL Server + +MSSQL_DB_FILES=/tmp/mssql_temp +MSSQL_CONTAINER_NAME=mssql_test +MSSQL_SA_PASSWORD=$(PASSWORD) + +start-mssql: + docker run --name=$(MSSQL_CONTAINER_NAME) -e 'ACCEPT_EULA=Y' -e 'MSSQL_SA_PASSWORD=$(MSSQL_SA_PASSWORD)' -e 'MSSQL_PID=Developer' --cap-add SYS_PTRACE -v $(MSSQL_DB_FILES):/var/opt/mssql -d -p 1433:1433 microsoft/mssql-server-linux + echo -n "starting $(MSSQL_CONTAINER_NAME) "; while ! docker logs $(MSSQL_CONTAINER_NAME) 2>&1 | grep SQL.Server.is.now.ready.for.client.connections >/dev/null ; do echo -n .; sleep 1; done; echo " done" + docker exec $(MSSQL_CONTAINER_NAME) /opt/mssql-tools/bin/sqlcmd -S localhost -U SA -P '$(MSSQL_SA_PASSWORD)' -Q 'create database $(DB_NAME)' + +test-mssql: + go test -v -mssrv=localhost -msdb=$(DB_NAME) -msuser=sa -mspass=$(MSSQL_SA_PASSWORD) -run=TestMSSQL + +test-mssql-race: + go test -v -mssrv=localhost -msdb=$(DB_NAME) -msuser=sa -mspass=$(MSSQL_SA_PASSWORD) -run=TestMSSQL --race + +stop-mssql: + docker stop $(MSSQL_CONTAINER_NAME) + docker rm $(MSSQL_CONTAINER_NAME) + +# MySQL + +MYSQL_CONTAINER_NAME=mysql_test +MYSQL_ROOT_PASSWORD=$(PASSWORD) + +start-mysql: + docker run --name=$(MYSQL_CONTAINER_NAME) -e 'MYSQL_ROOT_PASSWORD=$(MYSQL_ROOT_PASSWORD)' -d -p 127.0.0.1:3306:3306 mysql + echo -n "starting $(MYSQL_CONTAINER_NAME) "; while ! docker logs $(MYSQL_CONTAINER_NAME) 2>&1 | grep ^Version.*port:.3306 >/dev/null ; do echo -n .; sleep 1; done; echo " done" + docker exec $(MYSQL_CONTAINER_NAME) sh -c 'echo "create database $(DB_NAME)" | MYSQL_PWD=$(MYSQL_ROOT_PASSWORD) mysql -hlocalhost -uroot' + +test-mysql: + go test -v -mydb=$(DB_NAME) -mypass=$(MYSQL_ROOT_PASSWORD) -mysrv=127.0.0.1 -myuser=root -run=MYSQL + +stop-mysql: + docker stop $(MYSQL_CONTAINER_NAME) + docker rm $(MYSQL_CONTAINER_NAME) diff --git a/README.md b/README.md new file mode 100644 index 0000000..ffbb04f --- /dev/null +++ b/README.md @@ -0,0 +1,3 @@ +odbc driver written in go. Implements database driver interface as used by standard database/sql package. It calls into odbc dll on Windows, and uses cgo (unixODBC) everywhere else. + +To get started using odbc, have a look at the [wiki](../../wiki) pages. diff --git a/access_test.go b/access_test.go new file mode 100644 index 0000000..577ecb1 --- /dev/null +++ b/access_test.go @@ -0,0 +1,71 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package odbc + +import ( + "database/sql" + "fmt" + "io/ioutil" + "os" + "path/filepath" + "testing" + + ole "github.com/go-ole/go-ole" + "github.com/go-ole/go-ole/oleutil" +) + +func TestAccessMemo(t *testing.T) { + tmpdir, err := ioutil.TempDir("", "TestAccessMemo") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tmpdir) + + dbfilename := filepath.Join(tmpdir, "db.mdb") + createAccessDB(t, dbfilename) + + db, err := sql.Open("odbc", fmt.Sprintf("DRIVER={Microsoft Access Driver (*.mdb)};DBQ=%s;", dbfilename)) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + err = db.Ping() + if err != nil { + t.Fatal(err) + } + + _, err = db.Exec("create table mytable (m memo)") + if err != nil { + t.Fatal(err) + } + for s := ""; len(s) < 1000; s += "0123456789" { + _, err = db.Exec("insert into mytable (m) values (?)", s) + if err != nil { + t.Fatal(err) + } + } +} + +func createAccessDB(t *testing.T, dbfilename string) { + err := ole.CoInitialize(0) + if err != nil { + t.Fatal(err) + } + defer ole.CoUninitialize() + + unk, err := oleutil.CreateObject("adox.catalog") + if err != nil { + t.Fatal(err) + } + cat, err := unk.QueryInterface(ole.IID_IDispatch) + if err != nil { + t.Fatal(err) + } + _, err = oleutil.CallMethod(cat, "create", fmt.Sprintf("provider=microsoft.jet.oledb.4.0;data source=%s;", dbfilename)) + if err != nil { + t.Fatal(err) + } +} diff --git a/api/api.go b/api/api.go new file mode 100644 index 0000000..65c214a --- /dev/null +++ b/api/api.go @@ -0,0 +1,86 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package api + +//go:generate go run $GOROOT/src/syscall/mksyscall_windows.go -output zapi_windows.go api.go + +//go:generate sh -c "./mksyscall_unix.pl api.go | gofmt > zapi_unix.go" + +import ( + "unicode/utf16" +) + +type ( + SQL_DATE_STRUCT struct { + Year SQLSMALLINT + Month SQLUSMALLINT + Day SQLUSMALLINT + } + + SQL_TIME_STRUCT struct { + Hour SQLUSMALLINT + Minute SQLUSMALLINT + Second SQLUSMALLINT + } + + SQL_SS_TIME2_STRUCT struct { + Hour SQLUSMALLINT + Minute SQLUSMALLINT + Second SQLUSMALLINT + Fraction SQLUINTEGER + } + + SQL_TIMESTAMP_STRUCT struct { + Year SQLSMALLINT + Month SQLUSMALLINT + Day SQLUSMALLINT + Hour SQLUSMALLINT + Minute SQLUSMALLINT + Second SQLUSMALLINT + Fraction SQLUINTEGER + } +) + +//sys SQLAllocHandle(handleType SQLSMALLINT, inputHandle SQLHANDLE, outputHandle *SQLHANDLE) (ret SQLRETURN) = odbc32.SQLAllocHandle +//sys SQLBindCol(statementHandle SQLHSTMT, columnNumber SQLUSMALLINT, targetType SQLSMALLINT, targetValuePtr SQLPOINTER, bufferLength SQLLEN, vallen *SQLLEN) (ret SQLRETURN) = odbc32.SQLBindCol +//sys SQLBindParameter(statementHandle SQLHSTMT, parameterNumber SQLUSMALLINT, inputOutputType SQLSMALLINT, valueType SQLSMALLINT, parameterType SQLSMALLINT, columnSize SQLULEN, decimalDigits SQLSMALLINT, parameterValue SQLPOINTER, bufferLength SQLLEN, ind *SQLLEN) (ret SQLRETURN) = odbc32.SQLBindParameter +//sys SQLCloseCursor(statementHandle SQLHSTMT) (ret SQLRETURN) = odbc32.SQLCloseCursor +//sys SQLDescribeCol(statementHandle SQLHSTMT, columnNumber SQLUSMALLINT, columnName *SQLWCHAR, bufferLength SQLSMALLINT, nameLengthPtr *SQLSMALLINT, dataTypePtr *SQLSMALLINT, columnSizePtr *SQLULEN, decimalDigitsPtr *SQLSMALLINT, nullablePtr *SQLSMALLINT) (ret SQLRETURN) = odbc32.SQLDescribeColW +//sys SQLDescribeParam(statementHandle SQLHSTMT, parameterNumber SQLUSMALLINT, dataTypePtr *SQLSMALLINT, parameterSizePtr *SQLULEN, decimalDigitsPtr *SQLSMALLINT, nullablePtr *SQLSMALLINT) (ret SQLRETURN) = odbc32.SQLDescribeParam +//sys SQLDisconnect(connectionHandle SQLHDBC) (ret SQLRETURN) = odbc32.SQLDisconnect +//sys SQLDriverConnect(connectionHandle SQLHDBC, windowHandle SQLHWND, inConnectionString *SQLWCHAR, stringLength1 SQLSMALLINT, outConnectionString *SQLWCHAR, bufferLength SQLSMALLINT, stringLength2Ptr *SQLSMALLINT, driverCompletion SQLUSMALLINT) (ret SQLRETURN) = odbc32.SQLDriverConnectW +//sys SQLEndTran(handleType SQLSMALLINT, handle SQLHANDLE, completionType SQLSMALLINT) (ret SQLRETURN) = odbc32.SQLEndTran +//sys SQLExecute(statementHandle SQLHSTMT) (ret SQLRETURN) = odbc32.SQLExecute +//sys SQLFetch(statementHandle SQLHSTMT) (ret SQLRETURN) = odbc32.SQLFetch +//sys SQLFreeHandle(handleType SQLSMALLINT, handle SQLHANDLE) (ret SQLRETURN) = odbc32.SQLFreeHandle +//sys SQLGetData(statementHandle SQLHSTMT, colOrParamNum SQLUSMALLINT, targetType SQLSMALLINT, targetValuePtr SQLPOINTER, bufferLength SQLLEN, vallen *SQLLEN) (ret SQLRETURN) = odbc32.SQLGetData +//sys SQLGetDiagRec(handleType SQLSMALLINT, handle SQLHANDLE, recNumber SQLSMALLINT, sqlState *SQLWCHAR, nativeErrorPtr *SQLINTEGER, messageText *SQLWCHAR, bufferLength SQLSMALLINT, textLengthPtr *SQLSMALLINT) (ret SQLRETURN) = odbc32.SQLGetDiagRecW +//sys SQLNumParams(statementHandle SQLHSTMT, parameterCountPtr *SQLSMALLINT) (ret SQLRETURN) = odbc32.SQLNumParams +//sys SQLMoreResults(statementHandle SQLHSTMT) (ret SQLRETURN) = odbc32.SQLMoreResults +//sys SQLNumResultCols(statementHandle SQLHSTMT, columnCountPtr *SQLSMALLINT) (ret SQLRETURN) = odbc32.SQLNumResultCols +//sys SQLPrepare(statementHandle SQLHSTMT, statementText *SQLWCHAR, textLength SQLINTEGER) (ret SQLRETURN) = odbc32.SQLPrepareW +//sys SQLRowCount(statementHandle SQLHSTMT, rowCountPtr *SQLLEN) (ret SQLRETURN) = odbc32.SQLRowCount +//sys SQLSetEnvAttr(environmentHandle SQLHENV, attribute SQLINTEGER, valuePtr SQLPOINTER, stringLength SQLINTEGER) (ret SQLRETURN) = odbc32.SQLSetEnvAttr +//sys SQLSetConnectAttr(connectionHandle SQLHDBC, attribute SQLINTEGER, valuePtr SQLPOINTER, stringLength SQLINTEGER) (ret SQLRETURN) = odbc32.SQLSetConnectAttrW + +// UTF16ToString returns the UTF-8 encoding of the UTF-16 sequence s, +// with a terminating NUL removed. +func UTF16ToString(s []uint16) string { + for i, v := range s { + if v == 0 { + s = s[0:i] + break + } + } + return string(utf16.Decode(s)) +} + +// StringToUTF16 returns the UTF-16 encoding of the UTF-8 string s, +// with a terminating NUL added. +func StringToUTF16(s string) []uint16 { return utf16.Encode([]rune(s + "\x00")) } + +// StringToUTF16Ptr returns pointer to the UTF-16 encoding of +// the UTF-8 string s, with a terminating NUL added. +func StringToUTF16Ptr(s string) *uint16 { return &StringToUTF16(s)[0] } diff --git a/api/api_unix.go b/api/api_unix.go new file mode 100644 index 0000000..4f25e76 --- /dev/null +++ b/api/api_unix.go @@ -0,0 +1,162 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build darwin linux freebsd +// +build cgo + +package api + +// #cgo darwin LDFLAGS: -L /usr/local/opt/unixodbc/lib -lodbc +// #cgo darwin CFLAGS: -I /usr/local/opt/unixodbc/include +// #cgo linux LDFLAGS: -lodbc +// #cgo freebsd LDFLAGS: -L /usr/local/lib -lodbc +// #cgo freebsd CFLAGS: -I/usr/local/include +// #include +// #include +// #include +/* +SQLRETURN sqlSetEnvUIntPtrAttr(SQLHENV environmentHandle, SQLINTEGER attribute, uintptr_t valuePtr, SQLINTEGER stringLength) { + return SQLSetEnvAttr(environmentHandle, attribute, (SQLPOINTER)valuePtr, stringLength); +} + +SQLRETURN sqlSetConnectUIntPtrAttr(SQLHDBC connectionHandle, SQLINTEGER attribute, uintptr_t valuePtr, SQLINTEGER stringLength) { + return SQLSetConnectAttr(connectionHandle, attribute, (SQLPOINTER)valuePtr, stringLength); +} +*/ +import "C" + +const ( + SQL_OV_ODBC3 = uintptr(C.SQL_OV_ODBC3) + + SQL_ATTR_ODBC_VERSION = C.SQL_ATTR_ODBC_VERSION + + SQL_DRIVER_NOPROMPT = C.SQL_DRIVER_NOPROMPT + + SQL_HANDLE_ENV = C.SQL_HANDLE_ENV + SQL_HANDLE_DBC = C.SQL_HANDLE_DBC + SQL_HANDLE_STMT = C.SQL_HANDLE_STMT + + SQL_SUCCESS = C.SQL_SUCCESS + SQL_SUCCESS_WITH_INFO = C.SQL_SUCCESS_WITH_INFO + SQL_INVALID_HANDLE = C.SQL_INVALID_HANDLE + SQL_NO_DATA = C.SQL_NO_DATA + SQL_NO_TOTAL = C.SQL_NO_TOTAL + SQL_NTS = C.SQL_NTS + SQL_MAX_MESSAGE_LENGTH = C.SQL_MAX_MESSAGE_LENGTH + SQL_NULL_HANDLE = uintptr(C.SQL_NULL_HANDLE) + SQL_NULL_HENV = uintptr(C.SQL_NULL_HENV) + SQL_NULL_HDBC = uintptr(C.SQL_NULL_HDBC) + SQL_NULL_HSTMT = uintptr(C.SQL_NULL_HSTMT) + + SQL_PARAM_INPUT = C.SQL_PARAM_INPUT + + SQL_NULL_DATA = C.SQL_NULL_DATA + SQL_DATA_AT_EXEC = C.SQL_DATA_AT_EXEC + + SQL_UNKNOWN_TYPE = C.SQL_UNKNOWN_TYPE + SQL_CHAR = C.SQL_CHAR + SQL_NUMERIC = C.SQL_NUMERIC + SQL_DECIMAL = C.SQL_DECIMAL + SQL_INTEGER = C.SQL_INTEGER + SQL_SMALLINT = C.SQL_SMALLINT + SQL_FLOAT = C.SQL_FLOAT + SQL_REAL = C.SQL_REAL + SQL_DOUBLE = C.SQL_DOUBLE + SQL_DATETIME = C.SQL_DATETIME + SQL_DATE = C.SQL_DATE + SQL_TIME = C.SQL_TIME + SQL_VARCHAR = C.SQL_VARCHAR + SQL_TYPE_DATE = C.SQL_TYPE_DATE + SQL_TYPE_TIME = C.SQL_TYPE_TIME + SQL_TYPE_TIMESTAMP = C.SQL_TYPE_TIMESTAMP + SQL_TIMESTAMP = C.SQL_TIMESTAMP + SQL_LONGVARCHAR = C.SQL_LONGVARCHAR + SQL_BINARY = C.SQL_BINARY + SQL_VARBINARY = C.SQL_VARBINARY + SQL_LONGVARBINARY = C.SQL_LONGVARBINARY + SQL_BIGINT = C.SQL_BIGINT + SQL_TINYINT = C.SQL_TINYINT + SQL_BIT = C.SQL_BIT + SQL_WCHAR = C.SQL_WCHAR + SQL_WVARCHAR = C.SQL_WVARCHAR + SQL_WLONGVARCHAR = C.SQL_WLONGVARCHAR + SQL_GUID = C.SQL_GUID + SQL_SIGNED_OFFSET = C.SQL_SIGNED_OFFSET + SQL_UNSIGNED_OFFSET = C.SQL_UNSIGNED_OFFSET + + // TODO(lukemauldin): Not defined in sqlext.h. Using windows value, but it is not supported. + SQL_SS_XML = -152 + SQL_SS_TIME2 = -154 + + SQL_C_CHAR = C.SQL_C_CHAR + SQL_C_LONG = C.SQL_C_LONG + SQL_C_SHORT = C.SQL_C_SHORT + SQL_C_FLOAT = C.SQL_C_FLOAT + SQL_C_DOUBLE = C.SQL_C_DOUBLE + SQL_C_NUMERIC = C.SQL_C_NUMERIC + SQL_C_DATE = C.SQL_C_DATE + SQL_C_TIME = C.SQL_C_TIME + SQL_C_TYPE_TIMESTAMP = C.SQL_C_TYPE_TIMESTAMP + SQL_C_TIMESTAMP = C.SQL_C_TIMESTAMP + SQL_C_BINARY = C.SQL_C_BINARY + SQL_C_BIT = C.SQL_C_BIT + SQL_C_WCHAR = C.SQL_C_WCHAR + SQL_C_DEFAULT = C.SQL_C_DEFAULT + SQL_C_SBIGINT = C.SQL_C_SBIGINT + SQL_C_UBIGINT = C.SQL_C_UBIGINT + SQL_C_GUID = C.SQL_C_GUID + + SQL_COMMIT = C.SQL_COMMIT + SQL_ROLLBACK = C.SQL_ROLLBACK + + SQL_AUTOCOMMIT = C.SQL_AUTOCOMMIT + SQL_ATTR_AUTOCOMMIT = C.SQL_ATTR_AUTOCOMMIT + SQL_AUTOCOMMIT_OFF = C.SQL_AUTOCOMMIT_OFF + SQL_AUTOCOMMIT_ON = C.SQL_AUTOCOMMIT_ON + SQL_AUTOCOMMIT_DEFAULT = C.SQL_AUTOCOMMIT_DEFAULT + + SQL_IS_UINTEGER = C.SQL_IS_UINTEGER + + //Connection pooling + SQL_ATTR_CONNECTION_POOLING = C.SQL_ATTR_CONNECTION_POOLING + SQL_ATTR_CP_MATCH = C.SQL_ATTR_CP_MATCH + SQL_CP_OFF = uintptr(C.SQL_CP_OFF) + SQL_CP_ONE_PER_DRIVER = uintptr(C.SQL_CP_ONE_PER_DRIVER) + SQL_CP_ONE_PER_HENV = uintptr(C.SQL_CP_ONE_PER_HENV) + SQL_CP_DEFAULT = SQL_CP_OFF + SQL_CP_STRICT_MATCH = uintptr(C.SQL_CP_STRICT_MATCH) + SQL_CP_RELAXED_MATCH = uintptr(C.SQL_CP_RELAXED_MATCH) +) + +type ( + SQLHANDLE C.SQLHANDLE + SQLHENV C.SQLHENV + SQLHDBC C.SQLHDBC + SQLHSTMT C.SQLHSTMT + SQLHWND uintptr + + SQLWCHAR C.SQLWCHAR + SQLSCHAR C.SQLSCHAR + SQLSMALLINT C.SQLSMALLINT + SQLUSMALLINT C.SQLUSMALLINT + SQLINTEGER C.SQLINTEGER + SQLUINTEGER C.SQLUINTEGER + SQLPOINTER C.SQLPOINTER + SQLRETURN C.SQLRETURN + + SQLLEN C.SQLLEN + SQLULEN C.SQLULEN + + SQLGUID C.SQLGUID +) + +func SQLSetEnvUIntPtrAttr(environmentHandle SQLHENV, attribute SQLINTEGER, valuePtr uintptr, stringLength SQLINTEGER) (ret SQLRETURN) { + r := C.sqlSetEnvUIntPtrAttr(C.SQLHENV(environmentHandle), C.SQLINTEGER(attribute), C.uintptr_t(valuePtr), C.SQLINTEGER(stringLength)) + return SQLRETURN(r) +} + +func SQLSetConnectUIntPtrAttr(connectionHandle SQLHDBC, attribute SQLINTEGER, valuePtr uintptr, stringLength SQLINTEGER) (ret SQLRETURN) { + r := C.sqlSetConnectUIntPtrAttr(C.SQLHDBC(connectionHandle), C.SQLINTEGER(attribute), C.uintptr_t(valuePtr), C.SQLINTEGER(stringLength)) + return SQLRETURN(r) +} diff --git a/api/api_windows.go b/api/api_windows.go new file mode 100644 index 0000000..30e978f --- /dev/null +++ b/api/api_windows.go @@ -0,0 +1,147 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package api + +import ( + "syscall" + "unsafe" +) + +const ( + SQL_OV_ODBC3 = uintptr(3) + + SQL_ATTR_ODBC_VERSION = 200 + + SQL_DRIVER_NOPROMPT = 0 + + SQL_HANDLE_ENV = 1 + SQL_HANDLE_DBC = 2 + SQL_HANDLE_STMT = 3 + + SQL_SUCCESS = 0 + SQL_SUCCESS_WITH_INFO = 1 + SQL_INVALID_HANDLE = -2 + SQL_NO_DATA = 100 + SQL_NO_TOTAL = -4 + SQL_NTS = -3 + SQL_MAX_MESSAGE_LENGTH = 512 + SQL_NULL_HANDLE = 0 + SQL_NULL_HENV = 0 + SQL_NULL_HDBC = 0 + SQL_NULL_HSTMT = 0 + + SQL_PARAM_INPUT = 1 + + SQL_NULL_DATA = -1 + SQL_DATA_AT_EXEC = -2 + + SQL_UNKNOWN_TYPE = 0 + SQL_CHAR = 1 + SQL_NUMERIC = 2 + SQL_DECIMAL = 3 + SQL_INTEGER = 4 + SQL_SMALLINT = 5 + SQL_FLOAT = 6 + SQL_REAL = 7 + SQL_DOUBLE = 8 + SQL_DATETIME = 9 + SQL_DATE = 9 + SQL_TIME = 10 + SQL_VARCHAR = 12 + SQL_TYPE_DATE = 91 + SQL_TYPE_TIME = 92 + SQL_TYPE_TIMESTAMP = 93 + SQL_TIMESTAMP = 11 + SQL_LONGVARCHAR = -1 + SQL_BINARY = -2 + SQL_VARBINARY = -3 + SQL_LONGVARBINARY = -4 + SQL_BIGINT = -5 + SQL_TINYINT = -6 + SQL_BIT = -7 + SQL_WCHAR = -8 + SQL_WVARCHAR = -9 + SQL_WLONGVARCHAR = -10 + SQL_GUID = -11 + SQL_SIGNED_OFFSET = -20 + SQL_UNSIGNED_OFFSET = -22 + SQL_SS_XML = -152 + SQL_SS_TIME2 = -154 + + SQL_C_CHAR = SQL_CHAR + SQL_C_LONG = SQL_INTEGER + SQL_C_SHORT = SQL_SMALLINT + SQL_C_FLOAT = SQL_REAL + SQL_C_DOUBLE = SQL_DOUBLE + SQL_C_NUMERIC = SQL_NUMERIC + SQL_C_DATE = SQL_DATE + SQL_C_TIME = SQL_TIME + SQL_C_TYPE_TIMESTAMP = SQL_TYPE_TIMESTAMP + SQL_C_TIMESTAMP = SQL_TIMESTAMP + SQL_C_BINARY = SQL_BINARY + SQL_C_BIT = SQL_BIT + SQL_C_WCHAR = SQL_WCHAR + SQL_C_DEFAULT = 99 + SQL_C_SBIGINT = SQL_BIGINT + SQL_SIGNED_OFFSET + SQL_C_UBIGINT = SQL_BIGINT + SQL_UNSIGNED_OFFSET + SQL_C_GUID = SQL_GUID + + SQL_COMMIT = 0 + SQL_ROLLBACK = 1 + + SQL_AUTOCOMMIT = 102 + SQL_ATTR_AUTOCOMMIT = SQL_AUTOCOMMIT + SQL_AUTOCOMMIT_OFF = 0 + SQL_AUTOCOMMIT_ON = 1 + SQL_AUTOCOMMIT_DEFAULT = SQL_AUTOCOMMIT_ON + + SQL_IS_UINTEGER = -5 + + //Connection pooling + SQL_ATTR_CONNECTION_POOLING = 201 + SQL_ATTR_CP_MATCH = 202 + SQL_CP_OFF = 0 + SQL_CP_ONE_PER_DRIVER = 1 + SQL_CP_ONE_PER_HENV = uintptr(2) + SQL_CP_DEFAULT = SQL_CP_OFF + SQL_CP_STRICT_MATCH = 0 + SQL_CP_RELAXED_MATCH = uintptr(1) +) + +type ( + SQLHANDLE uintptr + SQLHENV SQLHANDLE + SQLHDBC SQLHANDLE + SQLHSTMT SQLHANDLE + SQLHWND uintptr + + SQLWCHAR uint16 + SQLSCHAR int8 + SQLSMALLINT int16 + SQLUSMALLINT uint16 + SQLINTEGER int32 + SQLUINTEGER uint32 + SQLPOINTER unsafe.Pointer + SQLRETURN SQLSMALLINT + + SQLGUID struct { + Data1 uint32 + Data2 uint16 + Data3 uint16 + Data4 [8]byte + } +) + +func SQLSetEnvUIntPtrAttr(environmentHandle SQLHENV, attribute SQLINTEGER, valuePtr uintptr, stringLength SQLINTEGER) (ret SQLRETURN) { + r0, _, _ := syscall.Syscall6(procSQLSetEnvAttr.Addr(), 4, uintptr(environmentHandle), uintptr(attribute), uintptr(valuePtr), uintptr(stringLength), 0, 0) + ret = SQLRETURN(r0) + return +} + +func SQLSetConnectUIntPtrAttr(connectionHandle SQLHDBC, attribute SQLINTEGER, valuePtr uintptr, stringLength SQLINTEGER) (ret SQLRETURN) { + r0, _, _ := syscall.Syscall6(procSQLSetConnectAttrW.Addr(), 4, uintptr(connectionHandle), uintptr(attribute), uintptr(valuePtr), uintptr(stringLength), 0, 0) + ret = SQLRETURN(r0) + return +} diff --git a/api/api_windows_386.go b/api/api_windows_386.go new file mode 100644 index 0000000..2edf92a --- /dev/null +++ b/api/api_windows_386.go @@ -0,0 +1,10 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package api + +type ( + SQLLEN SQLINTEGER + SQLULEN SQLUINTEGER +) diff --git a/api/api_windows_amd64.go b/api/api_windows_amd64.go new file mode 100644 index 0000000..b603663 --- /dev/null +++ b/api/api_windows_amd64.go @@ -0,0 +1,10 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package api + +type ( + SQLLEN int64 + SQLULEN uint64 +) diff --git a/api/mksyscall_unix.pl b/api/mksyscall_unix.pl new file mode 100755 index 0000000..6fce762 --- /dev/null +++ b/api/mksyscall_unix.pl @@ -0,0 +1,135 @@ +#!/usr/bin/env perl +# Copyright 2012 The Go Authors. All rights reserved. +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +# This program is based on $GOROOT/src/pkg/syscall/mksyscall_windows.pl. + +use strict; + +my $cmdline = "mksyscall_unix.pl " . join(' ', @ARGV); +my $errors = 0; + +binmode STDOUT; + +if($ARGV[0] =~ /^-/) { + print STDERR "usage: mksyscall_unix.pl [file ...]\n"; + exit 1; +} + +sub parseparamlist($) { + my ($list) = @_; + $list =~ s/^\s*//; + $list =~ s/\s*$//; + if($list eq "") { + return (); + } + return split(/\s*,\s*/, $list); +} + +sub parseparam($) { + my ($p) = @_; + if($p !~ /^(\S*) (\S*)$/) { + print STDERR "$ARGV:$.: malformed parameter: $p\n"; + $errors = 1; + return ("xx", "int"); + } + return ($1, $2); +} + +my $package = ""; +my $text = ""; +while(<>) { + chomp; + s/\s+/ /g; + s/^\s+//; + s/\s+$//; + $package = $1 if !$package && /^package (\S+)$/; + next if !/^\/\/sys /; + + # Line must be of the form + # func Open(path string, mode int, perm int) (fd int, err error) + # Split into name, in params, out params. + if(!/^\/\/sys (\w+)\(([^()]*)\)\s*(?:\(([^()]+)\))?\s*(?:\[failretval(.*)\])?\s*(?:=\s*(?:(\w*)\.)?(\w*))?$/) { + print STDERR "$ARGV:$.: malformed //sys declaration\n"; + $errors = 1; + next; + } + my ($func, $in, $out, $failcond, $modname, $sysname) = ($1, $2, $3, $4, $5, $6); + + # Split argument lists on comma. + my @in = parseparamlist($in); + my @out = parseparamlist($out); + + # System call name. + if($sysname eq "") { + $sysname = "$func"; + } + + # Go function header. + $out = join(', ', @out); + if($out ne "") { + $out = " ($out)"; + } + if($text ne "") { + $text .= "\n" + } + $text .= sprintf "func %s(%s)%s {\n", $func, join(', ', @in), $out; + + # Prepare arguments. + my @sqlin= (); + my @pin= (); + foreach my $p (@in) { + my ($name, $type) = parseparam($p); + + if($type =~ /^\*(SQLCHAR)/) { + push @sqlin, sprintf "(*C.%s)(unsafe.Pointer(%s))", $1, $name; + } elsif($type =~ /^\*(SQLWCHAR)/) { + push @sqlin, sprintf "(*C.%s)(unsafe.Pointer(%s))", $1, $name; + } elsif($type =~ /^\*(.*)$/) { + push @sqlin, sprintf "(*C.%s)(%s)", $1, $name; + } else { + push @sqlin, sprintf "C.%s(%s)", $type, $name; + } + push @pin, sprintf "\"%s=\", %s, ", $name, $name; + } + + $text .= sprintf "\tr := C.%s(%s)\n", $sysname, join(',', @sqlin); + if(0) { + $text .= sprintf "println(\"SYSCALL: %s(\", %s\") (\", r, \")\")\n", $func, join('", ", ', @pin); + } + $text .= "\treturn SQLRETURN(r)\n"; + $text .= "}\n"; +} + +if($errors) { + exit 1; +} + +print < +// #include +import "C" + +$text + +EOF +exit 0; diff --git a/api/zapi_unix.go b/api/zapi_unix.go new file mode 100644 index 0000000..e36d53f --- /dev/null +++ b/api/zapi_unix.go @@ -0,0 +1,126 @@ +// mksyscall_unix.pl api.go +// MACHINE GENERATED BY THE COMMAND ABOVE; DO NOT EDIT + +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build darwin linux freebsd +// +build cgo + +package api + +import "unsafe" + +// #cgo darwin LDFLAGS: -lodbc +// #cgo linux LDFLAGS: -lodbc +// #cgo freebsd LDFLAGS: -L /usr/local/lib -lodbc +// #cgo freebsd CFLAGS: -I/usr/local/include +// #include +// #include +import "C" + +func SQLAllocHandle(handleType SQLSMALLINT, inputHandle SQLHANDLE, outputHandle *SQLHANDLE) (ret SQLRETURN) { + r := C.SQLAllocHandle(C.SQLSMALLINT(handleType), C.SQLHANDLE(inputHandle), (*C.SQLHANDLE)(outputHandle)) + return SQLRETURN(r) +} + +func SQLBindCol(statementHandle SQLHSTMT, columnNumber SQLUSMALLINT, targetType SQLSMALLINT, targetValuePtr SQLPOINTER, bufferLength SQLLEN, vallen *SQLLEN) (ret SQLRETURN) { + r := C.SQLBindCol(C.SQLHSTMT(statementHandle), C.SQLUSMALLINT(columnNumber), C.SQLSMALLINT(targetType), C.SQLPOINTER(targetValuePtr), C.SQLLEN(bufferLength), (*C.SQLLEN)(vallen)) + return SQLRETURN(r) +} + +func SQLBindParameter(statementHandle SQLHSTMT, parameterNumber SQLUSMALLINT, inputOutputType SQLSMALLINT, valueType SQLSMALLINT, parameterType SQLSMALLINT, columnSize SQLULEN, decimalDigits SQLSMALLINT, parameterValue SQLPOINTER, bufferLength SQLLEN, ind *SQLLEN) (ret SQLRETURN) { + r := C.SQLBindParameter(C.SQLHSTMT(statementHandle), C.SQLUSMALLINT(parameterNumber), C.SQLSMALLINT(inputOutputType), C.SQLSMALLINT(valueType), C.SQLSMALLINT(parameterType), C.SQLULEN(columnSize), C.SQLSMALLINT(decimalDigits), C.SQLPOINTER(parameterValue), C.SQLLEN(bufferLength), (*C.SQLLEN)(ind)) + return SQLRETURN(r) +} + +func SQLCloseCursor(statementHandle SQLHSTMT) (ret SQLRETURN) { + r := C.SQLCloseCursor(C.SQLHSTMT(statementHandle)) + return SQLRETURN(r) +} + +func SQLDescribeCol(statementHandle SQLHSTMT, columnNumber SQLUSMALLINT, columnName *SQLWCHAR, bufferLength SQLSMALLINT, nameLengthPtr *SQLSMALLINT, dataTypePtr *SQLSMALLINT, columnSizePtr *SQLULEN, decimalDigitsPtr *SQLSMALLINT, nullablePtr *SQLSMALLINT) (ret SQLRETURN) { + r := C.SQLDescribeColW(C.SQLHSTMT(statementHandle), C.SQLUSMALLINT(columnNumber), (*C.SQLWCHAR)(unsafe.Pointer(columnName)), C.SQLSMALLINT(bufferLength), (*C.SQLSMALLINT)(nameLengthPtr), (*C.SQLSMALLINT)(dataTypePtr), (*C.SQLULEN)(columnSizePtr), (*C.SQLSMALLINT)(decimalDigitsPtr), (*C.SQLSMALLINT)(nullablePtr)) + return SQLRETURN(r) +} + +func SQLDescribeParam(statementHandle SQLHSTMT, parameterNumber SQLUSMALLINT, dataTypePtr *SQLSMALLINT, parameterSizePtr *SQLULEN, decimalDigitsPtr *SQLSMALLINT, nullablePtr *SQLSMALLINT) (ret SQLRETURN) { + r := C.SQLDescribeParam(C.SQLHSTMT(statementHandle), C.SQLUSMALLINT(parameterNumber), (*C.SQLSMALLINT)(dataTypePtr), (*C.SQLULEN)(parameterSizePtr), (*C.SQLSMALLINT)(decimalDigitsPtr), (*C.SQLSMALLINT)(nullablePtr)) + return SQLRETURN(r) +} + +func SQLDisconnect(connectionHandle SQLHDBC) (ret SQLRETURN) { + r := C.SQLDisconnect(C.SQLHDBC(connectionHandle)) + return SQLRETURN(r) +} + +func SQLDriverConnect(connectionHandle SQLHDBC, windowHandle SQLHWND, inConnectionString *SQLWCHAR, stringLength1 SQLSMALLINT, outConnectionString *SQLWCHAR, bufferLength SQLSMALLINT, stringLength2Ptr *SQLSMALLINT, driverCompletion SQLUSMALLINT) (ret SQLRETURN) { + r := C.SQLDriverConnectW(C.SQLHDBC(connectionHandle), C.SQLHWND(windowHandle), (*C.SQLWCHAR)(unsafe.Pointer(inConnectionString)), C.SQLSMALLINT(stringLength1), (*C.SQLWCHAR)(unsafe.Pointer(outConnectionString)), C.SQLSMALLINT(bufferLength), (*C.SQLSMALLINT)(stringLength2Ptr), C.SQLUSMALLINT(driverCompletion)) + return SQLRETURN(r) +} + +func SQLEndTran(handleType SQLSMALLINT, handle SQLHANDLE, completionType SQLSMALLINT) (ret SQLRETURN) { + r := C.SQLEndTran(C.SQLSMALLINT(handleType), C.SQLHANDLE(handle), C.SQLSMALLINT(completionType)) + return SQLRETURN(r) +} + +func SQLExecute(statementHandle SQLHSTMT) (ret SQLRETURN) { + r := C.SQLExecute(C.SQLHSTMT(statementHandle)) + return SQLRETURN(r) +} + +func SQLFetch(statementHandle SQLHSTMT) (ret SQLRETURN) { + r := C.SQLFetch(C.SQLHSTMT(statementHandle)) + return SQLRETURN(r) +} + +func SQLFreeHandle(handleType SQLSMALLINT, handle SQLHANDLE) (ret SQLRETURN) { + r := C.SQLFreeHandle(C.SQLSMALLINT(handleType), C.SQLHANDLE(handle)) + return SQLRETURN(r) +} + +func SQLGetData(statementHandle SQLHSTMT, colOrParamNum SQLUSMALLINT, targetType SQLSMALLINT, targetValuePtr SQLPOINTER, bufferLength SQLLEN, vallen *SQLLEN) (ret SQLRETURN) { + r := C.SQLGetData(C.SQLHSTMT(statementHandle), C.SQLUSMALLINT(colOrParamNum), C.SQLSMALLINT(targetType), C.SQLPOINTER(targetValuePtr), C.SQLLEN(bufferLength), (*C.SQLLEN)(vallen)) + return SQLRETURN(r) +} + +func SQLGetDiagRec(handleType SQLSMALLINT, handle SQLHANDLE, recNumber SQLSMALLINT, sqlState *SQLWCHAR, nativeErrorPtr *SQLINTEGER, messageText *SQLWCHAR, bufferLength SQLSMALLINT, textLengthPtr *SQLSMALLINT) (ret SQLRETURN) { + r := C.SQLGetDiagRecW(C.SQLSMALLINT(handleType), C.SQLHANDLE(handle), C.SQLSMALLINT(recNumber), (*C.SQLWCHAR)(unsafe.Pointer(sqlState)), (*C.SQLINTEGER)(nativeErrorPtr), (*C.SQLWCHAR)(unsafe.Pointer(messageText)), C.SQLSMALLINT(bufferLength), (*C.SQLSMALLINT)(textLengthPtr)) + return SQLRETURN(r) +} + +func SQLNumParams(statementHandle SQLHSTMT, parameterCountPtr *SQLSMALLINT) (ret SQLRETURN) { + r := C.SQLNumParams(C.SQLHSTMT(statementHandle), (*C.SQLSMALLINT)(parameterCountPtr)) + return SQLRETURN(r) +} + +func SQLMoreResults(statementHandle SQLHSTMT) (ret SQLRETURN) { + r := C.SQLMoreResults(C.SQLHSTMT(statementHandle)) + return SQLRETURN(r) +} + +func SQLNumResultCols(statementHandle SQLHSTMT, columnCountPtr *SQLSMALLINT) (ret SQLRETURN) { + r := C.SQLNumResultCols(C.SQLHSTMT(statementHandle), (*C.SQLSMALLINT)(columnCountPtr)) + return SQLRETURN(r) +} + +func SQLPrepare(statementHandle SQLHSTMT, statementText *SQLWCHAR, textLength SQLINTEGER) (ret SQLRETURN) { + r := C.SQLPrepareW(C.SQLHSTMT(statementHandle), (*C.SQLWCHAR)(unsafe.Pointer(statementText)), C.SQLINTEGER(textLength)) + return SQLRETURN(r) +} + +func SQLRowCount(statementHandle SQLHSTMT, rowCountPtr *SQLLEN) (ret SQLRETURN) { + r := C.SQLRowCount(C.SQLHSTMT(statementHandle), (*C.SQLLEN)(rowCountPtr)) + return SQLRETURN(r) +} + +func SQLSetEnvAttr(environmentHandle SQLHENV, attribute SQLINTEGER, valuePtr SQLPOINTER, stringLength SQLINTEGER) (ret SQLRETURN) { + r := C.SQLSetEnvAttr(C.SQLHENV(environmentHandle), C.SQLINTEGER(attribute), C.SQLPOINTER(valuePtr), C.SQLINTEGER(stringLength)) + return SQLRETURN(r) +} + +func SQLSetConnectAttr(connectionHandle SQLHDBC, attribute SQLINTEGER, valuePtr SQLPOINTER, stringLength SQLINTEGER) (ret SQLRETURN) { + r := C.SQLSetConnectAttrW(C.SQLHDBC(connectionHandle), C.SQLINTEGER(attribute), C.SQLPOINTER(valuePtr), C.SQLINTEGER(stringLength)) + return SQLRETURN(r) +} diff --git a/api/zapi_windows.go b/api/zapi_windows.go new file mode 100644 index 0000000..3657da3 --- /dev/null +++ b/api/zapi_windows.go @@ -0,0 +1,189 @@ +// MACHINE GENERATED BY 'go generate' COMMAND; DO NOT EDIT + +package api + +import ( + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +var _ unsafe.Pointer + +// Do the interface allocations only once for common +// Errno values. +const ( + errnoERROR_IO_PENDING = 997 +) + +var ( + errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) +) + +// errnoErr returns common boxed Errno values, to prevent +// allocations at runtime. +func errnoErr(e syscall.Errno) error { + switch e { + case 0: + return nil + case errnoERROR_IO_PENDING: + return errERROR_IO_PENDING + } + // TODO: add more here, after collecting data on the common + // error values see on Windows. (perhaps when running + // all.bat?) + return e +} + +var ( + mododbc32 = windows.NewLazySystemDLL("odbc32.dll") + + procSQLAllocHandle = mododbc32.NewProc("SQLAllocHandle") + procSQLBindCol = mododbc32.NewProc("SQLBindCol") + procSQLBindParameter = mododbc32.NewProc("SQLBindParameter") + procSQLCloseCursor = mododbc32.NewProc("SQLCloseCursor") + procSQLDescribeColW = mododbc32.NewProc("SQLDescribeColW") + procSQLDescribeParam = mododbc32.NewProc("SQLDescribeParam") + procSQLDisconnect = mododbc32.NewProc("SQLDisconnect") + procSQLDriverConnectW = mododbc32.NewProc("SQLDriverConnectW") + procSQLEndTran = mododbc32.NewProc("SQLEndTran") + procSQLExecute = mododbc32.NewProc("SQLExecute") + procSQLFetch = mododbc32.NewProc("SQLFetch") + procSQLFreeHandle = mododbc32.NewProc("SQLFreeHandle") + procSQLGetData = mododbc32.NewProc("SQLGetData") + procSQLGetDiagRecW = mododbc32.NewProc("SQLGetDiagRecW") + procSQLNumParams = mododbc32.NewProc("SQLNumParams") + procSQLMoreResults = mododbc32.NewProc("SQLMoreResults") + procSQLNumResultCols = mododbc32.NewProc("SQLNumResultCols") + procSQLPrepareW = mododbc32.NewProc("SQLPrepareW") + procSQLRowCount = mododbc32.NewProc("SQLRowCount") + procSQLSetEnvAttr = mododbc32.NewProc("SQLSetEnvAttr") + procSQLSetConnectAttrW = mododbc32.NewProc("SQLSetConnectAttrW") +) + +func SQLAllocHandle(handleType SQLSMALLINT, inputHandle SQLHANDLE, outputHandle *SQLHANDLE) (ret SQLRETURN) { + r0, _, _ := syscall.Syscall(procSQLAllocHandle.Addr(), 3, uintptr(handleType), uintptr(inputHandle), uintptr(unsafe.Pointer(outputHandle))) + ret = SQLRETURN(r0) + return +} + +func SQLBindCol(statementHandle SQLHSTMT, columnNumber SQLUSMALLINT, targetType SQLSMALLINT, targetValuePtr SQLPOINTER, bufferLength SQLLEN, vallen *SQLLEN) (ret SQLRETURN) { + r0, _, _ := syscall.Syscall6(procSQLBindCol.Addr(), 6, uintptr(statementHandle), uintptr(columnNumber), uintptr(targetType), uintptr(targetValuePtr), uintptr(bufferLength), uintptr(unsafe.Pointer(vallen))) + ret = SQLRETURN(r0) + return +} + +func SQLBindParameter(statementHandle SQLHSTMT, parameterNumber SQLUSMALLINT, inputOutputType SQLSMALLINT, valueType SQLSMALLINT, parameterType SQLSMALLINT, columnSize SQLULEN, decimalDigits SQLSMALLINT, parameterValue SQLPOINTER, bufferLength SQLLEN, ind *SQLLEN) (ret SQLRETURN) { + r0, _, _ := syscall.Syscall12(procSQLBindParameter.Addr(), 10, uintptr(statementHandle), uintptr(parameterNumber), uintptr(inputOutputType), uintptr(valueType), uintptr(parameterType), uintptr(columnSize), uintptr(decimalDigits), uintptr(parameterValue), uintptr(bufferLength), uintptr(unsafe.Pointer(ind)), 0, 0) + ret = SQLRETURN(r0) + return +} + +func SQLCloseCursor(statementHandle SQLHSTMT) (ret SQLRETURN) { + r0, _, _ := syscall.Syscall(procSQLCloseCursor.Addr(), 1, uintptr(statementHandle), 0, 0) + ret = SQLRETURN(r0) + return +} + +func SQLDescribeCol(statementHandle SQLHSTMT, columnNumber SQLUSMALLINT, columnName *SQLWCHAR, bufferLength SQLSMALLINT, nameLengthPtr *SQLSMALLINT, dataTypePtr *SQLSMALLINT, columnSizePtr *SQLULEN, decimalDigitsPtr *SQLSMALLINT, nullablePtr *SQLSMALLINT) (ret SQLRETURN) { + r0, _, _ := syscall.Syscall9(procSQLDescribeColW.Addr(), 9, uintptr(statementHandle), uintptr(columnNumber), uintptr(unsafe.Pointer(columnName)), uintptr(bufferLength), uintptr(unsafe.Pointer(nameLengthPtr)), uintptr(unsafe.Pointer(dataTypePtr)), uintptr(unsafe.Pointer(columnSizePtr)), uintptr(unsafe.Pointer(decimalDigitsPtr)), uintptr(unsafe.Pointer(nullablePtr))) + ret = SQLRETURN(r0) + return +} + +func SQLDescribeParam(statementHandle SQLHSTMT, parameterNumber SQLUSMALLINT, dataTypePtr *SQLSMALLINT, parameterSizePtr *SQLULEN, decimalDigitsPtr *SQLSMALLINT, nullablePtr *SQLSMALLINT) (ret SQLRETURN) { + r0, _, _ := syscall.Syscall6(procSQLDescribeParam.Addr(), 6, uintptr(statementHandle), uintptr(parameterNumber), uintptr(unsafe.Pointer(dataTypePtr)), uintptr(unsafe.Pointer(parameterSizePtr)), uintptr(unsafe.Pointer(decimalDigitsPtr)), uintptr(unsafe.Pointer(nullablePtr))) + ret = SQLRETURN(r0) + return +} + +func SQLDisconnect(connectionHandle SQLHDBC) (ret SQLRETURN) { + r0, _, _ := syscall.Syscall(procSQLDisconnect.Addr(), 1, uintptr(connectionHandle), 0, 0) + ret = SQLRETURN(r0) + return +} + +func SQLDriverConnect(connectionHandle SQLHDBC, windowHandle SQLHWND, inConnectionString *SQLWCHAR, stringLength1 SQLSMALLINT, outConnectionString *SQLWCHAR, bufferLength SQLSMALLINT, stringLength2Ptr *SQLSMALLINT, driverCompletion SQLUSMALLINT) (ret SQLRETURN) { + r0, _, _ := syscall.Syscall9(procSQLDriverConnectW.Addr(), 8, uintptr(connectionHandle), uintptr(windowHandle), uintptr(unsafe.Pointer(inConnectionString)), uintptr(stringLength1), uintptr(unsafe.Pointer(outConnectionString)), uintptr(bufferLength), uintptr(unsafe.Pointer(stringLength2Ptr)), uintptr(driverCompletion), 0) + ret = SQLRETURN(r0) + return +} + +func SQLEndTran(handleType SQLSMALLINT, handle SQLHANDLE, completionType SQLSMALLINT) (ret SQLRETURN) { + r0, _, _ := syscall.Syscall(procSQLEndTran.Addr(), 3, uintptr(handleType), uintptr(handle), uintptr(completionType)) + ret = SQLRETURN(r0) + return +} + +func SQLExecute(statementHandle SQLHSTMT) (ret SQLRETURN) { + r0, _, _ := syscall.Syscall(procSQLExecute.Addr(), 1, uintptr(statementHandle), 0, 0) + ret = SQLRETURN(r0) + return +} + +func SQLFetch(statementHandle SQLHSTMT) (ret SQLRETURN) { + r0, _, _ := syscall.Syscall(procSQLFetch.Addr(), 1, uintptr(statementHandle), 0, 0) + ret = SQLRETURN(r0) + return +} + +func SQLFreeHandle(handleType SQLSMALLINT, handle SQLHANDLE) (ret SQLRETURN) { + r0, _, _ := syscall.Syscall(procSQLFreeHandle.Addr(), 2, uintptr(handleType), uintptr(handle), 0) + ret = SQLRETURN(r0) + return +} + +func SQLGetData(statementHandle SQLHSTMT, colOrParamNum SQLUSMALLINT, targetType SQLSMALLINT, targetValuePtr SQLPOINTER, bufferLength SQLLEN, vallen *SQLLEN) (ret SQLRETURN) { + r0, _, _ := syscall.Syscall6(procSQLGetData.Addr(), 6, uintptr(statementHandle), uintptr(colOrParamNum), uintptr(targetType), uintptr(targetValuePtr), uintptr(bufferLength), uintptr(unsafe.Pointer(vallen))) + ret = SQLRETURN(r0) + return +} + +func SQLGetDiagRec(handleType SQLSMALLINT, handle SQLHANDLE, recNumber SQLSMALLINT, sqlState *SQLWCHAR, nativeErrorPtr *SQLINTEGER, messageText *SQLWCHAR, bufferLength SQLSMALLINT, textLengthPtr *SQLSMALLINT) (ret SQLRETURN) { + r0, _, _ := syscall.Syscall9(procSQLGetDiagRecW.Addr(), 8, uintptr(handleType), uintptr(handle), uintptr(recNumber), uintptr(unsafe.Pointer(sqlState)), uintptr(unsafe.Pointer(nativeErrorPtr)), uintptr(unsafe.Pointer(messageText)), uintptr(bufferLength), uintptr(unsafe.Pointer(textLengthPtr)), 0) + ret = SQLRETURN(r0) + return +} + +func SQLNumParams(statementHandle SQLHSTMT, parameterCountPtr *SQLSMALLINT) (ret SQLRETURN) { + r0, _, _ := syscall.Syscall(procSQLNumParams.Addr(), 2, uintptr(statementHandle), uintptr(unsafe.Pointer(parameterCountPtr)), 0) + ret = SQLRETURN(r0) + return +} + +func SQLMoreResults(statementHandle SQLHSTMT) (ret SQLRETURN) { + r0, _, _ := syscall.Syscall(procSQLMoreResults.Addr(), 1, uintptr(statementHandle), 0, 0) + ret = SQLRETURN(r0) + return +} + +func SQLNumResultCols(statementHandle SQLHSTMT, columnCountPtr *SQLSMALLINT) (ret SQLRETURN) { + r0, _, _ := syscall.Syscall(procSQLNumResultCols.Addr(), 2, uintptr(statementHandle), uintptr(unsafe.Pointer(columnCountPtr)), 0) + ret = SQLRETURN(r0) + return +} + +func SQLPrepare(statementHandle SQLHSTMT, statementText *SQLWCHAR, textLength SQLINTEGER) (ret SQLRETURN) { + r0, _, _ := syscall.Syscall(procSQLPrepareW.Addr(), 3, uintptr(statementHandle), uintptr(unsafe.Pointer(statementText)), uintptr(textLength)) + ret = SQLRETURN(r0) + return +} + +func SQLRowCount(statementHandle SQLHSTMT, rowCountPtr *SQLLEN) (ret SQLRETURN) { + r0, _, _ := syscall.Syscall(procSQLRowCount.Addr(), 2, uintptr(statementHandle), uintptr(unsafe.Pointer(rowCountPtr)), 0) + ret = SQLRETURN(r0) + return +} + +func SQLSetEnvAttr(environmentHandle SQLHENV, attribute SQLINTEGER, valuePtr SQLPOINTER, stringLength SQLINTEGER) (ret SQLRETURN) { + r0, _, _ := syscall.Syscall6(procSQLSetEnvAttr.Addr(), 4, uintptr(environmentHandle), uintptr(attribute), uintptr(valuePtr), uintptr(stringLength), 0, 0) + ret = SQLRETURN(r0) + return +} + +func SQLSetConnectAttr(connectionHandle SQLHDBC, attribute SQLINTEGER, valuePtr SQLPOINTER, stringLength SQLINTEGER) (ret SQLRETURN) { + r0, _, _ := syscall.Syscall6(procSQLSetConnectAttrW.Addr(), 4, uintptr(connectionHandle), uintptr(attribute), uintptr(valuePtr), uintptr(stringLength), 0, 0) + ret = SQLRETURN(r0) + return +} diff --git a/column.go b/column.go new file mode 100644 index 0000000..ee05d0a --- /dev/null +++ b/column.go @@ -0,0 +1,385 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package odbc + +import ( + "database/sql/driver" + "errors" + "fmt" + "time" + "unsafe" + + "github.com/taoikaihatsu-dev/odbc/api" +) + +type BufferLen api.SQLLEN + +func (l *BufferLen) IsNull() bool { + return *l == api.SQL_NULL_DATA +} + +func (l *BufferLen) GetData(h api.SQLHSTMT, idx int, ctype api.SQLSMALLINT, buf []byte) api.SQLRETURN { + return api.SQLGetData(h, api.SQLUSMALLINT(idx+1), ctype, + api.SQLPOINTER(unsafe.Pointer(&buf[0])), api.SQLLEN(len(buf)), + (*api.SQLLEN)(l)) +} + +func (l *BufferLen) Bind(h api.SQLHSTMT, idx int, ctype api.SQLSMALLINT, buf []byte) api.SQLRETURN { + return api.SQLBindCol(h, api.SQLUSMALLINT(idx+1), ctype, + api.SQLPOINTER(unsafe.Pointer(&buf[0])), api.SQLLEN(len(buf)), + (*api.SQLLEN)(l)) +} + +// Column provides access to row columns. +type Column interface { + Name() string + DatabaseTypeName() string + Bind(h api.SQLHSTMT, idx int) (bool, error) + Value(h api.SQLHSTMT, idx int) (driver.Value, error) +} + +func describeColumn(h api.SQLHSTMT, idx int, namebuf []uint16) (namelen int, sqltype api.SQLSMALLINT, size api.SQLULEN, ret api.SQLRETURN) { + var l, decimal, nullable api.SQLSMALLINT + ret = api.SQLDescribeCol(h, api.SQLUSMALLINT(idx+1), + (*api.SQLWCHAR)(unsafe.Pointer(&namebuf[0])), + api.SQLSMALLINT(len(namebuf)), &l, + &sqltype, &size, &decimal, &nullable) + return int(l), sqltype, size, ret +} + +// TODO(brainman): did not check for MS SQL timestamp + +func NewColumn(h api.SQLHSTMT, idx int) (Column, error) { + namebuf := make([]uint16, 150) + namelen, sqltype, size, ret := describeColumn(h, idx, namebuf) + if ret == api.SQL_SUCCESS_WITH_INFO && namelen > len(namebuf) { + // try again with bigger buffer + namebuf = make([]uint16, namelen) + namelen, sqltype, size, ret = describeColumn(h, idx, namebuf) + } + if IsError(ret) { + return nil, NewError("SQLDescribeCol", h) + } + if namelen > len(namebuf) { + // still complaining about buffer size + return nil, errors.New("Failed to allocate column name buffer") + } + b := &BaseColumn{ + name: api.UTF16ToString(namebuf[:namelen]), + SQLType: sqltype, + } + switch sqltype { + case api.SQL_BIT: + return NewBindableColumn(b, api.SQL_C_BIT, 1), nil + case api.SQL_TINYINT, api.SQL_SMALLINT, api.SQL_INTEGER: + return NewBindableColumn(b, api.SQL_C_LONG, 4), nil + case api.SQL_BIGINT: + return NewBindableColumn(b, api.SQL_C_SBIGINT, 8), nil + case api.SQL_NUMERIC, api.SQL_DECIMAL, api.SQL_FLOAT, api.SQL_REAL, api.SQL_DOUBLE: + return NewBindableColumn(b, api.SQL_C_DOUBLE, 8), nil + case api.SQL_TYPE_TIMESTAMP: + var v api.SQL_TIMESTAMP_STRUCT + return NewBindableColumn(b, api.SQL_C_TYPE_TIMESTAMP, int(unsafe.Sizeof(v))), nil + case api.SQL_TYPE_DATE: + var v api.SQL_DATE_STRUCT + return NewBindableColumn(b, api.SQL_C_DATE, int(unsafe.Sizeof(v))), nil + case api.SQL_TYPE_TIME: + var v api.SQL_TIME_STRUCT + return NewBindableColumn(b, api.SQL_C_TIME, int(unsafe.Sizeof(v))), nil + case api.SQL_SS_TIME2: + var v api.SQL_SS_TIME2_STRUCT + return NewBindableColumn(b, api.SQL_C_BINARY, int(unsafe.Sizeof(v))), nil + case api.SQL_GUID: + var v api.SQLGUID + return NewBindableColumn(b, api.SQL_C_GUID, int(unsafe.Sizeof(v))), nil + case api.SQL_CHAR, api.SQL_VARCHAR: + return NewVariableWidthColumn(b, api.SQL_C_CHAR, size) + case api.SQL_WCHAR, api.SQL_WVARCHAR: + return NewVariableWidthColumn(b, api.SQL_C_WCHAR, size) + case api.SQL_BINARY, api.SQL_VARBINARY: + return NewVariableWidthColumn(b, api.SQL_C_BINARY, size) + case api.SQL_LONGVARCHAR: + return NewVariableWidthColumn(b, api.SQL_C_CHAR, 0) + case api.SQL_WLONGVARCHAR, api.SQL_SS_XML: + return NewVariableWidthColumn(b, api.SQL_C_WCHAR, 0) + case api.SQL_LONGVARBINARY: + return NewVariableWidthColumn(b, api.SQL_C_BINARY, 0) + default: + return nil, fmt.Errorf("unsupported column type %d", sqltype) + } +} + +// BaseColumn implements common column functionality. +type BaseColumn struct { + name string + SQLType api.SQLSMALLINT + CType api.SQLSMALLINT +} + +func (c *BaseColumn) Name() string { + return c.name +} + +func (c *BaseColumn) DatabaseTypeName() string { + switch c.SQLType { + case api.SQL_CHAR: + return "CHAR" + case api.SQL_NUMERIC: + return "NUMERIC" + case api.SQL_DECIMAL: + return "DECIMAL" + case api.SQL_INTEGER: + return "INTEGER" + case api.SQL_SMALLINT: + return "SMALLINT" + case api.SQL_FLOAT: + return "FLOAT" + case api.SQL_REAL: + return "REAL" + case api.SQL_DOUBLE: + return "DOUBLE" + case api.SQL_DATETIME: + return "DATETIME" + case api.SQL_TIME: + return "TIME" + case api.SQL_VARCHAR: + return "VARCHAR" + case api.SQL_TYPE_DATE: + return "TYPE_DATE" + case api.SQL_TYPE_TIME: + return "TYPE_TIME" + case api.SQL_TYPE_TIMESTAMP: + return "TYPE_TIMESTAMP" + case api.SQL_TIMESTAMP: + return "TIMESTAMP" + case api.SQL_LONGVARCHAR: + return "LONGVARCHAR" + case api.SQL_BINARY: + return "BINARY" + case api.SQL_VARBINARY: + return "VARBINARY" + case api.SQL_LONGVARBINARY: + return "LONGVARBINARY" + case api.SQL_BIGINT: + return "BIGINT" + case api.SQL_TINYINT: + return "TINYINT" + case api.SQL_BIT: + return "BIT" + case api.SQL_WCHAR: + return "WCHAR" + case api.SQL_WVARCHAR: + return "WVARCHAR" + case api.SQL_WLONGVARCHAR: + return "WLONGVARCHAR" + case api.SQL_GUID: + return "GUID" + case api.SQL_SIGNED_OFFSET: + return "SIGNED_OFFSET" + case api.SQL_UNSIGNED_OFFSET: + return "UNSIGNED_OFFSET" + case api.SQL_UNKNOWN_TYPE: + return "" + default: + return "" + } +} + +func (c *BaseColumn) Value(buf []byte) (driver.Value, error) { + var p unsafe.Pointer + if len(buf) > 0 { + p = unsafe.Pointer(&buf[0]) + } + switch c.CType { + case api.SQL_C_BIT: + return buf[0] != 0, nil + case api.SQL_C_LONG: + return *((*int32)(p)), nil + case api.SQL_C_SBIGINT: + return *((*int64)(p)), nil + case api.SQL_C_DOUBLE: + return *((*float64)(p)), nil + case api.SQL_C_CHAR: + return buf, nil + case api.SQL_C_WCHAR: + if p == nil { + return buf, nil + } + s := (*[1 << 28]uint16)(p)[: len(buf)/2 : len(buf)/2] + return utf16toutf8(s), nil + case api.SQL_C_TYPE_TIMESTAMP: + t := (*api.SQL_TIMESTAMP_STRUCT)(p) + r := time.Date(int(t.Year), time.Month(t.Month), int(t.Day), + int(t.Hour), int(t.Minute), int(t.Second), int(t.Fraction), + time.Local) + return r, nil + case api.SQL_C_GUID: + t := (*api.SQLGUID)(p) + var p1, p2 string + for _, d := range t.Data4[:2] { + p1 += fmt.Sprintf("%02x", d) + } + for _, d := range t.Data4[2:] { + p2 += fmt.Sprintf("%02x", d) + } + r := fmt.Sprintf("%08x-%04x-%04x-%s-%s", + t.Data1, t.Data2, t.Data3, p1, p2) + return r, nil + case api.SQL_C_DATE: + t := (*api.SQL_DATE_STRUCT)(p) + r := time.Date(int(t.Year), time.Month(t.Month), int(t.Day), + 0, 0, 0, 0, time.Local) + return r, nil + case api.SQL_C_TIME: + t := (*api.SQL_TIME_STRUCT)(p) + r := time.Date(1, time.January, 1, + int(t.Hour), int(t.Minute), int(t.Second), 0, time.Local) + return r, nil + case api.SQL_C_BINARY: + if c.SQLType == api.SQL_SS_TIME2 { + t := (*api.SQL_SS_TIME2_STRUCT)(p) + r := time.Date(1, time.January, 1, + int(t.Hour), int(t.Minute), int(t.Second), int(t.Fraction), + time.Local) + return r, nil + } + return buf, nil + } + return nil, fmt.Errorf("unsupported column ctype %d", c.CType) +} + +// BindableColumn allows access to columns that can have their buffers +// bound. Once bound at start, they are written to by odbc driver every +// time it fetches new row. This saves on syscall and, perhaps, some +// buffer copying. BindableColumn can be left unbound, then it behaves +// like NonBindableColumn when user reads data from it. +type BindableColumn struct { + *BaseColumn + IsBound bool + IsVariableWidth bool + Size int + Len BufferLen + Buffer []byte +} + +// TODO(brainman): BindableColumn.Buffer is used by external code after external code returns - that needs to be avoided in the future + +func NewBindableColumn(b *BaseColumn, ctype api.SQLSMALLINT, bufSize int) *BindableColumn { + b.CType = ctype + c := &BindableColumn{BaseColumn: b, Size: bufSize} + l := 8 // always use small starting buffer + if c.Size > l { + l = c.Size + } + c.Buffer = make([]byte, l) + return c +} + +func NewVariableWidthColumn(b *BaseColumn, ctype api.SQLSMALLINT, colWidth api.SQLULEN) (Column, error) { + if colWidth == 0 || colWidth > 1024 { + b.CType = ctype + return &NonBindableColumn{b}, nil + } + l := int(colWidth) + switch ctype { + case api.SQL_C_WCHAR: + l += 1 // room for null-termination character + l *= 2 // wchars take 2 bytes each + case api.SQL_C_CHAR: + l += 1 // room for null-termination character + case api.SQL_C_BINARY: + // nothing to do + default: + return nil, fmt.Errorf("do not know how wide column of ctype %d is", ctype) + } + c := NewBindableColumn(b, ctype, l) + c.IsVariableWidth = true + return c, nil +} + +func (c *BindableColumn) Bind(h api.SQLHSTMT, idx int) (bool, error) { + ret := c.Len.Bind(h, idx, c.CType, c.Buffer) + if IsError(ret) { + return false, NewError("SQLBindCol", h) + } + c.IsBound = true + return true, nil +} + +func (c *BindableColumn) Value(h api.SQLHSTMT, idx int) (driver.Value, error) { + if !c.IsBound { + ret := c.Len.GetData(h, idx, c.CType, c.Buffer) + if IsError(ret) { + return nil, NewError("SQLGetData", h) + } + } + if c.Len.IsNull() { + // is NULL + return nil, nil + } + if !c.IsVariableWidth && int(c.Len) != c.Size { + return nil, fmt.Errorf("wrong column #%d length %d returned, %d expected", idx, c.Len, c.Size) + } + return c.BaseColumn.Value(c.Buffer[:c.Len]) +} + +// NonBindableColumn provide access to columns, that can't be bound. +// These are of character or binary type, and, usually, there is no +// limit for their width. +type NonBindableColumn struct { + *BaseColumn +} + +func (c *NonBindableColumn) Bind(h api.SQLHSTMT, idx int) (bool, error) { + return false, nil +} + +func (c *NonBindableColumn) Value(h api.SQLHSTMT, idx int) (driver.Value, error) { + var l BufferLen + var total []byte + b := make([]byte, 1024) +loop: + for { + ret := l.GetData(h, idx, c.CType, b) + switch ret { + case api.SQL_SUCCESS: + if l.IsNull() { + // is NULL + return nil, nil + } + if int(l) > len(b) { + return nil, fmt.Errorf("too much data returned: %d bytes returned, but buffer size is %d", l, cap(b)) + } + total = append(total, b[:l]...) + break loop + case api.SQL_SUCCESS_WITH_INFO: + err := NewError("SQLGetData", h).(*Error) + if len(err.Diag) > 0 && err.Diag[0].State != "01004" { + return nil, err + } + i := len(b) + switch c.CType { + case api.SQL_C_WCHAR: + i -= 2 // remove wchar (2 bytes) null-termination character + case api.SQL_C_CHAR: + i-- // remove null-termination character + } + total = append(total, b[:i]...) + if l != api.SQL_NO_TOTAL { + // odbc gives us a hint about remaining data, + // lets get it in one go. + n := int(l) // total bytes for our data + n -= i // subtract already received + n += 2 // room for biggest (wchar) null-terminator + if len(b) < n { + b = make([]byte, n) + } + } + default: + return nil, NewError("SQLGetData", h) + } + } + return c.BaseColumn.Value(total) +} diff --git a/conn.go b/conn.go new file mode 100644 index 0000000..0c135ad --- /dev/null +++ b/conn.go @@ -0,0 +1,74 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package odbc + +import ( + "database/sql/driver" + "strings" + "unsafe" + + "github.com/taoikaihatsu-dev/odbc/api" +) + +type Conn struct { + h api.SQLHDBC + tx *Tx + bad bool + isMSAccessDriver bool +} + +var accessDriverSubstr = strings.ToUpper(strings.Replace("DRIVER={Microsoft Access Driver", " ", "", -1)) + +func (d *Driver) Open(dsn string) (driver.Conn, error) { + if d.initErr != nil { + return nil, d.initErr + } + + var out api.SQLHANDLE + ret := api.SQLAllocHandle(api.SQL_HANDLE_DBC, api.SQLHANDLE(d.h), &out) + if IsError(ret) { + return nil, NewError("SQLAllocHandle", d.h) + } + h := api.SQLHDBC(out) + drv.Stats.updateHandleCount(api.SQL_HANDLE_DBC, 1) + + b := api.StringToUTF16(dsn) + ret = api.SQLDriverConnect(h, 0, + (*api.SQLWCHAR)(unsafe.Pointer(&b[0])), api.SQL_NTS, + nil, 0, nil, api.SQL_DRIVER_NOPROMPT) + if IsError(ret) { + defer releaseHandle(h) + return nil, NewError("SQLDriverConnect", h) + } + isAccess := strings.Contains(strings.ToUpper(strings.Replace(dsn, " ", "", -1)), accessDriverSubstr) + return &Conn{h: h, isMSAccessDriver: isAccess}, nil +} + +func (c *Conn) Close() (err error) { + if c.tx != nil { + c.tx.Rollback() + } + h := c.h + defer func() { + c.h = api.SQLHDBC(api.SQL_NULL_HDBC) + e := releaseHandle(h) + if err == nil { + err = e + } + }() + ret := api.SQLDisconnect(c.h) + if IsError(ret) { + return c.newError("SQLDisconnect", h) + } + return err +} + +func (c *Conn) newError(apiName string, handle interface{}) error { + err := NewError(apiName, handle) + if err == driver.ErrBadConn { + c.bad = true + } + return err +} diff --git a/driver.go b/driver.go new file mode 100644 index 0000000..2422d89 --- /dev/null +++ b/driver.go @@ -0,0 +1,79 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package odbc implements database/sql driver to access data via odbc interface. +// +package odbc + +import ( + "database/sql" + + "github.com/taoikaihatsu-dev/odbc/api" +) + +var drv Driver + +type Driver struct { + Stats + h api.SQLHENV // environment handle + initErr error +} + +func initDriver() error { + + //Allocate environment handle + var out api.SQLHANDLE + in := api.SQLHANDLE(api.SQL_NULL_HANDLE) + ret := api.SQLAllocHandle(api.SQL_HANDLE_ENV, in, &out) + if IsError(ret) { + return NewError("SQLAllocHandle", api.SQLHENV(in)) + } + drv.h = api.SQLHENV(out) + err := drv.Stats.updateHandleCount(api.SQL_HANDLE_ENV, 1) + if err != nil { + return err + } + + // will use ODBC v3 + ret = api.SQLSetEnvUIntPtrAttr(drv.h, api.SQL_ATTR_ODBC_VERSION, api.SQL_OV_ODBC3, 0) + if IsError(ret) { + defer releaseHandle(drv.h) + return NewError("SQLSetEnvUIntPtrAttr", drv.h) + } + + //TODO: find a way to make this attribute changeable at runtime + //Enable connection pooling + ret = api.SQLSetEnvUIntPtrAttr(drv.h, api.SQL_ATTR_CONNECTION_POOLING, api.SQL_CP_ONE_PER_HENV, api.SQL_IS_UINTEGER) + if IsError(ret) { + defer releaseHandle(drv.h) + return NewError("SQLSetEnvUIntPtrAttr", drv.h) + } + + //Set relaxed connection pool matching + ret = api.SQLSetEnvUIntPtrAttr(drv.h, api.SQL_ATTR_CP_MATCH, api.SQL_CP_RELAXED_MATCH, api.SQL_IS_UINTEGER) + if IsError(ret) { + defer releaseHandle(drv.h) + return NewError("SQLSetEnvUIntPtrAttr", drv.h) + } + + //TODO: it would be nice if we could call "drv.SetMaxIdleConns(0)" here but from the docs it looks like + //the user must call this function after db.Open + + return nil +} + +func (d *Driver) Close() error { + // TODO(brainman): who will call (*Driver).Close (to dispose all opened handles)? + h := d.h + d.h = api.SQLHENV(api.SQL_NULL_HENV) + return releaseHandle(h) +} + +func init() { + err := initDriver() + if err != nil { + drv.initErr = err + } + sql.Register("odbc", &drv) +} diff --git a/error.go b/error.go new file mode 100644 index 0000000..18dfa43 --- /dev/null +++ b/error.go @@ -0,0 +1,74 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package odbc + +import ( + "database/sql/driver" + "fmt" + "strings" + "unsafe" + + "github.com/taoikaihatsu-dev/odbc/api" +) + +func IsError(ret api.SQLRETURN) bool { + return !(ret == api.SQL_SUCCESS || ret == api.SQL_SUCCESS_WITH_INFO) +} + +type DiagRecord struct { + State string + NativeError int + Message string +} + +func (r *DiagRecord) String() string { + return fmt.Sprintf("{%s} %s", r.State, r.Message) +} + +type Error struct { + APIName string + Diag []DiagRecord +} + +func (e *Error) Error() string { + ss := make([]string, len(e.Diag)) + for i, r := range e.Diag { + ss[i] = r.String() + } + return e.APIName + ": " + strings.Join(ss, "\n") +} + +func NewError(apiName string, handle interface{}) error { + h, ht, herr := ToHandleAndType(handle) + if herr != nil { + return herr + } + err := &Error{APIName: apiName} + var ne api.SQLINTEGER + state := make([]uint16, 6) + msg := make([]uint16, api.SQL_MAX_MESSAGE_LENGTH) + for i := 1; ; i++ { + ret := api.SQLGetDiagRec(ht, h, api.SQLSMALLINT(i), + (*api.SQLWCHAR)(unsafe.Pointer(&state[0])), &ne, + (*api.SQLWCHAR)(unsafe.Pointer(&msg[0])), + api.SQLSMALLINT(len(msg)), nil) + if ret == api.SQL_NO_DATA { + break + } + if IsError(ret) { + return fmt.Errorf("SQLGetDiagRec failed: ret=%d", ret) + } + r := DiagRecord{ + State: api.UTF16ToString(state), + NativeError: int(ne), + Message: api.UTF16ToString(msg), + } + if r.State == "08S01" { + return driver.ErrBadConn + } + err.Diag = append(err.Diag, r) + } + return err +} diff --git a/foxpro_test.go b/foxpro_test.go new file mode 100644 index 0000000..6390e44 --- /dev/null +++ b/foxpro_test.go @@ -0,0 +1,157 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package odbc_test + +import ( + "database/sql" + "flag" + "fmt" + "testing" + "time" + + _ "github.com/taoikaihatsu-dev/odbc" +) + +var ( + fox = flag.String("fox", "testdata", "directory where foxpro tables reside") +) + +func TestFoxPro(t *testing.T) { + conn := fmt.Sprintf("driver={Microsoft dBASE Driver (*.dbf)};driverid=277;dbq=%s;", + *fox) + + db, err := sql.Open("odbc", conn) + if err != nil { + t.Fatal(err) + } + defer db.Close() + if err := db.Ping(); err != nil { + t.Skipf("skipping test: %v", err) + } + + type row struct { + char string + num_2_0 sql.NullFloat64 + num_20_0 float64 + num_6_3 float32 + date time.Time + float_2_0 float32 + float_20_0 float64 + float_6_3 float32 + logical bool + memo sql.NullString + } + + var tests = []row{ + { + char: "123", + num_2_0: sql.NullFloat64{Float64: 1, Valid: true}, + num_20_0: 1232543, + num_6_3: 12.73, + date: time.Date(2012, 5, 19, 0, 0, 0, 0, time.Local), + float_2_0: 23, + float_20_0: 12345678901234560, + float_6_3: 12.345, + logical: true, + memo: sql.NullString{String: "Hello", Valid: true}, + }, + { + char: "abcdef", + num_2_0: sql.NullFloat64{Float64: 23, Valid: true}, + num_20_0: 4564568, + num_6_3: 2, + date: time.Date(2012, 5, 20, 0, 0, 0, 0, time.Local), + float_2_0: 1, + float_20_0: 234, + float_6_3: 0.123, + logical: false, + memo: sql.NullString{String: "", Valid: false}, + }, + { + char: "346546", + num_2_0: sql.NullFloat64{Float64: 4, Valid: true}, + num_20_0: 1234567890123456000, + num_6_3: 99.99, + date: time.Date(2012, 5, 21, 0, 0, 0, 0, time.Local), + float_2_0: 23, + float_20_0: 457768, + float_6_3: 99, + logical: true, + memo: sql.NullString{String: "World", Valid: true}, + }, + { + char: "asasds", + num_2_0: sql.NullFloat64{Float64: 0, Valid: false}, + num_20_0: 234456, + num_6_3: 0.123, + date: time.Date(2012, 5, 22, 0, 0, 0, 0, time.Local), + float_2_0: 65, + float_20_0: 234, + float_6_3: 1, + logical: false, + memo: sql.NullString{String: "12398y345 sdflkjdsfsd fds;lkdsfgl;sd", Valid: true}, + }, + } + + const query = `select id, + char, num_2_0, num_20_0, num_6_3, date, + float_2_0, float_20_0, float_6_3, logical, memo + from fldtest` + rows, err := db.Query(query) + if err != nil { + t.Fatal(err) + } + for rows.Next() { + var id int + var r row + err = rows.Scan(&id, + &r.char, &r.num_2_0, &r.num_20_0, &r.num_6_3, &r.date, + &r.float_2_0, &r.float_20_0, &r.float_6_3, &r.logical, &r.memo) + if err != nil { + t.Fatal(err) + } + + if id < 0 || len(tests) < id { + t.Errorf("unexpected row with id %d", id) + continue + } + + x := tests[id] + if x.char != r.char { + t.Errorf("row %d: char expected %v, but received %v", id, x.char, r.char) + } + if x.num_2_0 != r.num_2_0 { + t.Errorf("row %d: num_2_0 expected %v, but received %v", id, x.num_2_0, r.num_2_0) + } + if x.num_20_0 != r.num_20_0 { + t.Errorf("row %d: num_20_0 expected %v, but received %v", id, x.num_20_0, r.num_20_0) + } + if x.num_6_3 != r.num_6_3 { + t.Errorf("row %d: num_6_3 expected %v, but received %v", id, x.num_6_3, r.num_6_3) + } + if x.date != r.date { + t.Errorf("row %d: date expected %v, but received %v", id, x.date, r.date) + } + if x.float_2_0 != r.float_2_0 { + t.Errorf("row %d: float_2_0 expected %v, but received %v", id, x.float_2_0, r.float_2_0) + } + if x.float_20_0 != r.float_20_0 { + t.Errorf("row %d: float_20_0 expected %v, but received %v", id, x.float_20_0, r.float_20_0) + } + if x.float_6_3 != r.float_6_3 { + t.Errorf("row %d: float_6_3 expected %v, but received %v", id, x.float_6_3, r.float_6_3) + } + if x.logical != r.logical { + t.Errorf("row %d: logical expected %v, but received %v", id, x.logical, r.logical) + } + if x.memo != r.memo { + t.Errorf("row %d: memo expected %v, but received %v", id, x.memo, r.memo) + } + } + err = rows.Err() + if err != nil { + t.Fatal(err) + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..417230c --- /dev/null +++ b/go.mod @@ -0,0 +1,6 @@ +module github.com/taoikaihatsu-dev/odbc + +require ( + github.com/go-ole/go-ole v1.2.5 + golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3 +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..2712e4d --- /dev/null +++ b/go.sum @@ -0,0 +1,4 @@ +github.com/go-ole/go-ole v1.2.5 h1:t4MGB5xEDZvXI+0rMjjsfBsD7yAgp/s9ZDkL1JndXwY= +github.com/go-ole/go-ole v1.2.5/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= +golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3 h1:7TYNF4UdlohbFwpNH04CoPMp1cHUZgO1Ebq5r2hIjfo= +golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/handle.go b/handle.go new file mode 100644 index 0000000..4b95291 --- /dev/null +++ b/handle.go @@ -0,0 +1,47 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package odbc + +import ( + "fmt" + + "github.com/taoikaihatsu-dev/odbc/api" +) + +func ToHandleAndType(handle interface{}) (h api.SQLHANDLE, ht api.SQLSMALLINT, err error) { + switch v := handle.(type) { + case api.SQLHENV: + if v == api.SQLHENV(api.SQL_NULL_HANDLE) { + ht = 0 + } else { + ht = api.SQL_HANDLE_ENV + } + h = api.SQLHANDLE(v) + case api.SQLHDBC: + ht = api.SQL_HANDLE_DBC + h = api.SQLHANDLE(v) + case api.SQLHSTMT: + ht = api.SQL_HANDLE_STMT + h = api.SQLHANDLE(v) + default: + err = fmt.Errorf("unexpected handle type %T", v) + } + return h, ht, err +} + +func releaseHandle(handle interface{}) error { + h, ht, err := ToHandleAndType(handle) + if err != nil { + return err + } + ret := api.SQLFreeHandle(ht, h) + if ret == api.SQL_INVALID_HANDLE { + return fmt.Errorf("SQLFreeHandle(%d, %d) returns SQL_INVALID_HANDLE", ht, h) + } + if IsError(ret) { + return NewError("SQLFreeHandle", handle) + } + return drv.Stats.updateHandleCount(ht, -1) +} diff --git a/mssql_test.go b/mssql_test.go new file mode 100644 index 0000000..e2a2a9c --- /dev/null +++ b/mssql_test.go @@ -0,0 +1,1919 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package odbc + +import ( + "bytes" + "database/sql" + "database/sql/driver" + "errors" + "flag" + "fmt" + "io" + "net" + "runtime" + "strconv" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/taoikaihatsu-dev/odbc/api" +) + +var ( + mssrv = flag.String("mssrv", "server", "ms sql server name") + msdb = flag.String("msdb", "dbname", "ms sql server database name") + msuser = flag.String("msuser", "", "ms sql server user name") + mspass = flag.String("mspass", "", "ms sql server password") + msdriver = flag.String("msdriver", defaultDriver(), "ms sql odbc driver name") + msport = flag.String("msport", "1433", "ms sql server port number") +) + +func defaultDriver() string { + if runtime.GOOS == "windows" { + return "sql server" + } else { + return "freetds" + } +} + +func isFreeTDS() bool { + return *msdriver == "freetds" +} + +type connParams map[string]string + +func newConnParams() connParams { + params := connParams{ + "driver": *msdriver, + "server": *mssrv, + "database": *msdb, + } + if isFreeTDS() { + params["uid"] = *msuser + params["pwd"] = *mspass + params["port"] = *msport + params["TDS_Version"] = "8.0" + //params["clientcharset"] = "UTF-8" + //params["debugflags"] = "0xffff" + } else { + if len(*msuser) == 0 { + params["trusted_connection"] = "yes" + } else { + params["uid"] = *msuser + params["pwd"] = *mspass + } + } + a := strings.SplitN(params["server"], ",", -1) + if len(a) == 2 { + params["server"] = a[0] + params["port"] = a[1] + } + return params +} + +func (params connParams) getConnAddress() (string, error) { + port, ok := params["port"] + if !ok { + return "", errors.New("no port number provided.") + } + host, ok := params["server"] + if !ok { + return "", errors.New("no host name provided.") + } + return host + ":" + port, nil +} + +func (params connParams) updateConnAddress(address string) error { + a := strings.SplitN(address, ":", -1) + if len(a) != 2 { + return fmt.Errorf("listen address must have 2 fields, but %d found", len(a)) + } + params["server"] = a[0] + params["port"] = a[1] + return nil +} + +func (params connParams) makeODBCConnectionString() string { + if port, ok := params["port"]; ok { + params["server"] += "," + port + delete(params, "port") + } + var c string + for n, v := range params { + c += n + "=" + v + ";" + } + return c +} + +func mssqlConnectWithParams(params connParams) (db *sql.DB, stmtCount int, err error) { + db, err = sql.Open("odbc", params.makeODBCConnectionString()) + if err != nil { + return nil, 0, err + } + stats := db.Driver().(*Driver).Stats + return db, stats.StmtCount, nil +} + +func mssqlConnect() (db *sql.DB, stmtCount int, err error) { + return mssqlConnectWithParams(newConnParams()) +} + +func closeDB(t *testing.T, db *sql.DB, shouldStmtCount, ignoreIfStmtCount int) { + s := db.Driver().(*Driver).Stats + err := db.Close() + if err != nil { + t.Fatalf("error closing DB: %v", err) + } + switch s.StmtCount { + case shouldStmtCount: + // all good + case ignoreIfStmtCount: + t.Logf("ignoring unexpected StmtCount of %v", ignoreIfStmtCount) + default: + t.Errorf("unexpected StmtCount: should=%v, is=%v", ignoreIfStmtCount, s.StmtCount) + } +} + +// as per http://www.mssqltips.com/sqlservertip/2198/determine-which-version-of-sql-server-data-access-driver-is-used-by-an-application/ +func connProtoVersion(db *sql.DB) ([]byte, error) { + var p []byte + if err := db.QueryRow("select cast(protocol_version as binary(4)) from master.sys.dm_exec_connections where session_id = @@spid").Scan(&p); err != nil { + return nil, err + } + if len(p) != 4 { + return nil, errors.New("failed to fetch connection protocol") + } + return p, nil +} + +// as per http://msdn.microsoft.com/en-us/library/dd339982.aspx +func isProto2008OrLater(db *sql.DB) (bool, error) { + p, err := connProtoVersion(db) + if err != nil { + return false, err + } + return p[0] >= 0x73, nil +} + +// as per http://www.mssqltips.com/sqlservertip/2563/understanding-the-sql-server-select-version-command/ +func serverVersion(db *sql.DB) (sqlVersion, sqlPartNumber, osVersion string, err error) { + var v string + if err = db.QueryRow("select @@version").Scan(&v); err != nil { + return "", "", "", err + } + a := strings.SplitN(v, "\n", -1) + if len(a) < 4 { + return "", "", "", errors.New("SQL Server version string must have at least 4 lines: " + v) + } + for i := range a { + a[i] = strings.Trim(a[i], " \t") + } + l1 := strings.SplitN(a[0], "- ", -1) + if len(l1) != 2 { + return "", "", "", errors.New("SQL Server version first line must have - in it: " + v) + } + i := strings.Index(a[3], " on ") + if i < 0 { + return "", "", "", errors.New("SQL Server version fourth line must have 'on' in it: " + v) + } + sqlVersion = l1[0] + a[3][:i] + osVersion = a[3][i+4:] + sqlPartNumber = strings.Trim(l1[1], " ") + l12 := strings.SplitN(sqlPartNumber, " ", -1) + if len(l12) < 2 { + return "", "", "", errors.New("SQL Server version first line must have space after part number in it: " + v) + } + sqlPartNumber = l12[0] + return sqlVersion, sqlPartNumber, osVersion, nil +} + +// as per http://www.mssqltips.com/sqlservertip/2563/understanding-the-sql-server-select-version-command/ +func isSrv2008OrLater(db *sql.DB) (bool, error) { + _, sqlPartNumber, _, err := serverVersion(db) + if err != nil { + return false, err + } + a := strings.SplitN(sqlPartNumber, ".", -1) + if len(a) != 4 { + return false, errors.New("SQL Server part number must have 4 numbers in it: " + sqlPartNumber) + } + n, err := strconv.ParseInt(a[0], 10, 0) + if err != nil { + return false, errors.New("SQL Server invalid part number: " + sqlPartNumber) + } + return n >= 10, nil +} + +func is2008OrLater(db *sql.DB) bool { + b, err := isSrv2008OrLater(db) + if err != nil || !b { + return false + } + b, err = isProto2008OrLater(db) + if err != nil || !b { + return false + } + return true +} + +func exec(t *testing.T, db *sql.DB, query string) { + // TODO(brainman): make sure https://github.com/golang/go/issues/3678 is fixed + //r, err := db.Exec(query, a...) + s, err := db.Prepare(query) + if err != nil { + t.Fatalf("db.Prepare(%q) failed: %v", query, err) + } + defer s.Close() + r, err := s.Exec() + if err != nil { + t.Fatalf("s.Exec(%q ...) failed: %v", query, err) + } + _, err = r.RowsAffected() + if err != nil { + t.Fatalf("r.RowsAffected(%q ...) failed: %v", query, err) + } +} + +func driverExec(t *testing.T, dc driver.Conn, query string) { + st, err := dc.Prepare(query) + if err != nil { + t.Fatal(err) + } + defer func() { + if err := st.Close(); err != nil && t != nil { + t.Fatal(err) + } + }() + + r, err := st.Exec([]driver.Value{}) + if err != nil { + if t != nil { + t.Fatal(err) + } + return + } + _, err = r.RowsAffected() + if err != nil { + if t != nil { + t.Fatalf("r.RowsAffected(%q ...) failed: %v", query, err) + } + return + } +} + +func TestMSSQLCreateInsertDelete(t *testing.T) { + db, sc, err := mssqlConnect() + if err != nil { + t.Fatal(err) + } + defer closeDB(t, db, sc, sc) + + type friend struct { + age int + isGirl bool + weight float64 + dob time.Time + data []byte + canBeNull sql.NullString + } + var friends = map[string]friend{ + "glenda": { + age: 5, + isGirl: true, + weight: 15.5, + dob: time.Date(2000, 5, 10, 11, 1, 1, 0, time.Local), + data: []byte{0x0, 0x0, 0xb, 0xad, 0xc0, 0xde}, + canBeNull: sql.NullString{"aa", true}, + }, + "gopher": { + age: 3, + isGirl: false, + weight: 26.12, + dob: time.Date(2009, 5, 10, 11, 1, 1, 123e6, time.Local), + data: []byte{0x0}, + canBeNull: sql.NullString{"bbb", true}, + }, + } + + // create table + db.Exec("drop table dbo.temp") + exec(t, db, "create table dbo.temp (name varchar(20), age int, isGirl bit, weight decimal(5,2), dob datetime, data varbinary(10) null, canBeNull varchar(10) null)") + func() { + s, err := db.Prepare("insert into dbo.temp (name, age, isGirl, weight, dob, data, canBeNull) values (?, ?, ?, ?, ?, cast(? as varbinary(10)), ?)") + if err != nil { + t.Fatal(err) + } + defer s.Close() + for name, f := range friends { + _, err := s.Exec(name, f.age, f.isGirl, f.weight, f.dob, f.data, f.canBeNull) + if err != nil { + t.Fatal(err) + } + } + _, err = s.Exec("chris", 25, 0, 50, time.Date(2015, 12, 25, 0, 0, 0, 0, time.Local), "ccc", nil) + if err != nil { + t.Fatal(err) + } + _, err = s.Exec("null", 0, 0, 0, time.Date(2015, 12, 25, 1, 2, 3, 0, time.Local), nil, nil) + if err != nil { + t.Fatal(err) + } + }() + + // read from the table and verify returned results + rows, err := db.Query("select name, age, isGirl, weight, dob, data, canBeNull from dbo.temp") + if err != nil { + t.Fatal(err) + } + defer rows.Close() + for rows.Next() { + var name string + var is friend + err = rows.Scan(&name, &is.age, &is.isGirl, &is.weight, &is.dob, &is.data, &is.canBeNull) + if err != nil { + t.Fatal(err) + } + want, ok := friends[name] + if !ok { + switch name { + case "chris": + // we know about chris, we just do not like him + case "null": + if is.canBeNull.Valid { + t.Errorf("null's canBeNull is suppose to be NULL, but is %v", is.canBeNull) + } + default: + t.Errorf("found %s, who is not my friend", name) + } + continue + } + if is.age < want.age { + t.Errorf("I did not know, that %s is so young (%d, but %d expected)", name, is.age, want.age) + continue + } + if is.age > want.age { + t.Errorf("I did not know, that %s is so old (%d, but %d expected)", name, is.age, want.age) + continue + } + if is.isGirl != want.isGirl { + if is.isGirl { + t.Errorf("I did not know, that %s is a girl", name) + } else { + t.Errorf("I did not know, that %s is a boy", name) + } + continue + } + if is.weight != want.weight { + t.Errorf("I did not know, that %s weighs %fkg (%fkg expected)", name, is.weight, want.weight) + continue + } + if !is.dob.Equal(want.dob) { + t.Errorf("I did not know, that %s's date of birth is %v (%v expected)", name, is.dob, want.dob) + continue + } + if !bytes.Equal(is.data, want.data) { + t.Errorf("I did not know, that %s's data is %v (%v expected)", name, is.data, want.data) + continue + } + if is.canBeNull != want.canBeNull { + t.Errorf("canBeNull for %s is wrong (%v, but %v expected)", name, is.canBeNull, want.canBeNull) + continue + } + } + err = rows.Err() + if err != nil { + t.Fatal(err) + } + + // clean after ourselves + exec(t, db, "drop table dbo.temp") +} + +func TestMSSQLTransactions(t *testing.T) { + db, sc, err := mssqlConnect() + if err != nil { + t.Fatal(err) + } + defer closeDB(t, db, sc, sc) + + db.Exec("drop table dbo.temp") + exec(t, db, "create table dbo.temp (name varchar(20))") + + var was, is int + err = db.QueryRow("select count(*) from dbo.temp").Scan(&was) + if err != nil { + t.Fatal(err) + } + + tx, err := db.Begin() + if err != nil { + t.Fatal(err) + } + _, err = tx.Exec("insert into dbo.temp (name) values ('tx1')") + if err != nil { + t.Fatal(err) + } + err = tx.QueryRow("select count(*) from dbo.temp").Scan(&is) + if err != nil { + t.Fatal(err) + } + if was+1 != is { + t.Fatalf("is(%d) should be 1 more then was(%d)", is, was) + } + ch := make(chan error) + go func() { + // this will block until our transaction is finished + err = db.QueryRow("select count(*) from dbo.temp").Scan(&is) + if err != nil { + ch <- err + } + if was+1 != is { + ch <- fmt.Errorf("is(%d) should be 1 more then was(%d)", is, was) + } + ch <- nil + }() + time.Sleep(100 * time.Millisecond) + tx.Commit() + err = <-ch + if err != nil { + t.Fatal(err) + } + err = db.QueryRow("select count(*) from dbo.temp").Scan(&is) + if err != nil { + t.Fatal(err) + } + if was+1 != is { + t.Fatalf("is(%d) should be 1 more then was(%d)", is, was) + } + + was = is + tx, err = db.Begin() + if err != nil { + t.Fatal(err) + } + _, err = tx.Exec("insert into dbo.temp (name) values ('tx2')") + if err != nil { + t.Fatal(err) + } + err = tx.QueryRow("select count(*) from dbo.temp").Scan(&is) + if err != nil { + t.Fatal(err) + } + if was+1 != is { + t.Fatalf("is(%d) should be 1 more then was(%d)", is, was) + } + tx.Rollback() + err = db.QueryRow("select count(*) from dbo.temp").Scan(&is) + if err != nil { + t.Fatal(err) + } + if was != is { + t.Fatalf("is(%d) should be equal to was(%d)", is, was) + } + + exec(t, db, "drop table dbo.temp") +} + +type matchFunc func(v interface{}) error + +func match(a interface{}) matchFunc { + return func(b interface{}) error { + switch got := b.(type) { + case nil: + switch expect := a.(type) { + case nil: + // matching + default: + return fmt.Errorf("expect %v, but got %v", expect, got) + } + case bool: + expect, ok := a.(bool) + if !ok { + return fmt.Errorf("couldn't convert expected value %v(%T) to %T", a, a, got) + } + if got != expect { + return fmt.Errorf("expect %v, but got %v", expect, got) + } + case int32: + expect, ok := a.(int32) + if !ok { + return fmt.Errorf("couldn't convert expected value %v(%T) to %T", a, a, got) + } + if got != expect { + return fmt.Errorf("expect %v, but got %v", expect, got) + } + case int64: + expect, ok := a.(int64) + if !ok { + return fmt.Errorf("couldn't convert expected value %v(%T) to %T", a, a, got) + } + if got != expect { + return fmt.Errorf("expect %v, but got %v", expect, got) + } + case float64: + switch expect := a.(type) { + case float64: + if got != expect { + return fmt.Errorf("expect %v, but got %v", expect, got) + } + case int64: + if got != float64(expect) { + return fmt.Errorf("expect %v, but got %v", expect, got) + } + default: + return fmt.Errorf("unsupported type %T", expect) + } + case string: + expect, ok := a.(string) + if !ok { + return fmt.Errorf("couldn't convert expected value %v(%T) to %T", a, a, got) + } + if got != expect { + return fmt.Errorf("expect %q, but got %q", expect, got) + } + case []byte: + expect, ok := a.([]byte) + if !ok { + return fmt.Errorf("couldn't convert expected value %v(%T) to %T", a, a, got) + } + if !bytes.Equal(got, expect) { + return fmt.Errorf("expect %v, but got %v", expect, got) + } + case time.Time: + expect, ok := a.(time.Time) + if !ok { + return fmt.Errorf("couldn't convert expected value %v(%T) to %T", a, a, got) + } + if !got.Equal(expect) { + return fmt.Errorf("expect %q, but got %q", expect, got) + } + default: + return fmt.Errorf("unsupported type %T", got) + } + return nil + } +} + +type typeTest struct { + query string + match matchFunc +} + +var veryLongString = strings.Repeat("abcd ", 206) + +var typeTests = []typeTest{ + // bool + {"select cast(1 as bit)", match(true)}, + {"select cast(2 as bit)", match(true)}, + {"select cast(0 as bit)", match(false)}, + {"select cast(NULL as bit)", match(nil)}, + + // int + {"select cast(0 as int)", match(int32(0))}, + {"select cast(123 as int)", match(int32(123))}, + {"select cast(-4 as int)", match(int32(-4))}, + {"select cast(NULL as int)", match(nil)}, + {"select cast(0 as tinyint)", match(int32(0))}, + {"select cast(255 as tinyint)", match(int32(255))}, + {"select cast(-32768 as smallint)", match(int32(-32768))}, + {"select cast(32767 as smallint)", match(int32(32767))}, + {"select cast(-9223372036854775808 as bigint)", match(int64(-9223372036854775808))}, + {"select cast(9223372036854775807 as bigint)", match(int64(9223372036854775807))}, + + // decimal, float, real + {"select cast(123 as decimal(5, 0))", match(float64(123))}, + {"select cast(-123 as decimal(5, 0))", match(float64(-123))}, + {"select cast(123.5 as decimal(5, 0))", match(float64(124))}, + {"select cast(NULL as decimal(5, 0))", match(nil)}, + {"select cast(123.45 as decimal(5, 2))", match(123.45)}, + {"select cast(-123.45 as decimal(5, 2))", match(-123.45)}, + {"select cast(123.456 as decimal(5, 2))", match(123.46)}, + {"select cast(0.123456789 as float)", match(0.123456789)}, + {"select cast(NULL as float)", match(nil)}, + {"select cast(3.6666667461395264 as real)", match(3.6666667461395264)}, + {"select cast(NULL as real)", match(nil)}, + {"select cast(1.2333333504e+10 as real)", match(1.2333333504e+10)}, + + // money + {"select cast(12 as money)", match(float64(12))}, + {"select cast(-12 as money)", match(float64(-12))}, + {"select cast(0.01 as money)", match(0.01)}, + {"select cast(0.0123 as money)", match(0.0123)}, + {"select cast(NULL as money)", match(nil)}, + {"select cast(1 as smallmoney)", match(float64(1))}, + {"select cast(0.0123 as smallmoney)", match(0.0123)}, + {"select cast(NULL as smallmoney)", match(nil)}, + + // strings + {"select cast(123 as varchar(21))", match([]byte("123"))}, + {"select cast(123 as char(5))", match([]byte("123 "))}, + {"select cast('abcde' as varchar(3))", match([]byte("abc"))}, + {"select cast('' as varchar(5))", match([]byte(""))}, + {"select cast(NULL as varchar(5))", match(nil)}, + {"select cast(123 as nvarchar(21))", match([]byte("123"))}, + {"select cast('abcde' as nvarchar(3))", match([]byte("abc"))}, + {"select cast('' as nvarchar(5))", match([]byte(""))}, + {"select cast(NULL as nvarchar(5))", match(nil)}, + + // datetime, smalldatetime + {"select cast('20151225' as datetime)", match(time.Date(2015, 12, 25, 0, 0, 0, 0, time.Local))}, + {"select cast('2007-05-08 12:35:29.123' as datetime)", match(time.Date(2007, 5, 8, 12, 35, 29, 123e6, time.Local))}, + {"select cast(NULL as datetime)", match(nil)}, + {"select cast('2007-05-08 12:35:29.123' as smalldatetime)", match(time.Date(2007, 5, 8, 12, 35, 0, 0, time.Local))}, + + // uniqueidentifier + {"select cast('0e984725-c51c-4bf4-9960-e1c80e27aba0' as uniqueidentifier)", match("0e984725-c51c-4bf4-9960-e1c80e27aba0")}, + {"select cast(NULL as uniqueidentifier)", match(nil)}, + + // string blobs + {"select cast('abc' as varchar(max))", match([]byte("abc"))}, + {"select cast('' as varchar(max))", match([]byte(""))}, + {fmt.Sprintf("select cast('%s' as varchar(max))", veryLongString), match([]byte(veryLongString))}, + {"select cast(NULL as varchar(max))", match(nil)}, + {"select cast('abc' as nvarchar(max))", match([]byte("abc"))}, + {"select cast('' as nvarchar(max))", match([]byte(""))}, + {fmt.Sprintf("select cast('%s' as nvarchar(max))", veryLongString), match([]byte(veryLongString))}, + {"select cast(NULL as nvarchar(max))", match(nil)}, + {"select cast('abc' as text)", match([]byte("abc"))}, + {"select cast('' as text)", match([]byte(""))}, + {fmt.Sprintf("select cast('%s' as text)", veryLongString), match([]byte(veryLongString))}, + {"select cast(NULL as text)", match(nil)}, + {"select cast('abc' as ntext)", match([]byte("abc"))}, + {"select cast('' as ntext)", match([]byte(""))}, + {fmt.Sprintf("select cast('%s' as ntext)", veryLongString), match([]byte(veryLongString))}, + {"select cast(NULL as ntext)", match(nil)}, + + // xml + {"select cast(N'hello' as xml)", match([]byte("hello"))}, + {"select cast(N'dd' as xml)", match([]byte("dd"))}, + + // binary blobs + {"select cast('abc' as binary(5))", match([]byte{'a', 'b', 'c', 0, 0})}, + {"select cast('' as binary(5))", match([]byte{0, 0, 0, 0, 0})}, + {"select cast(NULL as binary(5))", match(nil)}, + {"select cast('abc' as varbinary(5))", match([]byte{'a', 'b', 'c'})}, + {"select cast('' as varbinary(5))", match([]byte(""))}, + {"select cast(NULL as varbinary(5))", match(nil)}, + {"select cast('abc' as varbinary(max))", match([]byte{'a', 'b', 'c'})}, + {"select cast('' as varbinary(max))", match([]byte(""))}, + {fmt.Sprintf("select cast('%s' as varbinary(max))", veryLongString), match([]byte(veryLongString))}, + {"select cast(NULL as varbinary(max))", match(nil)}, +} + +// TODO(brainman): see why typeMSSpecificTests do not work on freetds + +var typeMSSpecificTests = []typeTest{ + {"select cast(N'\u0421\u0430\u0448\u0430' as nvarchar(5))", match([]byte("\u0421\u0430\u0448\u0430"))}, + {"select cast(N'\u0421\u0430\u0448\u0430' as nvarchar(max))", match([]byte("\u0421\u0430\u0448\u0430"))}, + {"select cast(N'\u0421\u0430\u0448\u0430' as ntext)", match([]byte("\u0421\u0430\u0448\u0430"))}, +} + +var typeMSSQL2008Tests = []typeTest{ + // datetime2 + {"select cast('20151225' as datetime2)", match(time.Date(2015, 12, 25, 0, 0, 0, 0, time.Local))}, + {"select cast('2007-05-08 12:35:29.1234567' as datetime2)", match(time.Date(2007, 5, 8, 12, 35, 29, 1234567e2, time.Local))}, + {"select cast(NULL as datetime2)", match(nil)}, + + // time(7) + {"select cast('12:35:29.1234567' as time(7))", match(time.Date(1, 1, 1, 12, 35, 29, 1234567e2, time.Local))}, + {"select cast(NULL as time(7))", match(nil)}, +} + +var typeTestsToFail = []string{ + // int + "select cast(-1 as tinyint)", + "select cast(256 as tinyint)", + "select cast(-32769 as smallint)", + "select cast(32768 as smallint)", + "select cast(-9223372036854775809 as bigint)", + "select cast(9223372036854775808 as bigint)", + + // decimal + "select cast(1234.5 as decimal(5, 2))", + + // uniqueidentifier + "select cast('0x984725-c51c-4bf4-9960-e1c80e27aba0' as uniqueidentifier)", +} + +func TestMSSQLTypes(t *testing.T) { + db, sc, err := mssqlConnect() + if err != nil { + t.Fatal(err) + } + defer closeDB(t, db, sc, sc) + + tests := typeTests + if !isFreeTDS() { + tests = append(tests, typeMSSpecificTests...) + } + if is2008OrLater(db) { + tests = append(tests, typeMSSQL2008Tests...) + } + for _, r := range tests { + func() { + rows, err := db.Query(r.query) + if err != nil { + t.Errorf("db.Query(%q) failed: %v", r.query, err) + return + } + defer rows.Close() + for rows.Next() { + var got interface{} + err := rows.Scan(&got) + if err != nil { + t.Errorf("rows.Scan for %q failed: %v", r.query, err) + return + } + err = r.match(got) + if err != nil { + t.Errorf("test %q failed: %v", r.query, err) + } + } + err = rows.Err() + if err != nil { + t.Error(err) + return + } + }() + } + + for _, query := range typeTestsToFail { + rows, err := db.Query(query) + if err != nil { + continue + } + rows.Close() + t.Errorf("test %q passed, but should fail", query) + } +} + +// TestMSSQLIntAfterText verify that non-bindable column can +// precede bindable column. +func TestMSSQLIntAfterText(t *testing.T) { + db, sc, err := mssqlConnect() + if err != nil { + t.Fatal(err) + } + defer closeDB(t, db, sc, sc) + + const query = "select cast('abc' as text), cast(123 as int)" + rows, err := db.Query(query) + if err != nil { + t.Fatalf("db.Query(%q) failed: %v", query, err) + } + defer rows.Close() + for rows.Next() { + var i int + var text string + err = rows.Scan(&text, &i) + if err != nil { + t.Fatalf("rows.Scan for %q failed: %v", query, err) + } + if text != "abc" { + t.Errorf("expected \"abc\", but received %v", text) + } + if i != 123 { + t.Errorf("expected 123, but received %v", i) + } + } + err = rows.Err() + if err != nil { + t.Fatal(err) + } +} + +func TestMSSQLStmtAndRows(t *testing.T) { + db, sc, err := mssqlConnect() + if err != nil { + t.Fatal(err) + } + defer func() { + // not checking resources usage here, because these are + // unpredictable due to use of goroutines. + err := db.Close() + if err != nil { + t.Fatalf("error closing DB: %v", err) + } + }() + + var staff = map[string][]string{ + "acc": {"John", "Mary", "Moe"}, + "eng": {"Bar", "Foo", "Uno"}, + "sls": {"Scrudge", "Sls2", "Sls3"}, + } + + db.Exec("drop table dbo.temp") + exec(t, db, "create table dbo.temp (dept char(3), name varchar(20))") + + func() { + // test 1 Stmt and many Exec's + s, err := db.Prepare("insert into dbo.temp (dept, name) values (?, ?)") + if err != nil { + t.Fatal(err) + } + defer s.Close() + for dept, people := range staff { + for _, person := range people { + _, err := s.Exec(dept, person) + if err != nil { + t.Fatal(err) + } + } + } + }() + + func() { + // test Stmt is closed before Rows are + s, err := db.Prepare("select name from dbo.temp") + if err != nil { + t.Fatal(err) + } + + r, err := s.Query() + if err != nil { + s.Close() + t.Fatal(err) + } + defer r.Close() + + // TODO(brainman): dangling statement(s) bug reported + // https://github.com/golang/go/issues/3865 + err = s.Close() + if err != nil { + t.Fatal(err) + } + + n := 0 + for r.Next() { + var name string + err = r.Scan(&name) + if err != nil { + t.Fatal(err) + } + n++ + } + err = r.Err() + if err != nil { + t.Fatal(err) + } + const should = 9 + if n != should { + t.Fatalf("expected %v, but received %v", should, n) + } + }() + + if db.Driver().(*Driver).Stats.StmtCount != sc { + t.Fatalf("invalid statement count: expected %v, is %v", sc, db.Driver().(*Driver).Stats.StmtCount) + } + + // no resource tracking past this point + + func() { + // test 1 Stmt and many Query's executed one after the other + s, err := db.Prepare("select name from dbo.temp where dept = ? order by name") + if err != nil { + t.Fatal(err) + } + defer s.Close() + for dept, people := range staff { + func() { + r, err := s.Query(dept) + if err != nil { + t.Fatal(err) + } + defer r.Close() + i := 0 + for r.Next() { + var is string + err = r.Scan(&is) + if err != nil { + t.Fatal(err) + } + if people[i] != is { + t.Fatalf("expected %v, but received %v", people[i], is) + } + i++ + } + err = r.Err() + if err != nil { + t.Fatal(err) + } + }() + } + // test 1 Stmt and many simultaneous Query's + eof := fmt.Errorf("eof") + ch := make(map[string]chan error) + for dept, people := range staff { + c := make(chan error) + go func(c chan error, dept string, people []string) { + c <- nil + // NOTE(brainman): this could actually re-prepare since + // we are running it simultaneously in multiple goroutines + r, err := s.Query(dept) + if err != nil { + c <- fmt.Errorf("%v", err) + return + } + defer r.Close() + i := 0 + c <- nil + for r.Next() { + var is string + c <- nil + err = r.Scan(&is) + if err != nil { + c <- fmt.Errorf("%v", err) + return + } + c <- nil + if people[i] != is { + c <- fmt.Errorf("expected %v, but received %v", people[i], is) + return + } + i++ + } + err = r.Err() + if err != nil { + c <- fmt.Errorf("%v", err) + return + } + c <- eof + }(c, dept, people) + ch[dept] = c + } + for len(ch) > 0 { + for dept, c := range ch { + err := <-c + if err != nil { + if err != eof { + t.Errorf("dept=%v: %v", dept, err) + } + delete(ch, dept) + } + } + } + }() + + exec(t, db, "drop table dbo.temp") +} + +func TestMSSQLIssue5(t *testing.T) { + testingIssue5 = true + defer func() { + testingIssue5 = false + }() + db, sc, err := mssqlConnect() + if err != nil { + t.Fatal(err) + } + + const nworkers = 8 + defer closeDB(t, db, sc, sc) + + db.Exec("drop table dbo.temp") + exec(t, db, ` + create table dbo.temp ( + id int, + value int, + constraint [pk_id] primary key ([id]) + ) + `) + + var count int32 + + runCycle := func(waitch <-chan struct{}, errch chan<- error) (reterr error) { + defer func() { + errch <- reterr + }() + stmt, err := db.Prepare("insert into dbo.temp (id, value) values (?, ?)") + if err != nil { + return fmt.Errorf("Prepare failed: %v", err) + } + defer stmt.Close() + errch <- nil + <-waitch + for { + i := (int)(atomic.AddInt32(&count, 1)) + _, err := stmt.Exec(i, i) + if err != nil { + return fmt.Errorf("Exec failed i=%d: %v", i, err) + } + runtime.GC() + if i >= 100 { + break + } + } + return + } + + waitch := make(chan struct{}) + errch := make(chan error, nworkers) + for i := 0; i < nworkers; i++ { + go runCycle(waitch, errch) + } + for i := 0; i < nworkers; i++ { + if err := <-errch; err != nil { + t.Error(err) + } + } + if t.Failed() { + return + } + close(waitch) + for i := 0; i < nworkers; i++ { + if err := <-errch; err != nil { + t.Fatal(err) + } + } + // TODO: maybe I should verify dbo.temp records here + + exec(t, db, "drop table dbo.temp") +} + +func TestMSSQLDeleteNonExistent(t *testing.T) { + db, sc, err := mssqlConnect() + if err != nil { + t.Fatal(err) + } + defer closeDB(t, db, sc, sc) + + db.Exec("drop table dbo.temp") + exec(t, db, "create table dbo.temp (name varchar(20))") + _, err = db.Exec("insert into dbo.temp (name) values ('alex')") + if err != nil { + t.Fatal(err) + } + + r, err := db.Exec("delete from dbo.temp where name = 'bob'") + if err != nil { + t.Fatalf("Exec failed: %v", err) + } + cnt, err := r.RowsAffected() + if err != nil { + t.Fatalf("RowsAffected failed: %v", err) + } + if cnt != 0 { + t.Fatalf("RowsAffected returns %d, but 0 expected", cnt) + } + + exec(t, db, "drop table dbo.temp") +} + +// https://github.com/alexbrainman/odbc/issues/14 +func TestMSSQLDatetime2Param(t *testing.T) { + db, sc, err := mssqlConnect() + if err != nil { + t.Fatal(err) + } + defer closeDB(t, db, sc, sc) + + if !is2008OrLater(db) { + t.Skip("skipping test; needs MS SQL Server 2008 or later") + } + + db.Exec("drop table dbo.temp") + exec(t, db, "create table dbo.temp (dt datetime2)") + + expect := time.Date(2007, 5, 8, 12, 35, 29, 1234567e2, time.Local) + _, err = db.Exec("insert into dbo.temp (dt) values (?)", expect) + if err != nil { + t.Fatal(err) + } + var got time.Time + err = db.QueryRow("select top 1 dt from dbo.temp").Scan(&got) + if err != nil { + t.Fatal(err) + } + if expect != got { + t.Fatalf("expect %v, but got %v", expect, got) + } + + exec(t, db, "drop table dbo.temp") +} + +// https://github.com/alexbrainman/odbc/issues/19 +func TestMSSQLMerge(t *testing.T) { + db, sc, err := mssqlConnect() + if err != nil { + t.Fatal(err) + } + defer closeDB(t, db, sc, sc) + + if !is2008OrLater(db) { + t.Skip("skipping test; needs MS SQL Server 2008 or later") + } + + db.Exec("drop table dbo.temp") + exec(t, db, ` + create table dbo.temp ( + id int not null, + name varchar(20), + constraint pk_temp primary key(id) + ) + `) + for i := 0; i < 5; i++ { + _, err = db.Exec("insert into dbo.temp (id, name) values (?, ?)", i, fmt.Sprintf("gordon%d", i)) + if err != nil { + t.Fatal(err) + } + } + + s, err := db.Prepare(` + merge into dbo.temp as dest + using ( values (?, ?) ) as src (id, name) on src.id = dest.id + when matched then update set dest.name = src.name + when not matched then insert values (src.id, src.name); + `) + if err != nil { + t.Fatal(err) + } + defer s.Close() + + var tests = []struct { + id int + name string + }{ + {id: 1, name: "new name1"}, + {id: 8, name: "hohoho"}, + } + for _, test := range tests { + _, err = s.Exec(test.id, test.name) + if err != nil { + t.Fatal(err) + } + } + + for _, test := range tests { + var got string + err = db.QueryRow("select name from dbo.temp where id = ?", test.id).Scan(&got) + if err != nil { + t.Fatal(err) + } + if test.name != got { + t.Fatalf("expect %v, but got %v", test.name, got) + } + } + + exec(t, db, "drop table dbo.temp") +} + +// https://github.com/alexbrainman/odbc/issues/20 +func TestMSSQLSelectInt(t *testing.T) { + db, sc, err := mssqlConnect() + if err != nil { + t.Fatal(err) + } + defer closeDB(t, db, sc, sc) + + const expect = 123456 + var got int + if err := db.QueryRow("select ?", expect).Scan(&got); err != nil { + t.Fatal(err) + } + if expect != got { + t.Fatalf("expect %v, but got %v", expect, got) + } +} + +// https://github.com/alexbrainman/odbc/issues/21 +func TestMSSQLTextColumnParam(t *testing.T) { + db, sc, err := mssqlConnect() + if err != nil { + t.Fatal(err) + } + defer closeDB(t, db, sc, sc) + + db.Exec("drop table dbo.temp") + exec(t, db, `create table dbo.temp(id int primary key not null, v1 text, v2 text, v3 text, v4 text, v5 text, v6 text, v7 text, v8 text)`) + + s, err := db.Prepare(`insert into dbo.temp(id, v1, v2, v3, v4, v5, v6, v7, v8) values (?, ?, ?, ?, ?, ?, ?, ?, ?)`) + if err != nil { + t.Fatal(err) + } + defer s.Close() + + b := "string string string string string string string string string" + for i := 0; i < 100; i++ { + _, err := s.Exec(i, b, b, b, b, b, b, b, b) + if err != nil { + t.Fatal(err) + } + } + + exec(t, db, "drop table dbo.temp") +} + +func digestString(s string) string { + if len(s) < 40 { + return s + } + return fmt.Sprintf("%s ... (%d bytes long)", s[:40], len(s)) +} + +func digestBytes(b []byte) string { + if len(b) < 20 { + return fmt.Sprintf("%v", b) + } + s := "" + for _, v := range b[:20] { + if s != "" { + s += " " + } + s += fmt.Sprintf("%d", v) + } + return fmt.Sprintf("[%v ...] (%d bytes long)", s, len(b)) +} + +var paramTypeTests = []struct { + description string + sqlType string + value interface{} +}{ + // nil parameters + {"NULL for bit", "bit", nil}, + {"NULL for text", "text", nil}, + {"NULL for int", "int", nil}, + // strings + {"non empty string", "varchar(10)", "abc"}, + {"one character string", "varchar(10)", "a"}, + {"empty string", "varchar(10)", ""}, + {"empty unicode string", "nvarchar(10)", ""}, + {"3999 large unicode string", "nvarchar(max)", strings.Repeat("a", 3999)}, + {"4000 large unicode string", "nvarchar(max)", strings.Repeat("a", 4000)}, + {"4000 large non-ascii unicode string", "nvarchar(max)", strings.Repeat("\u0421", 4000)}, + {"4001 large unicode string", "nvarchar(max)", strings.Repeat("a", 4001)}, + {"4001 large non-ascii unicode string", "nvarchar(max)", strings.Repeat("\u0421", 4001)}, + {"10000 large unicode string", "nvarchar(max)", strings.Repeat("a", 10000)}, + {"empty unicode null string", "nvarchar(10) null", ""}, + {"3999 large string value", "text", strings.Repeat("a", 3999)}, + {"4000 large string value", "text", strings.Repeat("a", 4000)}, + {"4000 large unicode string value", "ntext", strings.Repeat("\u0421", 4000)}, + {"4001 large string value", "text", strings.Repeat("a", 4001)}, + {"4001 large unicode string value", "ntext", strings.Repeat("\u0421", 4001)}, + {"very large string value", "text", strings.Repeat("a", 10000)}, + // datetime + {"datetime overflow", "datetime", time.Date(2013, 9, 9, 14, 07, 15, 123e6, time.Local)}, + // binary blobs + {"small blob", "varbinary", make([]byte, 1)}, + {"very large blob", "varbinary(max)", make([]byte, 100000)}, + {"7999 large image", "image", make([]byte, 7999)}, + {"8000 large image", "image", make([]byte, 8000)}, + {"8001 large image", "image", make([]byte, 8001)}, + {"very large image", "image", make([]byte, 10000)}, +} + +func TestMSSQLTextColumnParamTypes(t *testing.T) { + db, sc, err := mssqlConnect() + if err != nil { + t.Fatal(err) + } + defer closeDB(t, db, sc, sc) + + for _, test := range paramTypeTests { + db.Exec("drop table dbo.temp") + exec(t, db, fmt.Sprintf("create table dbo.temp(v %s)", test.sqlType)) + _, err = db.Exec("insert into dbo.temp(v) values(?)", test.value) + if err != nil { + t.Errorf("%s insert test failed: %s", test.description, err) + } + var v interface{} + err = db.QueryRow("select v from dbo.temp").Scan(&v) + if err != nil { + t.Errorf("%s select test failed: %s", test.description, err) + continue + } + switch want := test.value.(type) { + case string: + have := string(v.([]byte)) + if have != want { + t.Errorf("%s wrong return value: have %q; want %q", test.description, digestString(have), digestString(want)) + } + case []byte: + have := v.([]byte) + if !bytes.Equal(have, want) { + t.Errorf("%s wrong return value: have %v; want %v", test.description, digestBytes(have), digestBytes(want)) + } + case time.Time: + have := v.(time.Time) + if have != want { + t.Errorf("%s wrong return value: have %v; want %v", test.description, have, want) + } + case nil: + if v != nil { + t.Errorf("%s wrong return value: have %v; want nil", test.description, v) + } + } + } + exec(t, db, "drop table dbo.temp") +} + +func TestMSSQLLongColumnNames(t *testing.T) { + db, sc, err := mssqlConnect() + if err != nil { + t.Fatal(err) + } + defer closeDB(t, db, sc, sc) + + query := fmt.Sprintf("select 'hello' as %s", strings.Repeat("a", 110)) + var s string + err = db.QueryRow(query).Scan(&s) + if err != nil { + t.Fatal(err) + } + if s != "hello" { + t.Errorf("expected \"hello\", but received %v", s) + } +} + +func TestMSSQLRawBytes(t *testing.T) { + db, sc, err := mssqlConnect() + if err != nil { + t.Fatal(err) + } + defer closeDB(t, db, sc, sc) + + db.Exec("drop table dbo.temp") + exec(t, db, `create table dbo.temp(ascii char(7), utf16 nchar(7), blob binary(3))`) + _, err = db.Exec(`insert into dbo.temp (ascii, utf16, blob) values (?, ?, ?)`, "alex", "alex", []byte{1, 2, 3}) + if err != nil { + t.Fatal(err) + } + + rows, err := db.Query("select ascii, utf16, blob from dbo.temp") + if err != nil { + t.Fatalf("Query: %v", err) + } + defer rows.Close() + + for rows.Next() { + var ascii, utf16 sql.RawBytes + var blob []byte + err = rows.Scan(&ascii, &utf16, &blob) + if err != nil { + t.Fatalf("Scan: %v", err) + } + } + err = rows.Err() + if err != nil { + t.Fatal(err) + } + + exec(t, db, "drop table dbo.temp") +} + +// https://github.com/alexbrainman/odbc/issues/27 +func TestMSSQLUTF16ToUTF8(t *testing.T) { + s := []uint16{0x47, 0x75, 0x73, 0x74, 0x61, 0x66, 0x27, 0x73, 0x20, 0x4b, 0x6e, 0xe4, 0x63, 0x6b, 0x65, 0x62, 0x72, 0xf6, 0x64} + if api.UTF16ToString(s) != string(utf16toutf8(s)) { + t.Fatal("comparison fails") + } +} + +func TestMSSQLExecStoredProcedure(t *testing.T) { + db, sc, err := mssqlConnect() + if err != nil { + t.Fatal(err) + } + defer closeDB(t, db, sc, sc) + + db.Exec("drop procedure dbo.temp") + exec(t, db, ` +create procedure dbo.temp + @a int, + @b int +as +begin + return @a + @b +end +`) + qry := ` +declare @ret int +exec @ret = dbo.temp @a = ?, @b = ? +select @ret +` + var ret int64 + if err := db.QueryRow(qry, 2, 3).Scan(&ret); err != nil { + t.Fatal(err) + } + if ret != 5 { + t.Fatalf("unexpected return value: should=5, is=%v", ret) + } + exec(t, db, `drop procedure dbo.temp`) +} + +func TestMSSQLSingleCharParam(t *testing.T) { + db, sc, err := mssqlConnect() + if err != nil { + t.Fatal(err) + } + defer closeDB(t, db, sc, sc) + + db.Exec("drop table dbo.temp") + exec(t, db, `create table dbo.temp(name nvarchar(50), age int)`) + + rows, err := db.Query("select age from dbo.temp where name=?", "v") + if err != nil { + t.Fatal(err) + } + rows.Close() + + exec(t, db, "drop table dbo.temp") +} + +type tcpProxy struct { + mu sync.Mutex + stopped bool + conns []net.Conn +} + +func (p *tcpProxy) run(ln net.Listener, remote string) { + for { + defer p.pause() + c1, err := ln.Accept() + if err != nil { + return + } + go func(c1 net.Conn) { + defer c1.Close() + + if p.paused() { + return + } + + p.addConn(c1) + + c2, err := net.Dial("tcp", remote) + if err != nil { + panic(err) + } + p.addConn(c2) + defer c2.Close() + + go func() { + io.Copy(c2, c1) + }() + io.Copy(c1, c2) + }(c1) + } +} + +func (p *tcpProxy) pause() { + p.mu.Lock() + defer p.mu.Unlock() + p.stopped = true + for _, c := range p.conns { + c.Close() + } + p.conns = p.conns[:0] +} + +func (p *tcpProxy) paused() bool { + p.mu.Lock() + defer p.mu.Unlock() + return p.stopped +} + +func (p *tcpProxy) addConn(c net.Conn) { + p.mu.Lock() + defer p.mu.Unlock() + p.conns = append(p.conns, c) +} + +func (p *tcpProxy) restart() { + p.mu.Lock() + defer p.mu.Unlock() + p.stopped = false +} + +func TestMSSQLReconnect(t *testing.T) { + params := newConnParams() + address, err := params.getConnAddress() + if err != nil { + t.Skipf("Skipping test: %v", err) + } + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + err = params.updateConnAddress(ln.Addr().String()) + if err != nil { + t.Fatal(err) + } + + proxy := new(tcpProxy) + go proxy.run(ln, address) + + db, sc, err := mssqlConnectWithParams(params) + if err != nil { + t.Fatal(err) + } + defer closeDB(t, db, sc, sc) + + testConn := func() error { + var n int64 + err := db.QueryRow("select count(*) from dbo.temp").Scan(&n) + if err != nil { + return err + } + if n != 1 { + return fmt.Errorf("unexpected return value: should=1, is=%v", n) + } + return nil + } + + db.Exec("drop table dbo.temp") + exec(t, db, `create table dbo.temp (name varchar(50))`) + exec(t, db, `insert into dbo.temp (name) values ('alex')`) + + err = testConn() + if err != nil { + t.Fatal(err) + } + + proxy.pause() + time.Sleep(100 * time.Millisecond) + + err = testConn() + if err == nil { + t.Fatal("database IO should fail, but succeeded") + } + + proxy.restart() + + err = testConn() + if err != nil { + t.Fatal(err) + } + + exec(t, db, "drop table dbo.temp") +} + +func TestMSSQLMarkTxBadConn(t *testing.T) { + params := newConnParams() + + address, err := params.getConnAddress() + if err != nil { + t.Skipf("Skipping test: %v", err) + } + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + err = params.updateConnAddress(ln.Addr().String()) + if err != nil { + t.Fatal(err) + } + + proxy := new(tcpProxy) + go proxy.run(ln, address) + + testFn := func(endTx func(driver.Tx) error, nextFn func(driver.Conn) error) { + proxy.restart() + + cc, sc := drv.Stats.ConnCount, drv.Stats.StmtCount + defer func() { + if should, is := sc, drv.Stats.StmtCount; should != is { + t.Errorf("leaked statement, should=%d, is=%d", should, is) + } + if should, is := cc, drv.Stats.ConnCount; should != is { + t.Errorf("leaked connection, should=%d, is=%d", should, is) + } + }() + + dc, err := drv.Open(params.makeODBCConnectionString()) + if err != nil { + t.Fatal(err) + } + defer func() { + if err := dc.Close(); err != nil { + t.Fatal(err) + } + }() + + driverExec(nil, dc, "drop table dbo.temp") + driverExec(t, dc, `create table dbo.temp (name varchar(50))`) + + tx, err := dc.Begin() + if err != nil { + t.Fatal(err) + } + + driverExec(t, dc, `insert into dbo.temp (name) values ('alex')`) + + proxy.pause() + time.Sleep(100 * time.Millisecond) + + // the connection is broken, ending the transaction should fail + if err := endTx(tx); err == nil { + t.Fatal("unexpected success, expected error") + } + + // database/sql might return the broken driver.Conn to the pool in + // that case the next operation must fail. + if err := nextFn(dc); err == nil { + t.Fatal("unexpected success, expected error") + } + } + + beginFn := func(dc driver.Conn) error { + tx, err := dc.Begin() + if err != nil { + return err + } + tx.Rollback() + return nil + } + + prepareFn := func(dc driver.Conn) error { + st, err := dc.Prepare(`insert into dbo.temp (name) values ('alex')`) + if err != nil { + return err + } + st.Close() + return nil + } + + // Test all the permutations. + for _, endTx := range []func(driver.Tx) error{ + driver.Tx.Commit, + driver.Tx.Rollback, + } { + for _, nextFn := range []func(driver.Conn) error{ + beginFn, + prepareFn, + } { + testFn(endTx, nextFn) + } + } +} + +func TestMSSQLMarkBeginBadConn(t *testing.T) { + params := newConnParams() + + testFn := func(label string, nextFn func(driver.Conn) error) { + cc, sc := drv.Stats.ConnCount, drv.Stats.StmtCount + defer func() { + if should, is := sc, drv.Stats.StmtCount; should != is { + t.Errorf("leaked statement, should=%d, is=%d", should, is) + } + if should, is := cc, drv.Stats.ConnCount; should != is { + t.Errorf("leaked connection, should=%d, is=%d", should, is) + } + }() + + dc, err := drv.Open(params.makeODBCConnectionString()) + if err != nil { + t.Fatal(err) + } + defer func() { + if err := dc.Close(); err != nil { + t.Fatal(err) + } + }() + + driverExec(nil, dc, "drop table dbo.temp") + driverExec(t, dc, `create table dbo.temp (name varchar(50))`) + + // force an error starting a transaction + func() { + testBeginErr = errors.New("cannot start tx") + defer func() { testBeginErr = nil }() + + if _, err := dc.Begin(); err == nil { + t.Fatal("unexpected success, expected error") + } + }() + + // database/sql might return the broken driver.Conn to the pool. The + // next operation on the driver connection must return + // driver.ErrBadConn to prevent the bad connection from getting used + // again. + if should, is := driver.ErrBadConn, nextFn(dc); should != is { + t.Errorf("%s: should=\"%v\", is=\"%v\"", label, should, is) + } + } + + beginFn := func(dc driver.Conn) error { + tx, err := dc.Begin() + if err != nil { + return err + } + tx.Rollback() + return nil + } + + prepareFn := func(dc driver.Conn) error { + st, err := dc.Prepare(`insert into dbo.temp (name) values ('alex')`) + if err != nil { + return err + } + st.Close() + return nil + } + + // Test all the permutations. + for _, next := range []struct { + label string + fn func(driver.Conn) error + }{ + {"begin", beginFn}, + {"prepare", prepareFn}, + } { + testFn(next.label, next.fn) + } +} + +func testMSSQLNextResultSet(t *testing.T, verifyBatch func(rows *sql.Rows)) { + db, sc, err := mssqlConnect() + if err != nil { + t.Fatal(err) + } + defer closeDB(t, db, sc, sc) + + db.Exec("drop table dbo.temp") + exec(t, db, `create table dbo.temp (name varchar(50))`) + exec(t, db, `insert into dbo.temp (name) values ('russ')`) + exec(t, db, `insert into dbo.temp (name) values ('brad')`) + + rows, err := db.Query(` +select name from dbo.temp where name = 'russ'; +select name from dbo.temp where name = 'brad'; +`) + if err != nil { + t.Fatal(err) + } + defer rows.Close() + + verifyBatch(rows) + + exec(t, db, "drop table dbo.temp") +} + +func TestMSSQLNextResultSet(t *testing.T) { + checkName := func(rows *sql.Rows, name string) { + if !rows.Next() { + if err := rows.Err(); err != nil { + t.Fatalf("executing Next for %q failed: %v", name, err) + } + t.Fatalf("checking %q: at least one row expected", name) + } + var have string + err := rows.Scan(&have) + if err != nil { + t.Fatalf("executing Scan for %q failed: %v", name, err) + } + if name != have { + t.Fatalf("want %q, but %q found", name, have) + } + } + testMSSQLNextResultSet(t, + func(rows *sql.Rows) { + checkName(rows, "russ") + if !rows.NextResultSet() { + if err := rows.Err(); err != nil { + t.Fatal(err) + } + t.Fatal("more result sets expected") + } + checkName(rows, "brad") + if isFreeTDS() { // not sure why it does not work on FreeTDS + t.Log("skipping broken part of the test on FreeTDS") + return + } + if rows.NextResultSet() { + t.Fatal("unexpected result set found") + } else if err := rows.Err(); err != nil { + t.Fatal(err) + } + }) +} + +func TestMSSQLNextResultSetWithDifferentColumnsInResultSets(t *testing.T) { + db, sc, err := mssqlConnect() + if err != nil { + t.Fatal(err) + } + defer closeDB(t, db, sc, sc) + rows, err := db.Query("select 1 select 2,3") + if err != nil { + t.Fatal(err) + } + defer rows.Close() + if !rows.Next() { + t.Fatal("expected at least 1 result") + } + var v1, v2 int + err = rows.Scan(&v1) + if err != nil { + t.Fatalf("unable to scan select 1 underlying error: %v", err) + } + if v1 != 1 { + t.Fatalf("expected: %v got %v", 1, v1) + } + if rows.Next() { + t.Fatal("unexpected row") + } + if !rows.NextResultSet() { + t.Fatal("expected another result set") + } + if !rows.Next() { + t.Fatal("expected a single row") + } + err = rows.Scan(&v1, &v2) + if err != nil { + t.Fatalf("unable to scan select 2,3 underlying error: %v", err) + } + if v1 != 2 || v2 != 3 { + t.Fatalf("got wrong values expected v1=%v v2=%v. got v1=%v v2=%v", 2, 3, v1, v2) + } + +} + +func TestMSSQLHasNextResultSet(t *testing.T) { + checkName := func(rows *sql.Rows, name string) { + var reccount int + for rows.Next() { // reading till the end of data set to trigger call into HasNextResultSet + var have string + err := rows.Scan(&have) + if err != nil { + t.Fatalf("executing Scan for %q failed: %v", name, err) + } + if name != have { + t.Fatalf("want %q, but %q found", name, have) + } + reccount++ + } + if err := rows.Err(); err != nil { + t.Fatalf("executing Next for %q failed: %v", name, err) + } + if reccount != 1 { + t.Fatalf("checking %q: expected 1 row returned, but %v found", name, reccount) + } + } + testMSSQLNextResultSet(t, + func(rows *sql.Rows) { + checkName(rows, "russ") + if !rows.NextResultSet() { + if err := rows.Err(); err != nil { + t.Fatal(err) + } + t.Fatal("more result sets expected") + } + checkName(rows, "brad") + if rows.NextResultSet() { + t.Fatal("unexpected result set found") + } else { + if err := rows.Err(); err != nil { + t.Fatal(err) + } + } + }) +} + +func TestMSSQLIssue127(t *testing.T) { + db, sc, err := mssqlConnect() + if err != nil { + t.Fatal(err) + } + defer closeDB(t, db, sc, sc) + + db.Exec("drop table dbo.temp") + exec(t, db, "create table dbo.temp (id int, a varchar(255))") + + tx, err := db.Begin() + if err != nil { + t.Fatal(err) + } + + stmt, err := tx.Prepare(` +DECLARE @id INT, @a VARCHAR(255) +SELECT @id = ?, @a = ? +UPDATE dbo.temp SET a = @a WHERE id = @id +IF @@ROWCOUNT = 0 + INSERT INTO dbo.temp (id, a) VALUES (@id, @a) +`) + if err != nil { + t.Fatal(err) + } + if _, err = stmt.Exec(1, "test"); err != nil { + t.Errorf("Failed to insert record with ID 1: %s", err) + } + if _, err = stmt.Exec(1, "test2"); err != nil { + t.Errorf("Failed to update record with ID 1: %s", err) + } + + if err = tx.Commit(); err != nil { + t.Fatal(err) + } +} diff --git a/mysql_test.go b/mysql_test.go new file mode 100644 index 0000000..3d73470 --- /dev/null +++ b/mysql_test.go @@ -0,0 +1,60 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package odbc + +import ( + "database/sql" + "flag" + "fmt" + "testing" + "time" +) + +var ( + mysrv = flag.String("mysrv", "server", "mysql server name") + mydb = flag.String("mydb", "dbname", "mysql database name") + myuser = flag.String("myuser", "", "mysql user name") + mypass = flag.String("mypass", "", "mysql password") +) + +func mysqlConnect() (db *sql.DB, stmtCount int, err error) { + // from https://dev.mysql.com/doc/connector-odbc/en/connector-odbc-configuration-connection-parameters.html + conn := fmt.Sprintf("driver=mysql;server=%s;database=%s;user=%s;password=%s;", + *mysrv, *mydb, *myuser, *mypass) + db, err = sql.Open("odbc", conn) + if err != nil { + return nil, 0, err + } + stats := db.Driver().(*Driver).Stats + return db, stats.StmtCount, nil +} + +func TestMYSQLTime(t *testing.T) { + db, sc, err := mysqlConnect() + if err != nil { + t.Fatal(err) + } + defer closeDB(t, db, sc, sc) + + db.Exec("drop table temp") + exec(t, db, "create table temp(id int not null auto_increment primary key, time time)") + now := time.Now() + // SQL_TIME_STRUCT only supports hours, minutes and seconds + now = time.Date(1, time.January, 1, now.Hour(), now.Minute(), now.Second(), 0, time.Local) + _, err = db.Exec("insert into temp (time) values(?)", now) + if err != nil { + t.Fatal(err) + } + + var ret time.Time + if err := db.QueryRow("select time from temp where id = ?", 1).Scan(&ret); err != nil { + t.Fatal(err) + } + if ret != now { + t.Fatalf("unexpected return value: want=%v, is=%v", now, ret) + } + + exec(t, db, "drop table temp") +} diff --git a/odbc.iml b/odbc.iml new file mode 100644 index 0000000..49df094 --- /dev/null +++ b/odbc.iml @@ -0,0 +1,10 @@ + + + + + + + + + + \ No newline at end of file diff --git a/odbcstmt.go b/odbcstmt.go new file mode 100644 index 0000000..bd88ec3 --- /dev/null +++ b/odbcstmt.go @@ -0,0 +1,160 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package odbc + +import ( + "database/sql/driver" + "errors" + "fmt" + "sync" + "time" + "unsafe" + + "github.com/taoikaihatsu-dev/odbc/api" +) + +// TODO(brainman): see if I could use SQLExecDirect anywhere + +type ODBCStmt struct { + h api.SQLHSTMT + Parameters []Parameter + Cols []Column + // locking/lifetime + mu sync.Mutex + usedByStmt bool + usedByRows bool +} + +func (c *Conn) PrepareODBCStmt(query string) (*ODBCStmt, error) { + var out api.SQLHANDLE + ret := api.SQLAllocHandle(api.SQL_HANDLE_STMT, api.SQLHANDLE(c.h), &out) + if IsError(ret) { + return nil, c.newError("SQLAllocHandle", c.h) + } + h := api.SQLHSTMT(out) + err := drv.Stats.updateHandleCount(api.SQL_HANDLE_STMT, 1) + if err != nil { + return nil, err + } + + b := api.StringToUTF16(query) + ret = api.SQLPrepare(h, (*api.SQLWCHAR)(unsafe.Pointer(&b[0])), api.SQL_NTS) + if IsError(ret) { + defer releaseHandle(h) + return nil, c.newError("SQLPrepare", h) + } + ps, err := ExtractParameters(h) + if err != nil { + defer releaseHandle(h) + return nil, err + } + return &ODBCStmt{ + h: h, + Parameters: ps, + usedByStmt: true, + }, nil +} + +func (s *ODBCStmt) closeByStmt() error { + s.mu.Lock() + defer s.mu.Unlock() + if s.usedByStmt { + defer func() { s.usedByStmt = false }() + if !s.usedByRows { + return s.releaseHandle() + } + } + return nil +} + +func (s *ODBCStmt) closeByRows() error { + s.mu.Lock() + defer s.mu.Unlock() + if s.usedByRows { + defer func() { s.usedByRows = false }() + if s.usedByStmt { + ret := api.SQLCloseCursor(s.h) + if IsError(ret) { + return NewError("SQLCloseCursor", s.h) + } + return nil + } else { + return s.releaseHandle() + } + } + return nil +} + +func (s *ODBCStmt) releaseHandle() error { + h := s.h + s.h = api.SQLHSTMT(api.SQL_NULL_HSTMT) + return releaseHandle(h) +} + +var testingIssue5 bool // used during tests + +func (s *ODBCStmt) Exec(args []driver.Value, conn *Conn) error { + if len(args) != len(s.Parameters) { + return fmt.Errorf("wrong number of arguments %d, %d expected", len(args), len(s.Parameters)) + } + for i, a := range args { + // this could be done in 2 steps: + // 1) bind vars right after prepare; + // 2) set their (vars) values here; + // but rebinding parameters for every new parameter value + // should be efficient enough for our purpose. + if err := s.Parameters[i].BindValue(s.h, i, a, conn); err != nil { + return err + } + } + if testingIssue5 { + time.Sleep(10 * time.Microsecond) + } + ret := api.SQLExecute(s.h) + if ret == api.SQL_NO_DATA { + // success but no data to report + return nil + } + if IsError(ret) { + return NewError("SQLExecute", s.h) + } + return nil +} + +func (s *ODBCStmt) BindColumns() error { + // count columns + var n api.SQLSMALLINT + ret := api.SQLNumResultCols(s.h, &n) + if IsError(ret) { + return NewError("SQLNumResultCols", s.h) + } + if n < 1 { + return errors.New("Stmt did not create a result set") + } + // fetch column descriptions + s.Cols = make([]Column, n) + binding := true + for i := range s.Cols { + c, err := NewColumn(s.h, i) + if err != nil { + return err + } + s.Cols[i] = c + // Once we found one non-bindable column, we will not bind the rest. + // http://www.easysoft.com/developer/languages/c/odbc-tutorial-fetching-results.html + // ... One common restriction is that SQLGetData may only be called on columns after the last bound column. ... + if !binding { + continue + } + bound, err := s.Cols[i].Bind(s.h, i) + if err != nil { + return err + } + if !bound { + binding = false + } + } + return nil +} diff --git a/param.go b/param.go new file mode 100644 index 0000000..c3b5090 --- /dev/null +++ b/param.go @@ -0,0 +1,209 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package odbc + +import ( + "database/sql/driver" + "fmt" + "time" + "unsafe" + + "github.com/taoikaihatsu-dev/odbc/api" +) + +type Parameter struct { + SQLType api.SQLSMALLINT + Decimal api.SQLSMALLINT + Size api.SQLULEN + isDescribed bool + // Following fields store data used later by SQLExecute. + // The fields keep data alive and away from gc. + Data interface{} + StrLen_or_IndPtr api.SQLLEN +} + +// StoreStrLen_or_IndPtr stores v into StrLen_or_IndPtr field of p +// and returns address of that field. +func (p *Parameter) StoreStrLen_or_IndPtr(v api.SQLLEN) *api.SQLLEN { + p.StrLen_or_IndPtr = v + return &p.StrLen_or_IndPtr + +} + +func (p *Parameter) BindValue(h api.SQLHSTMT, idx int, v driver.Value, conn *Conn) error { + // TODO(brainman): Reuse memory for previously bound values. If memory + // is reused, we, probably, do not need to call SQLBindParameter either. + var ctype, sqltype, decimal api.SQLSMALLINT + var size api.SQLULEN + var buflen api.SQLLEN + var plen *api.SQLLEN + var buf unsafe.Pointer + switch d := v.(type) { + case nil: + ctype = api.SQL_C_WCHAR + p.Data = nil + buf = nil + size = 1 + buflen = 0 + plen = p.StoreStrLen_or_IndPtr(api.SQL_NULL_DATA) + sqltype = api.SQL_WCHAR + case string: + ctype = api.SQL_C_WCHAR + b := api.StringToUTF16(d) + p.Data = b + buf = unsafe.Pointer(&b[0]) + l := len(b) + l -= 1 // remove terminating 0 + size = api.SQLULEN(l) + if size < 1 { + // size cannot be less then 1 even for empty fields + size = 1 + } + l *= 2 // every char takes 2 bytes + buflen = api.SQLLEN(l) + plen = p.StoreStrLen_or_IndPtr(buflen) + if !conn.isMSAccessDriver { + switch { + case size >= 4000: + sqltype = api.SQL_WLONGVARCHAR + case p.isDescribed: + sqltype = p.SQLType + case size <= 1: + sqltype = api.SQL_WVARCHAR + default: + sqltype = api.SQL_WCHAR + } + } else { + // MS Acess requires SQL_WLONGVARCHAR for MEMO. + // https://docs.microsoft.com/en-us/sql/odbc/microsoft/microsoft-access-data-types + sqltype = api.SQL_WLONGVARCHAR + } + case int64: + if -0x80000000 < d && d < 0x7fffffff { + // Some ODBC drivers do not support SQL_BIGINT. + // Use SQL_INTEGER if the value fit in int32. + // See issue #78 for details. + d2 := int32(d) + ctype = api.SQL_C_LONG + p.Data = &d2 + buf = unsafe.Pointer(&d2) + sqltype = api.SQL_INTEGER + size = 4 + } else { + ctype = api.SQL_C_SBIGINT + p.Data = &d + buf = unsafe.Pointer(&d) + sqltype = api.SQL_BIGINT + size = 8 + } + case bool: + var b byte + if d { + b = 1 + } + ctype = api.SQL_C_BIT + p.Data = &b + buf = unsafe.Pointer(&b) + sqltype = api.SQL_BIT + size = 1 + case float64: + ctype = api.SQL_C_DOUBLE + p.Data = &d + buf = unsafe.Pointer(&d) + sqltype = api.SQL_DOUBLE + size = 8 + case time.Time: + ctype = api.SQL_C_TYPE_TIMESTAMP + y, m, day := d.Date() + b := api.SQL_TIMESTAMP_STRUCT{ + Year: api.SQLSMALLINT(y), + Month: api.SQLUSMALLINT(m), + Day: api.SQLUSMALLINT(day), + Hour: api.SQLUSMALLINT(d.Hour()), + Minute: api.SQLUSMALLINT(d.Minute()), + Second: api.SQLUSMALLINT(d.Second()), + Fraction: api.SQLUINTEGER(d.Nanosecond()), + } + p.Data = &b + buf = unsafe.Pointer(&b) + sqltype = api.SQL_TYPE_TIMESTAMP + if p.isDescribed && p.SQLType == api.SQL_TYPE_TIMESTAMP { + decimal = p.Decimal + } + if decimal <= 0 { + // represented as yyyy-mm-dd hh:mm:ss.fff format in ms sql server + decimal = 3 + } + size = 20 + api.SQLULEN(decimal) + case []byte: + ctype = api.SQL_C_BINARY + b := make([]byte, len(d)) + copy(b, d) + p.Data = b + buf = unsafe.Pointer(&b[0]) + buflen = api.SQLLEN(len(b)) + plen = p.StoreStrLen_or_IndPtr(buflen) + size = api.SQLULEN(len(b)) + switch { + case p.isDescribed: + sqltype = p.SQLType + case size <= 0: + sqltype = api.SQL_LONGVARBINARY + case size >= 8000: + sqltype = api.SQL_LONGVARBINARY + default: + sqltype = api.SQL_BINARY + } + default: + return fmt.Errorf("unsupported type %T", v) + } + ret := api.SQLBindParameter(h, api.SQLUSMALLINT(idx+1), + api.SQL_PARAM_INPUT, ctype, sqltype, size, decimal, + api.SQLPOINTER(buf), buflen, plen) + if IsError(ret) { + return NewError("SQLBindParameter", h) + } + return nil +} + +func ExtractParameters(h api.SQLHSTMT) ([]Parameter, error) { + // count parameters + var n, nullable api.SQLSMALLINT + ret := api.SQLNumParams(h, &n) + if IsError(ret) { + return nil, NewError("SQLNumParams", h) + } + if n <= 0 { + // no parameters + return nil, nil + } + ps := make([]Parameter, n) + // fetch param descriptions + for i := range ps { + p := &ps[i] + ret = api.SQLDescribeParam(h, api.SQLUSMALLINT(i+1), + &p.SQLType, &p.Size, &p.Decimal, &nullable) + if IsError(ret) { + // SQLDescribeParam is not implemented by freedts, + // it even fails for some statements on windows. + // Will try request without these descriptions + continue + } + p.isDescribed = true + // SQL Server MAX types (varchar(max), nvarchar(max), + // varbinary(max) are identified by size = 0 + if p.Size == 0 { + switch p.SQLType { + case api.SQL_VARBINARY: + p.SQLType = api.SQL_LONGVARBINARY + case api.SQL_VARCHAR: + p.SQLType = api.SQL_LONGVARCHAR + case api.SQL_WVARCHAR: + p.SQLType = api.SQL_WLONGVARCHAR + } + } + } + return ps, nil +} diff --git a/result.go b/result.go new file mode 100644 index 0000000..9d4a0da --- /dev/null +++ b/result.go @@ -0,0 +1,22 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package odbc + +import ( + "errors" +) + +type Result struct { + rowCount int64 +} + +func (r *Result) LastInsertId() (int64, error) { + // TODO(brainman): implement (*Result).LastInsertId + return 0, errors.New("not implemented") +} + +func (r *Result) RowsAffected() (int64, error) { + return r.rowCount, nil +} diff --git a/rows.go b/rows.go new file mode 100644 index 0000000..48883ed --- /dev/null +++ b/rows.go @@ -0,0 +1,80 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package odbc + +import ( + "database/sql/driver" + "io" + + "github.com/taoikaihatsu-dev/odbc/api" +) + +type Rows struct { + os *ODBCStmt +} + +func (r *Rows) Columns() []string { + names := make([]string, len(r.os.Cols)) + for i := 0; i < len(names); i++ { + names[i] = r.os.Cols[i].Name() + } + return names +} + +func (r *Rows) Next(dest []driver.Value) error { + ret := api.SQLFetch(r.os.h) + if ret == api.SQL_NO_DATA { + return io.EOF + } + if IsError(ret) { + return NewError("SQLFetch", r.os.h) + } + for i := range dest { + v, err := r.os.Cols[i].Value(r.os.h, i) + if err != nil { + return err + } + dest[i] = v + } + return nil +} + +func (r *Rows) Close() error { + return r.os.closeByRows() +} + +func (r *Rows) HasNextResultSet() bool { + return true +} + +func (r *Rows) NextResultSet() error { + ret := api.SQLMoreResults(r.os.h) + if ret == api.SQL_NO_DATA { + return io.EOF + } + if IsError(ret) { + return NewError("SQLMoreResults", r.os.h) + } + + err := r.os.BindColumns() + if err != nil { + return err + } + return nil +} + +// Implement RowsColumnTypeDatabaseTypeName interface in order to return column types. +// https://github.com/golang/go/blob/e22a14b7eb1e4a172d0c20d14a0d2433fdf20e5c/src/database/sql/driver/driver.go#L469-L477 +// +// From the docs: +// RowsColumnTypeDatabaseTypeName may be implemented by Rows. It should return the +// database system type name without the length. Type names should be uppercase. +// Examples of returned types: "VARCHAR", "NVARCHAR", "VARCHAR2", "CHAR", "TEXT", +// "DECIMAL", "SMALLINT", "INT", "BIGINT", "BOOL", "[]BIGINT", "JSONB", "XML", +// "TIMESTAMP". +func (rs *Rows) ColumnTypeDatabaseTypeName(i int) string { + return rs.os.Cols[i].DatabaseTypeName() + +} diff --git a/stats.go b/stats.go new file mode 100644 index 0000000..4f65050 --- /dev/null +++ b/stats.go @@ -0,0 +1,35 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package odbc + +import ( + "fmt" + "sync" + + "github.com/taoikaihatsu-dev/odbc/api" +) + +type Stats struct { + EnvCount int + ConnCount int + StmtCount int + mu sync.Mutex +} + +func (s *Stats) updateHandleCount(handleType api.SQLSMALLINT, change int) error { + s.mu.Lock() + defer s.mu.Unlock() + switch handleType { + case api.SQL_HANDLE_ENV: + s.EnvCount += change + case api.SQL_HANDLE_DBC: + s.ConnCount += change + case api.SQL_HANDLE_STMT: + s.StmtCount += change + default: + return fmt.Errorf("unexpected handle type %d", handleType) + } + return nil +} diff --git a/stmt.go b/stmt.go new file mode 100644 index 0000000..d8b2298 --- /dev/null +++ b/stmt.go @@ -0,0 +1,108 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package odbc + +import ( + "database/sql/driver" + "errors" + "sync" + + "github.com/taoikaihatsu-dev/odbc/api" +) + +type Stmt struct { + c *Conn + query string + os *ODBCStmt + mu sync.Mutex +} + +func (c *Conn) Prepare(query string) (driver.Stmt, error) { + if c.bad { + return nil, driver.ErrBadConn + } + os, err := c.PrepareODBCStmt(query) + if err != nil { + return nil, err + } + return &Stmt{c: c, os: os, query: query}, nil +} + +func (s *Stmt) NumInput() int { + if s.os == nil { + return -1 + } + return len(s.os.Parameters) +} + +func (s *Stmt) Close() error { + if s.os == nil { + return errors.New("Stmt is already closed") + } + ret := s.os.closeByStmt() + s.os = nil + return ret +} + +func (s *Stmt) Exec(args []driver.Value) (driver.Result, error) { + if s.os == nil { + return nil, errors.New("Stmt is closed") + } + s.mu.Lock() + defer s.mu.Unlock() + if s.os.usedByRows { + s.os.closeByStmt() + s.os = nil + os, err := s.c.PrepareODBCStmt(s.query) + if err != nil { + return nil, err + } + s.os = os + } + err := s.os.Exec(args, s.c) + if err != nil { + return nil, err + } + var sumRowCount int64 + for { + var c api.SQLLEN + ret := api.SQLRowCount(s.os.h, &c) + if IsError(ret) { + return nil, NewError("SQLRowCount", s.os.h) + } + sumRowCount += int64(c) + if ret = api.SQLMoreResults(s.os.h); ret == api.SQL_NO_DATA { + break + } + } + return &Result{rowCount: sumRowCount}, nil +} + +func (s *Stmt) Query(args []driver.Value) (driver.Rows, error) { + if s.os == nil { + return nil, errors.New("Stmt is closed") + } + s.mu.Lock() + defer s.mu.Unlock() + if s.os.usedByRows { + s.os.closeByStmt() + s.os = nil + os, err := s.c.PrepareODBCStmt(s.query) + if err != nil { + return nil, err + } + s.os = os + } + err := s.os.Exec(args, s.c) + if err != nil { + return nil, err + } + err = s.os.BindColumns() + if err != nil { + return nil, err + } + s.os.usedByRows = true // now both Stmt and Rows refer to it + return &Rows{os: s.os}, nil +} diff --git a/tx.go b/tx.go new file mode 100644 index 0000000..2400779 --- /dev/null +++ b/tx.go @@ -0,0 +1,77 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package odbc + +import ( + "database/sql/driver" + "errors" + + "github.com/taoikaihatsu-dev/odbc/api" +) + +type Tx struct { + c *Conn +} + +var testBeginErr error // used during tests + +func (c *Conn) setAutoCommitAttr(a uintptr) error { + if testBeginErr != nil { + return testBeginErr + } + ret := api.SQLSetConnectUIntPtrAttr(c.h, api.SQL_ATTR_AUTOCOMMIT, a, api.SQL_IS_UINTEGER) + if IsError(ret) { + return c.newError("SQLSetConnectUIntPtrAttr", c.h) + } + return nil +} + +func (c *Conn) Begin() (driver.Tx, error) { + if c.bad { + return nil, driver.ErrBadConn + } + if c.tx != nil { + return nil, errors.New("already in a transaction") + } + c.tx = &Tx{c: c} + err := c.setAutoCommitAttr(api.SQL_AUTOCOMMIT_OFF) + if err != nil { + c.bad = true + return nil, err + } + return c.tx, nil +} + +func (c *Conn) endTx(commit bool) error { + if c.tx == nil { + return errors.New("not in a transaction") + } + var howToEnd api.SQLSMALLINT + if commit { + howToEnd = api.SQL_COMMIT + } else { + howToEnd = api.SQL_ROLLBACK + } + ret := api.SQLEndTran(api.SQL_HANDLE_DBC, api.SQLHANDLE(c.h), howToEnd) + if IsError(ret) { + c.bad = true + return c.newError("SQLEndTran", c.h) + } + c.tx = nil + err := c.setAutoCommitAttr(api.SQL_AUTOCOMMIT_ON) + if err != nil { + c.bad = true + return err + } + return nil +} + +func (tx *Tx) Commit() error { + return tx.c.endTx(true) +} + +func (tx *Tx) Rollback() error { + return tx.c.endTx(false) +} diff --git a/utf16.go b/utf16.go new file mode 100644 index 0000000..8f3ddd0 --- /dev/null +++ b/utf16.go @@ -0,0 +1,55 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package odbc + +import ( + "unicode/utf16" + "unicode/utf8" +) + +const ( + replacementChar = '\uFFFD' // Unicode replacement character + + // 0xd800-0xdc00 encodes the high 10 bits of a pair. + // 0xdc00-0xe000 encodes the low 10 bits of a pair. + // the value is those 20 bits plus 0x10000. + surr1 = 0xd800 + surr2 = 0xdc00 + surr3 = 0xe000 +) + +// utf16toutf8 returns the UTF-8 encoding of the UTF-16 sequence s, +// with a terminating NUL removed. +func utf16toutf8(s []uint16) []byte { + for i, v := range s { + if v == 0 { + s = s[0:i] + break + } + } + buf := make([]byte, 0, len(s)*2) // allow 2 bytes for every rune + b := make([]byte, 4) + for i := 0; i < len(s); i++ { + var rr rune + switch r := s[i]; { + case surr1 <= r && r < surr2 && i+1 < len(s) && + surr2 <= s[i+1] && s[i+1] < surr3: + // valid surrogate sequence + rr = utf16.DecodeRune(rune(r), rune(s[i+1])) + i++ + case surr1 <= r && r < surr3: + // invalid surrogate sequence + rr = replacementChar + default: + // normal rune + rr = rune(r) + } + b := b[:cap(b)] + n := utf8.EncodeRune(b, rr) + b = b[:n] + buf = append(buf, b...) + } + return buf +}