Skip to content

Commit

Permalink
WIP: adding new models in this file
Browse files Browse the repository at this point in the history
  • Loading branch information
Louis-Mozart committed Mar 21, 2024
1 parent c65abc2 commit de48fca
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion dicee/static_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import datetime
from typing import Tuple, List
from .models import CMult, Pyke, DistMult, KeciBase, Keci, TransE, DeCaL,\
ComplEx, AConEx, AConvO, AConvQ, ConvQ, ConvO, ConEx, QMult, OMult, Shallom, LFMult
ComplEx, AConEx, AConvO, AConvQ, ConvQ, ConvO, ConEx, QMult, OMult, Shallom, LFMult, FMult, PolyMult
from .models.pykeen_models import PykeenKGE
from .models.transformers import BytE
import time
Expand Down Expand Up @@ -421,6 +421,12 @@ def intialize_model(args: dict,verbose=0) -> Tuple[object, str]:
elif model_name == 'DeCaL':
model =DeCaL(args=args)
form_of_labelling = 'EntityPrediction'
elif model_name == 'FMult':
model =FMult(args=args)
form_of_labelling = 'EntityPrediction'
elif model_name == 'PolyMult':
model =PolyMult(args=args)
form_of_labelling = 'EntityPrediction'
else:
raise ValueError(f"--model_name: {model_name} is not found.")
return model, form_of_labelling
Expand Down

0 comments on commit de48fca

Please sign in to comment.