Skip to content

Commit

Permalink
Fix of bugs in nnvm branch (apache#169)
Browse files Browse the repository at this point in the history
* Fix build error in travis

Another string conversion fix

* Fixed JSON and added testsets

* Fixed errors in julia 0.4
  • Loading branch information
Arkoniak authored and pluskid committed Jan 13, 2017
1 parent 2393d4a commit 179daa5
Show file tree
Hide file tree
Showing 17 changed files with 160 additions and 72 deletions.
1 change: 1 addition & 0 deletions REQUIRE
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ Compat 0.9.1
Formatting
BinDeps
JSON
BaseTestNext
1 change: 1 addition & 0 deletions deps/build.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Compat

################################################################################
# First try to detect and load existing libmxnet
################################################################################
Expand Down
8 changes: 4 additions & 4 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -384,8 +384,8 @@ function fit(self :: FeedForward, optimizer :: AbstractOptimizer, data :: Abstra

train_execs = Array(Executor, num_dev)
for i = 1:num_dev
data_shapes = Dict([k => tuple(v[1:end-1]...,length(slices[i])) for (k,v) in provide_data(data)])
label_shapes = Dict([k => tuple(v[1:end-1]...,length(slices[i])) for (k,v) in provide_label(data)])
data_shapes = Dict(map((x) -> x[1] => tuple(x[2][1:end-1]...,length(slices[i])), provide_data(data)))
label_shapes = Dict(map((x) -> x[1] => tuple(x[2][1:end-1]...,length(slices[i])), provide_label(data)))
train_execs[i] = simple_bind(self.arch, self.ctx[i]; grad_req=grad_req, data_shapes..., label_shapes...)
dbg_str = mx.debug_str(train_execs[i])
info(string("TempSpace: ", split(dbg_str, ['\n'])[end-2]..., " on ", self.ctx[i]))
Expand Down Expand Up @@ -574,8 +574,8 @@ end
function save_checkpoint(sym :: SymbolicNode, arg_params :: Dict{Base.Symbol, NDArray},
aux_params :: Dict{Base.Symbol, NDArray}, prefix :: AbstractString, epoch :: Int)
save("$prefix-symbol.json", sym)
save_dict = merge(Dict([Symbol("arg:$k") => v for (k,v) in arg_params]),
Dict([Symbol("aux:$k") => v for (k,v) in aux_params]))
save_dict = merge(Dict{Base.Symbol, NDArray}(map((x) -> Symbol("arg:$(x[1])") => x[2], arg_params)),
Dict{Base.Symbol, NDArray}(map((x) -> Symbol("aux:$(x[1])") => x[2], aux_params)))
save_filename = format("{1}-{2:04d}.params", prefix, epoch)
save(save_filename, save_dict)
info("Saved checkpoint to '$save_filename'")
Expand Down
22 changes: 15 additions & 7 deletions src/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -931,11 +931,14 @@ end
function _julia_to_mx_param(val :: Any)
string(val)
end
function _julia_to_mx_param(val :: Float16)
string(val)
function _julia_to_mx_param(val :: Float64)
@sprintf("%.16e", val)
end
function _julia_to_mx_param(val :: Real)
@sprintf("%e", val)
function _julia_to_mx_param(val :: Float32)
@sprintf("%.8e", val)
end
function _julia_to_mx_param(val :: Float16)
@sprintf("%.4e", val)
end

# Import corresponding math functions from base so the automatically defined libmxnet
Expand Down Expand Up @@ -986,6 +989,9 @@ function _get_ndarray_function_def(name :: String)
end

args = collect(args) # tuple to list
if length(args) == 0
args = MX_handle[]
end

# XXX: hacky way of solving the problem that the arguments of `dot` should be swapped
# See https://github.com/dmlc/MXNet.jl/issues/55
Expand All @@ -1000,9 +1006,11 @@ function _get_ndarray_function_def(name :: String)
kwargs = Any[key != :axes ? (key, arg) : (key, reverse(map(i->length(arg)-i, arg))) for (key, arg) in kwargs]
end

output_handles = [Base.cconvert(MX_handle, x) for x in output_vars]
if length(output_handles) > 0
output_handles_pp = [Base.cconvert(Ptr{MX_handle}, output_handles)]
if length(output_vars) > 0
output_handles = map((x) -> Base.cconvert(MX_handle, x), output_vars)
# XXX: Julia 0.4 has bug: [Array{MX_handle}] == Array{MX_handle}
output_handles_pp = Array{Array{MX_handle}}(1)
output_handles_pp[1] = Base.cconvert(Ptr{MX_handle}, output_handles)
else
output_handles_pp = [Base.convert(Ptr{MX_handle}, 0)]
end
Expand Down
5 changes: 4 additions & 1 deletion src/symbolic-node.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ macro _list_symbol_info(self, func_name)
ref_sz = Ref{MX_uint}(0)
ref_names = Ref{char_pp}(0)
@mxcall($func_name, (MX_handle, Ref{MX_uint}, Ref{char_pp}),
$self, ref_sz, ref_names)
$(esc(self)), ref_sz, ref_names)
narg = ref_sz[]
names = unsafe_wrap(Array, ref_names[], narg)
names = [Symbol(unsafe_wrap(String, x)) for x in names]
Expand Down Expand Up @@ -493,6 +493,9 @@ end
function /(self :: SymbolicNode, arg :: Real)
./(self, arg)
end
function /(arg :: Real, self :: SymbolicNode)
_RDivScalar(self, scalar=arg)
end
function ./(arg :: Real, self :: SymbolicNode)
_RDivScalar(self, scalar=arg)
end
Expand Down
4 changes: 2 additions & 2 deletions src/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ function get_mnist_ubyte()
:train_label => "train-labels-idx1-ubyte",
:test_data => "t10k-images-idx3-ubyte",
:test_label => "t10k-labels-idx1-ubyte")
filenames = Dict([k => joinpath(mnist_dir, v) for (k,v) in filenames])
filenames = Dict(map((x) -> x[1] => joinpath(mnist_dir, x[2]), filenames))
if !all(isfile, values(filenames))
cd(mnist_dir) do
mnist_dir = download("http://data.dmlc.ml/mxnet/data/mnist.zip", "mnist.zip")
Expand All @@ -38,7 +38,7 @@ function get_cifar10()
cifar10_dir = joinpath(data_dir, "cifar10")
mkpath(cifar10_dir)
filenames = Dict(:train => "cifar/train.rec", :test => "cifar/test.rec")
filenames = Dict([k => joinpath(cifar10_dir, v) for (k,v) in filenames])
filenames = Dict(map((x) -> x[1] => joinpath(cifar10_dir, x[2]), filenames))
if !all(isfile, values(filenames))
cd(cifar10_dir) do
run(`http://data.dmlc.ml/mxnet/data/cifar10.zip`)
Expand Down
19 changes: 14 additions & 5 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
using MXNet
using Base.Test
if VERSION v"0.5.0-dev+7720"
using Base.Test
else
using BaseTestNext
const Test = BaseTestNext
end

