-
Notifications
You must be signed in to change notification settings - Fork 3
/
benchmark_jdbc_py4j.py
63 lines (50 loc) · 2.25 KB
/
benchmark_jdbc_py4j.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os
from pathlib import Path
from py4j.java_gateway import JavaGateway
from utils import Timer, TIMER_TEXT, NUMBER_OF_RUNS, FlightDatabaseConnection, FLIGHT_DB, BENCHMARK_SQL_STATEMENT
SCRIPT_DIR = Path(__file__).parent.resolve()
def benchmark_jdbc_py4j(db: FlightDatabaseConnection = FLIGHT_DB,
query: str = BENCHMARK_SQL_STATEMENT
):
with Timer(name="\nJDBC - Py4J - Fetch data from lineitem table",
text=TIMER_TEXT,
initial_text=True
):
# Open JVM interface with the JDBC Jar
jdbc_jar_path = SCRIPT_DIR / "drivers" / "flight-sql-jdbc-driver-13.0.0.jar"
os.environ["_JAVA_OPTIONS"] = '--add-opens=java.base/java.nio=ALL-UNNAMED'
gateway = JavaGateway.launch_gateway(classpath=jdbc_jar_path.as_posix())
# Load the JDBC Jar
jdbc_class = "org.apache.arrow.driver.jdbc.ArrowFlightJdbcDriver"
gateway.jvm.Class.forName(jdbc_class)
# Initiate connection
jdbc_uri = (f"jdbc:arrow-flight-sql://{db.hostname}:{str(db.port)}?"
"useEncryption=true"
f"&user={db.username}&password={db.password}"
f"&disableCertificateVerification={str(db.disableCertificateVerification).lower()}"
)
con = gateway.jvm.java.sql.DriverManager.getConnection(jdbc_uri)
stmt = con.prepareStatement(query)
rs = stmt.executeQuery()
stmt.setFetchSize(10000)
metadata = rs.getMetaData()
column_count = metadata.getColumnCount()
print(f"Number of columns: {column_count}")
row_count = 0
while rs.next():
row_count += 1
rs.close()
stmt.close()
con.close()
print(f"Number of rows fetched: {row_count}")
if __name__ == "__main__":
import timeit
total_time = timeit.timeit(stmt="benchmark_jdbc_py4j()",
setup="from __main__ import benchmark_jdbc_py4j",
number=NUMBER_OF_RUNS
)
print((f"Number of runs: {NUMBER_OF_RUNS}\n"
f"Total time: {total_time}\n"
f"Average time: {total_time / float(NUMBER_OF_RUNS)}"
)
)