This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #4 from winstywang/master
rename
- Loading branch information
Showing
19 changed files
with
395 additions
and
161 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
/*! | ||
* Copyright (c) 2015 by Contributors | ||
* \file static_operator.h | ||
* \brief static operator interface of mxnet | ||
*/ | ||
#ifndef MXNET_STATIC_OPERATOR_H_ | ||
#define MXNET_STATIC_OPERATOR_H_ | ||
// this file will be seen by cuda, no c++11 for now | ||
#include <dmlc/base.h> | ||
#include <vector> | ||
#include "./base.h" | ||
#include "./tensor_blob.h" | ||
|
||
namespace mxnet { | ||
/*! | ||
* \brief static StaticOperator interface (current interface have not yet todo with scheduler), | ||
* StaticOperator is a stateful object that can be used to call forward and backprop | ||
* | ||
* This interface relies on pre-allocated memory in TBlob, the caller need to set | ||
* the memory region in TBlob correctly before calling Forward and Backward | ||
* | ||
* \sa TBlob, TShape | ||
*/ | ||
class StaticOperator { | ||
public: | ||
/*! | ||
* \brief get types of input argument of this oeprator | ||
* \return a vector corresponding to type of each argument | ||
* this order is same as the order of inputs in Forward, InferShape and Backward | ||
*/ | ||
virtual std::vector<ArgType> DescribeArgs() const { | ||
// default most of layers only have one data argument | ||
return std::vector<ArgType>(1, kDataArg); | ||
} | ||
/*! | ||
* \brief describe property of op | ||
* \return a bit map in int | ||
*/ | ||
virtual int DescribeProperty() const { | ||
// default most of layer only conatin internal state | ||
return kContainInteralState; | ||
} | ||
/*! | ||
* \brief set param for the StaticOperator from string | ||
* \param name parameter name | ||
* \param val string for configuration | ||
*/ | ||
virtual void SetParam(const char *name, const char *val) {} | ||
/*! | ||
* \brief inter the shapes of outputs and unknown input arguments | ||
* \param in_shape the shape of input arguments of the StaticOperator | ||
* this should be of same length as the vector returned by DescribeArgs | ||
* in_shape allows unknown elements, which are checked by shape.ndim() == 0. | ||
* For unknown shapes, InferShape will try to fill in the correct Shape in in_shape | ||
* For known shapes, InferShape will check shape consistency | ||
* | ||
* common practice: set the shape of data input, and usually weight's shape can be infered | ||
* | ||
* \param out_shape the shape of outputs of the StaticOperator | ||
* InferShape will modify the vector to fill output TShape | ||
*/ | ||
virtual void InferShape(std::vector<TShape> *in_shape, | ||
std::vector<TShape> *out_shape) = 0; | ||
/*! | ||
* \brief perform a forward operation of StaticOperator, save the output to TBlob | ||
* \param opt option on Forward such as whether this is training phase | ||
* \param ctx runtime context | ||
* \param in_data array of input data, it is const | ||
* \param out_data array of output data, | ||
* the space of TBlob in out_data must be pre-allocated with InferShape | ||
*/ | ||
virtual void Forward(Option opt, | ||
RunContext ctx, | ||
const std::vector<TBlob> &in_data, | ||
const std::vector<TBlob> &out_data) = 0; | ||
/*! | ||
* \brief perform a backward operation of the StaticOperator to get the gradient | ||
* \param ctx runtime context | ||
* \param grad_next the gradient value we get from output of the StaticOperator | ||
* \param in_data the array of input data | ||
* \param out_grad array of output gradient, there could be three possible TBlob | ||
* in the each element in the array | ||
* \param req request types of the gradient saving operation | ||
* only inplace will change input data | ||
* \sa GradReqType | ||
*/ | ||
virtual void Backward(RunContext ctx, | ||
const std::vector<TBlob> &grad_next, | ||
const std::vector<TBlob> &in_data, | ||
const std::vector<TBlob> &out_grad, | ||
const std::vector<GradReqType> &req) = 0; | ||
/*! | ||
* \brief factory unction, create a new StaticOperator | ||
* \param type the type of StaticOperator | ||
* \param ctx the context device type of StaticOperator | ||
* \return a pointer of StaticOperator object | ||
*/ | ||
static StaticOperator *Create(const char *type, Context ctx); | ||
}; | ||
} // namespace mxnet | ||
#endif // MXNET_STATIC_OPERATOR_H_ |
Oops, something went wrong.