# run test in the whole directory, latest modified files
# are run first, this makes waiting time shorter when writing
Expand All @@ -12,9 +17,13 @@ function test_dir(dir)
end

include(joinpath(dirname(@__FILE__), "common.jl"))
test_dir(joinpath(dirname(@__FILE__), "unittest"))
@testset "MXNet Test" begin
test_dir(joinpath(dirname(@__FILE__), "unittest"))

# run the basic MNIST mlp example
if haskey(ENV, "CONTINUOUS_INTEGRATION")
include(joinpath(Pkg.dir("MXNet"), "examples", "mnist", "mlp-test.jl"))
# run the basic MNIST mlp example
if haskey(ENV, "CONTINUOUS_INTEGRATION")
@testset "MNIST Test" begin
include(joinpath(Pkg.dir("MXNet"), "examples", "mnist", "mlp-test.jl"))
end
end
end
1 change: 1 addition & 0 deletions test/travis/setup_env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@ if [ ${TRAVIS_OS_NAME} == "linux" ]; then
mkdir shadow_bin
ln -s `which gcc-4.8` shadow_bin/gcc
ln -s `which g++-4.8` shadow_bin/g++

export PATH=$PWD/shadow_bin:$PATH
fi
11 changes: 9 additions & 2 deletions test/unittest/bind.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
module TestBind
using MXNet
using Base.Test
if VERSION v"0.5.0-dev+7720"
using Base.Test
else
using BaseTestNext
const Test = BaseTestNext
end

using ..Main: rand_dims, reldiff

Expand Down Expand Up @@ -70,7 +75,9 @@ end
################################################################################
# Run tests
################################################################################
test_arithmetic()
@testset "Bind Test" begin
test_arithmetic()
end

end

15 changes: 11 additions & 4 deletions test/unittest/io.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
module TestIO
using MXNet
using Base.Test
if VERSION v"0.5.0-dev+7720"
using Base.Test
else
using BaseTestNext
const Test = BaseTestNext
end

using ..Main: rand_dims, reldiff

Expand Down Expand Up @@ -117,8 +122,10 @@ function test_arrays_shuffle()
@test reldiff(data_got, data[:,Int[label_got...]]) < 1e-6
end

test_arrays_shuffle()
test_arrays()
test_mnist()
@testset "IO Test" begin
test_arrays_shuffle()
test_arrays()
test_mnist()
end

end
15 changes: 11 additions & 4 deletions test/unittest/kvstore.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
module TestKVStore
using MXNet
using Base.Test
if VERSION v"0.5.0-dev+7720"
using Base.Test
else
using BaseTestNext
const Test = BaseTestNext
end

using ..Main: rand_dims

