diff --git a/scripts/neighbor_advertiser b/scripts/neighbor_advertiser index 62124c7400..a242805ed4 100644 --- a/scripts/neighbor_advertiser +++ b/scripts/neighbor_advertiser @@ -169,9 +169,11 @@ def get_loopback_addr(ip_ver): def get_vlan_interfaces(): vlan_info = config_db.get_table('VLAN') vlan_interfaces = [] - + vlan_intfs = config_db.get_table('VLAN_INTERFACE') + # Skip L2 VLANs for vlan_name in vlan_info: - vlan_interfaces.append(vlan_name) + if vlan_name in vlan_intfs: + vlan_interfaces.append(vlan_name) return vlan_interfaces @@ -479,6 +481,14 @@ def reset_mirror_tunnel(): # Set vxlan tunnel # +def check_existing_tunnel(): + vxlan_tunnel = config_db.get_table('VXLAN_TUNNEL') + if len(vxlan_tunnel): + global VXLAN_TUNNEL_NAME + VXLAN_TUNNEL_NAME = list(vxlan_tunnel.keys())[0] + return True + return False + def add_vxlan_tunnel(dst_ipv4_addr): vxlan_tunnel_info = { 'src_ip': get_loopback_addr(4), @@ -494,12 +504,12 @@ def add_vxlan_tunnel_map(): 'vni': get_vlan_interface_vxlan_id(vlan_intf_name), 'vlan': vlan_intf_name } - config_db.set_entry('VXLAN_TUNNEL_MAP', (VXLAN_TUNNEL_NAME, VXLAN_TUNNEL_MAP_PREFIX + str(index)), vxlan_tunnel_map_info) def set_vxlan_tunnel(ferret_server_ip): - add_vxlan_tunnel(ferret_server_ip) + if not check_existing_tunnel(): + add_vxlan_tunnel(ferret_server_ip) add_vxlan_tunnel_map() log.log_info('Finish setting vxlan tunnel; Ferret: {}'.format(ferret_server_ip)) diff --git a/tests/neighbor_advertiser_test.py b/tests/neighbor_advertiser_test.py index 3a3aeba39f..cb908be888 100644 --- a/tests/neighbor_advertiser_test.py +++ b/tests/neighbor_advertiser_test.py @@ -49,3 +49,12 @@ def test_neighbor_advertiser_slice(self, set_up): } ) assert output == expected_output + + def test_set_vxlan(self, set_up): + assert(neighbor_advertiser.check_existing_tunnel()) + neighbor_advertiser.add_vxlan_tunnel_map() + tunnel_mapping = neighbor_advertiser.config_db.get_table('VXLAN_TUNNEL_MAP') + expected_mapping = {("vtep1", "map_1"): {"vni": "1000", "vlan": "Vlan1000"}, ("vtep1", "map_2"): {"vni": "2000", "vlan": "Vlan2000"}} + for key in expected_mapping.keys(): + assert(key in tunnel_mapping.keys()) + assert(expected_mapping[key] == tunnel_mapping[key])