From 7695845e48cc46acf55dfb0fb9eb5449917f4766 Mon Sep 17 00:00:00 2001 From: Caglar Demir Date: Thu, 14 Mar 2024 16:11:27 +0100 Subject: [PATCH 1/8] Working version of continual learning --- README.md | 7 ++--- dicee/config.py | 2 ++ dicee/evaluator.py | 2 +- dicee/executer.py | 31 +++++++++++----------- dicee/models/base_model.py | 2 ++ dicee/scripts/run.py | 48 ++++++++++++++++++++++------------- dicee/trainer/dice_trainer.py | 2 +- tests/test_regression_cl.py | 2 +- 8 files changed, 57 insertions(+), 39 deletions(-) diff --git a/README.md b/README.md index 8a1bd828..8fbe9fe2 100644 --- a/README.md +++ b/README.md @@ -95,15 +95,16 @@ A KGE model can also be trained from the command line ```bash dicee --dataset_dir "KGs/UMLS" --model Keci --eval_model "train_val_test" ``` -dicee automaticaly detects available GPUs and trains a model with distributed data parallels technique. Under the hood, dicee uses lighning as a default trainer. +dicee automatically detects available GPUs and trains a model with distributed data parallels technique. ```bash # Train a model by only using the GPU-0 CUDA_VISIBLE_DEVICES=0 dicee --dataset_dir "KGs/UMLS" --model Keci --eval_model "train_val_test" # Train a model by only using GPU-1 CUDA_VISIBLE_DEVICES=1 dicee --dataset_dir "KGs/UMLS" --model Keci --eval_model "train_val_test" -NCCL_P2P_DISABLE=1 CUDA_VISIBLE_DEVICES=0,1 python dicee/scripts/run.py --trainer PL --dataset_dir "KGs/UMLS" --model Keci --eval_model "train_val_test" +# Train a model by using all available GPUs +dicee --dataset_dir "KGs/UMLS" --model Keci --eval_model "train_val_test" ``` -Under the hood, dicee executes run.py script and uses lighning as a default trainer +Under the hood, dicee executes the run.py script and uses [lightning](https://lightning.ai/) as a default trainer. ```bash # Two equivalent executions # (1) diff --git a/dicee/config.py b/dicee/config.py index 5ba66e14..3d7921cc 100644 --- a/dicee/config.py +++ b/dicee/config.py @@ -133,6 +133,8 @@ def __init__(self, **kwargs): self.block_size: int = None "block size of LLM" + self.continual_learning=None + "Path of a pretrained model size of LLM" def __iter__(self): # Iterate diff --git a/dicee/evaluator.py b/dicee/evaluator.py index fdee4f7b..ab46393a 100644 --- a/dicee/evaluator.py +++ b/dicee/evaluator.py @@ -456,7 +456,7 @@ def dummy_eval(self, trained_model, form_of_labelling: str): valid_set=valid_set, test_set=test_set, trained_model=trained_model) - elif self.args.scoring_technique in ['KvsAll', 'KvsSample', '1vsAll', 'PvsAll', 'CCvsAll']: + elif self.args.scoring_technique in ["AllvsAll",'KvsAll', 'KvsSample', '1vsAll']: self.eval_with_vs_all(train_set=train_set, valid_set=valid_set, test_set=test_set, diff --git a/dicee/executer.py b/dicee/executer.py index 4399af63..7be7fbef 100644 --- a/dicee/executer.py +++ b/dicee/executer.py @@ -234,31 +234,32 @@ class ContinuousExecute(Execute): (1) Loading & Preprocessing & Serializing input data. (2) Training & Validation & Testing (3) Storing all necessary info + + During the continual learning we can only modify *** num_epochs *** parameter. + Trained model stored in the same folder as the seed model for the training. + Trained model is noted with the current time. """ def __init__(self, args): - assert os.path.exists(args.path_experiment_folder) - assert os.path.isfile(args.path_experiment_folder + '/configuration.json') - # (1) Load Previous input configuration - previous_args = load_json(args.path_experiment_folder + '/configuration.json') - dargs = vars(args) - del args - for k in list(dargs.keys()): - if dargs[k] is None: - del dargs[k] - # (2) Update (1) with new input - previous_args.update(dargs) + # (1) Current input configuration. + assert os.path.exists(args.continual_learning) + assert os.path.isfile(args.continual_learning + '/configuration.json') + # (2) Load previous input configuration. + previous_args = load_json(args.continual_learning + '/configuration.json') + args=vars(args) + # + previous_args["num_epochs"]=args["num_epochs"] + previous_args["continual_learning"]=args["continual_learning"] + print("Updated configuration:",previous_args) try: - report = load_json(dargs['path_experiment_folder'] + '/report.json') + report = load_json(args['continual_learning'] + '/report.json') previous_args['num_entities'] = report['num_entities'] previous_args['num_relations'] = report['num_relations'] except AssertionError: print("Couldn't find report.json.") previous_args = SimpleNamespace(**previous_args) - previous_args.full_storage_path = previous_args.path_experiment_folder print('ContinuousExecute starting...') print(previous_args) - # TODO: can we remove continuous_training from Execute ? super().__init__(previous_args, continuous_training=True) def continual_start(self) -> dict: @@ -279,7 +280,7 @@ def continual_start(self) -> dict: """ # (1) self.trainer = DICE_Trainer(args=self.args, is_continual_training=True, - storage_path=self.args.path_experiment_folder) + storage_path=self.args.continual_learning) # (2) self.trained_model, form_of_labelling = self.trainer.continual_start() diff --git a/dicee/models/base_model.py b/dicee/models/base_model.py index dbbdf9cf..c6a51890 100644 --- a/dicee/models/base_model.py +++ b/dicee/models/base_model.py @@ -431,6 +431,8 @@ class IdentityClass(torch.nn.Module): def __init__(self, args=None): super().__init__() self.args = args + def __call__(self, x): + return x @staticmethod def forward(x): diff --git a/dicee/scripts/run.py b/dicee/scripts/run.py index 7a4edcd8..ad085758 100755 --- a/dicee/scripts/run.py +++ b/dicee/scripts/run.py @@ -1,5 +1,5 @@ import json -from dicee.executer import Execute +from dicee.executer import Execute, ContinuousExecute import argparse def get_default_arguments(description=None): @@ -43,9 +43,9 @@ def get_default_arguments(description=None): parser.add_argument('--optim', type=str, default='Adam', help='An optimizer', choices=['Adam', 'AdamW', 'SGD',"NAdam", "Adagrad", "ASGD"]) - parser.add_argument('--embedding_dim', type=int, default=32, + parser.add_argument('--embedding_dim', type=int, default=256, help='Number of dimensions for an embedding vector. ') - parser.add_argument("--num_epochs", type=int, default=500, help='Number of epochs for training. ') + parser.add_argument("--num_epochs", type=int, default=100, help='Number of epochs for training. ') parser.add_argument('--batch_size', type=int, default=1024, help='Mini batch size. If None, automatic batch finder is applied') parser.add_argument("--lr", type=float, default=0.01) @@ -73,14 +73,6 @@ def get_default_arguments(description=None): parser.add_argument("--gradient_accumulation_steps", type=int, default=0, help="e.g. gradient_accumulation_steps=2 " "implies that gradients are accumulated at every second mini-batch") - parser.add_argument('--num_folds_for_cv', type=int, default=0, - help='Number of folds in k-fold cross validation.' - 'If >2 ,no evaluation scenario is applied implies no evaluation.') - parser.add_argument("--eval_model", type=str, default="train_val_test", - choices=["None", "train", "train_val", "train_val_test", "test"], - help='Evaluating link prediction performance on data splits. ') - parser.add_argument("--save_model_at_every_epoch", type=int, default=None, - help='At every X number of epochs model will be saved. If None, we save 4 times.') parser.add_argument("--label_smoothing_rate", type=float, default=0.0, help='None for not using it.') parser.add_argument("--kernel_size", type=int, default=3, help="Square kernel size for convolution based models.") @@ -90,19 +82,34 @@ def get_default_arguments(description=None): help='Number of cores to be used. 0 implies using single CPU') parser.add_argument("--random_seed", type=int, default=1, help='Seed for all, see pl seed_everything().') + parser.add_argument('--p', type=int, default=0, + help='P for Clifford Algebra') + parser.add_argument('--q', type=int, default=1, + help='Q for Clifford Algebra') + parser.add_argument('--pykeen_model_kwargs', type=json.loads, default={}) + + # Evaluation Related + parser.add_argument('--num_folds_for_cv', type=int, default=0, + help='Number of folds in k-fold cross validation.' + 'If >2 ,no evaluation scenario is applied implies no evaluation.') + parser.add_argument("--eval_model", type=str, default="train_val_test", + choices=["None", "train", "train_val", "train_val_test", "test"], + help='Evaluating link prediction performance on data splits. ') + parser.add_argument("--save_model_at_every_epoch", type=int, default=None, + help='At every X number of epochs model will be saved. If None, we save 4 times.') + # Continual Learning + parser.add_argument("--continual_learning", type=str, default=None, + help="The path of a folder containing a pretrained model and configurations") + parser.add_argument("--sample_triples_ratio", type=float, default=None, help='Sample input data.') parser.add_argument("--read_only_few", type=int, default=None, help='READ only first N triples. If 0, read all.') parser.add_argument("--add_noise_rate", type=float, default=0.0, help='Add x % of noisy triples into training dataset.') - parser.add_argument('--p', type=int, default=0, - help='P for Clifford Algebra') - parser.add_argument('--q', type=int, default=1, - help='Q for Clifford Algebra') + # WIP + parser.add_argument('--r', type=int, default=0, help='R for Clifford Algebra') - parser.add_argument('--pykeen_model_kwargs', type=json.loads, default={}) - # WIP parser.add_argument('--block_size', type=int, default=8, help='Block size for BytE') parser.add_argument("--byte_pair_encoding", @@ -122,7 +129,12 @@ def get_default_arguments(description=None): return parser.parse_args(description) def main(): - Execute(get_default_arguments()).start() + + args = get_default_arguments() + if args.continual_learning: + ContinuousExecute(args).continual_start() + else: + Execute(get_default_arguments()).start() if __name__ == '__main__': main() diff --git a/dicee/trainer/dice_trainer.py b/dicee/trainer/dice_trainer.py index 34272f48..4ab3376d 100644 --- a/dicee/trainer/dice_trainer.py +++ b/dicee/trainer/dice_trainer.py @@ -165,7 +165,7 @@ def continual_start(self): self.trainer = self.initialize_trainer(callbacks=get_callbacks(self.args)) model, form_of_labelling = self.initialize_or_load_model() assert form_of_labelling in ['EntityPrediction', 'RelationPrediction', 'Pyke'] - assert self.args.scoring_technique in ['KvsSample', '1vsAll', 'KvsAll', 'NegSample'] + assert self.args.scoring_technique in ["AllvsAll",'KvsSample', '1vsAll', 'KvsAll', 'NegSample'] train_loader = self.initialize_dataloader( reload_dataset(path=self.storage_path, form_of_labelling=form_of_labelling, scoring_technique=self.args.scoring_technique, diff --git a/tests/test_regression_cl.py b/tests/test_regression_cl.py index 7b530a0b..3494db1c 100644 --- a/tests/test_regression_cl.py +++ b/tests/test_regression_cl.py @@ -25,7 +25,7 @@ def test_k_vs_all(self): args.init_param = 'xavier_normal' result = Execute(args).start() - args.path_experiment_folder = result['path_experiment_folder'] + args.continual_learning = result['path_experiment_folder'] cl_result = ContinuousExecute(args).continual_start() assert cl_result['Train']['H@10'] >= result['Train']['H@10'] From a06f9ae08ff0fa788deb48e5d3658d0be05d57ce Mon Sep 17 00:00:00 2001 From: Caglar Demir Date: Thu, 14 Mar 2024 16:56:37 +0100 Subject: [PATCH 2/8] Refactoring the process of downloading pretrained models --- README.md | 41 ++++++++++++++++------------------------- dicee/static_funcs.py | 17 +++++++++++++++-- 2 files changed, 31 insertions(+), 27 deletions(-) diff --git a/README.md b/README.md index 8fbe9fe2..c3e3a59b 100644 --- a/README.md +++ b/README.md @@ -109,32 +109,12 @@ Under the hood, dicee executes the run.py script and uses [lightning](https://li # Two equivalent executions # (1) dicee --dataset_dir "KGs/UMLS" --model Keci --eval_model "train_val_test" -# Evaluate Keci on Train set: Evaluate Keci on Train set -# {'H@1': 0.9518788343558282, 'H@3': 0.9988496932515337, 'H@10': 1.0, 'MRR': 0.9753123402351737} -# Evaluate Keci on Validation set: Evaluate Keci on Validation set -# {'H@1': 0.6932515337423313, 'H@3': 0.9041411042944786, 'H@10': 0.9754601226993865, 'MRR': 0.8072362996241839} -# Evaluate Keci on Test set: Evaluate Keci on Test set -# {'H@1': 0.6951588502269289, 'H@3': 0.9039334341906202, 'H@10': 0.9750378214826021, 'MRR': 0.8064032293278861} - # (2) CUDA_VISIBLE_DEVICES=0,1 python dicee/scripts/run.py --trainer PL --dataset_dir "KGs/UMLS" --model Keci --eval_model "train_val_test" -# Evaluate Keci on Train set: Evaluate Keci on Train set -# {'H@1': 0.9518788343558282, 'H@3': 0.9988496932515337, 'H@10': 1.0, 'MRR': 0.9753123402351737} -# Evaluate Keci on Train set: Evaluate Keci on Train set -# Evaluate Keci on Validation set: Evaluate Keci on Validation set -# {'H@1': 0.6932515337423313, 'H@3': 0.9041411042944786, 'H@10': 0.9754601226993865, 'MRR': 0.8072362996241839} -# Evaluate Keci on Test set: Evaluate Keci on Test set -# {'H@1': 0.6951588502269289, 'H@3': 0.9039334341906202, 'H@10': 0.9750378214826021, 'MRR': 0.8064032293278861} ``` Similarly, models can be easily trained with torchrun ```bash torchrun --standalone --nnodes=1 --nproc_per_node=gpu dicee/scripts/run.py --trainer torchDDP --dataset_dir "KGs/UMLS" --model Keci --eval_model "train_val_test" -# Evaluate Keci on Train set: Evaluate Keci on Train set: Evaluate Keci on Train set -# {'H@1': 0.9518788343558282, 'H@3': 0.9988496932515337, 'H@10': 1.0, 'MRR': 0.9753123402351737} -# Evaluate Keci on Validation set: Evaluate Keci on Validation set -# {'H@1': 0.6932515337423313, 'H@3': 0.9041411042944786, 'H@10': 0.9754601226993865, 'MRR': 0.8072499937521418} -# Evaluate Keci on Test set: Evaluate Keci on Test set -{'H@1': 0.6951588502269289, 'H@3': 0.9039334341906202, 'H@10': 0.9750378214826021, 'MRR': 0.8064032293278861} ``` You can also train a model in multi-node multi-gpu setting. ```bash @@ -144,7 +124,7 @@ torchrun --nnodes 2 --nproc_per_node=gpu --node_rank 1 --rdzv_id 455 --rdzv_bac Train a KGE model by providing the path of a single file and store all parameters under newly created directory called `KeciFamilyRun`. ```bash -dicee --path_single_kg "KGs/Family/family-benchmark_rich_background.owl" --model Keci --path_to_store_single_run KeciFamilyRun --backend rdflib +dicee --path_single_kg "KGs/Family/family-benchmark_rich_background.owl" --model Keci --path_to_store_single_run KeciFamilyRun --backend rdflib --eval_model None ``` where the data is in the following form ```bash @@ -153,6 +133,11 @@ _:1 . . ``` +**Continual Training:** the training phase of a pretrained model can be resumed. +```bash +dicee --continual_learning KeciFamilyRun --path_single_kg "KGs/Family/family-benchmark_rich_background.owl" --model Keci --path_to_store_single_run KeciFamilyRun --backend rdflib --eval_model None +``` + **Apart from n-triples or standard link prediction dataset formats, we support ["owl", "nt", "turtle", "rdf/xml", "n3"]***. Moreover, a KGE model can be also trained by providing **an endpoint of a triple store**. ```bash @@ -286,16 +271,22 @@ pre_trained_kge.predict_topk(r=[".."],t=[".."],topk=10) ## Downloading Pretrained Models +We provide plenty pretrained knowledge graph embedding models at [dice-research.org/projects/DiceEmbeddings/](https://files.dice-research.org/projects/DiceEmbeddings/).
To see a code snippet ```python from dicee import KGE -# (1) Load a pretrained ConEx on DBpedia -model = KGE(url="https://files.dice-research.org/projects/DiceEmbeddings/KINSHIP-Keci-dim128-epoch256-KvsAll") +mure = KGE(url="https://files.dice-research.org/projects/DiceEmbeddings/YAGO3-10-Pykeen_MuRE-dim128-epoch256-KvsAll") +quate = KGE(url="https://files.dice-research.org/projects/DiceEmbeddings/YAGO3-10-Pykeen_QuatE-dim128-epoch256-KvsAll") +keci = KGE(url="https://files.dice-research.org/projects/DiceEmbeddings/YAGO3-10-Keci-dim128-epoch256-KvsAll") +quate.predict_topk(h=["Mongolia"],r=["isLocatedIn"],topk=3) +# [('Asia', 0.9894362688064575), ('Europe', 0.01575559377670288), ('Tadanari_Lee', 0.012544365599751472)] +keci.predict_topk(h=["Mongolia"],r=["isLocatedIn"],topk=3) +# [('Asia', 0.6522021293640137), ('Chinggis_Khaan_International_Airport', 0.36563414335250854), ('Democratic_Party_(Mongolia)', 0.19600993394851685)] +mure.predict_topk(h=["Mongolia"],r=["isLocatedIn"],topk=3) +# [('Asia', 0.9996906518936157), ('Ulan_Bator', 0.0009907372295856476), ('Philippines', 0.0003116439620498568)] ``` -- For more please look at [dice-research.org/projects/DiceEmbeddings/](https://files.dice-research.org/projects/DiceEmbeddings/) -
## How to Deploy diff --git a/dicee/static_funcs.py b/dicee/static_funcs.py index 6a876bb6..d6e3202a 100644 --- a/dicee/static_funcs.py +++ b/dicee/static_funcs.py @@ -624,7 +624,19 @@ def download_file(url, destination_folder="."): print(f"Failed to download: {url}") -def download_files_from_url(base_url, destination_folder="."): +def download_files_from_url(base_url:str, destination_folder=".")->None: + """ + + Parameters + ---------- + base_url: e.g. "https://files.dice-research.org/projects/DiceEmbeddings/KINSHIP-Keci-dim128-epoch256-KvsAll" + + destination_folder: e.g. "KINSHIP-Keci-dim128-epoch256-KvsAll" + + Returns + ------- + + """ # lazy import from bs4 import BeautifulSoup @@ -639,7 +651,8 @@ def download_files_from_url(base_url, destination_folder="."): hrefs = [i for i in hrefs if len(i) > 3 and "." in i] for file_url in hrefs: download_file(base_url + "/" + file_url, destination_folder) - + else: + print("ERROR:", response.status_code) def download_pretrained_model(url: str) -> str: assert url[-1] != "/" From cd8935bd83ff8130c7853f319f96101ce6c1d7ef Mon Sep 17 00:00:00 2001 From: Louis-Mozart Date: Tue, 26 Mar 2024 18:29:54 +0100 Subject: [PATCH 3/8] WIP: DualE with NegSampling implemented --- dicee/models/__init__.py | 1 + dicee/models/dualE.py | 365 +++++++++++++++++++++++++++++++++++++++ dicee/scripts/run.py | 2 +- dicee/static_funcs.py | 5 +- 4 files changed, 371 insertions(+), 2 deletions(-) create mode 100644 dicee/models/dualE.py diff --git a/dicee/models/__init__.py b/dicee/models/__init__.py index 27d8d4d0..a8240a92 100644 --- a/dicee/models/__init__.py +++ b/dicee/models/__init__.py @@ -6,3 +6,4 @@ from .clifford import Keci, KeciBase, CMult, DeCaL # noqa from .pykeen_models import * # noqa from .function_space import * # noqa +from .dualE import DualE diff --git a/dicee/models/dualE.py b/dicee/models/dualE.py new file mode 100644 index 00000000..0d654b6c --- /dev/null +++ b/dicee/models/dualE.py @@ -0,0 +1,365 @@ +import torch +import torch.autograd as autograd +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.autograd import Variable +from .base_model import BaseKGE +import numpy as np +from numpy.random import RandomState + + + +# class OMult(BaseKGE): +# def __init__(self, args): +# super().__init__(args) +# self.name = 'OMult' + +# @staticmethod +# def octonion_normalizer(emb_rel_e0, emb_rel_e1, emb_rel_e2, emb_rel_e3, emb_rel_e4, emb_rel_e5, emb_rel_e6, +# emb_rel_e7): +# denominator = torch.sqrt( +# emb_rel_e0 ** 2 + emb_rel_e1 ** 2 + emb_rel_e2 ** 2 + emb_rel_e3 ** 2 + emb_rel_e4 ** 2 +# + emb_rel_e5 ** 2 + emb_rel_e6 ** 2 + emb_rel_e7 ** 2) +# y0 = emb_rel_e0 / denominator +# y1 = emb_rel_e1 / denominator +# y2 = emb_rel_e2 / denominator +# y3 = emb_rel_e3 / denominator +# y4 = emb_rel_e4 / denominator +# y5 = emb_rel_e5 / denominator +# y6 = emb_rel_e6 / denominator +# y7 = emb_rel_e7 / denominator +# return y0, y1, y2, y3, y4, y5, y6, y7 + +# def score(self, head_ent_emb: torch.FloatTensor, rel_ent_emb: torch.FloatTensor, tail_ent_emb: torch.FloatTensor): +# # (2) Split (1) into real and imaginary parts. +# emb_head_e0, emb_head_e1, emb_head_e2, emb_head_e3, emb_head_e4, emb_head_e5, emb_head_e6, emb_head_e7 = torch.hsplit( +# head_ent_emb, 8) +# emb_rel_e0, emb_rel_e1, emb_rel_e2, emb_rel_e3, emb_rel_e4, emb_rel_e5, emb_rel_e6, emb_rel_e7 = torch.hsplit( +# rel_ent_emb, +# 8) +# if isinstance(self.normalize_relation_embeddings, IdentityClass): +# (emb_rel_e0, emb_rel_e1, emb_rel_e2, emb_rel_e3, emb_rel_e4, +# emb_rel_e5, emb_rel_e6, emb_rel_e7) = self.octonion_normalizer(emb_rel_e0, +# emb_rel_e1, emb_rel_e2, emb_rel_e3, +# emb_rel_e4, emb_rel_e5, emb_rel_e6, +# emb_rel_e7) + +# emb_tail_e0, emb_tail_e1, emb_tail_e2, emb_tail_e3, emb_tail_e4, emb_tail_e5, emb_tail_e6, emb_tail_e7 = torch.hsplit( +# tail_ent_emb, 8) +# # (3) Octonion Multiplication +# e0, e1, e2, e3, e4, e5, e6, e7 = octonion_mul( +# O_1=( +# emb_head_e0, emb_head_e1, emb_head_e2, emb_head_e3, emb_head_e4, emb_head_e5, emb_head_e6, emb_head_e7), +# O_2=(emb_rel_e0, emb_rel_e1, emb_rel_e2, emb_rel_e3, emb_rel_e4, emb_rel_e5, emb_rel_e6, emb_rel_e7)) +# # (4) +# # (4.3) Inner product +# e0_score = (e0 * emb_tail_e0).sum(dim=1) +# e1_score = (e1 * emb_tail_e1).sum(dim=1) +# e2_score = (e2 * emb_tail_e2).sum(dim=1) +# e3_score = (e3 * emb_tail_e3).sum(dim=1) +# e4_score = (e4 * emb_tail_e4).sum(dim=1) +# e5_score = (e5 * emb_tail_e5).sum(dim=1) +# e6_score = (e6 * emb_tail_e6).sum(dim=1) +# e7_score = (e7 * emb_tail_e7).sum(dim=1) + +# return e0_score + e1_score + e2_score + e3_score + e4_score + e5_score + e6_score + e7_score + + + +class DualE(BaseKGE): + def __init__(self, args): + super().__init__(args) + self.name = 'DualE' + self.lmbda = 0.0 + + self.entity_embeddings = torch.nn.Embedding(self.num_entities, self.embedding_dim) + self.relation_embeddings = torch.nn.Embedding(self.num_relations, self.embedding_dim) + + # self.entity_embeddings = torch.nn.Embedding(self.num_entities, self.embedding_dim) + # self.relation_embeddings = torch.nn.Embedding(self.num_relations, self.embedding_dim) + + # self.emb_1 = nn.Embedding(self.config.entTotal, self.config.hidden_size) + # self.emb_2 = nn.Embedding(self.config.entTotal, self.config.hidden_size) + # self.emb_3 = nn.Embedding(self.config.entTotal, self.config.hidden_size) + # self.emb_4 = nn.Embedding(self.config.entTotal, self.config.hidden_size) + # self.emb_5 = nn.Embedding(self.config.entTotal, self.config.hidden_size) + # self.emb_6 = nn.Embedding(self.config.entTotal, self.config.hidden_size) + # self.emb_7 = nn.Embedding(self.config.entTotal, self.config.hidden_size) + # self.emb_8 = nn.Embedding(self.config.entTotal, self.config.hidden_size) + # self.rel_1 = nn.Embedding(self.config.relTotal, self.config.hidden_size) + # self.rel_2 = nn.Embedding(self.config.relTotal, self.config.hidden_size) + # self.rel_3 = nn.Embedding(self.config.relTotal, self.config.hidden_size) + # self.rel_4 = nn.Embedding(self.config.relTotal, self.config.hidden_size) + # self.rel_5 = nn.Embedding(self.config.relTotal, self.config.hidden_size) + # self.rel_6 = nn.Embedding(self.config.relTotal, self.config.hidden_size) + # self.rel_7 = nn.Embedding(self.config.relTotal, self.config.hidden_size) + # self.rel_8 = nn.Embedding(self.config.relTotal, self.config.hidden_size) + # self.rel_w = nn.Embedding(self.config.relTotal, self.config.hidden_size) + # self.criterion = nn.Softplus() + # self.fc = nn.Linear(100, 50, bias=False) + # self.ent_dropout = torch.nn.Dropout(self.config.ent_dropout) + # self.rel_dropout = torch.nn.Dropout(self.config.rel_dropout) + # self.bn = torch.nn.BatchNorm1d(self.config.hidden_size) + + # self.init_weights() + + def init_weights(self): + if True: + r, i, j, k,r_1,i_1,j_1,k_1 = self.quaternion_init(self.config.entTotal, self.config.hidden_size) + r, i, j, k,r_1,i_1,j_1,k_1 = torch.from_numpy(r), torch.from_numpy(i), torch.from_numpy(j), torch.from_numpy(k),\ + torch.from_numpy(r_1), torch.from_numpy(i_1), torch.from_numpy(j_1), torch.from_numpy(k_1) + self.emb_1.weight.data = r.type_as(self.emb_1.weight.data) + self.emb_2.weight.data = i.type_as(self.emb_2.weight.data) + self.emb_3.weight.data = j.type_as(self.emb_3.weight.data) + self.emb_4.weight.data = k.type_as(self.emb_4.weight.data) + self.emb_5.weight.data = r_1.type_as(self.emb_5.weight.data) + self.emb_6.weight.data = i_1.type_as(self.emb_6.weight.data) + self.emb_7.weight.data = j_1.type_as(self.emb_7.weight.data) + self.emb_8.weight.data = k_1.type_as(self.emb_8.weight.data) + + s, x, y, z,s_1,x_1,y_1,z_1 = self.quaternion_init(self.config.entTotal, self.config.hidden_size) + s, x, y, z,s_1,x_1,y_1,z_1 = torch.from_numpy(s), torch.from_numpy(x), torch.from_numpy(y), torch.from_numpy(z), \ + torch.from_numpy(s_1), torch.from_numpy(x_1), torch.from_numpy(y_1), torch.from_numpy(z_1) + self.rel_1.weight.data = s.type_as(self.rel_1.weight.data) + self.rel_2.weight.data = x.type_as(self.rel_2.weight.data) + self.rel_3.weight.data = y.type_as(self.rel_3.weight.data) + self.rel_4.weight.data = z.type_as(self.rel_4.weight.data) + self.rel_5.weight.data = s_1.type_as(self.rel_5.weight.data) + self.rel_6.weight.data = x_1.type_as(self.rel_6.weight.data) + self.rel_7.weight.data = y_1.type_as(self.rel_7.weight.data) + self.rel_8.weight.data = z_1.type_as(self.rel_8.weight.data) + nn.init.xavier_uniform_(self.rel_w.weight.data) + else: + nn.init.xavier_uniform_(self.emb_1.weight.data) + nn.init.xavier_uniform_(self.emb_2.weight.data) + nn.init.xavier_uniform_(self.emb_3.weight.data) + nn.init.xavier_uniform_(self.emb_4.weight.data) + nn.init.xavier_uniform_(self.emb_5.weight.data) + nn.init.xavier_uniform_(self.emb_6.weight.data) + nn.init.xavier_uniform_(self.emb_7.weight.data) + nn.init.xavier_uniform_(self.emb_8.weight.data) + nn.init.xavier_uniform_(self.rel_1.weight.data) + nn.init.xavier_uniform_(self.rel_2.weight.data) + nn.init.xavier_uniform_(self.rel_3.weight.data) + nn.init.xavier_uniform_(self.rel_4.weight.data) + nn.init.xavier_uniform_(self.rel_5.weight.data) + nn.init.xavier_uniform_(self.rel_6.weight.data) + nn.init.xavier_uniform_(self.rel_7.weight.data) + nn.init.xavier_uniform_(self.rel_8.weight.data) + + + + #Calculate the Dual Hamiltonian product + def _omult(self, a_0, a_1, a_2, a_3, b_0, b_1, b_2, b_3, c_0, c_1, c_2, c_3, d_0, d_1, d_2, d_3): + + h_0=a_0*c_0-a_1*c_1-a_2*c_2-a_3*c_3 + h1_0=a_0*d_0+b_0*c_0-a_1*d_1-b_1*c_1-a_2*d_2-b_2*c_2-a_3*d_3-b_3*c_3 + h_1=a_0*c_1+a_1*c_0+a_2*c_3-a_3*c_2 + h1_1=a_0*d_1+b_0*c_1+a_1*d_0+b_1*c_0+a_2*d_3+b_2*c_3-a_3*d_2-b_3*c_2 + h_2=a_0*c_2-a_1*c_3+a_2*c_0+a_3*c_1 + h1_2=a_0*d_2+b_0*c_2-a_1*d_3-b_1*c_3+a_2*d_0+b_2*c_0+a_3*d_1+b_3*c_1 + h_3=a_0*c_3+a_1*c_2-a_2*c_1+a_3*c_0 + h1_3=a_0*d_3+b_0*c_3+a_1*d_2+b_1*c_2-a_2*d_1-b_2*c_1+a_3*d_0+b_3*c_0 + + return (h_0,h_1,h_2,h_3,h1_0,h1_1,h1_2,h1_3) + + #Normalization of relationship embedding + def _onorm(self,r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8): + denominator_0 = r_1 ** 2 + r_2 ** 2 + r_3 ** 2 + r_4 ** 2 + denominator_1 = torch.sqrt(denominator_0) + #denominator_2 = torch.sqrt(r_5 ** 2 + r_6 ** 2 + r_7 ** 2 + r_8 ** 2) + deno_cross = r_5 * r_1 + r_6 * r_2 + r_7 * r_3 + r_8 * r_4 + + r_5 = r_5 - deno_cross / denominator_0 * r_1 + r_6 = r_6 - deno_cross / denominator_0 * r_2 + r_7 = r_7 - deno_cross / denominator_0 * r_3 + r_8 = r_8 - deno_cross / denominator_0 * r_4 + + r_1 = r_1 / denominator_1 + r_2 = r_2 / denominator_1 + r_3 = r_3 / denominator_1 + r_4 = r_4 / denominator_1 + #r_5 = r_5 / denominator_2 + #r_6 = r_6 / denominator_2 + #r_7 = r_7 / denominator_2 + #r_8 = r_8 / denominator_2 + return r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 + + #Calculate the inner product of the head entity and the relationship Hamiltonian product and the tail entity + def _calc(self, e_1_h, e_2_h, e_3_h, e_4_h, e_5_h, e_6_h, e_7_h, e_8_h, + e_1_t, e_2_t, e_3_t, e_4_t, e_5_t, e_6_t, e_7_t, e_8_t, + r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 ): + + r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 = self._onorm(r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 ) + + o_1, o_2, o_3, o_4, o_5, o_6, o_7, o_8 = self._omult(e_1_h, e_2_h, e_3_h, e_4_h, e_5_h, e_6_h, e_7_h, e_8_h, + r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8) + + + score_r = (o_1 * e_1_t + o_2 * e_2_t + o_3 * e_3_t + o_4 * e_4_t + + o_5 * e_5_t + o_6 * e_6_t + o_7 * e_7_t + o_8 * e_8_t) + + return -torch.sum(score_r, -1) + + + + # def loss(self, score, regul, regul2): + # return ( + # torch.mean(self.criterion(score * self.batch_y)) + self.lmbda * regul + self.lmbda * regul2 + # ) + + def forward_triples(self, idx_triple): + + head_ent_emb, rel_emb, tail_ent_emb = self.get_triple_representation(idx_triple) + + e_1_h, e_2_h, e_3_h, e_4_h, e_5_h, e_6_h, e_7_h, e_8_h = torch.hsplit(head_ent_emb, 8) + e_1_t, e_2_t, e_3_t, e_4_t, e_5_t, e_6_t, e_7_t, e_8_t = torch.hsplit(tail_ent_emb, 8) + r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 = torch.hsplit(rel_emb, 8) + + + + score = self._calc(e_1_h, e_2_h, e_3_h, e_4_h, e_5_h, e_6_h, e_7_h, e_8_h, + e_1_t, e_2_t, e_3_t, e_4_t, e_5_t, e_6_t, e_7_t, e_8_t, + r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 ) + + regul = (torch.mean(torch.abs(e_1_h) ** 2) + + torch.mean(torch.abs(e_2_h) ** 2) + + torch.mean(torch.abs(e_3_h) ** 2) + + torch.mean(torch.abs(e_4_h) ** 2) + + torch.mean(torch.abs(e_5_h) ** 2) + + torch.mean(torch.abs(e_6_h) ** 2) + + torch.mean(torch.abs(e_7_h) ** 2) + + torch.mean(torch.abs(e_8_h) ** 2) + + torch.mean(torch.abs(e_1_t) ** 2) + + torch.mean(torch.abs(e_2_t) ** 2) + + torch.mean(torch.abs(e_3_t) ** 2) + + torch.mean(torch.abs(e_4_t) ** 2) + + torch.mean(torch.abs(e_5_t) ** 2) + + torch.mean(torch.abs(e_6_t) ** 2) + + torch.mean(torch.abs(e_7_t) ** 2) + + torch.mean(torch.abs(e_8_t) ** 2) + ) + regul2 = (torch.mean(torch.abs(r_1) ** 2) + + torch.mean(torch.abs(r_2) ** 2) + + torch.mean(torch.abs(r_3) ** 2) + + torch.mean(torch.abs(r_4) ** 2) + + torch.mean(torch.abs(r_5) ** 2) + + torch.mean(torch.abs(r_6) ** 2) + + torch.mean(torch.abs(r_7) ** 2) + + torch.mean(torch.abs(r_8) ** 2)) + + return score #self.loss(score, regul, regul2) + + def predict(self): + e_1_h = self.emb_1(self.batch_h) + e_2_h = self.emb_2(self.batch_h) + e_3_h = self.emb_3(self.batch_h) + e_4_h = self.emb_4(self.batch_h) + e_5_h = self.emb_5(self.batch_h) + e_6_h = self.emb_6(self.batch_h) + e_7_h = self.emb_7(self.batch_h) + e_8_h = self.emb_8(self.batch_h) + + e_1_t = self.emb_1(self.batch_t) + e_2_t = self.emb_2(self.batch_t) + e_3_t = self.emb_3(self.batch_t) + e_4_t = self.emb_4(self.batch_t) + e_5_t = self.emb_5(self.batch_t) + e_6_t = self.emb_6(self.batch_t) + e_7_t = self.emb_7(self.batch_t) + e_8_t = self.emb_8(self.batch_t) + + r_1 = self.rel_1(self.batch_r) + r_2 = self.rel_2(self.batch_r) + r_3 = self.rel_3(self.batch_r) + r_4 = self.rel_4(self.batch_r) + r_5 = self.rel_5(self.batch_r) + r_6 = self.rel_6(self.batch_r) + r_7 = self.rel_7(self.batch_r) + r_8 = self.rel_8(self.batch_r) + + score = self._calc(e_1_h, e_2_h, e_3_h, e_4_h, e_5_h, e_6_h, e_7_h, e_8_h, + e_1_t, e_2_t, e_3_t, e_4_t, e_5_t, e_6_t, e_7_t, e_8_t, + r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 ) + return score.cpu().data.numpy() + + + + + def quaternion_init(self, in_features, out_features, criterion='he'): + + fan_in = in_features + fan_out = out_features + + if criterion == 'glorot': + s = 1. / np.sqrt(2 * (fan_in + fan_out)) + elif criterion == 'he': + s = 1. / np.sqrt(2 * fan_in) + else: + raise ValueError('Invalid criterion: ', criterion) + rng = RandomState(2020) + + # Generating randoms and purely imaginary quaternions : + kernel_shape = (in_features, out_features) + + number_of_weights = np.prod(kernel_shape) + v_i = np.random.uniform(0.0, 1.0, number_of_weights) + v_j = np.random.uniform(0.0, 1.0, number_of_weights) + v_k = np.random.uniform(0.0, 1.0, number_of_weights) + + # Purely imaginary quaternions unitary + for i in range(0, number_of_weights): + norm = np.sqrt(v_i[i] ** 2 + v_j[i] ** 2 + v_k[i] ** 2) + 0.0001 + v_i[i] /= norm + v_j[i] /= norm + v_k[i] /= norm + v_i = v_i.reshape(kernel_shape) + v_j = v_j.reshape(kernel_shape) + v_k = v_k.reshape(kernel_shape) + + modulus = rng.uniform(low=-s, high=s, size=kernel_shape) + + + # Calculate the three parts about t + kernel_shape1 = (in_features, out_features) + number_of_weights1 = np.prod(kernel_shape1) + t_i = np.random.uniform(0.0, 1.0, number_of_weights1) + t_j = np.random.uniform(0.0, 1.0, number_of_weights1) + t_k = np.random.uniform(0.0, 1.0, number_of_weights1) + + # Purely imaginary quaternions unitary + for i in range(0, number_of_weights1): + norm1 = np.sqrt(t_i[i] ** 2 + t_j[i] ** 2 + t_k[i] ** 2) + 0.0001 + t_i[i] /= norm1 + t_j[i] /= norm1 + t_k[i] /= norm1 + t_i = t_i.reshape(kernel_shape1) + t_j = t_j.reshape(kernel_shape1) + t_k = t_k.reshape(kernel_shape1) + tmp_t = rng.uniform(low=-s, high=s, size=kernel_shape1) + + + phase = rng.uniform(low=-np.pi, high=np.pi, size=kernel_shape) + phase1 = rng.uniform(low=-np.pi, high=np.pi, size=kernel_shape1) + + weight_r = modulus * np.cos(phase) + weight_i = modulus * v_i * np.sin(phase) + weight_j = modulus * v_j * np.sin(phase) + weight_k = modulus * v_k * np.sin(phase) + + wt_i = tmp_t * t_i * np.sin(phase1) + wt_j = tmp_t * t_j * np.sin(phase1) + wt_k = tmp_t * t_k * np.sin(phase1) + + i_0=weight_r + i_1=weight_i + i_2=weight_j + i_3=weight_k + i_4=(-wt_i*weight_i-wt_j*weight_j-wt_k*weight_k)/2 + i_5=(wt_i*weight_r+wt_j*weight_k-wt_k*weight_j)/2 + i_6=(-wt_i*weight_k+wt_j*weight_r+wt_k*weight_i)/2 + i_7=(wt_i*weight_j-wt_j*weight_i+wt_k*weight_r)/2 + + + return (i_0,i_1,i_2,i_3,i_4,i_5,i_6,i_7) diff --git a/dicee/scripts/run.py b/dicee/scripts/run.py index ad085758..3b95efaf 100755 --- a/dicee/scripts/run.py +++ b/dicee/scripts/run.py @@ -31,7 +31,7 @@ def get_default_arguments(description=None): parser.add_argument("--model", type=str, default="Keci", choices=["ComplEx", "Keci", "ConEx", "AConEx", "ConvQ", "AConvQ", "ConvO", "AConvO", "QMult", - "OMult", "Shallom", "DistMult", "TransE", "DeCaL", + "OMult", "Shallom", "DistMult", "TransE", "DualE", "BytE", "Pykeen_MuRE", "Pykeen_QuatE", "Pykeen_DistMult", "Pykeen_BoxE", "Pykeen_CP", "Pykeen_HolE", "Pykeen_ProjE", "Pykeen_RotatE", diff --git a/dicee/static_funcs.py b/dicee/static_funcs.py index d6e3202a..9ff1ad26 100644 --- a/dicee/static_funcs.py +++ b/dicee/static_funcs.py @@ -2,7 +2,7 @@ import torch import datetime from typing import Tuple, List -from .models import CMult, Pyke, DistMult, KeciBase, Keci, TransE, DeCaL,\ +from .models import CMult, Pyke, DistMult, KeciBase, Keci, TransE, DeCaL, DualE,\ ComplEx, AConEx, AConvO, AConvQ, ConvQ, ConvO, ConEx, QMult, OMult, Shallom, LFMult from .models.pykeen_models import PykeenKGE from .models.transformers import BytE @@ -421,6 +421,9 @@ def intialize_model(args: dict,verbose=0) -> Tuple[object, str]: elif model_name == 'DeCaL': model =DeCaL(args=args) form_of_labelling = 'EntityPrediction' + elif model_name == 'DualE': + model =DualE(args=args) + form_of_labelling = 'EntityPrediction' else: raise ValueError(f"--model_name: {model_name} is not found.") return model, form_of_labelling From 325063cf4ce727fb0fd7b83d9e5a4e7eb55a7955 Mon Sep 17 00:00:00 2001 From: Louis-Mozart Date: Wed, 27 Mar 2024 11:55:33 +0100 Subject: [PATCH 4/8] Work done: KvsAll implemented --- dicee/models/dualE.py | 313 ++++++------------------------------------ 1 file changed, 39 insertions(+), 274 deletions(-) diff --git a/dicee/models/dualE.py b/dicee/models/dualE.py index 0d654b6c..f40a8a0e 100644 --- a/dicee/models/dualE.py +++ b/dicee/models/dualE.py @@ -1,69 +1,5 @@ import torch -import torch.autograd as autograd -import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim -from torch.autograd import Variable from .base_model import BaseKGE -import numpy as np -from numpy.random import RandomState - - - -# class OMult(BaseKGE): -# def __init__(self, args): -# super().__init__(args) -# self.name = 'OMult' - -# @staticmethod -# def octonion_normalizer(emb_rel_e0, emb_rel_e1, emb_rel_e2, emb_rel_e3, emb_rel_e4, emb_rel_e5, emb_rel_e6, -# emb_rel_e7): -# denominator = torch.sqrt( -# emb_rel_e0 ** 2 + emb_rel_e1 ** 2 + emb_rel_e2 ** 2 + emb_rel_e3 ** 2 + emb_rel_e4 ** 2 -# + emb_rel_e5 ** 2 + emb_rel_e6 ** 2 + emb_rel_e7 ** 2) -# y0 = emb_rel_e0 / denominator -# y1 = emb_rel_e1 / denominator -# y2 = emb_rel_e2 / denominator -# y3 = emb_rel_e3 / denominator -# y4 = emb_rel_e4 / denominator -# y5 = emb_rel_e5 / denominator -# y6 = emb_rel_e6 / denominator -# y7 = emb_rel_e7 / denominator -# return y0, y1, y2, y3, y4, y5, y6, y7 - -# def score(self, head_ent_emb: torch.FloatTensor, rel_ent_emb: torch.FloatTensor, tail_ent_emb: torch.FloatTensor): -# # (2) Split (1) into real and imaginary parts. -# emb_head_e0, emb_head_e1, emb_head_e2, emb_head_e3, emb_head_e4, emb_head_e5, emb_head_e6, emb_head_e7 = torch.hsplit( -# head_ent_emb, 8) -# emb_rel_e0, emb_rel_e1, emb_rel_e2, emb_rel_e3, emb_rel_e4, emb_rel_e5, emb_rel_e6, emb_rel_e7 = torch.hsplit( -# rel_ent_emb, -# 8) -# if isinstance(self.normalize_relation_embeddings, IdentityClass): -# (emb_rel_e0, emb_rel_e1, emb_rel_e2, emb_rel_e3, emb_rel_e4, -# emb_rel_e5, emb_rel_e6, emb_rel_e7) = self.octonion_normalizer(emb_rel_e0, -# emb_rel_e1, emb_rel_e2, emb_rel_e3, -# emb_rel_e4, emb_rel_e5, emb_rel_e6, -# emb_rel_e7) - -# emb_tail_e0, emb_tail_e1, emb_tail_e2, emb_tail_e3, emb_tail_e4, emb_tail_e5, emb_tail_e6, emb_tail_e7 = torch.hsplit( -# tail_ent_emb, 8) -# # (3) Octonion Multiplication -# e0, e1, e2, e3, e4, e5, e6, e7 = octonion_mul( -# O_1=( -# emb_head_e0, emb_head_e1, emb_head_e2, emb_head_e3, emb_head_e4, emb_head_e5, emb_head_e6, emb_head_e7), -# O_2=(emb_rel_e0, emb_rel_e1, emb_rel_e2, emb_rel_e3, emb_rel_e4, emb_rel_e5, emb_rel_e6, emb_rel_e7)) -# # (4) -# # (4.3) Inner product -# e0_score = (e0 * emb_tail_e0).sum(dim=1) -# e1_score = (e1 * emb_tail_e1).sum(dim=1) -# e2_score = (e2 * emb_tail_e2).sum(dim=1) -# e3_score = (e3 * emb_tail_e3).sum(dim=1) -# e4_score = (e4 * emb_tail_e4).sum(dim=1) -# e5_score = (e5 * emb_tail_e5).sum(dim=1) -# e6_score = (e6 * emb_tail_e6).sum(dim=1) -# e7_score = (e7 * emb_tail_e7).sum(dim=1) - -# return e0_score + e1_score + e2_score + e3_score + e4_score + e5_score + e6_score + e7_score @@ -71,83 +7,9 @@ class DualE(BaseKGE): def __init__(self, args): super().__init__(args) self.name = 'DualE' - self.lmbda = 0.0 - self.entity_embeddings = torch.nn.Embedding(self.num_entities, self.embedding_dim) self.relation_embeddings = torch.nn.Embedding(self.num_relations, self.embedding_dim) - - # self.entity_embeddings = torch.nn.Embedding(self.num_entities, self.embedding_dim) - # self.relation_embeddings = torch.nn.Embedding(self.num_relations, self.embedding_dim) - - # self.emb_1 = nn.Embedding(self.config.entTotal, self.config.hidden_size) - # self.emb_2 = nn.Embedding(self.config.entTotal, self.config.hidden_size) - # self.emb_3 = nn.Embedding(self.config.entTotal, self.config.hidden_size) - # self.emb_4 = nn.Embedding(self.config.entTotal, self.config.hidden_size) - # self.emb_5 = nn.Embedding(self.config.entTotal, self.config.hidden_size) - # self.emb_6 = nn.Embedding(self.config.entTotal, self.config.hidden_size) - # self.emb_7 = nn.Embedding(self.config.entTotal, self.config.hidden_size) - # self.emb_8 = nn.Embedding(self.config.entTotal, self.config.hidden_size) - # self.rel_1 = nn.Embedding(self.config.relTotal, self.config.hidden_size) - # self.rel_2 = nn.Embedding(self.config.relTotal, self.config.hidden_size) - # self.rel_3 = nn.Embedding(self.config.relTotal, self.config.hidden_size) - # self.rel_4 = nn.Embedding(self.config.relTotal, self.config.hidden_size) - # self.rel_5 = nn.Embedding(self.config.relTotal, self.config.hidden_size) - # self.rel_6 = nn.Embedding(self.config.relTotal, self.config.hidden_size) - # self.rel_7 = nn.Embedding(self.config.relTotal, self.config.hidden_size) - # self.rel_8 = nn.Embedding(self.config.relTotal, self.config.hidden_size) - # self.rel_w = nn.Embedding(self.config.relTotal, self.config.hidden_size) - # self.criterion = nn.Softplus() - # self.fc = nn.Linear(100, 50, bias=False) - # self.ent_dropout = torch.nn.Dropout(self.config.ent_dropout) - # self.rel_dropout = torch.nn.Dropout(self.config.rel_dropout) - # self.bn = torch.nn.BatchNorm1d(self.config.hidden_size) - - # self.init_weights() - - def init_weights(self): - if True: - r, i, j, k,r_1,i_1,j_1,k_1 = self.quaternion_init(self.config.entTotal, self.config.hidden_size) - r, i, j, k,r_1,i_1,j_1,k_1 = torch.from_numpy(r), torch.from_numpy(i), torch.from_numpy(j), torch.from_numpy(k),\ - torch.from_numpy(r_1), torch.from_numpy(i_1), torch.from_numpy(j_1), torch.from_numpy(k_1) - self.emb_1.weight.data = r.type_as(self.emb_1.weight.data) - self.emb_2.weight.data = i.type_as(self.emb_2.weight.data) - self.emb_3.weight.data = j.type_as(self.emb_3.weight.data) - self.emb_4.weight.data = k.type_as(self.emb_4.weight.data) - self.emb_5.weight.data = r_1.type_as(self.emb_5.weight.data) - self.emb_6.weight.data = i_1.type_as(self.emb_6.weight.data) - self.emb_7.weight.data = j_1.type_as(self.emb_7.weight.data) - self.emb_8.weight.data = k_1.type_as(self.emb_8.weight.data) - - s, x, y, z,s_1,x_1,y_1,z_1 = self.quaternion_init(self.config.entTotal, self.config.hidden_size) - s, x, y, z,s_1,x_1,y_1,z_1 = torch.from_numpy(s), torch.from_numpy(x), torch.from_numpy(y), torch.from_numpy(z), \ - torch.from_numpy(s_1), torch.from_numpy(x_1), torch.from_numpy(y_1), torch.from_numpy(z_1) - self.rel_1.weight.data = s.type_as(self.rel_1.weight.data) - self.rel_2.weight.data = x.type_as(self.rel_2.weight.data) - self.rel_3.weight.data = y.type_as(self.rel_3.weight.data) - self.rel_4.weight.data = z.type_as(self.rel_4.weight.data) - self.rel_5.weight.data = s_1.type_as(self.rel_5.weight.data) - self.rel_6.weight.data = x_1.type_as(self.rel_6.weight.data) - self.rel_7.weight.data = y_1.type_as(self.rel_7.weight.data) - self.rel_8.weight.data = z_1.type_as(self.rel_8.weight.data) - nn.init.xavier_uniform_(self.rel_w.weight.data) - else: - nn.init.xavier_uniform_(self.emb_1.weight.data) - nn.init.xavier_uniform_(self.emb_2.weight.data) - nn.init.xavier_uniform_(self.emb_3.weight.data) - nn.init.xavier_uniform_(self.emb_4.weight.data) - nn.init.xavier_uniform_(self.emb_5.weight.data) - nn.init.xavier_uniform_(self.emb_6.weight.data) - nn.init.xavier_uniform_(self.emb_7.weight.data) - nn.init.xavier_uniform_(self.emb_8.weight.data) - nn.init.xavier_uniform_(self.rel_1.weight.data) - nn.init.xavier_uniform_(self.rel_2.weight.data) - nn.init.xavier_uniform_(self.rel_3.weight.data) - nn.init.xavier_uniform_(self.rel_4.weight.data) - nn.init.xavier_uniform_(self.rel_5.weight.data) - nn.init.xavier_uniform_(self.rel_6.weight.data) - nn.init.xavier_uniform_(self.rel_7.weight.data) - nn.init.xavier_uniform_(self.rel_8.weight.data) - + self.num_ent = self.num_entities #Calculate the Dual Hamiltonian product @@ -201,165 +63,68 @@ def _calc(self, e_1_h, e_2_h, e_3_h, e_4_h, e_5_h, e_6_h, e_7_h, e_8_h, + o_5 * e_5_t + o_6 * e_6_t + o_7 * e_7_t + o_8 * e_8_t) return -torch.sum(score_r, -1) - + def kvsall_score(self, e_1_h, e_2_h, e_3_h, e_4_h, e_5_h, e_6_h, e_7_h, e_8_h, + e_1_t, e_2_t, e_3_t, e_4_t, e_5_t, e_6_t, e_7_t, e_8_t, + r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 ): + + r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 = self._onorm(r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 ) + + o_1, o_2, o_3, o_4, o_5, o_6, o_7, o_8 = self._omult(e_1_h, e_2_h, e_3_h, e_4_h, e_5_h, e_6_h, e_7_h, e_8_h, + r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8) + - # def loss(self, score, regul, regul2): - # return ( - # torch.mean(self.criterion(score * self.batch_y)) + self.lmbda * regul + self.lmbda * regul2 - # ) + score_r = torch.mm(o_1, e_1_t) + torch.mm(o_2 ,e_2_t) + torch.mm(o_3, e_3_t) + torch.mm(o_4, e_4_t)\ + + torch.mm(o_5, e_5_t) + torch.mm(o_6, e_6_t) + torch.mm(o_7, e_7_t) +torch.mm( o_8 , e_8_t) + + return -score_r + def forward_triples(self, idx_triple): head_ent_emb, rel_emb, tail_ent_emb = self.get_triple_representation(idx_triple) + e_1_h, e_2_h, e_3_h, e_4_h, e_5_h, e_6_h, e_7_h, e_8_h = torch.hsplit(head_ent_emb, 8) + e_1_t, e_2_t, e_3_t, e_4_t, e_5_t, e_6_t, e_7_t, e_8_t = torch.hsplit(tail_ent_emb, 8) + r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 = torch.hsplit(rel_emb, 8) - - score = self._calc(e_1_h, e_2_h, e_3_h, e_4_h, e_5_h, e_6_h, e_7_h, e_8_h, e_1_t, e_2_t, e_3_t, e_4_t, e_5_t, e_6_t, e_7_t, e_8_t, r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 ) - regul = (torch.mean(torch.abs(e_1_h) ** 2) - + torch.mean(torch.abs(e_2_h) ** 2) - + torch.mean(torch.abs(e_3_h) ** 2) - + torch.mean(torch.abs(e_4_h) ** 2) - + torch.mean(torch.abs(e_5_h) ** 2) - + torch.mean(torch.abs(e_6_h) ** 2) - + torch.mean(torch.abs(e_7_h) ** 2) - + torch.mean(torch.abs(e_8_h) ** 2) - + torch.mean(torch.abs(e_1_t) ** 2) - + torch.mean(torch.abs(e_2_t) ** 2) - + torch.mean(torch.abs(e_3_t) ** 2) - + torch.mean(torch.abs(e_4_t) ** 2) - + torch.mean(torch.abs(e_5_t) ** 2) - + torch.mean(torch.abs(e_6_t) ** 2) - + torch.mean(torch.abs(e_7_t) ** 2) - + torch.mean(torch.abs(e_8_t) ** 2) - ) - regul2 = (torch.mean(torch.abs(r_1) ** 2) - + torch.mean(torch.abs(r_2) ** 2) - + torch.mean(torch.abs(r_3) ** 2) - + torch.mean(torch.abs(r_4) ** 2) - + torch.mean(torch.abs(r_5) ** 2) - + torch.mean(torch.abs(r_6) ** 2) - + torch.mean(torch.abs(r_7) ** 2) - + torch.mean(torch.abs(r_8) ** 2)) - - return score #self.loss(score, regul, regul2) - - def predict(self): - e_1_h = self.emb_1(self.batch_h) - e_2_h = self.emb_2(self.batch_h) - e_3_h = self.emb_3(self.batch_h) - e_4_h = self.emb_4(self.batch_h) - e_5_h = self.emb_5(self.batch_h) - e_6_h = self.emb_6(self.batch_h) - e_7_h = self.emb_7(self.batch_h) - e_8_h = self.emb_8(self.batch_h) - - e_1_t = self.emb_1(self.batch_t) - e_2_t = self.emb_2(self.batch_t) - e_3_t = self.emb_3(self.batch_t) - e_4_t = self.emb_4(self.batch_t) - e_5_t = self.emb_5(self.batch_t) - e_6_t = self.emb_6(self.batch_t) - e_7_t = self.emb_7(self.batch_t) - e_8_t = self.emb_8(self.batch_t) - - r_1 = self.rel_1(self.batch_r) - r_2 = self.rel_2(self.batch_r) - r_3 = self.rel_3(self.batch_r) - r_4 = self.rel_4(self.batch_r) - r_5 = self.rel_5(self.batch_r) - r_6 = self.rel_6(self.batch_r) - r_7 = self.rel_7(self.batch_r) - r_8 = self.rel_8(self.batch_r) - - score = self._calc(e_1_h, e_2_h, e_3_h, e_4_h, e_5_h, e_6_h, e_7_h, e_8_h, - e_1_t, e_2_t, e_3_t, e_4_t, e_5_t, e_6_t, e_7_t, e_8_t, - r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 ) - return score.cpu().data.numpy() - - - - - def quaternion_init(self, in_features, out_features, criterion='he'): - - fan_in = in_features - fan_out = out_features - - if criterion == 'glorot': - s = 1. / np.sqrt(2 * (fan_in + fan_out)) - elif criterion == 'he': - s = 1. / np.sqrt(2 * fan_in) - else: - raise ValueError('Invalid criterion: ', criterion) - rng = RandomState(2020) - - # Generating randoms and purely imaginary quaternions : - kernel_shape = (in_features, out_features) - - number_of_weights = np.prod(kernel_shape) - v_i = np.random.uniform(0.0, 1.0, number_of_weights) - v_j = np.random.uniform(0.0, 1.0, number_of_weights) - v_k = np.random.uniform(0.0, 1.0, number_of_weights) + + return score + - # Purely imaginary quaternions unitary - for i in range(0, number_of_weights): - norm = np.sqrt(v_i[i] ** 2 + v_j[i] ** 2 + v_k[i] ** 2) + 0.0001 - v_i[i] /= norm - v_j[i] /= norm - v_k[i] /= norm - v_i = v_i.reshape(kernel_shape) - v_j = v_j.reshape(kernel_shape) - v_k = v_k.reshape(kernel_shape) - modulus = rng.uniform(low=-s, high=s, size=kernel_shape) + def forward_k_vs_all(self,x): + # (1) Retrieve embeddings & Apply Dropout & Normalization. + head_ent_emb, rel_ent_emb = self.get_head_relation_representation(x) + + e_1_h, e_2_h, e_3_h, e_4_h, e_5_h, e_6_h, e_7_h, e_8_h = torch.hsplit(head_ent_emb, 8) - # Calculate the three parts about t - kernel_shape1 = (in_features, out_features) - number_of_weights1 = np.prod(kernel_shape1) - t_i = np.random.uniform(0.0, 1.0, number_of_weights1) - t_j = np.random.uniform(0.0, 1.0, number_of_weights1) - t_k = np.random.uniform(0.0, 1.0, number_of_weights1) + r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 = torch.hsplit(rel_ent_emb, 8) - # Purely imaginary quaternions unitary - for i in range(0, number_of_weights1): - norm1 = np.sqrt(t_i[i] ** 2 + t_j[i] ** 2 + t_k[i] ** 2) + 0.0001 - t_i[i] /= norm1 - t_j[i] /= norm1 - t_k[i] /= norm1 - t_i = t_i.reshape(kernel_shape1) - t_j = t_j.reshape(kernel_shape1) - t_k = t_k.reshape(kernel_shape1) - tmp_t = rng.uniform(low=-s, high=s, size=kernel_shape1) + e_1_t, e_2_t, e_3_t, e_4_t, e_5_t, e_6_t, e_7_t, e_8_t = torch.hsplit(self.entity_embeddings.weight, 8) + e_1_t, e_2_t, e_3_t, e_4_t, e_5_t, e_6_t, e_7_t, e_8_t = self.T(e_1_t), self.T(e_2_t), self.T(e_3_t),\ + self.T(e_4_t), self.T(e_5_t), self.T(e_6_t), self.T(e_7_t), self.T(e_8_t) - phase = rng.uniform(low=-np.pi, high=np.pi, size=kernel_shape) - phase1 = rng.uniform(low=-np.pi, high=np.pi, size=kernel_shape1) + score = self.kvsall_score(e_1_h, e_2_h, e_3_h, e_4_h, e_5_h, e_6_h, e_7_h, e_8_h, + e_1_t, e_2_t, e_3_t, e_4_t, e_5_t, e_6_t, e_7_t, e_8_t, + r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 ) - weight_r = modulus * np.cos(phase) - weight_i = modulus * v_i * np.sin(phase) - weight_j = modulus * v_j * np.sin(phase) - weight_k = modulus * v_k * np.sin(phase) - wt_i = tmp_t * t_i * np.sin(phase1) - wt_j = tmp_t * t_j * np.sin(phase1) - wt_k = tmp_t * t_k * np.sin(phase1) + return score + + def T(self, x): + + return x.transpose(1, 0) - i_0=weight_r - i_1=weight_i - i_2=weight_j - i_3=weight_k - i_4=(-wt_i*weight_i-wt_j*weight_j-wt_k*weight_k)/2 - i_5=(wt_i*weight_r+wt_j*weight_k-wt_k*weight_j)/2 - i_6=(-wt_i*weight_k+wt_j*weight_r+wt_k*weight_i)/2 - i_7=(wt_i*weight_j-wt_j*weight_i+wt_k*weight_r)/2 + - return (i_0,i_1,i_2,i_3,i_4,i_5,i_6,i_7) From e7d33c102b70ef2393023d98716ff5674c371e40 Mon Sep 17 00:00:00 2001 From: Louis-Mozart Date: Wed, 27 Mar 2024 12:09:10 +0100 Subject: [PATCH 5/8] Regression file created --- tests/test_regression_DualE.py | 36 ++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 tests/test_regression_DualE.py diff --git a/tests/test_regression_DualE.py b/tests/test_regression_DualE.py new file mode 100644 index 00000000..c11d0ec5 --- /dev/null +++ b/tests/test_regression_DualE.py @@ -0,0 +1,36 @@ +from dicee.executer import Execute +import pytest +from dicee.config import Namespace + +class TestRegressionClifford: + @pytest.mark.filterwarnings('ignore::UserWarning') + def test_k_vs_all(self): + args = Namespace() + args.model = 'DualE' + args.scoring_technique = 'KvsAll' + args.optim = 'Adam' + args.dataset_dir = 'KGs/UMLS' + args.num_epochs = 32 + args.batch_size = 1024 + args.lr = 0.1 + args.embedding_dim = 32 + args.eval_model = 'train_val_test' + dualE_result = Execute(args).start() + + args = Namespace() + args.model = 'DeCaL' + args.scoring_technique = 'KvsAll' + args.optim = 'Adam' + args.p = 0 + args.q = 1 + args.r = 1 + args.dataset_dir = 'KGs/UMLS' + args.num_epochs = 32 + args.batch_size = 1024 + args.lr = 0.1 + args.embedding_dim = 32 + args.eval_model = 'train_val_test' + decal_result = Execute(args).start() + + assert decal_result["Train"]["MRR"] > dualE_result["Train"]["MRR"] + assert decal_result["Test"]["MRR"] > dualE_result["Test"]["MRR"] \ No newline at end of file From 39c1bb846bf63f5ee3bc3bc388819e74f6ad543c Mon Sep 17 00:00:00 2001 From: Louis-Mozart Date: Wed, 27 Mar 2024 15:19:12 +0100 Subject: [PATCH 6/8] Docstring added to DualE and DeCaL --- dicee/models/clifford.py | 88 ++++++++++++++++++++++++++------------- dicee/models/dualE.py | 89 ++++++++++++++++++++++++++++++++++++---- 2 files changed, 141 insertions(+), 36 deletions(-) diff --git a/dicee/models/clifford.py b/dicee/models/clifford.py index 31b5f025..570090f7 100644 --- a/dicee/models/clifford.py +++ b/dicee/models/clifford.py @@ -764,7 +764,7 @@ def forward_triples(self, x: torch.Tensor) -> torch.FloatTensor: Parameter --------- - x: torch.LongTensor with (n,3) shape + x: torch.LongTensor with (n, ) shape Returns ------- @@ -844,9 +844,9 @@ def forward_triples(self, x: torch.Tensor) -> torch.FloatTensor: sigma_qr = 0 return h0r0t0 + score_p + score_q + score_r + sigma_pp + sigma_qq + sigma_rr + sigma_pq + sigma_qr + sigma_pr - def cl_pqr(self, a): + def cl_pqr(self, a:torch.tensor)->torch.tensor: - ''' Input: tensor(batch_size, emb_dim) ----> output: tensor with 1+p+q+r components with size (batch_size, emb_dim/(1+p+q+r)) each. + ''' Input: tensor(batch_size, emb_dim) ---> output: tensor with 1+p+q+r components with size (batch_size, emb_dim/(1+p+q+r)) each. 1) takes a tensor of size (batch_size, emb_dim), split it into 1 + p + q +r components, hence 1+p+q+r must be a divisor of the emb_dim. @@ -861,17 +861,25 @@ def cl_pqr(self, a): def compute_sigmas_single(self, list_h_emb, list_r_emb, list_t_emb): '''here we compute all the sums with no others vectors interaction taken with the scalar product with t, that is, - 1) s0 = h_0r_0t_0 - 2) s1 = \sum_{i=1}^{p}h_ir_it_0 - 3) s2 = \sum_{j=p+1}^{p+q}h_jr_jt_0 - 4) s3 = \sum_{i=1}^{q}(h_0r_it_i + h_ir_0t_i) - 5) s4 = \sum_{i=p+1}^{p+q}(h_0r_it_i + h_ir_0t_i) - 5) s5 = \sum_{i=p+q+1}^{p+q+r}(h_0r_it_i + h_ir_0t_i) + + .. math:: + + s0 = h_0r_0t_0 + s1 = \sum_{i=1}^{p}h_ir_it_0 + s2 = \sum_{j=p+1}^{p+q}h_jr_jt_0 + s3 = \sum_{i=1}^{q}(h_0r_it_i + h_ir_0t_i) + s4 = \sum_{i=p+1}^{p+q}(h_0r_it_i + h_ir_0t_i) + s5 = \sum_{i=p+q+1}^{p+q+r}(h_0r_it_i + h_ir_0t_i) and return: - *) sigma_0t = \sigma_0 \cdot t_0 = s0 + s1 -s2 - *) s3, s4 and s5''' + .. math:: + + sigma_0t = \sigma_0 \cdot t_0 = s0 + s1 -s2 + s3, s4 and s5 + + + ''' p = self.p q = self.q @@ -906,15 +914,19 @@ def compute_sigmas_multivect(self, list_h_emb, list_r_emb): For same bases vectors interaction we have - 1) \sigma_pp = \sum_{i=1}^{p-1}\sum_{i'=i+1}^{p}(h_ir_{i'}-h_{i'}r_i) (models the interactions between e_i and e_i' for 1 <= i, i' <= p) - 2) \sigma_qq = \sum_{j=p+1}^{p+q-1}\sum_{j'=j+1}^{p+q}(h_jr_{j'}-h_{j'} (models the interactions between e_j and e_j' for p+1 <= j, j' <= p+q) - 3) \sigma_rr = \sum_{k=p+q+1}^{p+q+r-1}\sum_{k'=k+1}^{p}(h_kr_{k'}-h_{k'}r_k) (models the interactions between e_k and e_k' for p+q+1 <= k, k' <= p+q+r) - + .. math:: + + \sigma_pp = \sum_{i=1}^{p-1}\sum_{i'=i+1}^{p}(h_ir_{i'}-h_{i'}r_i) (models the interactions between e_i and e_i' for 1 <= i, i' <= p) + \sigma_qq = \sum_{j=p+1}^{p+q-1}\sum_{j'=j+1}^{p+q}(h_jr_{j'}-h_{j'} (models the interactions between e_j and e_j' for p+1 <= j, j' <= p+q) + \sigma_rr = \sum_{k=p+q+1}^{p+q+r-1}\sum_{k'=k+1}^{p}(h_kr_{k'}-h_{k'}r_k) (models the interactions between e_k and e_k' for p+q+1 <= k, k' <= p+q+r) + For different base vector interactions, we have - 4) \sigma_pq = \sum_{i=1}^{p}\sum_{j=p+1}^{p+q}(h_ir_j - h_jr_i) (interactionsn between e_i and e_j for 1<=i <=p and p+1<= j <= p+q) - 5) \sigma_pr = \sum_{i=1}^{p}\sum_{k=p+q+1}^{p+q+r}(h_ir_k - h_kr_i) (interactionsn between e_i and e_k for 1<=i <=p and p+q+1<= k <= p+q+r) - 6) \sigma_qr = \sum_{j=p+1}^{p+q}\sum_{j=p+q+1}^{p+q+r}(h_jr_k - h_kr_j) (interactionsn between e_j and e_k for p+1 <= j <=p+q and p+q+1<= j <= p+q+r) + .. math:: + + \sigma_pq = \sum_{i=1}^{p}\sum_{j=p+1}^{p+q}(h_ir_j - h_jr_i) (interactionsn between e_i and e_j for 1<=i <=p and p+1<= j <= p+q) + \sigma_pr = \sum_{i=1}^{p}\sum_{k=p+q+1}^{p+q+r}(h_ir_k - h_kr_i) (interactionsn between e_i and e_k for 1<=i <=p and p+q+1<= k <= p+q+r) + \sigma_qr = \sum_{j=p+1}^{p+q}\sum_{j=p+q+1}^{p+q+r}(h_jr_k - h_kr_j) (interactionsn between e_j and e_k for p+1 <= j <=p+q and p+q+1<= j <= p+q+r) ''' @@ -958,15 +970,15 @@ def forward_k_vs_all(self, x: torch.Tensor) -> torch.FloatTensor: """ Kvsall training - (1) Retrieve real-valued embedding vectors for heads and relations \mathbb{R}^d . - (2) Construct head entity and relation embeddings according to Cl_{p,q}(\mathbb{R}^d) . + (1) Retrieve real-valued embedding vectors for heads and relations + (2) Construct head entity and relation embeddings according to Cl_{p,q, r}(\mathbb{R}^d) . (3) Perform Cl multiplication (4) Inner product of (3) and all entity embeddings forward_k_vs_with_explicit and this funcitons are identical Parameter --------- - x: torch.LongTensor with (n,2) shape + x: torch.LongTensor with (n, ) shape Returns ------- torch.FloatTensor with (n, |E|) shape @@ -1097,9 +1109,12 @@ def construct_cl_multivector(self, x: torch.FloatTensor, re: int, p: int, q: int def compute_sigma_pp(self, hp, rp): """ - \sigma_{p,p}^* = \sum_{i=1}^{p-1}\sum_{i'=i+1}^{p}(x_iy_{i'}-x_{i'}y_i) + Compute + .. math:: + + \sigma_{p,p}^* = \sum_{i=1}^{p-1}\sum_{i'=i+1}^{p}(x_iy_{i'}-x_{i'}y_i) - sigma_{pp} captures the interactions between along p bases + \sigma_{pp} captures the interactions between along p bases For instance, let p e_1, e_2, e_3, we compute interactions between e_1 e_2, e_1 e_3 , and e_2 e_3 This can be implemented with a nested two for loops @@ -1125,7 +1140,12 @@ def compute_sigma_pp(self, hp, rp): def compute_sigma_qq(self, hq, rq): """ - Compute \sigma_{q,q}^* = \sum_{j=p+1}^{p+q-1}\sum_{j'=j+1}^{p+q}(x_jy_{j'}-x_{j'}y_j) Eq. 16 + Compute + + .. math:: + + \sigma_{q,q}^* = \sum_{j=p+1}^{p+q-1}\sum_{j'=j+1}^{p+q}(x_jy_{j'}-x_{j'}y_j) Eq. 16 + sigma_{q} captures the interactions between along q bases For instance, let q e_1, e_2, e_3, we compute interactions between e_1 e_2, e_1 e_3 , and e_2 e_3 This can be implemented with a nested two for loops @@ -1157,7 +1177,9 @@ def compute_sigma_qq(self, hq, rq): def compute_sigma_rr(self, hk, rk): """ - \sigma_{r,r}^* = \sum_{k=p+q+1}^{p+q+r-1}\sum_{k'=k+1}^{p}(x_ky_{k'}-x_{k'}y_k) + .. math:: + + \sigma_{r,r}^* = \sum_{k=p+q+1}^{p+q+r-1}\sum_{k'=k+1}^{p}(x_ky_{k'}-x_{k'}y_k) """ # Compute indexes for the upper triangle of p by p matrix @@ -1173,7 +1195,11 @@ def compute_sigma_rr(self, hk, rk): def compute_sigma_pq(self, *, hp, hq, rp, rq): """ - \sum_{i=1}^{p} \sum_{j=p+1}^{p+q} (h_i r_j - h_j r_i) e_i e_j + Compute + + .. math:: + + \sum_{i=1}^{p} \sum_{j=p+1}^{p+q} (h_i r_j - h_j r_i) e_i e_j results = [] sigma_pq = torch.zeros(b, r, p, q) @@ -1189,7 +1215,11 @@ def compute_sigma_pq(self, *, hp, hq, rp, rq): def compute_sigma_pr(self, *, hp, hk, rp, rk): """ - \sum_{i=1}^{p} \sum_{j=p+1}^{p+q} (h_i r_j - h_j r_i) e_i e_j + Compute + + .. math:: + + \sum_{i=1}^{p} \sum_{j=p+1}^{p+q} (h_i r_j - h_j r_i) e_i e_j results = [] sigma_pq = torch.zeros(b, r, p, q) @@ -1205,7 +1235,9 @@ def compute_sigma_pr(self, *, hp, hk, rp, rk): def compute_sigma_qr(self, *, hq, hk, rq, rk): """ - \sum_{i=1}^{p} \sum_{j=p+1}^{p+q} (h_i r_j - h_j r_i) e_i e_j + .. math:: + + \sum_{i=1}^{p} \sum_{j=p+1}^{p+q} (h_i r_j - h_j r_i) e_i e_j results = [] sigma_pq = torch.zeros(b, r, p, q) diff --git a/dicee/models/dualE.py b/dicee/models/dualE.py index f40a8a0e..d2688d41 100644 --- a/dicee/models/dualE.py +++ b/dicee/models/dualE.py @@ -2,8 +2,8 @@ from .base_model import BaseKGE - class DualE(BaseKGE): + """Dual Quaternion Knowledge Graph Embeddings (https://ojs.aaai.org/index.php/AAAI/article/download/16850/16657)""" def __init__(self, args): super().__init__(args) self.name = 'DualE' @@ -12,8 +12,9 @@ def __init__(self, args): self.num_ent = self.num_entities - #Calculate the Dual Hamiltonian product + def _omult(self, a_0, a_1, a_2, a_3, b_0, b_1, b_2, b_3, c_0, c_1, c_2, c_3, d_0, d_1, d_2, d_3): + """Calculate the Dual Hamiltonian product""" h_0=a_0*c_0-a_1*c_1-a_2*c_2-a_3*c_3 h1_0=a_0*d_0+b_0*c_0-a_1*d_1-b_1*c_1-a_2*d_2-b_2*c_2-a_3*d_3-b_3*c_3 @@ -26,8 +27,35 @@ def _omult(self, a_0, a_1, a_2, a_3, b_0, b_1, b_2, b_3, c_0, c_1, c_2, c_3, d_0 return (h_0,h_1,h_2,h_3,h1_0,h1_1,h1_2,h1_3) - #Normalization of relationship embedding + def _onorm(self,r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8): + """Normalization of relationship embedding + + Inputs + -------- + Real and Imaginary parts of the Relation embeddings + + .. math:: + + W_r = (c,d) + c = (r_1, r_2, r_3, r_4) + d = (r_5, r_6, r_7, r_8) + + .. math:: + + \bar{d} = d - \frac{}{} c + c' = \frac{c}{\|c\|} = \frac{c_0 + c_1i + c_2j + c_3k}{c_0^2 + c_1^2 + c_2^2 + c_3^2} + + + Outputs + -------- + Normalized Real and Imaginary parts of the Relation embeddings + + .. math:: + + W_r' = (c', \bar{d}) + """ + denominator_0 = r_1 ** 2 + r_2 ** 2 + r_3 ** 2 + r_4 ** 2 denominator_1 = torch.sqrt(denominator_0) #denominator_2 = torch.sqrt(r_5 ** 2 + r_6 ** 2 + r_7 ** 2 + r_8 ** 2) @@ -48,10 +76,19 @@ def _onorm(self,r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8): #r_8 = r_8 / denominator_2 return r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 - #Calculate the inner product of the head entity and the relationship Hamiltonian product and the tail entity + def _calc(self, e_1_h, e_2_h, e_3_h, e_4_h, e_5_h, e_6_h, e_7_h, e_8_h, e_1_t, e_2_t, e_3_t, e_4_t, e_5_t, e_6_t, e_7_t, e_8_t, - r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 ): + r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 )->torch.tensor: + + """Calculate the inner product of the head entity and the relationship Hamiltonian product and the tail entity ref(Eq.8) + \phi(h,r,t) = + + + + + Inputs: + ---------- + (Tensors) Real and imaginary parts of the head, relation and tail embeddings + + Output: inner product of the head entity and the relationship Hamiltonian product and the tail entity""" r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 = self._onorm(r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 ) @@ -66,7 +103,17 @@ def _calc(self, e_1_h, e_2_h, e_3_h, e_4_h, e_5_h, e_6_h, e_7_h, e_8_h, def kvsall_score(self, e_1_h, e_2_h, e_3_h, e_4_h, e_5_h, e_6_h, e_7_h, e_8_h, e_1_t, e_2_t, e_3_t, e_4_t, e_5_t, e_6_t, e_7_t, e_8_t, - r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 ): + r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 )->torch.tensor: + """KvsAll scoring function + + Input + --------- + x: torch.LongTensor with (n, ) shape + + Output + ------- + torch.FloatTensor with (n) shape + """ r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 = self._onorm(r_1, r_2, r_3, r_4, r_5, r_6, r_7, r_8 ) @@ -80,7 +127,17 @@ def kvsall_score(self, e_1_h, e_2_h, e_3_h, e_4_h, e_5_h, e_6_h, e_7_h, e_8_h, return -score_r - def forward_triples(self, idx_triple): + def forward_triples(self, idx_triple:torch.tensor)-> torch.tensor: + """Negative Sampling forward pass: + + Input + --------- + x: torch.LongTensor with (n, ) shape + + Output + ------- + torch.FloatTensor with (n) shape + """ head_ent_emb, rel_emb, tail_ent_emb = self.get_triple_representation(idx_triple) @@ -102,6 +159,18 @@ def forward_triples(self, idx_triple): def forward_k_vs_all(self,x): + """KvsAll forward pass + + Input + --------- + x: torch.LongTensor with (n, ) shape + + Output + ------- + torch.FloatTensor with (n) shape + + """ + # (1) Retrieve embeddings & Apply Dropout & Normalization. head_ent_emb, rel_ent_emb = self.get_head_relation_representation(x) @@ -121,7 +190,11 @@ def forward_k_vs_all(self,x): return score - def T(self, x): + def T(self, x:torch.tensor)->torch.tensor: + """ Transpose function + + Input: Tensor with shape (nxm) + Output: Tensor with shape (mxn)""" return x.transpose(1, 0) From ef23bd9d4e5d713d628d00adb3cd22970f88dabf Mon Sep 17 00:00:00 2001 From: Caglar Demir Date: Wed, 19 Jun 2024 14:08:01 +0200 Subject: [PATCH 7/8] example removed --- examples/Train_and_Eval_KGE.ipynb | 1408 ----------------------------- 1 file changed, 1408 deletions(-) delete mode 100644 examples/Train_and_Eval_KGE.ipynb diff --git a/examples/Train_and_Eval_KGE.ipynb b/examples/Train_and_Eval_KGE.ipynb deleted file mode 100644 index d3f1c894..00000000 --- a/examples/Train_and_Eval_KGE.ipynb +++ /dev/null @@ -1,1408 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "f096e66a", - "metadata": {}, - "source": [ - "# Train and Evaluate a KGE model" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "913b10b7", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Requirement already satisfied: dicee in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (0.0.2)\n", - "Requirement already satisfied: matplotlib>=3.6.2 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from dicee) (3.7.1)\n", - "Requirement already satisfied: modin[ray]>=0.16.2 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from dicee) (0.19.0)\n", - "Requirement already satisfied: scikit-learn>=1.1.1 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from dicee) (1.2.2)\n", - "Requirement already satisfied: gradio>=3.0.17 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from dicee) (3.23.0)\n", - "Requirement already satisfied: torch>=1.13.0 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from dicee) (2.0.0)\n", - "Requirement already satisfied: pandas>=1.5.1 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from dicee) (1.5.3)\n", - "Requirement already satisfied: pytorch-lightning==1.6.4 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from dicee) (1.6.4)\n", - "Requirement already satisfied: pyarrow>=8.0.0 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from dicee) (11.0.0)\n", - "Requirement already satisfied: polars>=0.15.13 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from dicee) (0.16.14)\n", - "Requirement already satisfied: pytest>=6.2.5 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from dicee) (7.2.2)\n", - "Requirement already satisfied: tqdm>=4.57.0 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from pytorch-lightning==1.6.4->dicee) (4.65.0)\n", - "Requirement already satisfied: typing-extensions>=4.0.0 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from pytorch-lightning==1.6.4->dicee) (4.5.0)\n", - "Requirement already satisfied: packaging>=17.0 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from pytorch-lightning==1.6.4->dicee) (23.0)\n", - "Requirement already satisfied: numpy>=1.17.2 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from pytorch-lightning==1.6.4->dicee) (1.24.2)\n", - "Requirement already satisfied: protobuf<=3.20.1 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from pytorch-lightning==1.6.4->dicee) (3.19.6)\n", - "Requirement already satisfied: fsspec[http]!=2021.06.0,>=2021.05.0 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from pytorch-lightning==1.6.4->dicee) (2023.3.0)\n", - "Requirement already satisfied: torchmetrics>=0.4.1 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from pytorch-lightning==1.6.4->dicee) (0.11.4)\n", - "Requirement already satisfied: tensorboard>=2.2.0 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from pytorch-lightning==1.6.4->dicee) (2.12.0)\n", - "Requirement already satisfied: pyDeprecate>=0.3.1 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from pytorch-lightning==1.6.4->dicee) (0.3.2)\n", - "Requirement already satisfied: PyYAML>=5.4 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from pytorch-lightning==1.6.4->dicee) (6.0)\n", - "Requirement already satisfied: python-multipart in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from gradio>=3.0.17->dicee) (0.0.6)\n", - "Requirement already satisfied: uvicorn in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from gradio>=3.0.17->dicee) (0.21.1)\n", - "Requirement already satisfied: huggingface-hub>=0.13.0 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from gradio>=3.0.17->dicee) (0.13.3)\n", - "Requirement already satisfied: semantic-version in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from gradio>=3.0.17->dicee) (2.10.0)\n", - "Requirement already satisfied: pillow in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from gradio>=3.0.17->dicee) (9.4.0)\n", - "Requirement already satisfied: fastapi in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from gradio>=3.0.17->dicee) (0.95.0)\n", - "Requirement already satisfied: markdown-it-py[linkify]>=2.0.0 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from gradio>=3.0.17->dicee) (2.2.0)\n", - "Requirement already satisfied: jinja2 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from gradio>=3.0.17->dicee) (3.1.2)\n", - "Requirement already satisfied: markupsafe in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from gradio>=3.0.17->dicee) (2.1.2)\n", - "Requirement already satisfied: httpx in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from gradio>=3.0.17->dicee) (0.23.3)\n", - "Requirement already satisfied: requests in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from gradio>=3.0.17->dicee) (2.28.2)\n", - "Requirement already satisfied: mdit-py-plugins<=0.3.3 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from gradio>=3.0.17->dicee) (0.3.3)\n", - "Requirement already satisfied: pydantic in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from gradio>=3.0.17->dicee) (1.10.7)\n", - "Requirement already satisfied: orjson in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from gradio>=3.0.17->dicee) (3.8.8)\n", - "Requirement already satisfied: aiofiles in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from gradio>=3.0.17->dicee) (23.1.0)\n", - "Requirement already satisfied: websockets>=10.0 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from gradio>=3.0.17->dicee) (10.4)\n", - "Requirement already satisfied: pydub in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from gradio>=3.0.17->dicee) (0.25.1)\n", - "Requirement already satisfied: ffmpy in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from gradio>=3.0.17->dicee) (0.3.0)\n", - "Requirement already satisfied: altair>=4.2.0 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from gradio>=3.0.17->dicee) (4.2.2)\n", - "Requirement already satisfied: aiohttp in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from gradio>=3.0.17->dicee) (3.8.4)\n", - "Requirement already satisfied: cycler>=0.10 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from matplotlib>=3.6.2->dicee) (0.11.0)\n", - "Requirement already satisfied: kiwisolver>=1.0.1 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from matplotlib>=3.6.2->dicee) (1.4.4)\n", - "Requirement already satisfied: fonttools>=4.22.0 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from matplotlib>=3.6.2->dicee) (4.39.2)\n", - "Requirement already satisfied: contourpy>=1.0.1 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from matplotlib>=3.6.2->dicee) (1.0.7)\n", - "Requirement already satisfied: pyparsing>=2.3.1 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from matplotlib>=3.6.2->dicee) (3.0.9)\n", - "Requirement already satisfied: python-dateutil>=2.7 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from matplotlib>=3.6.2->dicee) (2.8.2)\n", - "Requirement already satisfied: psutil in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from modin[ray]>=0.16.2->dicee) (5.9.4)\n", - "Requirement already satisfied: ray[default]>=1.13.0 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from modin[ray]>=0.16.2->dicee) (2.3.1)\n", - "Requirement already satisfied: pytz>=2020.1 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from pandas>=1.5.1->dicee) (2022.7.1)\n", - "Requirement already satisfied: tomli>=1.0.0 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from pytest>=6.2.5->dicee) (2.0.1)\n", - "Requirement already satisfied: exceptiongroup>=1.0.0rc8 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from pytest>=6.2.5->dicee) (1.1.1)\n", - "Requirement already satisfied: pluggy<2.0,>=0.12 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from pytest>=6.2.5->dicee) (1.0.0)\n", - "Requirement already satisfied: attrs>=19.2.0 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from pytest>=6.2.5->dicee) (22.2.0)\n", - "Requirement already satisfied: iniconfig in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from pytest>=6.2.5->dicee) (2.0.0)\n", - "Requirement already satisfied: scipy>=1.3.2 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from scikit-learn>=1.1.1->dicee) (1.10.1)\n", - "Requirement already satisfied: joblib>=1.1.1 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from scikit-learn>=1.1.1->dicee) (1.2.0)\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Requirement already satisfied: threadpoolctl>=2.0.0 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from scikit-learn>=1.1.1->dicee) (3.1.0)\n", - "Requirement already satisfied: nvidia-cublas-cu11==11.10.3.66 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from torch>=1.13.0->dicee) (11.10.3.66)\n", - "Requirement already satisfied: nvidia-nccl-cu11==2.14.3 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from torch>=1.13.0->dicee) (2.14.3)\n", - "Requirement already satisfied: nvidia-curand-cu11==10.2.10.91 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from torch>=1.13.0->dicee) (10.2.10.91)\n", - "Requirement already satisfied: triton==2.0.0 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from torch>=1.13.0->dicee) (2.0.0)\n", - "Requirement already satisfied: networkx in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from torch>=1.13.0->dicee) (3.0)\n", - "Requirement already satisfied: nvidia-cusparse-cu11==11.7.4.91 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from torch>=1.13.0->dicee) (11.7.4.91)\n", - "Requirement already satisfied: nvidia-cuda-nvrtc-cu11==11.7.99 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from torch>=1.13.0->dicee) (11.7.99)\n", - "Requirement already satisfied: filelock in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from torch>=1.13.0->dicee) (3.10.2)\n", - "Requirement already satisfied: nvidia-cufft-cu11==10.9.0.58 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from torch>=1.13.0->dicee) (10.9.0.58)\n", - "Requirement already satisfied: nvidia-cusolver-cu11==11.4.0.1 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from torch>=1.13.0->dicee) (11.4.0.1)\n", - "Requirement already satisfied: nvidia-cudnn-cu11==8.5.0.96 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from torch>=1.13.0->dicee) (8.5.0.96)\n", - "Requirement already satisfied: nvidia-nvtx-cu11==11.7.91 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from torch>=1.13.0->dicee) (11.7.91)\n", - "Requirement already satisfied: nvidia-cuda-runtime-cu11==11.7.99 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from torch>=1.13.0->dicee) (11.7.99)\n", - "Requirement already satisfied: nvidia-cuda-cupti-cu11==11.7.101 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from torch>=1.13.0->dicee) (11.7.101)\n", - "Requirement already satisfied: sympy in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from torch>=1.13.0->dicee) (1.11.1)\n", - "Requirement already satisfied: setuptools in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch>=1.13.0->dicee) (65.6.3)\n", - "Requirement already satisfied: wheel in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch>=1.13.0->dicee) (0.38.4)\n", - "Requirement already satisfied: lit in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from triton==2.0.0->torch>=1.13.0->dicee) (16.0.0)\n", - "Requirement already satisfied: cmake in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from triton==2.0.0->torch>=1.13.0->dicee) (3.26.0)\n", - "Requirement already satisfied: toolz in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from altair>=4.2.0->gradio>=3.0.17->dicee) (0.12.0)\n", - "Requirement already satisfied: jsonschema>=3.0 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from altair>=4.2.0->gradio>=3.0.17->dicee) (4.17.3)\n", - "Requirement already satisfied: entrypoints in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from altair>=4.2.0->gradio>=3.0.17->dicee) (0.4)\n", - "Requirement already satisfied: aiosignal>=1.1.2 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from aiohttp->gradio>=3.0.17->dicee) (1.3.1)\n", - "Requirement already satisfied: frozenlist>=1.1.1 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from aiohttp->gradio>=3.0.17->dicee) (1.3.3)\n", - "Requirement already satisfied: multidict<7.0,>=4.5 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from aiohttp->gradio>=3.0.17->dicee) (6.0.4)\n", - "Requirement already satisfied: yarl<2.0,>=1.0 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from aiohttp->gradio>=3.0.17->dicee) (1.8.2)\n", - "Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from aiohttp->gradio>=3.0.17->dicee) (3.1.0)\n", - "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from aiohttp->gradio>=3.0.17->dicee) (4.0.2)\n", - "Requirement already satisfied: mdurl~=0.1 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from markdown-it-py[linkify]>=2.0.0->gradio>=3.0.17->dicee) (0.1.2)\n", - "Requirement already satisfied: linkify-it-py<3,>=1 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from markdown-it-py[linkify]>=2.0.0->gradio>=3.0.17->dicee) (2.0.0)\n", - "Requirement already satisfied: six>=1.5 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from python-dateutil>=2.7->matplotlib>=3.6.2->dicee) (1.16.0)\n", - "Requirement already satisfied: virtualenv>=20.0.24 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from ray[default]>=1.13.0->modin[ray]>=0.16.2->dicee) (20.21.0)\n", - "Requirement already satisfied: msgpack<2.0.0,>=1.0.0 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from ray[default]>=1.13.0->modin[ray]>=0.16.2->dicee) (1.0.5)\n", - "Requirement already satisfied: grpcio>=1.42.0 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from ray[default]>=1.13.0->modin[ray]>=0.16.2->dicee) (1.51.3)\n", - "Requirement already satisfied: click>=7.0 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from ray[default]>=1.13.0->modin[ray]>=0.16.2->dicee) (8.1.3)\n", - "Requirement already satisfied: colorful in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from ray[default]>=1.13.0->modin[ray]>=0.16.2->dicee) (0.5.5)\n", - "Requirement already satisfied: smart-open in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from ray[default]>=1.13.0->modin[ray]>=0.16.2->dicee) (6.3.0)\n", - "Requirement already satisfied: py-spy>=0.2.0 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from ray[default]>=1.13.0->modin[ray]>=0.16.2->dicee) (0.3.14)\n", - "Requirement already satisfied: gpustat>=1.0.0 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from ray[default]>=1.13.0->modin[ray]>=0.16.2->dicee) (1.0.0)\n", - "Requirement already satisfied: opencensus in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from ray[default]>=1.13.0->modin[ray]>=0.16.2->dicee) (0.11.2)\n", - "Requirement already satisfied: aiohttp-cors in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from ray[default]>=1.13.0->modin[ray]>=0.16.2->dicee) (0.7.0)\n", - "Requirement already satisfied: prometheus-client>=0.7.1 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from ray[default]>=1.13.0->modin[ray]>=0.16.2->dicee) (0.16.0)\n", - "Requirement already satisfied: absl-py>=0.4 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from tensorboard>=2.2.0->pytorch-lightning==1.6.4->dicee) (1.4.0)\n", - "Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from tensorboard>=2.2.0->pytorch-lightning==1.6.4->dicee) (0.4.6)\n", - "Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from tensorboard>=2.2.0->pytorch-lightning==1.6.4->dicee) (0.7.0)\n", - "Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from tensorboard>=2.2.0->pytorch-lightning==1.6.4->dicee) (1.8.1)\n", - "Requirement already satisfied: google-auth<3,>=1.6.3 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from tensorboard>=2.2.0->pytorch-lightning==1.6.4->dicee) (2.16.2)\n", - "Requirement already satisfied: werkzeug>=1.0.1 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from tensorboard>=2.2.0->pytorch-lightning==1.6.4->dicee) (2.2.3)\n", - "Requirement already satisfied: markdown>=2.6.8 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from tensorboard>=2.2.0->pytorch-lightning==1.6.4->dicee) (3.4.2)\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Requirement already satisfied: idna<4,>=2.5 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from requests->gradio>=3.0.17->dicee) (3.4)\n", - "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from requests->gradio>=3.0.17->dicee) (1.26.15)\n", - "Requirement already satisfied: certifi>=2017.4.17 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from requests->gradio>=3.0.17->dicee) (2022.12.7)\n", - "Requirement already satisfied: starlette<0.27.0,>=0.26.1 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from fastapi->gradio>=3.0.17->dicee) (0.26.1)\n", - "Requirement already satisfied: sniffio in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from httpx->gradio>=3.0.17->dicee) (1.3.0)\n", - "Requirement already satisfied: rfc3986[idna2008]<2,>=1.3 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from httpx->gradio>=3.0.17->dicee) (1.5.0)\n", - "Requirement already satisfied: httpcore<0.17.0,>=0.15.0 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from httpx->gradio>=3.0.17->dicee) (0.16.3)\n", - "Requirement already satisfied: mpmath>=0.19 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from sympy->torch>=1.13.0->dicee) (1.3.0)\n", - "Requirement already satisfied: h11>=0.8 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from uvicorn->gradio>=3.0.17->dicee) (0.14.0)\n", - "Requirement already satisfied: rsa<5,>=3.1.4 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from google-auth<3,>=1.6.3->tensorboard>=2.2.0->pytorch-lightning==1.6.4->dicee) (4.9)\n", - "Requirement already satisfied: pyasn1-modules>=0.2.1 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from google-auth<3,>=1.6.3->tensorboard>=2.2.0->pytorch-lightning==1.6.4->dicee) (0.2.8)\n", - "Requirement already satisfied: cachetools<6.0,>=2.0.0 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from google-auth<3,>=1.6.3->tensorboard>=2.2.0->pytorch-lightning==1.6.4->dicee) (5.3.0)\n", - "Requirement already satisfied: requests-oauthlib>=0.7.0 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=2.2.0->pytorch-lightning==1.6.4->dicee) (1.3.1)\n", - "Requirement already satisfied: blessed>=1.17.1 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from gpustat>=1.0.0->ray[default]>=1.13.0->modin[ray]>=0.16.2->dicee) (1.20.0)\n", - "Requirement already satisfied: nvidia-ml-py<=11.495.46,>=11.450.129 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from gpustat>=1.0.0->ray[default]>=1.13.0->modin[ray]>=0.16.2->dicee) (11.495.46)\n", - "Requirement already satisfied: anyio<5.0,>=3.0 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from httpcore<0.17.0,>=0.15.0->httpx->gradio>=3.0.17->dicee) (3.6.2)\n", - "Requirement already satisfied: pyrsistent!=0.17.0,!=0.17.1,!=0.17.2,>=0.14.0 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from jsonschema>=3.0->altair>=4.2.0->gradio>=3.0.17->dicee) (0.19.3)\n", - "Requirement already satisfied: uc-micro-py in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from linkify-it-py<3,>=1->markdown-it-py[linkify]>=2.0.0->gradio>=3.0.17->dicee) (1.0.1)\n", - "Requirement already satisfied: platformdirs<4,>=2.4 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from virtualenv>=20.0.24->ray[default]>=1.13.0->modin[ray]>=0.16.2->dicee) (3.2.0)\n", - "Requirement already satisfied: distlib<1,>=0.3.6 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from virtualenv>=20.0.24->ray[default]>=1.13.0->modin[ray]>=0.16.2->dicee) (0.3.6)\n", - "Requirement already satisfied: google-api-core<3.0.0,>=1.0.0 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from opencensus->ray[default]>=1.13.0->modin[ray]>=0.16.2->dicee) (2.11.0)\n", - "Requirement already satisfied: opencensus-context>=0.1.3 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from opencensus->ray[default]>=1.13.0->modin[ray]>=0.16.2->dicee) (0.1.3)\n", - "Requirement already satisfied: wcwidth>=0.1.4 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from blessed>=1.17.1->gpustat>=1.0.0->ray[default]>=1.13.0->modin[ray]>=0.16.2->dicee) (0.2.6)\n", - "Requirement already satisfied: googleapis-common-protos<2.0dev,>=1.56.2 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from google-api-core<3.0.0,>=1.0.0->opencensus->ray[default]>=1.13.0->modin[ray]>=0.16.2->dicee) (1.59.0)\n", - "Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard>=2.2.0->pytorch-lightning==1.6.4->dicee) (0.4.8)\n", - "Requirement already satisfied: oauthlib>=3.0.0 in /home/demir/anaconda3/envs/dice/lib/python3.10/site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=2.2.0->pytorch-lightning==1.6.4->dicee) (3.2.2)\n" - ] - } - ], - "source": [ - "# Install dicee \n", - "!pip install dicee" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "cf3fe02b", - "metadata": {}, - "outputs": [], - "source": [ - "from dicee import KGE, Execute\n", - "from dicee.config import Args" - ] - }, - { - "cell_type": "markdown", - "id": "168ae7ac", - "metadata": {}, - "source": [ - "# How to Train" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "66a3f46f", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# (1) Load Default Params\n", - "args=Args()\n", - "# (2) Select a Dataset and a KGE\n", - "args.path_dataset_folder=\"../KGs/UMLS\"\n", - "args.model=\"AConEx\"\n", - "args.embedding_dim=32\n", - "args.num_epochs=256\n", - "args.num_of_output_channels= 32\n", - "args.batch_size=1024\n", - "args.scoring_technique=\"KvsAll\"\n", - "args.eval_model=\"train_val_test\"\n", - "args" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "c21f128d", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "acquired_abnormality\tlocation_of\texperimental_model_of_disease\r\n", - "anatomical_abnormality\tmanifestation_of\tphysiologic_function\r\n", - "alga\tisa\tentity\r\n", - "mental_or_behavioral_dysfunction\taffects\texperimental_model_of_disease\r\n", - "health_care_activity\tassociated_with\tanatomical_abnormality\r\n", - "population_group\tinteracts_with\tage_group\r\n", - "clinical_attribute\tresult_of\tnatural_phenomenon_or_process\r\n", - "body_part_organ_or_organ_component\tlocation_of\tbiologic_function\r\n", - "biologically_active_substance\tcomplicates\tanatomical_abnormality\r\n", - "disease_or_syndrome\tresult_of\tacquired_abnormality\r\n" - ] - } - ], - "source": [ - "!head ../KGs/UMLS/train.txt" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "885bc8da", - "metadata": {}, - "outputs": [], - "source": [ - "executor=Execute(args)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "8f39f407", - "metadata": { - "scrolled": false - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Start time:2023-04-05 12:59:36.761322\n", - "*** Read or Load Knowledge Graph ***\n", - "*** Reading ../KGs/UMLS/test.txt with Pandas ***\n", - "Reading with pandas.read_csv with sep ** s+ ** ...\n", - "Took 0.0068 seconds | Current Memory Usage 511.45 in MB\n", - "*** Reading ../KGs/UMLS/train.txt with Pandas ***\n", - "Reading with pandas.read_csv with sep ** s+ ** ...\n", - "Took 0.0096 seconds | Current Memory Usage 511.99 in MB\n", - "*** Reading ../KGs/UMLS/valid.txt with Pandas ***\n", - "Reading with pandas.read_csv with sep ** s+ ** ...\n", - "Took 0.0025 seconds | Current Memory Usage 511.99 in MB\n", - "[3.1 / 14] Add reciprocal triples to train, validation, and test sets, e.g. KG:= {(s,p,o)} union {(o,p_inverse,s)}\n", - "Done !\n", - "\n", - "\n", - "Concatenating data to obtain index...\n", - "Done !\n", - "\n", - "Creating a mapping from entities to integer indexes...\n", - "Done !\n", - "\n", - "Done ! 0.012 seconds\n", - "\n", - "Done !\n", - "\n", - "Done !\n", - "\n", - "Took 0.0366 seconds | Current Memory Usage 514.96 in MB\n", - "Data Type conversion...\n", - "Submit er-vocab, re-vocab, and ee-vocab via ProcessPoolExecutor...\n", - "Preprocessing took: 0.104 seconds\n", - "\n", - "------------------- Description of Dataset ../KGs/UMLS -------------------\n", - "Number of entities:135\n", - "Number of relations:92\n", - "Number of triples on train set:10432\n", - "Number of triples on valid set:1304\n", - "Number of triples on test set:1322\n", - "Entity Index:0.00000 in GB\n", - "Relation Index:0.00000 in GB\n", - "Train set :0.00006 in GB\n", - "\n", - "# of CPUs:4 | # of GPUs:0 | # of CPUs for dataloader:0\n", - "------------------- Train -------------------\n", - "Initializing TorchTrainer CPU Trainer...\tTook 0.0051 seconds | Current Memory Usage 515.43 in MB\n", - "Initializing Model...\tTook 0.0077 seconds | Current Memory Usage 517.64 in MB\n", - "Initializing Dataset...\tTook 0.0226 seconds\n", - "Took 0.0241 seconds | Current Memory Usage 518.88 in MB\n", - "Took 0.0243 seconds | Current Memory Usage 518.88 in MB\n", - "Initializing Dataloader...\tTook 0.0007 seconds | Current Memory Usage 518.93 in MB\n", - "AConEx(\n", - " (loss): BCEWithLogitsLoss()\n", - " (normalize_head_entity_embeddings): IdentityClass()\n", - " (normalize_relation_embeddings): IdentityClass()\n", - " (normalize_tail_entity_embeddings): IdentityClass()\n", - " (hidden_normalizer): IdentityClass()\n", - " (input_dp_ent_real): Dropout(p=0.0, inplace=False)\n", - " (input_dp_rel_real): Dropout(p=0.0, inplace=False)\n", - " (hidden_dropout): Dropout(p=0.0, inplace=False)\n", - " (entity_embeddings): Embedding(135, 32)\n", - " (relation_embeddings): Embedding(92, 32)\n", - " (conv2d): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", - " (fc1): Linear(in_features=2048, out_features=64, bias=True)\n", - " (norm_fc1): IdentityClass()\n", - " (bn_conv2d): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (feature_map_dropout): Dropout2d(p=0.0, inplace=False)\n", - ")\n", - " | Name | Type | Params\n", - "------------------------------------------------------------------------\n", - "0 | loss | BCEWithLogitsLoss | 0 \n", - "1 | normalize_head_entity_embeddings | IdentityClass | 0 \n", - "2 | normalize_relation_embeddings | IdentityClass | 0 \n", - "3 | normalize_tail_entity_embeddings | IdentityClass | 0 \n", - "4 | hidden_normalizer | IdentityClass | 0 \n", - "5 | input_dp_ent_real | Dropout | 0 \n", - "6 | input_dp_rel_real | Dropout | 0 \n", - "7 | hidden_dropout | Dropout | 0 \n", - "8 | entity_embeddings | Embedding | 4.3 K \n", - "9 | relation_embeddings | Embedding | 2.9 K \n", - "10 | conv2d | Conv2d | 320 \n", - "11 | fc1 | Linear | 131 K \n", - "12 | norm_fc1 | IdentityClass | 0 \n", - "13 | bn_conv2d | BatchNorm2d | 64 \n", - "14 | feature_map_dropout | Dropout2d | 0 \n", - "------------------------------------------------------------------------\n", - "138 K Trainable params\n", - "0 Non-trainable params\n", - "138 K Total params\n", - "0.555 Total estimated model params size (MB)\n", - "Adam (\n", - "Parameter Group 0\n", - " amsgrad: False\n", - " betas: (0.9, 0.999)\n", - " capturable: False\n", - " differentiable: False\n", - " eps: 1e-08\n", - " foreach: None\n", - " fused: None\n", - " lr: 0.1\n", - " maximize: False\n", - " weight_decay: 0.0\n", - ")\n", - "\n", - "Training is starting 2023-04-05 12:59:36.907856...\n", - "NumOfDataPoints:1560 | NumOfEpochs:256 | LearningRate:0.1 | BatchSize:1024 | EpochBatchsize:2\n", - "Epoch:1 | Batch:1 | Loss:3.2829627990722656 |ForwardBackwardUpdate:0.11secs | Mem. Usage 570.18MB\n", - "Epoch:1 | Batch:2 | Loss:86.3902359009 |ForwardBackwardUpdate:0.03sec | BatchConst.:0.03sec | Mem. Usage 588.69MB avail. 49.0 %\n", - "Epoch:1 | Loss:44.83659935 | Runtime:0.004 mins\n", - "Epoch:2 | Batch:1 | Loss:41.616783142089844 |ForwardBackwardUpdate:0.06secs | Mem. Usage 584.2MB\n", - "Epoch:2 | Batch:2 | Loss:19.4291305542 |ForwardBackwardUpdate:0.03sec | BatchConst.:0.01sec | Mem. Usage 588.78MB avail. 48.9 %\n", - "Epoch:2 | Loss:30.52295685 | Runtime:0.002 mins\n", - "Epoch:3 | Batch:1 | Loss:9.976593017578125 |ForwardBackwardUpdate:0.06secs | Mem. Usage 584.22MB\n", - "Epoch:3 | Batch:2 | Loss:4.1640186310 |ForwardBackwardUpdate:0.03sec | BatchConst.:0.01sec | Mem. Usage 588.81MB avail. 49.0 %\n", - "Epoch:3 | Loss:7.07030582 | Runtime:0.002 mins\n", - "Epoch:4 | Batch:1 | Loss:2.334935188293457 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.11MB\n", - "Epoch:4 | Batch:2 | Loss:1.5733017921 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.67MB avail. 48.9 %\n", - "Epoch:4 | Loss:1.95411849 | Runtime:0.002 mins\n", - "Epoch:5 | Batch:1 | Loss:1.432958722114563 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.13MB\n", - "Epoch:5 | Batch:2 | Loss:1.5192382336 |ForwardBackwardUpdate:0.03sec | BatchConst.:0.01sec | Mem. Usage 588.69MB avail. 48.9 %\n", - "Epoch:5 | Loss:1.47609848 | Runtime:0.002 mins\n", - "Epoch:6 | Batch:1 | Loss:1.595378041267395 |ForwardBackwardUpdate:0.05secs | Mem. Usage 584.14MB\n", - "Epoch:6 | Batch:2 | Loss:1.6085883379 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.02sec | Mem. Usage 588.73MB avail. 48.9 %\n", - "Epoch:6 | Loss:1.60198319 | Runtime:0.002 mins\n", - "Epoch:7 | Batch:1 | Loss:1.5414557456970215 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.16MB\n", - "Epoch:7 | Batch:2 | Loss:1.4164723158 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.15sec | Mem. Usage 588.69MB avail. 48.9 %\n", - "Epoch:7 | Loss:1.47896403 | Runtime:0.004 mins\n", - "Epoch:8 | Batch:1 | Loss:1.3490146398544312 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.18MB\n", - "Epoch:8 | Batch:2 | Loss:1.1586432457 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.02sec | Mem. Usage 588.76MB avail. 48.9 %\n", - "Epoch:8 | Loss:1.25382894 | Runtime:0.002 mins\n", - "Epoch:9 | Batch:1 | Loss:1.0980411767959595 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.19MB\n", - "Epoch:9 | Batch:2 | Loss:0.9264459014 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.02sec | Mem. Usage 588.78MB avail. 49.0 %\n", - "Epoch:9 | Loss:1.01224354 | Runtime:0.002 mins\n", - "Epoch:10 | Batch:1 | Loss:0.8234874606132507 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.21MB\n", - "Epoch:10 | Batch:2 | Loss:0.7607272863 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.02sec | Mem. Usage 588.79MB avail. 49.0 %\n", - "Epoch:10 | Loss:0.79210737 | Runtime:0.002 mins\n", - "Epoch:11 | Batch:1 | Loss:0.6415410041809082 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.22MB\n", - "Epoch:11 | Batch:2 | Loss:0.5779718757 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.02sec | Mem. Usage 588.76MB avail. 48.9 %\n", - "Epoch:11 | Loss:0.60975644 | Runtime:0.002 mins\n", - "Epoch:12 | Batch:1 | Loss:0.531267523765564 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.11MB\n", - "Epoch:12 | Batch:2 | Loss:0.4684059024 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.02sec | Mem. Usage 588.71MB avail. 48.9 %\n", - "Epoch:12 | Loss:0.49983671 | Runtime:0.002 mins\n", - "Epoch:13 | Batch:1 | Loss:0.4715423583984375 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.13MB\n", - "Epoch:13 | Batch:2 | Loss:0.4312486649 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.02sec | Mem. Usage 588.71MB avail. 49.0 %\n", - "Epoch:13 | Loss:0.45139551 | Runtime:0.002 mins\n", - "Epoch:14 | Batch:1 | Loss:0.4405898451805115 |ForwardBackwardUpdate:0.05secs | Mem. Usage 584.28MB\n", - "Epoch:14 | Batch:2 | Loss:0.4347212315 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.02sec | Mem. Usage 588.55MB avail. 49.0 %\n", - "Epoch:14 | Loss:0.43765554 | Runtime:0.002 mins\n", - "Epoch:15 | Batch:1 | Loss:0.42732974886894226 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.29MB\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch:15 | Batch:2 | Loss:0.4374894798 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.02sec | Mem. Usage 588.61MB avail. 48.9 %\n", - "Epoch:15 | Loss:0.43240961 | Runtime:0.002 mins\n", - "Epoch:16 | Batch:1 | Loss:0.43023064732551575 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.31MB\n", - "Epoch:16 | Batch:2 | Loss:0.3968146443 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.02sec | Mem. Usage 588.79MB avail. 48.9 %\n", - "Epoch:16 | Loss:0.41352265 | Runtime:0.002 mins\n", - "Epoch:17 | Batch:1 | Loss:0.40646809339523315 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.33MB\n", - "Epoch:17 | Batch:2 | Loss:0.3878369033 |ForwardBackwardUpdate:0.03sec | BatchConst.:0.02sec | Mem. Usage 588.91MB avail. 48.9 %\n", - "Epoch:17 | Loss:0.39715250 | Runtime:0.002 mins\n", - "Epoch:18 | Batch:1 | Loss:0.3904493451118469 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.22MB\n", - "Epoch:18 | Batch:2 | Loss:0.3501742482 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.8MB avail. 49.0 %\n", - "Epoch:18 | Loss:0.37031180 | Runtime:0.002 mins\n", - "Epoch:19 | Batch:1 | Loss:0.3455673158168793 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.09MB\n", - "Epoch:19 | Batch:2 | Loss:0.3615186214 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.68MB avail. 49.0 %\n", - "Epoch:19 | Loss:0.35354297 | Runtime:0.001 mins\n", - "Epoch:20 | Batch:1 | Loss:0.33274781703948975 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.11MB\n", - "Epoch:20 | Batch:2 | Loss:0.3184422851 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.65MB avail. 48.9 %\n", - "Epoch:20 | Loss:0.32559505 | Runtime:0.002 mins\n", - "Epoch:21 | Batch:1 | Loss:0.3171499967575073 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.13MB\n", - "Epoch:21 | Batch:2 | Loss:0.2909499109 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.72MB avail. 48.9 %\n", - "Epoch:21 | Loss:0.30404995 | Runtime:0.001 mins\n", - "Epoch:22 | Batch:1 | Loss:0.2913304567337036 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.29MB\n", - "Epoch:22 | Batch:2 | Loss:0.2941762507 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.77MB avail. 48.9 %\n", - "Epoch:22 | Loss:0.29275335 | Runtime:0.002 mins\n", - "Epoch:23 | Batch:1 | Loss:0.28453922271728516 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.18MB\n", - "Epoch:23 | Batch:2 | Loss:0.2737905979 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.94MB avail. 48.9 %\n", - "Epoch:23 | Loss:0.27916491 | Runtime:0.002 mins\n", - "Epoch:24 | Batch:1 | Loss:0.27182915806770325 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.21MB\n", - "Epoch:24 | Batch:2 | Loss:0.2711850703 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.02sec | Mem. Usage 588.82MB avail. 48.9 %\n", - "Epoch:24 | Loss:0.27150711 | Runtime:0.002 mins\n", - "Epoch:25 | Batch:1 | Loss:0.2654229700565338 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.25MB\n", - "Epoch:25 | Batch:2 | Loss:0.2560573816 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.84MB avail. 49.0 %\n", - "Epoch:25 | Loss:0.26074018 | Runtime:0.002 mins\n", - "Epoch:26 | Batch:1 | Loss:0.253724604845047 |ForwardBackwardUpdate:0.05secs | Mem. Usage 584.14MB\n", - "Epoch:26 | Batch:2 | Loss:0.2502148449 |ForwardBackwardUpdate:0.03sec | BatchConst.:0.02sec | Mem. Usage 588.72MB avail. 49.0 %\n", - "Epoch:26 | Loss:0.25196972 | Runtime:0.002 mins\n", - "Epoch:27 | Batch:1 | Loss:0.2477717101573944 |ForwardBackwardUpdate:0.05secs | Mem. Usage 584.16MB\n", - "Epoch:27 | Batch:2 | Loss:0.2315254509 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.75MB avail. 49.0 %\n", - "Epoch:27 | Loss:0.23964858 | Runtime:0.002 mins\n", - "Epoch:28 | Batch:1 | Loss:0.23392672836780548 |ForwardBackwardUpdate:0.05secs | Mem. Usage 584.31MB\n", - "Epoch:28 | Batch:2 | Loss:0.2280673385 |ForwardBackwardUpdate:0.03sec | BatchConst.:0.01sec | Mem. Usage 588.64MB avail. 48.9 %\n", - "Epoch:28 | Loss:0.23099703 | Runtime:0.002 mins\n", - "Epoch:29 | Batch:1 | Loss:0.2297716736793518 |ForwardBackwardUpdate:0.06secs | Mem. Usage 584.33MB\n", - "Epoch:29 | Batch:2 | Loss:0.2091106921 |ForwardBackwardUpdate:0.03sec | BatchConst.:0.01sec | Mem. Usage 588.92MB avail. 48.9 %\n", - "Epoch:29 | Loss:0.21944118 | Runtime:0.002 mins\n", - "Epoch:30 | Batch:1 | Loss:0.21492809057235718 |ForwardBackwardUpdate:0.06secs | Mem. Usage 584.22MB\n", - "Epoch:30 | Batch:2 | Loss:0.2151773423 |ForwardBackwardUpdate:0.03sec | BatchConst.:0.02sec | Mem. Usage 588.81MB avail. 49.1 %\n", - "Epoch:30 | Loss:0.21505272 | Runtime:0.002 mins\n", - "Epoch:31 | Batch:1 | Loss:0.20622862875461578 |ForwardBackwardUpdate:0.05secs | Mem. Usage 584.11MB\n", - "Epoch:31 | Batch:2 | Loss:0.2124694288 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.7MB avail. 49.1 %\n", - "Epoch:31 | Loss:0.20934903 | Runtime:0.002 mins\n", - "Epoch:32 | Batch:1 | Loss:0.20401526987552643 |ForwardBackwardUpdate:0.05secs | Mem. Usage 584.27MB\n", - "Epoch:32 | Batch:2 | Loss:0.2003787160 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.86MB avail. 49.0 %\n", - "Epoch:32 | Loss:0.20219699 | Runtime:0.002 mins\n", - "Epoch:33 | Batch:1 | Loss:0.1984919011592865 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.16MB\n", - "Epoch:33 | Batch:2 | Loss:0.1959941983 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.95MB avail. 49.0 %\n", - "Epoch:33 | Loss:0.19724305 | Runtime:0.001 mins\n", - "Epoch:34 | Batch:1 | Loss:0.193772092461586 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.24MB\n", - "Epoch:34 | Batch:2 | Loss:0.1905900538 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 584.51MB avail. 49.0 %\n", - "Epoch:34 | Loss:0.19218107 | Runtime:0.002 mins\n", - "Epoch:35 | Batch:1 | Loss:0.1889636367559433 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.11MB\n", - "Epoch:35 | Batch:2 | Loss:0.1862379909 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.89MB avail. 49.0 %\n", - "Epoch:35 | Loss:0.18760081 | Runtime:0.001 mins\n", - "Epoch:36 | Batch:1 | Loss:0.18612734973430634 |ForwardBackwardUpdate:0.06secs | Mem. Usage 584.18MB\n", - "Epoch:36 | Batch:2 | Loss:0.1789232641 |ForwardBackwardUpdate:0.03sec | BatchConst.:0.02sec | Mem. Usage 588.75MB avail. 49.0 %\n", - "Epoch:36 | Loss:0.18252531 | Runtime:0.002 mins\n", - "Epoch:37 | Batch:1 | Loss:0.18710067868232727 |ForwardBackwardUpdate:0.05secs | Mem. Usage 584.19MB\n", - "Epoch:37 | Batch:2 | Loss:0.1661313921 |ForwardBackwardUpdate:0.03sec | BatchConst.:0.01sec | Mem. Usage 588.77MB avail. 49.0 %\n", - "Epoch:37 | Loss:0.17661604 | Runtime:0.002 mins\n", - "Epoch:38 | Batch:1 | Loss:0.17419378459453583 |ForwardBackwardUpdate:0.05secs | Mem. Usage 584.2MB\n", - "Epoch:38 | Batch:2 | Loss:0.1803482324 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.8MB avail. 49.0 %\n", - "Epoch:38 | Loss:0.17727101 | Runtime:0.002 mins\n", - "Epoch:39 | Batch:1 | Loss:0.17577333748340607 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.22MB\n", - "Epoch:39 | Batch:2 | Loss:0.1678776294 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 584.49MB avail. 49.0 %\n", - "Epoch:39 | Loss:0.17182548 | Runtime:0.002 mins\n", - "Epoch:40 | Batch:1 | Loss:0.17488594353199005 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.09MB\n", - "Epoch:40 | Batch:2 | Loss:0.1606011838 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.68MB avail. 49.0 %\n", - "Epoch:40 | Loss:0.16774356 | Runtime:0.002 mins\n", - "Epoch:41 | Batch:1 | Loss:0.16182918846607208 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.1MB\n", - "Epoch:41 | Batch:2 | Loss:0.1773647517 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.69MB avail. 49.0 %\n", - "Epoch:41 | Loss:0.16959697 | Runtime:0.002 mins\n", - "Epoch:42 | Batch:1 | Loss:0.16739903390407562 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.12MB\n", - "Epoch:42 | Batch:2 | Loss:0.1582861990 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.7MB avail. 49.0 %\n", - "Epoch:42 | Loss:0.16284262 | Runtime:0.001 mins\n", - "Epoch:43 | Batch:1 | Loss:0.16058635711669922 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.13MB\n", - "Epoch:43 | Batch:2 | Loss:0.1637157947 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.72MB avail. 49.0 %\n", - "Epoch:43 | Loss:0.16215108 | Runtime:0.001 mins\n", - "Epoch:44 | Batch:1 | Loss:0.15870559215545654 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.15MB\n", - "Epoch:44 | Batch:2 | Loss:0.1603547484 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.74MB avail. 49.0 %\n", - "Epoch:44 | Loss:0.15953017 | Runtime:0.001 mins\n", - "Epoch:45 | Batch:1 | Loss:0.15698355436325073 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.17MB\n", - "Epoch:45 | Batch:2 | Loss:0.1567720622 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.76MB avail. 49.0 %\n", - "Epoch:45 | Loss:0.15687781 | Runtime:0.001 mins\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch:46 | Batch:1 | Loss:0.1537003219127655 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.18MB\n", - "Epoch:46 | Batch:2 | Loss:0.1565562934 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.76MB avail. 49.0 %\n", - "Epoch:46 | Loss:0.15512831 | Runtime:0.001 mins\n", - "Epoch:47 | Batch:1 | Loss:0.15052971243858337 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.2MB\n", - "Epoch:47 | Batch:2 | Loss:0.1561212689 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.78MB avail. 49.0 %\n", - "Epoch:47 | Loss:0.15332549 | Runtime:0.001 mins\n", - "Epoch:48 | Batch:1 | Loss:0.1507793515920639 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.22MB\n", - "Epoch:48 | Batch:2 | Loss:0.1497120857 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 584.49MB avail. 49.0 %\n", - "Epoch:48 | Loss:0.15024572 | Runtime:0.001 mins\n", - "Epoch:49 | Batch:1 | Loss:0.14650507271289825 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.08MB\n", - "Epoch:49 | Batch:2 | Loss:0.1519042104 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.67MB avail. 49.0 %\n", - "Epoch:49 | Loss:0.14920464 | Runtime:0.001 mins\n", - "Epoch:50 | Batch:1 | Loss:0.14762242138385773 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.1MB\n", - "Epoch:50 | Batch:2 | Loss:0.1444831192 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.69MB avail. 49.0 %\n", - "Epoch:50 | Loss:0.14605277 | Runtime:0.001 mins\n", - "Epoch:51 | Batch:1 | Loss:0.14663973450660706 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.26MB\n", - "Epoch:51 | Batch:2 | Loss:0.1409807801 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 584.53MB avail. 49.0 %\n", - "Epoch:51 | Loss:0.14381026 | Runtime:0.001 mins\n", - "Epoch:52 | Batch:1 | Loss:0.14680664241313934 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.13MB\n", - "Epoch:52 | Batch:2 | Loss:0.1354602575 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.71MB avail. 49.0 %\n", - "Epoch:52 | Loss:0.14113345 | Runtime:0.001 mins\n", - "Epoch:53 | Batch:1 | Loss:0.14129842817783356 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.29MB\n", - "Epoch:53 | Batch:2 | Loss:0.1409944296 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.6MB avail. 49.0 %\n", - "Epoch:53 | Loss:0.14114643 | Runtime:0.001 mins\n", - "Epoch:54 | Batch:1 | Loss:0.14185933768749237 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.3MB\n", - "Epoch:54 | Batch:2 | Loss:0.1350610256 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.62MB avail. 49.0 %\n", - "Epoch:54 | Loss:0.13846018 | Runtime:0.001 mins\n", - "Epoch:55 | Batch:1 | Loss:0.13801316916942596 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.31MB\n", - "Epoch:55 | Batch:2 | Loss:0.1378786117 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.63MB avail. 49.0 %\n", - "Epoch:55 | Loss:0.13794589 | Runtime:0.001 mins\n", - "Epoch:56 | Batch:1 | Loss:0.13736185431480408 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.33MB\n", - "Epoch:56 | Batch:2 | Loss:0.1345966011 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.65MB avail. 49.0 %\n", - "Epoch:56 | Loss:0.13597923 | Runtime:0.001 mins\n", - "Epoch:57 | Batch:1 | Loss:0.1351136863231659 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.35MB\n", - "Epoch:57 | Batch:2 | Loss:0.1345768124 |ForwardBackwardUpdate:0.03sec | BatchConst.:0.11sec | Mem. Usage 588.67MB avail. 49.0 %\n", - "Epoch:57 | Loss:0.13484525 | Runtime:0.003 mins\n", - "Epoch:58 | Batch:1 | Loss:0.1314854919910431 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.09MB\n", - "Epoch:58 | Batch:2 | Loss:0.1372490078 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.69MB avail. 49.0 %\n", - "Epoch:58 | Loss:0.13436725 | Runtime:0.002 mins\n", - "Epoch:59 | Batch:1 | Loss:0.13140921294689178 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.11MB\n", - "Epoch:59 | Batch:2 | Loss:0.1333954036 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.7MB avail. 48.9 %\n", - "Epoch:59 | Loss:0.13240231 | Runtime:0.001 mins\n", - "Epoch:60 | Batch:1 | Loss:0.12838269770145416 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.13MB\n", - "Epoch:60 | Batch:2 | Loss:0.1350383461 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.67MB avail. 49.0 %\n", - "Epoch:60 | Loss:0.13171052 | Runtime:0.001 mins\n", - "Epoch:61 | Batch:1 | Loss:0.12701313197612762 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.14MB\n", - "Epoch:61 | Batch:2 | Loss:0.1335827410 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.73MB avail. 49.0 %\n", - "Epoch:61 | Loss:0.13029794 | Runtime:0.001 mins\n", - "Epoch:62 | Batch:1 | Loss:0.1289120316505432 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.3MB\n", - "Epoch:62 | Batch:2 | Loss:0.1260796785 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.62MB avail. 49.0 %\n", - "Epoch:62 | Loss:0.12749586 | Runtime:0.001 mins\n", - "Epoch:63 | Batch:1 | Loss:0.12818162143230438 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.32MB\n", - "Epoch:63 | Batch:2 | Loss:0.1237298623 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.91MB avail. 49.0 %\n", - "Epoch:63 | Loss:0.12595574 | Runtime:0.001 mins\n", - "Epoch:64 | Batch:1 | Loss:0.12201056629419327 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.21MB\n", - "Epoch:64 | Batch:2 | Loss:0.1318084747 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.8MB avail. 49.0 %\n", - "Epoch:64 | Loss:0.12690952 | Runtime:0.001 mins\n", - "Epoch:65 | Batch:1 | Loss:0.12312677502632141 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.1MB\n", - "Epoch:65 | Batch:2 | Loss:0.1258103400 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.69MB avail. 49.0 %\n", - "Epoch:65 | Loss:0.12446856 | Runtime:0.001 mins\n", - "Epoch:66 | Batch:1 | Loss:0.12187240272760391 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.11MB\n", - "Epoch:66 | Batch:2 | Loss:0.1244671270 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.69MB avail. 49.0 %\n", - "Epoch:66 | Loss:0.12316976 | Runtime:0.002 mins\n", - "Epoch:67 | Batch:1 | Loss:0.12050668150186539 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.27MB\n", - "Epoch:67 | Batch:2 | Loss:0.1237472296 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.87MB avail. 49.0 %\n", - "Epoch:67 | Loss:0.12212696 | Runtime:0.002 mins\n", - "Epoch:68 | Batch:1 | Loss:0.11972417682409286 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.15MB\n", - "Epoch:68 | Batch:2 | Loss:0.1215148494 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.74MB avail. 48.9 %\n", - "Epoch:68 | Loss:0.12061951 | Runtime:0.001 mins\n", - "Epoch:69 | Batch:1 | Loss:0.11952636390924454 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.17MB\n", - "Epoch:69 | Batch:2 | Loss:0.1186402813 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.75MB avail. 48.9 %\n", - "Epoch:69 | Loss:0.11908332 | Runtime:0.001 mins\n", - "Epoch:70 | Batch:1 | Loss:0.11607828736305237 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.18MB\n", - "Epoch:70 | Batch:2 | Loss:0.1215606406 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.78MB avail. 48.9 %\n", - "Epoch:70 | Loss:0.11881946 | Runtime:0.002 mins\n", - "Epoch:71 | Batch:1 | Loss:0.11815118044614792 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.2MB\n", - "Epoch:71 | Batch:2 | Loss:0.1141292751 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.75MB avail. 48.9 %\n", - "Epoch:71 | Loss:0.11614023 | Runtime:0.001 mins\n", - "Epoch:72 | Batch:1 | Loss:0.11739254742860794 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.22MB\n", - "Epoch:72 | Batch:2 | Loss:0.1120735779 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.81MB avail. 48.9 %\n", - "Epoch:72 | Loss:0.11473306 | Runtime:0.002 mins\n", - "Epoch:73 | Batch:1 | Loss:0.11727050691843033 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.23MB\n", - "Epoch:73 | Batch:2 | Loss:0.1089676395 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.82MB avail. 48.9 %\n", - "Epoch:73 | Loss:0.11311907 | Runtime:0.001 mins\n", - "Epoch:74 | Batch:1 | Loss:0.11456330865621567 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.25MB\n", - "Epoch:74 | Batch:2 | Loss:0.1108285636 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.57MB avail. 48.9 %\n", - "Epoch:74 | Loss:0.11269594 | Runtime:0.001 mins\n", - "Epoch:75 | Batch:1 | Loss:0.1095118597149849 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.27MB\n", - "Epoch:75 | Batch:2 | Loss:0.1169053540 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 584.54MB avail. 48.9 %\n", - "Epoch:75 | Loss:0.11320861 | Runtime:0.001 mins\n", - "Epoch:76 | Batch:1 | Loss:0.11145947873592377 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.13MB\n", - "Epoch:76 | Batch:2 | Loss:0.1097679362 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.72MB avail. 48.9 %\n", - "Epoch:76 | Loss:0.11061371 | Runtime:0.001 mins\n", - "Epoch:77 | Batch:1 | Loss:0.1090729609131813 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.15MB\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch:77 | Batch:2 | Loss:0.1108762100 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.74MB avail. 49.0 %\n", - "Epoch:77 | Loss:0.10997459 | Runtime:0.001 mins\n", - "Epoch:78 | Batch:1 | Loss:0.10636425763368607 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.16MB\n", - "Epoch:78 | Batch:2 | Loss:0.1126901656 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 584.43MB avail. 48.9 %\n", - "Epoch:78 | Loss:0.10952721 | Runtime:0.002 mins\n", - "Epoch:79 | Batch:1 | Loss:0.10761662572622299 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.3MB\n", - "Epoch:79 | Batch:2 | Loss:0.1069041416 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.02sec | Mem. Usage 588.56MB avail. 48.9 %\n", - "Epoch:79 | Loss:0.10726038 | Runtime:0.002 mins\n", - "Epoch:80 | Batch:1 | Loss:0.10537602752447128 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.31MB\n", - "Epoch:80 | Batch:2 | Loss:0.1077241153 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.63MB avail. 48.9 %\n", - "Epoch:80 | Loss:0.10655007 | Runtime:0.001 mins\n", - "Epoch:81 | Batch:1 | Loss:0.10347626358270645 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.33MB\n", - "Epoch:81 | Batch:2 | Loss:0.1079140678 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.65MB avail. 48.9 %\n", - "Epoch:81 | Loss:0.10569517 | Runtime:0.001 mins\n", - "Epoch:82 | Batch:1 | Loss:0.10309697687625885 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.35MB\n", - "Epoch:82 | Batch:2 | Loss:0.1052022204 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.94MB avail. 48.9 %\n", - "Epoch:82 | Loss:0.10414960 | Runtime:0.001 mins\n", - "Epoch:83 | Batch:1 | Loss:0.10399553924798965 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.24MB\n", - "Epoch:83 | Batch:2 | Loss:0.0999430865 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.83MB avail. 49.0 %\n", - "Epoch:83 | Loss:0.10196931 | Runtime:0.002 mins\n", - "Epoch:84 | Batch:1 | Loss:0.10225624591112137 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.25MB\n", - "Epoch:84 | Batch:2 | Loss:0.0997121334 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 584.52MB avail. 48.9 %\n", - "Epoch:84 | Loss:0.10098419 | Runtime:0.001 mins\n", - "Epoch:85 | Batch:1 | Loss:0.09726635366678238 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.12MB\n", - "Epoch:85 | Batch:2 | Loss:0.1059302688 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.02sec | Mem. Usage 588.9MB avail. 48.9 %\n", - "Epoch:85 | Loss:0.10159831 | Runtime:0.002 mins\n", - "Epoch:86 | Batch:1 | Loss:0.09784471243619919 |ForwardBackwardUpdate:0.08secs | Mem. Usage 584.18MB\n", - "Epoch:86 | Batch:2 | Loss:0.1014050916 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.77MB avail. 48.9 %\n", - "Epoch:86 | Loss:0.09962490 | Runtime:0.002 mins\n", - "Epoch:87 | Batch:1 | Loss:0.10068716108798981 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.34MB\n", - "Epoch:87 | Batch:2 | Loss:0.0926587135 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.66MB avail. 48.8 %\n", - "Epoch:87 | Loss:0.09667294 | Runtime:0.001 mins\n", - "Epoch:88 | Batch:1 | Loss:0.09627404063940048 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.09MB\n", - "Epoch:88 | Batch:2 | Loss:0.0975930914 |ForwardBackwardUpdate:0.03sec | BatchConst.:0.01sec | Mem. Usage 588.67MB avail. 48.8 %\n", - "Epoch:88 | Loss:0.09693357 | Runtime:0.002 mins\n", - "Epoch:89 | Batch:1 | Loss:0.09736421704292297 |ForwardBackwardUpdate:0.07secs | Mem. Usage 584.1MB\n", - "Epoch:89 | Batch:2 | Loss:0.0923846513 |ForwardBackwardUpdate:0.03sec | BatchConst.:0.02sec | Mem. Usage 588.96MB avail. 48.9 %\n", - "Epoch:89 | Loss:0.09487443 | Runtime:0.002 mins\n", - "Epoch:90 | Batch:1 | Loss:0.09502170979976654 |ForwardBackwardUpdate:0.05secs | Mem. Usage 584.26MB\n", - "Epoch:90 | Batch:2 | Loss:0.0937189087 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 584.53MB avail. 49.0 %\n", - "Epoch:90 | Loss:0.09437031 | Runtime:0.002 mins\n", - "Epoch:91 | Batch:1 | Loss:0.09145331382751465 |ForwardBackwardUpdate:0.05secs | Mem. Usage 584.12MB\n", - "Epoch:91 | Batch:2 | Loss:0.0974505767 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.71MB avail. 49.0 %\n", - "Epoch:91 | Loss:0.09445195 | Runtime:0.002 mins\n", - "Epoch:92 | Batch:1 | Loss:0.09294049441814423 |ForwardBackwardUpdate:0.05secs | Mem. Usage 584.14MB\n", - "Epoch:92 | Batch:2 | Loss:0.0920354575 |ForwardBackwardUpdate:0.03sec | BatchConst.:0.01sec | Mem. Usage 588.72MB avail. 49.0 %\n", - "Epoch:92 | Loss:0.09248798 | Runtime:0.002 mins\n", - "Epoch:93 | Batch:1 | Loss:0.09122174233198166 |ForwardBackwardUpdate:0.05secs | Mem. Usage 584.16MB\n", - "Epoch:93 | Batch:2 | Loss:0.0919243991 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.75MB avail. 49.0 %\n", - "Epoch:93 | Loss:0.09157307 | Runtime:0.002 mins\n", - "Epoch:94 | Batch:1 | Loss:0.09038746356964111 |ForwardBackwardUpdate:0.05secs | Mem. Usage 584.3MB\n", - "Epoch:94 | Batch:2 | Loss:0.0905173942 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 584.57MB avail. 49.0 %\n", - "Epoch:94 | Loss:0.09045243 | Runtime:0.002 mins\n", - "Epoch:95 | Batch:1 | Loss:0.0902063325047493 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.17MB\n", - "Epoch:95 | Batch:2 | Loss:0.0880731046 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.75MB avail. 48.9 %\n", - "Epoch:95 | Loss:0.08913972 | Runtime:0.002 mins\n", - "Epoch:96 | Batch:1 | Loss:0.08929985761642456 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.18MB\n", - "Epoch:96 | Batch:2 | Loss:0.0870256647 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 584.45MB avail. 48.9 %\n", - "Epoch:96 | Loss:0.08816276 | Runtime:0.001 mins\n", - "Epoch:97 | Batch:1 | Loss:0.08559670299291611 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.32MB\n", - "Epoch:97 | Batch:2 | Loss:0.0915402770 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.63MB avail. 49.0 %\n", - "Epoch:97 | Loss:0.08856849 | Runtime:0.001 mins\n", - "Epoch:98 | Batch:1 | Loss:0.08726375550031662 |ForwardBackwardUpdate:0.06secs | Mem. Usage 584.34MB\n", - "Epoch:98 | Batch:2 | Loss:0.0854786411 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.02sec | Mem. Usage 584.6MB avail. 48.9 %\n", - "Epoch:98 | Loss:0.08637120 | Runtime:0.002 mins\n", - "Epoch:99 | Batch:1 | Loss:0.08444251865148544 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.2MB\n", - "Epoch:99 | Batch:2 | Loss:0.0887010917 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 584.47MB avail. 48.9 %\n", - "Epoch:99 | Loss:0.08657181 | Runtime:0.002 mins\n", - "Epoch:100 | Batch:1 | Loss:0.0843043103814125 |ForwardBackwardUpdate:0.05secs | Mem. Usage 584.33MB\n", - "Epoch:100 | Batch:2 | Loss:0.0862324163 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.92MB avail. 48.9 %\n", - "Epoch:100 | Loss:0.08526836 | Runtime:0.002 mins\n", - "Epoch:101 | Batch:1 | Loss:0.08232592791318893 |ForwardBackwardUpdate:0.05secs | Mem. Usage 584.22MB\n", - "Epoch:101 | Batch:2 | Loss:0.0876403302 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.81MB avail. 48.9 %\n", - "Epoch:101 | Loss:0.08498313 | Runtime:0.002 mins\n", - "Epoch:102 | Batch:1 | Loss:0.08463094383478165 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.11MB\n", - "Epoch:102 | Batch:2 | Loss:0.0809089318 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 584.38MB avail. 48.9 %\n", - "Epoch:102 | Loss:0.08276994 | Runtime:0.001 mins\n", - "Epoch:103 | Batch:1 | Loss:0.08199337124824524 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.25MB\n", - "Epoch:103 | Batch:2 | Loss:0.0837271959 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.56MB avail. 48.9 %\n", - "Epoch:103 | Loss:0.08286028 | Runtime:0.001 mins\n", - "Epoch:104 | Batch:1 | Loss:0.08308608084917068 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.27MB\n", - "Epoch:104 | Batch:2 | Loss:0.0798392817 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.55MB avail. 48.9 %\n", - "Epoch:104 | Loss:0.08146268 | Runtime:0.001 mins\n", - "Epoch:105 | Batch:1 | Loss:0.08183862268924713 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.28MB\n", - "Epoch:105 | Batch:2 | Loss:0.0795020312 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.87MB avail. 48.9 %\n", - "Epoch:105 | Loss:0.08067033 | Runtime:0.001 mins\n", - "Epoch:106 | Batch:1 | Loss:0.07992857694625854 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.17MB\n", - "Epoch:106 | Batch:2 | Loss:0.0815002844 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 584.44MB avail. 48.9 %\n", - "Epoch:106 | Loss:0.08071443 | Runtime:0.001 mins\n", - "Epoch:107 | Batch:1 | Loss:0.07833510637283325 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.31MB\n", - "Epoch:107 | Batch:2 | Loss:0.0825888067 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.11sec | Mem. Usage 588.62MB avail. 48.9 %\n", - "Epoch:107 | Loss:0.08046196 | Runtime:0.003 mins\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch:108 | Batch:1 | Loss:0.07847463339567184 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.32MB\n", - "Epoch:108 | Batch:2 | Loss:0.0804487914 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.91MB avail. 48.9 %\n", - "Epoch:108 | Loss:0.07946171 | Runtime:0.001 mins\n", - "Epoch:109 | Batch:1 | Loss:0.07641984522342682 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.21MB\n", - "Epoch:109 | Batch:2 | Loss:0.0830163658 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.79MB avail. 48.9 %\n", - "Epoch:109 | Loss:0.07971811 | Runtime:0.001 mins\n", - "Epoch:110 | Batch:1 | Loss:0.07835913449525833 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.23MB\n", - "Epoch:110 | Batch:2 | Loss:0.0770355538 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.82MB avail. 48.9 %\n", - "Epoch:110 | Loss:0.07769734 | Runtime:0.001 mins\n", - "Epoch:111 | Batch:1 | Loss:0.07820626348257065 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.12MB\n", - "Epoch:111 | Batch:2 | Loss:0.0756673440 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 584.39MB avail. 48.9 %\n", - "Epoch:111 | Loss:0.07693680 | Runtime:0.001 mins\n", - "Epoch:112 | Batch:1 | Loss:0.07462061196565628 |ForwardBackwardUpdate:0.05secs | Mem. Usage 584.25MB\n", - "Epoch:112 | Batch:2 | Loss:0.0805993751 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.56MB avail. 49.0 %\n", - "Epoch:112 | Loss:0.07760999 | Runtime:0.002 mins\n", - "Epoch:113 | Batch:1 | Loss:0.07553128898143768 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.27MB\n", - "Epoch:113 | Batch:2 | Loss:0.0773180574 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 584.54MB avail. 49.0 %\n", - "Epoch:113 | Loss:0.07642467 | Runtime:0.001 mins\n", - "Epoch:114 | Batch:1 | Loss:0.07622995972633362 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.13MB\n", - "Epoch:114 | Batch:2 | Loss:0.0744187385 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.73MB avail. 49.0 %\n", - "Epoch:114 | Loss:0.07532435 | Runtime:0.001 mins\n", - "Epoch:115 | Batch:1 | Loss:0.07482840865850449 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.29MB\n", - "Epoch:115 | Batch:2 | Loss:0.0755823702 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.61MB avail. 49.0 %\n", - "Epoch:115 | Loss:0.07520539 | Runtime:0.001 mins\n", - "Epoch:116 | Batch:1 | Loss:0.07335062325000763 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.31MB\n", - "Epoch:116 | Batch:2 | Loss:0.0768343881 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.63MB avail. 49.0 %\n", - "Epoch:116 | Loss:0.07509251 | Runtime:0.001 mins\n", - "Epoch:117 | Batch:1 | Loss:0.07311473786830902 |ForwardBackwardUpdate:0.07secs | Mem. Usage 584.33MB\n", - "Epoch:117 | Batch:2 | Loss:0.0757996216 |ForwardBackwardUpdate:0.05sec | BatchConst.:0.02sec | Mem. Usage 588.65MB avail. 49.0 %\n", - "Epoch:117 | Loss:0.07445718 | Runtime:0.003 mins\n", - "Epoch:118 | Batch:1 | Loss:0.07391613721847534 |ForwardBackwardUpdate:0.11secs | Mem. Usage 584.34MB\n", - "Epoch:118 | Batch:2 | Loss:0.0729035586 |ForwardBackwardUpdate:0.04sec | BatchConst.:0.02sec | Mem. Usage 588.93MB avail. 48.1 %\n", - "Epoch:118 | Loss:0.07340985 | Runtime:0.005 mins\n", - "Epoch:119 | Batch:1 | Loss:0.07280915975570679 |ForwardBackwardUpdate:0.06secs | Mem. Usage 584.23MB\n", - "Epoch:119 | Batch:2 | Loss:0.0733642727 |ForwardBackwardUpdate:0.05sec | BatchConst.:0.02sec | Mem. Usage 588.82MB avail. 48.1 %\n", - "Epoch:119 | Loss:0.07308672 | Runtime:0.003 mins\n", - "Epoch:120 | Batch:1 | Loss:0.07251515984535217 |ForwardBackwardUpdate:0.09secs | Mem. Usage 584.12MB\n", - "Epoch:120 | Batch:2 | Loss:0.0729089081 |ForwardBackwardUpdate:0.03sec | BatchConst.:0.02sec | Mem. Usage 588.71MB avail. 48.0 %\n", - "Epoch:120 | Loss:0.07271203 | Runtime:0.003 mins\n", - "Epoch:121 | Batch:1 | Loss:0.07060059159994125 |ForwardBackwardUpdate:0.06secs | Mem. Usage 584.27MB\n", - "Epoch:121 | Batch:2 | Loss:0.0753241330 |ForwardBackwardUpdate:0.05sec | BatchConst.:0.02sec | Mem. Usage 588.57MB avail. 47.9 %\n", - "Epoch:121 | Loss:0.07296236 | Runtime:0.003 mins\n", - "Epoch:122 | Batch:1 | Loss:0.07082579284906387 |ForwardBackwardUpdate:0.07secs | Mem. Usage 584.27MB\n", - "Epoch:122 | Batch:2 | Loss:0.0738190934 |ForwardBackwardUpdate:0.03sec | BatchConst.:0.02sec | Mem. Usage 588.56MB avail. 47.8 %\n", - "Epoch:122 | Loss:0.07232244 | Runtime:0.002 mins\n", - "Epoch:123 | Batch:1 | Loss:0.07156410068273544 |ForwardBackwardUpdate:0.06secs | Mem. Usage 584.28MB\n", - "Epoch:123 | Batch:2 | Loss:0.0709424764 |ForwardBackwardUpdate:0.03sec | BatchConst.:0.02sec | Mem. Usage 588.87MB avail. 47.9 %\n", - "Epoch:123 | Loss:0.07125329 | Runtime:0.002 mins\n", - "Epoch:124 | Batch:1 | Loss:0.07108917087316513 |ForwardBackwardUpdate:0.05secs | Mem. Usage 584.17MB\n", - "Epoch:124 | Batch:2 | Loss:0.0703925341 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.02sec | Mem. Usage 588.76MB avail. 47.9 %\n", - "Epoch:124 | Loss:0.07074085 | Runtime:0.002 mins\n", - "Epoch:125 | Batch:1 | Loss:0.07053133845329285 |ForwardBackwardUpdate:0.06secs | Mem. Usage 584.33MB\n", - "Epoch:125 | Batch:2 | Loss:0.0702253878 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.02sec | Mem. Usage 588.92MB avail. 47.9 %\n", - "Epoch:125 | Loss:0.07037836 | Runtime:0.002 mins\n", - "Epoch:126 | Batch:1 | Loss:0.07014317065477371 |ForwardBackwardUpdate:0.06secs | Mem. Usage 584.21MB\n", - "Epoch:126 | Batch:2 | Loss:0.0692904592 |ForwardBackwardUpdate:0.04sec | BatchConst.:0.02sec | Mem. Usage 588.78MB avail. 48.0 %\n", - "Epoch:126 | Loss:0.06971681 | Runtime:0.002 mins\n", - "Epoch:127 | Batch:1 | Loss:0.06878573447465897 |ForwardBackwardUpdate:0.06secs | Mem. Usage 584.2MB\n", - "Epoch:127 | Batch:2 | Loss:0.0711285844 |ForwardBackwardUpdate:0.03sec | BatchConst.:0.02sec | Mem. Usage 588.8MB avail. 48.1 %\n", - "Epoch:127 | Loss:0.06995716 | Runtime:0.003 mins\n", - "Epoch:128 | Batch:1 | Loss:0.06744014471769333 |ForwardBackwardUpdate:0.06secs | Mem. Usage 584.35MB\n", - "Epoch:128 | Batch:2 | Loss:0.0721412972 |ForwardBackwardUpdate:0.04sec | BatchConst.:0.01sec | Mem. Usage 588.94MB avail. 48.2 %\n", - "Epoch:128 | Loss:0.06979072 | Runtime:0.003 mins\n", - "Epoch:129 | Batch:1 | Loss:0.06752190738916397 |ForwardBackwardUpdate:0.10secs | Mem. Usage 584.24MB\n", - "Epoch:129 | Batch:2 | Loss:0.0714396387 |ForwardBackwardUpdate:0.04sec | BatchConst.:0.02sec | Mem. Usage 588.83MB avail. 48.3 %\n", - "Epoch:129 | Loss:0.06948077 | Runtime:0.003 mins\n", - "Epoch:130 | Batch:1 | Loss:0.06743156909942627 |ForwardBackwardUpdate:0.09secs | Mem. Usage 584.26MB\n", - "Epoch:130 | Batch:2 | Loss:0.0700578094 |ForwardBackwardUpdate:0.07sec | BatchConst.:0.03sec | Mem. Usage 584.53MB avail. 48.4 %\n", - "Epoch:130 | Loss:0.06874469 | Runtime:0.004 mins\n", - "Epoch:131 | Batch:1 | Loss:0.06625290215015411 |ForwardBackwardUpdate:0.09secs | Mem. Usage 584.12MB\n", - "Epoch:131 | Batch:2 | Loss:0.0710717663 |ForwardBackwardUpdate:0.04sec | BatchConst.:0.03sec | Mem. Usage 588.89MB avail. 48.7 %\n", - "Epoch:131 | Loss:0.06866233 | Runtime:0.003 mins\n", - "Epoch:132 | Batch:1 | Loss:0.06696782261133194 |ForwardBackwardUpdate:0.08secs | Mem. Usage 584.18MB\n", - "Epoch:132 | Batch:2 | Loss:0.0685427785 |ForwardBackwardUpdate:0.10sec | BatchConst.:0.02sec | Mem. Usage 588.78MB avail. 48.8 %\n", - "Epoch:132 | Loss:0.06775530 | Runtime:0.004 mins\n", - "Epoch:133 | Batch:1 | Loss:0.06756077706813812 |ForwardBackwardUpdate:0.16secs | Mem. Usage 584.34MB\n", - "Epoch:133 | Batch:2 | Loss:0.0663831010 |ForwardBackwardUpdate:0.10sec | BatchConst.:0.04sec | Mem. Usage 588.93MB avail. 49.4 %\n", - "Epoch:133 | Loss:0.06697194 | Runtime:0.007 mins\n", - "Epoch:134 | Batch:1 | Loss:0.06800059229135513 |ForwardBackwardUpdate:0.10secs | Mem. Usage 584.22MB\n", - "Epoch:134 | Batch:2 | Loss:0.0644917637 |ForwardBackwardUpdate:0.04sec | BatchConst.:0.05sec | Mem. Usage 588.81MB avail. 49.5 %\n", - "Epoch:134 | Loss:0.06624618 | Runtime:0.004 mins\n", - "Epoch:135 | Batch:1 | Loss:0.06692203879356384 |ForwardBackwardUpdate:0.08secs | Mem. Usage 584.11MB\n", - "Epoch:135 | Batch:2 | Loss:0.0660087466 |ForwardBackwardUpdate:0.03sec | BatchConst.:0.02sec | Mem. Usage 584.38MB avail. 49.5 %\n", - "Epoch:135 | Loss:0.06646539 | Runtime:0.003 mins\n", - "Epoch:136 | Batch:1 | Loss:0.06407713890075684 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.25MB\n", - "Epoch:136 | Batch:2 | Loss:0.0699972361 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.82MB avail. 49.4 %\n", - "Epoch:136 | Loss:0.06703719 | Runtime:0.002 mins\n", - "Epoch:137 | Batch:1 | Loss:0.06589104235172272 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.26MB\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch:137 | Batch:2 | Loss:0.0653665662 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.02sec | Mem. Usage 588.85MB avail. 49.3 %\n", - "Epoch:137 | Loss:0.06562880 | Runtime:0.002 mins\n", - "Epoch:138 | Batch:1 | Loss:0.06402962654829025 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.15MB\n", - "Epoch:138 | Batch:2 | Loss:0.0676817894 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 584.42MB avail. 49.3 %\n", - "Epoch:138 | Loss:0.06585571 | Runtime:0.001 mins\n", - "Epoch:139 | Batch:1 | Loss:0.06388145685195923 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.29MB\n", - "Epoch:139 | Batch:2 | Loss:0.0666979402 |ForwardBackwardUpdate:0.03sec | BatchConst.:0.02sec | Mem. Usage 588.87MB avail. 49.3 %\n", - "Epoch:139 | Loss:0.06528970 | Runtime:0.002 mins\n", - "Epoch:140 | Batch:1 | Loss:0.06414478272199631 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.18MB\n", - "Epoch:140 | Batch:2 | Loss:0.0653334856 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.77MB avail. 49.3 %\n", - "Epoch:140 | Loss:0.06473913 | Runtime:0.002 mins\n", - "Epoch:141 | Batch:1 | Loss:0.06362896412611008 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.34MB\n", - "Epoch:141 | Batch:2 | Loss:0.0652326643 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.93MB avail. 49.4 %\n", - "Epoch:141 | Loss:0.06443081 | Runtime:0.002 mins\n", - "Epoch:142 | Batch:1 | Loss:0.06386518478393555 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.22MB\n", - "Epoch:142 | Batch:2 | Loss:0.0640804842 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.81MB avail. 49.2 %\n", - "Epoch:142 | Loss:0.06397283 | Runtime:0.002 mins\n", - "Epoch:143 | Batch:1 | Loss:0.06251579523086548 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.11MB\n", - "Epoch:143 | Batch:2 | Loss:0.0655950382 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.69MB avail. 49.3 %\n", - "Epoch:143 | Loss:0.06405542 | Runtime:0.001 mins\n", - "Epoch:144 | Batch:1 | Loss:0.061766862869262695 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.13MB\n", - "Epoch:144 | Batch:2 | Loss:0.0662709773 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.71MB avail. 49.3 %\n", - "Epoch:144 | Loss:0.06401892 | Runtime:0.001 mins\n", - "Epoch:145 | Batch:1 | Loss:0.06411607563495636 |ForwardBackwardUpdate:0.05secs | Mem. Usage 584.15MB\n", - "Epoch:145 | Batch:2 | Loss:0.0603970587 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.69MB avail. 49.2 %\n", - "Epoch:145 | Loss:0.06225657 | Runtime:0.002 mins\n", - "Epoch:146 | Batch:1 | Loss:0.06331866979598999 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.12MB\n", - "Epoch:146 | Batch:2 | Loss:0.0613979809 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.69MB avail. 49.2 %\n", - "Epoch:146 | Loss:0.06235833 | Runtime:0.002 mins\n", - "Epoch:147 | Batch:1 | Loss:0.06237439811229706 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.14MB\n", - "Epoch:147 | Batch:2 | Loss:0.0622274652 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.65MB avail. 49.2 %\n", - "Epoch:147 | Loss:0.06230093 | Runtime:0.002 mins\n", - "Epoch:148 | Batch:1 | Loss:0.0628281980752945 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.16MB\n", - "Epoch:148 | Batch:2 | Loss:0.0608917363 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.75MB avail. 49.2 %\n", - "Epoch:148 | Loss:0.06185997 | Runtime:0.001 mins\n", - "Epoch:149 | Batch:1 | Loss:0.063510961830616 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.17MB\n", - "Epoch:149 | Batch:2 | Loss:0.0583437458 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.76MB avail. 49.2 %\n", - "Epoch:149 | Loss:0.06092735 | Runtime:0.001 mins\n", - "Epoch:150 | Batch:1 | Loss:0.06208859011530876 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.33MB\n", - "Epoch:150 | Batch:2 | Loss:0.0606142469 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.64MB avail. 49.2 %\n", - "Epoch:150 | Loss:0.06135142 | Runtime:0.001 mins\n", - "Epoch:151 | Batch:1 | Loss:0.06205631420016289 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.35MB\n", - "Epoch:151 | Batch:2 | Loss:0.0592227727 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.94MB avail. 49.2 %\n", - "Epoch:151 | Loss:0.06063954 | Runtime:0.001 mins\n", - "Epoch:152 | Batch:1 | Loss:0.059554219245910645 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.24MB\n", - "Epoch:152 | Batch:2 | Loss:0.0629249066 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.82MB avail. 49.2 %\n", - "Epoch:152 | Loss:0.06123956 | Runtime:0.002 mins\n", - "Epoch:153 | Batch:1 | Loss:0.061766546219587326 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.25MB\n", - "Epoch:153 | Batch:2 | Loss:0.0589257069 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.85MB avail. 49.2 %\n", - "Epoch:153 | Loss:0.06034613 | Runtime:0.002 mins\n", - "Epoch:154 | Batch:1 | Loss:0.059790972620248795 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.14MB\n", - "Epoch:154 | Batch:2 | Loss:0.0611557141 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.73MB avail. 49.2 %\n", - "Epoch:154 | Loss:0.06047334 | Runtime:0.002 mins\n", - "Epoch:155 | Batch:1 | Loss:0.057743772864341736 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.16MB\n", - "Epoch:155 | Batch:2 | Loss:0.0638928488 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.73MB avail. 49.1 %\n", - "Epoch:155 | Loss:0.06081831 | Runtime:0.002 mins\n", - "Epoch:156 | Batch:1 | Loss:0.05995730683207512 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.32MB\n", - "Epoch:156 | Batch:2 | Loss:0.0592187978 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.91MB avail. 49.1 %\n", - "Epoch:156 | Loss:0.05958805 | Runtime:0.002 mins\n", - "Epoch:157 | Batch:1 | Loss:0.05748409032821655 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.21MB\n", - "Epoch:157 | Batch:2 | Loss:0.0631464794 |ForwardBackwardUpdate:0.03sec | BatchConst.:0.11sec | Mem. Usage 588.8MB avail. 49.1 %\n", - "Epoch:157 | Loss:0.06031528 | Runtime:0.003 mins\n", - "Epoch:158 | Batch:1 | Loss:0.057980976998806 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.1MB\n", - "Epoch:158 | Batch:2 | Loss:0.0607736185 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.69MB avail. 49.1 %\n", - "Epoch:158 | Loss:0.05937730 | Runtime:0.002 mins\n", - "Epoch:159 | Batch:1 | Loss:0.05897434428334236 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.11MB\n", - "Epoch:159 | Batch:2 | Loss:0.0580624305 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.68MB avail. 49.2 %\n", - "Epoch:159 | Loss:0.05851839 | Runtime:0.001 mins\n", - "Epoch:160 | Batch:1 | Loss:0.060359034687280655 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.13MB\n", - "Epoch:160 | Batch:2 | Loss:0.0547461882 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.73MB avail. 49.1 %\n", - "Epoch:160 | Loss:0.05755261 | Runtime:0.001 mins\n", - "Epoch:161 | Batch:1 | Loss:0.0575183667242527 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.15MB\n", - "Epoch:161 | Batch:2 | Loss:0.0598444939 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.74MB avail. 49.1 %\n", - "Epoch:161 | Loss:0.05868143 | Runtime:0.001 mins\n", - "Epoch:162 | Batch:1 | Loss:0.057776376605033875 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.31MB\n", - "Epoch:162 | Batch:2 | Loss:0.0582922511 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.9MB avail. 49.1 %\n", - "Epoch:162 | Loss:0.05803431 | Runtime:0.001 mins\n", - "Epoch:163 | Batch:1 | Loss:0.05864546447992325 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.2MB\n", - "Epoch:163 | Batch:2 | Loss:0.0556474961 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.79MB avail. 49.1 %\n", - "Epoch:163 | Loss:0.05714648 | Runtime:0.002 mins\n", - "Epoch:164 | Batch:1 | Loss:0.056853387504816055 |ForwardBackwardUpdate:0.05secs | Mem. Usage 584.09MB\n", - "Epoch:164 | Batch:2 | Loss:0.0585261136 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.95MB avail. 49.1 %\n", - "Epoch:164 | Loss:0.05768975 | Runtime:0.002 mins\n", - "Epoch:165 | Batch:1 | Loss:0.057521119713783264 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.25MB\n", - "Epoch:165 | Batch:2 | Loss:0.0565902591 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.57MB avail. 49.1 %\n", - "Epoch:165 | Loss:0.05705569 | Runtime:0.002 mins\n", - "Epoch:166 | Batch:1 | Loss:0.05831480398774147 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.27MB\n", - "Epoch:166 | Batch:2 | Loss:0.0536241494 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.58MB avail. 49.1 %\n", - "Epoch:166 | Loss:0.05596948 | Runtime:0.001 mins\n", - "Epoch:167 | Batch:1 | Loss:0.05593831092119217 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.28MB\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch:167 | Batch:2 | Loss:0.0569442473 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.87MB avail. 49.0 %\n", - "Epoch:167 | Loss:0.05644128 | Runtime:0.002 mins\n", - "Epoch:168 | Batch:1 | Loss:0.05365685373544693 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.17MB\n", - "Epoch:168 | Batch:2 | Loss:0.0610043257 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.76MB avail. 49.0 %\n", - "Epoch:168 | Loss:0.05733059 | Runtime:0.001 mins\n", - "Epoch:169 | Batch:1 | Loss:0.05445089936256409 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.33MB\n", - "Epoch:169 | Batch:2 | Loss:0.0584338717 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.65MB avail. 49.0 %\n", - "Epoch:169 | Loss:0.05644239 | Runtime:0.001 mins\n", - "Epoch:170 | Batch:1 | Loss:0.05451720207929611 |ForwardBackwardUpdate:0.05secs | Mem. Usage 584.35MB\n", - "Epoch:170 | Batch:2 | Loss:0.0571436696 |ForwardBackwardUpdate:0.03sec | BatchConst.:0.02sec | Mem. Usage 588.94MB avail. 49.1 %\n", - "Epoch:170 | Loss:0.05583044 | Runtime:0.002 mins\n", - "Epoch:171 | Batch:1 | Loss:0.05559270828962326 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.24MB\n", - "Epoch:171 | Batch:2 | Loss:0.0552281141 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.83MB avail. 49.1 %\n", - "Epoch:171 | Loss:0.05541041 | Runtime:0.002 mins\n", - "Epoch:172 | Batch:1 | Loss:0.05507534742355347 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.13MB\n", - "Epoch:172 | Batch:2 | Loss:0.0551640950 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.72MB avail. 49.1 %\n", - "Epoch:172 | Loss:0.05511972 | Runtime:0.002 mins\n", - "Epoch:173 | Batch:1 | Loss:0.056973110884428024 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.14MB\n", - "Epoch:173 | Batch:2 | Loss:0.0503338948 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.72MB avail. 49.1 %\n", - "Epoch:173 | Loss:0.05365350 | Runtime:0.001 mins\n", - "Epoch:174 | Batch:1 | Loss:0.055764634162187576 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.16MB\n", - "Epoch:174 | Batch:2 | Loss:0.0527380779 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.75MB avail. 49.1 %\n", - "Epoch:174 | Loss:0.05425136 | Runtime:0.001 mins\n", - "Epoch:175 | Batch:1 | Loss:0.052320536226034164 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.32MB\n", - "Epoch:175 | Batch:2 | Loss:0.0581855550 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.64MB avail. 49.1 %\n", - "Epoch:175 | Loss:0.05525305 | Runtime:0.001 mins\n", - "Epoch:176 | Batch:1 | Loss:0.053551871329545975 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.34MB\n", - "Epoch:176 | Batch:2 | Loss:0.0552392192 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.02sec | Mem. Usage 588.93MB avail. 49.1 %\n", - "Epoch:176 | Loss:0.05439555 | Runtime:0.002 mins\n", - "Epoch:177 | Batch:1 | Loss:0.05468503385782242 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.22MB\n", - "Epoch:177 | Batch:2 | Loss:0.0523067452 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.8MB avail. 49.1 %\n", - "Epoch:177 | Loss:0.05349589 | Runtime:0.002 mins\n", - "Epoch:178 | Batch:1 | Loss:0.05471048876643181 |ForwardBackwardUpdate:0.05secs | Mem. Usage 584.25MB\n", - "Epoch:178 | Batch:2 | Loss:0.0512269996 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.83MB avail. 49.1 %\n", - "Epoch:178 | Loss:0.05296874 | Runtime:0.002 mins\n", - "Epoch:179 | Batch:1 | Loss:0.052981551736593246 |ForwardBackwardUpdate:0.05secs | Mem. Usage 584.26MB\n", - "Epoch:179 | Batch:2 | Loss:0.0539744496 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.58MB avail. 49.1 %\n", - "Epoch:179 | Loss:0.05347800 | Runtime:0.002 mins\n", - "Epoch:180 | Batch:1 | Loss:0.05245537683367729 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.28MB\n", - "Epoch:180 | Batch:2 | Loss:0.0543906577 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.58MB avail. 49.1 %\n", - "Epoch:180 | Loss:0.05342302 | Runtime:0.002 mins\n", - "Epoch:181 | Batch:1 | Loss:0.053506892174482346 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.29MB\n", - "Epoch:181 | Batch:2 | Loss:0.0521351323 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.02sec | Mem. Usage 588.58MB avail. 49.1 %\n", - "Epoch:181 | Loss:0.05282101 | Runtime:0.002 mins\n", - "Epoch:182 | Batch:1 | Loss:0.0522281788289547 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.31MB\n", - "Epoch:182 | Batch:2 | Loss:0.0539352261 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.63MB avail. 49.1 %\n", - "Epoch:182 | Loss:0.05308170 | Runtime:0.002 mins\n", - "Epoch:183 | Batch:1 | Loss:0.0522024892270565 |ForwardBackwardUpdate:0.05secs | Mem. Usage 584.33MB\n", - "Epoch:183 | Batch:2 | Loss:0.0535104014 |ForwardBackwardUpdate:0.05sec | BatchConst.:0.05sec | Mem. Usage 588.92MB avail. 49.4 %\n", - "Epoch:183 | Loss:0.05285645 | Runtime:0.003 mins\n", - "Epoch:184 | Batch:1 | Loss:0.04964430257678032 |ForwardBackwardUpdate:0.15secs | Mem. Usage 584.22MB\n", - "Epoch:184 | Batch:2 | Loss:0.0571962520 |ForwardBackwardUpdate:0.08sec | BatchConst.:0.03sec | Mem. Usage 588.81MB avail. 50.0 %\n", - "Epoch:184 | Loss:0.05342028 | Runtime:0.005 mins\n", - "Epoch:185 | Batch:1 | Loss:0.05134941264986992 |ForwardBackwardUpdate:0.11secs | Mem. Usage 584.11MB\n", - "Epoch:185 | Batch:2 | Loss:0.0538177714 |ForwardBackwardUpdate:0.06sec | BatchConst.:0.03sec | Mem. Usage 588.7MB avail. 50.2 %\n", - "Epoch:185 | Loss:0.05258359 | Runtime:0.004 mins\n", - "Epoch:186 | Batch:1 | Loss:0.05343470349907875 |ForwardBackwardUpdate:0.08secs | Mem. Usage 584.27MB\n", - "Epoch:186 | Batch:2 | Loss:0.0509754270 |ForwardBackwardUpdate:0.04sec | BatchConst.:0.02sec | Mem. Usage 588.86MB avail. 50.1 %\n", - "Epoch:186 | Loss:0.05220507 | Runtime:0.003 mins\n", - "Epoch:187 | Batch:1 | Loss:0.04987652972340584 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.2MB\n", - "Epoch:187 | Batch:2 | Loss:0.0589211956 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.79MB avail. 50.1 %\n", - "Epoch:187 | Loss:0.05439886 | Runtime:0.002 mins\n", - "Epoch:188 | Batch:1 | Loss:0.051099248230457306 |ForwardBackwardUpdate:0.05secs | Mem. Usage 584.22MB\n", - "Epoch:188 | Batch:2 | Loss:0.0532381460 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.81MB avail. 50.0 %\n", - "Epoch:188 | Loss:0.05216870 | Runtime:0.002 mins\n", - "Epoch:189 | Batch:1 | Loss:0.05241641774773598 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.09MB\n", - "Epoch:189 | Batch:2 | Loss:0.0507234186 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 584.36MB avail. 50.0 %\n", - "Epoch:189 | Loss:0.05156992 | Runtime:0.002 mins\n", - "Epoch:190 | Batch:1 | Loss:0.05150429904460907 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.23MB\n", - "Epoch:190 | Batch:2 | Loss:0.0556306951 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.82MB avail. 50.0 %\n", - "Epoch:190 | Loss:0.05356750 | Runtime:0.001 mins\n", - "Epoch:191 | Batch:1 | Loss:0.052306585013866425 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.25MB\n", - "Epoch:191 | Batch:2 | Loss:0.0517870076 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.83MB avail. 50.0 %\n", - "Epoch:191 | Loss:0.05204680 | Runtime:0.002 mins\n", - "Epoch:192 | Batch:1 | Loss:0.05249800905585289 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.26MB\n", - "Epoch:192 | Batch:2 | Loss:0.0488391109 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.85MB avail. 50.0 %\n", - "Epoch:192 | Loss:0.05066856 | Runtime:0.002 mins\n", - "Epoch:193 | Batch:1 | Loss:0.05310021713376045 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.17MB\n", - "Epoch:193 | Batch:2 | Loss:0.0492961146 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.74MB avail. 50.0 %\n", - "Epoch:193 | Loss:0.05119817 | Runtime:0.002 mins\n", - "Epoch:194 | Batch:1 | Loss:0.05340215191245079 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.18MB\n", - "Epoch:194 | Batch:2 | Loss:0.0493066721 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.78MB avail. 50.0 %\n", - "Epoch:194 | Loss:0.05135441 | Runtime:0.002 mins\n", - "Epoch:195 | Batch:1 | Loss:0.05069021135568619 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.2MB\n", - "Epoch:195 | Batch:2 | Loss:0.0544601269 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.79MB avail. 50.0 %\n", - "Epoch:195 | Loss:0.05257517 | Runtime:0.002 mins\n", - "Epoch:196 | Batch:1 | Loss:0.05041426047682762 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.22MB\n", - "Epoch:196 | Batch:2 | Loss:0.0517515913 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.81MB avail. 50.0 %\n", - "Epoch:196 | Loss:0.05108293 | Runtime:0.001 mins\n", - "Epoch:197 | Batch:1 | Loss:0.05015302821993828 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.11MB\n", - "Epoch:197 | Batch:2 | Loss:0.0535663255 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.7MB avail. 50.0 %\n", - "Epoch:197 | Loss:0.05185968 | Runtime:0.001 mins\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch:198 | Batch:1 | Loss:0.05207153037190437 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.12MB\n", - "Epoch:198 | Batch:2 | Loss:0.0504996330 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.02sec | Mem. Usage 588.71MB avail. 50.0 %\n", - "Epoch:198 | Loss:0.05128558 | Runtime:0.002 mins\n", - "Epoch:199 | Batch:1 | Loss:0.048883289098739624 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.28MB\n", - "Epoch:199 | Batch:2 | Loss:0.0520132184 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 584.55MB avail. 49.9 %\n", - "Epoch:199 | Loss:0.05044825 | Runtime:0.002 mins\n", - "Epoch:200 | Batch:1 | Loss:0.047928303480148315 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.15MB\n", - "Epoch:200 | Batch:2 | Loss:0.0520504415 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.73MB avail. 50.0 %\n", - "Epoch:200 | Loss:0.04998937 | Runtime:0.002 mins\n", - "Epoch:201 | Batch:1 | Loss:0.04889510199427605 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.31MB\n", - "Epoch:201 | Batch:2 | Loss:0.0504072234 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.02sec | Mem. Usage 588.9MB avail. 49.9 %\n", - "Epoch:201 | Loss:0.04965116 | Runtime:0.002 mins\n", - "Epoch:202 | Batch:1 | Loss:0.04873442277312279 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.18MB\n", - "Epoch:202 | Batch:2 | Loss:0.0497742556 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.78MB avail. 49.9 %\n", - "Epoch:202 | Loss:0.04925434 | Runtime:0.002 mins\n", - "Epoch:203 | Batch:1 | Loss:0.04655801132321358 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.34MB\n", - "Epoch:203 | Batch:2 | Loss:0.0538899340 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.02sec | Mem. Usage 588.63MB avail. 49.9 %\n", - "Epoch:203 | Loss:0.05022397 | Runtime:0.002 mins\n", - "Epoch:204 | Batch:1 | Loss:0.04666663333773613 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.32MB\n", - "Epoch:204 | Batch:2 | Loss:0.0530395508 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.02sec | Mem. Usage 588.64MB avail. 49.9 %\n", - "Epoch:204 | Loss:0.04985309 | Runtime:0.002 mins\n", - "Epoch:205 | Batch:1 | Loss:0.04929682984948158 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.34MB\n", - "Epoch:205 | Batch:2 | Loss:0.0518559217 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.64MB avail. 49.9 %\n", - "Epoch:205 | Loss:0.05057638 | Runtime:0.002 mins\n", - "Epoch:206 | Batch:1 | Loss:0.04564577713608742 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.09MB\n", - "Epoch:206 | Batch:2 | Loss:0.0574958846 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.95MB avail. 49.9 %\n", - "Epoch:206 | Loss:0.05157083 | Runtime:0.001 mins\n", - "Epoch:207 | Batch:1 | Loss:0.05065576732158661 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.25MB\n", - "Epoch:207 | Batch:2 | Loss:0.0503311418 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.11sec | Mem. Usage 588.83MB avail. 49.8 %\n", - "Epoch:207 | Loss:0.05049345 | Runtime:0.003 mins\n", - "Epoch:208 | Batch:1 | Loss:0.049541424959897995 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.26MB\n", - "Epoch:208 | Batch:2 | Loss:0.0489338860 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.56MB avail. 49.8 %\n", - "Epoch:208 | Loss:0.04923766 | Runtime:0.002 mins\n", - "Epoch:209 | Batch:1 | Loss:0.051391731947660446 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.28MB\n", - "Epoch:209 | Batch:2 | Loss:0.0452043973 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.87MB avail. 49.8 %\n", - "Epoch:209 | Loss:0.04829806 | Runtime:0.002 mins\n", - "Epoch:210 | Batch:1 | Loss:0.04955744370818138 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.17MB\n", - "Epoch:210 | Batch:2 | Loss:0.0502790846 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.75MB avail. 49.9 %\n", - "Epoch:210 | Loss:0.04991826 | Runtime:0.002 mins\n", - "Epoch:211 | Batch:1 | Loss:0.04714949056506157 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.33MB\n", - "Epoch:211 | Batch:2 | Loss:0.0505932197 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 584.6MB avail. 49.8 %\n", - "Epoch:211 | Loss:0.04887136 | Runtime:0.001 mins\n", - "Epoch:212 | Batch:1 | Loss:0.047772783786058426 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.19MB\n", - "Epoch:212 | Batch:2 | Loss:0.0486590602 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.78MB avail. 49.8 %\n", - "Epoch:212 | Loss:0.04821592 | Runtime:0.002 mins\n", - "Epoch:213 | Batch:1 | Loss:0.04933064803481102 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.35MB\n", - "Epoch:213 | Batch:2 | Loss:0.0476627871 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.02sec | Mem. Usage 588.65MB avail. 49.8 %\n", - "Epoch:213 | Loss:0.04849672 | Runtime:0.002 mins\n", - "Epoch:214 | Batch:1 | Loss:0.04761657118797302 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.34MB\n", - "Epoch:214 | Batch:2 | Loss:0.0470872372 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.94MB avail. 49.8 %\n", - "Epoch:214 | Loss:0.04735190 | Runtime:0.002 mins\n", - "Epoch:215 | Batch:1 | Loss:0.04624146968126297 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.23MB\n", - "Epoch:215 | Batch:2 | Loss:0.0490209050 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.82MB avail. 49.8 %\n", - "Epoch:215 | Loss:0.04763119 | Runtime:0.002 mins\n", - "Epoch:216 | Batch:1 | Loss:0.04717108979821205 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.25MB\n", - "Epoch:216 | Batch:2 | Loss:0.0475515090 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.84MB avail. 49.8 %\n", - "Epoch:216 | Loss:0.04736130 | Runtime:0.002 mins\n", - "Epoch:217 | Batch:1 | Loss:0.04746459797024727 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.14MB\n", - "Epoch:217 | Batch:2 | Loss:0.0451246984 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.02sec | Mem. Usage 588.73MB avail. 49.8 %\n", - "Epoch:217 | Loss:0.04629465 | Runtime:0.002 mins\n", - "Epoch:218 | Batch:1 | Loss:0.04821448400616646 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.3MB\n", - "Epoch:218 | Batch:2 | Loss:0.0457179174 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 584.57MB avail. 49.8 %\n", - "Epoch:218 | Loss:0.04696620 | Runtime:0.001 mins\n", - "Epoch:219 | Batch:1 | Loss:0.046421248465776443 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.16MB\n", - "Epoch:219 | Batch:2 | Loss:0.0460630395 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.75MB avail. 49.8 %\n", - "Epoch:219 | Loss:0.04624214 | Runtime:0.001 mins\n", - "Epoch:220 | Batch:1 | Loss:0.04734170809388161 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.18MB\n", - "Epoch:220 | Batch:2 | Loss:0.0443059169 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.02sec | Mem. Usage 588.75MB avail. 49.8 %\n", - "Epoch:220 | Loss:0.04582381 | Runtime:0.002 mins\n", - "Epoch:221 | Batch:1 | Loss:0.04771336913108826 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.2MB\n", - "Epoch:221 | Batch:2 | Loss:0.0434141345 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 584.47MB avail. 49.9 %\n", - "Epoch:221 | Loss:0.04556375 | Runtime:0.001 mins\n", - "Epoch:222 | Batch:1 | Loss:0.04515066370368004 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.34MB\n", - "Epoch:222 | Batch:2 | Loss:0.0474022143 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.62MB avail. 49.9 %\n", - "Epoch:222 | Loss:0.04627644 | Runtime:0.001 mins\n", - "Epoch:223 | Batch:1 | Loss:0.04366186261177063 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.31MB\n", - "Epoch:223 | Batch:2 | Loss:0.0481328703 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.63MB avail. 49.8 %\n", - "Epoch:223 | Loss:0.04589737 | Runtime:0.001 mins\n", - "Epoch:224 | Batch:1 | Loss:0.047253672033548355 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.33MB\n", - "Epoch:224 | Batch:2 | Loss:0.0428189822 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.65MB avail. 49.8 %\n", - "Epoch:224 | Loss:0.04503633 | Runtime:0.002 mins\n", - "Epoch:225 | Batch:1 | Loss:0.04419654607772827 |ForwardBackwardUpdate:0.05secs | Mem. Usage 584.35MB\n", - "Epoch:225 | Batch:2 | Loss:0.0475541390 |ForwardBackwardUpdate:0.03sec | BatchConst.:0.02sec | Mem. Usage 588.94MB avail. 49.9 %\n", - "Epoch:225 | Loss:0.04587534 | Runtime:0.002 mins\n", - "Epoch:226 | Batch:1 | Loss:0.04221401736140251 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.24MB\n", - "Epoch:226 | Batch:2 | Loss:0.0507492051 |ForwardBackwardUpdate:0.03sec | BatchConst.:0.01sec | Mem. Usage 588.74MB avail. 49.9 %\n", - "Epoch:226 | Loss:0.04648161 | Runtime:0.002 mins\n", - "Epoch:227 | Batch:1 | Loss:0.04547853767871857 |ForwardBackwardUpdate:0.06secs | Mem. Usage 584.17MB\n", - "Epoch:227 | Batch:2 | Loss:0.0431850813 |ForwardBackwardUpdate:0.03sec | BatchConst.:0.01sec | Mem. Usage 588.76MB avail. 49.9 %\n", - "Epoch:227 | Loss:0.04433181 | Runtime:0.002 mins\n", - "Epoch:228 | Batch:1 | Loss:0.044031064957380295 |ForwardBackwardUpdate:0.05secs | Mem. Usage 584.18MB\n", - "Epoch:228 | Batch:2 | Loss:0.0471545197 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.78MB avail. 49.9 %\n", - "Epoch:228 | Loss:0.04559279 | Runtime:0.002 mins\n", - "Epoch:229 | Batch:1 | Loss:0.04485464096069336 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.34MB\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch:229 | Batch:2 | Loss:0.0465643480 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.02sec | Mem. Usage 588.63MB avail. 49.9 %\n", - "Epoch:229 | Loss:0.04570949 | Runtime:0.002 mins\n", - "Epoch:230 | Batch:1 | Loss:0.046474575996398926 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.11MB\n", - "Epoch:230 | Batch:2 | Loss:0.0401118584 |ForwardBackwardUpdate:0.03sec | BatchConst.:0.02sec | Mem. Usage 588.7MB avail. 49.9 %\n", - "Epoch:230 | Loss:0.04329322 | Runtime:0.002 mins\n", - "Epoch:231 | Batch:1 | Loss:0.04379909485578537 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.27MB\n", - "Epoch:231 | Batch:2 | Loss:0.0459378101 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.52MB avail. 49.9 %\n", - "Epoch:231 | Loss:0.04486845 | Runtime:0.002 mins\n", - "Epoch:232 | Batch:1 | Loss:0.046187710016965866 |ForwardBackwardUpdate:0.05secs | Mem. Usage 584.28MB\n", - "Epoch:232 | Batch:2 | Loss:0.0430983864 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.6MB avail. 49.9 %\n", - "Epoch:232 | Loss:0.04464305 | Runtime:0.002 mins\n", - "Epoch:233 | Batch:1 | Loss:0.04350961744785309 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.3MB\n", - "Epoch:233 | Batch:2 | Loss:0.0449445955 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.62MB avail. 49.9 %\n", - "Epoch:233 | Loss:0.04422711 | Runtime:0.002 mins\n", - "Epoch:234 | Batch:1 | Loss:0.04559534415602684 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.32MB\n", - "Epoch:234 | Batch:2 | Loss:0.0420075022 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.91MB avail. 49.9 %\n", - "Epoch:234 | Loss:0.04380142 | Runtime:0.002 mins\n", - "Epoch:235 | Batch:1 | Loss:0.04320334643125534 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.21MB\n", - "Epoch:235 | Batch:2 | Loss:0.0443654656 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.79MB avail. 49.9 %\n", - "Epoch:235 | Loss:0.04378441 | Runtime:0.002 mins\n", - "Epoch:236 | Batch:1 | Loss:0.0458211749792099 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.22MB\n", - "Epoch:236 | Batch:2 | Loss:0.0407364257 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.81MB avail. 49.9 %\n", - "Epoch:236 | Loss:0.04327880 | Runtime:0.002 mins\n", - "Epoch:237 | Batch:1 | Loss:0.04326337203383446 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.24MB\n", - "Epoch:237 | Batch:2 | Loss:0.0459301770 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.83MB avail. 49.9 %\n", - "Epoch:237 | Loss:0.04459677 | Runtime:0.002 mins\n", - "Epoch:238 | Batch:1 | Loss:0.0429919995367527 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.26MB\n", - "Epoch:238 | Batch:2 | Loss:0.0451381095 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.02sec | Mem. Usage 588.85MB avail. 49.9 %\n", - "Epoch:238 | Loss:0.04406505 | Runtime:0.002 mins\n", - "Epoch:239 | Batch:1 | Loss:0.04168468713760376 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.15MB\n", - "Epoch:239 | Batch:2 | Loss:0.0453135408 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.69MB avail. 49.9 %\n", - "Epoch:239 | Loss:0.04349911 | Runtime:0.002 mins\n", - "Epoch:240 | Batch:1 | Loss:0.043360088020563126 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.18MB\n", - "Epoch:240 | Batch:2 | Loss:0.0483630225 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.02sec | Mem. Usage 588.77MB avail. 49.9 %\n", - "Epoch:240 | Loss:0.04586156 | Runtime:0.002 mins\n", - "Epoch:241 | Batch:1 | Loss:0.042422980070114136 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.34MB\n", - "Epoch:241 | Batch:2 | Loss:0.0433721431 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 584.61MB avail. 49.9 %\n", - "Epoch:241 | Loss:0.04289756 | Runtime:0.002 mins\n", - "Epoch:242 | Batch:1 | Loss:0.04205933213233948 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.2MB\n", - "Epoch:242 | Batch:2 | Loss:0.0438773744 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.79MB avail. 49.9 %\n", - "Epoch:242 | Loss:0.04296835 | Runtime:0.001 mins\n", - "Epoch:243 | Batch:1 | Loss:0.04169277846813202 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.22MB\n", - "Epoch:243 | Batch:2 | Loss:0.0438640155 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 584.49MB avail. 49.9 %\n", - "Epoch:243 | Loss:0.04277840 | Runtime:0.002 mins\n", - "Epoch:244 | Batch:1 | Loss:0.04168829694390297 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.35MB\n", - "Epoch:244 | Batch:2 | Loss:0.0432196856 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 584.35MB avail. 49.9 %\n", - "Epoch:244 | Loss:0.04245399 | Runtime:0.002 mins\n", - "Epoch:245 | Batch:1 | Loss:0.04263291507959366 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.22MB\n", - "Epoch:245 | Batch:2 | Loss:0.0411167927 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.78MB avail. 49.9 %\n", - "Epoch:245 | Loss:0.04187485 | Runtime:0.002 mins\n", - "Epoch:246 | Batch:1 | Loss:0.041595473885536194 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.23MB\n", - "Epoch:246 | Batch:2 | Loss:0.0424383804 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.02sec | Mem. Usage 588.82MB avail. 49.9 %\n", - "Epoch:246 | Loss:0.04201693 | Runtime:0.002 mins\n", - "Epoch:247 | Batch:1 | Loss:0.04269953444600105 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.12MB\n", - "Epoch:247 | Batch:2 | Loss:0.0402541570 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 584.39MB avail. 49.9 %\n", - "Epoch:247 | Loss:0.04147685 | Runtime:0.002 mins\n", - "Epoch:248 | Batch:1 | Loss:0.04210883751511574 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.26MB\n", - "Epoch:248 | Batch:2 | Loss:0.0406109691 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.53MB avail. 49.9 %\n", - "Epoch:248 | Loss:0.04135990 | Runtime:0.002 mins\n", - "Epoch:249 | Batch:1 | Loss:0.04235502704977989 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.27MB\n", - "Epoch:249 | Batch:2 | Loss:0.0402841270 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.6MB avail. 49.9 %\n", - "Epoch:249 | Loss:0.04131958 | Runtime:0.002 mins\n", - "Epoch:250 | Batch:1 | Loss:0.04164082556962967 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.29MB\n", - "Epoch:250 | Batch:2 | Loss:0.0422295630 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.6MB avail. 49.9 %\n", - "Epoch:250 | Loss:0.04193519 | Runtime:0.001 mins\n", - "Epoch:251 | Batch:1 | Loss:0.040893737226724625 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.31MB\n", - "Epoch:251 | Batch:2 | Loss:0.0425735973 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.62MB avail. 49.9 %\n", - "Epoch:251 | Loss:0.04173367 | Runtime:0.002 mins\n", - "Epoch:252 | Batch:1 | Loss:0.04230279102921486 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.32MB\n", - "Epoch:252 | Batch:2 | Loss:0.0402002037 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.91MB avail. 49.9 %\n", - "Epoch:252 | Loss:0.04125150 | Runtime:0.002 mins\n", - "Epoch:253 | Batch:1 | Loss:0.04151597619056702 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.21MB\n", - "Epoch:253 | Batch:2 | Loss:0.0404539071 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.8MB avail. 49.9 %\n", - "Epoch:253 | Loss:0.04098494 | Runtime:0.002 mins\n", - "Epoch:254 | Batch:1 | Loss:0.04145575314760208 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.23MB\n", - "Epoch:254 | Batch:2 | Loss:0.0406865627 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.82MB avail. 49.9 %\n", - "Epoch:254 | Loss:0.04107116 | Runtime:0.002 mins\n", - "Epoch:255 | Batch:1 | Loss:0.04104768857359886 |ForwardBackwardUpdate:0.05secs | Mem. Usage 584.25MB\n", - "Epoch:255 | Batch:2 | Loss:0.0432411432 |ForwardBackwardUpdate:0.03sec | BatchConst.:0.01sec | Mem. Usage 588.83MB avail. 49.9 %\n", - "Epoch:255 | Loss:0.04214442 | Runtime:0.002 mins\n", - "Epoch:256 | Batch:1 | Loss:0.041763968765735626 |ForwardBackwardUpdate:0.04secs | Mem. Usage 584.26MB\n", - "Epoch:256 | Batch:2 | Loss:0.0412489139 |ForwardBackwardUpdate:0.02sec | BatchConst.:0.01sec | Mem. Usage 588.58MB avail. 50.0 %\n", - "Epoch:256 | Loss:0.04150644 | Runtime:0.002 mins\n", - "Done ! It took 28.224 seconds.\n", - "\n", - "*** Save Trained Model ***\n", - "Took 0.0017 seconds | Current Memory Usage 588.89 in MB\n", - "Total computation time: 28.333 seconds\n", - "Evaluate AConEx on Train set: Evaluate AConEx on Train set\n", - "{'H@1': 0.8515145705521472, 'H@3': 0.9526457055214724, 'H@10': 0.9818826687116564, 'MRR': 0.9049869811033961}\n", - "Evaluate AConEx on Validation set: Evaluate AConEx on Validation set\n", - "{'H@1': 0.6694785276073619, 'H@3': 0.8351226993865031, 'H@10': 0.9386503067484663, 'MRR': 0.7677471996019155}\n", - "Evaluate AConEx on Test set: Evaluate AConEx on Test set\n", - "{'H@1': 0.6717095310136157, 'H@3': 0.8562783661119516, 'H@10': 0.9402420574886535, 'MRR': 0.771550307601881}\n" - ] - } - ], - "source": [ - "report=executor.start()" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "e561c776", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'num_train_triples': 10432,\n", - " 'num_entities': 135,\n", - " 'num_relations': 92,\n", - " 'EstimatedSizeMB': 0.13260650634765625,\n", - " 'NumParam': 138784,\n", - " 'path_experiment_folder': '/home/demir/Desktop/Softwares/dice-embeddings/examples/Experiments/2023-04-05 12:59:36.749715',\n", - " 'Runtime': 28.33254861831665,\n", - " 'Train': {'H@1': 0.8515145705521472,\n", - " 'H@3': 0.9526457055214724,\n", - " 'H@10': 0.9818826687116564,\n", - " 'MRR': 0.9049869811033961},\n", - " 'Val': {'H@1': 0.6694785276073619,\n", - " 'H@3': 0.8351226993865031,\n", - " 'H@10': 0.9386503067484663,\n", - " 'MRR': 0.7677471996019155},\n", - " 'Test': {'H@1': 0.6717095310136157,\n", - " 'H@3': 0.8562783661119516,\n", - " 'H@10': 0.9402420574886535,\n", - " 'MRR': 0.771550307601881}}" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "report" - ] - }, - { - "cell_type": "markdown", - "id": "8048736e", - "metadata": {}, - "source": [ - "# How to Eval" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "96c6dabd", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Loading model model.pt... Done! It took 0.002\n", - "Loading entity and relation indexes... Done! It took 0.0002\n", - "Loading indexed training data...\n" - ] - } - ], - "source": [ - "pre_trained_model=KGE(path=report['path_experiment_folder'])" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "d9e5cb4d", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(tensor([0.9401, 0.9305, 0.8677, 0.8673, 0.8347, 0.6596, 0.4561, 0.3340, 0.1839,\n", - " 0.1810]),\n", - " ['family_group',\n", - " 'professional_or_occupational_group',\n", - " 'population_group',\n", - " 'group',\n", - " 'age_group',\n", - " 'patient_or_disabled_group',\n", - " 'carbohydrate_sequence',\n", - " 'geographic_area',\n", - " 'sign_or_symptom',\n", - " 'qualitative_concept'])" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "pre_trained_model.topk(head_entity='behavior',relation='associated_with',k=10)" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "331f6f89", - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Number of entities:135 \t Number of relations:92\n", - "Finding missing triples..\n", - "Number of found missing triples: 1\n" - ] - }, - { - "data": { - "text/plain": [ - "{('behavior', 'associated_with', 'age_group')}" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "pre_trained_model.find_missing_triples(confidence=0.8,\n", - " entities=['behavior'],\n", - " relations=['associated_with'],topk=10)" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "a60dce9d", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "'2023-04-05 12:59:36.749715'\r\n" - ] - } - ], - "source": [ - "# Data is already saved\n", - "!ls Experiments" - ] - }, - { - "cell_type": "markdown", - "id": "97fc46d4", - "metadata": {}, - "source": [ - "# How to deploy" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6ab54c45", - "metadata": {}, - "outputs": [], - "source": [ - "# Pretrained model can be used with anyone\n", - "pre_trained_model.deploy(share=True,top_k=10)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.10" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} \ No newline at end of file From 3eebbac5ea69d0cf66e7746b8be291642db522d0 Mon Sep 17 00:00:00 2001 From: Caglar Demir Date: Wed, 26 Jun 2024 11:07:22 +0200 Subject: [PATCH 8/8] Numpy verson fixed --- README.md | 4 ++-- requirements.txt | 1 + setup.py | 6 ++++-- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index c3e3a59b..4fd3934c 100644 --- a/README.md +++ b/README.md @@ -35,7 +35,7 @@ Deploy a pre-trained embedding model without writing a single line of code. ### Installation from Source ``` bash git clone https://github.com/dice-group/dice-embeddings.git -conda create -n dice python=3.10.13 --no-default-packages && conda activate dice && cd dice-embeddings && +conda create -n dice python=3.10.13 --no-default-packages && conda activate dice pip3 install -e . ``` or @@ -48,7 +48,7 @@ wget https://files.dice-research.org/datasets/dice-embeddings/KGs.zip --no-check ``` To test the Installation ```bash -python -m pytest -p no:warnings -x # Runs >114 tests leading to > 15 mins +python -m pytest -p no:warnings -x # Runs >119 tests leading to > 15 mins python -m pytest -p no:warnings --lf # run only the last failed test python -m pytest -p no:warnings --ff # to run the failures first and then the rest of the tests. ``` diff --git a/requirements.txt b/requirements.txt index c267a2aa..8c925a22 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +numpy==1.26.4 torch>=2.2.0 lightning>=2.1.3 pandas>=2.1.0 diff --git a/setup.py b/setup.py index 92b864e6..bb0ec75b 100644 --- a/setup.py +++ b/setup.py @@ -8,6 +8,7 @@ "torch>=2.2.0", "lightning>=2.1.3", "pandas>=2.1.0", + "numpy==1.26.4", "polars>=0.16.14", "scikit-learn>=1.2.2", "pyarrow>=11.0.0", @@ -33,12 +34,13 @@ def deps_list(*pkgs): extras = dict() extras["min"] = deps_list( "pandas", - "polars", "rdflib", # Loading KG + "polars", "pyarrow", "rdflib", # Loading KG "torch", "lightning", # Training KGE "tiktoken", # used for BPE "psutil", # Memory tracking: maybe remove later ? "matplotlib", # Unclear why it is needed - "pykeen" # additional kge models + "pykeen", # additional kge models + "numpy" ) # TODO: Remove polars, rdflib, tiktoken, psutil, matplotlib from min