diff --git a/storage/connectors/bigquery/src/main/java/feast/storage/connectors/bigquery/compression/FeatureRowsBatch.java b/storage/connectors/bigquery/src/main/java/feast/storage/connectors/bigquery/compression/FeatureRowsBatch.java index ee5b1d67bc..fa9cd94c11 100644 --- a/storage/connectors/bigquery/src/main/java/feast/storage/connectors/bigquery/compression/FeatureRowsBatch.java +++ b/storage/connectors/bigquery/src/main/java/feast/storage/connectors/bigquery/compression/FeatureRowsBatch.java @@ -19,6 +19,8 @@ import static feast.proto.types.ValueProto.Value.ValCase.*; import static feast.storage.connectors.bigquery.common.TypeUtil.*; +import com.google.common.collect.ImmutableList; +import com.google.protobuf.Timestamp; import feast.proto.types.FeatureRowProto; import feast.proto.types.FieldProto; import feast.proto.types.ValueProto; @@ -40,6 +42,8 @@ *

getFeatureRows provides reverse transformation */ public class FeatureRowsBatch implements Serializable { + public static final ImmutableList SERVICE_FIELDS = + ImmutableList.of("eventTimestamp", "ingestionId"); private final Schema schema; private String featureSetReference; private List values = new ArrayList<>(); @@ -118,6 +122,12 @@ private Schema inferCommonSchema(Iterable featureRow featureSetReference = row.getFeatureSet(); } })); + + fieldsInOrder.add( + Schema.Field.of("eventTimestamp", Schema.FieldType.array(Schema.FieldType.INT64))); + fieldsInOrder.add( + Schema.Field.of("ingestionId", Schema.FieldType.array(Schema.FieldType.STRING))); + Schema schema = Schema.builder().addFields(fieldsInOrder).build(); schema.setUUID(UUID.randomUUID()); return schema; @@ -132,16 +142,33 @@ private void initValues() { } private void toColumnar(Iterable featureRows) { + int timestampColumnIdx = schema.indexOf("eventTimestamp"); + int ingestionIdColumnIdx = schema.indexOf("ingestionId"); + featureRows.forEach( row -> { - Map rowValues = - row.getFieldsList().stream() - .collect(Collectors.toMap(FieldProto.Field::getName, FieldProto.Field::getValue)); + Map rowValues; + try { + rowValues = + row.getFieldsList().stream() + .collect( + Collectors.toMap(FieldProto.Field::getName, FieldProto.Field::getValue)); + } catch (IllegalStateException e) { + // row contains feature duplicates + // omitting for now + return; + } - IntStream.range(0, schema.getFieldCount()) + schema + .getFieldNames() .forEach( - idx -> { - Schema.Field field = schema.getField(idx); + fieldName -> { + if (SERVICE_FIELDS.contains(fieldName)) { + return; + } + Schema.Field field = schema.getField(fieldName); + int idx = schema.indexOf(fieldName); + if (rowValues.containsKey(field.getName())) { Object o = protoValueToObject(rowValues.get(field.getName())); if (o != null) { @@ -152,6 +179,10 @@ private void toColumnar(Iterable featureRows) { ((List) values.get(idx)).add(defaultValues.get(field.getName())); }); + + // adding service fields + ((List) values.get(timestampColumnIdx)).add(row.getEventTimestamp().getSeconds()); + ((List) values.get(ingestionIdColumnIdx)).add(row.getIngestionId()); }); } @@ -177,27 +208,45 @@ public static FeatureRowsBatch fromRow(Row row) { } public Iterator getFeatureRows() { + int timestampColumnIdx = schema.indexOf("eventTimestamp"); + int ingestionIdColumnIdx = schema.indexOf("ingestionId"); + return IntStream.range(0, ((List) values.get(0)).size()) .parallel() .mapToObj( rowIdx -> FeatureRowProto.FeatureRow.newBuilder() .setFeatureSet(getFeatureSetReference()) + .setEventTimestamp( + Timestamp.newBuilder() + .setSeconds( + (long) + (((List) values.get(timestampColumnIdx)).get(rowIdx))) + .build()) + .setIngestionId( + (String) (((List) values.get(ingestionIdColumnIdx)).get(rowIdx))) .addAllFields( - IntStream.range(0, schema.getFieldCount()) - .mapToObj( - fieldIdx -> - FieldProto.Field.newBuilder() - .setName(schema.getField(fieldIdx).getName()) - .setValue( - objectToProtoValue( - ((List) values.get(fieldIdx)).get(rowIdx), - schemaToProtoTypes.get( - schema - .getField(fieldIdx) - .getType() - .getCollectionElementType()))) - .build()) + schema.getFieldNames().stream() + .map( + fieldName -> { + if (SERVICE_FIELDS.contains(fieldName)) { + return null; + } + int fieldIdx = schema.indexOf(fieldName); + + return FieldProto.Field.newBuilder() + .setName(schema.getField(fieldIdx).getName()) + .setValue( + objectToProtoValue( + ((List) values.get(fieldIdx)).get(rowIdx), + schemaToProtoTypes.get( + schema + .getField(fieldIdx) + .getType() + .getCollectionElementType()))) + .build(); + }) + .filter(Objects::nonNull) .collect(Collectors.toList())) .build()) .iterator(); diff --git a/storage/connectors/bigquery/src/test/java/feast/storage/connectors/bigquery/writer/BigQuerySinkTest.java b/storage/connectors/bigquery/src/test/java/feast/storage/connectors/bigquery/writer/BigQuerySinkTest.java index 3f35c5e4ae..d58bebf65b 100644 --- a/storage/connectors/bigquery/src/test/java/feast/storage/connectors/bigquery/writer/BigQuerySinkTest.java +++ b/storage/connectors/bigquery/src/test/java/feast/storage/connectors/bigquery/writer/BigQuerySinkTest.java @@ -124,6 +124,11 @@ private FeatureRow generateRow(String featureSet) { FeatureRow.Builder row = FeatureRow.newBuilder() .setFeatureSet(featureSet) + .setEventTimestamp( + com.google.protobuf.Timestamp.newBuilder() + .setSeconds(System.currentTimeMillis() / 1000) + .build()) + .setIngestionId("ingestion-id") .addFields(field("entity", rd.nextInt(), ValueProto.ValueType.Enum.INT64)) .addFields(FieldProto.Field.newBuilder().setName("null_value").build()); @@ -499,6 +504,8 @@ private List dropNullFeature(List input) { r -> FeatureRow.newBuilder() .setFeatureSet(r.getFeatureSet()) + .setIngestionId(r.getIngestionId()) + .setEventTimestamp(r.getEventTimestamp()) .addAllFields(copyFieldsWithout(r, "null_value")) .build()) .collect(Collectors.toList()); @@ -520,6 +527,8 @@ public static List sortFeaturesByName(List rows) { return FeatureRow.newBuilder() .setFeatureSet(row.getFeatureSet()) + .setEventTimestamp(row.getEventTimestamp()) + .setIngestionId(row.getIngestionId()) .addAllFields(fieldsList) .build(); })