Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add interfaces and documentation for abstract indexing operations #93

Merged
merged 6 commits into from
Mar 1, 2023

Conversation

eb8680
Copy link
Contributor

@eb8680 eb8680 commented Feb 16, 2023

Addresses #80, #12.

First in a series of PRs extracted from the refactoring in #92.

This PR adds a new data structure IndexSet and some new operations scatter, gather, and indices_of for abstracting away many of the tensor-specific implementation details of MultiWorldCounterfactual, both internally and in model code. These operations are documented but not implemented in this PR, except for IndexSet. Mostly-complete implementations can be found in the draft PR #92 and will be added in the next PR in this series.

Tested:

  • Unit tests for IndexSet that check some algebraic properties of their broadcasting operation join

@eb8680 eb8680 added enhancement New feature or request status:awaiting review Awaiting response from reviewer refactor labels Feb 16, 2023
Copy link
Contributor

@cscherrer cscherrer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @eb8680 , some things aren't clear to me. I've added some comments about these, but I don't want to hold up progress. So hopefully this is helpful, and when we hit diminishing returns on any discussion I'll leave it to you to review. Most importantly there are tests that seem to do a good job capturing most of the intended semantics.

for which a value is defined::

>>> IndexSet(x={0, 1}, y={2, 3}})
{"x": {0, 1}, "y": {2, 3}}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The printed representatoin could be a little confusing, because it looks like calling IndexSet returns a Dict. Would it work to print this as IndexSet({"x": {0, 1}, "y": {2, 3}})?

Ideally, if you copy a return value and paste it again at the REPL, you'll get the same value. [Is there a name for this property? I've never heard one]

Comment on lines +135 to +136
Compute the union of multiple :class:`IndexSet` s
as the union of their keys and of value sets at shared keys.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on this description, would it be better to call this union? Or does it extend a Dict method? If it's the latter, does
join(IndexSet(d1:Dict), IndexSet(d2:Dict)) == IndexSet(join(d1, d2))?

... assert indices_of(X) == IndexSet()
... assert indices_of(T) == IndexSet(T={0, 1})
... assert indices_of(Y) == IndexSet(T={0, 1})

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming I'm understanding this right (wasn't obvious at first), maybe we could add something like

Suggested change
After the call to :func:`pyro.sample`, we'd have ``indices_of(T) == IndexSet()``. But following this, a call to :func:`intervene` in the `MultiWorldCounterfactual` context extends ``T``...

Or maybe in the comments, like

        ...     T = pyro.sample("T", get_T_dist(X))  # `indices_of(T) == IndexSet()
        ...     T = intervene(T, t, name="T")        # After world-splitting, indices_of(T) == IndexSet({T={0, 1})

like ``torch.sparse`` or relational databases.

However, this is beyond the scope of this library as it currently exists.
Instead, :func:`gather` currently binds free variables in ``indexset``
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why indexset and not IndexSet?

>>> indices_of(Y) == IndexSet(T={0, 1})
True
>>> Y0 = gather(Y, IndexSet(T={0}))
>>> indices_of(Y0) == IndexSet() != IndexSet(T={0})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If IndexSet(T={0}) is different from IndexSet(), it's not clear to me why X = pyro.sample("X", get_X_dist()) would lead to IndexSet(X) == IndexSet()

@eb8680 eb8680 added status:awaiting response Awaiting response from creator and removed status:awaiting review Awaiting response from reviewer labels Feb 24, 2023
@eb8680 eb8680 added status:awaiting review Awaiting response from reviewer and removed status:awaiting response Awaiting response from creator labels Mar 1, 2023
@eb8680 eb8680 merged commit 7175eb4 into master Mar 1, 2023
@eb8680 eb8680 deleted the eb-indexset-primitives branch August 7, 2023 15:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request refactor status:awaiting review Awaiting response from reviewer
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants