diff --git a/src/symbolic-node.jl b/src/symbolic-node.jl index 62bf24d18b2b..d7584d73baa1 100644 --- a/src/symbolic-node.jl +++ b/src/symbolic-node.jl @@ -122,6 +122,35 @@ function get_internals(self :: SymbolicNode) return SymbolicNode(MX_SymbolHandle(ref_hdr[])) end +#=doc +.. function:: get_attr(self :: SymbolicNode, key :: Symbol) + + Get attribute attached to this :class:`SymbolicNode` belonging to key. +=# +function get_attr(self :: SymbolicNode, key :: Symbol) + key_s = bytestring(string(key)) + ref_out = Ref{Cstring}() + ref_success = Ref{Cint}(-1) + @mxcall(:MXSymbolGetAttr, (MX_handle, Cstring, Ref{Cstring}, Ref{Cint}), self, key_s, ref_out, ref_success) + if ref_success[] == 1 + return bytestring(ref_out[]) + else + throw(KeyError(key)) + end +end + +#=doc +.. function:: set_attr(self:: SymbolicNode, key :: Symbol, value :: AbstractString) + + Set the attribute key to value for this :class:`SymbolicNode`. +=# +function set_attr(self :: SymbolicNode, key :: Symbol, value :: AbstractString) + key_s = bytestring(string(key)) + value_s = bytestring(value) + + @mxcall(:MXSymbolSetAttr, (MX_handle, Cstring, Cstring), self, key_s, value_s) +end + #=doc .. function:: Variable(name :: Union{Base.Symbol, AbstractString}) diff --git a/test/unittest/symbolic-node.jl b/test/unittest/symbolic-node.jl index 33948adfcd40..9dabcf281c89 100644 --- a/test/unittest/symbolic-node.jl +++ b/test/unittest/symbolic-node.jl @@ -81,6 +81,14 @@ function test_saveload() rm(fname) end +function test_attrs() + info("SymbolicNode::Attributes") + + data = mx.Variable(:data) + + mx.set_attr(data, :test, "1.0") + @test mx.get_attr(data, :test) == "1.0" +end ################################################################################ # Run tests @@ -91,5 +99,6 @@ test_compose() test_infer_shape() test_infer_shape_error() test_saveload() +test_attrs() end