Skip to content

Commit

Permalink
adjust pass order and follow naming conventions
Browse files Browse the repository at this point in the history
  • Loading branch information
TR666 committed Jan 3, 2024
1 parent d0df7b5 commit 8d67748
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 20 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ if(WITH_XPU)
${XPU_PASS_DEPS})
pass_library(qk_qkv_attention_xpu_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
pass_library(vis_decoder_attention_xpu_fuse_pass inference DIR xpu DEPS
pass_library(decoder_attention_xpu_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
pass_library(multi_encoder_xpu_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/ir/xpu/vis_decoder_attention_xpu_fuse_pass.h"
#include "paddle/fluid/framework/ir/xpu/decoder_attention_xpu_fuse_pass.h"

#include "glog/logging.h"

Expand All @@ -27,9 +27,9 @@ namespace ir {

namespace patterns {

struct VisDecoderAttentionFusePattern : public PatternBase {
VisDecoderAttentionFusePattern(PDPattern* pattern,
const std::string& name_scope);
struct DecoderAttentionFusePattern : public PatternBase {
DecoderAttentionFusePattern(PDPattern* pattern,
const std::string& name_scope);

// declare operator node's name
PATTERN_DECL_NODE(reshape2_1);
Expand Down Expand Up @@ -63,7 +63,7 @@ struct VisDecoderAttentionFusePattern : public PatternBase {
PATTERN_DECL_NODE(output);
};

VisDecoderAttentionFusePattern::VisDecoderAttentionFusePattern(
DecoderAttentionFusePattern::DecoderAttentionFusePattern(
PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, name_scope) {
auto* input_q = pattern->NewNode(input_q_repr())
Expand Down Expand Up @@ -179,16 +179,16 @@ VisDecoderAttentionFusePattern::VisDecoderAttentionFusePattern(

} // namespace patterns

void VisDecoderAttentionXPUFusePass::ApplyVisDecoderAttentionXPUFuse(
void DecoderAttentionXPUFusePass::ApplyDecoderAttentionXPUFuse(
ir::Graph* graph) const {
GraphPatternDetector gpd;
patterns::VisDecoderAttentionFusePattern pattern(gpd.mutable_pattern(),
name_scope_);
patterns::DecoderAttentionFusePattern pattern(gpd.mutable_pattern(),
name_scope_);
int found_subgraph_count = 0;

auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle VisDecoderAttentionXPUFusePass";
VLOG(4) << "handle DecoderAttentionXPUFusePass";

// declare operator node's name
GET_IR_NODE(reshape2_1);
Expand Down Expand Up @@ -292,22 +292,22 @@ void VisDecoderAttentionXPUFusePass::ApplyVisDecoderAttentionXPUFuse(
AddStatis(found_subgraph_count);
}

void VisDecoderAttentionXPUFusePass::ApplyImpl(ir::Graph* graph) const {
void DecoderAttentionXPUFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
Init(name_scope_, graph);

ApplyVisDecoderAttentionXPUFuse(graph);
ApplyDecoderAttentionXPUFuse(graph);
}

} // namespace ir
} // namespace framework
} // namespace paddle

REGISTER_PASS(vis_decoder_attention_xpu_fuse_pass,
paddle::framework::ir::VisDecoderAttentionXPUFusePass);
REGISTER_PASS(decoder_attention_xpu_fuse_pass,
paddle::framework::ir::DecoderAttentionXPUFusePass);

REGISTER_PASS_CAPABILITY(vis_decoder_attention_xpu_fuse_pass)
REGISTER_PASS_CAPABILITY(decoder_attention_xpu_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
"qkv_attention_xpu", 0));
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,12 @@ Fused subgraph:
*/

class VisDecoderAttentionXPUFusePass : public FusePassBase {
class DecoderAttentionXPUFusePass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;

private:
void ApplyVisDecoderAttentionXPUFuse(ir::Graph* graph) const;
void ApplyDecoderAttentionXPUFuse(ir::Graph* graph) const;

const std::string name_scope_{"vis_decoder_attention_xpu_fuse_pass"};
};
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/inference/api/paddle_pass_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -538,12 +538,12 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"generate_sequence_xpu_fuse_pass",
"embedding_with_eltwise_add_xpu_fuse_pass",
"qk_qkv_attention_xpu_fuse_pass",
"vis_decoder_attention_xpu_fuse_pass",
"multi_encoder_xpu_fuse_pass",
"multi_encoder_xpu_adaptive_seqlen_fuse_pass",
"multi_encoder_xpu_slice_fuse_pass",
"fused_multi_transformer_cachekv_layout_trans_pass",
"fused_multi_transformer_int8_cachekv_layout_trans_pass",
"decoder_attention_xpu_fuse_pass",
"one_beam_size_fuse_pass",
"fold_interp_outsize_fuse_pass",
"fold_two_squeeze2_fuse_pass",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from program_config import OpConfig, ProgramConfig, TensorConfig


class TestVisDecoderAttentionXPUFusePass(PassAutoScanTest):
class TestDecoderAttentionXPUFusePass(PassAutoScanTest):
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_xpu=True)
yield config, ["qkv_attention_xpu"], (1e-1, 1e-1)
Expand Down Expand Up @@ -164,7 +164,7 @@ def test(self):
self.run_and_statis(
quant=False,
max_examples=25,
passes=["vis_decoder_attention_xpu_fuse_pass"],
passes=["decoder_attention_xpu_fuse_pass"],
)


Expand Down

0 comments on commit 8d67748

Please sign in to comment.