-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG] Lda training visualization in visdom #1399
Changes from 2 commits
bb65439
9d2e78d
33818ec
281222c
c507bbb
6f75ccc
d9db4e2
cd5f822
f4728e0
40cf092
d4f69f5
fde7d4d
3f18076
546908e
651a61a
13dfddc
1376d90
44c8e58
92949a3
5b22e4d
c369fc5
a32960d
48526d9
adf2a60
a272090
d3389bb
96949f7
7d0f0ec
dcc64a1
47434f9
30c9b64
e55af47
df5e01f
b334c50
c54e6bf
5f3d902
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,19 +24,22 @@ class Metric(object): | |
def __init__(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need to define empty |
||
pass | ||
|
||
def get_value(self, **parameters): | ||
def set_parameters(self, **parameters): | ||
""" | ||
Set the parameters | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I'll replace with |
||
""" | ||
for parameter, value in parameters.items(): | ||
setattr(self, parameter, value) | ||
|
||
def get_value(self): | ||
pass | ||
|
||
|
||
class CoherenceMetric(Metric): | ||
""" | ||
Metric class for coherence evaluation | ||
""" | ||
def __init__(self, corpus=None, texts=None, dictionary=None, coherence=None, window_size=None, topn=None, logger="shell", viz_env=None, title=None): | ||
def __init__(self, corpus=None, texts=None, dictionary=None, coherence=None, window_size=None, topn=10, logger=None, viz_env=None, title=None): | ||
""" | ||
Args: | ||
corpus : Gensim document corpus. | ||
|
@@ -98,7 +101,7 @@ def get_value(self, **kwargs): | |
# only one of the model or topic would be defined | ||
self.model = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why should you do this assignment? (only in current Callback) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As both |
||
self.topics = None | ||
super(CoherenceMetric, self).get_value(**kwargs) | ||
super(CoherenceMetric, self).set_parameters(**kwargs) | ||
cm = gensim.models.CoherenceModel(self.model, self.topics, self.texts, self.corpus, self.dictionary, self.window_size, self.coherence, self.topn) | ||
return cm.get_coherence() | ||
|
||
|
@@ -107,7 +110,7 @@ class PerplexityMetric(Metric): | |
""" | ||
Metric class for perplexity evaluation | ||
""" | ||
def __init__(self, corpus=None, logger="shell", viz_env=None, title=None): | ||
def __init__(self, corpus=None, logger=None, viz_env=None, title=None): | ||
""" | ||
Args: | ||
corpus : Gensim document corpus | ||
|
@@ -127,7 +130,7 @@ def get_value(self, **kwargs): | |
Args: | ||
model : Trained topic model | ||
""" | ||
super(PerplexityMetric, self).get_value(**kwargs) | ||
super(PerplexityMetric, self).set_parameters(**kwargs) | ||
corpus_words = sum(cnt for document in self.corpus for _, cnt in document) | ||
perwordbound = self.model.bound(self.corpus) / corpus_words | ||
return np.exp2(-perwordbound) | ||
|
@@ -137,7 +140,7 @@ class DiffMetric(Metric): | |
""" | ||
Metric class for topic difference evaluation | ||
""" | ||
def __init__(self, distance="jaccard", num_words=100, n_ann_terms=10, normed=True, logger="shell", viz_env=None, title=None): | ||
def __init__(self, distance="jaccard", num_words=100, n_ann_terms=10, normed=True, logger=None, viz_env=None, title=None): | ||
""" | ||
Args: | ||
distance : measure used to calculate difference between any topic pair. Available values: | ||
|
@@ -167,7 +170,7 @@ def get_value(self, **kwargs): | |
model : Trained topic model | ||
other_model : second topic model instance to calculate the difference from | ||
""" | ||
super(DiffMetric, self).get_value(**kwargs) | ||
super(DiffMetric, self).set_parameters(**kwargs) | ||
diff_matrix, _ = self.model.diff(self.other_model, self.distance, self.num_words, self.n_ann_terms, self.normed) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Now you can use new version for diff (with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated |
||
return np.diagonal(diff_matrix) | ||
|
||
|
@@ -176,7 +179,7 @@ class ConvergenceMetric(Metric): | |
""" | ||
Metric class for convergence evaluation | ||
""" | ||
def __init__(self, distance="jaccard", num_words=100, n_ann_terms=10, normed=True, logger="shell", viz_env=None, title=None): | ||
def __init__(self, distance="jaccard", num_words=100, n_ann_terms=10, normed=True, logger=None, viz_env=None, title=None): | ||
""" | ||
Args: | ||
distance : measure used to calculate difference between any topic pair. Available values: | ||
|
@@ -206,7 +209,7 @@ def get_value(self, **kwargs): | |
model : Trained topic model | ||
other_model : second topic model instance to calculate the difference from | ||
""" | ||
super(ConvergenceMetric, self).get_value(**kwargs) | ||
super(ConvergenceMetric, self).set_parameters(**kwargs) | ||
diff_matrix, _ = self.model.diff(self.other_model, self.distance, self.num_words, self.n_ann_terms, self.normed) | ||
return np.sum(np.diagonal(diff_matrix)) | ||
|
||
|
@@ -257,10 +260,16 @@ def on_epoch_end(self, epoch, topics=None): | |
epoch : current epoch no. | ||
topics : topic distribution from current epoch (required for coherence of unsupported topic models) | ||
""" | ||
# stores current epoch's metric values | ||
current_metrics = {} | ||
|
||
# plot all metrics in current epoch | ||
for i, metric in enumerate(self.metrics): | ||
value = metric.get_value(topics=topics, model=self.model, other_model=self.previous) | ||
metric_label = type(metric).__name__[:-6] | ||
metric_label = type(metric).__name__ | ||
|
||
current_metrics[metric_label] = value | ||
|
||
# check for any metric which need model state from previous epoch | ||
if isinstance(metric, (DiffMetric, ConvergenceMetric)): | ||
self.previous = copy.deepcopy(self.model) | ||
|
@@ -269,24 +278,27 @@ def on_epoch_end(self, epoch, topics=None): | |
if epoch == 0: | ||
if value.ndim > 0: | ||
diff_mat = np.array([value]) | ||
viz_metric = self.viz.heatmap(X=diff_mat.T, env=metric.viz_env, opts=dict(xlabel='Epochs', ylabel=metric_label, title=metric.title)) | ||
viz_metric = self.viz.heatmap(X=diff_mat.T, env=metric.viz_env, opts=dict(xlabel='Epochs', ylabel=metric_label[:-6], title=metric.title)) | ||
# store current epoch's diff diagonal | ||
self.diff_mat.put(diff_mat) | ||
# saving initial plot window | ||
self.windows.append(copy.deepcopy(viz_metric)) | ||
else: | ||
viz_metric = self.viz.line(Y=np.array([value]), X=np.array([epoch]), env=metric.viz_env, opts=dict(xlabel='Epochs', ylabel=metric_label, title=metric.title)) | ||
viz_metric = self.viz.line(Y=np.array([value]), X=np.array([epoch]), env=metric.viz_env, opts=dict(xlabel='Epochs', ylabel=metric_label[:-6], title=metric.title)) | ||
# saving initial plot window | ||
self.windows.append(copy.deepcopy(viz_metric)) | ||
else: | ||
if value.ndim > 0: | ||
# concatenate with previous epoch's diff diagonals | ||
diff_mat = np.concatenate((self.diff_mat.get(), np.array([value]))) | ||
self.viz.heatmap(X=diff_mat.T, env=metric.viz_env, win=self.windows[i], opts=dict(xlabel='Epochs', ylabel=metric_label, title=metric.title)) | ||
self.viz.heatmap(X=diff_mat.T, env=metric.viz_env, win=self.windows[i], opts=dict(xlabel='Epochs', ylabel=metric_label[:-6], title=metric.title)) | ||
self.diff_mat.put(diff_mat) | ||
else: | ||
self.viz.updateTrace(Y=np.array([value]), X=np.array([epoch]), env=metric.viz_env, win=self.windows[i]) | ||
|
||
if metric.logger == "shell": | ||
statement = "".join(("Epoch ", str(epoch), ": ", metric_label, " estimate: ", str(value))) | ||
statement = "".join(("Epoch ", str(epoch), ": ", metric_label[:-6], " estimate: ", str(value))) | ||
self.log_type.info(statement) | ||
|
||
return current_metrics | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -631,8 +631,13 @@ def rho(): | |
return pow(offset + pass_ + (self.num_updates / chunksize), -decay) | ||
|
||
if self.callbacks: | ||
# pass the list of input callbacks to Callback class | ||
callback = Callback(self.callbacks) | ||
callback.set_model(self) | ||
# initialize metrics dict to store metric values after every epoch | ||
self.metrics = {} | ||
for metric in self.callbacks: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Dict comprehension more readable? Also, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated to use |
||
self.metrics[type(metric).__name__] = [] | ||
|
||
for pass_ in xrange(passes): | ||
if self.dispatcher: | ||
|
@@ -686,8 +691,11 @@ def rho(): | |
if reallen != lencorpus: | ||
raise RuntimeError("input corpus size changed during training (don't use generators as input)") | ||
|
||
# append current epoch's metric values | ||
if self.callbacks: | ||
callback.on_epoch_end(pass_) | ||
current_metrics = callback.on_epoch_end(pass_) | ||
for metric, value in current_metrics.items(): | ||
self.metrics[metric].append(value) | ||
|
||
if dirty: | ||
# finish any remaining updates | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
os.path.join
more standard.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated