Skip to content

Commit

Permalink
Handle batchable lists of length 0 (#1600)
Browse files Browse the repository at this point in the history
Signed-off-by: eduardo apolinario <eapolinario@users.noreply.github.com>
Co-authored-by: eduardo apolinario <eapolinario@users.noreply.github.com>
  • Loading branch information
eapolinario and eapolinario committed May 16, 2023
1 parent 2d6aaa5 commit 5d4da60
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
9 changes: 6 additions & 3 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -987,14 +987,17 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp
if ListTransformer.is_batchable(python_type):
from flytekit.types.pickle.pickle import BatchSize, FlytePickle

batchSize = len(python_val) # default batch size
batch_size = len(python_val) # default batch size
# parse annotated to get the number of items saved in a pickle file.
if get_origin(python_type) is Annotated:
for annotation in get_args(python_type)[1:]:
if isinstance(annotation, BatchSize):
batchSize = annotation.val
batch_size = annotation.val
break
lit_list = [TypeEngine.to_literal(ctx, python_val[i : i + batchSize], FlytePickle, expected.collection_type) for i in range(0, len(python_val), batchSize)] # type: ignore
if batch_size > 0:
lit_list = [TypeEngine.to_literal(ctx, python_val[i : i + batch_size], FlytePickle, expected.collection_type) for i in range(0, len(python_val), batch_size)] # type: ignore
else:
lit_list = []
else:
t = self.get_sub_type(python_type)
lit_list = [TypeEngine.to_literal(ctx, x, t, expected.collection_type) for x in python_val] # type: ignore
Expand Down
2 changes: 2 additions & 0 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1609,6 +1609,8 @@ def test_is_batchable():
# [[batched_FlytePickle(3 items)], [batched_FlytePickle(3 items)]]
# Therefore, the expected list length is [2, 1] (the length of the outer list remains the same, the inner list is batched).
([["foo", "foo", "foo"]] * 2, typing.List[Annotated[typing.List[FlytePickle], BatchSize(3)]], [2, 1]),
# Case 4: Empty list
([[], typing.List[FlytePickle], []]),
],
)
def test_batch_pickle_list(python_val, python_type, expected_list_length):
Expand Down

0 comments on commit 5d4da60

Please sign in to comment.