Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance LDA tutorial #2790

Merged
merged 3 commits into from
Mar 29, 2021
Merged

Enhance LDA tutorial #2790

merged 3 commits into from
Mar 29, 2021

Conversation

fehiepsi
Copy link
Member

@fehiepsi fehiepsi commented Mar 29, 2021

This PR fixes some issues in LDA tutorial and adds some enhancements to make the result better.

Fixes

  • In cell 2, change
pyro.sample('counts', dist.Multinomial(len(counts), theta), obs=counts)

to

total_count = int(counts.sum())
pyro.sample('counts', dist.Multinomial(total_count, theta), obs=counts)
  • Above cell 9, change log-normal to logistic-normal
To address this issue, they used an encoder network that approximates the Dirichlet prior p(θ|α) with a log-normal distribution.

to

To address this issue, they used an encoder network that approximates the Dirichlet prior 𝑝(𝜃|𝛼) with a logistic-normal distribution (more precisely, this is softmax-normal distribution).

because log-normal cannot be used to approximate Dirichlet. The reference also talked about logistic-normal, not log-normal.

  • Add bias=False to Decoder.beta to match the discussion: wn|β,θ∼Categorical(σ(βθ)). Otherwise, we should change the discussion text to wn|β,θ∼Categorical(σ(βθ + bias)). bias=False matches the behavior of the original implementation of ProdLDA.
  • Fix total_count argument and remove to_event(1) at
pyro.sample('obs', dist.Multinomial(docs.shape[1], count_param).to_event(1), obs=docs)

Using to_event(1) here will give us a wrong model (Multinomal already has event_shape=1). Empirically, in the tutorial, epoch_loss=1.12e+07 while after the fix epoch_loss=3.72e+05.

  • Make it clear that the output of Encoder is logtheta, rather than theta.

Enhancements

  • Using affine=False in BatchNorm1d: I got no luck with affine=True. The inference seems overfitting with those extra parameters of affine=True and the result topics do not make much sense.
  • Using stop_words='english' at cell 7 seems to help. The number of unique words is reduced from 12999 to 12722 and the words his/he/was/... are removed, which is a nice preprocessing improvement IMO.

Result

According to the notebook, the word cloud topics are more coherent than the current one. IMO, the result is pretty good now. :) cc @ucals

Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for improving this!

Fix total_count argument and remove to_event(1) at pyro.sample('obs', ...)

It's disconcerting that validation missed this. Do you think there's any way we could have strengthened model validation so as to catch this otherwise silent error?

@fritzo fritzo merged commit d8fdfb0 into pyro-ppl:dev Mar 29, 2021
@fritzo
Copy link
Member

fritzo commented Mar 29, 2021

@fehiepsi does this run fine in the latest Pyro release? If so I'll push it and update the website.

@fehiepsi
Copy link
Member Author

strengthened model validation so as to catch this otherwise silent error?

We can remove this relaxation when Multinomial supports inhomogeneous total_count. Then I think this issue can be detected if users do not provide a correct total_count.

does this run fine in the latest Pyro release?

Yes, it gave the same result.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants