Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JAVA-3118: Add support for vector data type in Schema Builder, QueryBuilder #1931

Open
wants to merge 11 commits into
base: 4.x
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ public String getClassName() {
@NonNull
@Override
public String asCql(boolean includeFrozen, boolean pretty) {
return String.format("'%s(%d)'", getClassName(), getDimensions());
return String.format(
"vector<%s, %d>", this.subtype.asCql(includeFrozen, pretty), getDimensions());
}

/* ============== General class implementation ============== */
Expand Down
10 changes: 10 additions & 0 deletions query-builder/revapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -2772,6 +2772,16 @@
"code": "java.method.addedToInterface",
"new": "method com.datastax.oss.driver.api.querybuilder.update.UpdateStart com.datastax.oss.driver.api.querybuilder.update.UpdateStart::usingTtl(int)",
"justification": "JAVA-2210: Add ability to set TTL for modification queries"
},
{
"code": "java.method.addedToInterface",
"new": "method com.datastax.oss.driver.api.querybuilder.select.Select com.datastax.oss.driver.api.querybuilder.select.Select::orderByAnnOf(java.lang.String, com.datastax.oss.driver.api.core.data.CqlVector<? extends java.lang.Number>)",
"justification": "JAVA-3118: Add support for vector data type in Schema Builder, QueryBuilder"
},
{
"code": "java.method.addedToInterface",
"new": "method com.datastax.oss.driver.api.querybuilder.select.Select com.datastax.oss.driver.api.querybuilder.select.Select::orderByAnnOf(com.datastax.oss.driver.api.core.CqlIdentifier, com.datastax.oss.driver.api.core.data.CqlVector<? extends java.lang.Number>)",
"justification": "JAVA-3118: Add support for vector data type in Schema Builder, QueryBuilder"
}
]
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package com.datastax.oss.driver.api.querybuilder.select;

import com.datastax.oss.driver.api.core.CqlIdentifier;
import com.datastax.oss.driver.api.core.data.CqlVector;
import com.datastax.oss.driver.api.core.metadata.schema.ClusteringOrder;
import com.datastax.oss.driver.api.querybuilder.BindMarker;
import com.datastax.oss.driver.api.querybuilder.BuildableQuery;
Expand Down Expand Up @@ -146,6 +147,16 @@ default Select orderBy(@NonNull String columnName, @NonNull ClusteringOrder orde
return orderBy(CqlIdentifier.fromCql(columnName), order);
}

/**
* Shortcut for {@link #orderByAnnOf(CqlIdentifier, CqlVector)}, adding an ORDER BY ... ANN OF ...
* clause
*/
@NonNull
Select orderByAnnOf(@NonNull String columnName, @NonNull CqlVector<? extends Number> ann);