Expand Down Expand Up @@ -62,8 +67,10 @@ function test_aggregator()
end
end

test_kv_basic()
test_single_kv_pair()
test_aggregator()
@testset "KVStore Test" begin
test_kv_basic()
test_single_kv_pair()
test_aggregator()
end

end
13 changes: 10 additions & 3 deletions test/unittest/name.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
module TestNameManager
using MXNet
using Base.Test
if VERSION v"0.5.0-dev+7720"
using Base.Test
else
using BaseTestNext
const Test = BaseTestNext
end

function test_default()
info("NameManager::default")
Expand All @@ -25,7 +30,9 @@ function test_prefix()
@test get!(prefix_manager, "", name) == Symbol("$prefix$(name)0")
end

test_default()
test_prefix()
@testset "Name Test" begin
test_default()
test_prefix()
end

end
53 changes: 30 additions & 23 deletions test/unittest/ndarray.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
module TestNDArray
using MXNet
using Base.Test
if VERSION v"0.5.0-dev+7720"
using Base.Test
else
using BaseTestNext
const Test = BaseTestNext
end

using ..Main: rand_dims, reldiff

Expand Down Expand Up @@ -121,8 +126,8 @@ function test_plus()
a6 = copy(t6, mx.cpu())
scalar_small = Float16(1e-5)
scalar_large = Float16(1e4)
@test reldiff(t6 + scalar_small, copy(a6 .+ scalar_small)) < 1e-2
@test reldiff(t6 + scalar_large, copy(a6 .+ scalar_large)) < 1e-2
@test reldiff(t6 + scalar_small, copy(a6 .+ scalar_small)) < 1e-1
@test reldiff(t6 + scalar_large, copy(a6 .+ scalar_large)) < 1e-1
end

function test_minus()
Expand Down Expand Up @@ -172,8 +177,8 @@ function test_minus()
a6 = copy(t6, mx.cpu())
scalar_small = Float16(1e-5)
scalar_large = Float16(1e4)
@test reldiff(t6 - scalar_small, copy(a6 .- scalar_small)) < 1e-2
@test reldiff(t6 - scalar_large, copy(a6 .- scalar_large)) < 1e-2
@test reldiff(t6 - scalar_small, copy(a6 .- scalar_small)) < 1e-1
@test reldiff(t6 - scalar_large, copy(a6 .- scalar_large)) < 1e-1
end

function test_mul()
Expand Down Expand Up @@ -213,7 +218,7 @@ function test_mul()

t6, a6 = rand_tensors(Float16, dims)
scalar_small = Float16(1e-5)
@test reldiff(t6 * scalar_small, copy(a6 .* scalar_small)) < 1e-2
@test reldiff(t6 * scalar_small, copy(a6 .* scalar_small)) < 1e-1
end

function test_div()
Expand Down Expand Up @@ -254,7 +259,7 @@ function test_div()

t6, a6 = rand_tensors(Float16, dims)
scalar_large = 1e4
@test reldiff(t6 / scalar_large, copy(a6 ./ scalar_large)) < 1e-2
@test reldiff(t6 / scalar_large, copy(a6 ./ scalar_large)) < 1e-1
end

function test_gd()
Expand Down Expand Up @@ -300,7 +305,7 @@ function test_saveload()

# save and load dictionary of ndarrays
names = [Symbol("array$i") for i = 1:n_arrays]
dict = Dict([n => v for (n,v) in zip(names, nd_arrays)])
dict = Dict([(n, v) for (n,v) in zip(names, nd_arrays)])
mx.save(fname, dict)
data = mx.load(fname, mx.NDArray)
@test isa(data, Dict{Symbol, mx.NDArray})
Expand Down Expand Up @@ -397,20 +402,22 @@ end
################################################################################
# Run tests
################################################################################
test_assign()
test_copy()
test_slice()
test_plus()
test_minus()
test_mul()
test_div()
test_gd()
test_saveload()
test_clip()
test_sqrt()
test_eltype()
test_nd_as_jl()
test_dot()
test_kwargs()
@testset "NDArray Test" begin
test_assign()
test_copy()
test_slice()
test_plus()
test_minus()
test_mul()
test_div()
test_gd()
test_saveload()
test_clip()
test_sqrt()
test_eltype()
test_nd_as_jl()
test_dot()
test_kwargs()
end

end
12 changes: 10 additions & 2 deletions test/unittest/operator.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
module TestOperator
using MXNet
using Base.Test
if VERSION v"0.5.0-dev+7720"
using Base.Test
else
using BaseTestNext
const Test = BaseTestNext
end

using ..Main: rand_dims, reldiff

Expand Down Expand Up @@ -31,6 +36,9 @@ end
################################################################################
# Run tests
################################################################################
test_scalar_op()

@testset "Operator Test" begin
test_scalar_op()
end

end
Loading

0 comments on commit 179daa5

Please sign in to comment.