Skip to content

Commit

Permalink
Merge pull request apache#229 from dmlc/vc/bilinear_initializer
Browse files Browse the repository at this point in the history
fixes bilinear initializer following approach in apache#34
  • Loading branch information
vchuravy authored Apr 13, 2017
2 parents 7a6120a + 8947ead commit 44a6c36
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 22 deletions.
43 changes: 21 additions & 22 deletions src/initializer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,29 +48,28 @@ function _init_loc_bias(self :: AbstractInitializer, name :: Base.Symbol, array
end

function _init_bilinear(self :: AbstractInitializer, name :: Base.Symbol, array :: NDArray)
# ported from python version:
#weight = np.zeros(np.prod(arr.shape), dtype='float32')
#shape = arr.shape
#f = np.ceil(shape[3] / 2.)
#c = (2 * f - 1 - f % 2) / (2. * f)
#for i in range(np.prod(shape)):
# x = i % shape[3]
# y = (i / shape[3]) % shape[2]
# weight[i] = (1 - abs(x / f - c)) * (1 - abs(y / f - c))
#arr[:] = weight.reshape(shape)

weight=zeros(array)

h,w,channels,n=size(array)
f = ceil(w / 2.)
c = (2 * f - 1 - f % 2) / (2. * f)

for i=1:length(weight)
x = i % w
y = (i / w) % h
weight[i] = (1 - abs(x / f - c)) * (1 - abs(y / f - c))
@assert ndims(array) == 4

W, H, C, N = size(array) # Inverse of NCHW layout
filter = Base.zeros(eltype(array), W, H)

@assert H == W

f = ceil(Int, W / 2) # factor
c = (2 * f - 1 - f % 2) / (2 * f) # center
for x in 0:(W-1)
for y in 0:(H-1)
filter[x+1, y+1] = (1 - abs(x / f - c)) * (1 - abs(y / f - c))
end
end

@nd_as_jl rw=array begin
for i in 1:N
for j in 1:C
array[:,:, j, i] = filter
end
end
end
array[:,:,:,:]=weight
end

function _init_bias(self :: AbstractInitializer, name :: Base.Symbol, array :: NDArray)
Expand Down
18 changes: 18 additions & 0 deletions test/unittest/initializer.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
@testset "Initializers" begin
@testset "Bilinear initializer" begin
# Setup a filter with scale = 2
expectedFilter = Float32[
0.0625 0.1875 0.1875 0.0625;
0.1875 0.5625 0.5625 0.1875;
0.1875 0.5625 0.5625 0.1875;
0.0625 0.1875 0.1875 0.0625]
filter = mx.zeros(Float32, 4, 4, 1, 4)
mx.init(mx.XavierInitializer(), :upsampling0_weight, filter)

mx.@nd_as_jl ro=filter begin
for s in 1:size(filter, 4)
@test all(filter[:, :, 1, s] .== expectedFilter)
end
end
end
end

0 comments on commit 44a6c36

Please sign in to comment.