Skip to content

Commit

Permalink
feat(firestore): SUM and AVG aggregations (googleapis#8293)
Browse files Browse the repository at this point in the history
  • Loading branch information
bhshkh authored Oct 18, 2023
1 parent 5823db5 commit 011f9ff
Show file tree
Hide file tree
Showing 3 changed files with 530 additions and 21 deletions.
218 changes: 217 additions & 1 deletion firestore/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
apiv1 "cloud.google.com/go/firestore/apiv1/admin"
"cloud.google.com/go/firestore/apiv1/admin/adminpb"
firestorev1 "cloud.google.com/go/firestore/apiv1/firestorepb"
pb "cloud.google.com/go/firestore/apiv1/firestorepb"
"cloud.google.com/go/internal/pretty"
"cloud.google.com/go/internal/testutil"
"cloud.google.com/go/internal/uid"
Expand All @@ -43,6 +44,7 @@ import (
"google.golang.org/genproto/googleapis/type/latlng"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/structpb"
)

func TestMain(m *testing.M) {
Expand Down Expand Up @@ -150,7 +152,11 @@ func initIntegrationTest() {
// desc 'The query requires multiple indexes'.
func createIndexes(ctx context.Context, dbPath string) {

indexFields := [][]string{{"updatedAt", "weight", "height"}, {"weight", "height"}}
indexFields := [][]string{
{"updatedAt", "weight", "height"},
{"weight", "height"},
{"width", "depth"},
{"width", "model"}}
indexNames = make([]string, len(indexFields))
indexParent := fmt.Sprintf("%s/collectionGroups/%s", dbPath, iColl.ID)

Expand Down Expand Up @@ -2355,6 +2361,216 @@ func TestIntegration_BulkWriter(t *testing.T) {
})
}

func TestIntegration_AggregationQueries(t *testing.T) {
ctx := context.Background()
coll := integrationColl(t)
client := integrationClient(t)
h := testHelper{t}
docs := []map[string]interface{}{
{"width": 1.5, "depth": 99, "model": "A"},
{"width": 2.6, "depth": 98, "model": "A"},
{"width": 3.7, "depth": 97, "model": "B"},
{"width": 4.8, "depth": 96, "model": "B"},
{"width": 5.9, "depth": 95, "model": "C"},
{"width": 6.0, "depth": 94, "model": "B"},
{"width": 7.1, "depth": 93, "model": "C"},
{"width": 8.2, "depth": 93, "model": "A"},
}
for _, doc := range docs {
newDoc := coll.NewDoc()
h.mustCreate(newDoc, doc)
}

query := coll.Where("width", ">=", 1)

limitQuery := coll.Where("width", ">=", 1).Limit(4)
limitToLastQuery := coll.Where("width", ">=", 2.6).OrderBy("width", Asc).LimitToLast(4)

startAtQuery := coll.Where("width", ">=", 2.6).OrderBy("width", Asc).StartAt(3.7)
startAfterQuery := coll.Where("width", ">=", 2.6).OrderBy("width", Asc).StartAfter(3.7)

endAtQuery := coll.Where("width", ">=", 2.6).OrderBy("width", Asc).EndAt(7.1)
endBeforeQuery := coll.Where("width", ">=", 2.6).OrderBy("width", Asc).EndBefore(7.1)

emptyResultsQuery := coll.Where("width", "<", 1)
emptyResultsQueryPtr := &emptyResultsQuery

testcases := []struct {
desc string
aggregationQuery *AggregationQuery
wantErr bool
runInTransaction bool
result AggregationResult
}{
{
desc: "Multiple aggregations",
aggregationQuery: query.NewAggregationQuery().WithCount("count1").WithAvg("width", "width_avg1").WithAvg("depth", "depth_avg1").WithSum("width", "width_sum1").WithSum("depth", "depth_sum1"),
wantErr: false,
result: map[string]interface{}{
"count1": &pb.Value{ValueType: &pb.Value_IntegerValue{IntegerValue: int64(8)}},
"width_sum1": &pb.Value{ValueType: &pb.Value_DoubleValue{DoubleValue: float64(39.8)}},
"depth_sum1": &pb.Value{ValueType: &pb.Value_IntegerValue{IntegerValue: int64(765)}},
"width_avg1": &pb.Value{ValueType: &pb.Value_DoubleValue{DoubleValue: float64(4.975)}},
"depth_avg1": &pb.Value{ValueType: &pb.Value_DoubleValue{DoubleValue: float64(95.625)}},
},
},
{
desc: "Aggregations in transaction",
aggregationQuery: query.NewAggregationQuery().WithCount("count1").WithAvg("width", "width_avg1").WithAvg("depth", "depth_avg1").WithSum("width", "width_sum1").WithSum("depth", "depth_sum1"),
wantErr: false,
runInTransaction: true,
result: map[string]interface{}{
"count1": &pb.Value{ValueType: &pb.Value_IntegerValue{IntegerValue: int64(8)}},
"width_sum1": &pb.Value{ValueType: &pb.Value_DoubleValue{DoubleValue: float64(39.8)}},
"depth_sum1": &pb.Value{ValueType: &pb.Value_IntegerValue{IntegerValue: int64(765)}},
"width_avg1": &pb.Value{ValueType: &pb.Value_DoubleValue{DoubleValue: float64(4.975)}},
"depth_avg1": &pb.Value{ValueType: &pb.Value_DoubleValue{DoubleValue: float64(95.625)}},
},
},
{
desc: "WithSum aggregation without alias",
aggregationQuery: query.NewAggregationQuery().WithSum("width", ""),
wantErr: false,
result: map[string]interface{}{
"field_1": &pb.Value{ValueType: &pb.Value_DoubleValue{DoubleValue: float64(39.8)}},
},
},
{
desc: "WithSumPath aggregation without alias",
aggregationQuery: query.NewAggregationQuery().WithSumPath([]string{"width"}, ""),
wantErr: false,
result: map[string]interface{}{
"field_1": &pb.Value{ValueType: &pb.Value_DoubleValue{DoubleValue: float64(39.8)}},
},
},
{
desc: "WithAvg aggregation without alias",
aggregationQuery: query.NewAggregationQuery().WithAvg("width", ""),
wantErr: false,
result: map[string]interface{}{
"field_1": &pb.Value{ValueType: &pb.Value_DoubleValue{DoubleValue: float64(4.975)}},
},
},
{
desc: "WithAvgPath aggregation without alias",
aggregationQuery: query.NewAggregationQuery().WithAvgPath([]string{"width"}, ""),
wantErr: false,
result: map[string]interface{}{
"field_1": &pb.Value{ValueType: &pb.Value_DoubleValue{DoubleValue: float64(4.975)}},
},
},
{
desc: "Aggregations with limit",
aggregationQuery: (&limitQuery).NewAggregationQuery().WithCount("count1").WithAvgPath([]string{"width"}, "width_avg1").WithSumPath([]string{"width"}, "width_sum1"),
wantErr: false,
result: map[string]interface{}{
"count1": &pb.Value{ValueType: &pb.Value_IntegerValue{IntegerValue: int64(4)}},
"width_sum1": &pb.Value{ValueType: &pb.Value_DoubleValue{DoubleValue: float64(12.6)}},
"width_avg1": &pb.Value{ValueType: &pb.Value_DoubleValue{DoubleValue: float64(3.15)}},
},
},
{
desc: "Aggregations with StartAt",
aggregationQuery: (&startAtQuery).NewAggregationQuery().WithCount("count1").WithAvgPath([]string{"width"}, "width_avg1").WithSumPath([]string{"width"}, "width_sum1"),
wantErr: false,
result: map[string]interface{}{
"count1": &pb.Value{ValueType: &pb.Value_IntegerValue{IntegerValue: int64(6)}},
"width_sum1": &pb.Value{ValueType: &pb.Value_DoubleValue{DoubleValue: float64(35.7)}},
"width_avg1": &pb.Value{ValueType: &pb.Value_DoubleValue{DoubleValue: float64(5.95)}},
},
},
{
desc: "Aggregations with StartAfter",
aggregationQuery: (&startAfterQuery).NewAggregationQuery().WithCount("count1").WithAvgPath([]string{"width"}, "width_avg1").WithSumPath([]string{"width"}, "width_sum1"),
wantErr: false,
result: map[string]interface{}{
"count1": &pb.Value{ValueType: &pb.Value_IntegerValue{IntegerValue: int64(5)}},
"width_sum1": &pb.Value{ValueType: &pb.Value_DoubleValue{DoubleValue: float64(32)}},
"width_avg1": &pb.Value{ValueType: &pb.Value_DoubleValue{DoubleValue: float64(6.4)}},
},
},
{
desc: "Aggregations with EndAt",
aggregationQuery: (&endAtQuery).NewAggregationQuery().WithCount("count1").WithAvgPath([]string{"width"}, "width_avg1").WithSumPath([]string{"width"}, "width_sum1"),
wantErr: false,
result: map[string]interface{}{
"count1": &pb.Value{ValueType: &pb.Value_IntegerValue{IntegerValue: int64(6)}},
"width_sum1": &pb.Value{ValueType: &pb.Value_DoubleValue{DoubleValue: float64(30.1)}},
"width_avg1": &pb.Value{ValueType: &pb.Value_DoubleValue{DoubleValue: float64(5.016666666666667)}},
},
},
{
desc: "Aggregations with EndBefore",
aggregationQuery: (&endBeforeQuery).NewAggregationQuery().WithCount("count1").WithAvgPath([]string{"width"}, "width_avg1").WithSumPath([]string{"width"}, "width_sum1"),
wantErr: false,
result: map[string]interface{}{
"count1": &pb.Value{ValueType: &pb.Value_IntegerValue{IntegerValue: int64(5)}},
"width_sum1": &pb.Value{ValueType: &pb.Value_DoubleValue{DoubleValue: float64(23)}},
"width_avg1": &pb.Value{ValueType: &pb.Value_DoubleValue{DoubleValue: float64(4.6)}},
},
},
{
desc: "Aggregations with LimitToLast",
aggregationQuery: (&limitToLastQuery).NewAggregationQuery().WithCount("count1").WithAvgPath([]string{"width"}, "width_avg1").WithSumPath([]string{"width"}, "width_sum1"),
wantErr: false,
result: map[string]interface{}{
"count1": &pb.Value{ValueType: &pb.Value_IntegerValue{IntegerValue: int64(4)}},
"width_sum1": &pb.Value{ValueType: &pb.Value_DoubleValue{DoubleValue: float64(27.2)}},
"width_avg1": &pb.Value{ValueType: &pb.Value_DoubleValue{DoubleValue: float64(6.8)}},
},
},
{
desc: "Aggregations on empty results",
aggregationQuery: emptyResultsQueryPtr.NewAggregationQuery().WithCount("count1").WithAvg("width", "width_avg1").WithSum("width", "width_sum1"),
wantErr: false,
result: map[string]interface{}{
"count1": &pb.Value{ValueType: &pb.Value_IntegerValue{IntegerValue: int64(0)}},
"width_sum1": &pb.Value{ValueType: &pb.Value_IntegerValue{IntegerValue: int64(0)}},
"width_avg1": &pb.Value{ValueType: &pb.Value_NullValue{NullValue: structpb.NullValue_NULL_VALUE}},
},
},
{
desc: "Aggregation on non-numeric field",
aggregationQuery: query.NewAggregationQuery().WithAvg("model", "model_avg1").WithSum("model", "model_sum1"),
wantErr: false,
result: map[string]interface{}{
"model_sum1": &pb.Value{ValueType: &pb.Value_IntegerValue{IntegerValue: int64(0)}},
"model_avg1": &pb.Value{ValueType: &pb.Value_NullValue{NullValue: structpb.NullValue_NULL_VALUE}},
},
},
{
desc: "Aggregation on non existent key",
aggregationQuery: query.NewAggregationQuery().WithAvg("randKey", "key_avg1").WithSum("randKey", "key_sum1"),
wantErr: true,
},
}

for _, tc := range testcases {
var aggResult AggregationResult
var err error
if tc.runInTransaction {
client.RunTransaction(ctx, func(ctx context.Context, tx *Transaction) error {
aggResult, err = tc.aggregationQuery.Transaction(tx).Get(ctx)
return err
})
} else {
aggResult, err = tc.aggregationQuery.Get(ctx)
}
if err != nil && !tc.wantErr {
t.Errorf("%s: got: %v, want: nil", tc.desc, err)
continue
}
if err == nil && tc.wantErr {
t.Errorf("%s: got: %v, wanted error", tc.desc, err)
continue
}
if !reflect.DeepEqual(aggResult, tc.result) {
t.Errorf("%s: got: %v, want: %v", tc.desc, aggResult, tc.result)
continue
}
}
}

func TestIntegration_CountAggregationQuery(t *testing.T) {
str := uid.NewSpace("firestore-count", &uid.Options{})
datum := str.New()
Expand Down
Loading

0 comments on commit 011f9ff

Please sign in to comment.