Skip to content

Commit

Permalink
Update default channel pool options when creating grpc channels (#887)
Browse files Browse the repository at this point in the history
* Updating channel pool so that it doesn't limit message size for a channel and will bypass proxies when target address is on the localhost.

* Brad's feedback.

* Formatting to conform to style guide.
  • Loading branch information
jasonmreding authored Sep 17, 2024
1 parent e3ca88b commit 4bb38e9
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,13 @@

from __future__ import annotations

import ipaddress
import re
import sys
from threading import Lock
from types import TracebackType
from typing import (
Dict,
Literal,
Optional,
Type,
TYPE_CHECKING,
)
from typing import TYPE_CHECKING, Dict, Literal, Optional, Type
from urllib.parse import urlparse

import grpc

Expand Down Expand Up @@ -57,9 +54,7 @@ def get_channel(self, target: str) -> grpc.Channel:
with self._lock:
if target not in self._channel_cache:
self._lock.release()
new_channel = grpc.insecure_channel(target)
if ClientLogger.is_enabled():
new_channel = grpc.intercept_channel(new_channel, ClientLogger())
new_channel = self._create_channel(target)
self._lock.acquire()
if target not in self._channel_cache:
self._channel_cache[target] = new_channel
Expand All @@ -78,3 +73,44 @@ def close(self) -> None:
for channel in self._channel_cache.values():
channel.close()
self._channel_cache.clear()

def _create_channel(self, target: str) -> grpc.Channel:
options = [
("grpc.max_receive_message_length", -1),
("grpc.max_send_message_length", -1),
]
if self._is_local(target):
options.append(("grpc.enable_http_proxy", 0))
channel = grpc.insecure_channel(target, options)
if ClientLogger.is_enabled():
channel = grpc.intercept_channel(channel, ClientLogger())
return channel

def _is_local(self, target: str) -> bool:
hostname = ""
# First, check if the target string is in URL format
parse_result = urlparse(target)
if parse_result.scheme and parse_result.hostname and parse_result.port:
hostname = parse_result.hostname
else:
# Next, check for target string in <host_name>:<port> format
match = re.match(r"^(.*):(\d+)$", target)
if match:
hostname = match.group(1)

if not hostname:
return False
if hostname == "localhost" or hostname == "LOCALHOST":
return True

# IPv6 addresses don't support parsing with leading/trailing brackets
# so we need to remove them.
match = re.match(r"^\[(.*)\]$", hostname)
if match:
hostname = match.group(1)

try:
address = ipaddress.ip_address(hostname)
return address.is_loopback
except ValueError:
return False
1 change: 1 addition & 0 deletions packages/service/tests/unit/grpc/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Unit tests for ni_measurement_plugin_sdk_service.grpc."""
1 change: 1 addition & 0 deletions packages/service/tests/unit/grpc/channelpool/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Unit tests for ni_measurement_plugin_sdk_service.grpc.channelpool."""
32 changes: 32 additions & 0 deletions packages/service/tests/unit/grpc/channelpool/test_channel_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import pytest

from ni_measurement_plugin_sdk_service.grpc.channelpool import GrpcChannelPool


@pytest.mark.parametrize(
"target,expected_result",
[
("127.0.0.1", False), # Port must be specified explicitly
("[::1]", False), # Port must be specified explicitly
("localhost", False), # Port must be specified explicitly
("127.0.0.1:100", True),
("[::1]:100", True),
("localhost:100", True),
("http://127.0.0.1", False), # Port must be specified explicitly
("http://[::1]", False), # Port must be specified explicitly
("http://localhost", False), # Port must be specified explicitly
("http://127.0.0.1:100", True),
("http://[::1]:100", True),
("http://localhost:100", True),
("1.1.1.1:100", False),
("http://www.google.com:80", False),
],
)
def test___channel_pool___is_local___returns_expected_result(
target: str, expected_result: bool
) -> None:
channel_pool = GrpcChannelPool()

result = channel_pool._is_local(target)

assert result == expected_result

0 comments on commit 4bb38e9

Please sign in to comment.