Skip to content

Commit

Permalink
【Hackathon 7th No.41】NO.41 为 Paddle 代码转换工具新增 API 转换规则(第 8 组) (#495)
Browse files Browse the repository at this point in the history
* add is_inference

* update

* update

* update

* add geometric_

* add cauchy_

* add random_

* update

* update

* add chi2

* add Constraint

* add Gamma

* update

* add Poisson LKJCholesky

* update

* add StudentT PositiveDefiniteTransform

* update

* update

* add remote

* add remote

* add remote

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* add DistributedOptimizer

* update

* update

* update

* update

* upadte api_matcher

* update
  • Loading branch information
decade-afk authored Oct 22, 2024
1 parent bf9fe55 commit 616434d
Show file tree
Hide file tree
Showing 16 changed files with 1,252 additions and 10 deletions.
149 changes: 145 additions & 4 deletions paconvert/api_mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -884,7 +884,22 @@
"memory_format"
]
},
"torch.Tensor.cauchy_": {},
"torch.Tensor.cauchy_": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.Tensor.cauchy_",
"min_input_args": 0,
"args_list": [
"median",
"sigma",
"*",
"generator"
],
"kwargs_change": {
"median": "loc",
"sigma":"scale",
"generator":""
}
},
"torch.Tensor.cdouble": {
"Matcher": "TensorCdoubleMatcher",
"paddle_api": "paddle.Tensor.astype",
Expand Down Expand Up @@ -1765,7 +1780,20 @@
"other": "y"
}
},
"torch.Tensor.geometric_": {},
"torch.Tensor.geometric_": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.Tensor.geometric_",
"min_input_args": 1,
"args_list": [
"p",
"*",
"generator"
],
"kwargs_change": {
"p": "probs",
"generator":""
}
},
"torch.Tensor.geqrf": {},
"torch.Tensor.ger": {
"Matcher": "GenericMatcher",
Expand Down Expand Up @@ -2142,7 +2170,17 @@
"paddle_api": "paddle.Tensor.is_floating_point",
"min_input_args": 0
},
"torch.Tensor.is_inference": {},
"torch.Tensor.is_inference": {
"Matcher": "Is_InferenceMatcher",
"min_input_args": 0
},
"torch.is_inference": {
"Matcher": "Is_InferenceMatcher",
"min_input_args": 1,
"args_list":[
"input"
]
},
"torch.Tensor.is_pinned": {
"Matcher": "Is_PinnedMatcher",
"min_input_args": 0
Expand Down Expand Up @@ -3220,7 +3258,22 @@
"Matcher": "UnchangeMatcher",
"min_input_args": 0
},
"torch.Tensor.random_": {},
"torch.Tensor.random_": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.Tensor.uniform_",
"min_input_args": 0,
"args_list": [
"from",
"to",
"*",
"generator"
],
"kwargs_change": {
"from": "min",
"to": "max",
"generator": ""
}
},
"torch.Tensor.ravel": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.Tensor.flatten",
Expand Down Expand Up @@ -6392,6 +6445,18 @@
],
"min_input_args": 1
},
"torch.distributed.rpc.remote":{
"Matcher": "RpcRemoteMatcher",
"paddle_api": "paddle.distributed.rpc.rpc_async",
"min_input_args": 2,
"args_list": [
"to",
"func",
"args",
"kwargs",
"timeout"
]
},
"torch.distributed.rpc.shutdown": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.distributed.rpc.shutdown",
Expand Down Expand Up @@ -6475,6 +6540,52 @@
},
"min_input_args": 2
},
"torch.distributions.lkj_cholesky.LKJCholesky":{
"Matcher": "LKJCholeskyMatcher",
"paddle_api": "paddle.distribution.LKJCholesky",
"min_input_args": 1,
"args_list": [
"dim",
"concentration",
"validate_args"
]
},
"torch.distributions.studentT.StudentT":{
"Matcher": "StudentTMatcher",
"paddle_api": "paddle.distribution.StudentT",
"min_input_args": 1,
"args_list": [
"df",
"loc",
"scale",
"validate_args"
],
"kwargs_change": {
"validate_args": ""
}
},
"torch.distributions.transforms.PositiveDefiniteTransform":{
"Matcher": "TransformsPositiveDefiniteTransformMatcher",
"min_input_args": 0,
"args_list": [
"cache_size"
],
"kwargs_change": {
"cache_size": ""
}
},
"torch.distributions.poisson.Poisson":{
"Matcher": "GenericMatcher",
"paddle_api": "paddle.distribution.Poisson",
"min_input_args": 1,
"args_list": [
"rate",
"validate_args"
],
"kwargs_change": {
"validate_args": ""
}
},
"torch.distributions.Bernoulli": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.distribution.Bernoulli",
Expand Down Expand Up @@ -6525,6 +6636,18 @@
"total_count": "1"
}
},
"torch.distributions.chi2.Chi2":{
"Matcher": "GenericMatcher",
"paddle_api": "paddle.distribution.Chi2",
"min_input_args": 1,
"args_list": [
"df",
"validate_args"
],
"kwargs_change": {
"validate_args": ""
}
},
"torch.distributions.Categorical": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.distribution.Categorical",
Expand Down Expand Up @@ -6567,6 +6690,24 @@
"cache_size": ""
}
},
"torch.distributions.constraints.Constraint" : {
"Matcher": "DistributionsConstrainMatcher",
"paddle_api": "paddle.distribution.constraint.Constraint",
"abstract": true
},
"torch.distributions.gamma.Gamma":{
"Matcher": "GenericMatcher",
"paddle_api": "paddle.distribution.Gamma",
"min_input_args": 2,
"args_list": [
"concentration",
"rate",
"validate_args"
],
"kwargs_change": {
"validate_args": ""
}
},
"torch.distributions.ContinuousBernoulli": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.distribution.ContinuousBernoulli",
Expand Down
150 changes: 150 additions & 0 deletions paconvert/api_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,127 @@ def get_paddle_nodes(self, args, kwargs):
for i in range(1, len(new_args)):
code = "{}({}, {})".format(self.get_paddle_api(), code, new_args[i])
return ast.parse(code).body


