You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
My use-case is a type of EmbeddingLayer where the dictionary size is small (<1M).
I'm implementing a HashGridEncoding layer for training NERFs
Nd4j.ScatterUpdates is slow. In fact, even slower than a naive CPU implementation.
Nd4j.ScatterUpdates seems to run in constant time relative to the size of the dictionary and seems to run in linear time relative to the number of updates.
This is problematic for an EmbeddingLayer with a small Dictionary size and a large batch size
expected behavior
Nd4j.ScatterUpdates is quick with minimal overhead.
I would expect it to be on par with Nd4j.pullRows which is a lot faster (but also a bottleneck in the forward pass).
encountered behavior
Nd4j.ScatterUpdates runs almost in constant time relative to the size of the array to update but it has a large overhead.
public static void doScatterUpdateCPUWorkaround(INDArray weights, int[] indexes, INDArray updates){
float[][] tempWeights = new float[dictionarySize[0]][dictionarySize[1]];
float[][] tempUpdates = new float[dictionarySize[0]][dictionarySize[1]];
updates = updates.dup('f');
for (int j = 0; j < dictionarySize[1]; j++) {
INDArray column = updates.getColumn(j);
tempUpdates[j] = column.data().asFloat();
}
for (int i = 0; i < indexes.length; i++) {
for (int j = 0; j < dictionarySize[1]; j++) {
tempWeights[indexes[i]][j]+=tempUpdates[j][i];
}
}
INDArray reshape = Nd4j.create(tempWeights);
weights.addi(reshape);
}
by a factor of up to 50x (see the test class below)
The CPU workaround doesn't scale.
Nd4j.scatterUpdates runs in linear time relative to the number of updates but even for a smaller amount of updates, it takes an abnormal about of time, perhaps some optimization is possible.
Since Nd4j.pullRows does essentially the same job, it might be worth comparing. It's much faster.
OS: windows
Nvidia RTX 3060 laptop GPU (driver: 31.0.15.3667)
cudnn is installed too if it's used at all
Test class
public class test {
static int[] dictionarySize = {10000,2};
static int indexesSize = 1000000;
public static void main(String[] args) {
INDArray indices = Nd4j.create(indexesSize).assign(1).castTo(DataType.INT32);
int[] indexes = indices.data().asInt();
indices.reshape(indices.size(0),1);
INDArray weights = Nd4j.create(dictionarySize).assign(6);
INDArray updates = Nd4j.create(indexesSize, weights.size(1)).assign(4);
doScatterUpdate(weights, indices, updates);
assert weights.getFloat(1,0) == 40006;
weights = Nd4j.create(dictionarySize).assign(6);
updates = Nd4j.create(indexesSize, weights.size(1)).assign(4);
doScatterUpdateCPUWorkaround(weights, indexes, updates);
assert weights.getFloat(1,0) == 40006;
// warmup
for (int i = 0; i < 10; i++) {
weights = Nd4j.create(dictionarySize).assign(6);
updates = Nd4j.create(indexesSize, weights.size(1)).assign(4);
doScatterUpdate(weights, indices, updates);
weights = Nd4j.create(dictionarySize).assign(6);
updates = Nd4j.create(indexesSize, weights.size(1)).assign(4);
doScatterUpdateCPUWorkaround(weights, indexes, updates);
}
long start = System.currentTimeMillis();
for (int i = 0; i < 10; i++) {
weights = Nd4j.create(dictionarySize).assign(6);
updates = Nd4j.create(indexesSize, weights.size(1)).assign(4);
doScatterUpdate(weights, indices, updates);
}
System.out.println("scatterUpdates Nd4j : "+(System.currentTimeMillis()-start)+ " ms");
start = System.currentTimeMillis();
for (int i = 0; i < 10; i++) {
weights = Nd4j.create(dictionarySize).assign(6);
updates = Nd4j.create(indexesSize, weights.size(1)).assign(4);
doScatterUpdateCPUWorkaround(weights, indexes, updates);
}
System.out.println("scatterUpdates CPU : "+(System.currentTimeMillis()-start)+ " ms");
}
public static void doScatterUpdate(INDArray weights, INDArray indices, INDArray updates){
Nd4j.scatterUpdate(ScatterUpdate.UpdateOp.ADD, weights, indices, updates, 1);
}
public static void doScatterUpdateCPUWorkaround(INDArray weights, int[] indexes, INDArray updates){
float[][] tempWeights = new float[dictionarySize[0]][dictionarySize[1]];
float[][] tempUpdates = new float[dictionarySize[0]][dictionarySize[1]];
updates = updates.dup('f');
for (int j = 0; j < dictionarySize[1]; j++) {
INDArray column = updates.getColumn(j);
tempUpdates[j] = column.data().asFloat();
}
for (int i = 0; i < indexes.length; i++) {
for (int j = 0; j < dictionarySize[1]; j++) {
tempWeights[indexes[i]][j]+=tempUpdates[j][i];
}
}
INDArray reshape = Nd4j.create(tempWeights);
weights.addi(reshape);
}
}
Contributing
I have little experience with C++ so it's a risk to let me contribute on that although if you tell me there's no demand and no time to improve this, I will have a look.
The text was updated successfully, but these errors were encountered:
Issue Description
My use-case is a type of EmbeddingLayer where the dictionary size is small (<1M).
I'm implementing a HashGridEncoding layer for training NERFs
Nd4j.ScatterUpdates is slow. In fact, even slower than a naive CPU implementation.
Nd4j.ScatterUpdates seems to run in constant time relative to the size of the dictionary and seems to run in linear time relative to the number of updates.
This is problematic for an EmbeddingLayer with a small Dictionary size and a large batch size
expected behavior
Nd4j.ScatterUpdates is quick with minimal overhead.
I would expect it to be on par with Nd4j.pullRows which is a lot faster (but also a bottleneck in the forward pass).
encountered behavior
Nd4j.ScatterUpdates runs almost in constant time relative to the size of the array to update but it has a large overhead.
until around 1M "weights", this:
is slower than this
by a factor of up to 50x (see the test class below)
The CPU workaround doesn't scale.
Nd4j.scatterUpdates runs in linear time relative to the number of updates but even for a smaller amount of updates, it takes an abnormal about of time, perhaps some optimization is possible.
Since Nd4j.pullRows does essentially the same job, it might be worth comparing. It's much faster.
Version Information
OS: windows
Nvidia RTX 3060 laptop GPU (driver: 31.0.15.3667)
cudnn is installed too if it's used at all
Test class
Contributing
I have little experience with C++ so it's a risk to let me contribute on that although if you tell me there's no demand and no time to improve this, I will have a look.
The text was updated successfully, but these errors were encountered: