Skip to content

Commit

Permalink
Merge pull request apache#35 from vchuravy/vc/symbol_attr
Browse files Browse the repository at this point in the history
basic interface for setting and getting attributes
  • Loading branch information
pluskid committed Nov 23, 2015
2 parents 2c72446 + 46257d2 commit 05bcddf
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 0 deletions.
29 changes: 29 additions & 0 deletions src/symbolic-node.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,35 @@ function get_internals(self :: SymbolicNode)
return SymbolicNode(MX_SymbolHandle(ref_hdr[]))
end

#=doc
.. function:: get_attr(self :: SymbolicNode, key :: Symbol)
Get attribute attached to this :class:`SymbolicNode` belonging to key.
=#
function get_attr(self :: SymbolicNode, key :: Symbol)
key_s = bytestring(string(key))
ref_out = Ref{Cstring}()
ref_success = Ref{Cint}(-1)
@mxcall(:MXSymbolGetAttr, (MX_handle, Cstring, Ref{Cstring}, Ref{Cint}), self, key_s, ref_out, ref_success)
if ref_success[] == 1
return bytestring(ref_out[])
else
throw(KeyError(key))
end
end

#=doc
.. function:: set_attr(self:: SymbolicNode, key :: Symbol, value :: AbstractString)
Set the attribute key to value for this :class:`SymbolicNode`.
=#
function set_attr(self :: SymbolicNode, key :: Symbol, value :: AbstractString)
key_s = bytestring(string(key))
value_s = bytestring(value)

@mxcall(:MXSymbolSetAttr, (MX_handle, Cstring, Cstring), self, key_s, value_s)
end

#=doc
.. function:: Variable(name :: Union{Base.Symbol, AbstractString})
Expand Down
9 changes: 9 additions & 0 deletions test/unittest/symbolic-node.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,14 @@ function test_saveload()
rm(fname)
end

function test_attrs()
info("SymbolicNode::Attributes")

data = mx.Variable(:data)

mx.set_attr(data, :test, "1.0")
@test mx.get_attr(data, :test) == "1.0"
end

################################################################################
# Run tests
Expand All @@ -91,5 +99,6 @@ test_compose()
test_infer_shape()
test_infer_shape_error()
test_saveload()
test_attrs()

end

0 comments on commit 05bcddf

Please sign in to comment.