Skip to content

Commit

Permalink
pnnx convert onnx logsoftmax/logsigmoid/mish/selu/sigmoid/silu/softmi…
Browse files Browse the repository at this point in the history
…n/softplus/softshrink/softsign/tanh/tanhshrink (#5581)
  • Loading branch information
nihui authored Jul 13, 2024
1 parent 1c40615 commit e7cae68
Show file tree
Hide file tree
Showing 34 changed files with 1,872 additions and 6 deletions.
73 changes: 73 additions & 0 deletions tools/pnnx/src/pass_level2/F_log_softmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,77 @@ pnnx.Output output 1 0 out

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_log_softmax, 10)

class F_log_softmax_onnx : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
3 2
pnnx.Input input_0 0 1 input
LogSoftmax op_0 1 1 input out axis=%dim
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "F.log_softmax";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_log_softmax_onnx, 10)

class F_log_softmax_onnx_1 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
5 4
pnnx.Input input_0 0 1 input
Transpose op_0 1 1 input a perm=%perm
LogSoftmax op_1 1 1 a b axis=%axis
Transpose op_2 1 1 b out perm=%perm
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "F.log_softmax";
}

bool match(const std::map<std::string, Parameter>& captured_params) const
{
const std::vector<int>& perm = captured_params.at("perm").ai;
const int axis = captured_params.at("axis").i;

if (axis >= (int)perm.size())
return false;

int excount = 0;
for (int i = 0; i < (int)perm.size(); i++)
{
if (perm[i] != i)
excount++;
}

if (excount != 2)
return false;

return true;
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
const std::vector<int>& perm = captured_params.at("perm").ai;
const int axis = captured_params.at("axis").i;

op->params["dim"] = perm[axis];
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_log_softmax_onnx_1, 9)

} // namespace pnnx
22 changes: 22 additions & 0 deletions tools/pnnx/src/pass_level2/F_logsigmoid.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,26 @@ pnnx.Output output 1 0 out

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_logsigmoid, 10)

class F_logsigmoid_onnx : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
aten::sigmoid op_0 1 1 input a
aten::log op_1 1 1 a out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "F.logsigmoid";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_logsigmoid_onnx, 9)

} // namespace pnnx
23 changes: 23 additions & 0 deletions tools/pnnx/src/pass_level2/F_mish.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,27 @@ pnnx.Output output 1 0 out

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_mish_1, 9)

class F_mish_onnx : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
5 4
pnnx.Input input 0 1 input
Softplus op_0 1 1 input a
aten::tanh op_1 1 1 a b
aten::mul op_2 2 1 input b out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "F.mish";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_mish_onnx, 9)

} // namespace pnnx
21 changes: 21 additions & 0 deletions tools/pnnx/src/pass_level2/F_selu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,25 @@ pnnx.Output output 1 0 out

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_selu, 10)

class F_selu_onnx : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
Selu op_0 1 1 input out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "F.selu";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_selu_onnx, 10)

} // namespace pnnx
22 changes: 22 additions & 0 deletions tools/pnnx/src/pass_level2/F_softmin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,26 @@ pnnx.Output output 1 0 out

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_softmin, 9)

class F_softmin_onnx : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
aten::neg op_0 1 1 input 6
Softmax op_1 1 1 6 out axis=%dim
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "F.softmin";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_softmin_onnx, 9)

} // namespace pnnx
58 changes: 58 additions & 0 deletions tools/pnnx/src/pass_level2/F_softplus.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,62 @@ pnnx.Output output 1 0 out

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_softplus, 10)

class F_softplus_onnx : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
3 2
pnnx.Input input_0 0 1 input
Softplus op_0 1 1 input out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "F.softplus";
}

void write(Operator* op, const std::map<std::string, Parameter>& /*captured_params*/) const
{
op->params["beta"] = 1.f;
op->params["threshold"] = 20.f;
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_softplus_onnx, 10)

class F_softplus_onnx_1 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
7 6
pnnx.Input input_0 0 1 input
prim::Constant op_0 0 1 beta value=%beta
aten::mul op_1 2 1 input beta a
Softplus op_2 1 1 a b
prim::Constant op_3 0 1 beta2 value=%beta
aten::div op_4 2 1 b beta2 out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "F.softplus";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
op->params["beta"] = captured_params.at("beta");
op->params["threshold"] = 20.f;
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_softplus_onnx_1, 9)

} // namespace pnnx
58 changes: 58 additions & 0 deletions tools/pnnx/src/pass_level2/F_softshrink.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,62 @@ pnnx.Output output 1 0 out

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_softshrink, 10)

static bool NearlyEqual(float a, float b, float epsilon)
{
if (a == b)
return true;

float diff = (float)fabs(a - b);
if (diff <= epsilon)
return true;

// relative error
return diff < epsilon * std::max(fabs(a), fabs(b));
}

class F_softshrink_onnx : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
15 14
pnnx.Input input 0 1 input
prim::Constant op_0 0 1 lambd value=%lambd
aten::gt op_1 2 1 input lambd 8
prim::Constant op_2 0 1 lambd2 value=%lambd
aten::sub op_3 2 1 input lambd2 9
prim::Constant op_4 0 1 zero value=0
aten::where op_5 3 1 8 9 zero a
prim::Constant op_6 0 1 mlambd value=%lambd2
aten::lt op_7 2 1 input mlambd 11
prim::Constant op_8 0 1 lambd3 value=%lambd
aten::add op_9 2 1 input lambd3 12
prim::Constant op_10 0 1 zero2 value=0
aten::where op_11 3 1 11 12 zero2 b
aten::add op_12 2 1 a b out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "F.softshrink";
}

bool match(const std::map<std::string, Parameter>& captured_params) const
{
float lambd = captured_params.at("lambd").f;
float lambd2 = captured_params.at("lambd2").f;
return NearlyEqual(lambd, -lambd2, 0.001);
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
op->params["lambd"] = captured_params.at("lambd");
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_softshrink_onnx, 10)

} // namespace pnnx
24 changes: 24 additions & 0 deletions tools/pnnx/src/pass_level2/F_softsign.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,28 @@ pnnx.Output output 1 0 out

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_softsign, 10)

class F_softsign_onnx : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
6 5
pnnx.Input input 0 1 input
aten::abs op_0 1 1 input 6
prim::Constant op_1 0 1 8 value=1
aten::add op_2 2 1 6 8 9
aten::div op_3 2 1 input 9 out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "F.softsign";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_softsign_onnx, 10)

} // namespace pnnx
22 changes: 22 additions & 0 deletions tools/pnnx/src/pass_level2/F_tanhshrink.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,26 @@ pnnx.Output output 1 0 out

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_tanhshrink, 9)

class F_tanhshrink_onnx : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
aten::tanh op_0 1 1 input 7
aten::sub op_1 2 1 input 7 out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "F.tanhshrink";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_tanhshrink_onnx, 9)

} // namespace pnnx
Loading

0 comments on commit e7cae68

Please sign in to comment.