Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

This method doesn't work properly after migrating to Tensorlfow 2 #3

Open
soran-ghaderi opened this issue May 11, 2022 · 0 comments
Open
Labels
bug Something isn't working enhancement New feature or request help wanted Extra attention is needed

Comments

@soran-ghaderi
Copy link
Member

This method is bound to the _calibrate method working properly.

def _predict_proba(self, X):
"""Predicts probabilities using the Platt scaling model (after calibration).
Model must be calibrated beforehand with the ``calibrate`` method.
:param X: Numpy array of triples to be evaluated.
:type X: ndarray, shape [n, 3]
:return: Probability of each triple to be true according to the Platt scaling calibration.
:rtype: ndarray, shape [n, 3]
"""
if not self.is_calibrated:
msg = "Model has not been calibrated. Please call `model.calibrate(...)` before predicting probabilities."
logger.error(msg)
raise RuntimeError(msg)
# tf.reset_default_graph()
self._load_model_from_trained_params()
w = tf.Variable(self.calibration_parameters[0], dtype=tf.float32, trainable=False)
b = tf.Variable(self.calibration_parameters[1], dtype=tf.float32, trainable=False)
x_idx = to_idx(X, ent_to_idx=self.ent_to_idx, rel_to_idx=self.rel_to_idx)
x_tf = tf.Variable(x_idx, dtype=tf.int32, trainable=False)
e_s, e_p, e_o = self._lookup_embeddings(x_tf)
scores = self._fn(e_s, e_p, e_o)
logits = -(w * scores + b)
probas = tf.sigmoid(logits)
# with tf.Session(config=self.tf_config) as sess:
# sess.run(tf.global_variables_initializer())
# return sess.run(probas)
return probas

@soran-ghaderi soran-ghaderi added bug Something isn't working enhancement New feature or request help wanted Extra attention is needed labels May 11, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working enhancement New feature or request help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

1 participant