Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cond wrapper #24

Closed
albertz opened this issue Aug 9, 2021 · 6 comments · Fixed by #102
Closed

Cond wrapper #24

albertz opened this issue Aug 9, 2021 · 6 comments · Fixed by #102
Milestone

Comments

@albertz
Copy link
Member

albertz commented Aug 9, 2021

Similar as #23, some API with with Cond(...) as cond_obj:, which corresponds to if ...:, and then some further with cond_obj.false_branch():, which corresponds to the else: branch.

Example:

x = ... # whatever
cond = ...  # scalar tensor, True or False
with Cond(cond) as cond_obj:
  y = mod_true_case(x)
with cond_obj.false_branch():
  y = mod_false_case(x)

But this is not so clear yet.

@albertz
Copy link
Member Author

albertz commented Aug 10, 2021

x = ... # whatever
cond = ...  # scalar tensor, True or False
with Cond(cond) as cond_obj:
  y = mod_true_case(x)
with cond_obj.false_branch():
  y = mod_false_case(x)

The example cannot work like this because it requires that we know about the Python local variable name, which we do not.

We need to specifically connect both the true-case value and the false-case value such that RETURNN can make it the same layer name.

Maybe just disallow to directly access y now outside the scope so this cannot be used accidentally like this.

And then require sth like:

x = ... # whatever
cond = ...  # scalar tensor, True or False
with Cond(cond) as cond_obj:
  y_true = mod_true_case(x)
with cond_obj.false_branch():
  y_false = mod_false_case(x)
  y = cond_obj.output(y_true, y_false)

So all outputs have to be passed through Cond.output explicitly, and combine the true-branch value with the false-branch value.

@JackTemaki
Copy link
Contributor

In the way it is written here it seems for me it is more complicated than necessary, because in the end you could reduce it just to:

x = ... # whatever
cond = ...  # scalar tensor, True or False
y_true = mod_true_case(x)
y_false = mod_false_case(x)
y = Cond(cond, y_true, y_false)

Which looks somewhat simpler but has the same expressiveness. As long as there is no possibility for syntax that directly looks like how if and else look like, I think it is easier to stick to something like this.

@albertz
Copy link
Member Author

albertz commented Oct 1, 2021

No, it cannot work like that, because it has to be inside some special scope, because you don't want to execute both mod_true_case and mod_false_case as it would be the case in your example (also if this would be equivalent PyTorch code).

@albertz
Copy link
Member Author

albertz commented Jan 5, 2022

Another API suggestion (draft):

with Cond(cond) as cond_obj:
  y = mod_true_case(x)
  cond_obj.else()
  y = mod_false_case(x)

In both the original suggestion and this draft, we have the problem that the first y would be marked as unused, and also it is not really defined that this is the output.

So, maybe:

with Cond(cond) as cond_obj:
  y = mod_true_case(x)
  cond_obj.else(y)
  y = mod_false_case(x)
  y = cond_obj.end(y)

The arguments to else and end would define the outputs of the true and false branch, and the returned value of end would yield the layer ref(s) which can be used outside.

The same written shorter:

with Cond(cond) as cond_obj:
  cond_obj.else(mod_true_case(x))
  y = cond_obj.end(mod_false_case(x))

Or maybe this is cleaner?

with Cond(cond) as cond_obj:
  cond_obj.true(mod_true_case(x))
  cond_obj.false(mod_false_case(x))
  y = cond_obj.result()

Maybe not because this does not indicate that true has the side effect of changing the scope to the false branch, while else above does indicate this.

@albertz
Copy link
Member Author

albertz commented Jan 5, 2022

Should this Cond generalize to the case that cond is not a scalar? In that case, it would wrap tf.where. Or it could split up the tensors to have separate code paths anyway.

@albertz
Copy link
Member Author

albertz commented Feb 4, 2022

Note, else is a reserved keyword, so the method must be called else_.

Another variant, somewhat more similar to the last example but reducing the overhead slightly:

with Cond(cond) as cond_obj:
  cond_obj.true = mod_true_case(x)
  cond_obj.false = mod_false_case(x)
  y = cond_obj.result

Re: the side effects of true = ... and false = ... assignment, I think when this is properly documented, this might not be much an issue. I think no matter what API we end up here, anything is fine, as you will easily get used to it. But it should be readable (meaning of the code should be clear, even when not remembering the exact API) and straightforward to use (straightforward = no mental overhead needed, when you know the API).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants