diff --git a/.gitignore b/.gitignore
index 00368ede67d3d..7e21ba0b750df 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,2 +1,6 @@
*.DS_Store
build/
+*.user
+
+.vscode
+.idea
\ No newline at end of file
diff --git a/cmake/cblas.cmake b/cmake/cblas.cmake
index 617bd7ea7162b..529b4b9d15d09 100644
--- a/cmake/cblas.cmake
+++ b/cmake/cblas.cmake
@@ -65,12 +65,14 @@ set(OPENBLAS_ROOT $ENV{OPENBLAS_ROOT} CACHE PATH "Folder contains Openblas")
set(OPENBLAS_INCLUDE_SEARCH_PATHS
${OPENBLAS_ROOT}/include
/usr/include
- /usr/include/openblas)
+ /usr/include/openblas
+ /usr/local/opt/openblas/include)
set(OPENBLAS_LIB_SEARCH_PATHS
${OPENBLAS_ROOT}/lib
/usr/lib
/usr/lib/blas/openblas
- /usr/lib/openblas)
+ /usr/lib/openblas
+ /usr/local/opt/openblas/lib)
find_path(OPENBLAS_INC_DIR NAMES cblas.h
PATHS ${OPENBLAS_INCLUDE_SEARCH_PATHS})
diff --git a/cmake/cudnn.cmake b/cmake/cudnn.cmake
index e2ff923a22923..e5b59be19369d 100644
--- a/cmake/cudnn.cmake
+++ b/cmake/cudnn.cmake
@@ -15,7 +15,7 @@ list(APPEND CUDNN_CHECK_LIBRARY_DIRS
$ENV{CUDNN_ROOT}/lib64
$ENV{CUDNN_ROOT}/lib
/usr/lib)
-find_library(CUDNN_LIBRARY NAMES libcudnn.so # libcudnn_static.a
+find_library(CUDNN_LIBRARY NAMES libcudnn.so libcudnn.dylib # libcudnn_static.a
PATHS ${CUDNN_CHECK_LIBRARY_DIRS} ${CUDNN_INCLUDE_DIR} ${__libpath_hist}
NO_DEFAULT_PATH
DOC "Path to cuDNN library.")
diff --git a/cmake/util.cmake b/cmake/util.cmake
index 5b56304656e38..d776c3ae49952 100644
--- a/cmake/util.cmake
+++ b/cmake/util.cmake
@@ -1,16 +1,55 @@
# Some common routine for paddle compile.
-
# target_circle_link_libraries
# Link libraries to target which has circle dependencies.
#
# First Argument: target name want to be linked with libraries
# Rest Arguments: libraries which link together.
function(target_circle_link_libraries TARGET_NAME)
- target_link_libraries(${TARGET_NAME}
- -Wl,--start-group
- ${ARGN}
- -Wl,--end-group)
+ if(APPLE)
+ set(LIBS)
+ set(inArchive OFF)
+ set(libsInArgn)
+
+ foreach(arg ${ARGN})
+ if(${arg} STREQUAL "ARCHIVE_START")
+ set(inArchive ON)
+ elseif(${arg} STREQUAL "ARCHIVE_END")
+ set(inArchive OFF)
+ else()
+ if(inArchive)
+ list(APPEND LIBS "-Wl,-force_load")
+ endif()
+ list(APPEND LIBS ${arg})
+ list(APPEND libsInArgn ${arg})
+ endif()
+ endforeach()
+ if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang")
+ list(APPEND LIBS "-undefined dynamic_lookup")
+ endif()
+ list(REVERSE libsInArgn)
+ target_link_libraries(${TARGET_NAME}
+ ${LIBS}
+ ${libsInArgn})
+
+ else() # LINUX
+ set(LIBS)
+
+ foreach(arg ${ARGN})
+ if(${arg} STREQUAL "ARCHIVE_START")
+ list(APPEND LIBS "-Wl,--whole-archive")
+ elseif(${arg} STREQUAL "ARCHIVE_END")
+ list(APPEND LIBS "-Wl,--no-whole-archive")
+ else()
+ list(APPEND LIBS ${arg})
+ endif()
+ endforeach()
+
+ target_link_libraries(${TARGET_NAME}
+ "-Wl,--start-group"
+ ${LIBS}
+ "-Wl,--end-group")
+ endif()
endfunction()
# compile_cu_as_cpp
@@ -41,20 +80,20 @@ function(link_paddle_exe TARGET_NAME)
if(PADDLE_WITH_INTERNAL)
set(INTERAL_LIBS paddle_internal_gserver paddle_internal_parameter)
target_circle_link_libraries(${TARGET_NAME}
- -Wl,--whole-archive
+ ARCHIVE_START
paddle_internal_gserver
paddle_internal_owlqn
- -Wl,--no-whole-archive
+ ARCHIVE_END
paddle_internal_parameter)
else()
set(INTERAL_LIBS "")
endif()
target_circle_link_libraries(${TARGET_NAME}
- -Wl,--whole-archive
+ ARCHIVE_START
paddle_gserver
${METRIC_LIBS}
- -Wl,--no-whole-archive
+ ARCHIVE_END
paddle_pserver
paddle_trainer_lib
paddle_network
diff --git a/doc/build/build_from_source.md b/doc/build/build_from_source.md
index a191d31318aa6..a6090d6819162 100644
--- a/doc/build/build_from_source.md
+++ b/doc/build/build_from_source.md
@@ -1,141 +1,306 @@
-Build and Install
+Installing from Sources
=================
-## Requirement
+* [1. Download and Setup](#download)
+* [2. Requirements](#requirements)
+* [3. Build on Ubuntu](#ubuntu)
+* [4. Build on Mac OS X](#mac)
-### Dependents
+## Download and Setup
+You can download PaddlePaddle from the [github source](https://github.com/gangliao/Paddle).
-- **CMake**: required for 2.8+ version
-- **g++**: a recent c++ compiler supporting c++11, >= 4.6, < 5
-- **BLAS library**: such as openBLAS, MKL, ATLAS
-- **protobuf**: required for 2.4+ version, 3.x is not supported
-- **python**: currently only 2.7 version is supported
+```bash
+git clone https://github.com/baidu/Paddle paddle
+```
+
+## Requirements
+
+To compile the source code, your computer must be equipped with GCC >=4.6 or Clang Compiler.
+### Dependencies
+
+- **CMake**: version >= 2.8
+- **BLAS**: MKL, OpenBlas or ATLAS
+- **protobuf**: version >= 2.4, **Note: 3.x is not supported**
+- **python**: only python 2.7 is supported currently
+
+### Options
+
+PaddlePaddle supports some build options. To enable it, first you need to install the related libraries.
+
+ Optional | Description
+ ------------ | :-----------
+ **WITH_GPU** | Compile with GPU mode.
+ **WITH_DOUBLE** | Compile with double precision floating-point, default: single precision. |
+ **WITH_GLOG** | Compile with glog. If not found, default: an internal log implementation.
+ **WITH_GFLAGS** | Compile with gflags. If not found, default: an internal flag implementation.
+ **WITH_TESTING** | Compile with gtest for PaddlePaddle's unit testing.
+ **WITH_DOC** | Compile to generate PaddlePaddle's docs, default: disabled (OFF).
+ **WITH_SWIG_PY** | Compile with python predict API, default: disabled (OFF).
+ **WITH_STYLE_CHECK**| Compile with code style check, default: enabled (ON).
+|
-### Optional
+**Note:**
+ - The GPU version works best with Cuda Toolkit 7.5 and cuDNN v5.
+ - Other versions like Cuda Toolkit 6.5, 7.0, 8.0 and cuDNN v2, v3, v4 are also supported.
+ - **To utilize cuDNN v5, Cuda Toolkit 7.5 is prerequisite and vice versa.**
-PaddlePaddle also support some build options, you have to install related libraries.
+As a simple example, consider the following:
-- **WITH_GPU**: Compile with gpu mode
- - The GPU version works best with Cuda Toolkit 7.5 and cuDNN v5
- - Other versions Cuda Toolkit 6.5, 7.0 and cuDNN v2, v3, v4 are also supported
- - Note: to utilize cuDNN v5, Cuda Toolkit 7.5 is prerequisite and vice versa
-- **WITH_DOUBLE**: Compile with double precision, otherwise use single precision
-- **WITH_GLOG**: Compile with glog, otherwise use a log implement internally
-- **WITH_GFLAGS**: Compile with gflags, otherwise use a flag implement internally
-- **WITH_TESTING**: Compile with gtest and run unittest for PaddlePaddle
-- **WITH_DOC**: Compile with documentation
-- **WITH_SWIG_PY**: Compile with python predict api
-- **WITH_STYLE_CHECK**: Style check for source code
+1. **Python Dependencies(optional)**
+
+ To compile PaddlePaddle with python predict API, make sure swig installed and set `-DWITH_SWIG_PY=ON` as follows:
+
+ ```bash
+ # install swig on ubuntu
+ sudo apt-get install swig
+ # install swig on Mac OS X
+ brew install swig
+
+ # active swig in cmake
+ cmake .. -DWITH_SWIG_PY=ON
+ ```
+
+2. **Doc Dependencies(optional)**
+
+ To generate PaddlePaddle's documentation, install dependencies and set `-DWITH_DOC=ON` as follows:
+ ```bash
+ pip install 'sphinx>=1.4.0'
+ pip install sphinx_rtd_theme breathe recommonmark
-## Building on Ubuntu14.04
+ # install doxygen on Ubuntu
+ sudo apt-get install doxygen
+ # install doxygen on Mac OS X
+ brew install doxygen
+
+ # active docs in cmake
+ cmake .. -DWITH_DOC=ON`
+ ```
+
+## Build on Ubuntu 14.04
### Install Dependencies
- **CPU Dependencies**
-```bash
-# necessary
-sudo apt-get update
-sudo apt-get install -y g++ make cmake build-essential libatlas-base-dev python python-pip libpython-dev m4 libprotobuf-dev protobuf-compiler python-protobuf python-numpy git
-# optional
-sudo apt-get install libgoogle-glog-dev
-sudo apt-get install libgflags-dev
-sudo apt-get install libgtest-dev
-sudo pip install wheel
-pushd /usr/src/gtest
-cmake .
-make
-sudo cp *.a /usr/lib
-popd
-```
-
+ ```bash
+ # necessary
+ sudo apt-get update
+ sudo apt-get install -y g++ make cmake build-essential libatlas-base-dev python python-pip libpython-dev m4 libprotobuf-dev protobuf-compiler python-protobuf python-numpy git
+ # optional
+ sudo apt-get install libgoogle-glog-dev
+ sudo apt-get install libgflags-dev
+ sudo apt-get install libgtest-dev
+ sudo pip install wheel
+ pushd /usr/src/gtest
+ cmake .
+ make
+ sudo cp *.a /usr/lib
+ popd
+ ```
-- **GPU Dependencies(optional)**
+- **GPU Dependencies (optional)**
-If you need to build GPU version, the first thing you need is a machine that has GPU and CUDA installed.
-And you also need to install cuDNN.
+ To build GPU version, you will need the following installed:
-You can download CUDA toolkit and cuDNN from nvidia website:
-
-```bash
-https://developer.nvidia.com/cuda-downloads
-https://developer.nvidia.com/cudnn
-```
-You can copy cuDNN files into the CUDA toolkit directory, such as:
+ 1. a CUDA-capable GPU
+ 2. A supported version of Linux with a gcc compiler and toolchain
+ 3. NVIDIA CUDA Toolkit (available at http://developer.nvidia.com/cuda-downloads)
+ 4. NVIDIA cuDNN Library (availabel at https://developer.nvidia.com/cudnn)
+
+ The CUDA development environment relies on tight integration with the host development environment,
+ including the host compiler and C runtime libraries, and is therefore only supported on
+ distribution versions that have been qualified for this CUDA Toolkit release.
+
+ After downloading cuDNN library, issue the following commands:
+
+ ```bash
+ sudo tar -xzf cudnn-7.5-linux-x64-v5.1.tgz -C /usr/local
+ sudo chmod a+r /usr/local/cuda/include/cudnn.h /usr/local/cuda/lib64/libcudnn*
+ ```
+ Then you need to set LD\_LIBRARY\_PATH, CUDA\_HOME and PATH environment variables in ~/.bashrc.
+
+ ```bash
+ export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH
+ export CUDA_HOME=/usr/local/cuda
+ export PATH=/usr/local/cuda/bin:$PATH
+ ```
+
+### Build and Install
+
+As usual, the best option is to create build folder under paddle project directory.
```bash
-sudo tar -xzf cudnn-7.5-linux-x64-v5.1.tgz -C /usr/local
-sudo chmod a+r /usr/local/cuda/include/cudnn.h /usr/local/cuda/lib64/libcudnn*
+mkdir build && cd build
+cmake ..
```
-Then you need to set LD\_LIBRARY\_PATH, CUDA\_HOME and PATH environment variables in ~/.bashrc.
+
+CMake first check PaddlePaddle's dependencies in system default path. After installing some optional
+libraries, corresponding build option will be set automatically (for instance, glog, gtest and gflags).
+If still not found, you can manually set it based on CMake error information from your screen.
+
+As a simple example, consider the following:
+
+- **Only CPU**
+
+ ```bash
+ cmake .. -DWITH_GPU=OFF -DWITH_DOC=OFF
+ ```
+- **GPU**
+
+ ```bash
+ cmake .. -DWITH_GPU=ON -DWITH_DOC=OFF
+ ```
+
+- **GPU with doc and swig**
+
+ ```bash
+ cmake .. -DWITH_GPU=ON -DWITH_DOC=ON -DWITH_SWIG_PY=ON
+ ```
+
+Finally, you can download source code and build:
```bash
-export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH
-export CUDA_HOME=/usr/local/cuda
-export PATH=/usr/local/cuda/bin:$PATH
+# you can add build option here, such as:
+cmake .. -DWITH_GPU=ON -DWITH_DOC=OFF -DCMAKE_INSTALL_PREFIX=
+# please use sudo make install, if you want
+# to install PaddlePaddle into the system
+make -j `nproc` && make install
+# set PaddlePaddle installation path in ~/.bashrc
+export PATH=/bin:$PATH
```
-- **Python Dependencies(optional)**
-If you want to compile PaddlePaddle with python predict api, you need to add -DWITH_SWIG_PY=ON in cmake command and install these first:
+**Note:**
+
+If you set `WITH_SWIG_PY=ON`, related python dependencies also need to be installed.
+Otherwise, PaddlePaddle will automatically install python dependencies
+at first time when user run paddle commands, such as `paddle version`, `paddle train`.
+It may require sudo privileges:
```bash
-sudo apt-get install swig
+# you can run
+sudo pip install /opt/paddle/share/wheels/*.whl
+# or just run
+sudo paddle version
```
-- **Doc Dependencies(optional)**
+## Building on Mac OS X
-If you want to compile PaddlePaddle with doc, you need to add -DWITH_DOC=ON in cmake command and install these first:
+### Prerequisites
+This guide is based on Mac OS X 10.11 (El Capitan). Note that if you are running an up to date version of OS X,
+you will already have Python 2.7.10 and Numpy 1.8 installed.
+
+The best option is to use the package manager homebrew to handle installations and upgrades for you.
+To install [homebrew](http://brew.sh/), first open a terminal window (you can find Terminal in the Utilities folder in Applications), and issue the command:
```bash
-pip install 'sphinx>=1.4.0'
-pip install sphinx_rtd_theme breathe recommonmark
-sudo apt-get install doxygen
+# install brew
+/usr/bin/ruby -e "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/master/install)"
+# install pip
+easy_install pip
```
-### Build and Install
+### Install Dependencies
-CMake will find dependent libraries in system default paths first. After installing some optional libraries, corresponding build option will automatically be on(such as glog, gtest and gflags). And if libraries are not found, you have to set following variables manually in cmake command(CUDNN_ROOT, ATLAS_ROOT, MKL_ROOT, OPENBLAS_ROOT).
+- **CPU Dependencies**
-Here are some examples of cmake command with different options:
+ ```bash
+ # Install fundamental dependents
+ brew install glog gflags cmake protobuf openblas
+
+ # Install google test on Mac OS X
+ # Download gtest 1.7.0
+ wget https://github.com/google/googletest/archive/release-1.7.0.tar.gz
+ tar -xvf googletest-release-1.7.0.tar.gz && cd googletest-release-1.7.0
+ # Build gtest
+ mkdir build && cmake ..
+ make
+ # Install gtest library
+ sudo cp -r ../include/gtest /usr/local/include/
+ sudo cp lib*.a /usr/local/lib
+ ```
-**only cpu**
+- **GPU Dependencies(optional)**
-```bash
-cmake -DWITH_GPU=OFF -DWITH_DOC=OFF
-```
+ To build GPU version, you will need the following installed:
+
+ 1. a CUDA-capable GPU
+ 2. Mac OS X 10.11 or later
+ 2. the Clang compiler and toolchain installed using Xcode
+ 3. NVIDIA CUDA Toolkit (available at http://developer.nvidia.com/cuda-downloads)
+ 4. NVIDIA cuDNN Library (availabel at https://developer.nvidia.com/cudnn)
+
+ The CUDA development environment relies on tight integration with the host development environment,
+ including the host compiler and C runtime libraries, and is therefore only supported on
+ distribution versions that have been qualified for this CUDA Toolkit release.
+
+ 1. After downloading cuDNN library, issue the following commands:
+
+ ```bash
+ sudo tar -xzf cudnn-7.5-osx-x64-v5.0-ga.tgz -C /usr/local
+ sudo chmod a+r /usr/local/cuda/include/cudnn.h /usr/local/cuda/lib64/libcudnn*
+ ```
+ 2. Then you need to set DYLD\_LIBRARY\_PATH, CUDA\_HOME and PATH environment variables in ~/.bashrc.
-**gpu**
+ ```bash
+ export DYLD_LIBRARY_PATH=/usr/local/cuda/lib:$DYLD_LIBRARY_PATH
+ export PATH=/usr/local/cuda/bin:$PATH
+ ```
+
+### Build and Install
+
+As usual, the best option is to create build folder under paddle project directory.
```bash
-cmake -DWITH_GPU=ON -DWITH_DOC=OFF
+mkdir build && cd build
+cmake ..
```
-**gpu with doc and swig**
+CMake first check PaddlePaddle's dependencies in system default path. After installing some optional
+libraries, corresponding build option will be set automatically (for instance, glog, gtest and gflags).
+If still not found, you can manually set it based on CMake error information from your screen.
-```bash
-cmake -DWITH_GPU=ON -DWITH_DOC=ON -DWITH_SWIG_PY=ON
-```
+As a simple example, consider the following:
-Finally, you can download source code and build:
+- **Only CPU**
+
+ ```bash
+ cmake .. -DWITH_GPU=OFF -DWITH_DOC=OFF
+ ```
+- **GPU**
+
+ ```bash
+ cmake .. -DWITH_GPU=ON -DWITH_DOC=OFF
+ ```
+
+- **GPU with doc and swig**
+
+ ```bash
+ cmake .. -DWITH_GPU=ON -DWITH_DOC=ON -DWITH_SWIG_PY=ON
+ ```
+
+Finally, you can build PaddlePaddle:
```bash
-git clone https://github.com/baidu/Paddle paddle
-cd paddle
-mkdir build
-cd build
# you can add build option here, such as:
-cmake -DWITH_GPU=ON -DWITH_DOC=OFF -DCMAKE_INSTALL_PREFIX= ..
-# please use sudo make install, if you want
-# to install PaddlePaddle into the system
+cmake .. -DWITH_GPU=ON -DWITH_DOC=OFF -DCMAKE_INSTALL_PREFIX=
+# please use sudo make install, if you want to install PaddlePaddle into the system
make -j `nproc` && make install
-# PaddlePaddle installation path
-export PATH=/bin:$PATH
+# set PaddlePaddle installation path in ~/.bashrc
+export PATH=/bin:$PATH
```
-**Note**
-And if you set WITH_SWIG_PY=ON, you have to install related python predict api at the same time:
+
+**Note:**
+
+If you set `WITH_SWIG_PY=ON`, related python dependencies also need to be installed.
+Otherwise, PaddlePaddle will automatically install python dependencies
+at first time when user run paddle commands, such as `paddle version`, `paddle train`.
+It may require sudo privileges:
```bash
-pip install /opt/paddle/share/wheels/*.whl
-```
+# you can run
+sudo pip install /opt/paddle/share/wheels/*.whl
+# or just run
+sudo paddle version
+```
\ No newline at end of file
diff --git a/paddle/api/PaddleAPI.h b/paddle/api/PaddleAPI.h
index 79487c4cf4d41..b3140617af188 100644
--- a/paddle/api/PaddleAPI.h
+++ b/paddle/api/PaddleAPI.h
@@ -20,6 +20,7 @@ limitations under the License. */
#include
#include
#include "paddle/utils/GlobalConstants.h"
+#include "paddle/utils/TypeDefs.h"
/// Import PaddlePaddle's enumeration into global namespace.
using namespace paddle::enumeration_wrapper; // NOLINT
diff --git a/paddle/api/Util.cpp b/paddle/api/Util.cpp
index 4e655c324a1ed..8a6741078f2f1 100644
--- a/paddle/api/Util.cpp
+++ b/paddle/api/Util.cpp
@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/utils/Util.h"
#include "paddle/utils/PythonUtil.h"
#include "paddle/utils/Flags.h"
+#include "paddle/utils/Excepts.h"
#include "paddle/parameter/Parameter.h"
#include
diff --git a/paddle/api/paddle_ld_flags.py b/paddle/api/paddle_ld_flags.py
index 21b4ca1dd6171..bc1afc5898e82 100644
--- a/paddle/api/paddle_ld_flags.py
+++ b/paddle/api/paddle_ld_flags.py
@@ -15,6 +15,19 @@
try:
from paddle_api_config import *
import os.path
+ import platform
+
+ system = platform.system().lower()
+ is_osx = (system == 'darwin')
+ is_win = (system == 'windows')
+ is_lin = (system == 'linux')
+
+ if is_lin:
+ whole_start = "-Wl,--whole-archive"
+ whole_end = "-Wl,--no-whole-archive"
+ elif is_osx:
+ whole_start = ""
+ whole_end = ""
LIB_DIRS = ["math", 'utils', 'parameter', "gserver", "api", "cuda", "pserver", "trainer"]
PARENT_LIB_DIRS = ['proto']
@@ -56,9 +69,9 @@ def parent_dir_str(self):
def libs_str(self):
libs = [
- "-Wl,--whole-archive",
+ whole_start,
"-lpaddle_gserver",
- "-Wl,--no-whole-archive",
+ whole_end,
"-lpaddle_pserver",
"-lpaddle_trainer_lib",
"-lpaddle_network",
diff --git a/paddle/cuda/include/hl_device_functions.cuh b/paddle/cuda/include/hl_device_functions.cuh
index 27e3f450c5c1c..88d950d6c1713 100755
--- a/paddle/cuda/include/hl_device_functions.cuh
+++ b/paddle/cuda/include/hl_device_functions.cuh
@@ -16,28 +16,37 @@ limitations under the License. */
#ifndef HL_DEVICE_FUNCTIONS_CUH_
#define HL_DEVICE_FUNCTIONS_CUH_
-namespace hppl {
-
-static __inline__ __device__ double atomicAdd(double* address, double val) {
- // NOLINTNEXTLINE
- unsigned long long int* address_as_ull = (unsigned long long int*)address;
- unsigned long long int old = *address_as_ull, assumed; // NOLINT
-
- do {
- assumed = old;
- old = atomicCAS(address_as_ull,
- assumed,
- __double_as_longlong(val +
- __longlong_as_double(assumed)));
- } while (assumed != old);
-
- return __longlong_as_double(old);
-}
+namespace paddle {
+
+template
+inline __device__ T paddleAtomicAdd(T* address, T val);
-} // namespace hppl
+template <>
+inline __device__ float paddleAtomicAdd(float* address, float val) {
+ return atomicAdd(address, val);
+}
-#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 600
-using hppl::atomicAdd;
+template <>
+inline __device__ double paddleAtomicAdd(double* address, double val) {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600
+ return atomicAdd(address, val);
+#else
+ // NOLINTNEXTLINE
+ unsigned long long int* address_as_ull = (unsigned long long int*)address;
+ unsigned long long int old = *address_as_ull, assumed; // NOLINT
+
+ do {
+ assumed = old;
+ old = atomicCAS(address_as_ull,
+ assumed,
+ __double_as_longlong(val +
+ __longlong_as_double(assumed)));
+ } while (assumed != old);
+
+ return __longlong_as_double(old);
#endif
+}
+} // namespace paddle
+
#endif /* HL_DEVICE_FUNCTIONS_CUH_ */
diff --git a/paddle/cuda/include/hl_gpu_lstm.cuh b/paddle/cuda/include/hl_gpu_lstm.cuh
index 2ca33f2b13a1f..07806e11c18a2 100644
--- a/paddle/cuda/include/hl_gpu_lstm.cuh
+++ b/paddle/cuda/include/hl_gpu_lstm.cuh
@@ -192,10 +192,10 @@ __global__ void KeLstmBackward(Op op,
if (isBatch) {
if (value.prevStateValue) {
- if (grad.checkIgGrad) atomicAdd(grad.checkIgGrad+frameIdx, rCheckIGrad);
- if (grad.checkFgGrad) atomicAdd(grad.checkFgGrad+frameIdx, rCheckFGrad);
+ if (grad.checkIgGrad) paddle::paddleAtomicAdd(grad.checkIgGrad+frameIdx, rCheckIGrad);
+ if (grad.checkFgGrad) paddle::paddleAtomicAdd(grad.checkFgGrad+frameIdx, rCheckFGrad);
}
- if (grad.checkOgGrad) atomicAdd(grad.checkOgGrad+frameIdx, rCheckOGrad);
+ if (grad.checkOgGrad) paddle::paddleAtomicAdd(grad.checkOgGrad+frameIdx, rCheckOGrad);
} else {
if (value.prevStateValue) {
if (grad.checkIgGrad) grad.checkIgGrad[frameIdx] += rCheckIGrad;
diff --git a/paddle/cuda/include/hl_matrix_type.cuh b/paddle/cuda/include/hl_matrix_type.cuh
index 85b60cc313fa7..6917f36290141 100644
--- a/paddle/cuda/include/hl_matrix_type.cuh
+++ b/paddle/cuda/include/hl_matrix_type.cuh
@@ -27,6 +27,8 @@ typedef float4 vecType;
typedef double2 vecType;
#endif
#else
+#include
+#include
#include
#ifndef HPPL_TYPE_DOUBLE
typedef __m128 vecType;
diff --git a/paddle/cuda/include/hl_sse_matrix_kernel.cuh b/paddle/cuda/include/hl_sse_matrix_kernel.cuh
index d774150c21e61..c90d49e4adeb5 100644
--- a/paddle/cuda/include/hl_sse_matrix_kernel.cuh
+++ b/paddle/cuda/include/hl_sse_matrix_kernel.cuh
@@ -25,6 +25,9 @@ limitations under the License. */
#define VECTOR_LEN 4
#define VECTOR_SET _mm_set_ps1
#else
+#if defined(__APPLE__) || defined(__OSX__)
+#define _mm_set_pd1 _mm_set1_pd
+#endif
/* number of double in vector */
#define VECTOR_LEN 2
#define VECTOR_SET _mm_set_pd1
diff --git a/paddle/cuda/src/hl_cuda_device.cc b/paddle/cuda/src/hl_cuda_device.cc
index f07538d6ba713..acd8e2fe6afb4 100644
--- a/paddle/cuda/src/hl_cuda_device.cc
+++ b/paddle/cuda/src/hl_cuda_device.cc
@@ -209,7 +209,18 @@ __thread cudaStream_t default_stream = 0;
__thread bool g_sync_flag = true;
bool hl_start_flag = false;
-#define gettid() syscall(SYS_gettid)
+inline pid_t gettid() {
+#if defined(__APPLE__) || defined(__OSX__)
+ pid_t tid = syscall(SYS_thread_selfid);
+#else
+ #ifndef __NR_gettid
+ #define __NR_gettid 224
+ #endif
+ pid_t tid = syscall(__NR_gettid);
+#endif
+ CHECK_NE(tid, -1);
+ return tid;
+}
void hl_init(int device) {
CHECK(hl_start_flag)
diff --git a/paddle/cuda/src/hl_cuda_lstm.cu b/paddle/cuda/src/hl_cuda_lstm.cu
index 64699c9f6d450..cf009620bf69d 100644
--- a/paddle/cuda/src/hl_cuda_lstm.cu
+++ b/paddle/cuda/src/hl_cuda_lstm.cu
@@ -564,11 +564,11 @@ __global__ void KeLstmBackward(real *gateValue,
/* TODO: Temporary save & merger in another kernel */
if (frameIdy == 1) {
- if (checkIgGrad) atomicAdd(checkIgGrad+frameIdx, rCheckGrad);
+ if (checkIgGrad) paddle::paddleAtomicAdd(checkIgGrad+frameIdx, rCheckGrad);
} else if (frameIdy == 2) {
- if (checkFgGrad) atomicAdd(checkFgGrad+frameIdx, rCheckGrad);
+ if (checkFgGrad) paddle::paddleAtomicAdd(checkFgGrad+frameIdx, rCheckGrad);
} else if (frameIdy == 3) {
- if (checkOgGrad) atomicAdd(checkOgGrad+frameIdx, rCheckGrad);
+ if (checkOgGrad) paddle::paddleAtomicAdd(checkOgGrad+frameIdx, rCheckGrad);
}
}
diff --git a/paddle/cuda/src/hl_cuda_matrix.cu b/paddle/cuda/src/hl_cuda_matrix.cu
index ecc44944e4fa1..38e4f16217c2a 100644
--- a/paddle/cuda/src/hl_cuda_matrix.cu
+++ b/paddle/cuda/src/hl_cuda_matrix.cu
@@ -623,7 +623,7 @@ __global__ void KeCosSimDerivative(real* grad,
prevGradY[index] +=
scale * grad[ty] * prevOutX[index] * reciprocal;
} else {
- atomicAdd(prevGradY + index,
+ paddle::paddleAtomicAdd(prevGradY + index,
scale * grad[ty] * prevOutX[index] * reciprocal);
}
}
@@ -640,7 +640,7 @@ __global__ void KeCosSimDerivative(real* grad,
(prevOutX[index] * reciprocalXY -
prevOutY[index] * reciprocalSquareSumY);
} else {
- atomicAdd(prevGradY + index, output[ty] * grad[ty] *
+ paddle::paddleAtomicAdd(prevGradY + index, output[ty] * grad[ty] *
(prevOutX[index] * reciprocalXY -
prevOutY[index] * reciprocalSquareSumY));
}
diff --git a/paddle/cuda/src/hl_cuda_sequence.cu b/paddle/cuda/src/hl_cuda_sequence.cu
index f88a2682fd060..e028880156e5b 100644
--- a/paddle/cuda/src/hl_cuda_sequence.cu
+++ b/paddle/cuda/src/hl_cuda_sequence.cu
@@ -362,7 +362,7 @@ __global__ void KeMatrixAddRows(real* output,
if (AddRow == 0) {
outputData[i] += tableData[i];
} else {
- atomicAdd(&tableData[i], outputData[i]);
+ paddle::paddleAtomicAdd(&tableData[i], outputData[i]);
}
}
}
diff --git a/paddle/cuda/src/hl_cuda_sparse.cuh b/paddle/cuda/src/hl_cuda_sparse.cuh
index becb6c66492c1..db5c9ce979885 100644
--- a/paddle/cuda/src/hl_cuda_sparse.cuh
+++ b/paddle/cuda/src/hl_cuda_sparse.cuh
@@ -280,7 +280,7 @@ __global__ void KeSMatrixCscMulDense(real *C_d,
if (index_n_t < dimN) {
real tmp;
tmp = alpha*a_r*b_r[n];
- atomicAdd(C_d_r, tmp);
+ paddle::paddleAtomicAdd(C_d_r, tmp);
C_d_r += CU_CSC_MUL_DENSE_THREAD_X;
index_n_t += CU_CSC_MUL_DENSE_THREAD_X;
}
@@ -328,7 +328,7 @@ __global__ void KeSMatrixCscMulDense(real *C_d,
if (index_n_t < dimN) {
real tmp;
tmp = alpha*a_r*b_r[n];
- atomicAdd(C_d_r, tmp);
+ paddle::paddleAtomicAdd(C_d_r, tmp);
C_d_r += CU_CSC_MUL_DENSE_THREAD_X;
index_n_t += CU_CSC_MUL_DENSE_THREAD_X;
}
@@ -629,7 +629,7 @@ __global__ void KeSMatrixDenseMulCsr(real *C_d,
for (int n=0; n < CU_DM_CSR_N; n++) {
if (index_m_t++ < dimM) {
tmp = alpha * b_r * a_r[n];
- atomicAdd(C_d_r, tmp);
+ paddle::paddleAtomicAdd(C_d_r, tmp);
C_d_r += dimN;
}
}
@@ -660,7 +660,7 @@ __global__ void KeSMatrixDenseMulCsr(real *C_d,
for (int n=0; n < CU_DM_CSR_N; n++) {
if (index_m_t++ < dimM) {
tmp = alpha * b_r * a_r[n];
- atomicAdd(C_d_r, tmp);
+ paddle::paddleAtomicAdd(C_d_r, tmp);
C_d_r += dimN;
}
}
@@ -912,7 +912,7 @@ __global__ void KeSMatrixCsrColumnSum(real* a_val, real* csr_val,
for (int idx = gid; idx < dimNNZ; idx += gridDim.x * blockDim.x) {
int colIdx = csr_col[idx];
real val = csr_val[idx];
- atomicAdd(a_val + colIdx, val);
+ paddle::paddleAtomicAdd(a_val + colIdx, val);
}
}
diff --git a/paddle/cuda/src/hl_dso_loader.cc b/paddle/cuda/src/hl_dso_loader.cc
index 3558b163b5ae0..eee9984e07326 100644
--- a/paddle/cuda/src/hl_dso_loader.cc
+++ b/paddle/cuda/src/hl_dso_loader.cc
@@ -69,23 +69,40 @@ static inline void GetDsoHandleWithSearchPath(
CHECK(nullptr != *dso_handle)
<< "For Gpu version of PaddlePaddle, it couldn't find CUDA library: "
- << dlPath.c_str() << " Please make sure you already specify its path."
- << "Note: for training data on Cpu using Gpu version of PaddlePaddle,"
- << "you must specify libcudart.so via LD_LIBRARY_PATH.";
+ << dlPath.c_str() << ". Please make sure you already specify its path. "
+ << "Note: for training data on Cpu using Gpu version of PaddlePaddle, "
+ << "you must specify libcudart via export LD_LIBRARY_PATH for Linux or "
+ << "export DYLD_LIBRARY_PATH for MAC OS.";
}
void GetCublasDsoHandle(void** dso_handle) {
+#if defined(__APPLE__) || defined(__OSX__)
+ GetDsoHandleWithSearchPath(FLAGS_cuda_dir, "libcublas.dylib", dso_handle);
+#else
GetDsoHandleWithSearchPath(FLAGS_cuda_dir, "libcublas.so", dso_handle);
+#endif
}
void GetCudnnDsoHandle(void** dso_handle) {
+#if defined(__APPLE__) || defined(__OSX__)
+ GetDsoHandleWithSearchPath(FLAGS_cudnn_dir, "libcudnn.dylib", dso_handle);
+#else
GetDsoHandleWithSearchPath(FLAGS_cudnn_dir, "libcudnn.so", dso_handle);
+#endif
}
void GetCudartDsoHandle(void** dso_handle) {
+#if defined(__APPLE__) || defined(__OSX__)
+ GetDsoHandleWithSearchPath("", "libcudart.dylib", dso_handle);
+#else
GetDsoHandleWithSearchPath("", "libcudart.so", dso_handle);
+#endif
}
void GetCurandDsoHandle(void** dso_handle) {
+#if defined(__APPLE__) || defined(__OSX__)
+ GetDsoHandleWithSearchPath(FLAGS_cuda_dir, "libcurand.dylib", dso_handle);
+#else
GetDsoHandleWithSearchPath(FLAGS_cuda_dir, "libcurand.so", dso_handle);
+#endif
}
diff --git a/paddle/cuda/src/hl_table_apply.cu b/paddle/cuda/src/hl_table_apply.cu
index 05335c5f835fc..52ee4610edf67 100644
--- a/paddle/cuda/src/hl_table_apply.cu
+++ b/paddle/cuda/src/hl_table_apply.cu
@@ -35,7 +35,7 @@ __global__ void KeMatrixAddRows(real* output, int ldo,
real *tab = table + tableId * ldt;
for (int i = idx; i < dim; i += blockDimX) {
if (AddRow) {
- atomicAdd(&tab[i], out[i]);
+ paddle::paddleAtomicAdd(&tab[i], out[i]);
} else {
out[i] += tab[i];
}
diff --git a/paddle/gserver/dataproviders/DataProviderGroup.h b/paddle/gserver/dataproviders/DataProviderGroup.h
index decbde6c91758..0689f90f3e7dd 100644
--- a/paddle/gserver/dataproviders/DataProviderGroup.h
+++ b/paddle/gserver/dataproviders/DataProviderGroup.h
@@ -65,7 +65,8 @@ void DataProviderGroup::reset() {
provider_ = nullptr;
// shuffle file list
- std::random_shuffle(fileList_.begin(), fileList_.end());
+ std::shuffle(fileList_.begin(), fileList_.end(),
+ ThreadLocalRandomEngine::get());
startLoader();
DataProvider::reset();
diff --git a/paddle/gserver/dataproviders/ProtoDataProvider.cpp b/paddle/gserver/dataproviders/ProtoDataProvider.cpp
index b0c14c85b2d81..344644755f240 100644
--- a/paddle/gserver/dataproviders/ProtoDataProvider.cpp
+++ b/paddle/gserver/dataproviders/ProtoDataProvider.cpp
@@ -374,7 +374,8 @@ void ProtoDataProvider::reset() {
}
void ProtoDataProvider::shuffle() {
- std::random_shuffle(shuffledSequenceIds_.begin(), shuffledSequenceIds_.end());
+ std::shuffle(shuffledSequenceIds_.begin(), shuffledSequenceIds_.end(),
+ ThreadLocalRandomEngine::get());
}
/*
diff --git a/paddle/gserver/dataproviders/PyDataProvider.cpp b/paddle/gserver/dataproviders/PyDataProvider.cpp
index aeefd16063df8..1332c0ab635b6 100644
--- a/paddle/gserver/dataproviders/PyDataProvider.cpp
+++ b/paddle/gserver/dataproviders/PyDataProvider.cpp
@@ -17,6 +17,8 @@ limitations under the License. */
#include "paddle/utils/PythonUtil.h"
#include
#include "paddle/utils/Util.h"
+#include "paddle/utils/Excepts.h"
+
namespace paddle {
@@ -44,7 +46,6 @@ PyDataProvider::PyDataProvider(const DataConfig& config, bool useGpu,
}
void PyDataProvider::loadData(const std::vector& fileList) {
- int feFlag = fegetexcept();
VLOG(1) << "module:" << pyModuleName_ << " class:" << pyClassName_;
classInstance_ =
createPythonClass(pyModuleName_, pyClassName_, fileList, pyUserArgs_);
@@ -55,7 +56,7 @@ void PyDataProvider::loadData(const std::vector& fileList) {
std::string headerInfo =
std::string(PyString_AsString(obj.get()), PyString_Size(obj.get()));
parseHeaderData(headerInfo);
- feenableexcept(feFlag);
+ feenableexcept(FE_INVALID | FE_DIVBYZERO | FE_OVERFLOW);
}
void PyDataProvider::parseHeaderData(const std::string& headerData) {
diff --git a/paddle/gserver/gradientmachines/NeuralNetwork.cpp b/paddle/gserver/gradientmachines/NeuralNetwork.cpp
index eb1522a178d48..3127b4dd9a2fd 100644
--- a/paddle/gserver/gradientmachines/NeuralNetwork.cpp
+++ b/paddle/gserver/gradientmachines/NeuralNetwork.cpp
@@ -385,17 +385,17 @@ void NeuralNetwork::setOutputGrad(const std::vector& args) {
}
}
-extern NeuralNetwork* newCustomNeuralNetwork(
- const std::string& name, NeuralNetwork* network) __attribute__((weak));
+extern NeuralNetwork* newCustomNerualNetwork(
+ const std::string& name, NeuralNetwork* network) __attribute__((weak));
NeuralNetwork* NeuralNetwork::newNeuralNetwork(
const std::string& name,
NeuralNetwork* rootNetwork) {
- if (newCustomNeuralNetwork) {
- return newCustomNeuralNetwork(name, rootNetwork);
- } else {
- return new NeuralNetwork(name, rootNetwork);
- }
+ if (newCustomNerualNetwork) {
+ return newCustomNerualNetwork(name, rootNetwork);
+ } else {
+ return new NeuralNetwork(name, rootNetwork);
+ }
}
} // namespace paddle
diff --git a/paddle/gserver/tests/concat_table_a.conf b/paddle/gserver/tests/concat_table_a.conf
index 2e3c518883e20..a8ff70f883318 100644
--- a/paddle/gserver/tests/concat_table_a.conf
+++ b/paddle/gserver/tests/concat_table_a.conf
@@ -16,9 +16,9 @@
from paddle.trainer_config_helpers import *
-settings(batch_size=1000)
+settings(batch_size=300)
-data = data_layer(name ="input", size=100000)
+data = data_layer(name ="input", size=10000)
# emb1 is equal to emb2, note that bias_attr=false
# and act=LinearActivation() in default.
diff --git a/paddle/gserver/tests/concat_table_b.conf b/paddle/gserver/tests/concat_table_b.conf
index 6da24a5fbc55c..95d7c10f7b0cd 100644
--- a/paddle/gserver/tests/concat_table_b.conf
+++ b/paddle/gserver/tests/concat_table_b.conf
@@ -16,9 +16,9 @@
from paddle.trainer_config_helpers import *
-settings(batch_size=1000)
+settings(batch_size=300)
-data = data_layer(name ="input", size=100000)
+data = data_layer(name ="input", size=10000)
proj1 = table_projection(input=data, size=128)
diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp
index 5c80eb546cfaf..3150c31e4900c 100644
--- a/paddle/gserver/tests/test_LayerGrad.cpp
+++ b/paddle/gserver/tests/test_LayerGrad.cpp
@@ -50,7 +50,7 @@ TEST(Operator, dot_mul) {
TEST(Projection, context) {
for (auto contextStart : {-5, -3, -1, 0, 3}) {
for (auto contextLength : {1, 2, 5, 7}) {
- for (auto batchSize : {1, 2, 5, 20, 100}) {
+ for (auto batchSize : {1, 2, 5, 20, 50}) {
for (auto trainablePadding : {false, true}) {
LOG(INFO) << " contextStart=" << contextStart
<< " contextLength=" << contextLength
diff --git a/paddle/gserver/tests/test_PyDataProvider2.cpp b/paddle/gserver/tests/test_PyDataProvider2.cpp
index c5fe31b29187f..e75e53ab7f431 100644
--- a/paddle/gserver/tests/test_PyDataProvider2.cpp
+++ b/paddle/gserver/tests/test_PyDataProvider2.cpp
@@ -321,7 +321,7 @@ TEST(PyDataProvider2, input_order) {
if (!realBatchSize) {
break;
}
- ASSERT_EQ(batch.getStreams().size(), 2);
+ ASSERT_EQ(batch.getStreams().size(), (size_t)2);
for (size_t i = 0; i < realBatchSize; ++i) {
ASSERT_EQ(batch.getStream(0).ids->getData()[i], 0);
ASSERT_EQ(batch.getStream(1).ids->getData()[i], 1);
diff --git a/paddle/math/Allocator.h b/paddle/math/Allocator.h
index 36166236e9eff..f7aa60380f23e 100644
--- a/paddle/math/Allocator.h
+++ b/paddle/math/Allocator.h
@@ -16,7 +16,7 @@ limitations under the License. */
#pragma once
#include
-#include
+#include
#include "hl_gpu.h"
#include "paddle/utils/Logging.h"
@@ -48,9 +48,10 @@ class CpuAllocator : public Allocator {
* @return Pointer to the allocated memory
*/
virtual void* alloc(size_t size) {
- void* ptr = memalign(32ul, size);
- CHECK(ptr) << "Fail to allocate CPU memory: size=" << size;
- return ptr;
+ void* ptr;
+ CHECK_EQ(posix_memalign(&ptr, 32ul, size), 0);
+ CHECK(ptr) << "Fail to allocate CPU memory: size=" << size;
+ return ptr;
}
/**
diff --git a/paddle/math/MathFunctions.h b/paddle/math/MathFunctions.h
index fe486c741d6f5..43075977dc9ce 100644
--- a/paddle/math/MathFunctions.h
+++ b/paddle/math/MathFunctions.h
@@ -23,6 +23,8 @@ extern "C" {
}
#endif
+#include
+
namespace paddle {
template
diff --git a/paddle/math/Matrix.cpp b/paddle/math/Matrix.cpp
index 1b7f9ac5dac16..e351bede724ac 100644
--- a/paddle/math/Matrix.cpp
+++ b/paddle/math/Matrix.cpp
@@ -2514,7 +2514,8 @@ void SharedCpuMatrix::mul(CpuSparseMatrix* a, CpuMatrix* b, real scaleAB,
for (int k = 0; k < blockNum_; ++k) {
blockSeq.push_back(k);
}
- std::random_shuffle(blockSeq.begin(), blockSeq.end());
+ std::shuffle(blockSeq.begin(), blockSeq.end(),
+ ThreadLocalRandomEngine::get());
}
std::vector& localBufRows = *localBufRows_;
int* cols = a->getCols();
diff --git a/paddle/math/PoolAllocator.h b/paddle/math/PoolAllocator.h
index 22af0eb893753..aca8ffb0ab42e 100644
--- a/paddle/math/PoolAllocator.h
+++ b/paddle/math/PoolAllocator.h
@@ -19,6 +19,7 @@ limitations under the License. */
#include
#include
#include
+#include