Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallel / distributed training #1148

Closed
wants to merge 1 commit into from
Closed

Parallel / distributed training #1148

wants to merge 1 commit into from

Conversation

shelhamer
Copy link
Member

Parallel and distributed training of Caffe models by streaming gradients among solvers. Parallelizes training without redefining models. This is the integration PR for @cypof's original contribution in #1140.

TODO

  • switch all threads to boost::thread with thread local storage
  • fix distributed hyper-parameter schedule
  • validate parallelism on multi-GPU Caffenet training
  • lint

Follow-up

  • distributed synchronous SGD for comparison with asynchronous / HogWild approach
  • synchronization by device peer-to-peer communication instead of host reduce

Please review and collaborate here. The original details of #1140 by @cypof are:


A set of classes to synchronize SGD between multiple solvers. Based on the Hogwild paper, and our work at Flickr to extend the model to GPUs and distributed configurations by streaming gradients between solvers.

Features

  • Models can be trained in parallel without modification. Caffe’s training code is also mostly untouched.
  • Modular design. The code is broken down in simple components that synchronize one segment, e.g. CPU/GPU, CPU/LAN. They can be combined to create an architecture, either in-process or between processes by memory-mapping the weights to /dev/shm.
  • Works on commodity hardware. Apparently even on 1G Ethernet, at least for mnist. Synchronization and SGD run asynchronously, to keep both compute and networking resources fully utilized. Bandwidth and latency across machines are optimized using raw sockets and user-space networking.
  • No additional memory used on the GPUs.

Limitations

  • Only supports data-parallelism. Limited forms of model-parallelism should be possible using the same components but no work has been done.
  • Training is less stable than on a single GPU. In particular, disabling momentum at least at the beginning of training seems to help.
  • No deployment / monitoring tools. We are looking at integrating with IPython.parallel.

Tests

Early results on MNIST seem to show linear scaling. We tested on up to 6 machines with 4 solvers each for CPU, and 2 machines with 2 GPUs each. GPUs do not perform well on this small network but still seem to scale linearly.

mnist

In the weeks to come we plan to start testing on larger networks and clusters. Currently our GPU machines are connected through 1G Ethernet, please contact us if you are interested to help benchmarking on better hardware.

Architecture

We made the Caffe singleton thread-local, to allow multiple solvers to run in parallel on their own thread. Synchronization works by sharing the weight buffers between solvers in the same address space, and by asynchronously measuring and exchanging gradients between address spaces.

Bugs / Todos

  • If the bandwidth between the GPU and host is set too high, machines seem to hang.
  • Incorrect hyper-params schedule in the distributed case. The total count of iterations needs to be tracked, maybe through the monitoring tool.
  • Thread-local Caffe singletons are not destroyed, we need to design a proper shutdown strategy.

@bug-fixed
Copy link
Contributor

Hi, thank you very much for this great PR!
I have forked the dev branch from @cypof, and compiled with the default settings.
Then, lanch the gpus.bin with imagenet, the following is some output, it works well, but it seems longer processing time and needs more memory. The gpus.bin costs about 40GB memory and the hogwild.bin costs abount 57GB memory.
The hardware is 2 k20m with ECC on.
The software is CentOS 6.5 and CUDA 6.0 with driver 331.62.
Because this is in a cluster system, I didn't have the root privilege, so the raw.bin couldn't run.

qq 20140924072612

@cypof
Copy link
Member

cypof commented Sep 24, 2014

The high memory is probably due to ldbm. Each solver creates a data_layer, which maps the data file and reads from a different location so a lot of memory shows up as used. It's just a cache for the data file so it should not be problem, the os can discard it if needed.

For info we should modify data_layer to only map the file only once per process, just to avoid running out of virtual memory. I tried once to start dozens of solvers with a 8TB map size and ran into x64 48-bits address limit!

For imagenet we also don't have very convincing results here, increasing the GPU bandwidth helps but crashes our machines. I'm not sure why yet, we are still well below the PCI limit.

Also I have a UDP based synchronization prototype that's a bit slower than raw sockets, but do not require a sudo install step. I could look at it again if there is demand for it.

@bug-fixed
Copy link
Contributor

@cypof , thank you very much for comments and sharing the valuable experience!
In the test of gpus.bin, it seems crash down (exit automatically) unexpectedly when evaluating the network after the first 1000 iterations, and the log didn't show any error message. So I'm not sure where the problem is.
If I understood correctly, the data parallelism needs much more concurrency visit of the disk. Because in the cluster system, the data storage locates in one single node who is mounted as a NFS share, many users are using it at the same time, so if the code needs too much data visits, I don't know whether the single data storage node could bear the heavy loads. If it needs too much data visits, that may influence the whole cluster performance.
If I understood correctly, the HDFS system could manage the indexes of all batches of data, the server would control which part of data run at which client. Thus, the HDFS system may integrate with the data parallelism in raw.bin. This is just a immature thought, maybe it is not reasonable.
In addition, the cluster system is for many users, and as a common user, I didn't have the root privilege, and if convenient, please share your UDP version of raw.bin. Because if I used the root to run the program, it may modify some global settings of the whole cluster, that may have some unexpected impacts on other users.

@bhack
Copy link
Contributor

bhack commented Sep 24, 2014

@cypof Do you have never evaluated zeromq multi transport or do it create too many overhead?

@bhack
Copy link
Contributor

bhack commented Sep 24, 2014

@zszhong Data parallelism is also an interesting topic. Spark use Berkeley Resilient Distributed Datasets (RDD) for distributed SGD update

@bug-fixed
Copy link
Contributor

@bhack , thank you for reply!
I'm not familiar with the distributed file system, I just thought the integration could work.
In the http://www.image-net.org, there is another big data set fall 2011, who contains about 22000 classes. And its amount is about 750GB, so maybe with this kind of large data set, the distributed training with data parallelism is necessary. The server knows where is the batched data, it takes the responsibility of updating the parameter database, if I understood correctly, similar to the game server, many clients are sending its own results of computing in their own node into the server, then the server update its status and sending the new data into the clients. This is just a immature thought, I think there are more better solutions.

@shelhamer
Copy link
Member Author

Let's first focus on finishing the parallel implementation as begun here, fix crashes, and validate the approach for ImageNet training. All possible file systems and communication protocols can't be covered here and future extension is always an option.

The most effective help is to address the TODO list and help check ILSVRC distributed training. I will give multi-GPU training a try and report back.

@shelhamer shelhamer added this to the 1.0 milestone Sep 24, 2014
@cypof
Copy link
Member

cypof commented Sep 24, 2014

In a distributed setting it's preferable to store the data files locally. Maybe sharding if they don't fit on each machine. The synchronization is likely to use all available bandwidth, or at least it should. There are probably a lot of optimization we could do on top of it, but as an API, user-space networking is supposed to be as fast as it can be, similar to writing kernel code.

@shelhamer
Copy link
Member Author

I've drafted a list of TODOs for merge and soon-to-come follow ups that include @cypof's initial TODOs -- see the PR description.

@kloudkl
Copy link
Contributor

kloudkl commented Sep 25, 2014

For such a highly complex system mixing distributed systems, single node parallelisms and machine learning algorithms, simply reading the code can hardly help to grasp the whole big picture. From the existing feedbacks, there are still a lot of aspects to be researched. It would be good to have a detailed design document hosted on the wiki for everyone to easily understand and review the internals piece by piece so that different teams could attack the problems collaboratively.

@thatguymike
Copy link
Contributor

Let's work on this in pieces. Let's get the GPU stuff cleaned up and stable. Is there a simple reproduction for instability? Sounds like just cranking up the CHUNK size in parallel.hpp causes things to break. MNIST still the right place to start?

@cypof
Copy link
Member

cypof commented Sep 25, 2014

Yes just increasing the bandwidth using CHUNK should be enough. MNIST seems stable but ImageNet makes the box unreachable after an hour of running gpus.bin on two GPUs.

@thatguymike
Copy link
Contributor

Yes, I can't reproduce with MNIST, rebuilding ImageNet now. Looking quickly at MNIST, there might be a memory leak as I watched resident memory cruise up in funny pulses during use, but it could also just be lmdb doing it's thing. I will also note there is a segfault when gpu.bin exits that traces back to the shutdown of cublas, looks like we have killed a buffer it is also trying to free.

Larger question. Looking at GPUSync, it looks like you are trying to run a separate thread hitting the GPU pulling data as fast as you can. It seems that should be interlocked with the solver state on the same GPU, but perhaps I'm missing something in what you are trying to do. What is preventing you from issuing copies back to back without anything changing on the GPU? Seems you wouldn't want to bother with a copy unless the solver has progressed.

In theory, if you are racing on the sync, as you get to larger transfers, you are just going to beat up the PCIe bus and interrupt bus traffic, could easily cause instability. Smaller chunks would give more scheduling slots for other work to slide in there. One "simple" check for that you can do would be to sleep the thread for a bit at the bottom of the loop (say 100ms) to see if things stabilze out. One more nitpick, in that infinite loop in GPUSync::run, should we check for a clean-up and exit criteria since in theory that current code is not performance critical (famous last words).

Also, it seems unclear if we always need to transfer the maximum chunk size. Seems it should only need to be the active parameters, but again, perhaps I'm misunderstanding where you are trying to go.

@cypof
Copy link
Member

cypof commented Sep 25, 2014

OK, I hope you will reproduce on imagenet. I agree that that the separate thread might cause other tasks to get delayed, but I’m not sure why that should be too much of a problem, and why the whole machine seems to hang.

I think asynchrony is the right approach because it separates the training and sync code, and that the GPU processes batches faster than we can get the gradients across the bus. Several batches will go through while the sync thread does one pass over the buffer. If the two activities are not ridiculously unbalanced, SGD should still converge, and it gets us a lot of freedom to do more sophisticated processing on the sync thread in the long run. We could skip sparse gradients, encode on less bits etc.

I think for now one way to simplify things and monitor what is going on on the bus would be to replace the data pre-fetch thread by a queue on the GPU, e.g. 100 long. The sync thread can continuously fill this queue with new batches, and only sync the gradients if the queue is already full. That would also get us better utilization of the GPUs in general by buffering against latency spikes, particularly if data is served from a network. I can't start it right now, but if you agree I could put that in the todo.

@thatguymike
Copy link
Contributor

Here is what I'm getting at. With MNIST and the default setup, I get this behavior where "step" is printed int the flow after each call inside Solve to ForwardBackward and "pull" in the GPUSync loop:

I0925 16:58:41.672260 6082 solver.cpp:195] Step
I0925 16:58:41.672756 6091 parallel.cpp:195] pull
I0925 16:58:41.673331 6088 solver.cpp:195] Step
I0925 16:58:41.673641 6103 parallel.cpp:195] pull
I0925 16:58:41.689064 6091 parallel.cpp:195] pull
I0925 16:58:41.689996 6103 parallel.cpp:195] pull
I0925 16:58:41.695577 6091 parallel.cpp:195] pull
I0925 16:58:41.696666 6103 parallel.cpp:195] pull
I0925 16:58:41.701489 6091 parallel.cpp:195] pull
I0925 16:58:41.702142 6103 parallel.cpp:195] pull
I0925 16:58:41.711835 6082 solver.cpp:195] Step
I0925 16:58:41.712268 6088 solver.cpp:195] Step
I0925 16:58:41.714303 6091 parallel.cpp:195] pull
I0925 16:58:41.714488 6103 parallel.cpp:195] pull
I0925 16:58:41.718863 6091 parallel.cpp:195] pull
I0925 16:58:41.719328 6103 parallel.cpp:195] pull
I0925 16:58:41.724057 6091 parallel.cpp:195] pull
I0925 16:58:41.725388 6103 parallel.cpp:195] pull
I0925 16:58:41.731945 6091 parallel.cpp:195] pull
I0925 16:58:41.732102 6103 parallel.cpp:195] pull
I0925 16:58:41.740545 6082 solver.cpp:195] Step
I0925 16:58:41.741618 6088 solver.cpp:195] Step
I0925 16:58:41.743661 6091 parallel.cpp:195] pull
I0925 16:58:41.744518 6103 parallel.cpp:195] pull

The data pull is outpacing the stepping. I understand your desire to be asynchronous, but I'm seeing data getting pulled faster than we process the batch. GPU call tracing confirms this. From what you write above, I would assume you would see >=1 step for each pull is the desired behavior you want.

@cypof
Copy link
Member

cypof commented Sep 26, 2014

Ah I see, but do you write the pull for each chunk of at the end of the whole buffer? Each pull should be for a different location over the weights. For imagenet at the maximum rate I could get the thing to go, I still can only synchronize the whole weight buffer only once every couple seconds.

@bhack
Copy link
Contributor

bhack commented Sep 26, 2014

@cypof Probably some of boost freelock data structures could be explored since caffe already depend on some boost modules.

@thatguymike
Copy link
Contributor

Ah, I see your point. you do multiple transfers to get all the data, so the logging should actually be inside the if check that resets the chunk offset pointer.

@thatguymike
Copy link
Contributor

I still see multiple pulls often for each step. My gut is an interlock is still needed there. It also looks like you are sending everything on the default cuda stream, so everything will end up sychronizing access to the GPU, e.g. all submissions to the device will go in order. (ImageNet still building, likely complete in the morning).

@thatguymike
Copy link
Contributor

And of course the database build died overnight with some issue in the exif data in one of the images. Super lovely. Patching the convert_imageset.cpp file to hack around the issue.

@cypof
Copy link
Member

cypof commented Sep 26, 2014

Ha! That makes me wonder how many transfer errors we would get on weeks-long runs over networks with such large bandwidths. Maybe we should add CRCs at least in RawSync. Thanks for your feedback, I still hope we can get away with no interlock, more below. I am looking at CUDA channels, it looks like that should help. You said LMDB creates thread local caches. That definitely might be unhappy with pre-fetch threads that afaik get created and destroyed on each batch.

I started an architecture document, but I don’t have much yet. Basically the same principle is used for all synchronization segments. The model weights are replicated on all nodes. For each segment, one of the ends is designated master. The master sends a set of weights to the slave. The slave compares them to a copy of it’s own weights. The copy did not change since the last message from the master. The difference between this copy and the slave weights is the gradient. The slave updates it’s weights using the master position plus the gradient, and sends the gradient to the master. On the GPU the copy is on the host to save memory, which was a bad idea. I need to look at moving it to the GPU to optimize bandwidth instead.

This mechanism allows slaves to apply other solvers’ gradients and return their own in one pass over the memory. It also does not override progress on either side. In most cases it adds each other’s gradient while keeping slaves within a bounded distance to the master. Memory is scanned in order on all nodes, which helps locality. The amount of work lost due to races during additions on each node seems negligible. On the host, for single-socket CPU, I have seen it be zero for hours-long runs with all cores running. I’m not sure how, maybe the cores always manage to own the cache line during addition for large weight arrays.

In terms of locking, there is no way for each node to know if the other side has made progress on a chunk since last exchange. I assume it’s extremely likely that at least one end will have made progress, particularly for hubs like host memory that receives updates from several other nodes and GPUS. So the best strategy is probably to let go of all synchronization and just max out all the links. As an engineering problem, there has to be a way to keep things stable once we have fixed the memory pressure and queuing problems.

SGD runs on weights that can change anytime. The hardware has to be able to avoid torn reads on float-length words, which seems to be the case also for GPUs. I have played with ways to synchronize things since last year, and it just lowers the compute utilization. SGD doesn’t seem to be affected by partial writes, but depends heavily on solvers weights being as close as possible to each other. Any latency kills performance, and locking depends either on stopping SGD while buffers are copied, or having separate buffers for transfer, which means SGD is running on stale weights and is likely to produce wasted work.

@kloudkl
Copy link
Contributor

kloudkl commented Sep 30, 2014

I don't know whether this is good news or bad news for Caffe and the larger deep learning community. A startup company Skymind has open sourced their Deeplearning4j project to run deep learning on the omnipotent big data framework Yarn which is also known as Hadoop 2.x or modern Hadoop.
https://github.com/agibsonccc/java-deeplearning
http://www.oscon.com/oscon2014/public/schedule/detail/33709
http://www.wired.com/2014/06/skymind-deep-learning/

@kloudkl
Copy link
Contributor

kloudkl commented Sep 30, 2014

