Skip to content

Commit

Permalink
pnnx convert onnx sdap reduce min/max/mean/sum/prod (#5579)
Browse files Browse the repository at this point in the history
* pnnx convert onnx sdap

* test reduce
  • Loading branch information
nihui authored Jul 12, 2024
1 parent 3752d71 commit 1c40615
Show file tree
Hide file tree
Showing 11 changed files with 617 additions and 42 deletions.
91 changes: 91 additions & 0 deletions tools/pnnx/src/pass_level2/F_scaled_dot_product_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,4 +80,95 @@ pnnx.Output output 1 0 out

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_scaled_dot_product_attention_1, 10)

static bool NearlyEqual(float a, float b, float epsilon)
{
if (a == b)
return true;

float diff = (float)fabs(a - b);
if (diff <= epsilon)
return true;

// relative error
return diff < epsilon * std::max(fabs(a), fabs(b));
}

class F_scaled_dot_product_attention_onnx : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
12 11
pnnx.Input input_0 0 1 query
pnnx.Input input_1 0 1 key
pnnx.Input input_2 0 1 value
Transpose op_0 1 1 key kt perm=(0,1,3,2)
prim::Constant op_1 0 1 scale value=%sqrt_scale
aten::mul op_2 2 1 query scale q
prim::Constant op_3 0 1 scale2 value=%sqrt_scale
aten::mul op_4 2 1 kt scale2 k
MatMul op_5 2 1 q k qk
Softmax op_6 1 1 qk 4 axis=-1
MatMul op_7 2 1 4 value out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "F.scaled_dot_product_attention";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
op->params["dropout_p"] = 0.f;
op->params["is_causal"] = false;

const float sqrt_scale = captured_params.at("sqrt_scale").f;
const float scale = sqrt_scale * sqrt_scale;

op->params["scale"] = scale;

if (!op->inputs[0]->shape.empty())
{
const int embed_dim = op->inputs[0]->shape[op->inputs[0]->shape.size() - 1];
if (NearlyEqual(scale, 1.f / sqrt(embed_dim), 0.001))
{
// drop scale=None for compatibility with old torch
op->params.erase("scale");
}
}
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_scaled_dot_product_attention_onnx, 10)

class F_scaled_dot_product_attention_onnx_1 : public F_scaled_dot_product_attention_onnx
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
14 13
pnnx.Input input_0 0 1 query
pnnx.Input input_1 0 1 key
pnnx.Input input_2 0 1 value
pnnx.Input input_3 0 1 attn_mask
Transpose op_0 1 1 key kt perm=(0,1,3,2)
prim::Constant op_1 0 1 scale value=%sqrt_scale
aten::mul op_2 2 1 query scale q
prim::Constant op_3 0 1 scale2 value=%sqrt_scale
aten::mul op_4 2 1 kt scale2 k
MatMul op_5 2 1 q k qk
aten::add op_6 2 1 qk attn_mask qkm
Softmax op_7 1 1 qkm 4 axis=-1
MatMul op_8 2 1 4 value out
pnnx.Output output 1 0 out
)PNNXIR";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_scaled_dot_product_attention_onnx_1, 10)

} // namespace pnnx
80 changes: 65 additions & 15 deletions tools/pnnx/src/pass_level2/torch_max.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,30 +83,80 @@ pnnx.Output output 1 0 out
if (captured_params.find("op_0.axes") != captured_params.end())
{
op->params["dim"] = captured_params.at("op_0.axes");
}
else
{
// reduce all
const int input_rank = (int)op->inputs[0]->shape.size();
std::vector<int> dim(input_rank);
for (int i = 0; i < input_rank; i++)

if (captured_params.find("op_0.keepdims") != captured_params.end())
{
dim[i] = i;
op->params["keepdim"] = captured_params.at("op_0.keepdims").i ? true : false;
}
else
{
op->params["keepdim"] = true;
}
op->params["dim"] = dim;
}

if (captured_params.find("op_0.keepdims") != captured_params.end())
{
op->params["keepdim"] = captured_params.at("op_0.keepdims").i ? true : false;
}
else
{
op->params["keepdim"] = true;
// reduce all
}
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_max_onnx, 20)

class torch_max_onnx_1 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
ReduceMax op_0 1 1 input out %*=%*
ArgMax op_1 1 1 input indices %*=%*
pnnx.Output output 2 0 out indices
)PNNXIR";
}

const char* type_str() const
{
return "torch.max";
}

