Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/fix #3570

Merged
merged 8 commits into from
Dec 28, 2021
Merged

Fix/fix #3570

Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/common/ssl/SSLConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ namespace nebula {
std::shared_ptr<wangle::SSLContextConfig> sslContextConfig() {
auto sslCfg = std::make_shared<wangle::SSLContextConfig>();
sslCfg->addCertificate(FLAGS_cert_path, FLAGS_key_path, FLAGS_password_path);
sslCfg->clientVerification = folly::SSLContext::VerifyClientCertificate::DO_NOT_REQUEST;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean that the certificate from the client could be omitted?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about using "always" and failed if verification fails

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not, it just check CA of server (which is new feature after library upgrading), so I disable it to keep same with before.

sslCfg->isDefault = true;
return sslCfg;
}
Expand Down
35 changes: 27 additions & 8 deletions tests/common/nebula_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
import signal
import copy
import fcntl
import logging
from pathlib import Path
from contextlib import closing

from tests.common.constants import TMP_DIR
from tests.common.utils import get_ssl_config
from nebula2.gclient.net import ConnectionPool
from nebula2.Config import Config

Expand Down Expand Up @@ -133,6 +135,11 @@ def __init__(
self.storaged_port = 0
self.graphd_port = 0
self.ca_signed = ca_signed
self.is_graph_ssl = (
kwargs.get("enable_graph_ssl", "false").upper() == "TRUE"
or kwargs.get("enable_ssl", "false").upper() == "TRUE"
)

self.debug_log = debug_log
self.ports_per_process = 4
self.lock_file = os.path.join(TMP_DIR, "cluster_port.lock")
Expand Down Expand Up @@ -200,14 +207,14 @@ def _make_params(self, **kwargs):
'expired_time_factor': 60,
}
if self.ca_signed:
_params['ca_path'] = 'share/resources/test.ca.pem'
_params['cert_path'] = 'share/resources/test.derive.crt'
_params['key_path'] = 'share/resources/test.derive.key'
_params['ca_path'] = 'share/resources/test.ca.pem'

else:
_params['ca_path'] = 'share/resources/test.ca.pem'
_params['cert_path'] = 'share/resources/test.ca.key'
_params['key_path'] = 'share/resources/test.ca.password'
_params['cert_path'] = 'share/resources/test.ca.pem'
_params['key_path'] = 'share/resources/test.ca.key'
_params['password_path'] = 'share/resources/test.ca.password'

if self.debug_log:
_params['v'] = '4'
Expand All @@ -218,6 +225,7 @@ def _make_params(self, **kwargs):
self.graphd_param['system_memory_high_watermark_ratio'] = '0.95'
self.graphd_param['num_rows_to_check_memory'] = '4'
self.graphd_param['session_reclaim_interval_secs'] = '2'

self.storaged_param = copy.copy(_params)
self.storaged_param['local_config'] = 'false'
self.storaged_param['raft_heartbeat_interval_secs'] = '30'
Expand All @@ -244,7 +252,9 @@ def _copy_nebula_conf(self):
os.makedirs(resources_dir)

# timezone file
shutil.copy(self.build_dir + '/../resources/date_time_zonespec.csv', resources_dir)
shutil.copy(
self.build_dir + '/../resources/date_time_zonespec.csv', resources_dir
)
shutil.copy(self.build_dir + '/../resources/gflags.json', resources_dir)
# cert files
shutil.copy(self.src_dir + '/tests/cert/test.ca.key', resources_dir)
Expand Down Expand Up @@ -365,14 +375,23 @@ def start(self):
# init connection pool
client_pool = ConnectionPool()
# assert client_pool.init([("127.0.0.1", int(self.graphd_port))], config)
assert client_pool.init([("127.0.0.1", self.graphd_processes[0].tcp_port)], config)
ssl_config = get_ssl_config(self.is_graph_ssl, self.ca_signed)
print("begin to add hosts")
assert client_pool.init(
[("127.0.0.1", self.graphd_processes[0].tcp_port)], config, ssl_config
)

cmd = "ADD HOSTS 127.0.0.1:" + str(self.storaged_processes[0].tcp_port) + " INTO NEW ZONE \"default_zone\""
print(cmd)
cmd = (
"ADD HOSTS 127.0.0.1:"
+ str(self.storaged_processes[0].tcp_port)
+ " INTO NEW ZONE \"default_zone\""
)
print("add hosts cmd is {}".format(cmd))

# get session from the pool
client = client_pool.get_session('root', 'nebula')
resp = client.execute(cmd)
assert resp.is_succeeded(), resp.error_msg()
client.release()

# wait nebula start
Expand Down
30 changes: 25 additions & 5 deletions tests/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,19 @@

import os
import re
import json
import random
import string
import time
import yaml
from typing import Pattern

from nebula2.Config import Config
from nebula2.Config import Config, SSL_config
from nebula2.common import ttypes as CommonTtypes
from nebula2.gclient.net import Session
from nebula2.gclient.net import ConnectionPool

from tests.common.constants import NB_TMP_PATH, NEBULA_HOME
from tests.common.csv_import import CSVImporter
from tests.common.path_value import PathVal
from tests.common.types import SpaceDesc
Expand Down Expand Up @@ -113,9 +115,12 @@ def compare_value(real, expect):
if eedge.type < 0:
esrc, edst = edst, esrc
# ignore props comparison
return rsrc == esrc and rdst == edst \
and redge.ranking == eedge.ranking \
return (
rsrc == esrc
and rdst == edst
and redge.ranking == eedge.ranking
and redge.name == eedge.name
)

return real == expect

Expand Down Expand Up @@ -433,13 +438,13 @@ def load_csv_data(
return space_desc


def get_conn_pool(host: str, port: int):
def get_conn_pool(host: str, port: int, ssl_config: SSL_config):
config = Config()
config.max_connection_pool_size = 20
config.timeout = 180000
# init connection pool
pool = ConnectionPool()
if not pool.init([(host, port)], config):
if not pool.init([(host, port)], config, ssl_config):
raise Exception("Fail to init connection pool.")
return pool

Expand All @@ -450,3 +455,18 @@ def parse_service_index(name: str):
if m and len(m.groups()) == 2:
return int(m.groups()[1])
return None

def get_ssl_config(is_graph_ssl: bool, ca_signed: bool):
if not is_graph_ssl:
return None
ssl_config = SSL_config()

if ca_signed:
ssl_config.ca_certs = os.path.join(NEBULA_HOME, 'tests/cert/test.ca.pem')
ssl_config.certfile = os.path.join(NEBULA_HOME, 'tests/cert/test.derive.crt')
ssl_config.keyfile = os.path.join(NEBULA_HOME, 'tests/cert/test.derive.key')
else:
ssl_config.ca_certs = os.path.join(NEBULA_HOME, 'tests/cert/test.ca.pem')
ssl_config.certfile = os.path.join(NEBULA_HOME, 'tests/cert/test.derive.crt')
ssl_config.keyfile = os.path.join(NEBULA_HOME, 'tests/cert/test.derive.key')
return ssl_config
30 changes: 19 additions & 11 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from tests.common.configs import all_configs
from tests.common.types import SpaceDesc
from tests.common.utils import get_conn_pool
from tests.common.utils import get_conn_pool, get_ssl_config
from tests.common.constants import NB_TMP_PATH, SPACE_TMP_PATH, BUILD_DIR, NEBULA_HOME
from tests.common.nebula_service import NebulaService

Expand Down Expand Up @@ -110,6 +110,17 @@ def get_ports():
raise Exception(f"Invalid port: {port}")
return port

def get_ssl_config_from_tmp():
with open(NB_TMP_PATH, "r") as f:
data = json.loads(f.readline())
is_graph_ssl = (
data.get("enable_ssl", "false").upper() == "TRUE"
or data.get("enable_graph_ssl", "false").upper() == "TRUE"
)
ca_signed = data.get("ca_signed", "false").upper() == "TRUE"
return get_ssl_config(is_graph_ssl, ca_signed)


@pytest.fixture(scope="class")
def class_fixture_variables():
"""save class scope fixture, used for session update.
Expand Down Expand Up @@ -140,7 +151,8 @@ def conn_pool_to_first_graph_service(pytestconfig):
addr = pytestconfig.getoption("address")
host_addr = addr.split(":") if addr else ["localhost", get_ports()[0]]
assert len(host_addr) == 2
pool = get_conn_pool(host_addr[0], host_addr[1])
ssl_config = get_ssl_config_from_tmp()
pool = get_conn_pool(host_addr[0], host_addr[1], ssl_config)
yield pool
pool.close()

Expand All @@ -150,7 +162,8 @@ def conn_pool_to_second_graph_service(pytestconfig):
addr = pytestconfig.getoption("address")
host_addr = ["localhost", get_ports()[1]]
assert len(host_addr) == 2
pool = get_conn_pool(host_addr[0], host_addr[1])
ssl_config = get_ssl_config_from_tmp()
pool = get_conn_pool(host_addr[0], host_addr[1], ssl_config)
yield pool
pool.close()

Expand Down Expand Up @@ -246,11 +259,6 @@ def workarround_for_class(
request.cls.drop_data()

@pytest.fixture(scope="class")
def establish_a_rare_connection(pytestconfig):
addr = pytestconfig.getoption("address")
host_addr = addr.split(":") if addr else ["localhost", get_ports()[0]]
socket = TSocket.TSocket(host_addr[0], host_addr[1])
transport = TTransport.TBufferedTransport(socket)
protocol = TBinaryProtocol.TBinaryProtocol(transport)
transport.open()
return GraphService.Client(protocol)
def establish_a_rare_connection(conn_pool, pytestconfig):
conn = conn_pool.get_connection()
return conn._connection
16 changes: 6 additions & 10 deletions tests/job/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,22 +139,19 @@ def test_sessions(self):

def test_the_same_id_to_different_graphd(self):
def get_connection(ip, port):
ssl_config = self.client_pool._ssl_configs
try:
socket = TSocket.TSocket(ip, port)
transport = TTransport.TBufferedTransport(socket)
protocol = TBinaryProtocol.TBinaryProtocol(transport)
transport.open()
connection = GraphService.Client(protocol)
conn = Connection()
conn.open_SSL(ip, port, 0, ssl_config)
except Exception as ex:
assert False, 'Create connection to {}:{} failed'.format(ip, port)
return connection
return conn

conn1 = get_connection(self.addr_host1, self.addr_port1)
conn2 = get_connection(self.addr_host2, self.addr_port2)

resp = conn1.authenticate('root', 'nebula')
assert resp.error_code == ttypes.ErrorCode.SUCCEEDED
session_id = resp.session_id
session_id = resp.get_session_id()

resp = conn1.execute(session_id, 'CREATE SPACE IF NOT EXISTS aSpace(partition_num=1, vid_type=FIXED_STRING(8));USE aSpace;')
self.check_resp_succeeded(ResultSet(resp, 0))
Expand Down Expand Up @@ -217,8 +214,7 @@ def test_out_of_max_connections(self):

def test_signout_and_execute(self):
try:
conn = Connection()
conn.open(self.addr_host1, self.addr_port1, 3000)
conn = self.client_pool.get_connection()
auth_result = conn.authenticate(self.user, self.password)
session_id = auth_result.get_session_id()
conn.signout(session_id)
Expand Down
25 changes: 21 additions & 4 deletions tests/nebula-test-run.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,14 @@
import os
import shutil
from tests.common.nebula_service import NebulaService
from tests.common.utils import get_conn_pool, load_csv_data
from tests.common.constants import NEBULA_HOME, TMP_DIR, NB_TMP_PATH, SPACE_TMP_PATH, BUILD_DIR
from tests.common.utils import get_conn_pool, load_csv_data, get_ssl_config
from tests.common.constants import (
NEBULA_HOME,
TMP_DIR,
NB_TMP_PATH,
SPACE_TMP_PATH,
BUILD_DIR,
)


CURR_PATH = os.path.dirname(os.path.abspath(__file__))
Expand Down Expand Up @@ -100,8 +106,12 @@ def start_nebula(nb, configs):
address = "localhost"
ports = nb.start()

is_graph_ssl = opt_is(configs.enable_ssl, "true") or opt_is(
configs.enable_graph_ssl, "true"
)
ca_signed = opt_is(configs.enable_ssl, "true")
# Load csv data
pool = get_conn_pool(address, ports[0])
pool = get_conn_pool(address, ports[0], get_ssl_config(is_graph_ssl, ca_signed))
sess = pool.get_session(configs.user, configs.password)

if not os.path.exists(TMP_DIR):
Expand All @@ -119,7 +129,14 @@ def start_nebula(nb, configs):
f.write(json.dumps(spaces))

with open(NB_TMP_PATH, "w") as f:
data = {"ip": "localhost", "port": ports, "work_dir": nb.work_dir}
data = {
"ip": "localhost",
"port": ports,
"work_dir": nb.work_dir,
"enable_ssl": configs.enable_ssl,
"enable_graph_ssl": configs.enable_graph_ssl,
"ca_signed": configs.ca_signed,
}
f.write(json.dumps(data))
print('Start nebula successfully')

Expand Down
5 changes: 3 additions & 2 deletions tests/tck/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,8 @@ def given_nebulacluster_with_param(
nebula_svc.start()
graph_ip = nebula_svc.graphd_processes[0].host
graph_port = nebula_svc.graphd_processes[0].tcp_port
pool = get_conn_pool(graph_ip, graph_port)
# TODO add ssl pool if tests needed
pool = get_conn_pool(graph_ip, graph_port, None)
sess = pool.get_session(user, password)
class_fixture_variables["current_session"] = sess
class_fixture_variables["sessions"].append(sess)
Expand All @@ -352,7 +353,7 @@ def when_login_graphd(graph, user, password, class_fixture_variables, pytestconf
assert index < len(nebula_svc.graphd_processes)
graphd_process = nebula_svc.graphd_processes[index]
graph_ip, graph_port = graphd_process.host, graphd_process.tcp_port
pool = get_conn_pool(graph_ip, graph_port)
pool = get_conn_pool(graph_ip, graph_port, None)
sess = pool.get_session(user, password)
# do not release original session, as we may have cases to test multiple sessions.
# connection could be released after cluster stopped.
Expand Down