Skip to content

Commit

Permalink
phi autoparallel fallback (PaddlePaddle#59241)
Browse files Browse the repository at this point in the history
  • Loading branch information
wanghuancoder authored and SecretXV committed Nov 28, 2023
1 parent b4591d5 commit 8f40a4a
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 11 deletions.
33 changes: 28 additions & 5 deletions paddle/phi/api/yaml/generator/dist_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,13 @@
// 4. Select Kernel{}
// 5. Reshard Input{}\n
// 6. PrepareData (DataTransform & Prepare Dense Input){}
// RecordOpInfoSupplement{}
// 7. Infer Local DenseTensor Meta{}
// 8. DenseTensor Kernel Call{}
// 7. RecordOpInfoSupplement{}
// 8. Infer Local DenseTensor Meta{}
// 9. DenseTensor Kernel Call{}
// 10. Fallback{}
}}\n
// 9. Set Output Dist Attr For Default Impl{}\n
// 10. Return
// 11. Set Output Dist Attr For Default Impl{}\n
// 12. Return
{}
}}
"""
Expand Down Expand Up @@ -1462,6 +1463,26 @@ def generate_kernel_call_code(self) -> str:
result += MULTI_SINGLE_SET_DIST_OUT_DIMS.format(i, i)
return result

def generate_fallback_code(self) -> str:
fallback_code = ""
fallback_code += """
if (kernel_result.has_fallback_cpu) {"""
for kernel_out in self.dense_output_args:
fallback_code += f"""
TransDataBackend({kernel_out}, kernel_backend, {kernel_out});"""

inplace_flag = False
if len(self.inplace_map) > 0:
inplace_flag = True

fallback_code += self.reset_view_after_fallback(
self.outputs['types'], ' ', inplace_flag
)

fallback_code += """
}"""
return fallback_code

def generate_output_dist_attr_setting(self) -> str:
set_out_dist_attr_code = ""
if self.generate_general_infer_spmd is True:
Expand Down Expand Up @@ -1547,6 +1568,7 @@ def generate_auto_paralel_branch(self) -> str:
)
infer_meta_code = self.generate_infer_meta_code()
kernel_call_code = self.generate_kernel_call_code()
fallback_code = self.generate_fallback_code()
output_dist_attr_setting = self.generate_output_dist_attr_setting()
return_code = self.generate_return_code()

Expand All @@ -1560,6 +1582,7 @@ def generate_auto_paralel_branch(self) -> str:
record_op_info_supplement_code,
infer_meta_code,
kernel_call_code,
fallback_code,
output_dist_attr_setting,
return_code,
)
Expand Down
14 changes: 8 additions & 6 deletions paddle/phi/api/yaml/generator/dist_bw_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@
// 5. Select Kernel{}
// 6. Reshard Input{}\n
// 7. PrepareData (DataTransform & Prepare Dense Input){}
// RecordOpInfoSupplement
{}
// 8. Infer Local DenseTensor Meta{}
// 9. DenseTensor Kernel Call{}
// 8. RecordOpInfoSupplement{}
// 9. Infer Local DenseTensor Meta{}
// 10. DenseTensor Kernel Call{}
// 11. Fallback{}
}}
// 10. Reshard Kernel Output to API output{}\n
// 11. Return
// 12. Reshard Kernel Output to API output{}\n
// 13. Return
{}
}}
"""
Expand Down Expand Up @@ -310,6 +310,7 @@ def generate_auto_paralel_branch(self) -> str:
)
infer_meta_code = self.generate_infer_meta_code()
kernel_call_code = self.generate_kernel_call_code()
fallback_code = self.generate_fallback_code()
reshard_output_code = self.generate_reshard_output_code()
return_code = self.generate_return_code()

Expand All @@ -324,6 +325,7 @@ def generate_auto_paralel_branch(self) -> str:
record_op_info_supplement_code,
infer_meta_code,
kernel_call_code,
fallback_code,
reshard_output_code,
return_code,
)
Expand Down

0 comments on commit 8f40a4a

Please sign in to comment.