Two more players:
DeepDist: Lightning-Fast Deep Learning on Spark Via parallel stochastic gradient updates
H2O: Distributed Deep Learning with H2O in the Cloud
Even the old dog Mahout is going to reinvent itself to keep relevant in the post-MapReduce big data era represented by the in-memory distributed computation framework Spark.

@kloudkl
Copy link
Contributor

kloudkl commented Sep 30, 2014

Open source distributed machine learning frameworks with deep neural network support:
Petuum: Q. Ho, J. Cipar, H. Cui, J.-K. Kim, S. Lee, P. B. Gibbons, G. Gibson, G. R. Ganger and E. P. Xing, More Effective Distributed ML via a Stale Synchronous Parallel Parameter Server. Neural Information Processing Systems, 2013 (NIPS 2013).
Metronome: Suite of parallel iterative algorithms built on top of Iterative Reduce
DimmWitted(to be open sourced): Ce Zhang, Christopher Ré. DimmWitted: A Study of Main-Memory Statistical Analytics. VLDB, 2014.

@kloudkl
Copy link
Contributor

kloudkl commented Sep 30, 2014

Among all of the alternatives, deepdist is the simplest and most intuitive implementation of Downpour SGD[1].

[1] J Dean, GS Corrado, R Monga, K Chen, M Devin, QV Le, MZ Mao, M’A Ranzato, A Senior, P Tucker, K Yang, and AY Ng. Large Scale Distributed Deep Networks. NIPS 2012: Neural Information Processing Systems, Lake Tahoe, Nevada, 2012.

@thatguymike
Copy link
Contributor

I was able to run 2 GPUs okay, but I did notice a little swapping after a few hours. With 4 GPUs things start swapping hard after ~10 minutes and effectively swap lock the machine. 8 GPUs quite quickly goes south. You can see this happen with vmstat polling at a given interval.

The current code initializes an mdb_env per thread and the resident size grows quickly. I tried to quickly hack around that by instantiating the database once, but then I get errors on the first transactions about invalid rlocks. Still working through that. I haven't been able to find examples on the right way to handle threaded readers. Anyone have ideas here or is more familiar with lmdb/leveldb and threading? This looks like the primary issue as the footprint appears near linear with the number of GPUs. However, there could also be a slow memory leak hiding somewhere as well as the memory utilization does slowly cruise up.

Secondly, we should be preallocating pinned memory and doing async transfers. This grabs memory up front instead of dynamically pinning, which can be quite expensive under heavy VM load that the database interaction currently causes. Need to check other data transfer through the stack for the same issue.

@thatguymike
Copy link
Contributor

Okay, moving to CaffeMallocHost to call cudaMallocHost seems to have massively improved stability. It appears without this the VM is struggling move things around to dynamically pin memory. I understand the concern expressed in syncedmem.hpp, we can work around that by attempting to catch the error and default to regular malloc. The problem with assuming that unpinned memory is always safe is exactly the issue being hit here, mainly if something else in the system is really putting the VM under pressure, finding memory to pin can be difficult. That also aligns with things seeming worse as the chunk size is increased as it gets even more difficult to find memory to pin and can cause the VM to move things around.

I'll continue my 4 GPU run, but I'm long past where things went bad before now. I'm still worried we are leaking memory somewhere.

@thatguymike
Copy link
Contributor

I can confirm that up to 8 GPUs changing the allocation in syncedmem.hpp to use cudaMallocHost and free to cudaFreeHost seems to fixes stability issues with multiple GPUs, even with much larger chunk sizes than the default. Waiting for clearance to submit a patch, but it's a simple change someone else can make while I get through the paperwork...

The main issue is figuring out a clean way to only allocate pinned memory if GPUs are actually being used. Fall through if GPUs are not in the system is straightforward, check the return code and switch to regular malloc/free. However, if you are only using CPUs and GPUs are in the system, you don't want to be pinning memory. I don't see a great way around that without passing through a mode flag or some other mechanism to track CPU vs GPU use.

@cypof
Copy link
Member

cypof commented Dec 24, 2014

New version finally! I had to rebranch from dev. Some of the todos and @futurely suggestions are fixed. There is still a lot to do, in particular momentum is still not working well, and ImageNet is unstable for 8+ gpus. But performance is much better and there are two new big components that are pretty promising.