/** Adds the ORDER BY ... ANN OF ... clause */
@NonNull
Select orderByAnnOf(@NonNull CqlIdentifier columnId, @NonNull CqlVector<? extends Number> ann);
/**
* Adds a LIMIT clause to this query with a literal value.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
import com.datastax.oss.driver.api.core.CqlIdentifier;
import com.datastax.oss.driver.api.core.cql.SimpleStatement;
import com.datastax.oss.driver.api.core.cql.SimpleStatementBuilder;
import com.datastax.oss.driver.api.core.data.CqlVector;
import com.datastax.oss.driver.api.core.metadata.schema.ClusteringOrder;
import com.datastax.oss.driver.api.querybuilder.BindMarker;
import com.datastax.oss.driver.api.querybuilder.QueryBuilder;
import com.datastax.oss.driver.api.querybuilder.relation.Relation;
import com.datastax.oss.driver.api.querybuilder.select.Select;
import com.datastax.oss.driver.api.querybuilder.select.SelectFrom;
Expand Down Expand Up @@ -49,6 +51,7 @@ public class DefaultSelect implements SelectFrom, Select {
private final ImmutableList<Relation> relations;
private final ImmutableList<Selector> groupByClauses;
private final ImmutableMap<CqlIdentifier, ClusteringOrder> orderings;
private final Ann ann;
private final Object limit;
private final Object perPartitionLimit;
private final boolean allowsFiltering;
Expand All @@ -65,6 +68,7 @@ public DefaultSelect(@Nullable CqlIdentifier keyspace, @NonNull CqlIdentifier ta
ImmutableMap.of(),
null,
null,
null,
false);
}

Expand All @@ -74,6 +78,8 @@ public DefaultSelect(@Nullable CqlIdentifier keyspace, @NonNull CqlIdentifier ta
* @param selectors if it contains {@link AllSelector#INSTANCE}, that must be the only element.
* This isn't re-checked because methods that call this constructor internally already do it,
* make sure you do it yourself.
* @param ann Approximate nearest neighbor. ANN ordering does not support secondary ordering or
* ASC order.
*/
public DefaultSelect(
@Nullable CqlIdentifier keyspace,
Expand All @@ -84,6 +90,7 @@ public DefaultSelect(
@NonNull ImmutableList<Relation> relations,
@NonNull ImmutableList<Selector> groupByClauses,
@NonNull ImmutableMap<CqlIdentifier, ClusteringOrder> orderings,
@Nullable Ann ann,
@Nullable Object limit,
@Nullable Object perPartitionLimit,
boolean allowsFiltering) {
Expand All @@ -94,6 +101,9 @@ public DefaultSelect(
|| (limit instanceof Integer && (Integer) limit > 0)
|| limit instanceof BindMarker,
"limit must be a strictly positive integer or a bind marker");
Preconditions.checkArgument(
orderings.isEmpty() || ann == null, "ANN ordering does not support secondary ordering");
this.ann = ann;
this.keyspace = keyspace;
this.table = table;
this.isJson = isJson;
Expand All @@ -117,6 +127,7 @@ public SelectFrom json() {
relations,
groupByClauses,
orderings,
ann,
limit,
perPartitionLimit,
allowsFiltering);
Expand All @@ -134,6 +145,7 @@ public SelectFrom distinct() {
relations,
groupByClauses,
orderings,
ann,
limit,
perPartitionLimit,
allowsFiltering);
Expand Down Expand Up @@ -193,6 +205,7 @@ public Select withSelectors(@NonNull ImmutableList<Selector> newSelectors) {
relations,
groupByClauses,
orderings,
ann,
limit,
perPartitionLimit,
allowsFiltering);
Expand Down Expand Up @@ -221,6 +234,7 @@ public Select withRelations(@NonNull ImmutableList<Relation> newRelations) {
newRelations,
groupByClauses,
orderings,
ann,
limit,
perPartitionLimit,
allowsFiltering);
Expand Down Expand Up @@ -249,6 +263,7 @@ public Select withGroupByClauses(@NonNull ImmutableList<Selector> newGroupByClau
relations,
newGroupByClauses,
orderings,
ann,
limit,
perPartitionLimit,
allowsFiltering);
Expand All @@ -260,6 +275,19 @@ public Select orderBy(@NonNull CqlIdentifier columnId, @NonNull ClusteringOrder
return withOrderings(ImmutableCollections.append(orderings, columnId, order));
}

@NonNull
@Override
public Select orderByAnnOf(@NonNull String columnName, @NonNull CqlVector<? extends Number> ann) {
return withAnn(new Ann(CqlIdentifier.fromCql(columnName), ann));
}

@NonNull
@Override
public Select orderByAnnOf(
@NonNull CqlIdentifier columnId, @NonNull CqlVector<? extends Number> ann) {
return withAnn(new Ann(columnId, ann));
}

@NonNull
@Override
public Select orderByIds(@NonNull Map<CqlIdentifier, ClusteringOrder> newOrderings) {
Expand All @@ -277,6 +305,24 @@ public Select withOrderings(@NonNull ImmutableMap<CqlIdentifier, ClusteringOrder
relations,
groupByClauses,
newOrderings,
ann,
limit,
perPartitionLimit,
allowsFiltering);
}

@NonNull
Select withAnn(@NonNull Ann ann) {
return new DefaultSelect(
keyspace,
table,
isJson,
isDistinct,
selectors,
relations,
groupByClauses,
orderings,
ann,
limit,
perPartitionLimit,
allowsFiltering);
Expand All @@ -295,6 +341,7 @@ public Select limit(int limit) {
relations,
groupByClauses,
orderings,
ann,
limit,
perPartitionLimit,
allowsFiltering);
Expand All @@ -312,6 +359,7 @@ public Select limit(@Nullable BindMarker bindMarker) {
relations,
groupByClauses,
orderings,
ann,
bindMarker,
perPartitionLimit,
allowsFiltering);
Expand All @@ -331,6 +379,7 @@ public Select perPartitionLimit(int perPartitionLimit) {
relations,
groupByClauses,
orderings,
ann,
limit,
perPartitionLimit,
allowsFiltering);
Expand All @@ -348,6 +397,7 @@ public Select perPartitionLimit(@Nullable BindMarker bindMarker) {
relations,
groupByClauses,
orderings,
ann,
limit,
bindMarker,
allowsFiltering);
Expand All @@ -365,6 +415,7 @@ public Select allowFiltering() {
relations,
groupByClauses,
orderings,
ann,
limit,
perPartitionLimit,
true);
Expand All @@ -391,15 +442,20 @@ public String asCql() {
CqlHelper.append(relations, builder, " WHERE ", " AND ", null);
CqlHelper.append(groupByClauses, builder, " GROUP BY ", ",", null);

boolean first = true;
for (Map.Entry<CqlIdentifier, ClusteringOrder> entry : orderings.entrySet()) {
if (first) {
builder.append(" ORDER BY ");
first = false;
} else {
builder.append(",");
if (ann != null) {
builder.append(" ORDER BY ").append(this.ann.columnId.asCql(true)).append(" ANN OF ");
QueryBuilder.literal(ann.vector).appendTo(builder);
} else {
boolean first = true;
for (Map.Entry<CqlIdentifier, ClusteringOrder> entry : orderings.entrySet()) {
if (first) {
builder.append(" ORDER BY ");
first = false;
} else {
builder.append(",");
}
builder.append(entry.getKey().asCql(true)).append(" ").append(entry.getValue().name());
}
builder.append(entry.getKey().asCql(true)).append(" ").append(entry.getValue().name());
}

if (limit != null) {
Expand Down Expand Up @@ -512,4 +568,14 @@ public boolean allowsFiltering() {
public String toString() {
return asCql();
}

public static class Ann {
private final CqlVector<? extends Number> vector;
private final CqlIdentifier columnId;

private Ann(CqlIdentifier columnId, CqlVector<? extends Number> vector) {
this.vector = vector;
this.columnId = columnId;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import static com.datastax.oss.driver.api.querybuilder.QueryBuilder.deleteFrom;
import static com.datastax.oss.driver.api.querybuilder.QueryBuilder.literal;

import com.datastax.oss.driver.api.core.data.CqlVector;
import org.junit.Test;

public class DeleteSelectorTest {
Expand All @@ -34,6 +35,16 @@ public void should_generate_column_deletion() {
.hasCql("DELETE v FROM ks.foo WHERE k=?");
}

@Test
public void should_generate_vector_deletion() {
assertThat(
deleteFrom("foo")
.column("v")
.whereColumn("k")
.isEqualTo(literal(CqlVector.newInstance(0.1, 0.2))))
.hasCql("DELETE v FROM foo WHERE k=[0.1, 0.2]");
}

@Test
public void should_generate_field_deletion() {
assertThat(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import static com.datastax.oss.driver.api.querybuilder.QueryBuilder.literal;
import static org.assertj.core.api.Assertions.catchThrowable;

import com.datastax.oss.driver.api.core.data.CqlVector;
import com.datastax.oss.driver.api.querybuilder.term.Term;
import com.datastax.oss.driver.internal.querybuilder.insert.DefaultInsert;
import com.datastax.oss.driver.shaded.guava.common.collect.ImmutableMap;
Expand All @@ -41,6 +42,12 @@ public void should_generate_column_assignments() {
.hasCql("INSERT INTO foo (a,b) VALUES (?,?)");
}

@Test
public void should_generate_vector_literals() {
assertThat(insertInto("foo").value("a", literal(CqlVector.newInstance(0.1, 0.2, 0.3))))
.hasCql("INSERT INTO foo (a) VALUES ([0.1, 0.2, 0.3])");
}

@Test
public void should_keep_last_assignment_if_column_listed_twice() {
assertThat(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,4 +108,10 @@ public void should_generate_alter_table_with_no_compression() {
assertThat(alterTable("bar").withNoCompression())
.hasCql("ALTER TABLE bar WITH compression={'sstable_compression':''}");
}

@Test
public void should_generate_alter_table_with_vector() {
assertThat(alterTable("bar").alterColumn("v", DataTypes.vectorOf(DataTypes.FLOAT, 3)))
.hasCql("ALTER TABLE bar ALTER v TYPE vector<float, 3>");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,10 @@ public void should_generate_alter_table_with_rename_three_columns() {
assertThat(alterType("bar").renameField("x", "y").renameField("u", "v").renameField("b", "a"))
.hasCql("ALTER TYPE bar RENAME x TO y AND u TO v AND b TO a");
}

@Test
public void should_generate_alter_type_with_vector() {
assertThat(alterType("foo", "bar").alterField("vec", DataTypes.vectorOf(DataTypes.FLOAT, 3)))
.hasCql("ALTER TYPE foo.bar ALTER vec TYPE vector<float, 3>");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -307,4 +307,13 @@ public void should_generate_create_table_time_window_compaction() {
.hasCql(
"CREATE TABLE bar (k int PRIMARY KEY,v text) WITH compaction={'class':'TimeWindowCompactionStrategy','compaction_window_size':10,'compaction_window_unit':'DAYS','timestamp_resolution':'MICROSECONDS','unsafe_aggressive_sstable_expiration':false}");
}

@Test
public void should_generate_vector_column() {
assertThat(
createTable("foo")
.withPartitionKey("k", DataTypes.INT)
.withColumn("v", DataTypes.vectorOf(DataTypes.FLOAT, 3)))
.hasCql("CREATE TABLE foo (k int PRIMARY KEY,v vector<float, 3>)");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,13 @@ public void should_create_type_with_collections() {
.withField("map", DataTypes.mapOf(DataTypes.INT, DataTypes.TEXT)))
.hasCql("CREATE TYPE ks1.type (map map<int, text>)");
}

@Test
public void should_create_type_with_vector() {
assertThat(
createType("ks1", "type")
.withField("c1", DataTypes.INT)
.withField("vec", DataTypes.vectorOf(DataTypes.FLOAT, 3)))
.hasCql("CREATE TYPE ks1.type (c1 int,vec vector<float, 3>)");
}
}
Loading