diff --git a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionRepository.java b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionRepository.java index 2e16d5f01f..750a012684 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionRepository.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionRepository.java @@ -27,6 +27,7 @@ import org.opensearch.sql.expression.aggregation.AggregatorFunction; import org.opensearch.sql.expression.datetime.DateTimeFunction; import org.opensearch.sql.expression.datetime.IntervalClause; +import org.opensearch.sql.expression.ip.IpFunctions; import org.opensearch.sql.expression.operator.arthmetic.ArithmeticFunction; import org.opensearch.sql.expression.operator.arthmetic.MathematicalFunction; import org.opensearch.sql.expression.operator.convert.TypeCastOperator; @@ -81,6 +82,7 @@ public static synchronized BuiltinFunctionRepository getInstance() { TypeCastOperator.register(instance); SystemFunctions.register(instance); OpenSearchFunctions.register(instance); + IpFunctions.register(instance); } return instance; } diff --git a/core/src/main/java/org/opensearch/sql/expression/ip/CidrExpression.java b/core/src/main/java/org/opensearch/sql/expression/ip/CidrExpression.java new file mode 100644 index 0000000000..7ae74a5af7 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/expression/ip/CidrExpression.java @@ -0,0 +1,139 @@ +package org.opensearch.sql.expression.ip; + +import com.google.common.net.InetAddresses; +import lombok.EqualsAndHashCode; +import lombok.ToString; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.model.ExprValueUtils; +import org.opensearch.sql.data.type.ExprCoreType; +import org.opensearch.sql.data.type.ExprType; +import org.opensearch.sql.exception.ExpressionEvaluationException; +import org.opensearch.sql.exception.SemanticCheckException; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.FunctionExpression; +import org.opensearch.sql.expression.env.Environment; +import org.opensearch.sql.expression.function.FunctionName; + +import java.io.Serializable; +import java.math.BigInteger; +import java.net.Inet4Address; +import java.net.InetAddress; +import java.util.List; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +@ToString +@EqualsAndHashCode(callSuper = false) +public class CidrExpression extends FunctionExpression { + + private final Expression addressExpression; + private final InetAddressRange range; + + public CidrExpression(List arguments) { + super(FunctionName.of("cidr"), arguments); + + // Must be exactly two arguments. + if (arguments.size() != 2) { + String msg = String.format("Unexpected number of arguments to function '%s'. Expected %s, but found %s.", FunctionName.of("cidr"), 2, arguments.size()); + throw new ExpressionEvaluationException(msg); + } + + this.addressExpression = arguments.getFirst(); + this.range = new InetAddressRange(arguments.getLast().valueOf().stringValue()); + } + + @Override + public ExprValue valueOf(Environment valueEnv) { + ExprValue addressValue = addressExpression.valueOf(valueEnv); + if (addressValue.isNull() || addressValue.isMissing()) + return ExprValueUtils.nullValue(); + + String addressString = addressValue.stringValue(); + if (!InetAddresses.isInetAddress(addressString)) + return ExprValueUtils.nullValue(); + + InetAddress address = InetAddresses.forString(addressString); + return ExprValueUtils.booleanValue(range.contains(address)); + } + + @Override + public ExprType type() { + return ExprCoreType.BOOLEAN; + } + + /** + * Represents an IP address range. + * Supports both IPv4 and IPv6 addresses. + */ + private class InetAddressRange implements Serializable { + + // Basic CIDR notation pattern. + private static final Pattern cidrPattern = Pattern.compile("(?
.+)[/](?[0-9]+)"); + + // Lower/upper bounds for the IP address range. + private final InetAddress lowerBound; + private final InetAddress upperBound; + + /** + * Builds a new IP address range from the given CIDR notation string. + * + * @param cidr CIDR notation string (e.g. "198.51.100.0/24" or "2001:0db8::/32") + */ + public InetAddressRange(String cidr) { + + // Parse address and network length. + Matcher cidrMatcher = cidrPattern.matcher(cidr); + if (!cidrMatcher.matches()) + throw new SemanticCheckException(String.format("CIDR notation '%s' in not valid", range)); + + String addressString = cidrMatcher.group("address"); + if (!InetAddresses.isInetAddress(addressString)) + throw new SemanticCheckException(String.format("IP address '%s' in not valid", addressString)); + + InetAddress address = InetAddresses.forString(addressString); + + int networkLengthBits = Integer.parseInt(cidrMatcher.group("prefix")); + int addressLengthBits = address.getAddress().length * Byte.SIZE; + + if (networkLengthBits > addressLengthBits) + throw new SemanticCheckException(String.format("Network length of '%s' bits is not valid", networkLengthBits)); + + // Build bounds by converting the address to an integer, setting all the non-significant bits to + // zero for the lower bounds and one for the upper bounds, and then converting back to addresses. + BigInteger lowerBoundInt = InetAddresses.toBigInteger(address); + BigInteger upperBoundInt = InetAddresses.toBigInteger(address); + + int hostLengthBits = addressLengthBits - networkLengthBits; + for (int bit = 0; bit < hostLengthBits; bit++) { + lowerBoundInt = lowerBoundInt.clearBit(bit); + upperBoundInt = upperBoundInt.setBit(bit); + } + + if (address instanceof Inet4Address) { + lowerBound = InetAddresses.fromIPv4BigInteger(lowerBoundInt); + upperBound = InetAddresses.fromIPv4BigInteger(upperBoundInt); + } else { + lowerBound = InetAddresses.fromIPv6BigInteger(lowerBoundInt); + upperBound = InetAddresses.fromIPv6BigInteger(upperBoundInt); + } + } + + /** + * Returns whether the IP address is contained within the range. + * + * @param address IPv4 or IPv6 address, represented as a {@link BigInteger}. + * (see {@link InetAddresses#toBigInteger(InetAddress)}). + */ + public boolean contains(InetAddress address) { + + if ((address instanceof Inet4Address) ^ (lowerBound instanceof Inet4Address)) return false; + + BigInteger addressInt = InetAddresses.toBigInteger(address); + + if (addressInt.compareTo(InetAddresses.toBigInteger(lowerBound)) < 0) return false; + if (addressInt.compareTo(InetAddresses.toBigInteger(upperBound)) <= 0) return false; + + return true; + } + } +} diff --git a/core/src/main/java/org/opensearch/sql/expression/ip/IpFunctions.java b/core/src/main/java/org/opensearch/sql/expression/ip/IpFunctions.java new file mode 100644 index 0000000000..5582ff89ef --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/expression/ip/IpFunctions.java @@ -0,0 +1,32 @@ +package org.opensearch.sql.expression.ip; + +import lombok.experimental.UtilityClass; +import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.sql.expression.function.*; + +import java.util.Arrays; + +import static org.opensearch.sql.data.type.ExprCoreType.*; +import static org.opensearch.sql.expression.function.FunctionDSL.define; + +/** + * Utility class that defines and registers IP functions. + */ +@UtilityClass +public class IpFunctions { + + /** + * Registers all IP functions with the given built-in function repository. + */ + public void register(BuiltinFunctionRepository repository) { + repository.register(cidr()); + } + + private DefaultFunctionResolver cidr() { + + FunctionName name = BuiltinFunctionName.CIDR.getName(); + FunctionSignature signature = new FunctionSignature(name, Arrays.asList(STRING, STRING)); + + return define(name, funcName -> Pair.of(signature,(properties, arguments) -> new CidrExpression(arguments))); + } +} diff --git a/core/src/main/java/org/opensearch/sql/expression/operator/predicate/BinaryPredicateOperator.java b/core/src/main/java/org/opensearch/sql/expression/operator/predicate/BinaryPredicateOperator.java index 58482fda2c..bf6b3c22f5 100644 --- a/core/src/main/java/org/opensearch/sql/expression/operator/predicate/BinaryPredicateOperator.java +++ b/core/src/main/java/org/opensearch/sql/expression/operator/predicate/BinaryPredicateOperator.java @@ -26,7 +26,6 @@ import org.opensearch.sql.expression.function.BuiltinFunctionName; import org.opensearch.sql.expression.function.BuiltinFunctionRepository; import org.opensearch.sql.expression.function.DefaultFunctionResolver; -import org.opensearch.sql.utils.IPUtils; import org.opensearch.sql.utils.OperatorUtils; /** @@ -56,7 +55,6 @@ public static void register(BuiltinFunctionRepository repository) { repository.register(like()); repository.register(notLike()); repository.register(regexp()); - repository.register(cidr()); } /** @@ -398,12 +396,6 @@ private static DefaultFunctionResolver regexp() { impl(nullMissingHandling(OperatorUtils::matchesRegexp), INTEGER, STRING, STRING)); } - private static DefaultFunctionResolver cidr() { - return define( - BuiltinFunctionName.CIDR.getName(), - impl(nullMissingHandling(IPUtils::isAddressInRange), BOOLEAN, STRING, STRING)); - } - private static DefaultFunctionResolver notLike() { return define( BuiltinFunctionName.NOT_LIKE.getName(), diff --git a/core/src/main/java/org/opensearch/sql/utils/IPUtils.java b/core/src/main/java/org/opensearch/sql/utils/IPUtils.java deleted file mode 100644 index 065f158a54..0000000000 --- a/core/src/main/java/org/opensearch/sql/utils/IPUtils.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.utils; - -import com.google.common.net.InetAddresses; -import org.opensearch.sql.data.model.ExprBooleanValue; -import org.opensearch.sql.data.model.ExprValue; - -import java.util.Arrays; - -public class IPUtils { - - /** - * Returns whether the given IP address is within the specified IP address range. - * Supports both IPv4 and IPv6 addresses. - * - * @param addressExprValue IP address (e.g. "198.51.100.14" or "2001:0db8::ff00:42:8329"). - * @param rangeExprValue IP address range in CIDR notation (e.g. "198.51.100.0/24" or "2001:0db8::/32") - * @return true if address is in range; else false - */ - public static ExprBooleanValue isAddressInRange(ExprValue addressExprValue, ExprValue rangeExprValue) { - - try { - byte[] addressBytes = InetAddresses.forString(addressExprValue.stringValue()).getAddress(); - - String[] rangeFields = rangeExprValue.stringValue().split("/"); - int prefixLengthBytes = Integer.parseInt(rangeFields[1]) / Byte.SIZE; - byte[] rangeBytes = Arrays.copyOfRange(InetAddresses.forString(rangeFields[0]).getAddress(), 0, prefixLengthBytes); - - return ExprBooleanValue.of(Arrays.equals(addressBytes, 0, prefixLengthBytes, rangeBytes, 0, prefixLengthBytes)); - } catch (Exception e) { - return ExprBooleanValue.of(false); - } - } -} diff --git a/core/src/test/java/org/opensearch/sql/expression/ip/CidrExpressionTest.java b/core/src/test/java/org/opensearch/sql/expression/ip/CidrExpressionTest.java new file mode 100644 index 0000000000..d46b916742 --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/expression/ip/CidrExpressionTest.java @@ -0,0 +1,104 @@ +package org.opensearch.sql.expression.ip; + +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.model.ExprValueUtils; +import org.opensearch.sql.exception.ExpressionEvaluationException; +import org.opensearch.sql.exception.SemanticCheckException; +import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.ExpressionTestBase; +import org.opensearch.sql.expression.FunctionExpression; +import org.opensearch.sql.expression.env.Environment; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.Mockito.when; +import static org.opensearch.sql.data.model.ExprValueUtils.*; +import static org.opensearch.sql.data.type.ExprCoreType.STRING; + +@ExtendWith(MockitoExtension.class) +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) +public class CidrExpressionTest extends ExpressionTestBase { + + // IP range and address constants for testing. + static final ExprValue IPv4Range = ExprValueUtils.stringValue("198.51.100.0/24"); + static final ExprValue IPv6Range = ExprValueUtils.stringValue("2001:0db8::/32"); + + static final ExprValue IPv4AddressBelow = ExprValueUtils.stringValue("198.51.99.1"); + static final ExprValue IPv4AddressWithin = ExprValueUtils.stringValue("198.51.100.1"); + static final ExprValue IPv4AddressAbove = ExprValueUtils.stringValue("198.51.101.2"); + + static final ExprValue IPv6AddressBelow = ExprValueUtils.stringValue("2001:0db7::ff00:42:8329"); + static final ExprValue IPv6AddressWithin = ExprValueUtils.stringValue("2001:0db8::ff00:42:8329"); + static final ExprValue IPv6AddressAbove = ExprValueUtils.stringValue("2001:0db9::ff00:42:8329"); + + // Mock value environment for testing. + @Mock + Environment env; + + @Test + public void test_invalid_num_arguments() { + assertThrows(ExpressionEvaluationException.class, DSL::cidr); + assertThrows(ExpressionEvaluationException.class, () -> DSL.cidr(DSL.literal(0), DSL.literal(0), DSL.literal(0))); + } + + @Test + public void test_null_and_missing() { + assertEquals(LITERAL_NULL, execute(LITERAL_NULL, IPv4Range)); + assertEquals(LITERAL_NULL, execute(LITERAL_MISSING, IPv4Range)); + } + + @Test + public void test_invalid_address() { + assertEquals(LITERAL_NULL, execute(ExprValueUtils.stringValue("INVALID"), IPv4Range)); + } + + @Test + public void test_invalid_range() { + assertThrows(SemanticCheckException.class, () -> execute(IPv4AddressWithin, ExprValueUtils.stringValue("INVALID"))); + assertThrows(SemanticCheckException.class, () -> execute(IPv4AddressWithin, ExprValueUtils.stringValue("INVALID/32"))); + assertThrows(SemanticCheckException.class, () -> execute(IPv4AddressWithin, ExprValueUtils.stringValue("198.51.100.0/33"))); + } + + @Test + public void test_valid_ipv4() { + assertEquals(LITERAL_FALSE, execute(IPv4AddressBelow, IPv4Range)); + assertEquals(LITERAL_TRUE, execute(IPv4AddressWithin, IPv4Range)); + assertEquals(LITERAL_FALSE, execute(IPv4AddressAbove, IPv4Range)); + } + + @Test + public void test_valid_ipv6() { + assertEquals(LITERAL_FALSE, execute(IPv6AddressBelow, IPv6Range)); + assertEquals(LITERAL_TRUE, execute(IPv6AddressWithin, IPv6Range)); + assertEquals(LITERAL_FALSE, execute(IPv6AddressAbove, IPv6Range)); + } + + @Test + public void test_valid_different_versions() { + assertEquals(LITERAL_FALSE, execute(IPv4AddressWithin, IPv6Range)); + assertEquals(LITERAL_FALSE, execute(IPv6AddressWithin, IPv4Range)); + } + + /** + * Builds and evaluates a CIDR function expression with the given field + * and range expression values, and returns the resulting value. + */ + private ExprValue execute(ExprValue field, ExprValue range) { + + final String fieldName = "ip_address"; + FunctionExpression exp = DSL.cidr(DSL.ref(fieldName, STRING), DSL.literal(range)); + + // Mock the value environment to return the specified field + // expression as the value for the "ip_address" field. + when(DSL.ref(fieldName, STRING).valueOf(env)).thenReturn(field); + + return exp.valueOf(env); + } +} diff --git a/docs/user/ppl/functions/ip.rst b/docs/user/ppl/functions/ip.rst index cd213918fd..8b9e370198 100644 --- a/docs/user/ppl/functions/ip.rst +++ b/docs/user/ppl/functions/ip.rst @@ -22,12 +22,11 @@ Return type: BOOLEAN Example:: - os> source=devices | where cidr(address, "198.51.100.0/24") - fetched rows / total rows = 2/2 - +----------------+----------------+ - | name | address | - |----------------+----------------+ - | John's Macbook | 198.51.100.2 | - | Iain's PC | 198.51.100.254 | - +----------------+----------------+ + os> source=weblogs | where cidr(address, "199.120.110.0/24") | fields host, method, url + fetched rows / total rows = 1/1 + +----------------+--------+----------------------------------------------+ + | host | method | url | + |----------------+--------+----------------------------------------------+ + | 199.120.110.21 | GET | /shuttle/missions/sts-73/mission-sts-73.html | + +----------------+--------+----------------------------------------------+