Skip to content

Commit

Permalink
fix OOM issue in accelerate-training-of-large-embedding-tables-by-Laz…
Browse files Browse the repository at this point in the history
…yAdam nb unit test (#816)

* fix OOM issue in unit test

* fix test_soft_embedding
  • Loading branch information
rnyak authored Oct 20, 2022
1 parent 969c9a0 commit f11d77a
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,23 @@
"name": "stderr",
"output_type": "stream",
"text": [
"2022-09-28 21:36:00.826297: I tensorflow/core/platform/cpu_feature_guard.cc:194] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: SSE3 SSE4.1 SSE4.2 AVX\n",
"/usr/lib/python3/dist-packages/requests/__init__.py:89: RequestsDependencyWarning: urllib3 (1.26.12) or chardet (3.0.4) doesn't match a supported version!\n",
" warnings.warn(\"urllib3 ({}) or chardet ({}) doesn't match a supported \"\n",
"2022-10-20 19:37:18.414599: I tensorflow/core/platform/cpu_feature_guard.cc:194] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: SSE3 SSE4.1 SSE4.2 AVX\n",
"To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
"2022-09-28 21:36:03.012876: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1532] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 16255 MB memory: -> device: 0, name: Tesla V100-SXM2-32GB-LS, pci bus id: 0000:86:00.0, compute capability: 7.0\n"
"2022-10-20 19:37:21.731176: I tensorflow/core/common_runtime/gpu/gpu_process_state.cc:222] Using CUDA malloc Async allocator for GPU: 0\n",
"2022-10-20 19:37:21.731361: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1532] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 8080 MB memory: -> device: 0, name: Tesla V100-SXM2-16GB-N, pci bus id: 0000:0a:00.0, compute capability: 7.0\n",
"/usr/local/lib/python3.8/dist-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.\n"
]
}
],
"source": [
"import os\n",
"\n",
"import tensorflow as tf\n",
"os.environ[\"TF_GPU_ALLOCATOR\"]=\"cuda_malloc_async\"\n",
"\n",
"import merlin.models.tf as mm\n",
"from merlin.datasets.synthetic import generate_data\n",
Expand Down
1 change: 1 addition & 0 deletions tests/unit/torch/features/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def test_soft_embedding_invalid_embeddings_dim():


def test_soft_embedding():
torch.manual_seed(0)
embeddings_dim = 16
num_embeddings = 64

Expand Down

0 comments on commit f11d77a

Please sign in to comment.