Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MetaSchedule] Enable Clone Function for Task-Level Classes #12796

Merged
merged 10 commits into from
Sep 16, 2022
88 changes: 54 additions & 34 deletions include/tvm/meta_schedule/mutator.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ namespace tvm {
namespace meta_schedule {

class TuneContext;
class Mutator;

/*! \brief Mutator is designed to mutate the trace to explore the design space. */
class MutatorNode : public runtime::Object {
Expand All @@ -57,12 +58,21 @@ class MutatorNode : public runtime::Object {
virtual Optional<tir::Trace> Apply(const tir::Trace& trace,
support::LinearCongruentialEngine::TRandState* rand_state) = 0;

/*!
* \brief Clone the mutator.
* \return The cloned mutator.
*/
virtual Mutator Clone() const = 0;

static constexpr const char* _type_key = "meta_schedule.Mutator";
TVM_DECLARE_BASE_OBJECT_INFO(MutatorNode, Object);
};

/*! \brief The mutator with customized methods on the python-side. */
class PyMutatorNode : public MutatorNode {
/*!
* \brief Managed reference to MutatorNode
* \sa MutatorNode
*/
class Mutator : public runtime::ObjectRef {
public:
/*!
* \brief The function type of `InitializeWithTuneContext` method.
Expand All @@ -76,39 +86,16 @@ class PyMutatorNode : public MutatorNode {
*/
using FApply = runtime::TypedPackedFunc<Optional<tir::Trace>(
const tir::Trace&, support::LinearCongruentialEngine::TRandState rand_state)>;
/*!
* \brief Clone the mutator.
* \return The cloned mutator.
*/
using FClone = runtime::TypedPackedFunc<Mutator()>;
/*!
* \brief Get the mutator as string with name.
* \return The string of the mutator.
*/
using FAsString = runtime::TypedPackedFunc<String()>;

/*! \brief The packed function to the `InitializeWithTuneContext` function. */
FInitializeWithTuneContext f_initialize_with_tune_context;
/*! \brief The packed function to the `Apply` function. */
FApply f_apply;
/*! \brief The packed function to the `AsString` function. */
FAsString f_as_string;

void VisitAttrs(tvm::AttrVisitor* v) {
// `f_initialize_with_tune_context` is not visited
// `f_apply` is not visited
// `f_as_string` is not visited
}

void InitializeWithTuneContext(const TuneContext& context) final;
Optional<tir::Trace> Apply(const tir::Trace& trace,
support::LinearCongruentialEngine::TRandState* rand_state) final;

static constexpr const char* _type_key = "meta_schedule.PyMutator";
TVM_DECLARE_FINAL_OBJECT_INFO(PyMutatorNode, MutatorNode);
};

/*!
* \brief Managed reference to MutatorNode
* \sa MutatorNode
*/
class Mutator : public runtime::ObjectRef {
public:
/*! \brief Create a Mutator that mutates the decision of instruction Sample-Perfect-Tile */
TVM_DLL static Mutator MutateTileSize();
/*!
Expand Down Expand Up @@ -136,16 +123,49 @@ class Mutator : public runtime::ObjectRef {
* \brief Create a mutator with customized methods on the python-side.
* \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`.
* \param f_apply The packed function of `Apply`.
* \param f_clone The packed function of `Clone`.
* \param f_as_string The packed function of `AsString`.
* \return The mutator created.
*/
TVM_DLL static Mutator PyMutator(
PyMutatorNode::FInitializeWithTuneContext f_initialize_with_tune_context, //
PyMutatorNode::FApply f_apply, //
PyMutatorNode::FAsString f_as_string);
TVM_DLL static Mutator PyMutator(FInitializeWithTuneContext f_initialize_with_tune_context, //
FApply f_apply, //
FClone f_clone, //
FAsString f_as_string);
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Mutator, ObjectRef, MutatorNode);
};

/*! \brief The mutator with customized methods on the python-side. */
class PyMutatorNode : public MutatorNode {
public:
using FInitializeWithTuneContext = Mutator::FInitializeWithTuneContext;
using FApply = Mutator::FApply;
using FClone = Mutator::FClone;
using FAsString = Mutator::FAsString;
/*! \brief The packed function to the `InitializeWithTuneContext` function. */
FInitializeWithTuneContext f_initialize_with_tune_context;
/*! \brief The packed function to the `Apply` function. */
FApply f_apply;
/*! \brief The packed function to the `Clone` function. */
FClone f_clone;
/*! \brief The packed function to the `AsString` function. */
FAsString f_as_string;

void VisitAttrs(tvm::AttrVisitor* v) {
// `f_initialize_with_tune_context` is not visited
// `f_apply` is not visited
// `f_clone` is not visited
// `f_as_string` is not visited
}

void InitializeWithTuneContext(const TuneContext& context) final;
Optional<tir::Trace> Apply(const tir::Trace& trace,
support::LinearCongruentialEngine::TRandState* rand_state) final;
Mutator Clone() const final;

static constexpr const char* _type_key = "meta_schedule.PyMutator";
TVM_DECLARE_FINAL_OBJECT_INFO(PyMutatorNode, MutatorNode);
};

} // namespace meta_schedule
} // namespace tvm

Expand Down
86 changes: 53 additions & 33 deletions include/tvm/meta_schedule/postproc.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ namespace tvm {
namespace meta_schedule {

class TuneContext;
class Postproc;

/*!
* \brief Rules to apply a postprocessor to a schedule.
Expand All @@ -54,12 +55,21 @@ class PostprocNode : public runtime::Object {
*/
virtual bool Apply(const tir::Schedule& sch) = 0;

/*!
* \brief Clone the postprocessor.
* \return The cloned postprocessor.
*/
virtual Postproc Clone() const = 0;

static constexpr const char* _type_key = "meta_schedule.Postproc";
TVM_DECLARE_BASE_OBJECT_INFO(PostprocNode, Object);
};

/*! \brief The postprocessor with customized methods on the python-side. */
class PyPostprocNode : public PostprocNode {
/*!
* \brief Managed reference to PostprocNode
* \sa PostprocNode
*/
class Postproc : public runtime::ObjectRef {
public:
/*!
* \brief The function type of `InitializeWithTuneContext` method.
Expand All @@ -72,49 +82,28 @@ class PyPostprocNode : public PostprocNode {
* \return Whether the postprocessor was successfully applied.
*/
using FApply = runtime::TypedPackedFunc<bool(const tir::Schedule&)>;
/*!
* \brief Clone the postprocessor.
* \return The cloned postprocessor.
*/
using FClone = runtime::TypedPackedFunc<Postproc()>;
/*!
* \brief Get the postprocessor function as string with name.
* \return The string of the postprocessor function.
*/
using FAsString = runtime::TypedPackedFunc<String()>;

/*! \brief The packed function to the `InitializeWithTuneContext` function. */
FInitializeWithTuneContext f_initialize_with_tune_context;
/*! \brief The packed function to the `Apply` function. */
FApply f_apply;
/*! \brief The packed function to the `AsString` function. */
FAsString f_as_string;

void VisitAttrs(tvm::AttrVisitor* v) {
// `f_initialize_with_tune_context` is not visited
// `f_apply` is not visited
// `f_as_string` is not visited
}

void InitializeWithTuneContext(const TuneContext& context) final;
bool Apply(const tir::Schedule& sch) final;

static constexpr const char* _type_key = "meta_schedule.PyPostproc";
TVM_DECLARE_FINAL_OBJECT_INFO(PyPostprocNode, PostprocNode);
};

/*!
* \brief Managed reference to PostprocNode
* \sa PostprocNode
*/
class Postproc : public runtime::ObjectRef {
public:
/*!
* \brief Create a postprocessor with customized methods on the python-side.
* \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`.
* \param f_apply The packed function of `Apply`.
* \param f_clone The packed function of `Clone`.
* \param f_as_string The packed function of `AsString`.
* \return The postprocessor created.
*/
TVM_DLL static Postproc PyPostproc(
PyPostprocNode::FInitializeWithTuneContext f_initialize_with_tune_context, //
PyPostprocNode::FApply f_apply, //
PyPostprocNode::FAsString f_as_string);
TVM_DLL static Postproc PyPostproc(FInitializeWithTuneContext f_initialize_with_tune_context, //
FApply f_apply, //
FClone f_clone, //
FAsString f_as_string);
/*!
* \brief Create a postprocessor that checks if all loops are static
* \return The postprocessor created
Expand Down Expand Up @@ -164,6 +153,37 @@ class Postproc : public runtime::ObjectRef {
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Postproc, ObjectRef, PostprocNode);
};

/*! \brief The postprocessor with customized methods on the python-side. */
class PyPostprocNode : public PostprocNode {
public:
using FInitializeWithTuneContext = Postproc::FInitializeWithTuneContext;
using FApply = Postproc::FApply;
using FClone = Postproc::FClone;
using FAsString = Postproc::FAsString;
/*! \brief The packed function to the `InitializeWithTuneContext` function. */
FInitializeWithTuneContext f_initialize_with_tune_context;
/*! \brief The packed function to the `Apply` function. */
FApply f_apply;
/*! \brief The packed function to the `Clone` function. */
FClone f_clone;
/*! \brief The packed function to the `AsString` function. */
FAsString f_as_string;

void VisitAttrs(tvm::AttrVisitor* v) {
// `f_initialize_with_tune_context` is not visited
// `f_apply` is not visited
// `f_clone` is not visited
// `f_as_string` is not visited
}

void InitializeWithTuneContext(const TuneContext& context) final;
bool Apply(const tir::Schedule& sch) final;
Postproc Clone() const final;

static constexpr const char* _type_key = "meta_schedule.PyPostproc";
TVM_DECLARE_FINAL_OBJECT_INFO(PyPostprocNode, PostprocNode);
};

} // namespace meta_schedule
} // namespace tvm

Expand Down
86 changes: 54 additions & 32 deletions include/tvm/meta_schedule/schedule_rule.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ namespace tvm {
namespace meta_schedule {

class TuneContext;
class ScheduleRule;

/*! \brief Rules to modify a block in a schedule. */
class ScheduleRuleNode : public runtime::Object {
Expand All @@ -59,12 +60,21 @@ class ScheduleRuleNode : public runtime::Object {
virtual runtime::Array<tir::Schedule> Apply(const tir::Schedule& sch,
const tir::BlockRV& block) = 0;

/*!
* \brief Deep clone the schedule rule.
* \return The cloned schedule rule.
*/
virtual ScheduleRule Clone() const = 0;

static constexpr const char* _type_key = "meta_schedule.ScheduleRule";
TVM_DECLARE_BASE_OBJECT_INFO(ScheduleRuleNode, Object);
};

/*! \brief The schedule rule with customized methods on the python-side. */
class PyScheduleRuleNode : public ScheduleRuleNode {
/*!
* \brief Managed reference to ScheduleRuleNode
* \sa ScheduleRuleNode
*/
class ScheduleRule : public runtime::ObjectRef {
public:
/*!
* \brief The function type of `InitializeWithTuneContext` method.
Expand All @@ -84,33 +94,11 @@ class PyScheduleRuleNode : public ScheduleRuleNode {
* \return The string of the schedule rule.
*/
using FAsString = runtime::TypedPackedFunc<String()>;

/*! \brief The packed function to the `InitializeWithTuneContext` function. */
FInitializeWithTuneContext f_initialize_with_tune_context;
/*! \brief The packed function to the `Apply` function. */
FApply f_apply;
/*! \brief The packed function to the `AsString` function. */
FAsString f_as_string;

void VisitAttrs(tvm::AttrVisitor* v) {
// `f_initialize_with_tune_context` is not visited
// `f_apply` is not visited
// `f_as_string` is not visited
}

void InitializeWithTuneContext(const TuneContext& context) final;
Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block) final;

static constexpr const char* _type_key = "meta_schedule.PyScheduleRule";
TVM_DECLARE_FINAL_OBJECT_INFO(PyScheduleRuleNode, ScheduleRuleNode);
};

/*!
* \brief Managed reference to ScheduleRuleNode
* \sa ScheduleRuleNode
*/
class ScheduleRule : public runtime::ObjectRef {
public:
/*!
* \brief The function type of `Clone` method.
* \return The cloned schedule rule.
*/
using FClone = runtime::TypedPackedFunc<ScheduleRule()>;
/*!
* \brief Create an auto-inline rule that inlines spatial blocks if it satisfies some conditions
* \param into_producer If allows to inline a block into its producer
Expand Down Expand Up @@ -249,16 +237,50 @@ class ScheduleRule : public runtime::ObjectRef {
* \brief Create a schedule rule with customized methods on the python-side.
* \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`.
* \param f_apply The packed function of `Apply`.
* \param f_clone The packed function of `Clone`.
* \param f_as_string The packed function of `AsString`.
* \return The schedule rule created.
*/
TVM_DLL static ScheduleRule PyScheduleRule(
PyScheduleRuleNode::FInitializeWithTuneContext f_initialize_with_tune_context, //
PyScheduleRuleNode::FApply f_apply, //
PyScheduleRuleNode::FAsString f_as_string);
FInitializeWithTuneContext f_initialize_with_tune_context, //
FApply f_apply, //
FClone f_clone, //
FAsString f_as_string);
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ScheduleRule, ObjectRef, ScheduleRuleNode);
};

/*! \brief The schedule rule with customized methods on the python-side. */
class PyScheduleRuleNode : public ScheduleRuleNode {
public:
using FInitializeWithTuneContext = ScheduleRule::FInitializeWithTuneContext;
using FApply = ScheduleRule::FApply;
using FClone = ScheduleRule::FClone;
using FAsString = ScheduleRule::FAsString;

/*! \brief The packed function to the `InitializeWithTuneContext` function. */
FInitializeWithTuneContext f_initialize_with_tune_context;
/*! \brief The packed function to the `Apply` function. */
FApply f_apply;
/*! \brief The packed function to the `AsString` function. */
FAsString f_as_string;
/*! \brief The packed function to the `Clone` function. */
FClone f_clone;

void VisitAttrs(tvm::AttrVisitor* v) {
// `f_initialize_with_tune_context` is not visited
// `f_apply` is not visited
// `f_as_string` is not visited
// `f_clone` is not visited
}

void InitializeWithTuneContext(const TuneContext& context) final;
Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block) final;
ScheduleRule Clone() const final;

static constexpr const char* _type_key = "meta_schedule.PyScheduleRule";
TVM_DECLARE_FINAL_OBJECT_INFO(PyScheduleRuleNode, ScheduleRuleNode);
};

} // namespace meta_schedule
} // namespace tvm

Expand Down
Loading