From 6c6c013a1a5c6c38e00f7bd4489ff1973ce53330 Mon Sep 17 00:00:00 2001 From: Minjie Wang Date: Tue, 7 Jul 2015 21:58:18 -0400 Subject: [PATCH 01/14] minor change --- Makefile | 9 +++++---- src/dag_engine/threaded_engine.cc | 13 ++++++------ test/test_threaded_engine.cc | 33 +++++++++++++++++++++++++++++++ 3 files changed, 45 insertions(+), 10 deletions(-) diff --git a/Makefile b/Makefile index e037062bd9fa..155ef6baf7c5 100644 --- a/Makefile +++ b/Makefile @@ -46,9 +46,9 @@ ifneq ($(ADD_LDFLAGS), NONE) LDFLAGS += $(ADD_LDFLAGS) endif -BIN = test/api_registry_test +BIN = test/test_threaded_engine #test/api_registry_test OBJ = storage.o narray_op_cpu.o operator.o operator_cpu.o -OBJCXX11 = engine.o narray.o mxnet_api.o api_registry.o threaded_engine.o +OBJCXX11 = engine.o narray.o api_registry.o CUOBJ = narray_op_gpu.o operator_gpu.o SLIB = api/libmxnet.so ALIB = api/libmxnet.a @@ -62,8 +62,8 @@ $(DMLC_CORE)/libdmlc.a: + cd $(DMLC_CORE); make libdmlc.a config=$(ROOTDIR)/$(config); cd $(ROOTDIR) storage.o: src/storage/storage.cc -engine.o: src/dag_engine/simple_engine.cc -threaded_engine.o: src/dag_engine/threaded_engine.cc src/common/concurrent_blocking_queue.h src/common/spin_lock.h +#engine.o: src/dag_engine/simple_engine.cc +engine.o: src/dag_engine/threaded_engine.cc src/common/concurrent_blocking_queue.h src/common/spin_lock.h narray.o: src/narray/narray.cc narray_op_cpu.o: src/narray/narray_op_cpu.cc src/narray/narray_op-inl.h narray_op_gpu.o: src/narray/narray_op_gpu.cu src/narray/narray_op-inl.h @@ -77,6 +77,7 @@ api/libmxnet.a: $(OBJ) $(OBJCXX11) $(CUOBJ) api/libmxnet.so: $(OBJ) $(OBJCXX11) $(CUOBJ) test/api_registry_test: test/api_registry_test.cc api/libmxnet.a +test/test_threaded_engine: test/test_threaded_engine.cc api/libmxnet.a $(BIN) : $(CXX) $(CFLAGS) -std=c++11 -o $@ $(filter %.cpp %.o %.c %.a %.cc, $^) $(LDFLAGS) diff --git a/src/dag_engine/threaded_engine.cc b/src/dag_engine/threaded_engine.cc index 143b5e72f413..c85ff8e8dc24 100644 --- a/src/dag_engine/threaded_engine.cc +++ b/src/dag_engine/threaded_engine.cc @@ -67,17 +67,18 @@ class ThreadedEngine : public DAGEngine { exec_fun(ctx); on_complete(); }, exec_ctx, use_vars, mutate_vars); } - void PushDelete(Op delete_fun, Variable var) override { + void PushDelete(Op delete_fun, Context exec_ctx, Variable var) override { // TODO this->Push([delete_fun, var] (RunContext ctx) { delete_fun(ctx); - delete static_cast(var); - }, Context()/* TODO exec_ctx is missing?*/, {}, {var}); + delete static_cast(var); // TODO use variable pool instead + }, exec_ctx, {}, {var}); } Variable NewVar() override { // in practice return a ptr to a cell // that have the info about the variable // use ptr directly instead of ID because this avoids an indirect mapping + // TODO use variable pool instead VarDescr* vd = new VarDescr; vd->lock = SPINLOCK_INITIALIZER; vd->rw = 0; @@ -119,7 +120,6 @@ class ThreadedEngine : public DAGEngine { ++vard->rw; } if (vard->rw == 0) { - // if the next one is a delete // pop the next write vard->waitings.pop(); vard->rw = -1; @@ -153,14 +153,15 @@ class ThreadedEngine : public DAGEngine { } void OnDepsResolved(OpDescr* opd) { static default_random_engine generator; - static uniform_int_distribution distribution(0, numthreads_); + static uniform_int_distribution distribution(0, numthreads_ - 1); int thrid = distribution(generator); + //LOG(INFO) << "schedule operator " << opd << " to thread #" << thrid; worker_queues_[thrid]->Push(opd); } void WorkerRoutine(int thrid) { OpDescr* opd = nullptr; while(! worker_queues_[thrid]->Pop(opd)) { - LOG(INFO) << "worker thread #" << thrid << " got operator " << opd; + //LOG(INFO) << "worker thread #" << thrid << " got operator " << opd; opd->op(GetRunContext(opd->exec_ctx), [this, opd] () { this->OnOpFinished(opd); }); opd = nullptr; } diff --git a/test/test_threaded_engine.cc b/test/test_threaded_engine.cc index 40dea029cf6e..e1dead4af1a5 100644 --- a/test/test_threaded_engine.cc +++ b/test/test_threaded_engine.cc @@ -1,9 +1,42 @@ +#include +#include +#include + #include using namespace std; using namespace mxnet; +void Foo(RunContext rctx, int i) { + cout << "say: " << i << endl; +} + int main() { DAGEngine* engine = DAGEngine::Get(); + Context exec_ctx; + + // Test #1 + cout << "============= Test #1 ==============" << endl; + vector vars; + for(int i = 0; i < 10; ++i) { + vars.push_back(engine->NewVar()); + } + for(int i = 0; i < 10; ++i) { + engine->Push([i] (RunContext rctx) { Foo(rctx, i); }, + exec_ctx, vars, {}); + } + + usleep(1000000); + + // Test #2 + cout << "============= Test #2 ==============" << endl; + for(int i = 0; i < 10; ++i) { + engine->Push([i] (RunContext rctx) { Foo(rctx, i); }, + exec_ctx, {}, vars); + } + + usleep(1000000); + + // Test #3 return 0; } From 9fbaa6bfc633f9148b786f36ed06b41bf8fcb31e Mon Sep 17 00:00:00 2001 From: Bing Xu Date: Tue, 7 Jul 2015 21:16:58 -0600 Subject: [PATCH 02/14] fix doc --- doc/Doxyfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/Doxyfile b/doc/Doxyfile index b3d9d7fdbb81..f1f8f62bf4c0 100644 --- a/doc/Doxyfile +++ b/doc/Doxyfile @@ -773,7 +773,7 @@ INPUT_ENCODING = UTF-8 # *.md, *.mm, *.dox, *.py, *.f90, *.f, *.for, *.tcl, *.vhd, *.vhdl, *.ucf, # *.qsf, *.as and *.js. -FILE_PATTERNS = *.cc *.h +FILE_PATTERNS = *.h # The RECURSIVE tag can be used to specify whether or not subdirectories should # be searched for input files as well. From 7edd02bb6c13c7aad505bc8404a0b3af2a6b8f82 Mon Sep 17 00:00:00 2001 From: Minjie Wang Date: Wed, 8 Jul 2015 00:44:55 -0400 Subject: [PATCH 03/14] fix style --- src/common/concurrent_blocking_queue.h | 54 +++++++++++++++++---- src/common/spin_lock.h | 29 ++++++++--- src/dag_engine/threaded_engine.cc | 66 +++++++++++++------------- test/test_threaded_engine.cc | 11 +++-- 4 files changed, 109 insertions(+), 51 deletions(-) diff --git a/src/common/concurrent_blocking_queue.h b/src/common/concurrent_blocking_queue.h index aab39895b119..14bab00d8280 100644 --- a/src/common/concurrent_blocking_queue.h +++ b/src/common/concurrent_blocking_queue.h @@ -1,4 +1,11 @@ -#pragma once +/*! + * Copyright (c) 2015 by Contributors + * \file concurrent_blocking_queue.h + * \brief A simple lock-based consumer-producer queue. + */ +#ifndef MXNET_COMMON_CONCURRENT_BLOCKING_QUEUE_H_ +#define MXNET_COMMON_CONCURRENT_BLOCKING_QUEUE_H_ + #include #include #include @@ -6,11 +13,21 @@ #include #include +namespace common { + +/*! + * \brief A simple lock-based consumer-producer queue. + */ template class ConcurrentBlockingQueue { - const static int BUSY_LOOP = 1000; + static const int kBusyLoop = 1000; + public: ConcurrentBlockingQueue() : has_elmt_(false), exit_now_(false) { } + /*! + * \brief Push object into the queue. Notify anyone who is waiting. + * \param e the object + */ void Push(const T& e) { std::lock_guard lock(mutex_); has_elmt_ = true; @@ -19,15 +36,22 @@ template class ConcurrentBlockingQueue { cv_.notify_all(); } } - bool Pop(T& rv) { - for (int i = 0; i < BUSY_LOOP; i++) { + /*! + * \brief Pop object out of the queue. If the queue is empty, the caller thread will sleep until + * (1) Producer pushed some product into the queue and the caller thread wins it. + * (2) A kill signal is passed to the queue. + * \param rv the pointer point to the return object + * \return whether an object is returned + */ + bool Pop(T* rv) { + for (int i = 0; i < kBusyLoop; i++) { if (has_elmt_) { std::lock_guard lock(mutex_); if (!has_elmt_) { assert(queue_.empty()); continue; } - rv = queue_.front(); + *rv = queue_.front(); queue_.pop_front(); if (queue_.empty()) has_elmt_ = false; @@ -40,28 +64,38 @@ template class ConcurrentBlockingQueue { cv_.wait(lock); } if (!exit_now_) { - rv = queue_.front(); + *rv = queue_.front(); queue_.pop_front(); if (queue_.empty()) has_elmt_ = false; return false; } else { - return true; + return true; } } } + /*! + * \brief pop all objects in the queue. + * \return a list containing all objects in the queue. + */ std::list PopAll() { std::lock_guard lock(mutex_); std::list rv; rv.swap(queue_); return rv; } - // Call `SignalForKill` before destruction + /*! + * \brief tell the queue to release all waiting consumers + */ void SignalForKill() { std::unique_lock lock(mutex_); exit_now_ = true; cv_.notify_all(); } + /*! + * \brief return the current queue size + * \return queue size + */ size_t QueueSize() { std::unique_lock lock(mutex_); return queue_.size(); @@ -77,3 +111,7 @@ template class ConcurrentBlockingQueue { ConcurrentBlockingQueue(const ConcurrentBlockingQueue&) = delete; ConcurrentBlockingQueue& operator=(const ConcurrentBlockingQueue&) = delete; }; + +} // namespace common + +#endif // MXNET_COMMON_CONCURRENT_BLOCKING_QUEUE_H_ diff --git a/src/common/spin_lock.h b/src/common/spin_lock.h index 5a0cc3f786e6..60850f171ecf 100644 --- a/src/common/spin_lock.h +++ b/src/common/spin_lock.h @@ -1,17 +1,18 @@ -#ifndef _SPINLOCK_XCHG_H -#define _SPINLOCK_XCHG_H - -/* Spin lock using xchg. +/* Copyright (c) 2015 by Contributors + * Spin lock using xchg. * Copied from http://locklessinc.com/articles/locks/ */ +#ifndef MXNET_COMMON_SPIN_LOCK_H_ +#define MXNET_COMMON_SPIN_LOCK_H_ + /* Compile read-write barrier */ #define barrier() asm volatile("": : :"memory") /* Pause instruction to prevent excess processor bus usage */ #define cpu_relax() asm volatile("pause\n": : :"memory") -static inline unsigned short xchg_8(void *ptr, unsigned char x) { +static inline unsigned short xchg_8(void *ptr, unsigned char x) { // NOLINT(*) __asm__ __volatile__("xchgb %0,%1" :"=r" (x) :"m" (*(volatile unsigned char *)ptr), "0" (x) @@ -23,8 +24,15 @@ static inline unsigned short xchg_8(void *ptr, unsigned char x) { #define BUSY 1 typedef unsigned char spinlock; +/*! + * \brief use this value to initialize lock object + */ #define SPINLOCK_INITIALIZER 0 +/*! + * \brief lock + * \param lock the pointer to lock object + */ static inline void spin_lock(spinlock *lock) { while (1) { if (!xchg_8(lock, BUSY)) return; @@ -33,13 +41,22 @@ static inline void spin_lock(spinlock *lock) { } } +/*! + * \brief unlock + * \param lock the pointer to lock object + */ static inline void spin_unlock(spinlock *lock) { barrier(); *lock = 0; } +/*! + * \brief try lock + * \param lock the pointer to lock object + * \return whether the lock is grabbed or not + */ static inline int spin_trylock(spinlock *lock) { return xchg_8(lock, BUSY); } -#endif /* _SPINLOCK_XCHG_H */ +#endif // MXNET_COMMON_SPIN_LOCK_H_ diff --git a/src/dag_engine/threaded_engine.cc b/src/dag_engine/threaded_engine.cc index c85ff8e8dc24..e5b44d5d1db2 100644 --- a/src/dag_engine/threaded_engine.cc +++ b/src/dag_engine/threaded_engine.cc @@ -1,3 +1,4 @@ +// Copyright (c) 2015 by Contributors #include #include #include @@ -6,8 +7,8 @@ #include #include -#include -#include +#include "dmlc/logging.h" +#include "mxnet/dag_engine.h" #include "../common/spin_lock.h" #include "../common/concurrent_blocking_queue.h" @@ -19,14 +20,14 @@ namespace mxnet { class ThreadedEngine : public DAGEngine { public: - ThreadedEngine(int numthreads = DEFAULT_NUM_WORKER_THREADS): numthreads_(numthreads) { - for(int i = 0; i < numthreads; ++i) { + explicit ThreadedEngine(int numthreads = DEFAULT_NUM_WORKER_THREADS): numthreads_(numthreads) { + for (int i = 0; i < numthreads; ++i) { worker_queues_.push_back(new ConcurrentBlockingQueue()); workers_.emplace_back(&ThreadedEngine::WorkerRoutine, this, i); } } ~ThreadedEngine() { - for(int i = 0; i < numthreads_; ++i) { + for (int i = 0; i < numthreads_; ++i) { worker_queues_[i]->SignalForKill(); delete worker_queues_[i]; workers_[i].join(); @@ -36,10 +37,10 @@ class ThreadedEngine : public DAGEngine { Context exec_ctx, const vector &use_vars, const vector &mutate_vars) override { - shared_ptr opd( new OpDescr{exec_fun, exec_ctx, use_vars, mutate_vars}, - [this] (OpDescr* o) { this->OnDepsResolved(o); } ); - for( Variable v : use_vars ) { // read - VarDescr* vard = static_cast(v); // safe to cast here + shared_ptr opd(new OpDescr{exec_fun, exec_ctx, use_vars, mutate_vars}, + [this] (OpDescr* o) { this->OnDepsResolved(o); }); + for ( Variable v : use_vars ) { // read + VarDescr* vard = static_cast(v); // safe to cast here spin_lock(&vard->lock); if (vard->rw < 0) { vard->waitings.push(make_pair(opd, DepType::kRead)); @@ -48,8 +49,8 @@ class ThreadedEngine : public DAGEngine { } spin_unlock(&vard->lock); } - for( Variable v : mutate_vars ) { // write - VarDescr* vard = static_cast(v); // safe to cast here + for ( Variable v : mutate_vars ) { // write + VarDescr* vard = static_cast(v); // safe to cast here spin_lock(&vard->lock); if (vard->rw != 0) { vard->waitings.push(make_pair(opd, DepType::kWrite)); @@ -68,28 +69,28 @@ class ThreadedEngine : public DAGEngine { }, exec_ctx, use_vars, mutate_vars); } void PushDelete(Op delete_fun, Context exec_ctx, Variable var) override { - // TODO this->Push([delete_fun, var] (RunContext ctx) { delete_fun(ctx); - delete static_cast(var); // TODO use variable pool instead + delete static_cast(var); // TODO(minjie): use variable pool instead }, exec_ctx, {}, {var}); } Variable NewVar() override { // in practice return a ptr to a cell // that have the info about the variable // use ptr directly instead of ID because this avoids an indirect mapping - // TODO use variable pool instead + // TODO(minjie): use variable pool instead VarDescr* vd = new VarDescr; vd->lock = SPINLOCK_INITIALIZER; vd->rw = 0; return vd; } void WaitForVar(Variable var) override { - // TODO + // TODO(minjie): tbd } void WaitForAll() override { - // TODO + // TODO(minjie): tbd } + private: enum class DepType { kRead = 0, @@ -104,18 +105,18 @@ class ThreadedEngine : public DAGEngine { }; struct VarDescr { spinlock lock; - int rw; // a semaphore-like count - // if rw > 0, the variable has several readers and the number - // means how many operators are currently reading it; - // if rw < 0, the varaible has one writer (should be -1) + int rw; // a semaphore-like count + // if rw > 0, the variable has several readers and the number + // means how many operators are currently reading it; + // if rw < 0, the varaible has one writer (should be -1) queue, DepType>> waitings; }; void TriggerWaiting(VarDescr* vard) { // ATTENTION: this function should be called with vard->lock held. CHECK(vard->rw == 0) << "the variable should be free during triggering"; - if(!vard->waitings.empty()) { + if (!vard->waitings.empty()) { // pop all reads first - while(vard->waitings.front().second == DepType::kRead) { + while (vard->waitings.front().second == DepType::kRead) { vard->waitings.pop(); ++vard->rw; } @@ -128,44 +129,45 @@ class ThreadedEngine : public DAGEngine { } void OnOpFinished(OpDescr* opd) { CHECK(opd) << "completing a nullptr op!"; - for(Variable v : opd->read_vars) { - VarDescr* vard = static_cast(v); // safe to cast here + for (Variable v : opd->read_vars) { + VarDescr* vard = static_cast(v); // safe to cast here spin_lock(&vard->lock); CHECK(vard->rw > 0) << "incorrect rw count (reader):" << vard->rw; - if(--vard->rw == 0) { + if (--vard->rw == 0) { TriggerWaiting(vard); } spin_unlock(&vard->lock); } - for(Variable v : opd->write_vars) { - VarDescr* vard = static_cast(v); // safe to cast here + for (Variable v : opd->write_vars) { + VarDescr* vard = static_cast(v); // safe to cast here spin_lock(&vard->lock); CHECK(vard->rw == -1) << "incorrect rw count (writer):" << vard->rw; vard->rw = 0; TriggerWaiting(vard); spin_unlock(&vard->lock); } - delete opd; // delete the operator + delete opd; // delete the operator } RunContext GetRunContext(const Context& ctx) { - // TODO + // TODO(minjie): get the correct runtime context return RunContext(); } void OnDepsResolved(OpDescr* opd) { static default_random_engine generator; static uniform_int_distribution distribution(0, numthreads_ - 1); int thrid = distribution(generator); - //LOG(INFO) << "schedule operator " << opd << " to thread #" << thrid; + // LOG(INFO) << "schedule operator " << opd << " to thread #" << thrid; worker_queues_[thrid]->Push(opd); } void WorkerRoutine(int thrid) { OpDescr* opd = nullptr; - while(! worker_queues_[thrid]->Pop(opd)) { - //LOG(INFO) << "worker thread #" << thrid << " got operator " << opd; + while (!worker_queues_[thrid]->Pop(opd)) { + // LOG(INFO) << "worker thread #" << thrid << " got operator " << opd; opd->op(GetRunContext(opd->exec_ctx), [this, opd] () { this->OnOpFinished(opd); }); opd = nullptr; } } + private: const int numthreads_; vector*> worker_queues_; diff --git a/test/test_threaded_engine.cc b/test/test_threaded_engine.cc index e9d2566afb15..fecd552d1b50 100644 --- a/test/test_threaded_engine.cc +++ b/test/test_threaded_engine.cc @@ -1,8 +1,9 @@ // Copyright (c) 2015 by Contributors +#include #include #include -#include -#include + +#include "mxnet/dag_engine.h" using namespace std; using namespace mxnet; @@ -18,10 +19,10 @@ int main() { // Test #1 cout << "============= Test #1 ==============" << endl; vector vars; - for(int i = 0; i < 10; ++i) { + for (int i = 0; i < 10; ++i) { vars.push_back(engine->NewVar()); } - for(int i = 0; i < 10; ++i) { + for (int i = 0; i < 10; ++i) { engine->Push([i] (RunContext rctx) { Foo(rctx, i); }, exec_ctx, vars, {}); } @@ -30,7 +31,7 @@ int main() { // Test #2 cout << "============= Test #2 ==============" << endl; - for(int i = 0; i < 10; ++i) { + for (int i = 0; i < 10; ++i) { engine->Push([i] (RunContext rctx) { Foo(rctx, i); }, exec_ctx, {}, vars); } From 6da99c94de739beb5d30fdc4957b91342b47739c Mon Sep 17 00:00:00 2001 From: linmin Date: Thu, 9 Jul 2015 00:14:58 +0800 Subject: [PATCH 04/14] Interface of Symbol --- include/mxnet/symbol.h | 233 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 233 insertions(+) create mode 100644 include/mxnet/symbol.h diff --git a/include/mxnet/symbol.h b/include/mxnet/symbol.h new file mode 100644 index 000000000000..5669f6d0932e --- /dev/null +++ b/include/mxnet/symbol.h @@ -0,0 +1,233 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file symbol.h + * \brief symbol interface of mxnet + */ +#ifndef MXNET_SYMBOL_H_ +#define MXNET_SYMBOL_H_ + +#include +#include +#include +#include +#include "./base.h" +#include "./tensor_blob.h" + +using std::shared_ptr; +using std::vector; +using std::map; + +namespace mxnet { +/*! + * \brief Symbol is the wrapper of AtomicSymbol, the reason for this indirection is that Symbol should + * support expressions and often passed by value. While AtomicSymbol have many subclasses, passing by + * value would result in object slicing. + * + * Symbol is always composite, the head Node is the output node of the symbol. + * A atomic symbol can be seen as a special case of the composite symbol with only the head node. + */ +class Symbol { + protected: + /*! + * \brief Node is the container of AtomicSymbol, it also stores the connection of the AtomicSymbol + * with input symbols. + */ + class Node { + protected: + /*! wrapped atomic symbol */ + AtomicSymbol* sym_; + /*! inputs to this node */ + std::vector > in_symbol_; + /*! the output shape of the wrapped symbol */ + std::vector out_shape_; + /*! + * \brief hide the constructor + */ + explicit Node(AtomicSymbol* sym) : sym_(sym) {} + + public: + /*! + * \brief wrap the atomic symbol with a new Node and return this Node as shared_ptr + * \param sym the atomic symbol to be wrapped + * \return the shared_ptr to the Node that wraps the sym + */ + static std::shared_ptr Wrap(AtomicSymbol* sym) { + return std::make_shared(sym); + } + /*! + * \brief destructor + */ + virtual ~Node() { delete sym_; } + /*! + * \brief getter for the output shape of the wrapped atomic symbol + * \return const reference to the internal out_shape_ + */ + inline const std::vector& OutShape() const { return out_shape_; } + /*! + * \brief set the in_symbol_ + * \param in_symbol the input symbol to set for this Node. + * \tparam V vector > or its lvalue/rvalue references + */ + template + inline void SetInSymbol(V in_symbol) { + in_symbol_ = std::forward(in_symbol); + } + /*! + * \brief getter for the in_symbol_ + * \return the input symbols for this Node + */ + inline const std::vector >& InSymbol() { return in_symbol_; } + /*! + * \brief getter for the symbol wrapped in this Node + * \return get the pointer to the atomic symbol wrappe in this Node. + */ + inline const AtomicSymbol* Sym() const { return sym_; } + }; + /*! \brief the head node of the Symbol, it could be shared in many graphs */ + std::shared_ptr head_; + + public: + /*! + * \brief bind to device and returns an NArrayOperator. + * \param ctx context of the operator + * \return returns the pointer to a created NArrayOperator. It is on the user to delete. + */ + virtual NArrayOperator* Bind(Context ctx) const; + /*! + * \brief elementwise add to current symbol + * \param src the data to add + * \return reference of self + */ + Symbol &operator += (const Symbol &src); + /*! + * \brief elementwise subtract from current symbol + * \param src the data to substract + * \return reference of self + */ + Symbol &operator -= (const Symbol &src); + /*! + * \brief elementwise multiplication to current symbol + * \param src the data to multiply + * \return reference of self + */ + Symbol &operator *= (const Symbol &src); + /*! + * \brief elementwise division from current symbol + * \param src the data to divide + * \return reference of self + */ + Symbol &operator /= (const Symbol &src); + /*! + * \brief copy the symbol + * \return a deep copy of the graph + */ + virtual Symbol Copy() const { + // use Node* to avoid copying shared_ptr + std::map > old_new; + std::vector stk; + stk.push_back(head_.get()); + // copy nodes + while (!stk.empty()) { + Node* top = stk.back(); + stk.pop_back(); + if (old_new.count(top) == 0) { + old_new[top] = Node::Wrap(top->Sym()->Copy()); + } + for (const std::shared_ptr& n : top->InSymbol()) { + if (old_new->count(n.get()) == 0) { + stk.push_back(n.get()); + } + } + } + // connect nodes + for (auto kv : old_new) { + std::vector > in_symbol; + for (const std::shared_ptr& n : kv.first->InSymbol()) { + in_symbol.push_back(old_new[n.get()]); + } + kv.first->SetInSymbol(std::move(in_symbol)); + } + Symbol s; + s.head_ = old_new[this->head_.get()]; + return s; + } + /*! + * \brief compose with arguments + * \param args positional arguments for the symbol + * \return a new Symbol which is the composition of current symbol with its arguments + */ + virtual Symbol operator() (const vector& args); + /*! + * \brief compose with named arguments + * \param kwargs keyword arguments for the symbol + * \return a new symbol which is the composition of current symbol with its arguments + */ + virtual Symbol operator() (const map& kwargs) { + Symbol s = this->Copy(); + } + /*! + * \brief get the index th element from the returned tuple. + */ + virtual Symbol& operator[] (int index) { + } +}; + +/*! + * \brief AtomicSymbol is the base class of all atomic symbols. + * This is not meant to be used by user, it should be wrapped in Symbol, so that the same instance + * of AtomicSymbol can be shared in the graphs of different Symbols + */ +class AtomicSymbol { + /*! Only accessible from its wrapper Symbol */ + protected: + /*! + * \brief Constructor with param as the argument. + * \param param name value pairs of the param, the constructor call SetParam to set each of them. + */ + explicit AtomicSymbol(const std::map ¶m) { + for (std::map::iterator it = param.begin(); it != param.end(); ++it) { + this->SetParam(it->first.c_str(), it->second.c_str()); + } + } + /*! \brief get the number of inputs for this symbol */ + virtual int InCount() const { return 1; } + /*! \brief get the number of outputs for this symbol */ + virtual int OutCount() const { return 1; } + /*! + * \brief set param for the symbol from string + * \param name parameter name + * \param val string for the configuration + */ + virtual void SetParam(const char *name, const char *val) {} + /*! + * \brief infer the shapes of outputs and unknown input arguments + * \param in_shape the shape of input arguments of the operator + * this should be of same length as the vector returned by DescribeArgs + * in_shape allows unknown elements, which are checked by shape.ndim() == 0. + * For unknown shapes, InferShape will try to fill in the correct Shape in in_shape + * For known shapes, InferShape will check shape consistency + * + * common practice: set the shape of data input, and usually weight's shape can be infered + * + * \param out_shape the shape of outputs of the operator + * InferShape will modify the vector to fill output TShape + */ + virtual void InferShape(std::vector *in_shape, std::vector *out_shape) = 0; + /*! + * \brief Copy this AtomicSymbol and returns a shared_ptr to the copied object. + * this is a virtual function because different subclass of AtomicSymbol would copy differently. + * \return a const reference of the shared_ptr to the copied object. + * with return value optimization may be returning const reference is not necessary. + */ + virtual AtomicSymbol* Copy() const = 0; + /*! + * \brief Bind this AtomicSymbol to a context and get back a static operator + * Bind function of AtomicSymbol does not return NArrayOperator, but static operator. + * Calling bind from the Symbol wrapper would generate a NArrayOperator. + */ + virtual Operator* Bind(Context ctx) const = 0; + friend class Symbol; +}; + +} // namespace mxnet +#endif // MXNET_SYMBOL_H_ From 5a757ffc032db088de2ee3f3d1e9362a48911463 Mon Sep 17 00:00:00 2001 From: linmin Date: Tue, 14 Jul 2015 01:00:03 +0800 Subject: [PATCH 05/14] separate atomic_symbol from symbol --- include/mxnet/atomic_symbol.h | 74 +++++++++++++ include/mxnet/symbol.h | 194 ++++++---------------------------- 2 files changed, 104 insertions(+), 164 deletions(-) create mode 100644 include/mxnet/atomic_symbol.h diff --git a/include/mxnet/atomic_symbol.h b/include/mxnet/atomic_symbol.h new file mode 100644 index 000000000000..086bba9c6bae --- /dev/null +++ b/include/mxnet/atomic_symbol.h @@ -0,0 +1,74 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file atomic_symbol.h + * \brief atomic symbol interface of mxnet + */ +#ifndef MXNET_ATOMIC_SYMBOL_H_ +#define MXNET_ATOMIC_SYMBOL_H_ + +#include +#include +#include +#include +#include "./base.h" +#include "./tensor_blob.h" + +namespace mxnet { +class Operator; +/*! + * \brief AtomicSymbol is the base class of all atomic symbols. + * This is not meant to be used by user, it should be wrapped in Symbol, so that the same instance + * of AtomicSymbol can be shared in the graphs of different Symbols + */ +class AtomicSymbol { + public: + /*! + * \brief Constructor with param as the argument. + */ + AtomicSymbol(); + /*! + * \brief virtual destructor + */ + virtual ~AtomicSymbol(); + /*! \brief get the descriptions of inputs for this symbol */ + virtual std::vector DescribeArguments() const = 0; + /*! \brief get the descriptions of outputs for this symbol */ + virtual std::vector DescribeReturns() const = 0; + /*! + * \brief set param for the symbol from string + * \param name parameter name + * \param val string for the configuration + */ + virtual void SetParam(const char *name, const char *val) {} + /*! + * \brief infer the shapes of outputs and unknown input arguments + * \param in_shape the shape of input arguments of the operator + * this should be of same length as the vector returned by DescribeArgs + * in_shape allows unknown elements, which are checked by shape.ndim() == 0. + * For unknown shapes, InferShape will try to fill in the correct Shape in in_shape + * For known shapes, InferShape will check shape consistency + * + * common practice: set the shape of data input, and usually weight's shape can be infered + * + * \param out_shape the shape of outputs of the operator + * InferShape will modify the vector to fill output TShape + * \return if the shape inference is successful, return true, else return false. + */ + virtual bool InferShape(std::vector *in_shape, std::vector *out_shape) = 0; + /*! + * \brief Copy this AtomicSymbol and returns a pointer to the copied object. + * this is a virtual function because different subclass of AtomicSymbol would copy differently. + * \return a pointer to the copied atomic symbol + */ + virtual AtomicSymbol* Copy() const = 0; + /*! + * \brief Bind this AtomicSymbol to a context and get back a static operator + * Bind function of AtomicSymbol does not return NArrayOperator, but static operator. + * Calling bind from the Symbol wrapper would generate a NArrayOperator. + */ + virtual Operator* Bind(Context ctx) const = 0; + friend class Symbol; +}; + +} // namespace mxnet +#endif // MXNET_ATOMIC_SYMBOL_H_ diff --git a/include/mxnet/symbol.h b/include/mxnet/symbol.h index 5669f6d0932e..a8fb624d31d0 100644 --- a/include/mxnet/symbol.h +++ b/include/mxnet/symbol.h @@ -6,22 +6,21 @@ #ifndef MXNET_SYMBOL_H_ #define MXNET_SYMBOL_H_ +#include #include #include #include -#include +#include +#include #include "./base.h" #include "./tensor_blob.h" -using std::shared_ptr; -using std::vector; -using std::map; - namespace mxnet { +class NArrayOperator; /*! - * \brief Symbol is the wrapper of AtomicSymbol, the reason for this indirection is that Symbol should - * support expressions and often passed by value. While AtomicSymbol have many subclasses, passing by - * value would result in object slicing. + * \brief Symbol is the wrapper of AtomicSymbol, the reason for this indirection is that Symbol + * should support expressions and often passed by value. While AtomicSymbol have many subclasses, + * passing by value would result in object slicing. * * Symbol is always composite, the head Node is the output node of the symbol. * A atomic symbol can be seen as a special case of the composite symbol with only the head node. @@ -32,59 +31,34 @@ class Symbol { * \brief Node is the container of AtomicSymbol, it also stores the connection of the AtomicSymbol * with input symbols. */ - class Node { - protected: + struct Node { /*! wrapped atomic symbol */ - AtomicSymbol* sym_; + AtomicSymbol* sym_ = nullptr; + /*! name of the node */ + std::string name_ = ""; /*! inputs to this node */ std::vector > in_symbol_; + /*! index of the inputs if the inputs are tuple */ + std::vector in_index_; /*! the output shape of the wrapped symbol */ std::vector out_shape_; /*! - * \brief hide the constructor + * \brief constructor */ - explicit Node(AtomicSymbol* sym) : sym_(sym) {} - - public: - /*! - * \brief wrap the atomic symbol with a new Node and return this Node as shared_ptr - * \param sym the atomic symbol to be wrapped - * \return the shared_ptr to the Node that wraps the sym - */ - static std::shared_ptr Wrap(AtomicSymbol* sym) { - return std::make_shared(sym); - } + explicit Node(AtomicSymbol* sym = NULL, const std::string& name = ""); /*! * \brief destructor */ - virtual ~Node() { delete sym_; } - /*! - * \brief getter for the output shape of the wrapped atomic symbol - * \return const reference to the internal out_shape_ - */ - inline const std::vector& OutShape() const { return out_shape_; } - /*! - * \brief set the in_symbol_ - * \param in_symbol the input symbol to set for this Node. - * \tparam V vector > or its lvalue/rvalue references - */ - template - inline void SetInSymbol(V in_symbol) { - in_symbol_ = std::forward(in_symbol); - } - /*! - * \brief getter for the in_symbol_ - * \return the input symbols for this Node - */ - inline const std::vector >& InSymbol() { return in_symbol_; } - /*! - * \brief getter for the symbol wrapped in this Node - * \return get the pointer to the atomic symbol wrappe in this Node. - */ - inline const AtomicSymbol* Sym() const { return sym_; } + virtual ~Node(); }; /*! \brief the head node of the Symbol, it could be shared in many graphs */ std::shared_ptr head_; + /*! \brief if the head has multiple return values, index is used to specify */ + int index_; + /*! \brief find the nodes that use placeholder arguments */ + std::shared_ptr > > arg_users_; + /*! \brief find arg users */ + void FindArgUsers(); public: /*! @@ -92,141 +66,33 @@ class Symbol { * \param ctx context of the operator * \return returns the pointer to a created NArrayOperator. It is on the user to delete. */ - virtual NArrayOperator* Bind(Context ctx) const; - /*! - * \brief elementwise add to current symbol - * \param src the data to add - * \return reference of self - */ - Symbol &operator += (const Symbol &src); - /*! - * \brief elementwise subtract from current symbol - * \param src the data to substract - * \return reference of self - */ - Symbol &operator -= (const Symbol &src); - /*! - * \brief elementwise multiplication to current symbol - * \param src the data to multiply - * \return reference of self - */ - Symbol &operator *= (const Symbol &src); - /*! - * \brief elementwise division from current symbol - * \param src the data to divide - * \return reference of self - */ - Symbol &operator /= (const Symbol &src); + virtual NArrayOperator* Bind(Context ctx) const { return nullptr; } /*! * \brief copy the symbol * \return a deep copy of the graph */ - virtual Symbol Copy() const { - // use Node* to avoid copying shared_ptr - std::map > old_new; - std::vector stk; - stk.push_back(head_.get()); - // copy nodes - while (!stk.empty()) { - Node* top = stk.back(); - stk.pop_back(); - if (old_new.count(top) == 0) { - old_new[top] = Node::Wrap(top->Sym()->Copy()); - } - for (const std::shared_ptr& n : top->InSymbol()) { - if (old_new->count(n.get()) == 0) { - stk.push_back(n.get()); - } - } - } - // connect nodes - for (auto kv : old_new) { - std::vector > in_symbol; - for (const std::shared_ptr& n : kv.first->InSymbol()) { - in_symbol.push_back(old_new[n.get()]); - } - kv.first->SetInSymbol(std::move(in_symbol)); - } - Symbol s; - s.head_ = old_new[this->head_.get()]; - return s; - } + virtual Symbol Copy() const; /*! * \brief compose with arguments * \param args positional arguments for the symbol * \return a new Symbol which is the composition of current symbol with its arguments */ - virtual Symbol operator() (const vector& args); + virtual Symbol operator () (const std::vector& args) const; /*! * \brief compose with named arguments * \param kwargs keyword arguments for the symbol * \return a new symbol which is the composition of current symbol with its arguments */ - virtual Symbol operator() (const map& kwargs) { - Symbol s = this->Copy(); - } + virtual Symbol operator () (const std::unordered_map& kwargs) const; /*! * \brief get the index th element from the returned tuple. */ - virtual Symbol& operator[] (int index) { - } -}; - -/*! - * \brief AtomicSymbol is the base class of all atomic symbols. - * This is not meant to be used by user, it should be wrapped in Symbol, so that the same instance - * of AtomicSymbol can be shared in the graphs of different Symbols - */ -class AtomicSymbol { - /*! Only accessible from its wrapper Symbol */ - protected: - /*! - * \brief Constructor with param as the argument. - * \param param name value pairs of the param, the constructor call SetParam to set each of them. - */ - explicit AtomicSymbol(const std::map ¶m) { - for (std::map::iterator it = param.begin(); it != param.end(); ++it) { - this->SetParam(it->first.c_str(), it->second.c_str()); - } - } - /*! \brief get the number of inputs for this symbol */ - virtual int InCount() const { return 1; } - /*! \brief get the number of outputs for this symbol */ - virtual int OutCount() const { return 1; } - /*! - * \brief set param for the symbol from string - * \param name parameter name - * \param val string for the configuration - */ - virtual void SetParam(const char *name, const char *val) {} - /*! - * \brief infer the shapes of outputs and unknown input arguments - * \param in_shape the shape of input arguments of the operator - * this should be of same length as the vector returned by DescribeArgs - * in_shape allows unknown elements, which are checked by shape.ndim() == 0. - * For unknown shapes, InferShape will try to fill in the correct Shape in in_shape - * For known shapes, InferShape will check shape consistency - * - * common practice: set the shape of data input, and usually weight's shape can be infered - * - * \param out_shape the shape of outputs of the operator - * InferShape will modify the vector to fill output TShape - */ - virtual void InferShape(std::vector *in_shape, std::vector *out_shape) = 0; - /*! - * \brief Copy this AtomicSymbol and returns a shared_ptr to the copied object. - * this is a virtual function because different subclass of AtomicSymbol would copy differently. - * \return a const reference of the shared_ptr to the copied object. - * with return value optimization may be returning const reference is not necessary. - */ - virtual AtomicSymbol* Copy() const = 0; + virtual Symbol operator[] (int index) const; /*! - * \brief Bind this AtomicSymbol to a context and get back a static operator - * Bind function of AtomicSymbol does not return NArrayOperator, but static operator. - * Calling bind from the Symbol wrapper would generate a NArrayOperator. + * \brief arguments information + * \return the arguments list of this symbol, they can be either named or unnamed (empty string). */ - virtual Operator* Bind(Context ctx) const = 0; - friend class Symbol; + virtual std::vector ListArgs(); }; } // namespace mxnet From d239267cde83355bfadea0c366107d6baba1ec88 Mon Sep 17 00:00:00 2001 From: linmin Date: Tue, 14 Jul 2015 01:00:31 +0800 Subject: [PATCH 06/14] add implementation --- src/symbol/symbol.cc | 137 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 137 insertions(+) create mode 100644 src/symbol/symbol.cc diff --git a/src/symbol/symbol.cc b/src/symbol/symbol.cc new file mode 100644 index 000000000000..81e336069f27 --- /dev/null +++ b/src/symbol/symbol.cc @@ -0,0 +1,137 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file symbol.cc + * \brief symbol of mxnet + */ +#include +#include +#include + +namespace mxnet { + +Symbol::Node::Node(AtomicSymbol* sym, const std::string& name) : sym_(sym), name_(name) {} + +Symbol::Node::~Node() { + if (sym_) { + delete sym_; + } +} + +void Symbol::FindArgUsers() { + arg_users_.reset(new std::vector >); + // depth first traversing + std::vector > stk; + stk.push_back({head_.get(), 0}); + while (!stk.empty()) { + std::pair& back = stk.back(); + if (back.first->in_symbol_.size() == back.second) { + stk.pop_back(); + } else { + Node* next_level = back.first->in_symbol_[back.second].get(); + if (next_level->sym_) { + stk.push_back({next_level, 0}); + } else { // back uses next_level which is a placeholder + arg_users_->push_back({back.first, back.second}); + } + back.second += 1; + } + } +} + +Symbol Symbol::Copy() const { + Symbol s; + std::unordered_map > old_new; + std::vector stk; + stk.push_back(head_.get()); + // copy nodes + while (!stk.empty()) { + Node* back = stk.back(); + stk.pop_back(); + if (old_new.count(back) == 0) { + if (back->sym_) { + old_new[back] = std::make_shared(back->sym_->Copy(), back->name_); + } else { + old_new[back] = std::make_shared(nullptr, back->name_); + } + } + for (const std::shared_ptr& n : back->in_symbol_) { + if (old_new.count(n.get()) == 0) { + stk.push_back(n.get()); + } + } + } + // connect nodes + for (auto kv : old_new) { + for (const std::shared_ptr& n : kv.first->in_symbol_) { + kv.second->in_symbol_.push_back(old_new[n.get()]); + } + } + s.head_ = old_new[this->head_.get()]; + // copy arg_users_ + if (arg_users_) { + s.arg_users_.reset(new std::vector >); + std::transform(arg_users_->begin(), arg_users_->end(), std::back_inserter(*s.arg_users_), + [&old_new](const std::pair& n) -> std::pair { + return { old_new[n.first].get(), n.second }; + }); + } + return s; +} + +Symbol Symbol::operator () (const std::vector& args) const { + Symbol s = this->Copy(); + if (!s.arg_users_) { // if arg_users_ has not been populated + s.FindArgUsers(); + } + CHECK_LT(args.size(), s.arg_users_->size()) << "Too many args, requires " << s.arg_users_->size() + << " provided " << args.size(); + for (size_t i = 0; i < args.size(); ++i) { + const std::pair& arg_user = (*s.arg_users_)[i]; + arg_user.first->in_symbol_[arg_user.second] = args[i].head_; + CHECK_NE(args[i].index_, -1) << "Argument " << i << " is a tuple, scalar is required"; + arg_user.first->in_index_[arg_user.second] = args[i].index_; + } + return s; +} + +Symbol Symbol::operator () (const std::unordered_map& kwargs) const { + Symbol s = this->Copy(); + if (!s.arg_users_) { // if arg_users_ has not been populated + s.FindArgUsers(); + } + CHECK_LT(kwargs.size(), s.arg_users_->size()) << "Too many args, requires " + << s.arg_users_->size() << " provided " << kwargs.size(); + for (size_t i = 0; i < s.arg_users_->size(); ++i) { + const std::pair& arg_user = (*s.arg_users_)[i]; + const std::string& name = arg_user.first->name_; + if (!(name == "") && kwargs.count(name) != 0) { + const Symbol& bind = kwargs.at(name); + arg_user.first->in_symbol_[arg_user.second] = bind.head_; + CHECK_NE(bind.index_, -1) << "Argument " << name << " is a tuple, scalar is required"; + arg_user.first->in_index_[arg_user.second] = bind.index_; + } + } + // TODO(linmin): report error if kwargs contains non-existing keys + return s; +} + +Symbol Symbol::operator[] (int index) const { + CHECK_EQ(index_, -1) << "Current symbol can't be indexed because it returns a scalar."; + Symbol s = *this; + s.index_ = index; + return s; +} + +std::vector Symbol::ListArgs() { + std::vector ret; + if (!arg_users_) { + FindArgUsers(); + } + std::transform(arg_users_->begin(), arg_users_->end(), std::back_inserter(ret), + [&](const std::pair& n) -> std::string { + return n.first->in_symbol_[n.second]->name_; + }); + return ret; +} + +} // namespace mxnet From cda674c4c525a9bf729e55ec823ab1ca6d50db05 Mon Sep 17 00:00:00 2001 From: linmin Date: Tue, 14 Jul 2015 01:00:59 +0800 Subject: [PATCH 07/14] add symbol to makefile --- Makefile | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/Makefile b/Makefile index 5686bff75998..af1a7730a50e 100644 --- a/Makefile +++ b/Makefile @@ -54,11 +54,11 @@ ifneq ($(ADD_LDFLAGS), NONE) LDFLAGS += $(ADD_LDFLAGS) endif -#BIN = test/test_threaded_engine test/api_registry_test -BIN = test/api_registry_test +#BIN = test/test_threaded_engine test/api_registry_test +BIN = test/api_registry_test OBJ = storage.o narray_op_cpu.o operator.o operator_cpu.o # add threaded engine after it is done -OBJCXX11 = engine.o narray.o mxnet_api.o api_registry.o +OBJCXX11 = engine.o narray.o mxnet_api.o api_registry.o symbol.o CUOBJ = SLIB = api/libmxnet.so ALIB = api/libmxnet.a @@ -84,6 +84,7 @@ narray_op_gpu.o: src/narray/narray_op_gpu.cu src/narray/narray_op-inl.h operator.o: src/operator/operator.cc operator_cpu.o: src/operator/operator_cpu.cc operator_gpu.o: src/operator/operator_gpu.cu +symbol.o: src/symbol/symbol.cc api_registry.o: src/api_registry.cc mxnet_api.o: api/mxnet_api.cc @@ -124,4 +125,3 @@ doc: clean: $(RM) $(OBJ) $(OBJCXX11) $(BIN) $(CUBIN) $(CUOBJ) $(SLIB) $(ALIB) *~ */*~ */*/*~ */*/*/*~ cd $(DMLC_CORE); make clean; cd - - From 0a9128be0e1470639c6903479b9c25584f3a6045 Mon Sep 17 00:00:00 2001 From: linmin Date: Wed, 15 Jul 2015 10:37:53 +0800 Subject: [PATCH 08/14] fix a bit --- include/mxnet/symbol.h | 10 +++++----- src/symbol/symbol.cc | 6 +++++- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/include/mxnet/symbol.h b/include/mxnet/symbol.h index a8fb624d31d0..62a6803332dd 100644 --- a/include/mxnet/symbol.h +++ b/include/mxnet/symbol.h @@ -33,9 +33,9 @@ class Symbol { */ struct Node { /*! wrapped atomic symbol */ - AtomicSymbol* sym_ = nullptr; + AtomicSymbol* sym_; /*! name of the node */ - std::string name_ = ""; + std::string name_; /*! inputs to this node */ std::vector > in_symbol_; /*! index of the inputs if the inputs are tuple */ @@ -45,18 +45,18 @@ class Symbol { /*! * \brief constructor */ - explicit Node(AtomicSymbol* sym = NULL, const std::string& name = ""); + explicit Node(AtomicSymbol* sym = nullptr, const std::string& name = ""); /*! * \brief destructor */ - virtual ~Node(); + ~Node(); }; /*! \brief the head node of the Symbol, it could be shared in many graphs */ std::shared_ptr head_; /*! \brief if the head has multiple return values, index is used to specify */ int index_; /*! \brief find the nodes that use placeholder arguments */ - std::shared_ptr > > arg_users_; + std::unique_ptr > > arg_users_; /*! \brief find arg users */ void FindArgUsers(); diff --git a/src/symbol/symbol.cc b/src/symbol/symbol.cc index 81e336069f27..d44f63b52d95 100644 --- a/src/symbol/symbol.cc +++ b/src/symbol/symbol.cc @@ -9,7 +9,9 @@ namespace mxnet { -Symbol::Node::Node(AtomicSymbol* sym, const std::string& name) : sym_(sym), name_(name) {} +Symbol::Node::Node(AtomicSymbol* sym, const std::string& name) + : sym_(sym), name_(name) { +} Symbol::Node::~Node() { if (sym_) { @@ -91,6 +93,7 @@ Symbol Symbol::operator () (const std::vector& args) const { CHECK_NE(args[i].index_, -1) << "Argument " << i << " is a tuple, scalar is required"; arg_user.first->in_index_[arg_user.second] = args[i].index_; } + s.arg_users_.reset(); return s; } @@ -111,6 +114,7 @@ Symbol Symbol::operator () (const std::unordered_map& kwarg arg_user.first->in_index_[arg_user.second] = bind.index_; } } + s.arg_users_.reset(); // TODO(linmin): report error if kwargs contains non-existing keys return s; } From cd6ffee0072e66672514885cc193689dd783db8c Mon Sep 17 00:00:00 2001 From: linmin Date: Wed, 15 Jul 2015 16:08:39 +0800 Subject: [PATCH 09/14] expose c api --- api/mxnet_api.cc | 75 ++++++++++++++++++++++++++++++++++++ api/mxnet_api.h | 70 +++++++++++++++++++++++++++++---- api/python/mxnet/base.py | 1 + include/mxnet/api_registry.h | 65 +++++++++++++++++++++++++++++++ include/mxnet/symbol.h | 2 +- src/api_registry.cc | 22 +++++++++++ 6 files changed, 226 insertions(+), 9 deletions(-) diff --git a/api/mxnet_api.cc b/api/mxnet_api.cc index 0d0575ba488c..04412e570655 100644 --- a/api/mxnet_api.cc +++ b/api/mxnet_api.cc @@ -246,3 +246,78 @@ int MXFuncInvoke(FunctionHandle fun, (NArray**)(mutate_vars)); // NOLINT(*) API_END(); } + +int MXSymFree(SymbolHandle sym) { + API_BEGIN(); + delete static_cast(sym); + API_END(); +} + +int MXSymCreatorDescribe(SymbolCreatorHandle sym_creator, + mx_uint *use_param) { + API_BEGIN(); + auto *sc = static_cast(sym_creator); + *use_param = sc->use_param ? 1 : 0; + API_END(); +} + +int MXSymCreatorInvoke(SymbolCreatorHandle sym_creator, + int count, + const char** keys, + const char** vals, + SymbolHandle* out) { + API_BEGIN(); + const SymbolCreatorRegistry::Entry *sc = + static_cast(sym_creator); + sc->body(count, keys, vals, (Symbol**)(out)); + API_END(); +} + +int MXListSymCreators(mx_uint *out_size, + SymbolCreatorHandle **out_array) { + API_BEGIN(); + auto &vec = SymbolCreatorRegistry::List(); + *out_size = static_cast(vec.size()); + *out_array = (SymbolCreatorHandle*)(dmlc::BeginPtr(vec)); // NOLINT(*) + API_END(); +} + +int MXGetSymCreator(const char *name, + SymbolCreatorHandle *out) { + API_BEGIN(); + *out = SymbolCreatorRegistry::Find(name); + API_END(); +} + +int MXSymCreatorGetName(SymbolCreatorHandle sym_creator, + const char **out_name) { + API_BEGIN(); + auto *f = static_cast(sym_creator); + *out_name = f->name.c_str(); + API_END(); +} + +int MXSymbolCompose(SymbolHandle sym, + mx_uint num_args, + const char** keys, + SymbolHandle* args, + SymbolHandle* out) { + API_BEGIN(); + const Symbol* s = static_cast(sym); + Symbol* ret = new Symbol; + if (keys == NULL) { + std::vector pos_args; + for (mx_uint i = 0; i < num_args; ++i) { + pos_args.push_back(*(Symbol*)(args[i])); + } + *ret = (*s)(pos_args); + } else { + std::unordered_map kwargs; + for (mx_uint i = 0; i < num_args; ++i) { + kwargs[keys[i]] = *(Symbol*)(args[i]); + } + *ret = (*s)(kwargs); + } + *out = ret; + API_END(); +} diff --git a/api/mxnet_api.h b/api/mxnet_api.h index d30a18d571dd..39fba9220859 100644 --- a/api/mxnet_api.h +++ b/api/mxnet_api.h @@ -28,6 +28,8 @@ typedef float mx_float; typedef void *NArrayHandle; /*! \brief handle to a mxnet narray function that changes NArray */ typedef const void *FunctionHandle; +/*! \brief handle to a function that takes param and creates symbol */ +typedef const void *SymbolCreatorHandle; /*! \brief handle to a symbol that can be bind as operator */ typedef void *SymbolHandle; /*! \brief handle to a NArrayOperator */ @@ -217,17 +219,69 @@ MXNET_DLL int MXSymCreateFromConfig(const char *cfg, * \param sym the symbol * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXSymFree(SymbolHandle *sym); +MXNET_DLL int MXSymFree(SymbolHandle sym); /*! - * \brief set the parameter in to current symbol - * \param sym the symbol - * \param name name of the parameter - * \param val value of the parameter + * \brief query if the symbol creator needs param. + * \param sym_creator the symbol creator handle + * \param use_param describe if the symbol creator requires param + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXSymCreatorDescribe(SymbolCreatorHandle sym_creator, + mx_uint *use_param); +/*! + * \brief invoke registered symbol creator through its handle. + * \param sym_creator pointer to the symbolcreator function. + * \param count the number of the key value pairs in the param. + * \param keys an array of c str. + * \param vals the corresponding values of the keys. + * \param out pointer to the created symbol handle + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXSymCreatorInvoke(SymbolCreatorHandle sym_creator, + int count, + const char** keys, + const char** vals, + SymbolHandle* out); +/*! + * \brief list all the available sym_creator + * most user can use it to list all the needed sym_creators + * \param out_size the size of returned array + * \param out_array the output sym_creators + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXListSymCreators(mx_uint *out_size, + SymbolCreatorHandle **out_array); +/*! + * \brief get the sym_creator by name + * \param name the name of the sym_creator + * \param out the corresponding sym_creator + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXGetSymCreator(const char *name, + SymbolCreatorHandle *out); +/*! + * \brief get the name of sym_creator handle + * \param fun the sym_creator handle + * \param out_name the name of the sym_creator * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXSymSetParam(SymbolHandle sym, - const char *name, - const char *val); +MXNET_DLL int MXSymCreatorGetName(SymbolCreatorHandle sym_creator, + const char **out_name); +/*! + * \brief compose the symbol on other symbol + * \param sym the symbol to apply + * \param num_args number of arguments + * \param keys the key of keyword args (optional) + * \param args arguments to sym + * \param out the resulting symbol + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXSymbolCompose(SymbolHandle sym, + mx_uint num_args, + const char** keys, + SymbolHandle* args, + SymbolHandle* out); + //-------------------------------------------- // Part 4: operator interface on NArray //-------------------------------------------- diff --git a/api/python/mxnet/base.py b/api/python/mxnet/base.py index e7e20394f738..fd20cb77ef81 100644 --- a/api/python/mxnet/base.py +++ b/api/python/mxnet/base.py @@ -55,6 +55,7 @@ def _load_lib(): mx_float = ctypes.c_float NArrayHandle = ctypes.c_void_p FunctionHandle = ctypes.c_void_p +SymbolHandle = ctypes.c_void_p #---------------------------- # helper function definition diff --git a/include/mxnet/api_registry.h b/include/mxnet/api_registry.h index 403201f93ac4..0154002e8d1c 100644 --- a/include/mxnet/api_registry.h +++ b/include/mxnet/api_registry.h @@ -17,6 +17,7 @@ #include #include "./base.h" #include "./narray.h" +#include "./symbol.h" namespace mxnet { @@ -211,5 +212,69 @@ class FunctionRegistry { static auto __ ## name ## _narray_fun__ = \ ::mxnet::FunctionRegistry::Get()->Register("" # name) +/*! \brief registry of symbol creator */ +class SymbolCreatorRegistry { + public: + /*! \brief SymbolCreator is a function pointer */ + typedef void(*SymbolCreator)(int count, const char**, const char**, Symbol**); + /*! \return get a singleton */ + static SymbolCreatorRegistry *Get(); + /*! \brief keep the SymbolCreator function and its meta information */ + struct Entry { + /*! \brief the name of the symbol creator */ + std::string name; + /*! \brief the body of the function */ + SymbolCreator body; + /*! \brief if the creator requires params to construct */ + bool use_param; + /*! \brief constructor */ + explicit Entry(const std::string& name) : name(name), body(nullptr), use_param(true) {} + /*! \brief setter of body */ + inline Entry& set_body(SymbolCreator sc) { body = sc; return *this; } + /*! \brief setter of use_param */ + inline Entry& set_use_param(bool up) { use_param = up; return *this; } + }; + /*! + * \brief register a name symbol under name + * \param name name of the function + * \return ref to the registered entry, used to set properties + */ + Entry &Register(const std::string& name); + /*! \return list of functions in the registry */ + inline static const std::vector &List() { + return Get()->fun_list_; + } + /*! + * \brief find an symbolcreator entry with corresponding name + * \param name name of the symbolcreator + * \return the corresponding symbolcreator, can be NULL + */ + inline static const Entry *Find(const std::string &name) { + auto &fmap = Get()->fmap_; + auto p = fmap.find(name); + if (p != fmap.end()) { + return p->second; + } else { + return nullptr; + } + } + private: + /*! \brief list of functions */ + std::vector fun_list_; + /*! \brief map of name->function */ + std::map fmap_; + /*! \brief constructor */ + SymbolCreatorRegistry() {} + /*! \brief destructor */ + ~SymbolCreatorRegistry(); +}; + +/*! + * \brief macro to register symbol creator + */ +#define REGISTER_SYMBOL_CREATOR(name) \ + static auto __ ## name ## _symbol_creator__ = \ + ::mxnet::SymbolCreatorRegistry::Get()->Register("" # name) + } // namespace mxnet #endif // MXNET_API_REGISTRY_H_ diff --git a/include/mxnet/symbol.h b/include/mxnet/symbol.h index 62a6803332dd..a003c48d79d3 100644 --- a/include/mxnet/symbol.h +++ b/include/mxnet/symbol.h @@ -56,7 +56,7 @@ class Symbol { /*! \brief if the head has multiple return values, index is used to specify */ int index_; /*! \brief find the nodes that use placeholder arguments */ - std::unique_ptr > > arg_users_; + std::shared_ptr > > arg_users_; /*! \brief find arg users */ void FindArgUsers(); diff --git a/src/api_registry.cc b/src/api_registry.cc index 0a6423441bbe..c029502287e9 100644 --- a/src/api_registry.cc +++ b/src/api_registry.cc @@ -29,4 +29,26 @@ FunctionRegistry *FunctionRegistry::Get() { return &instance; } +// SymbolCreatorRegistry + +SymbolCreatorRegistry::Entry& +SymbolCreatorRegistry::Register(const std::string& name) { + CHECK_EQ(fmap_.count(name), 0); + Entry *e = new Entry(name); + fmap_[name] = e; + fun_list_.push_back(e); + return *e; +} + +SymbolCreatorRegistry::~SymbolCreatorRegistry() { + for (auto p = fmap_.begin(); p != fmap_.end(); ++p) { + delete p->second; + } +} + +SymbolCreatorRegistry *SymbolCreatorRegistry::Get() { + static SymbolCreatorRegistry instance; + return &instance; +} + } // namespace mxnet From 66d576c8d84db14e71b2eeaa4355127d89a722e7 Mon Sep 17 00:00:00 2001 From: linmin Date: Wed, 15 Jul 2015 16:09:58 +0800 Subject: [PATCH 10/14] symbol python api --- api/python/mxnet/__init__.py | 3 ++ api/python/mxnet/base.py | 2 +- api/python/mxnet/symbol.py | 64 ++++++++++++++++++++++++++++ api/python/mxnet/symbol_creator.py | 68 ++++++++++++++++++++++++++++++ 4 files changed, 136 insertions(+), 1 deletion(-) create mode 100644 api/python/mxnet/symbol.py create mode 100644 api/python/mxnet/symbol_creator.py diff --git a/api/python/mxnet/__init__.py b/api/python/mxnet/__init__.py index 28b1659efb75..dabf0b795412 100644 --- a/api/python/mxnet/__init__.py +++ b/api/python/mxnet/__init__.py @@ -13,6 +13,9 @@ from .context import Context, current_context from .narray import NArray from .function import _FunctionRegistry +from .symbol import Symbol +from .symbol_creator import _SymbolCreatorRegistry # this is a global function registry that can be used to invoke functions op = NArray._init_function_registry(_FunctionRegistry()) +sym = Symbol._init_symbol_creator_registry(_SymbolCreatorRegistry()) diff --git a/api/python/mxnet/base.py b/api/python/mxnet/base.py index fd20cb77ef81..239e284293b4 100644 --- a/api/python/mxnet/base.py +++ b/api/python/mxnet/base.py @@ -55,6 +55,7 @@ def _load_lib(): mx_float = ctypes.c_float NArrayHandle = ctypes.c_void_p FunctionHandle = ctypes.c_void_p +SymbolCreatorHandle = ctypes.c_void_p SymbolHandle = ctypes.c_void_p #---------------------------- @@ -132,4 +133,3 @@ def ctypes2numpy_shared(cptr, shape): size *= s dbuffer = (mx_float * size).from_address(ctypes.addressof(cptr.contents)) return np.frombuffer(dbuffer, dtype=np.float32).reshape(shape) - diff --git a/api/python/mxnet/symbol.py b/api/python/mxnet/symbol.py new file mode 100644 index 000000000000..2b0648580072 --- /dev/null +++ b/api/python/mxnet/symbol.py @@ -0,0 +1,64 @@ +# coding: utf-8 +"""Symbol support of mxnet""" +from __future__ import absolute_import + +import ctypes +from .base import _LIB +from .base import c_array +from .base import mx_uint, mx_float, SymbolHandle +from .base import check_call, MXNetError +from .narray import NArray, _new_empty_handle + +class Symbol(object): + """SymbolCreator is a function that takes Param and return symbol""" + _registry = None + + @staticmethod + def _init_symbol_creator_registry(symbol_creator_registry): + _registry = symbol_creator_registry + return _registry + + def __init__(self, handle): + """Initialize the function with handle + + Parameters + ---------- + handle : SymbolHandle + the handle to the underlying C++ Symbol + """ + self.handle = handle + + def __call__(self, *args, **kwargs): + """Compose Symbols + + Parameters + ---------- + args: + provide positional arguments + kwargs: + provide keyword arguments + Returns + ------- + the resulting symbol + """ + assert (len(args) == 0 or len(kwargs) == 0) + for arg in args: + assert isinstance(arg, Symbol) + for key, val in kwargs: + assert isinstance(val, Symbol) + num_args = len(args) + len(kwargs) + if len(kwargs) != 0: + keys = c_array(ctypes.c_char_p, map(c_str, kwargs.keys())) + args = c_array(SymbolHandle, kwargs.values()) + else: + keys = None + args = c_array(SymbolHandle, args) + + out = SymbolHandle() + check_call(_LIB.MXSymbolCompose( + self.handle, + num_args, + keys, + args, + ctypes.byref(out))) + return Symbol(out) diff --git a/api/python/mxnet/symbol_creator.py b/api/python/mxnet/symbol_creator.py new file mode 100644 index 000000000000..8bcc781f831f --- /dev/null +++ b/api/python/mxnet/symbol_creator.py @@ -0,0 +1,68 @@ +# coding: utf-8 +"""Symbol support of mxnet""" +from __future__ import absolute_import + +import ctypes +from .base import _LIB +from .base import c_array +from .base import mx_uint, mx_float, NArrayHandle +from .base import check_call, MXNetError +from .narray import NArray, _new_empty_handle + +class _SymbolCreator(object): + """SymbolCreator is a function that takes Param and return symbol""" + + def __init__(self, handle, name): + """Initialize the function with handle + + Parameters + ---------- + handle : SymbolCreatorHandle + the function handle of the function + + name : string + the name of the function + """ + self.handle = handle + self.name = name + check_call(_LIB.MXSymCreatorDescribe( + self.handle, + ctypes.byref(use_param))) + self.use_param = use_param.value + + def __call__(self, **kwargs): + """Invoke creator of symbol by passing kwargs + + Parameters + ---------- + params : kwargs + provide the params necessary for the symbol creation + Returns + ------- + the resulting symbol + """ + keys = c_array(ctypes.c_char_p, map(c_str, kwargs.keys())) + vals = c_array(ctypes.c_char_p, map(c_str, map(str, kwargs.values()))) + sym_handle = SymbolHandle() + check_call(_LIB.MXSymCreatorInvoke( + self.handle, + mx_uint(len(kwargs)), + keys, + vals, + ctypes.byref(sym_handle))) + return Symbol(sym_handle) + +class _SymbolCreatorRegistry(object): + """Function Registry""" + def __init__(self): + plist = ctypes.POINTER(ctypes.c_void_p)() + size = ctypes.c_uint() + check_call(_LIB.MXListSymCreators(ctypes.byref(size), + ctypes.byref(plist))) + hmap = {} + for i in range(size.value): + hdl = plist[i] + name = ctypes.c_char_p() + check_call(_LIB.MXSymCreatorGetName(hdl, ctypes.byref(name))) + hmap[name.value] = _SymbolCreator(hdl, name.value) + self.__dict__.update(hmap) From c7d5e59ec8e357ceb42761f1c0e7fc593c79a658 Mon Sep 17 00:00:00 2001 From: linmin Date: Wed, 15 Jul 2015 16:13:11 +0800 Subject: [PATCH 11/14] fix lint and doc --- api/mxnet_api.cc | 6 +++--- api/mxnet_api.h | 2 +- include/mxnet/api_registry.h | 1 + 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/api/mxnet_api.cc b/api/mxnet_api.cc index 04412e570655..7e351e52cb03 100644 --- a/api/mxnet_api.cc +++ b/api/mxnet_api.cc @@ -269,7 +269,7 @@ int MXSymCreatorInvoke(SymbolCreatorHandle sym_creator, API_BEGIN(); const SymbolCreatorRegistry::Entry *sc = static_cast(sym_creator); - sc->body(count, keys, vals, (Symbol**)(out)); + sc->body(count, keys, vals, (Symbol**)(out)); // NOLINT(*) API_END(); } @@ -308,13 +308,13 @@ int MXSymbolCompose(SymbolHandle sym, if (keys == NULL) { std::vector pos_args; for (mx_uint i = 0; i < num_args; ++i) { - pos_args.push_back(*(Symbol*)(args[i])); + pos_args.push_back(*(Symbol*)(args[i])); // NOLINT(*) } *ret = (*s)(pos_args); } else { std::unordered_map kwargs; for (mx_uint i = 0; i < num_args; ++i) { - kwargs[keys[i]] = *(Symbol*)(args[i]); + kwargs[keys[i]] = *(Symbol*)(args[i]); // NOLINT(*) } *ret = (*s)(kwargs); } diff --git a/api/mxnet_api.h b/api/mxnet_api.h index 39fba9220859..fb4b8710e2e6 100644 --- a/api/mxnet_api.h +++ b/api/mxnet_api.h @@ -261,7 +261,7 @@ MXNET_DLL int MXGetSymCreator(const char *name, SymbolCreatorHandle *out); /*! * \brief get the name of sym_creator handle - * \param fun the sym_creator handle + * \param sym_creator the sym_creator handle * \param out_name the name of the sym_creator * \return 0 when success, -1 when failure happens */ diff --git a/include/mxnet/api_registry.h b/include/mxnet/api_registry.h index 0154002e8d1c..91083c2ea11d 100644 --- a/include/mxnet/api_registry.h +++ b/include/mxnet/api_registry.h @@ -258,6 +258,7 @@ class SymbolCreatorRegistry { return nullptr; } } + private: /*! \brief list of functions */ std::vector fun_list_; From cb4b45466efcbe7f2d21fca457c5742567db39a9 Mon Sep 17 00:00:00 2001 From: linmin Date: Wed, 15 Jul 2015 16:59:09 +0800 Subject: [PATCH 12/14] add template function to create atomic symbol --- include/mxnet/symbol.h | 48 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/include/mxnet/symbol.h b/include/mxnet/symbol.h index a003c48d79d3..8aacc0b27080 100644 --- a/include/mxnet/symbol.h +++ b/include/mxnet/symbol.h @@ -93,7 +93,55 @@ class Symbol { * \return the arguments list of this symbol, they can be either named or unnamed (empty string). */ virtual std::vector ListArgs(); + /*! + * \brief create atomic symbol wrapped in symbol + * \param param the parameter stored as key value pairs + * \return the constructed Symbol + */ + template + static Symbol CreateSymbol(const std::vector >& param) { + Symbol* s; + std::vector keys(param.size()); + std::vector vals(param.size()); + for (auto p : param) { + keys.push_back(p.first.c_str()); + vals.push_back(p.second.c_str()); + } + CreateSymbol(param.size(), &keys[0], &vals[0], &s); + Symbol ret = *s; + delete s; + return ret; + } + /*! + * \brief create + */ + template + friend void CreateSymbol(int, const char**, const char**, Symbol**); }; +template +void CreateSymbol(int num_param, const char** keys, const char** vals, Symbol** out) { + Symbol* s = new Symbol; + Atomic* atom = new Atomic; + for (int i = 0; i < num_param; ++i) { + atom->SetParam(keys[i], vals[i]); + } + std::vector args = atom->DescribeArguments(); + std::vector rets = atom->DescribeReturns(); + // set head_ + s->head_ = std::make_shared(atom, ""); + // set index_ + s->index_ = rets.size() > 1 ? -1 : 0; + // set head_->in_index_ + s->head_->in_index_ = std::vector(args.size(), 0); + // set head_->in_symbol_ + for (auto name : args) { + s->head_->in_symbol_.push_back(std::make_shared(nullptr, name)); + } + // set head_->out_shape_ + s->head_->out_shape_ = std::vector(rets.size()); + *out = s; +} + } // namespace mxnet #endif // MXNET_SYMBOL_H_ From 007b15b103019a16e89ee84033e1f7995161dfeb Mon Sep 17 00:00:00 2001 From: linmin Date: Wed, 15 Jul 2015 17:35:42 +0800 Subject: [PATCH 13/14] resolve some errors and lints --- api/python/mxnet/symbol.py | 21 +++++++++++++++------ api/python/mxnet/symbol_creator.py | 13 +++++++------ 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/api/python/mxnet/symbol.py b/api/python/mxnet/symbol.py index 2b0648580072..6f4146d162e3 100644 --- a/api/python/mxnet/symbol.py +++ b/api/python/mxnet/symbol.py @@ -4,10 +4,9 @@ import ctypes from .base import _LIB -from .base import c_array -from .base import mx_uint, mx_float, SymbolHandle -from .base import check_call, MXNetError -from .narray import NArray, _new_empty_handle +from .base import c_array, c_str +from .base import SymbolHandle +from .base import check_call class Symbol(object): """SymbolCreator is a function that takes Param and return symbol""" @@ -15,6 +14,16 @@ class Symbol(object): @staticmethod def _init_symbol_creator_registry(symbol_creator_registry): + """Initialize symbol creator registry + + Parameters + ---------- + symbol_creator_registry: + pass in symbol_creator_registry + Returns + ------- + the passed in registry + """ _registry = symbol_creator_registry return _registry @@ -44,11 +53,11 @@ def __call__(self, *args, **kwargs): assert (len(args) == 0 or len(kwargs) == 0) for arg in args: assert isinstance(arg, Symbol) - for key, val in kwargs: + for _, val in kwargs: assert isinstance(val, Symbol) num_args = len(args) + len(kwargs) if len(kwargs) != 0: - keys = c_array(ctypes.c_char_p, map(c_str, kwargs.keys())) + keys = c_array(ctypes.c_char_p, [c_str(key) for key in kwargs.keys()]) args = c_array(SymbolHandle, kwargs.values()) else: keys = None diff --git a/api/python/mxnet/symbol_creator.py b/api/python/mxnet/symbol_creator.py index 8bcc781f831f..e8a49149ec35 100644 --- a/api/python/mxnet/symbol_creator.py +++ b/api/python/mxnet/symbol_creator.py @@ -4,10 +4,10 @@ import ctypes from .base import _LIB -from .base import c_array -from .base import mx_uint, mx_float, NArrayHandle -from .base import check_call, MXNetError -from .narray import NArray, _new_empty_handle +from .base import c_array, c_str +from .base import mx_uint, SymbolHandle +from .base import check_call +from .symbol import Symbol class _SymbolCreator(object): """SymbolCreator is a function that takes Param and return symbol""" @@ -25,6 +25,7 @@ def __init__(self, handle, name): """ self.handle = handle self.name = name + use_param = mx_uint() check_call(_LIB.MXSymCreatorDescribe( self.handle, ctypes.byref(use_param))) @@ -41,8 +42,8 @@ def __call__(self, **kwargs): ------- the resulting symbol """ - keys = c_array(ctypes.c_char_p, map(c_str, kwargs.keys())) - vals = c_array(ctypes.c_char_p, map(c_str, map(str, kwargs.values()))) + keys = c_array(ctypes.c_char_p, [c_str(key) for key in kwargs.keys()]) + vals = c_array(ctypes.c_char_p, [c_str(str(val)) for val in kwargs.values()]) sym_handle = SymbolHandle() check_call(_LIB.MXSymCreatorInvoke( self.handle, From 93f080191dbfef0c2e14bb66beed1bd62c5dc63a Mon Sep 17 00:00:00 2001 From: linmin Date: Wed, 15 Jul 2015 17:45:19 +0800 Subject: [PATCH 14/14] fix docs --- include/mxnet/symbol.h | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/include/mxnet/symbol.h b/include/mxnet/symbol.h index 8aacc0b27080..a623cfa1502e 100644 --- a/include/mxnet/symbol.h +++ b/include/mxnet/symbol.h @@ -113,10 +113,14 @@ class Symbol { return ret; } /*! - * \brief create + * \brief c api for CreateSymbol, this can be registered with SymbolCreatorRegistry + * \param num_param the number of params + * \param keys the key for the params + * \param vals values of the params + * \param out stores the returning symbol */ template - friend void CreateSymbol(int, const char**, const char**, Symbol**); + friend void CreateSymbol(int num_param, const char** keys, const char** vals, Symbol** out); }; template