The main one is an RDMA-based networking layer that allows direct communication between GPU memory and the network adapter, either InfiniBand or RoCE. It supports multicast, which turned out to help a lot. GPUs are faster at receiving than sending, so multicast allows a small output stream from each gpu to expand into a lot of bandwidth on the receiving ends.

The other big piece is a modif to data_layer that allows multiple sources, with one data-loading thread per source, one data-transform thread per GPU, and queues that prefetch asynchronously on a separate CUDA stream. We have not benchmarked too much yet but it seems to help a lot. In particular it opens only one database environment per file, and prevents threads from being continuously created and destroyed.

Latest graphs:

ImageNet

graph-ib04-acc

CIFAR-10

graph

@bhack
Copy link
Contributor

bhack commented Jan 17, 2015

@cypof and others. Anybody already checked ParamterServer caffe effort at CMU

@bhack
Copy link
Contributor

bhack commented Jan 17, 2015

See also the last doc link of this ticket

@immars
Copy link

immars commented Mar 14, 2015

compilation failure?
nvcc version 6.5, V6.5.12, g++ 4.8.2

src/caffe/parallel.cu(14): error: name followed by "::" must be a class or namespace name
          detected during instantiation of "void caffe::sync_master_kernel(Dtype *, Dtype **, size_t *, int, int, const cudaStream_t &, size_t) [with Dtype=float]"

@immars
Copy link

immars commented Mar 24, 2015

another implementation:

https://github.com/sailorsb/caffe-parallel

data parallelism over MPI

@raingo
Copy link

raingo commented Apr 23, 2015

@cypof How this graph generated? Is it on training data or testing data?

image

@cypof
Copy link
Member

cypof commented Apr 23, 2015

It was on testing data, the validation set.

@wangsheng1001
Copy link

Hi,

I got segmentation fault when running the cifar example on two nodes.
I do not run as root, but with 'setcap cap_net_raw+ep'.
Details are as follows:

Environment:
Ubuntu 13.04 kernel 3.8.0

Command on two machines:
build/examples/parallel/raw.bin examples/parallel/cifar_solver.prototxt 1 0025908a56d8:0025908a4dfc

Error:
E0609 21:21:57.113490 15274 raw.cpp:99] # of gpus 0# of cpus 1
*** Aborted at 1433856117 (unix time) try "date -d @1433856117" if you are using GNU date ***
PC: @ 0x46b802 caffe::RawSync<>::run()
*** SIGSEGV (@0x5d1aadf80) received by PID 15274 (TID 0x7d60af7fe700) from PID 18446744072932220800; stack trace: ***
@ 0x7f60bede3bd0 (unknown)
@ 0x46b802 caffe::RawSync<>::run()
@ 0x7f60c06ffaf2 thread_proxy
@ 0x7f60beddbf8e start_thread
@ 0x7f60beb05a0d (unknown)
Segmentation fault (core dumped)

More Details from gdb:
Program received signal SIGSEGV, Segmentation fault.
[Switching to Thread 0x7efff0c28700 (LWP 15611)]
0x000000000046b802 in caffe::RawSync::run() ()
(gdb) bt
#0 0x000000000046b802 in caffe::RawSync::run() ()
#1 0x00007ffff5999af2 in thread_proxy () from /users/dinhtta/local/lib/libboost_thread.so.1.56.0
#2 0x00007ffff4075f8e in start_thread (arg=0x7efff0c28700) at pthread_create.c:311
#3 0x00007ffff3d9fa0d in clone () at ../sysdeps/unix/sysv/linux/x86_64/clone.S:113

@shelhamer shelhamer removed the focus label Jun 30, 2015
@shelhamer shelhamer removed this from the Future milestone Jun 30, 2015
@shelhamer
Copy link
Member Author

Instead of attacking all axes of parallelism at once, see the multi-GPU data parallelism of #2114 for a start. Closing this since sync SGD with data parallelism covers a lot of cases. However, it clearly does not cover all so this PR can serve as the historical record for the good pieces we might pick up later.

@shelhamer shelhamer closed this Jun 30, 2015
@bhack
Copy link
Contributor

bhack commented Sep 5, 2015

To add something post-modern to the historical record I reference https://github.com/Yangqing/caffe2/issues/11

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.