Skip to content

Commit

Permalink
fix UT cov
Browse files Browse the repository at this point in the history
Signed-off-by: zhichao-aws <zhichaog@amazon.com>
  • Loading branch information
zhichao-aws committed Nov 22, 2024
1 parent 2a65ec8 commit 2a3e2cf
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ private static Tuple<Map<String, Float>, Map<String, Float>> pruneByAlphaMass(
* @param pruneRatio The ratio or threshold for prune
* @param sparseVector The input sparse vector as a map of string keys to float values
* @return A tuple containing two maps: the first with high-scoring elements,
* the second with low-scoring elements (or null if requiresPrunedEntries is false)
* the second with low-scoring elements
*/
public static Tuple<Map<String, Float>, Map<String, Float>> splitSparseVector(
PruneType pruneType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,33 @@ public void testPruneByAlphaMass() {
assertEquals(1.0f, tupleResult.v2().get("d"), 0.001);
}

public void testNonePrune() {
Map<String, Float> input = new HashMap<>();
input.put("a", 5.0f);
input.put("b", 3.0f);
input.put("c", 4.0f);
input.put("d", 1.0f);

// Test prune
Map<String, Float> result = PruneUtils.pruneSparseVector(PruneType.NONE, 2, input);

assertEquals(4, result.size());
assertEquals(5.0f, result.get("a"), 0.001);
assertEquals(3.0f, result.get("b"), 0.001);
assertEquals(4.0f, result.get("c"), 0.001);
assertEquals(1.0f, result.get("d"), 0.001);

// Test split
Tuple<Map<String, Float>, Map<String, Float>> tupleResult = PruneUtils.splitSparseVector(PruneType.NONE, 2, input);

assertEquals(4, tupleResult.v1().size());
assertEquals(0, tupleResult.v2().size());
assertEquals(5.0f, tupleResult.v1().get("a"), 0.001);
assertEquals(3.0f, tupleResult.v1().get("b"), 0.001);
assertEquals(4.0f, tupleResult.v1().get("c"), 0.001);
assertEquals(1.0f, tupleResult.v1().get("d"), 0.001);
}

public void testEmptyInput() {
Map<String, Float> input = new HashMap<>();

Expand Down Expand Up @@ -154,13 +181,13 @@ public void testInvalidPruneType() {
// Test prune
IllegalArgumentException exception1 = assertThrows(
IllegalArgumentException.class,
() -> PruneUtils.splitSparseVector(null, 2, input)
() -> PruneUtils.pruneSparseVector(null, 2, input)
);
assertEquals(exception1.getMessage(), "Prune type and prune ratio must be provided");

IllegalArgumentException exception2 = assertThrows(
IllegalArgumentException.class,
() -> PruneUtils.splitSparseVector(null, 2, input)
() -> PruneUtils.pruneSparseVector(null, 2, input)
);
assertEquals(exception2.getMessage(), "Prune type and prune ratio must be provided");

Expand All @@ -181,7 +208,7 @@ public void testInvalidPruneType() {
public void testNullSparseVector() {
IllegalArgumentException exception1 = assertThrows(
IllegalArgumentException.class,
() -> PruneUtils.splitSparseVector(PruneType.TOP_K, 2, null)
() -> PruneUtils.pruneSparseVector(PruneType.TOP_K, 2, null)
);
assertEquals(exception1.getMessage(), "Sparse vector must be provided");

Expand Down

0 comments on commit 2a3e2cf

Please sign in to comment.