class StudentTMatcher(BaseMatcher):
def generate_aux_code(self):
API_TEMPLATE = textwrap.dedent(
"""
import paddle
class StudentT_Aux_Class:
def __init__(self, df, loc, scale):
self.df = paddle.to_tensor(df)
self.loc = paddle.to_tensor(loc)
self.scale = paddle.to_tensor(scale)
self.sT = paddle.distribution.StudentT(self.df, self.loc, self.scale)
def sample(self):
return paddle.reshape(self.sT.sample(), self.df.shape)
"""
)

return API_TEMPLATE
def generate_code(self, kwargs):
self.write_aux_code()
if "validate_args" in kwargs:
del kwargs["validate_args"]
if "loc" not in kwargs:
kwargs["loc"] = 0.1
if "scale" not in kwargs:
kwargs["scale"] = 1.0
kwargs = self.kwargs_to_str(kwargs)
API_TEMPLATE = textwrap.dedent(
"""
paddle_aux.StudentT_Aux_Class({})
"""
)
code = API_TEMPLATE.format(kwargs)
return code


class TransformsPositiveDefiniteTransformMatcher(BaseMatcher):
def generate_aux_code(self):
API_TEMPLATE = textwrap.dedent(
"""
import paddle
class TransformsPositiveDefiniteTransform:
def __call__(self, x):
x = x.tril(-1) + x.diagonal(axis1=-2, axis2=-1).exp().diag_embed()
return x @ x.T
def inv(self, y):
y = paddle.linalg.cholesky(y)
return y.tril(-1) + y.diagonal(axis1=-2, axis2=-1).log().diag_embed()
"""
)

return API_TEMPLATE
def generate_code(self, kwargs):
self.write_aux_code()
API_TEMPLATE = textwrap.dedent(
"""
paddle_aux.TransformsPositiveDefiniteTransform()
"""
)
return API_TEMPLATE


class LKJCholeskyMatcher(BaseMatcher):
def generate_aux_code(self):
API_TEMPLATE = textwrap.dedent(
"""
import paddle
class LKJCholesky_Aux_Class:
def __init__(self, dim, concentration, sample_method='onion'):
self.lkj = paddle.distribution.LKJCholesky(dim, concentration, sample_method)
def sample(self):
return paddle.unsqueeze(self.lkj.sample(), axis=0)
"""
)

return API_TEMPLATE
def generate_code(self, kwargs):
self.write_aux_code()
if "validate_args" in kwargs:
del kwargs["validate_args"]
kwargs = self.kwargs_to_str(kwargs)
API_TEMPLATE = textwrap.dedent(
"""
paddle_aux.LKJCholesky_Aux_Class({})
"""
)
code = API_TEMPLATE.format(kwargs)
return code



class Is_InferenceMatcher(BaseMatcher):
def generate_code(self, kwargs):
if "input" not in kwargs:
kwargs["input"] = self.paddleClass
code = "{}.stop_gradient".format(kwargs["input"])
return code


class DistributionsConstrainMatcher(BaseMatcher):
def generate_aux_code(self):
API_TEMPLATE = textwrap.dedent(
"""
import paddle
class DistributionsConstrain:
def check(self, value):
return paddle.distribution.constraint.Constraint()(value)
"""
)

return API_TEMPLATE
def generate_code(self, kwargs):
self.write_aux_code()
API_TEMPLATE = textwrap.dedent(
"""
paddle_aux.DistributionsConstrain()
"""
)
return API_TEMPLATE


class IInfoMatcher(BaseMatcher):
Expand Down Expand Up @@ -5145,6 +5266,35 @@ def generate_code(self, kwargs):
return code


class RpcRemoteMatcher(BaseMatcher):
def generate_aux_code(self):
CODE_TEMPLATE = textwrap.dedent(
"""
class rpc_remote:
def __init__(self, remote_obj):
self.remote = remote_obj
def to_here(self):
return self.remote.wait()
"""
)
return CODE_TEMPLATE

def generate_code(self, kwargs):
self.write_aux_code()
kwargs['fn'] = kwargs.pop('func')
kwargs = self.kwargs_to_str(kwargs)
API_TEMPLATE = textwrap.dedent(
"""
paddle_aux.rpc_remote(paddle.distributed.rpc.rpc_async({}))
"""
)
code = API_TEMPLATE.format(
kwargs
)
return code


class GetNumThreadsMatcher(BaseMatcher):
def generate_code(self, kwargs):
API_TEMPLATE = textwrap.dedent(
Expand Down
Loading

0 comments on commit 616434d

Please sign in to comment.