-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable newer Python versions and use of own test data #9
base: master
Are you sure you want to change the base?
Conversation
# Save parameters | ||
config = vars(args) | ||
with open(os.path.join(args.model_dir, "config.json"), "w") as outfile: | ||
json.dump(config, outfile) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added this in order to improve model loading, compared to the long file name with all parameters in the file name.
if do_full_eval == False: | ||
# random sample each context points in NeighborGraph() | ||
self.sample_neg_pts(ng_list) | ||
|
||
center_pred_embed = self.predict(ng_list) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this file, I added a predict
function for each model, which only embeds the point itself and not the positive and negative points. The predict
function is then used in the forward
functions. This looks like a lot of changes but it's mainly encapsulation of the center_pred_embed part
@@ -240,6 +240,136 @@ def make_args_combine(args): | |||
return args_combine | |||
|
|||
|
|||
def make_enc_dec(args, pointset=None, feature_embedding= None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not new code; it's only an encapsulation of code that was previously part of the __init__
of the Trainer
class. In the new version, it is a separate function, and can be used in the test.py
script.
Hi, thanks for providing this nice codebase! I am using it for an own project, and I did some modifications to the code in order to run it with later Python versions, and in order to use my own data. Maybe you want to consider merging my new version into your repository :) Here's a summary of the changes
Set
toset
,cPickle
topickle
, andxrange
torange
to enable Python 3 (I'm using 3.9.5)Trainer
class and to themodel.py
file. Maybe I overlooked some other option to do this; please let me know if that changes are redundant.