Skip to content

Commit

Permalink
Update to use CidrExpression:
Browse files Browse the repository at this point in the history
- Add `CidrExpression`, which is a `FunctionImplementation`.
- Add `CidrExressionTest` (unit tests).
- Remove `IPUtils` (logic moved to `CidrExpression`).
- Update documentation in `ip.rst`.
- Add `IpFunctions` utility class.
- Remove previous `IPUtils` utility class.

Signed-off-by: currantw <taylor.curran@improving.com>
  • Loading branch information
currantw committed Oct 29, 2024
1 parent 809e7e6 commit 6c9c58d
Show file tree
Hide file tree
Showing 7 changed files with 284 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.opensearch.sql.expression.aggregation.AggregatorFunction;
import org.opensearch.sql.expression.datetime.DateTimeFunction;
import org.opensearch.sql.expression.datetime.IntervalClause;
import org.opensearch.sql.expression.ip.IpFunctions;
import org.opensearch.sql.expression.operator.arthmetic.ArithmeticFunction;
import org.opensearch.sql.expression.operator.arthmetic.MathematicalFunction;
import org.opensearch.sql.expression.operator.convert.TypeCastOperator;
Expand Down Expand Up @@ -81,6 +82,7 @@ public static synchronized BuiltinFunctionRepository getInstance() {
TypeCastOperator.register(instance);
SystemFunctions.register(instance);
OpenSearchFunctions.register(instance);
IpFunctions.register(instance);
}
return instance;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
package org.opensearch.sql.expression.ip;

import com.google.common.net.InetAddresses;
import lombok.EqualsAndHashCode;
import lombok.ToString;
import org.opensearch.sql.data.model.ExprValue;
import org.opensearch.sql.data.model.ExprValueUtils;
import org.opensearch.sql.data.type.ExprCoreType;
import org.opensearch.sql.data.type.ExprType;
import org.opensearch.sql.exception.ExpressionEvaluationException;
import org.opensearch.sql.exception.SemanticCheckException;
import org.opensearch.sql.expression.Expression;
import org.opensearch.sql.expression.FunctionExpression;
import org.opensearch.sql.expression.env.Environment;
import org.opensearch.sql.expression.function.FunctionName;

import java.io.Serializable;
import java.math.BigInteger;
import java.net.Inet4Address;
import java.net.InetAddress;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

