-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add interfaces and documentation for abstract indexing operations #93
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @eb8680 , some things aren't clear to me. I've added some comments about these, but I don't want to hold up progress. So hopefully this is helpful, and when we hit diminishing returns on any discussion I'll leave it to you to review. Most importantly there are tests that seem to do a good job capturing most of the intended semantics.
for which a value is defined:: | ||
|
||
>>> IndexSet(x={0, 1}, y={2, 3}}) | ||
{"x": {0, 1}, "y": {2, 3}} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The printed representatoin could be a little confusing, because it looks like calling IndexSet
returns a Dict
. Would it work to print this as IndexSet({"x": {0, 1}, "y": {2, 3}})
?
Ideally, if you copy a return value and paste it again at the REPL, you'll get the same value. [Is there a name for this property? I've never heard one]
Compute the union of multiple :class:`IndexSet` s | ||
as the union of their keys and of value sets at shared keys. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Based on this description, would it be better to call this union
? Or does it extend a Dict
method? If it's the latter, does
join(IndexSet(d1:Dict), IndexSet(d2:Dict)) == IndexSet(join(d1, d2))
?
... assert indices_of(X) == IndexSet() | ||
... assert indices_of(T) == IndexSet(T={0, 1}) | ||
... assert indices_of(Y) == IndexSet(T={0, 1}) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Assuming I'm understanding this right (wasn't obvious at first), maybe we could add something like
After the call to :func:`pyro.sample`, we'd have ``indices_of(T) == IndexSet()``. But following this, a call to :func:`intervene` in the `MultiWorldCounterfactual` context extends ``T``... | |
Or maybe in the comments, like
... T = pyro.sample("T", get_T_dist(X)) # `indices_of(T) == IndexSet()
... T = intervene(T, t, name="T") # After world-splitting, indices_of(T) == IndexSet({T={0, 1})
causal_pyro/primitives.py
Outdated
like ``torch.sparse`` or relational databases. | ||
|
||
However, this is beyond the scope of this library as it currently exists. | ||
Instead, :func:`gather` currently binds free variables in ``indexset`` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why indexset
and not IndexSet
?
causal_pyro/primitives.py
Outdated
>>> indices_of(Y) == IndexSet(T={0, 1}) | ||
True | ||
>>> Y0 = gather(Y, IndexSet(T={0})) | ||
>>> indices_of(Y0) == IndexSet() != IndexSet(T={0}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If IndexSet(T={0})
is different from IndexSet()
, it's not clear to me why X = pyro.sample("X", get_X_dist())
would lead to IndexSet(X) == IndexSet()
Addresses #80, #12.
First in a series of PRs extracted from the refactoring in #92.
This PR adds a new data structure
IndexSet
and some new operationsscatter
,gather
, andindices_of
for abstracting away many of the tensor-specific implementation details ofMultiWorldCounterfactual
, both internally and in model code. These operations are documented but not implemented in this PR, except forIndexSet
. Mostly-complete implementations can be found in the draft PR #92 and will be added in the next PR in this series.Tested:
IndexSet
that check some algebraic properties of their broadcasting operationjoin