Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix test_query_generator for query generation test on each snowflake plan node #2407

Merged
merged 1 commit into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/snowflake/snowpark/_internal/analyzer/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,7 +797,7 @@ def do_resolve(self, logical_plan: LogicalPlan) -> SnowflakePlan:

if isinstance(logical_plan, Selectable):
# Selectable doesn't have children. It already has the expr_to_alias dict.
self.alias_maps_to_use = logical_plan.expr_to_alias
self.alias_maps_to_use = logical_plan.expr_to_alias.copy()
else:
use_maps = {}
# get counts of expr_to_alias keys
Expand Down
15 changes: 10 additions & 5 deletions tests/integ/compiler/test_query_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,18 @@
]


def reset_node(node: LogicalPlan) -> None:
def reset_node(node: LogicalPlan, query_generator: QueryGenerator) -> None:
def reset_selectable(selectable_node: Selectable) -> None:
# reset the analyzer to use the current query generator instance to
# ensure the new query generator is used during the resolve process
selectable_node.analyzer = query_generator
if not isinstance(selectable_node, (SelectSnowflakePlan, SelectSQL)):
selectable_node._snowflake_plan = None
if isinstance(selectable_node, (SelectStatement, SetStatement)):
selectable_node._sql_query = None
selectable_node._projection_in_str = None
if isinstance(selectable_node, SelectStatement):
selectable_node.expr_to_alias = selectable_node.from_.expr_to_alias

if isinstance(node, SnowflakePlan):
# do not reset leaf snowflake plan
Expand Down Expand Up @@ -100,9 +105,9 @@ def check_generated_plan_queries(

nodes = nodes[::-1] # reverse the list
for node in nodes:
reset_node(node)
reset_node(node, query_generator)
if isinstance(node, SnowflakePlan):
re_resolve_and_compare_plan_queries(plan, query_generator)
re_resolve_and_compare_plan_queries(node, query_generator)


def verify_multiple_create_queries(
Expand Down Expand Up @@ -372,8 +377,8 @@ def test_multiple_plan_query_generation(session):
)
snowflake_plan = session._analyzer.resolve(create_table_logic_plan)
query_generator = create_query_generator(snowflake_plan)
reset_node(snowflake_plan)
reset_node(df_res._plan)
reset_node(snowflake_plan, query_generator)
reset_node(df_res._plan, query_generator)
logical_plans = [snowflake_plan.source_plan, df_res._plan.source_plan]
with SqlCounter(query_count=0, describe_count=0):
generated_queries = query_generator.generate_queries(logical_plans)
Expand Down
Loading