Skip to content

Commit

Permalink
Refactoring the process of downloading pretrained models
Browse files Browse the repository at this point in the history
  • Loading branch information
Demirrr committed Mar 14, 2024
1 parent 7695845 commit a06f9ae
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 27 deletions.
41 changes: 16 additions & 25 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,32 +109,12 @@ Under the hood, dicee executes the run.py script and uses [lightning](https://li
# Two equivalent executions
# (1)
dicee --dataset_dir "KGs/UMLS" --model Keci --eval_model "train_val_test"
# Evaluate Keci on Train set: Evaluate Keci on Train set
# {'H@1': 0.9518788343558282, 'H@3': 0.9988496932515337, 'H@10': 1.0, 'MRR': 0.9753123402351737}
# Evaluate Keci on Validation set: Evaluate Keci on Validation set
# {'H@1': 0.6932515337423313, 'H@3': 0.9041411042944786, 'H@10': 0.9754601226993865, 'MRR': 0.8072362996241839}
# Evaluate Keci on Test set: Evaluate Keci on Test set
# {'H@1': 0.6951588502269289, 'H@3': 0.9039334341906202, 'H@10': 0.9750378214826021, 'MRR': 0.8064032293278861}

# (2)
CUDA_VISIBLE_DEVICES=0,1 python dicee/scripts/run.py --trainer PL --dataset_dir "KGs/UMLS" --model Keci --eval_model "train_val_test"
# Evaluate Keci on Train set: Evaluate Keci on Train set
# {'H@1': 0.9518788343558282, 'H@3': 0.9988496932515337, 'H@10': 1.0, 'MRR': 0.9753123402351737}
# Evaluate Keci on Train set: Evaluate Keci on Train set
# Evaluate Keci on Validation set: Evaluate Keci on Validation set
# {'H@1': 0.6932515337423313, 'H@3': 0.9041411042944786, 'H@10': 0.9754601226993865, 'MRR': 0.8072362996241839}
# Evaluate Keci on Test set: Evaluate Keci on Test set
# {'H@1': 0.6951588502269289, 'H@3': 0.9039334341906202, 'H@10': 0.9750378214826021, 'MRR': 0.8064032293278861}
```
Similarly, models can be easily trained with torchrun
```bash
torchrun --standalone --nnodes=1 --nproc_per_node=gpu dicee/scripts/run.py --trainer torchDDP --dataset_dir "KGs/UMLS" --model Keci --eval_model "train_val_test"
# Evaluate Keci on Train set: Evaluate Keci on Train set: Evaluate Keci on Train set
# {'H@1': 0.9518788343558282, 'H@3': 0.9988496932515337, 'H@10': 1.0, 'MRR': 0.9753123402351737}
# Evaluate Keci on Validation set: Evaluate Keci on Validation set
# {'H@1': 0.6932515337423313, 'H@3': 0.9041411042944786, 'H@10': 0.9754601226993865, 'MRR': 0.8072499937521418}
# Evaluate Keci on Test set: Evaluate Keci on Test set
{'H@1': 0.6951588502269289, 'H@3': 0.9039334341906202, 'H@10': 0.9750378214826021, 'MRR': 0.8064032293278861}
```
You can also train a model in multi-node multi-gpu setting.
```bash
Expand All @@ -144,7 +124,7 @@ torchrun --nnodes 2 --nproc_per_node=gpu --node_rank 1 --rdzv_id 455 --rdzv_bac
Train a KGE model by providing the path of a single file and store all parameters under newly created directory
called `KeciFamilyRun`.
```bash
dicee --path_single_kg "KGs/Family/family-benchmark_rich_background.owl" --model Keci --path_to_store_single_run KeciFamilyRun --backend rdflib
dicee --path_single_kg "KGs/Family/family-benchmark_rich_background.owl" --model Keci --path_to_store_single_run KeciFamilyRun --backend rdflib --eval_model None
```
where the data is in the following form
```bash
Expand All @@ -153,6 +133,11 @@ _:1 <http://www.w3.org/1999/02/22-rdf-syntax-ns#type> <http://www.w3.org/2002/07
<http://www.benchmark.org/family#hasChild> <http://www.w3.org/1999/02/22-rdf-syntax-ns#type> <http://www.w3.org/2002/07/owl#ObjectProperty> .
<http://www.benchmark.org/family#hasParent> <http://www.w3.org/1999/02/22-rdf-syntax-ns#type> <http://www.w3.org/2002/07/owl#ObjectProperty> .
```
**Continual Training:** the training phase of a pretrained model can be resumed.
```bash
dicee --continual_learning KeciFamilyRun --path_single_kg "KGs/Family/family-benchmark_rich_background.owl" --model Keci --path_to_store_single_run KeciFamilyRun --backend rdflib --eval_model None
```

**Apart from n-triples or standard link prediction dataset formats, we support ["owl", "nt", "turtle", "rdf/xml", "n3"]***.
Moreover, a KGE model can be also trained by providing **an endpoint of a triple store**.
```bash
Expand Down Expand Up @@ -286,16 +271,22 @@ pre_trained_kge.predict_topk(r=[".."],t=[".."],topk=10)

## Downloading Pretrained Models

We provide plenty pretrained knowledge graph embedding models at [dice-research.org/projects/DiceEmbeddings/](https://files.dice-research.org/projects/DiceEmbeddings/).
<details> <summary> To see a code snippet </summary>

```python
from dicee import KGE
# (1) Load a pretrained ConEx on DBpedia
model = KGE(url="https://files.dice-research.org/projects/DiceEmbeddings/KINSHIP-Keci-dim128-epoch256-KvsAll")
mure = KGE(url="https://files.dice-research.org/projects/DiceEmbeddings/YAGO3-10-Pykeen_MuRE-dim128-epoch256-KvsAll")
quate = KGE(url="https://files.dice-research.org/projects/DiceEmbeddings/YAGO3-10-Pykeen_QuatE-dim128-epoch256-KvsAll")
keci = KGE(url="https://files.dice-research.org/projects/DiceEmbeddings/YAGO3-10-Keci-dim128-epoch256-KvsAll")
quate.predict_topk(h=["Mongolia"],r=["isLocatedIn"],topk=3)
# [('Asia', 0.9894362688064575), ('Europe', 0.01575559377670288), ('Tadanari_Lee', 0.012544365599751472)]
keci.predict_topk(h=["Mongolia"],r=["isLocatedIn"],topk=3)
# [('Asia', 0.6522021293640137), ('Chinggis_Khaan_International_Airport', 0.36563414335250854), ('Democratic_Party_(Mongolia)', 0.19600993394851685)]
mure.predict_topk(h=["Mongolia"],r=["isLocatedIn"],topk=3)
# [('Asia', 0.9996906518936157), ('Ulan_Bator', 0.0009907372295856476), ('Philippines', 0.0003116439620498568)]
```

- For more please look at [dice-research.org/projects/DiceEmbeddings/](https://files.dice-research.org/projects/DiceEmbeddings/)

</details>

## How to Deploy
Expand Down
17 changes: 15 additions & 2 deletions dicee/static_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,19 @@ def download_file(url, destination_folder="."):
print(f"Failed to download: {url}")


def download_files_from_url(base_url, destination_folder="."):
def download_files_from_url(base_url:str, destination_folder=".")->None:
"""
Parameters
----------
base_url: e.g. "https://files.dice-research.org/projects/DiceEmbeddings/KINSHIP-Keci-dim128-epoch256-KvsAll"
destination_folder: e.g. "KINSHIP-Keci-dim128-epoch256-KvsAll"
Returns
-------
"""
# lazy import
from bs4 import BeautifulSoup

Expand All @@ -639,7 +651,8 @@ def download_files_from_url(base_url, destination_folder="."):
hrefs = [i for i in hrefs if len(i) > 3 and "." in i]
for file_url in hrefs:
download_file(base_url + "/" + file_url, destination_folder)

else:
print("ERROR:", response.status_code)

def download_pretrained_model(url: str) -> str:
assert url[-1] != "/"
Expand Down

0 comments on commit a06f9ae

Please sign in to comment.