From 1a20e4ae9313805fd83eb904ff9ee1837484a2b5 Mon Sep 17 00:00:00 2001 From: galipremsagar Date: Tue, 26 Dec 2023 04:38:25 +0000 Subject: [PATCH] Add more validation to MultiIndex.to_frame --- python/cudf/cudf/core/multiindex.py | 4 ++++ python/cudf/cudf/tests/test_multiindex.py | 4 ++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/python/cudf/cudf/core/multiindex.py b/python/cudf/cudf/core/multiindex.py index 5c2b4e6c7b0..a2cc5450ca4 100644 --- a/python/cudf/cudf/core/multiindex.py +++ b/python/cudf/cudf/core/multiindex.py @@ -1028,6 +1028,10 @@ def to_frame(self, index=True, name=no_default, allow_duplicates=False): for level, name in enumerate(self.names) ] else: + if not is_list_like(name): + raise TypeError( + "'name' must be a list / sequence of column names." + ) if len(name) != len(self.levels): raise ValueError( "'name' should have the same length as " diff --git a/python/cudf/cudf/tests/test_multiindex.py b/python/cudf/cudf/tests/test_multiindex.py index 5fdeacc346f..0cdc0e42cc1 100644 --- a/python/cudf/cudf/tests/test_multiindex.py +++ b/python/cudf/cudf/tests/test_multiindex.py @@ -1953,13 +1953,13 @@ def test_multiindex_to_frame_allow_duplicates( ): gidx = cudf.from_pandas(pidx) - if ( + if name is None or ( ( len(pidx.names) != len(set(pidx.names)) and not all(x is None for x in pidx.names) ) and not allow_duplicates - and (name is None or name is no_default) + and name is no_default ): assert_exceptions_equal( pidx.to_frame,