Skip to content

Commit

Permalink
Fix/fix tck ssl (#3570)
Browse files Browse the repository at this point in the history
* Fix some errors.

* add tck ssl

* fix ssl client

Co-authored-by: HarrisChu <1726587+HarrisChu@users.noreply.github.com>
  • Loading branch information
Shylock-Hg and HarrisChu authored Dec 28, 2021
1 parent 7398d80 commit 5a4a36b
Show file tree
Hide file tree
Showing 7 changed files with 102 additions and 40 deletions.
1 change: 1 addition & 0 deletions src/common/ssl/SSLConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ namespace nebula {
std::shared_ptr<wangle::SSLContextConfig> sslContextConfig() {
auto sslCfg = std::make_shared<wangle::SSLContextConfig>();
sslCfg->addCertificate(FLAGS_cert_path, FLAGS_key_path, FLAGS_password_path);
sslCfg->clientVerification = folly::SSLContext::VerifyClientCertificate::DO_NOT_REQUEST;
sslCfg->isDefault = true;
return sslCfg;
}
Expand Down
35 changes: 27 additions & 8 deletions tests/common/nebula_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
import signal
import copy
import fcntl
import logging
from pathlib import Path
from contextlib import closing

from tests.common.constants import TMP_DIR
from tests.common.utils import get_ssl_config
from nebula2.gclient.net import ConnectionPool
from nebula2.Config import Config

Expand Down Expand Up @@ -133,6 +135,11 @@ def __init__(
self.storaged_port = 0
self.graphd_port = 0
self.ca_signed = ca_signed
self.is_graph_ssl = (
kwargs.get("enable_graph_ssl", "false").upper() == "TRUE"
or kwargs.get("enable_ssl", "false").upper() == "TRUE"
)

self.debug_log = debug_log
self.ports_per_process = 4
self.lock_file = os.path.join(TMP_DIR, "cluster_port.lock")
Expand Down Expand Up @@ -200,14 +207,14 @@ def _make_params(self, **kwargs):
'expired_time_factor': 60,
}
if self.ca_signed:
_params['ca_path'] = 'share/resources/test.ca.pem'
_params['cert_path'] = 'share/resources/test.derive.crt'
_params['key_path'] = 'share/resources/test.derive.key'
_params['ca_path'] = 'share/resources/test.ca.pem'

else:
_params['ca_path'] = 'share/resources/test.ca.pem'
_params['cert_path'] = 'share/resources/test.ca.key'
_params['key_path'] = 'share/resources/test.ca.password'
_params['cert_path'] = 'share/resources/test.ca.pem'
_params['key_path'] = 'share/resources/test.ca.key'
_params['password_path'] = 'share/resources/test.ca.password'

if self.debug_log:
_params['v'] = '4'
Expand All @@ -218,6 +225,7 @@ def _make_params(self, **kwargs):
self.graphd_param['system_memory_high_watermark_ratio'] = '0.95'
self.graphd_param['num_rows_to_check_memory'] = '4'
self.graphd_param['session_reclaim_interval_secs'] = '2'

self.storaged_param = copy.copy(_params)
self.storaged_param['local_config'] = 'false'
self.storaged_param['raft_heartbeat_interval_secs'] = '30'
Expand All @@ -244,7 +252,9 @@ def _copy_nebula_conf(self):
os.makedirs(resources_dir)

# timezone file
shutil.copy(self.build_dir + '/../resources/date_time_zonespec.csv', resources_dir)
shutil.copy(
self.build_dir + '/../resources/date_time_zonespec.csv', resources_dir
)
shutil.copy(self.build_dir + '/../resources/gflags.json', resources_dir)
# cert files
shutil.copy(self.src_dir + '/tests/cert/test.ca.key', resources_dir)
Expand Down Expand Up @@ -365,14 +375,23 @@ def start(self):
# init connection pool
client_pool = ConnectionPool()
# assert client_pool.init([("127.0.0.1", int(self.graphd_port))], config)
assert client_pool.init([("127.0.0.1", self.graphd_processes[0].tcp_port)], config)
ssl_config = get_ssl_config(self.is_graph_ssl, self.ca_signed)
print("begin to add hosts")
assert client_pool.init(
[("127.0.0.1", self.graphd_processes[0].tcp_port)], config, ssl_config
)

cmd = "ADD HOSTS 127.0.0.1:" + str(self.storaged_processes[0].tcp_port) + " INTO NEW ZONE \"default_zone\""
print(cmd)
cmd = (
"ADD HOSTS 127.0.0.1:"
+ str(self.storaged_processes[0].tcp_port)
+ " INTO NEW ZONE \"default_zone\""
)
print("add hosts cmd is {}".format(cmd))

# get session from the pool
client = client_pool.get_session('root', 'nebula')
resp = client.execute(cmd)
assert resp.is_succeeded(), resp.error_msg()
client.release()

# wait nebula start
Expand Down
30 changes: 25 additions & 5 deletions tests/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,19 @@

import os
import re
import json
import random
import string
import time
import yaml
from typing import Pattern

from nebula2.Config import Config
from nebula2.Config import Config, SSL_config
from nebula2.common import ttypes as CommonTtypes
from nebula2.gclient.net import Session
from nebula2.gclient.net import ConnectionPool

from tests.common.constants import NB_TMP_PATH, NEBULA_HOME
from tests.common.csv_import import CSVImporter
from tests.common.path_value import PathVal
from tests.common.types import SpaceDesc
Expand Down Expand Up @@ -113,9 +115,12 @@ def compare_value(real, expect):
if eedge.type < 0:
esrc, edst = edst, esrc
# ignore props comparison
return rsrc == esrc and rdst == edst \
and redge.ranking == eedge.ranking \
return (
rsrc == esrc
and rdst == edst
and redge.ranking == eedge.ranking
and redge.name == eedge.name
)

return real == expect

Expand Down Expand Up @@ -433,13 +438,13 @@ def load_csv_data(
return space_desc


def get_conn_pool(host: str, port: int):
def get_conn_pool(host: str, port: int, ssl_config: SSL_config):
config = Config()
config.max_connection_pool_size = 20
config.timeout = 180000
# init connection pool
pool = ConnectionPool()
if not pool.init([(host, port)], config):
if not pool.init([(host, port)], config, ssl_config):
raise Exception("Fail to init connection pool.")
return pool

Expand All @@ -450,3 +455,18 @@ def parse_service_index(name: str):
if m and len(m.groups()) == 2:
return int(m.groups()[1])
return None

def get_ssl_config(is_graph_ssl: bool, ca_signed: bool):
if not is_graph_ssl:
return None
ssl_config = SSL_config()

if ca_signed:
ssl_config.ca_certs = os.path.join(NEBULA_HOME, 'tests/cert/test.ca.pem')
ssl_config.certfile = os.path.join(NEBULA_HOME, 'tests/cert/test.derive.crt')
ssl_config.keyfile = os.path.join(NEBULA_HOME, 'tests/cert/test.derive.key')
else:
ssl_config.ca_certs = os.path.join(NEBULA_HOME, 'tests/cert/test.ca.pem')
ssl_config.certfile = os.path.join(NEBULA_HOME, 'tests/cert/test.derive.crt')
ssl_config.keyfile = os.path.join(NEBULA_HOME, 'tests/cert/test.derive.key')
return ssl_config
30 changes: 19 additions & 11 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from tests.common.configs import all_configs
from tests.common.types import SpaceDesc
from tests.common.utils import get_conn_pool
from tests.common.utils import get_conn_pool, get_ssl_config
from tests.common.constants import NB_TMP_PATH, SPACE_TMP_PATH, BUILD_DIR, NEBULA_HOME
from tests.common.nebula_service import NebulaService

Expand Down Expand Up @@ -110,6 +110,17 @@ def get_ports():
raise Exception(f"Invalid port: {port}")
return port

def get_ssl_config_from_tmp():
with open(NB_TMP_PATH, "r") as f:
data = json.loads(f.readline())
is_graph_ssl = (
data.get("enable_ssl", "false").upper() == "TRUE"
or data.get("enable_graph_ssl", "false").upper() == "TRUE"
)
ca_signed = data.get("ca_signed", "false").upper() == "TRUE"
return get_ssl_config(is_graph_ssl, ca_signed)


@pytest.fixture(scope="class")
def class_fixture_variables():
"""save class scope fixture, used for session update.
Expand Down Expand Up @@ -140,7 +151,8 @@ def conn_pool_to_first_graph_service(pytestconfig):
addr = pytestconfig.getoption("address")
host_addr = addr.split(":") if addr else ["localhost", get_ports()[0]]
assert len(host_addr) == 2
pool = get_conn_pool(host_addr[0], host_addr[1])
ssl_config = get_ssl_config_from_tmp()
pool = get_conn_pool(host_addr[0], host_addr[1], ssl_config)
yield pool
pool.close()

Expand All @@ -150,7 +162,8 @@ def conn_pool_to_second_graph_service(pytestconfig):
addr = pytestconfig.getoption("address")
host_addr = ["localhost", get_ports()[1]]
assert len(host_addr) == 2
pool = get_conn_pool(host_addr[0], host_addr[1])
ssl_config = get_ssl_config_from_tmp()
pool = get_conn_pool(host_addr[0], host_addr[1], ssl_config)
yield pool
pool.close()

Expand Down Expand Up @@ -246,11 +259,6 @@ def workarround_for_class(
request.cls.drop_data()

@pytest.fixture(scope="class")
def establish_a_rare_connection(pytestconfig):
addr = pytestconfig.getoption("address")
host_addr = addr.split(":") if addr else ["localhost", get_ports()[0]]
socket = TSocket.TSocket(host_addr[0], host_addr[1])
transport = TTransport.TBufferedTransport(socket)
protocol = TBinaryProtocol.TBinaryProtocol(transport)
transport.open()
return GraphService.Client(protocol)
def establish_a_rare_connection(conn_pool, pytestconfig):
conn = conn_pool.get_connection()
return conn._connection
16 changes: 6 additions & 10 deletions tests/job/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,22 +139,19 @@ def test_sessions(self):

def test_the_same_id_to_different_graphd(self):
def get_connection(ip, port):
ssl_config = self.client_pool._ssl_configs
try:
socket = TSocket.TSocket(ip, port)
transport = TTransport.TBufferedTransport(socket)
protocol = TBinaryProtocol.TBinaryProtocol(transport)
transport.open()
connection = GraphService.Client(protocol)
conn = Connection()
conn.open_SSL(ip, port, 0, ssl_config)
except Exception as ex:
assert False, 'Create connection to {}:{} failed'.format(ip, port)
return connection
return conn

conn1 = get_connection(self.addr_host1, self.addr_port1)
conn2 = get_connection(self.addr_host2, self.addr_port2)

resp = conn1.authenticate('root', 'nebula')
assert resp.error_code == ttypes.ErrorCode.SUCCEEDED
session_id = resp.session_id
session_id = resp.get_session_id()

resp = conn1.execute(session_id, 'CREATE SPACE IF NOT EXISTS aSpace(partition_num=1, vid_type=FIXED_STRING(8));USE aSpace;')
self.check_resp_succeeded(ResultSet(resp, 0))
Expand Down Expand Up @@ -217,8 +214,7 @@ def test_out_of_max_connections(self):

def test_signout_and_execute(self):
try:
conn = Connection()
conn.open(self.addr_host1, self.addr_port1, 3000)
conn = self.client_pool.get_connection()
auth_result = conn.authenticate(self.user, self.password)
session_id = auth_result.get_session_id()
conn.signout(session_id)
Expand Down
25 changes: 21 additions & 4 deletions tests/nebula-test-run.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,14 @@
import os
import shutil
from tests.common.nebula_service import NebulaService
from tests.common.utils import get_conn_pool, load_csv_data
from tests.common.constants import NEBULA_HOME, TMP_DIR, NB_TMP_PATH, SPACE_TMP_PATH, BUILD_DIR
from tests.common.utils import get_conn_pool, load_csv_data, get_ssl_config
from tests.common.constants import (
NEBULA_HOME,
TMP_DIR,
NB_TMP_PATH,
SPACE_TMP_PATH,
BUILD_DIR,
)


CURR_PATH = os.path.dirname(os.path.abspath(__file__))
Expand Down Expand Up @@ -100,8 +106,12 @@ def start_nebula(nb, configs):
address = "localhost"
ports = nb.start()

is_graph_ssl = opt_is(configs.enable_ssl, "true") or opt_is(
configs.enable_graph_ssl, "true"
)
ca_signed = opt_is(configs.enable_ssl, "true")
# Load csv data
pool = get_conn_pool(address, ports[0])
pool = get_conn_pool(address, ports[0], get_ssl_config(is_graph_ssl, ca_signed))
sess = pool.get_session(configs.user, configs.password)

if not os.path.exists(TMP_DIR):
Expand All @@ -119,7 +129,14 @@ def start_nebula(nb, configs):
f.write(json.dumps(spaces))

with open(NB_TMP_PATH, "w") as f:
data = {"ip": "localhost", "port": ports, "work_dir": nb.work_dir}
data = {
"ip": "localhost",
"port": ports,
"work_dir": nb.work_dir,
"enable_ssl": configs.enable_ssl,
"enable_graph_ssl": configs.enable_graph_ssl,
"ca_signed": configs.ca_signed,
}
f.write(json.dumps(data))
print('Start nebula successfully')

Expand Down
5 changes: 3 additions & 2 deletions tests/tck/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,8 @@ def given_nebulacluster_with_param(
nebula_svc.start()
graph_ip = nebula_svc.graphd_processes[0].host
graph_port = nebula_svc.graphd_processes[0].tcp_port
pool = get_conn_pool(graph_ip, graph_port)
# TODO add ssl pool if tests needed
pool = get_conn_pool(graph_ip, graph_port, None)
sess = pool.get_session(user, password)
class_fixture_variables["current_session"] = sess
class_fixture_variables["sessions"].append(sess)
Expand All @@ -352,7 +353,7 @@ def when_login_graphd(graph, user, password, class_fixture_variables, pytestconf
assert index < len(nebula_svc.graphd_processes)
graphd_process = nebula_svc.graphd_processes[index]
graph_ip, graph_port = graphd_process.host, graphd_process.tcp_port
pool = get_conn_pool(graph_ip, graph_port)
pool = get_conn_pool(graph_ip, graph_port, None)
sess = pool.get_session(user, password)
# do not release original session, as we may have cases to test multiple sessions.
# connection could be released after cluster stopped.
Expand Down

0 comments on commit 5a4a36b

Please sign in to comment.