Skip to content

Commit

Permalink
feat(firestore): add support for VectorValue (#16476)
Browse files Browse the repository at this point in the history
* feat(firestore): add support for VectorValue

* android

* fix iOS

* android

* fix format

* add to example app

* android

* JS

* with correct JS version

* format

* more tests

* fix tests

* change to test
  • Loading branch information
Lyokone authored Dec 18, 2024
1 parent 71e1f21 commit cc23f17
Show file tree
Hide file tree
Showing 16 changed files with 291 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import com.google.firebase.firestore.Query;
import com.google.firebase.firestore.QuerySnapshot;
import com.google.firebase.firestore.SnapshotMetadata;
import com.google.firebase.firestore.VectorValue;
import io.flutter.plugin.common.StandardMessageCodec;
import java.io.ByteArrayOutputStream;
import java.nio.ByteBuffer;
Expand Down Expand Up @@ -55,6 +56,7 @@ class FlutterFirebaseFirestoreMessageCodec extends StandardMessageCodec {
private static final byte DATA_TYPE_FIRESTORE_INSTANCE = (byte) 196;
private static final byte DATA_TYPE_FIRESTORE_QUERY = (byte) 197;
private static final byte DATA_TYPE_FIRESTORE_SETTINGS = (byte) 198;
private static final byte DATA_TYPE_VECTOR_VALUE = (byte) 199;

@Override
protected void writeValue(ByteArrayOutputStream stream, Object value) {
Expand All @@ -70,6 +72,9 @@ protected void writeValue(ByteArrayOutputStream stream, Object value) {
writeAlignment(stream, 8);
writeDouble(stream, ((GeoPoint) value).getLatitude());
writeDouble(stream, ((GeoPoint) value).getLongitude());
} else if (value instanceof VectorValue) {
stream.write(DATA_TYPE_VECTOR_VALUE);
writeValue(stream, ((VectorValue) value).toArray());
} else if (value instanceof DocumentReference) {
stream.write(DATA_TYPE_DOCUMENT_REFERENCE);
FirebaseFirestore firestore = ((DocumentReference) value).getFirestore();
Expand Down Expand Up @@ -238,6 +243,13 @@ protected Object readValueOfType(byte type, ByteBuffer buffer) {
case DATA_TYPE_GEO_POINT:
readAlignment(buffer, 8);
return new GeoPoint(buffer.getDouble(), buffer.getDouble());
case DATA_TYPE_VECTOR_VALUE:
final ArrayList<Double> arrayList = (ArrayList<Double>) readValue(buffer);
double[] doubleArray = new double[arrayList.size()];
for (int i = 0; i < arrayList.size(); i++) {
doubleArray[i] = Objects.requireNonNull(arrayList.get(i), "Null value at index " + i);
}
return FieldValue.vector(doubleArray);
case DATA_TYPE_DOCUMENT_REFERENCE:
FirebaseFirestore firestore = (FirebaseFirestore) readValue(buffer);
final String path = (String) readValue(buffer);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,7 @@ void runDocumentReferenceTests() {
'null': null,
'timestamp': Timestamp.now(),
'geopoint': const GeoPoint(1, 2),
'vectorValue': const VectorValue([1, 2, 3]),
'reference': firestore.doc('foo/bar'),
'nan': double.nan,
'infinity': double.infinity,
Expand Down Expand Up @@ -444,6 +445,11 @@ void runDocumentReferenceTests() {
expect(data['geopoint'], isA<GeoPoint>());
expect((data['geopoint'] as GeoPoint).latitude, equals(1));
expect((data['geopoint'] as GeoPoint).longitude, equals(2));
expect(data['vectorValue'], isA<VectorValue>());
expect(
(data['vectorValue'] as VectorValue).toArray(),
equals([1, 2, 3]),
);
expect(data['reference'], isA<DocumentReference>());
expect((data['reference'] as DocumentReference).id, equals('bar'));
expect(data['nan'].isNaN, equals(true));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ import 'settings_e2e.dart';
import 'snapshot_metadata_e2e.dart';
import 'timestamp_e2e.dart';
import 'transaction_e2e.dart';
import 'write_batch_e2e.dart';
import 'vector_value_e2e.dart';
import 'web_snapshot_listeners.dart';
import 'write_batch_e2e.dart';

bool kUseFirestoreEmulator = true;

Expand Down Expand Up @@ -52,6 +53,7 @@ void main() {
runDocumentReferenceTests();
runFieldValueTests();
runGeoPointTests();
runVectorValueTests();
runQueryTests();
runSnapshotMetadataTests();
runTimestampTests();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
// Copyright 2020, the Chromium project authors. Please see the AUTHORS file
// for details. All rights reserved. Use of this source code is governed by a
// BSD-style license that can be found in the LICENSE file.

import 'package:cloud_firestore/cloud_firestore.dart';
import 'package:flutter_test/flutter_test.dart';

void runVectorValueTests() {
group('$VectorValue', () {
late FirebaseFirestore firestore;

setUpAll(() async {
firestore = FirebaseFirestore.instance;
});

Future<DocumentReference<Map<String, dynamic>>> initializeTest(
String path,
) async {
String prefixedPath = 'flutter-tests/$path';
await firestore.doc(prefixedPath).delete();
return firestore.doc(prefixedPath);
}

test('sets a $VectorValue & returns one', () async {
DocumentReference<Map<String, dynamic>> doc =
await initializeTest('vector-value');

await doc.set({
'foo': const VectorValue([10.0, -10.0]),
});

DocumentSnapshot<Map<String, dynamic>> snapshot = await doc.get();

VectorValue vectorValue = snapshot.data()!['foo'];
expect(vectorValue, isA<VectorValue>());
expect(vectorValue.toArray(), equals([10.0, -10.0]));
});

test('updates a $VectorValue & returns', () async {
DocumentReference<Map<String, dynamic>> doc =
await initializeTest('vector-value-update');

await doc.set({
'foo': const VectorValue([10.0, -10.0]),
});

await doc.update({
'foo': const VectorValue([-10.0, 10.0]),
});

DocumentSnapshot<Map<String, dynamic>> snapshot = await doc.get();

VectorValue vectorValue = snapshot.data()!['foo'];
expect(vectorValue, isA<VectorValue>());
expect(vectorValue.toArray(), equals([-10.0, 10.0]));
});

test('handles empty vector', () async {
DocumentReference<Map<String, dynamic>> doc =
await initializeTest('vector-value-empty');

try {
await doc.set({
'foo': const VectorValue([]),
});
fail('Should have thrown an exception');
} catch (e) {
expect(e, isA<FirebaseException>());
expect(
(e as FirebaseException).code.contains('invalid-argument'),
isTrue,
);
}
});

test('handles single dimension vector', () async {
DocumentReference<Map<String, dynamic>> doc =
await initializeTest('vector-value-single');

await doc.set({
'foo': const VectorValue([42.0]),
});

DocumentSnapshot<Map<String, dynamic>> snapshot = await doc.get();

VectorValue vectorValue = snapshot.data()!['foo'];
expect(vectorValue, isA<VectorValue>());
expect(vectorValue.toArray(), equals([42.0]));
});

test('handles maximum dimensions vector', () async {
List<double> maxDimensions = List.filled(2048, 1);
DocumentReference<Map<String, dynamic>> doc =
await initializeTest('vector-value-max-dimensions');

await doc.set({
'foo': VectorValue(maxDimensions),
});

DocumentSnapshot<Map<String, dynamic>> snapshot = await doc.get();

VectorValue vectorValue = snapshot.data()!['foo'];
expect(vectorValue, isA<VectorValue>());
expect(vectorValue.toArray(), equals(maxDimensions));
});

test('handles maximum dimensions + 1 vector', () async {
List<double> maxPlusOneDimensions = List.filled(2049, 1);
DocumentReference<Map<String, dynamic>> doc =
await initializeTest('vector-value-max-plus-one');

try {
await doc.set({
'foo': VectorValue(maxPlusOneDimensions),
});

fail('Should have thrown an exception');
} catch (e) {
expect(e, isA<FirebaseException>());
expect(
(e as FirebaseException).code.contains('invalid-argument'),
isTrue,
);
}
});

test('handles very large values in vector', () async {
DocumentReference<Map<String, dynamic>> doc =
await initializeTest('vector-value-large-values');

await doc.set({
'foo': const VectorValue([1e10, -1e10]),
});

DocumentSnapshot<Map<String, dynamic>> snapshot = await doc.get();

VectorValue vectorValue = snapshot.data()!['foo'];
expect(vectorValue, isA<VectorValue>());
expect(vectorValue.toArray(), equals([1e10, -1e10]));
});

test('handles floats in vector', () async {
DocumentReference<Map<String, dynamic>> doc =
await initializeTest('vector-value-floats');

await doc.set({
'foo': const VectorValue([3.14, 2.718]),
});

DocumentSnapshot<Map<String, dynamic>> snapshot = await doc.get();

VectorValue vectorValue = snapshot.data()!['foo'];
expect(vectorValue, isA<VectorValue>());
expect(vectorValue.toArray(), equals([3.14, 2.718]));
});

test('handles negative values in vector', () async {
DocumentReference<Map<String, dynamic>> doc =
await initializeTest('vector-value-negative');

await doc.set({
'foo': const VectorValue([-42.0, -100.0]),
});

DocumentSnapshot<Map<String, dynamic>> snapshot = await doc.get();

VectorValue vectorValue = snapshot.data()!['foo'];
expect(vectorValue, isA<VectorValue>());
expect(vectorValue.toArray(), equals([-42.0, -100.0]));
});
});
}
13 changes: 13 additions & 0 deletions packages/cloud_firestore/cloud_firestore/example/lib/main.dart
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,15 @@ class _FilmListState extends State<FilmList> {
list.map((e) => e.taskState),
);
return;
case 'vectorValue':
const vectorValue = VectorValue([1.0, 2.0, 3.0]);
final vectorValueDoc = await FirebaseFirestore.instance
.collection('firestore-example-app')
.add({'vectorValue': vectorValue});

final snapshot = await vectorValueDoc.get();
print(snapshot.data());
return;
default:
return;
}
Expand All @@ -250,6 +259,10 @@ class _FilmListState extends State<FilmList> {
value: 'load_bundle',
child: Text('Load bundle'),
),
const PopupMenuItem(
value: 'vectorValue',
child: Text('Test Vector Value'),
),
];
},
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ - (id)readValueOfType:(UInt8)type {
[self readBytes:&longitude length:8];
return [[FIRGeoPoint alloc] initWithLatitude:latitude longitude:longitude];
}
case FirestoreDataTypeVectorValue: {
return [[FIRVectorValue alloc] initWithArray:[self readValue]];
}
case FirestoreDataTypeDocumentReference: {
FIRFirestore *firestore = [self readValue];
NSString *documentPath = [self readValue];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ - (void)writeValue:(id)value {
[self writeAlignment:8];
[self writeBytes:(UInt8 *)&latitude length:8];
[self writeBytes:(UInt8 *)&longitude length:8];
} else if ([value isKindOfClass:[FIRVectorValue class]]) {
FIRVectorValue *vector = value;
[self writeByte:FirestoreDataTypeVectorValue];
[self writeValue:vector.array];
} else if ([value isKindOfClass:[FIRDocumentReference class]]) {
FIRDocumentReference *document = value;
NSString *documentPath = [document path];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ typedef NS_ENUM(UInt8, FirestoreDataType) {
FirestoreDataTypeFirestoreInstance = 196,
FirestoreDataTypeFirestoreQuery = 197,
FirestoreDataTypeFirestoreSettings = 198,
FirestoreDataTypeVectorValue = 199,
};

@interface FLTFirebaseFirestoreReaderWriter : FlutterStandardReaderWriter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ export 'package:cloud_firestore_platform_interface/cloud_firestore_platform_inte
FieldPath,
Blob,
GeoPoint,
VectorValue,
Timestamp,
Source,
GetOptions,
Expand Down Expand Up @@ -57,11 +58,11 @@ part 'src/filters.dart';
part 'src/firestore.dart';
part 'src/load_bundle_task.dart';
part 'src/load_bundle_task_snapshot.dart';
part 'src/persistent_cache_index_manager.dart';
part 'src/query.dart';
part 'src/query_document_snapshot.dart';
part 'src/query_snapshot.dart';
part 'src/snapshot_metadata.dart';
part 'src/transaction.dart';
part 'src/utils/codec_utility.dart';
part 'src/write_batch.dart';
part 'src/persistent_cache_index_manager.dart';
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,17 @@ export 'src/platform_interface/platform_interface_firestore.dart';
export 'src/platform_interface/platform_interface_index_definitions.dart';
export 'src/platform_interface/platform_interface_load_bundle_task.dart';
export 'src/platform_interface/platform_interface_load_bundle_task_snapshot.dart';
export 'src/platform_interface/platform_interface_persistent_cache_index_manager.dart';
export 'src/platform_interface/platform_interface_query.dart';
export 'src/platform_interface/platform_interface_query_snapshot.dart';
export 'src/platform_interface/platform_interface_transaction.dart';
export 'src/platform_interface/platform_interface_write_batch.dart';
export 'src/platform_interface/platform_interface_persistent_cache_index_manager.dart';
export 'src/platform_interface/utils/load_bundle_task_state.dart';
export 'src/set_options.dart';
export 'src/settings.dart';
export 'src/snapshot_metadata.dart';
export 'src/timestamp.dart';
export 'src/vector_value.dart';

/// Helper method exposed to determine whether a given [collectionPath] points to
/// a valid Firestore collection.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

// TODO(Lyokone): remove once we bump Flutter SDK min version to 3.3
// ignore: unnecessary_import
import 'dart:typed_data';
import 'dart:core';

import 'package:cloud_firestore_platform_interface/cloud_firestore_platform_interface.dart';
import 'package:cloud_firestore_platform_interface/src/method_channel/method_channel_field_value.dart';
Expand Down Expand Up @@ -43,6 +43,7 @@ class FirestoreMessageCodec extends StandardMessageCodec {
static const int _kFirestoreInstance = 196;
static const int _kFirestoreQuery = 197;
static const int _kFirestoreSettings = 198;
static const int _kVectorValue = 199;

static const Map<FieldValueType, int> _kFieldValueCodes =
<FieldValueType, int>{
Expand Down Expand Up @@ -124,6 +125,9 @@ class FirestoreMessageCodec extends StandardMessageCodec {
buffer.putUint8(_kInfinity);
} else if (value == double.negativeInfinity) {
buffer.putUint8(_kNegativeInfinity);
} else if (value is VectorValue) {
buffer.putUint8(_kVectorValue);
writeValue(buffer, value.toArray());
} else {
super.writeValue(buffer, value);
}
Expand All @@ -148,6 +152,10 @@ class FirestoreMessageCodec extends StandardMessageCodec {
FirebaseFirestorePlatform.instanceFor(
app: app, databaseId: databaseId);
return firestore.doc(path);
case _kVectorValue:
final List<Object?> vector = (readValue(buffer)!) as List<Object?>;
final List<double> doubles = vector.map((e) => e! as double).toList();
return VectorValue(doubles);
case _kBlob:
final int length = readSize(buffer);
final List<int> bytes = buffer.getUint8List(length);
Expand Down
Loading

0 comments on commit cc23f17

Please sign in to comment.