bool match(const std::map<std::string, Parameter>& captured_params) const
{
if (captured_params.find("op_0.axes") == captured_params.end())
return false;

if (captured_params.find("op_0.keepdims") == captured_params.end())
return false;

if (captured_params.find("op_1.axis") == captured_params.end())
return false;

if (captured_params.find("op_1.keepdims") == captured_params.end())
return false;

if (captured_params.at("op_0.axes").type != 5 || captured_params.at("op_0.axes").ai.size() != 1)
return false;

if (captured_params.at("op_1.axis").type != 2)
return false;

if (captured_params.at("op_0.axes").ai[0] != captured_params.at("op_1.axis").i)
return false;

if (captured_params.at("op_0.keepdims").i != captured_params.at("op_1.keepdims").i)
return false;

return true;
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
op->params["dim"] = captured_params.at("op_0.axes").ai[0];
op->params["keepdim"] = captured_params.at("op_0.keepdims").i ? true : false;
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_max_onnx_1, 19)

} // namespace pnnx
80 changes: 65 additions & 15 deletions tools/pnnx/src/pass_level2/torch_min.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,30 +83,80 @@ pnnx.Output output 1 0 out
if (captured_params.find("op_0.axes") != captured_params.end())
{
op->params["dim"] = captured_params.at("op_0.axes");
}
else
{
// reduce all
const int input_rank = (int)op->inputs[0]->shape.size();
std::vector<int> dim(input_rank);
for (int i = 0; i < input_rank; i++)

if (captured_params.find("op_0.keepdims") != captured_params.end())
{
dim[i] = i;
op->params["keepdim"] = captured_params.at("op_0.keepdims").i ? true : false;
}
else
{
op->params["keepdim"] = true;
}
op->params["dim"] = dim;
}

if (captured_params.find("op_0.keepdims") != captured_params.end())
{
op->params["keepdim"] = captured_params.at("op_0.keepdims").i ? true : false;
}
else
{
op->params["keepdim"] = true;
// reduce all
}
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_min_onnx, 20)

class torch_min_onnx_1 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
ReduceMin op_0 1 1 input out %*=%*
ArgMin op_1 1 1 input indices %*=%*
pnnx.Output output 2 0 out indices
)PNNXIR";
}

const char* type_str() const
{
return "torch.min";
}

bool match(const std::map<std::string, Parameter>& captured_params) const
{
if (captured_params.find("op_0.axes") == captured_params.end())
return false;

if (captured_params.find("op_0.keepdims") == captured_params.end())
return false;

if (captured_params.find("op_1.axis") == captured_params.end())
return false;

if (captured_params.find("op_1.keepdims") == captured_params.end())
return false;

if (captured_params.at("op_0.axes").type != 5 || captured_params.at("op_0.axes").ai.size() != 1)
return false;

if (captured_params.at("op_1.axis").type != 2)
return false;

if (captured_params.at("op_0.axes").ai[0] != captured_params.at("op_1.axis").i)
return false;

if (captured_params.at("op_0.keepdims").i != captured_params.at("op_1.keepdims").i)
return false;

return true;
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
op->params["dim"] = captured_params.at("op_0.axes").ai[0];
op->params["keepdim"] = captured_params.at("op_0.keepdims").i ? true : false;
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_min_onnx_1, 19)

} // namespace pnnx
32 changes: 21 additions & 11 deletions tools/pnnx/src/pass_level2/torch_prod.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,24 +58,34 @@ pnnx.Output output 1 0 out
return "torch.prod";
}

bool match(const std::map<std::string, Parameter>& captured_params) const
{
if (captured_params.find("op_0.axes") == captured_params.end())
return false;

if (captured_params.at("op_0.axes").type != 2 && captured_params.at("op_0.axes").type != 5)
return false;

if (captured_params.at("op_0.axes").type == 5 && captured_params.at("op_0.axes").ai.size() > 1)
return false;

return true;
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
if (captured_params.find("op_0.axes") != captured_params.end())
int dim;
if (captured_params.at("op_0.axes").type == 2)
{
op->params["dim"] = captured_params.at("op_0.axes");
dim = captured_params.at("op_0.axes").i;
}
else
else // if (captured_params.at("op_0.axes").type == 5)
{
// reduce all
const int input_rank = (int)op->inputs[0]->shape.size();
std::vector<int> dim(input_rank);
for (int i = 0; i < input_rank; i++)
{
dim[i] = i;
}
op->params["dim"] = dim;
dim = captured_params.at("op_0.axes").ai[0];
}

op->params["dim"] = dim;

if (captured_params.find("op_0.keepdims") != captured_params.end())
{
op->params["keepdim"] = captured_params.at("op_0.keepdims").i ? true : false;
Expand Down
8 changes: 7 additions & 1 deletion tools/pnnx/tests/onnx/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ pnnx_onnx_add_test(F_pad)
pnnx_onnx_add_test(F_prelu)
pnnx_onnx_add_test(F_relu)
pnnx_onnx_add_test(F_relu6)
# pnnx_onnx_add_test(F_scaled_dot_product_attention)
pnnx_onnx_add_test(F_scaled_dot_product_attention)
pnnx_onnx_add_test(F_sigmoid)
pnnx_onnx_add_test(F_softmax)
pnnx_onnx_add_test(F_upsample_bilinear)
Expand Down Expand Up @@ -103,3 +103,9 @@ pnnx_onnx_add_test(shufflenet_v2_x1_0)
pnnx_onnx_add_test(squeezenet1_1)
pnnx_onnx_add_test(swin_t)
pnnx_onnx_add_test(vit_b_32)

pnnx_onnx_add_test(torch_max)
pnnx_onnx_add_test(torch_mean)
pnnx_onnx_add_test(torch_min)
pnnx_onnx_add_test(torch_prod)
pnnx_onnx_add_test(torch_sum)
Loading

0 comments on commit 1c40615

Please sign in to comment.