@ToString
@EqualsAndHashCode(callSuper = false)
public class CidrExpression extends FunctionExpression {

private final Expression addressExpression;
private final InetAddressRange range;

public CidrExpression(List<Expression> arguments) {
super(FunctionName.of("cidr"), arguments);

// Must be exactly two arguments.
if (arguments.size() != 2) {
String msg = String.format("Unexpected number of arguments to function '%s'. Expected %s, but found %s.", FunctionName.of("cidr"), 2, arguments.size());
throw new ExpressionEvaluationException(msg);
}

this.addressExpression = arguments.getFirst();
this.range = new InetAddressRange(arguments.getLast().valueOf().stringValue());
}

@Override
public ExprValue valueOf(Environment<Expression, ExprValue> valueEnv) {
ExprValue addressValue = addressExpression.valueOf(valueEnv);
if (addressValue.isNull() || addressValue.isMissing())
return ExprValueUtils.nullValue();

String addressString = addressValue.stringValue();
if (!InetAddresses.isInetAddress(addressString))
return ExprValueUtils.nullValue();

InetAddress address = InetAddresses.forString(addressString);
return ExprValueUtils.booleanValue(range.contains(address));
}

@Override
public ExprType type() {
return ExprCoreType.BOOLEAN;
}

/**
* Represents an IP address range.
* Supports both IPv4 and IPv6 addresses.
*/
private class InetAddressRange implements Serializable {

// Basic CIDR notation pattern.
private static final Pattern cidrPattern = Pattern.compile("(?<address>.+)[/](?<prefix>[0-9]+)");

// Lower/upper bounds for the IP address range.
private final InetAddress lowerBound;
private final InetAddress upperBound;

/**
* Builds a new IP address range from the given CIDR notation string.
*
* @param cidr CIDR notation string (e.g. "198.51.100.0/24" or "2001:0db8::/32")
*/
public InetAddressRange(String cidr) {

// Parse address and network length.
Matcher cidrMatcher = cidrPattern.matcher(cidr);
if (!cidrMatcher.matches())
throw new SemanticCheckException(String.format("CIDR notation '%s' in not valid", range));

String addressString = cidrMatcher.group("address");
if (!InetAddresses.isInetAddress(addressString))
throw new SemanticCheckException(String.format("IP address '%s' in not valid", addressString));

InetAddress address = InetAddresses.forString(addressString);

int networkLengthBits = Integer.parseInt(cidrMatcher.group("prefix"));
int addressLengthBits = address.getAddress().length * Byte.SIZE;

if (networkLengthBits > addressLengthBits)
throw new SemanticCheckException(String.format("Network length of '%s' bits is not valid", networkLengthBits));

// Build bounds by converting the address to an integer, setting all the non-significant bits to
// zero for the lower bounds and one for the upper bounds, and then converting back to addresses.
BigInteger lowerBoundInt = InetAddresses.toBigInteger(address);
BigInteger upperBoundInt = InetAddresses.toBigInteger(address);

int hostLengthBits = addressLengthBits - networkLengthBits;
for (int bit = 0; bit < hostLengthBits; bit++) {
lowerBoundInt = lowerBoundInt.clearBit(bit);
upperBoundInt = upperBoundInt.setBit(bit);
}

if (address instanceof Inet4Address) {
lowerBound = InetAddresses.fromIPv4BigInteger(lowerBoundInt);
upperBound = InetAddresses.fromIPv4BigInteger(upperBoundInt);
} else {
lowerBound = InetAddresses.fromIPv6BigInteger(lowerBoundInt);
upperBound = InetAddresses.fromIPv6BigInteger(upperBoundInt);
}
}

/**
* Returns whether the IP address is contained within the range.
*
* @param address IPv4 or IPv6 address, represented as a {@link BigInteger}.
* (see {@link InetAddresses#toBigInteger(InetAddress)}).
*/
public boolean contains(InetAddress address) {

if ((address instanceof Inet4Address) ^ (lowerBound instanceof Inet4Address)) return false;

BigInteger addressInt = InetAddresses.toBigInteger(address);

if (addressInt.compareTo(InetAddresses.toBigInteger(lowerBound)) < 0) return false;
if (addressInt.compareTo(InetAddresses.toBigInteger(upperBound)) <= 0) return false;

return true;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package org.opensearch.sql.expression.ip;

import lombok.experimental.UtilityClass;
import org.apache.commons.lang3.tuple.Pair;
import org.opensearch.sql.expression.function.*;

import java.util.Arrays;

import static org.opensearch.sql.data.type.ExprCoreType.*;
import static org.opensearch.sql.expression.function.FunctionDSL.define;

/**
* Utility class that defines and registers IP functions.
*/
@UtilityClass
public class IpFunctions {

/**
* Registers all IP functions with the given built-in function repository.
*/
public void register(BuiltinFunctionRepository repository) {
repository.register(cidr());
}

private DefaultFunctionResolver cidr() {

FunctionName name = BuiltinFunctionName.CIDR.getName();
FunctionSignature signature = new FunctionSignature(name, Arrays.asList(STRING, STRING));

return define(name, funcName -> Pair.of(signature,(properties, arguments) -> new CidrExpression(arguments)));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import org.opensearch.sql.expression.function.BuiltinFunctionName;
import org.opensearch.sql.expression.function.BuiltinFunctionRepository;
import org.opensearch.sql.expression.function.DefaultFunctionResolver;
import org.opensearch.sql.utils.IPUtils;
import org.opensearch.sql.utils.OperatorUtils;

/**
Expand Down Expand Up @@ -56,7 +55,6 @@ public static void register(BuiltinFunctionRepository repository) {
repository.register(like());
repository.register(notLike());
repository.register(regexp());
repository.register(cidr());
}

/**
Expand Down Expand Up @@ -398,12 +396,6 @@ private static DefaultFunctionResolver regexp() {
impl(nullMissingHandling(OperatorUtils::matchesRegexp), INTEGER, STRING, STRING));
}

private static DefaultFunctionResolver cidr() {
return define(
BuiltinFunctionName.CIDR.getName(),
impl(nullMissingHandling(IPUtils::isAddressInRange), BOOLEAN, STRING, STRING));
}

private static DefaultFunctionResolver notLike() {
return define(
BuiltinFunctionName.NOT_LIKE.getName(),
Expand Down
38 changes: 0 additions & 38 deletions core/src/main/java/org/opensearch/sql/utils/IPUtils.java

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package org.opensearch.sql.expression.ip;

import org.junit.jupiter.api.DisplayNameGeneration;
import org.junit.jupiter.api.DisplayNameGenerator;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.opensearch.sql.data.model.ExprValue;
import org.opensearch.sql.data.model.ExprValueUtils;
import org.opensearch.sql.exception.ExpressionEvaluationException;
import org.opensearch.sql.exception.SemanticCheckException;
import org.opensearch.sql.expression.DSL;
import org.opensearch.sql.expression.Expression;
import org.opensearch.sql.expression.ExpressionTestBase;
import org.opensearch.sql.expression.FunctionExpression;
import org.opensearch.sql.expression.env.Environment;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.Mockito.when;
import static org.opensearch.sql.data.model.ExprValueUtils.*;
import static org.opensearch.sql.data.type.ExprCoreType.STRING;

@ExtendWith(MockitoExtension.class)
@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class)
public class CidrExpressionTest extends ExpressionTestBase {

// IP range and address constants for testing.
static final ExprValue IPv4Range = ExprValueUtils.stringValue("198.51.100.0/24");
static final ExprValue IPv6Range = ExprValueUtils.stringValue("2001:0db8::/32");

static final ExprValue IPv4AddressBelow = ExprValueUtils.stringValue("198.51.99.1");
static final ExprValue IPv4AddressWithin = ExprValueUtils.stringValue("198.51.100.1");
static final ExprValue IPv4AddressAbove = ExprValueUtils.stringValue("198.51.101.2");

static final ExprValue IPv6AddressBelow = ExprValueUtils.stringValue("2001:0db7::ff00:42:8329");
static final ExprValue IPv6AddressWithin = ExprValueUtils.stringValue("2001:0db8::ff00:42:8329");
static final ExprValue IPv6AddressAbove = ExprValueUtils.stringValue("2001:0db9::ff00:42:8329");

// Mock value environment for testing.
@Mock
Environment<Expression, ExprValue> env;

@Test
public void test_invalid_num_arguments() {
assertThrows(ExpressionEvaluationException.class, DSL::cidr);
assertThrows(ExpressionEvaluationException.class, () -> DSL.cidr(DSL.literal(0), DSL.literal(0), DSL.literal(0)));
}

@Test
public void test_null_and_missing() {
assertEquals(LITERAL_NULL, execute(LITERAL_NULL, IPv4Range));
assertEquals(LITERAL_NULL, execute(LITERAL_MISSING, IPv4Range));
}

@Test
public void test_invalid_address() {
assertEquals(LITERAL_NULL, execute(ExprValueUtils.stringValue("INVALID"), IPv4Range));
}

@Test
public void test_invalid_range() {
assertThrows(SemanticCheckException.class, () -> execute(IPv4AddressWithin, ExprValueUtils.stringValue("INVALID")));
assertThrows(SemanticCheckException.class, () -> execute(IPv4AddressWithin, ExprValueUtils.stringValue("INVALID/32")));
assertThrows(SemanticCheckException.class, () -> execute(IPv4AddressWithin, ExprValueUtils.stringValue("198.51.100.0/33")));
}

@Test
public void test_valid_ipv4() {
assertEquals(LITERAL_FALSE, execute(IPv4AddressBelow, IPv4Range));
assertEquals(LITERAL_TRUE, execute(IPv4AddressWithin, IPv4Range));
assertEquals(LITERAL_FALSE, execute(IPv4AddressAbove, IPv4Range));
}

@Test
public void test_valid_ipv6() {
assertEquals(LITERAL_FALSE, execute(IPv6AddressBelow, IPv6Range));
assertEquals(LITERAL_TRUE, execute(IPv6AddressWithin, IPv6Range));
assertEquals(LITERAL_FALSE, execute(IPv6AddressAbove, IPv6Range));
}

@Test
public void test_valid_different_versions() {
assertEquals(LITERAL_FALSE, execute(IPv4AddressWithin, IPv6Range));
assertEquals(LITERAL_FALSE, execute(IPv6AddressWithin, IPv4Range));
}

/**
* Builds and evaluates a CIDR function expression with the given field
* and range expression values, and returns the resulting value.
*/
private ExprValue execute(ExprValue field, ExprValue range) {

final String fieldName = "ip_address";
FunctionExpression exp = DSL.cidr(DSL.ref(fieldName, STRING), DSL.literal(range));

// Mock the value environment to return the specified field
// expression as the value for the "ip_address" field.
when(DSL.ref(fieldName, STRING).valueOf(env)).thenReturn(field);

return exp.valueOf(env);
}
}
15 changes: 7 additions & 8 deletions docs/user/ppl/functions/ip.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,11 @@ Return type: BOOLEAN

Example::

os> source=devices | where cidr(address, "198.51.100.0/24")
fetched rows / total rows = 2/2
+----------------+----------------+
| name | address |
|----------------+----------------+
| John's Macbook | 198.51.100.2 |
| Iain's PC | 198.51.100.254 |
+----------------+----------------+
os> source=weblogs | where cidr(address, "199.120.110.0/24") | fields host, method, url
fetched rows / total rows = 1/1
+----------------+--------+----------------------------------------------+
| host | method | url |
|----------------+--------+----------------------------------------------+
| 199.120.110.21 | GET | /shuttle/missions/sts-73/mission-sts-73.html |
+----------------+--------+----------------------------------------------+

0 comments on commit 6c9c58d

Please sign in to comment.