diff --git a/paddle/fluid/framework/new_executor/new_ir_interpreter.cc b/paddle/fluid/framework/new_executor/new_ir_interpreter.cc index cf677348aecc7..847089dcbf324 100644 --- a/paddle/fluid/framework/new_executor/new_ir_interpreter.cc +++ b/paddle/fluid/framework/new_executor/new_ir_interpreter.cc @@ -344,6 +344,19 @@ void NewIRInterpreter::UpdateSyncOpNum() { VLOG(4) << "Update sync op num, sync op num is: " << sync_op_num_; } +void NewIRInterpreter::UpdateNcclOpNum() { + static std::set nccl_op_set = { + "pd.sync_batch_norm_", "pd.sync_batch_norm", "pd.sync_batch_norm_grad"}; + int64_t nccl_op_num = 0; + for (auto& ins : vec_instruction_base_) { + if (nccl_op_set.count(ins->Name())) { + nccl_op_num = nccl_op_num + 1; + } + } + nccl_op_num_ = nccl_op_num; + VLOG(4) << "Update nccl op num, nccl op num is: " << nccl_op_num; +} + // Note(zhangbo): // When there is a KQueueSync type OP in the model, breadth traversal is better // than depth traversal. For example: OP(O) ->(direct_run)-> OP(A) @@ -852,7 +865,7 @@ FetchList NewIRInterpreter::Run(const std::vector& feed_names, VLOG(4) << "Done PreAnalysis"; // Run - if (FLAGS_enable_new_ir_in_executor_trace_run || + if (FLAGS_enable_new_ir_in_executor_trace_run || nccl_op_num_ > 1 || ((execution_config_.used_for_jit || execution_config_.used_for_cinn) && (sync_op_num_ == 0))) { LOG_FIRST_N(INFO, 1) << "New ir interpreter is running in BetaRun mode " @@ -867,7 +880,7 @@ FetchList NewIRInterpreter::Run(const std::vector& feed_names, is_build_ = true; is_shared_results_build_ = true; } else { - if (FLAGS_enable_new_ir_in_executor_trace_run || + if (FLAGS_enable_new_ir_in_executor_trace_run || nccl_op_num_ > 1 || ((execution_config_.used_for_jit || execution_config_.used_for_cinn) && (sync_op_num_ == 0))) { TraceRunImpl(); @@ -1182,6 +1195,9 @@ void NewIRInterpreter::PreAnalysis() { UpdateSyncOpNum(); VLOG(4) << "Done UpdateSyncOpNum"; + + UpdateNcclOpNum(); + VLOG(4) << "Done UpdateNcclOpNum"; } ::ir::Value NewIRInterpreter::GetValueByName(const std::string& var_name) { diff --git a/paddle/fluid/framework/new_executor/new_ir_interpreter.h b/paddle/fluid/framework/new_executor/new_ir_interpreter.h index 3669d8f8dd970..841e9136a2ecc 100644 --- a/paddle/fluid/framework/new_executor/new_ir_interpreter.h +++ b/paddle/fluid/framework/new_executor/new_ir_interpreter.h @@ -84,6 +84,7 @@ class NewIRInterpreter : public InterpreterBaseImpl { private: // build graph void UpdateSyncOpNum(); + void UpdateNcclOpNum(); void AnalyseExecuteOrderForTrace( std::map> op_downstream_map, InstructionSchedulingPriorityLess compare); @@ -148,6 +149,7 @@ class NewIRInterpreter : public InterpreterBaseImpl { // used for Trace int64_t sync_op_num_{-1}; + int64_t nccl_op_num_{-1}; std::vector trace_execute_order_; std::vector hookfuncs_; diff --git a/paddle/fluid/ir/dialect/op_generator/op_gen.py b/paddle/fluid/ir/dialect/op_generator/op_gen.py index 4248142dc974b..da1d7cbdde090 100644 --- a/paddle/fluid/ir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/ir/dialect/op_generator/op_gen.py @@ -20,10 +20,14 @@ from op_interface_gen import ( gen_exclusive_interface_str, gen_op_infer_meta_str, - vjp_interface_gen_op_list, + gen_op_vjp_str, ) from op_member_func_gen import gen_op_get_inputs_outputs_str from op_verify_gen import gen_verify_func_str +from vjp_interface_gen_op_list import ( + vjp_interface_declare_gen_op_list, + vjp_interface_implementation_gen_op_list, +) # ===================================== # String Template for h file code gen @@ -286,6 +290,9 @@ def __init__(self, op_yaml_item, op_compat_item): self.attribute_build_arg_type_list = ( self.parse_attribute_build_arg_type_list() ) + self.attribute_gen_arg_type_list = ( + self.parse_attribute_gen_arg_type_list() + ) self.attribute_data_type_list = self.parse_attribute_data_type_list() self.attribute_default_value_list = ( self.parse_attribute_default_value_list() @@ -584,6 +591,17 @@ def parse_attribute_build_arg_type_list(self): type_list.append(self.get_phi_dtype_name(temp_type)) return type_list + def parse_attribute_gen_arg_type_list(self): + type_list = [] + for attribute_info in self.op_yaml_item['attrs']: + assert ( + attribute_info['typename'] in self.attr_types_map + ), f"{self.op_phi_name} : Attr type error." + + temp_type = self.attr_types_map[attribute_info['typename']][1] + type_list.append(self.get_phi_dtype_name(temp_type)) + return type_list + def parse_attribute_type_list(self): type_list = [] for attribute_info in self.op_yaml_item['attrs']: @@ -741,7 +759,7 @@ def OpGenerator( if ( op_info.backward_name - and op_info.op_phi_name[0] in vjp_interface_gen_op_list + and op_info.op_phi_name[0] in vjp_interface_declare_gen_op_list ): op_interfaces += ["VjpInterface"] exclusive_interface_str = gen_exclusive_interface_str(op_info) @@ -1038,12 +1056,18 @@ def OpGenerator( op_vjp_str = '' # TODO(chenzhiyang) add vjp gen code - # if op_info.backward_name and op_info.op_phi_name[0] in vjp_interface_gen_op_list: - # op_vjp_str = gen_op_vjp_str(op_class_name, - # op_info.backward_name, - # op_name, - # op_info_items[op_info.op_phi_name[0]], - # op_info_items[op_info.backward_name]) + if ( + op_info.backward_name + and op_info.op_phi_name[0] + in vjp_interface_implementation_gen_op_list + ): + op_vjp_str = gen_op_vjp_str( + op_class_name, + op_info.backward_name, + op_name, + op_info_items[op_info.op_phi_name[0]], + op_info_items[op_info.backward_name], + ) ops_name_list.append(op_class_name) ops_declare_list.append(op_declare_str) diff --git a/paddle/fluid/ir/dialect/op_generator/op_interface_gen.py b/paddle/fluid/ir/dialect/op_generator/op_interface_gen.py index e6bde5bfb0846..8762c6328e1b6 100644 --- a/paddle/fluid/ir/dialect/op_generator/op_interface_gen.py +++ b/paddle/fluid/ir/dialect/op_generator/op_interface_gen.py @@ -13,7 +13,7 @@ # limitations under the License. # generator interfaces -from vjp_interface_gen_op_list import vjp_interface_gen_op_list +from vjp_interface_gen_op_list import vjp_interface_declare_gen_op_list OP_INFER_SHAPE_TEMPLATE = """ void {op_name}::InferMeta( phi::InferMetaContext *infer_meta ) {{ @@ -23,57 +23,61 @@ """ OP_VJP_FORWARD_INPUT_OR_OUTPUT_TEMPLATE = """ - {input_type} {input_name}(std::make_shared(op_obj.{input_name}())); -""" + {input_type} {input_name}(std::make_shared(op_obj.{input_name}()));""" OP_VJP_FORWARD_OUTPUT_GRAD_TEMPLATE = """ - Tensor {output_grad_name}(std::make_shared((out_grads[{idx1}][{idx2}]); -""" + Tensor {output_grad_name}(std::make_shared(out_grads[{idx1}][{idx2}]));""" OP_VJP_FORWARD_OUTPUT_GRAD_LIST_TEMPLATE = """ - std::vector {output_grad_name}(std::make_shared((out_grads[{idx1}]); -""" + std::vector {output_grad_name}(std::make_shared(out_grads[{idx1}]));""" -OP_VJP_CALL_VJP_TEMPLATE = """ - Tensor std::vector> tensor_res = - primitive::{op_phi_name}_vjp({inputs_list}, stop_gradients); -""" +OP_VJP_ATTRIBUTE_TEMPLATE = """ + {attr_type} {attr_name} = op->attribute("{attr_name}").dyn_cast<{attr_parse_type}>().data();""" -OP_VJP_STOPGRADIENT_TEMPLATE = """ - if(!stop_gradients[{idx1}][{idx2}]){{ - res[{idx1}][{idx2}] = std::static_pointer_cast( - tensor_res[idx1][idx2].impl()) - ->getValue() - .dyn_cast(); - }} -""" +OP_VJP_ATTRIBUTE_DEFAULT_TEMPLATE = """ + {attr_type} {attr_name} = {default_value};""" -OP_VJP_DEFINE_TEMPLATE = """ -std::vector> {op_class_name}::Vjp( - ir::Operation* op, - const std::vector>& out_grads, - const std::vector>& stop_gradients){{ - {op_class_name} op_obj = op->dyn_cast<{op_class_name}>(); - VLOG(6) << "Prepare inputs of {op_grad_name}"; +OP_VJP_CALL_VJP_TEMPLATE = """ std::vector> tensor_res = + primitive::{op_phi_name}_vjp( + {inputs_list}stop_gradients);""" - {forward_input_code} - {forward_output_code} - {forward_output_grad_code} +OP_VJP_STOPGRADIENT_TEMPLATE = """ + std::vector> res(tensor_res.size()); + for (size_t i = 0; i < tensor_res.size(); ++i) {{ + res[i].resize(tensor_res[i].size()); + for (size_t j = 0; j < tensor_res[i].size(); ++j) {{ + if(tensor_res[i][j].defined()){{ + res[i][j] = std::static_pointer_cast(tensor_res[i][j].impl())->getValue().dyn_cast(); + }} + }} + }}""" + +OP_VJP_DEFINE_TEMPLATE = """ +std::vector> {op_class_name}::Vjp(ir::Operation* op, const std::vector>& out_grads, const std::vector>& stop_gradients){{ + {op_class_name} op_obj = op->dyn_cast<{op_class_name}>(); - VLOG(6) << "Vjp prepare Prepare attributes of {op_grad_name}"; - {attribute_code} +VLOG(6) << "Prepare inputs of {op_grad_name}"; +{forward_input_code} +{forward_output_grad_code} - VLOG(4) << "Vjp prepare call {op_phi_name}'s vjp inteface"; - {call_vjp_code} +VLOG(6) << "Vjp prepare Prepare attributes of {op_grad_name}"; +{attribute_code} - std::vector> res(1, std::vector(1)); - {stop_gradient_input_grad_code} +VLOG(4) << "Vjp prepare call {op_phi_name}'s vjp inteface"; +{call_vjp_code} - return res; +VLOG(4) << "Vjp prepare stop gradient of {op_grad_name}"; +{stop_gradient_input_grad_code} + return res; }} """ +input_types_map = { + 'paddle::dialect::DenseTensorType': 'Tensor', + 'ir::VectorType': 'Tensor[]', +} + def gen_op_vjp_str( op_class_name, @@ -82,19 +86,62 @@ def gen_op_vjp_str( op_info, op_grad_info, ): + bw_input_list = op_grad_info.input_name_list forward_input_code = '' - forward_output_code = '' forward_output_grad_code = '' + build_args_str = '' + grad_idx = -1 + for idx in range(len(bw_input_list)): + build_args_str += bw_input_list[idx] + ", " + if ( + bw_input_list[idx] in op_info.input_name_list + or bw_input_list[idx] in op_info.output_name_list + ): + forward_input_code += ( + OP_VJP_FORWARD_INPUT_OR_OUTPUT_TEMPLATE.format( + input_type=input_types_map[ + op_grad_info.input_type_list[idx] + ], + input_name=bw_input_list[idx], + ) + ) + else: + grad_idx += 1 + forward_output_grad_code += ( + OP_VJP_FORWARD_OUTPUT_GRAD_TEMPLATE.format( + output_grad_name=bw_input_list[idx], idx1=grad_idx, idx2=0 + ) + ) + op_attribute_list = op_grad_info.attribute_name_list attribute_code = '' - call_vjp_code = '' - stop_gradient_input_grad_code = '' + for idx in range(len(op_attribute_list)): + build_args_str += op_attribute_list[idx] + ", " + if op_attribute_list[idx] in op_info.attribute_name_list: + attribute_code += OP_VJP_ATTRIBUTE_TEMPLATE.format( + attr_type=op_grad_info.attribute_gen_arg_type_list[idx], + attr_name=op_attribute_list[idx], + attr_parse_type=op_grad_info.attribute_type_list[idx], + ) + else: + attribute_code += OP_VJP_ATTRIBUTE_DEFAULT_TEMPLATE.format( + attr_type=op_grad_info.attribute_gen_arg_type_list[idx], + attr_name=op_attribute_list[idx], + default_value=op_grad_info.attribute_default_value_list[idx], + ) + if op_phi_name[-1] == '_': + op_phi_name = op_phi_name[:-1] + call_vjp_code = OP_VJP_CALL_VJP_TEMPLATE.format( + op_phi_name=op_phi_name, + inputs_list=build_args_str, + ) + stop_gradient_input_grad_code = OP_VJP_STOPGRADIENT_TEMPLATE str = OP_VJP_DEFINE_TEMPLATE.format( op_class_name=op_class_name, op_grad_name=op_grad_name, op_phi_name=op_phi_name, + res_size=len(op_info.input_name_list), forward_input_code=forward_input_code, - forward_output_code=forward_output_code, forward_output_grad_code=forward_output_grad_code, attribute_code=attribute_code, call_vjp_code=call_vjp_code, @@ -119,6 +166,6 @@ def gen_exclusive_interface_str(op_info): exclusive_interface_str += ( " static void InferMeta( phi::InferMetaContext *infer_meta );" ) - if op_info.op_phi_name[0] in vjp_interface_gen_op_list: + if op_info.op_phi_name[0] in vjp_interface_declare_gen_op_list: exclusive_interface_str += "\n static std::vector> Vjp(ir::Operation* op, const std::vector>& out_grads, const std::vector>& stop_gradients);" return exclusive_interface_str diff --git a/paddle/fluid/ir/dialect/op_generator/vjp_interface_gen_op_list.py b/paddle/fluid/ir/dialect/op_generator/vjp_interface_gen_op_list.py index c9a866dae1528..fd7d61897d858 100644 --- a/paddle/fluid/ir/dialect/op_generator/vjp_interface_gen_op_list.py +++ b/paddle/fluid/ir/dialect/op_generator/vjp_interface_gen_op_list.py @@ -21,4 +21,5 @@ # TODO(wanghao107) # remove this file and support Vjp methods # code gen. -vjp_interface_gen_op_list = ["tanh", "mean", "divide", "sum", "add"] +vjp_interface_declare_gen_op_list = ["tanh", "mean", "divide", "sum", "add"] +vjp_interface_implementation_gen_op_list = ["tanh", "mean", "divide", "add"] diff --git a/paddle/fluid/ir/dialect/pd_op_vjp_manual.cc b/paddle/fluid/ir/dialect/pd_op_vjp_manual.cc index b08b6dc77f5fb..b41cbdab51991 100644 --- a/paddle/fluid/ir/dialect/pd_op_vjp_manual.cc +++ b/paddle/fluid/ir/dialect/pd_op_vjp_manual.cc @@ -24,100 +24,6 @@ namespace paddle { namespace dialect { -using IntArray = paddle::experimental::IntArray; - -std::vector> TanhOp::Vjp( - ir::Operation* op, - const std::vector>& out_grads, - const std::vector>& stop_gradients) { - TanhOp op_obj = op->dyn_cast(); - Tensor out(std::make_shared(op_obj.out())); - Tensor grad_out(std::make_shared(out_grads[0][0])); - std::vector> tensor_res = - primitive::tanh_vjp(out, grad_out, stop_gradients); - std::vector> res(1, std::vector(1)); - if (tensor_res[0][0].defined()) { - res[0][0] = - std::static_pointer_cast(tensor_res[0][0].impl()) - ->getValue() - .dyn_cast(); - } - return res; -} - -std::vector> Tanh_Op::Vjp( - ir::Operation* op, - const std::vector>& out_grads, - const std::vector>& stop_gradients) { - // TODO(wanghao107) - // we don't support inplace now, - // so use the non-inplace version instead currently. - // Support inplace in the future. - Tanh_Op op_obj = op->dyn_cast(); - Tensor out(std::make_shared(op_obj.out())); - Tensor grad_out(std::make_shared(out_grads[0][0])); - std::vector> tensor_res = - primitive::tanh_vjp(out, grad_out, stop_gradients); - std::vector> res(1, std::vector(1)); - if (tensor_res[0][0].defined()) { - res[0][0] = - std::static_pointer_cast(tensor_res[0][0].impl()) - ->getValue() - .dyn_cast(); - } - return res; -} - -std::vector> MeanOp::Vjp( - ir::Operation* op, - const std::vector>& out_grads, - const std::vector>& stop_gradients) { - MeanOp op_obj = op->dyn_cast(); - Tensor x(std::make_shared(op_obj.x())); - Tensor out_grad(std::make_shared(out_grads[0][0])); - - IntArray axis = op->attribute("axis") - .dyn_cast() - .data(); - bool keepdim = op->attribute("keepdim").dyn_cast().data(); - bool reduce_all = false; - std::vector> tensor_res = primitive::mean_vjp( - x, out_grad, axis, keepdim, reduce_all, stop_gradients); - std::vector> res(1, std::vector(1)); - if (tensor_res[0][0].defined()) { - res[0][0] = - std::static_pointer_cast(tensor_res[0][0].impl()) - ->getValue() - .dyn_cast(); - } - return res; -} - -std::vector> DivideOp::Vjp( - ir::Operation* op, - const std::vector>& out_grads, - const std::vector>& stop_gradients) { - DivideOp op_obj = op->dyn_cast(); - Tensor x(std::make_shared(op_obj.x())); - Tensor y(std::make_shared(op_obj.y())); - Tensor out(std::make_shared(op_obj.out())); - Tensor out_grad(std::make_shared(out_grads[0][0])); - - int axis = -1; - std::vector> tensor_res = - primitive::divide_vjp(x, y, out, out_grad, axis, stop_gradients); - std::vector> res(2, std::vector(1)); - for (size_t i = 0; i < 2; ++i) { - if (tensor_res[i][0].defined()) { - res[i][0] = std::static_pointer_cast( - tensor_res[i][0].impl()) - ->getValue() - .dyn_cast(); - } - } - return res; -} - std::vector> SumOp::Vjp( ir::Operation* op, const std::vector>& out_grads, @@ -144,53 +50,5 @@ std::vector> SumOp::Vjp( } return res; } - -std::vector> AddOp::Vjp( - ir::Operation* op, - const std::vector>& out_grads, - const std::vector>& stop_gradients) { - AddOp op_obj = op->dyn_cast(); - Tensor x(std::make_shared(op_obj.x())); - Tensor y(std::make_shared(op_obj.y())); - Tensor out_grad(std::make_shared(out_grads[0][0])); - int axis = -1; - - std::vector> tensor_res = - primitive::add_vjp(x, y, out_grad, axis, stop_gradients); - std::vector> res(2, std::vector(1)); - for (size_t i = 0; i < 2; ++i) { - if (tensor_res[i][0].defined()) { - res[i][0] = std::static_pointer_cast( - tensor_res[i][0].impl()) - ->getValue() - .dyn_cast(); - } - } - return res; -} - -std::vector> Add_Op::Vjp( - ir::Operation* op, - const std::vector>& out_grads, - const std::vector>& stop_gradients) { - Add_Op op_obj = op->dyn_cast(); - Tensor x(std::make_shared(op_obj.x())); - Tensor y(std::make_shared(op_obj.y())); - Tensor out_grad(std::make_shared(out_grads[0][0])); - int axis = -1; - - std::vector> tensor_res = - primitive::add_vjp(x, y, out_grad, axis, stop_gradients); - std::vector> res(2, std::vector(1)); - for (size_t i = 0; i < 2; ++i) { - if (tensor_res[i][0].defined()) { - res[i][0] = std::static_pointer_cast( - tensor_res[i][0].impl()) - ->getValue() - .dyn_cast(); - } - } - return res; -} } // namespace dialect } // namespace paddle diff --git a/python/paddle/text/datasets/wmt16.py b/python/paddle/text/datasets/wmt16.py index 13e610bfd62cb..79bd13c9538e5 100644 --- a/python/paddle/text/datasets/wmt16.py +++ b/python/paddle/text/datasets/wmt16.py @@ -59,13 +59,13 @@ class WMT16(Dataset): Args: data_file(str): path to data tar file, can be set None if - :attr:`download` is True. Default None - mode(str): 'train', 'test' or 'val'. Default 'train' + :attr:`download` is True. Default None. + mode(str): 'train', 'test' or 'val'. Default 'train'. src_dict_size(int): word dictionary size for source language word. Default -1. trg_dict_size(int): word dictionary size for target language word. Default -1. lang(str): source language, 'en' or 'de'. Default 'en'. download(bool): whether to download dataset automatically if - :attr:`data_file` is not set. Default True + :attr:`data_file` is not set. Default True. Returns: Dataset: Instance of WMT16 dataset. The instance of dataset has 3 fields: @@ -77,30 +77,37 @@ class WMT16(Dataset): .. code-block:: python - import paddle - from paddle.text.datasets import WMT16 - - class SimpleNet(paddle.nn.Layer): - def __init__(self): - super().__init__() - - def forward(self, src_ids, trg_ids, trg_ids_next): - return paddle.sum(src_ids), paddle.sum(trg_ids), paddle.sum(trg_ids_next) - - paddle.disable_static() - - wmt16 = WMT16(mode='train', src_dict_size=50, trg_dict_size=50) - - for i in range(10): - src_ids, trg_ids, trg_ids_next = wmt16[i] - src_ids = paddle.to_tensor(src_ids) - trg_ids = paddle.to_tensor(trg_ids) - trg_ids_next = paddle.to_tensor(trg_ids_next) - - model = SimpleNet() - src_ids, trg_ids, trg_ids_next = model(src_ids, trg_ids, trg_ids_next) - print(src_ids.numpy(), trg_ids.numpy(), trg_ids_next.numpy()) - + >>> import paddle + >>> from paddle.text.datasets import WMT16 + + >>> class SimpleNet(paddle.nn.Layer): + ... def __init__(self): + ... super().__init__() + ... + ... def forward(self, src_ids, trg_ids, trg_ids_next): + ... return paddle.sum(src_ids), paddle.sum(trg_ids), paddle.sum(trg_ids_next) + + >>> wmt16 = WMT16(mode='train', src_dict_size=50, trg_dict_size=50) + + >>> for i in range(10): + ... src_ids, trg_ids, trg_ids_next = wmt16[i] + ... src_ids = paddle.to_tensor(src_ids) + ... trg_ids = paddle.to_tensor(trg_ids) + ... trg_ids_next = paddle.to_tensor(trg_ids_next) + ... + ... model = SimpleNet() + ... src_ids, trg_ids, trg_ids_next = model(src_ids, trg_ids, trg_ids_next) + ... print(src_ids.item(), trg_ids.item(), trg_ids_next.item()) + 89 32 33 + 79 18 19 + 55 26 27 + 147 36 37 + 106 22 23 + 135 50 51 + 54 43 44 + 217 30 31 + 146 51 52 + 55 24 25 """ def __init__( @@ -257,9 +264,9 @@ def get_dict(self, lang, reverse=False): .. code-block:: python - from paddle.text.datasets import WMT16 - wmt16 = WMT16(mode='train', src_dict_size=50, trg_dict_size=50) - en_dict = wmt16.get_dict('en') + >>> from paddle.text.datasets import WMT16 + >>> wmt16 = WMT16(mode='train', src_dict_size=50, trg_dict_size=50) + >>> en_dict = wmt16.get_dict('en') """ dict_size = ( diff --git a/python/paddle/text/viterbi_decode.py b/python/paddle/text/viterbi_decode.py index d0e8d120faa79..591f7ae6033e5 100644 --- a/python/paddle/text/viterbi_decode.py +++ b/python/paddle/text/viterbi_decode.py @@ -42,20 +42,27 @@ def viterbi_decode( Returns: scores(Tensor): The output tensor containing the score for the Viterbi sequence. The shape is [batch_size] and the data type is float32 or float64. - paths(Tensor): The output tensor containing the highest scoring tag indices. The shape is [batch_size, sequence_length] - and the data type is int64. + paths(Tensor): The output tensor containing the highest scoring tag indices. The shape is [batch_size, sequence_length] + and the data type is int64. - Example: + Examples: .. code-block:: python - import paddle - paddle.seed(102) - batch_size, seq_len, num_tags = 2, 4, 3 - emission = paddle.rand((batch_size, seq_len, num_tags), dtype='float32') - length = paddle.randint(1, seq_len + 1, [batch_size]) - tags = paddle.randint(0, num_tags, [batch_size, seq_len]) - transition = paddle.rand((num_tags, num_tags), dtype='float32') - scores, path = paddle.text.viterbi_decode(emission, transition, length, False) # scores: [3.37089300, 1.56825531], path: [[1, 0, 0], [1, 1, 0]] + >>> import paddle + >>> paddle.seed(2023) + >>> batch_size, seq_len, num_tags = 2, 4, 3 + >>> emission = paddle.rand((batch_size, seq_len, num_tags), dtype='float32') + >>> length = paddle.randint(1, seq_len + 1, [batch_size]) + >>> tags = paddle.randint(0, num_tags, [batch_size, seq_len]) + >>> transition = paddle.rand((num_tags, num_tags), dtype='float32') + >>> scores, path = paddle.text.viterbi_decode(emission, transition, length, False) + >>> print(scores) + Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True, + [2.57385254, 2.04533720]) + >>> print(path) + Tensor(shape=[2, 2], dtype=int64, place=Place(cpu), stop_gradient=True, + [[0, 0], + [1, 1]]) """ if in_dygraph_mode(): return _C_ops.viterbi_decode( @@ -95,7 +102,7 @@ class ViterbiDecoder(Layer): Decode the highest scoring sequence of tags computed by transitions and potentials and get the viterbi path. Args: - transitions (`Tensor`): The transition matrix. Its dtype is float32 and has a shape of `[num_tags, num_tags]`. + transitions (`Tensor`): The transition matrix. Its dtype is float32 and has a shape of `[num_tags, num_tags]`. include_bos_eos_tag (`bool`, optional): If set to True, the last row and the last column of transitions will be considered as start tag, the second to last row and the second to last column of transitions will be considered as stop tag. Defaults to ``True``. name (str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please @@ -104,27 +111,34 @@ class ViterbiDecoder(Layer): Shape: potentials (Tensor): The input tensor of unary emission. This is a 3-D tensor with shape of [batch_size, sequence_length, num_tags]. The data type is float32 or float64. - lengths (Tensor): The input tensor of length of each sequence. This is a 1-D tensor with shape of + lengths (Tensor): The input tensor of length of each sequence. This is a 1-D tensor with shape of [batch_size]. The data type is int64. Returns: scores(Tensor): The output tensor containing the score for the Viterbi sequence. The shape is [batch_size] and the data type is float32 or float64. - paths(Tensor): The output tensor containing the highest scoring tag indices. The shape is [batch_size, sequence_length] + paths(Tensor): The output tensor containing the highest scoring tag indices. The shape is [batch_size, sequence_length] and the data type is int64. - Example: + Examples: .. code-block:: python - import paddle - paddle.seed(102) - batch_size, seq_len, num_tags = 2, 4, 3 - emission = paddle.rand((batch_size, seq_len, num_tags), dtype='float32') - length = paddle.randint(1, seq_len + 1, [batch_size]) - tags = paddle.randint(0, num_tags, [batch_size, seq_len]) - transition = paddle.rand((num_tags, num_tags), dtype='float32') - decoder = paddle.text.ViterbiDecoder(transition, include_bos_eos_tag=False) - scores, path = decoder(emission, length) # scores: [3.37089300, 1.56825531], path: [[1, 0, 0], [1, 1, 0]] + >>> import paddle + >>> paddle.seed(2023) + >>> batch_size, seq_len, num_tags = 2, 4, 3 + >>> emission = paddle.rand((batch_size, seq_len, num_tags), dtype='float32') + >>> length = paddle.randint(1, seq_len + 1, [batch_size]) + >>> tags = paddle.randint(0, num_tags, [batch_size, seq_len]) + >>> transition = paddle.rand((num_tags, num_tags), dtype='float32') + >>> decoder = paddle.text.ViterbiDecoder(transition, include_bos_eos_tag=False) + >>> scores, path = decoder(emission, length) + >>> print(scores) + Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True, + [2.57385254, 2.04533720]) + >>> print(path) + Tensor(shape=[2, 2], dtype=int64, place=Place(cpu), stop_gradient=True, + [[0, 0], + [1, 1]]) """ def __init__(self, transitions, include_bos_eos_tag=True, name=None): diff --git a/python/paddle/vision/models/_utils.py b/python/paddle/vision/models/_utils.py index a3b68363ebf4e..a0a4bc95dd886 100644 --- a/python/paddle/vision/models/_utils.py +++ b/python/paddle/vision/models/_utils.py @@ -21,13 +21,13 @@ def _make_divisible(v, divisor=8, min_value=None): """ - This function ensures that all layers have a channel number that is divisible by divisor + This function ensures that all layers have a channel number that is divisible by divisor. You can also see at https://github.com/keras-team/keras/blob/8ecef127f70db723c158dbe9ed3268b3d610ab55/keras/applications/mobilenet_v2.py#L505 Args: - divisor (int): The divisor for number of channels. Default: 8. + divisor (int, optional): The divisor for number of channels. Default: 8. min_value (int, optional): The minimum value of number of channels, if it is None, - the default is divisor. Default: None. + the default is divisor. Default: None. """ if min_value is None: min_value = divisor @@ -50,22 +50,25 @@ class IntermediateLayerGetter(nn.LayerDict): So if `model` is passed, `model.feature1` can be returned, but not `model.feature1.layer2`. Args: - model (nn.Layer): model on which we will extract the features - return_layers (Dict[name, new_name]): a dict containing the names of the layers for + + model (nn.Layer): Model on which we will extract the features. + return_layers (Dict[name, new_name]): A dict containing the names of the layers for which the activations will be returned as the key of the dict, and the value of the dict is the name of the returned activation (which the user can specify). Examples: + .. code-block:: python - import paddle - m = paddle.vision.models.resnet18(pretrained=False) - # extract layer1 and layer3, giving as names `feat1` and feat2` - new_m = paddle.vision.models._utils.IntermediateLayerGetter(m, - {'layer1': 'feat1', 'layer3': 'feat2'}) - out = new_m(paddle.rand([1, 3, 224, 224])) - print([(k, v.shape) for k, v in out.items()]) - # [('feat1', [1, 64, 56, 56]), ('feat2', [1, 256, 14, 14])] + >>> import paddle + >>> m = paddle.vision.models.resnet18(pretrained=False) + + >>> # extract layer1 and layer3, giving as names `feat1` and feat2` + >>> new_m = paddle.vision.models._utils.IntermediateLayerGetter(m, + ... {'layer1': 'feat1', 'layer3': 'feat2'}) + >>> out = new_m(paddle.rand([1, 3, 224, 224])) + >>> print([(k, v.shape) for k, v in out.items()]) + [('feat1', [1, 64, 56, 56]), ('feat2', [1, 256, 14, 14])] """ __annotations__ = { diff --git a/python/paddle/vision/models/alexnet.py b/python/paddle/vision/models/alexnet.py index 4239395c03319..7a6e2b0328ae5 100644 --- a/python/paddle/vision/models/alexnet.py +++ b/python/paddle/vision/models/alexnet.py @@ -75,7 +75,7 @@ class AlexNet(nn.Layer): Args: num_classes (int, optional): Output dim of last fc layer. If num_classes <= 0, last fc layer - will not be defined. Default: 1000. + will not be defined. Default: 1000. Returns: :ref:`api_paddle_nn_Layer`. An instance of AlexNet model. @@ -83,16 +83,14 @@ class AlexNet(nn.Layer): Examples: .. code-block:: python - import paddle - from paddle.vision.models import AlexNet + >>> import paddle + >>> from paddle.vision.models import AlexNet - alexnet = AlexNet() - - x = paddle.rand([1, 3, 224, 224]) - out = alexnet(x) - - print(out.shape) - # [1, 1000] + >>> alexnet = AlexNet() + >>> x = paddle.rand([1, 3, 224, 224]) + >>> out = alexnet(x) + >>> print(out.shape) + [1, 1000] """ def __init__(self, num_classes=1000): @@ -197,7 +195,7 @@ def alexnet(pretrained=False, **kwargs): Args: pretrained (bool, optional): Whether to load pre-trained weights. If True, returns a model pre-trained - on ImageNet. Default: False. + on ImageNet. Default: False. **kwargs (optional): Additional keyword arguments. For details, please refer to :ref:`AlexNet `. Returns: @@ -206,19 +204,19 @@ def alexnet(pretrained=False, **kwargs): Examples: .. code-block:: python - import paddle - from paddle.vision.models import alexnet + >>> import paddle + >>> from paddle.vision.models import alexnet - # build model - model = alexnet() + >>> # Build model + >>> model = alexnet() - # build model and load imagenet pretrained weight - # model = alexnet(pretrained=True) + >>> # Build model and load imagenet pretrained weight + >>> # model = alexnet(pretrained=True) - x = paddle.rand([1, 3, 224, 224]) - out = model(x) + >>> x = paddle.rand([1, 3, 224, 224]) + >>> out = model(x) - print(out.shape) - # [1, 1000] + >>> print(out.shape) + [1, 1000] """ return _alexnet('alexnet', pretrained, **kwargs) diff --git a/python/paddle/vision/models/densenet.py b/python/paddle/vision/models/densenet.py index e5ab689d8465a..ccf6573f5588a 100644 --- a/python/paddle/vision/models/densenet.py +++ b/python/paddle/vision/models/densenet.py @@ -209,7 +209,7 @@ class DenseNet(nn.Layer): bn_size (int, optional): Expansion of growth rate in the middle layer. Default: 4. dropout (float, optional): Dropout rate. Default: :math:`0.0`. num_classes (int, optional): Output dim of last fc layer. If num_classes <= 0, last fc layer - will not be defined. Default: 1000. + will not be defined. Default: 1000. with_pool (bool, optional): Use pool before the last fc layer or not. Default: True. Returns: @@ -218,17 +218,17 @@ class DenseNet(nn.Layer): Examples: .. code-block:: python - import paddle - from paddle.vision.models import DenseNet + >>> import paddle + >>> from paddle.vision.models import DenseNet - # build model - densenet = DenseNet() + >>> # Build model + >>> densenet = DenseNet() - x = paddle.rand([1, 3, 224, 224]) - out = densenet(x) + >>> x = paddle.rand([1, 3, 224, 224]) + >>> out = densenet(x) - print(out.shape) - # [1, 1000] + >>> print(out.shape) + [1, 1000] """ def __init__( @@ -360,7 +360,7 @@ def densenet121(pretrained=False, **kwargs): Args: pretrained (bool, optional): Whether to load pre-trained weights. If True, returns a model pre-trained - on ImageNet. Default: False. + on ImageNet. Default: False. **kwargs (optional): Additional keyword arguments. For details, please refer to :ref:`DenseNet `. Returns: @@ -369,20 +369,20 @@ def densenet121(pretrained=False, **kwargs): Examples: .. code-block:: python - import paddle - from paddle.vision.models import densenet121 + >>> import paddle + >>> from paddle.vision.models import densenet121 - # build model - model = densenet121() + >>> # Build model + >>> model = densenet121() - # build model and load imagenet pretrained weight - # model = densenet121(pretrained=True) + >>> # Build model and load imagenet pretrained weight + >>> # model = densenet121(pretrained=True) - x = paddle.rand([1, 3, 224, 224]) - out = model(x) + >>> x = paddle.rand([1, 3, 224, 224]) + >>> out = model(x) - print(out.shape) - # [1, 1000] + >>> print(out.shape) + [1, 1000] """ return _densenet('densenet121', 121, pretrained, **kwargs) @@ -393,7 +393,7 @@ def densenet161(pretrained=False, **kwargs): Args: pretrained (bool, optional): Whether to load pre-trained weights. If True, returns a model pre-trained - on ImageNet. Default: False. + on ImageNet. Default: False. **kwargs (optional): Additional keyword arguments. For details, please refer to :ref:`DenseNet `. Returns: @@ -402,13 +402,20 @@ def densenet161(pretrained=False, **kwargs): Examples: .. code-block:: python - from paddle.vision.models import densenet161 + >>> import paddle + >>> from paddle.vision.models import densenet161 - # build model - model = densenet161() + >>> # Build model + >>> model = densenet161() - # build model and load imagenet pretrained weight - # model = densenet161(pretrained=True) + >>> # Build model and load imagenet pretrained weight + >>> # model = densenet161(pretrained=True) + + >>> x = paddle.rand([1, 3, 224, 224]) + >>> out = model(x) + + >>> print(out.shape) + [1, 1000] """ return _densenet('densenet161', 161, pretrained, **kwargs) @@ -419,7 +426,7 @@ def densenet169(pretrained=False, **kwargs): Args: pretrained (bool, optional): Whether to load pre-trained weights. If True, returns a model pre-trained - on ImageNet. Default: False. + on ImageNet. Default: False. **kwargs (optional): Additional keyword arguments. For details, please refer to :ref:`DenseNet `. Returns: @@ -428,20 +435,20 @@ def densenet169(pretrained=False, **kwargs): Examples: .. code-block:: python - import paddle - from paddle.vision.models import densenet169 + >>> import paddle + >>> from paddle.vision.models import densenet169 - # build model - model = densenet169() + >>> # Build model + >>> model = densenet169() - # build model and load imagenet pretrained weight - # model = densenet169(pretrained=True) + >>> # Build model and load imagenet pretrained weight + >>> # model = densenet169(pretrained=True) - x = paddle.rand([1, 3, 224, 224]) - out = model(x) + >>> x = paddle.rand([1, 3, 224, 224]) + >>> out = model(x) - print(out.shape) - # [1, 1000] + >>> print(out.shape) + [1, 1000] """ return _densenet('densenet169', 169, pretrained, **kwargs) @@ -452,7 +459,7 @@ def densenet201(pretrained=False, **kwargs): Args: pretrained (bool, optional): Whether to load pre-trained weights. If True, returns a model pre-trained - on ImageNet. Default: False. + on ImageNet. Default: False. **kwargs (optional): Additional keyword arguments. For details, please refer to :ref:`DenseNet `. Returns: @@ -461,19 +468,19 @@ def densenet201(pretrained=False, **kwargs): Examples: .. code-block:: python - import paddle - from paddle.vision.models import densenet201 + >>> import paddle + >>> from paddle.vision.models import densenet201 - # build model - model = densenet201() + >>> # Build model + >>> model = densenet201() - # build model and load imagenet pretrained weight - # model = densenet201(pretrained=True) - x = paddle.rand([1, 3, 224, 224]) - out = model(x) + >>> # Build model and load imagenet pretrained weight + >>> # model = densenet201(pretrained=True) + >>> x = paddle.rand([1, 3, 224, 224]) + >>> out = model(x) - print(out.shape) - # [1, 1000] + >>> print(out.shape) + [1, 1000] """ return _densenet('densenet201', 201, pretrained, **kwargs) @@ -484,7 +491,7 @@ def densenet264(pretrained=False, **kwargs): Args: pretrained (bool, optional): Whether to load pre-trained weights. If True, returns a model pre-trained - on ImageNet. Default: False. + on ImageNet. Default: False. **kwargs (optional): Additional keyword arguments. For details, please refer to :ref:`DenseNet `. Returns: @@ -493,19 +500,19 @@ def densenet264(pretrained=False, **kwargs): Examples: .. code-block:: python - import paddle - from paddle.vision.models import densenet264 + >>> import paddle + >>> from paddle.vision.models import densenet264 - # build model - model = densenet264() + >>> # Build model + >>> model = densenet264() - # build model and load imagenet pretrained weight - # model = densenet264(pretrained=True) + >>> # Build model and load imagenet pretrained weight + >>> # model = densenet264(pretrained=True) - x = paddle.rand([1, 3, 224, 224]) - out = model(x) + >>> x = paddle.rand([1, 3, 224, 224]) + >>> out = model(x) - print(out.shape) - # [1, 1000] + >>> print(out.shape) + [1, 1000] """ return _densenet('densenet264', 264, pretrained, **kwargs) diff --git a/python/paddle/vision/models/googlenet.py b/python/paddle/vision/models/googlenet.py index b5e6e72fc5301..617ce182c5039 100644 --- a/python/paddle/vision/models/googlenet.py +++ b/python/paddle/vision/models/googlenet.py @@ -110,7 +110,7 @@ class GoogLeNet(nn.Layer): Args: num_classes (int, optional): Output dim of last fc layer. If num_classes <= 0, last fc layer - will not be defined. Default: 1000. + will not be defined. Default: 1000. with_pool (bool, optional): Use pool before the last fc layer or not. Default: True. Returns: @@ -119,17 +119,17 @@ class GoogLeNet(nn.Layer): Examples: .. code-block:: python - import paddle - from paddle.vision.models import GoogLeNet + >>> import paddle + >>> from paddle.vision.models import GoogLeNet - # build model - model = GoogLeNet() + >>> # Build model + >>> model = GoogLeNet() - x = paddle.rand([1, 3, 224, 224]) - out, out1, out2 = model(x) + >>> x = paddle.rand([1, 3, 224, 224]) + >>> out, out1, out2 = model(x) - print(out.shape, out1.shape, out2.shape) - # [1, 1000] [1, 1000] [1, 1000] + >>> print(out.shape, out1.shape, out2.shape) + [1, 1000] [1, 1000] [1, 1000] """ def __init__(self, num_classes=1000, with_pool=True): @@ -236,7 +236,7 @@ def googlenet(pretrained=False, **kwargs): Args: pretrained (bool, optional): Whether to load pre-trained weights. If True, returns a model pre-trained - on ImageNet. Default: False. + on ImageNet. Default: False. **kwargs (optional): Additional keyword arguments. For details, please refer to :ref:`GoogLeNet `. Returns: @@ -245,20 +245,20 @@ def googlenet(pretrained=False, **kwargs): Examples: .. code-block:: python - import paddle - from paddle.vision.models import googlenet + >>> import paddle + >>> from paddle.vision.models import googlenet - # build model - model = googlenet() + >>> # Build model + >>> model = googlenet() - # build model and load imagenet pretrained weight - # model = googlenet(pretrained=True) + >>> # Build model and load imagenet pretrained weight + >>> # model = googlenet(pretrained=True) - x = paddle.rand([1, 3, 224, 224]) - out, out1, out2 = model(x) + >>> x = paddle.rand([1, 3, 224, 224]) + >>> out, out1, out2 = model(x) - print(out.shape, out1.shape, out2.shape) - # [1, 1000] [1, 1000] [1, 1000] + >>> print(out.shape, out1.shape, out2.shape) + [1, 1000] [1, 1000] [1, 1000] """ model = GoogLeNet(**kwargs) arch = "googlenet" diff --git a/python/paddle/vision/models/inceptionv3.py b/python/paddle/vision/models/inceptionv3.py index edc1ce42d9732..9482d7b12e208 100644 --- a/python/paddle/vision/models/inceptionv3.py +++ b/python/paddle/vision/models/inceptionv3.py @@ -491,7 +491,7 @@ class InceptionV3(nn.Layer): Args: num_classes (int, optional): Output dim of last fc layer. If num_classes <= 0, last fc layer - will not be defined. Default: 1000. + will not be defined. Default: 1000. with_pool (bool, optional): Use pool before the last fc layer or not. Default: True. Returns: @@ -500,16 +500,16 @@ class InceptionV3(nn.Layer): Examples: .. code-block:: python - import paddle - from paddle.vision.models import InceptionV3 + >>> import paddle + >>> from paddle.vision.models import InceptionV3 - inception_v3 = InceptionV3() + >>> inception_v3 = InceptionV3() - x = paddle.rand([1, 3, 299, 299]) - out = inception_v3(x) + >>> x = paddle.rand([1, 3, 299, 299]) + >>> out = inception_v3(x) - print(out.shape) - # [1, 1000] + >>> print(out.shape) + [1, 1000] """ def __init__(self, num_classes=1000, with_pool=True): @@ -591,7 +591,7 @@ def inception_v3(pretrained=False, **kwargs): Args: pretrained (bool, optional): Whether to load pre-trained weights. If True, returns a model pre-trained - on ImageNet. Default: False. + on ImageNet. Default: False. **kwargs (optional): Additional keyword arguments. For details, please refer to :ref:`InceptionV3 `. Returns: @@ -600,20 +600,20 @@ def inception_v3(pretrained=False, **kwargs): Examples: .. code-block:: python - import paddle - from paddle.vision.models import inception_v3 + >>> import paddle + >>> from paddle.vision.models import inception_v3 - # build model - model = inception_v3() + >>> # Build model + >>> model = inception_v3() - # build model and load imagenet pretrained weight - # model = inception_v3(pretrained=True) + >>> # Build model and load imagenet pretrained weight + >>> # model = inception_v3(pretrained=True) - x = paddle.rand([1, 3, 299, 299]) - out = model(x) + >>> x = paddle.rand([1, 3, 299, 299]) + >>> out = model(x) - print(out.shape) - # [1, 1000] + >>> print(out.shape) + [1, 1000] """ model = InceptionV3(**kwargs) arch = "inception_v3" diff --git a/python/paddle/vision/models/lenet.py b/python/paddle/vision/models/lenet.py index 64854e65b9c20..75d17d9ed80c2 100644 --- a/python/paddle/vision/models/lenet.py +++ b/python/paddle/vision/models/lenet.py @@ -24,7 +24,7 @@ class LeNet(nn.Layer): Args: num_classes (int, optional): Output dim of last fc layer. If num_classes <= 0, last fc layer - will not be defined. Default: 10. + will not be defined. Default: 10. Returns: :ref:`api_paddle_nn_Layer`. An instance of LeNet model. @@ -32,16 +32,16 @@ class LeNet(nn.Layer): Examples: .. code-block:: python - import paddle - from paddle.vision.models import LeNet + >>> import paddle + >>> from paddle.vision.models import LeNet - model = LeNet() + >>> model = LeNet() - x = paddle.rand([1, 1, 28, 28]) - out = model(x) + >>> x = paddle.rand([1, 1, 28, 28]) + >>> out = model(x) - print(out.shape) - # [1, 10] + >>> print(out.shape) + [1, 10] """ def __init__(self, num_classes=10): diff --git a/python/paddle/vision/models/mobilenetv1.py b/python/paddle/vision/models/mobilenetv1.py index aed3407c5df0f..3c55d971a4cbc 100644 --- a/python/paddle/vision/models/mobilenetv1.py +++ b/python/paddle/vision/models/mobilenetv1.py @@ -70,7 +70,7 @@ class MobileNetV1(nn.Layer): Args: scale (float, optional): Scale of channels in each layer. Default: 1.0. num_classes (int, optional): Output dim of last fc layer. If num_classes <= 0, last fc layer - will not be defined. Default: 1000. + will not be defined. Default: 1000. with_pool (bool, optional): Use pool before the last fc layer or not. Default: True. Returns: @@ -79,16 +79,16 @@ class MobileNetV1(nn.Layer): Examples: .. code-block:: python - import paddle - from paddle.vision.models import MobileNetV1 + >>> import paddle + >>> from paddle.vision.models import MobileNetV1 - model = MobileNetV1() + >>> model = MobileNetV1() - x = paddle.rand([1, 3, 224, 224]) - out = model(x) + >>> x = paddle.rand([1, 3, 224, 224]) + >>> out = model(x) - print(out.shape) - # [1, 1000] + >>> print(out.shape) + [1, 1000] """ def __init__(self, scale=1.0, num_classes=1000, with_pool=True): @@ -268,7 +268,7 @@ def mobilenet_v1(pretrained=False, scale=1.0, **kwargs): Args: pretrained (bool, optional): Whether to load pre-trained weights. If True, returns a model pre-trained - on ImageNet. Default: False. + on ImageNet. Default: False. scale (float, optional): Scale of channels in each layer. Default: 1.0. **kwargs (optional): Additional keyword arguments. For details, please refer to :ref:`MobileNetV1 `. @@ -278,23 +278,23 @@ def mobilenet_v1(pretrained=False, scale=1.0, **kwargs): Examples: .. code-block:: python - import paddle - from paddle.vision.models import mobilenet_v1 + >>> import paddle + >>> from paddle.vision.models import mobilenet_v1 - # build model - model = mobilenet_v1() + >>> # Build model + >>> model = mobilenet_v1() - # build model and load imagenet pretrained weight - # model = mobilenet_v1(pretrained=True) + >>> # Build model and load imagenet pretrained weight + >>> # model = mobilenet_v1(pretrained=True) - # build mobilenet v1 with scale=0.5 - model_scale = mobilenet_v1(scale=0.5) + >>> # build mobilenet v1 with scale=0.5 + >>> model_scale = mobilenet_v1(scale=0.5) - x = paddle.rand([1, 3, 224, 224]) - out = model(x) + >>> x = paddle.rand([1, 3, 224, 224]) + >>> out = model(x) - print(out.shape) - # [1, 1000] + >>> print(out.shape) + [1, 1000] """ model = _mobilenet( 'mobilenetv1_' + str(scale), pretrained, scale=scale, **kwargs diff --git a/python/paddle/vision/models/mobilenetv2.py b/python/paddle/vision/models/mobilenetv2.py index 47051ff4ab9c3..bc6bc5d49e915 100644 --- a/python/paddle/vision/models/mobilenetv2.py +++ b/python/paddle/vision/models/mobilenetv2.py @@ -81,7 +81,7 @@ class MobileNetV2(nn.Layer): Args: scale (float, optional): Scale of channels in each layer. Default: 1.0. num_classes (int, optional): Output dim of last fc layer. If num_classes <= 0, last fc layer - will not be defined. Default: 1000. + will not be defined. Default: 1000. with_pool (bool, optional): Use pool before the last fc layer or not. Default: True. Returns: @@ -90,16 +90,16 @@ class MobileNetV2(nn.Layer): Examples: .. code-block:: python - import paddle - from paddle.vision.models import MobileNetV2 + >>> import paddle + >>> from paddle.vision.models import MobileNetV2 - model = MobileNetV2() + >>> model = MobileNetV2() - x = paddle.rand([1, 3, 224, 224]) - out = model(x) + >>> x = paddle.rand([1, 3, 224, 224]) + >>> out = model(x) - print(out.shape) - # [1, 1000] + >>> print(out.shape) + [1, 1000] """ def __init__(self, scale=1.0, num_classes=1000, with_pool=True): @@ -206,8 +206,7 @@ def mobilenet_v2(pretrained=False, scale=1.0, **kwargs): `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" `_. Args: - pretrained (bool, optional): Whether to load pre-trained weights. If True, returns a model pre-trained - on ImageNet. Default: False. + pretrained (bool, optional): Whether to load pre-trained weights. If True, returns a model pre-trained on ImageNet. Default: False. scale (float, optional): Scale of channels in each layer. Default: 1.0. **kwargs (optional): Additional keyword arguments. For details, please refer to :ref:`MobileNetV2 `. @@ -217,23 +216,23 @@ def mobilenet_v2(pretrained=False, scale=1.0, **kwargs): Examples: .. code-block:: python - import paddle - from paddle.vision.models import mobilenet_v2 + >>> import paddle + >>> from paddle.vision.models import mobilenet_v2 - # build model - model = mobilenet_v2() + >>> # Build model + >>> model = mobilenet_v2() - # build model and load imagenet pretrained weight - # model = mobilenet_v2(pretrained=True) + >>> # Build model and load imagenet pretrained weight + >>> # model = mobilenet_v2(pretrained=True) - # build mobilenet v2 with scale=0.5 - model = mobilenet_v2(scale=0.5) + >>> # Build mobilenet v2 with scale=0.5 + >>> model = mobilenet_v2(scale=0.5) - x = paddle.rand([1, 3, 224, 224]) - out = model(x) + >>> x = paddle.rand([1, 3, 224, 224]) + >>> out = model(x) - print(out.shape) - # [1, 1000] + >>> print(out.shape) + [1, 1000] """ model = _mobilenet( 'mobilenetv2_' + str(scale), pretrained, scale=scale, **kwargs diff --git a/python/paddle/vision/models/mobilenetv3.py b/python/paddle/vision/models/mobilenetv3.py index 98236bec695fc..a35058c9243f0 100644 --- a/python/paddle/vision/models/mobilenetv3.py +++ b/python/paddle/vision/models/mobilenetv3.py @@ -41,11 +41,12 @@ class SqueezeExcitation(nn.Layer): Parameters ``activation``, and ``scale_activation`` correspond to ``delta`` and ``sigma`` in eq. 3. This code is based on the torchvision code with modifications. You can also see at https://github.com/pytorch/vision/blob/main/torchvision/ops/misc.py#L127 + Args: - input_channels (int): Number of channels in the input image - squeeze_channels (int): Number of squeeze channels - activation (Callable[..., paddle.nn.Layer], optional): ``delta`` activation. Default: ``paddle.nn.ReLU`` - scale_activation (Callable[..., paddle.nn.Layer]): ``sigma`` activation. Default: ``paddle.nn.Sigmoid`` + input_channels (int): Number of channels in the input image. + squeeze_channels (int): Number of squeeze channels. + activation (Callable[..., paddle.nn.Layer], optional): ``delta`` activation. Default: ``paddle.nn.ReLU``. + scale_activation (Callable[..., paddle.nn.Layer]): ``sigma`` activation. Default: ``paddle.nn.Sigmoid``. """ def __init__( @@ -190,7 +191,7 @@ class MobileNetV3(nn.Layer): last_channel (int): The number of channels on the penultimate layer. scale (float, optional): Scale of channels in each layer. Default: 1.0. num_classes (int, optional): Output dim of last fc layer. If num_classes <=0, last fc layer - will not be defined. Default: 1000. + will not be defined. Default: 1000. with_pool (bool, optional): Use pool before the last fc layer or not. Default: True. """ @@ -280,7 +281,7 @@ class MobileNetV3Small(MobileNetV3): Args: scale (float, optional): Scale of channels in each layer. Default: 1.0. num_classes (int, optional): Output dim of last fc layer. If num_classes <= 0, last fc layer - will not be defined. Default: 1000. + will not be defined. Default: 1000. with_pool (bool, optional): Use pool before the last fc layer or not. Default: True. Returns: @@ -289,17 +290,17 @@ class MobileNetV3Small(MobileNetV3): Examples: .. code-block:: python - import paddle - from paddle.vision.models import MobileNetV3Small + >>> import paddle + >>> from paddle.vision.models import MobileNetV3Small - # build model - model = MobileNetV3Small(scale=1.0) + >>> # Build model + >>> model = MobileNetV3Small(scale=1.0) - x = paddle.rand([1, 3, 224, 224]) - out = model(x) + >>> x = paddle.rand([1, 3, 224, 224]) + >>> out = model(x) - print(out.shape) - # [1, 1000] + >>> print(out.shape) + [1, 1000] """ def __init__(self, scale=1.0, num_classes=1000, with_pool=True): @@ -333,7 +334,7 @@ class MobileNetV3Large(MobileNetV3): Args: scale (float, optional): Scale of channels in each layer. Default: 1.0. num_classes (int, optional): Output dim of last fc layer. If num_classes <= 0, last fc layer - will not be defined. Default: 1000. + will not be defined. Default: 1000. with_pool (bool, optional): Use pool before the last fc layer or not. Default: True. Returns: @@ -342,17 +343,17 @@ class MobileNetV3Large(MobileNetV3): Examples: .. code-block:: python - import paddle - from paddle.vision.models import MobileNetV3Large + >>> import paddle + >>> from paddle.vision.models import MobileNetV3Large - # build model - model = MobileNetV3Large(scale=1.0) + >>> # Build model + >>> model = MobileNetV3Large(scale=1.0) - x = paddle.rand([1, 3, 224, 224]) - out = model(x) + >>> x = paddle.rand([1, 3, 224, 224]) + >>> out = model(x) - print(out.shape) - # [1, 1000] + >>> print(out.shape) + [1, 1000] """ def __init__(self, scale=1.0, num_classes=1000, with_pool=True): @@ -427,8 +428,7 @@ def mobilenet_v3_small(pretrained=False, scale=1.0, **kwargs): `"Searching for MobileNetV3" `_. Args: - pretrained (bool, optional): Whether to load pre-trained weights. If True, returns a model pre-trained - on ImageNet. Default: False. + pretrained (bool, optional): Whether to load pre-trained weights. If True, returns a model pre-trained on ImageNet. Default: False. scale (float, optional): Scale of channels in each layer. Default: 1.0. **kwargs (optional): Additional keyword arguments. For details, please refer to :ref:`MobileNetV3Small `. @@ -438,23 +438,23 @@ def mobilenet_v3_small(pretrained=False, scale=1.0, **kwargs): Examples: .. code-block:: python - import paddle - from paddle.vision.models import mobilenet_v3_small + >>> import paddle + >>> from paddle.vision.models import mobilenet_v3_small - # build model - model = mobilenet_v3_small() + >>> # Build model + >>> model = mobilenet_v3_small() - # build model and load imagenet pretrained weight - # model = mobilenet_v3_small(pretrained=True) + >>> # Build model and load imagenet pretrained weight + >>> # model = mobilenet_v3_small(pretrained=True) - # build mobilenet v3 small model with scale=0.5 - model = mobilenet_v3_small(scale=0.5) + >>> # Build mobilenet v3 small model with scale=0.5 + >>> model = mobilenet_v3_small(scale=0.5) - x = paddle.rand([1, 3, 224, 224]) - out = model(x) + >>> x = paddle.rand([1, 3, 224, 224]) + >>> out = model(x) - print(out.shape) - # [1, 1000] + >>> print(out.shape) + [1, 1000] """ model = _mobilenet_v3( "mobilenet_v3_small", scale=scale, pretrained=pretrained, **kwargs @@ -467,8 +467,7 @@ def mobilenet_v3_large(pretrained=False, scale=1.0, **kwargs): `"Searching for MobileNetV3" `_. Args: - pretrained (bool, optional): Whether to load pre-trained weights. If True, returns a model pre-trained - on ImageNet. Default: False. + pretrained (bool, optional): Whether to load pre-trained weights. If True, returns a model pre-trained on ImageNet. Default: False. scale (float, optional): Scale of channels in each layer. Default: 1.0. **kwargs (optional): Additional keyword arguments. For details, please refer to :ref:`MobileNetV3Large `. @@ -478,23 +477,23 @@ def mobilenet_v3_large(pretrained=False, scale=1.0, **kwargs): Examples: .. code-block:: python - import paddle - from paddle.vision.models import mobilenet_v3_large + >>> import paddle + >>> from paddle.vision.models import mobilenet_v3_large - # build model - model = mobilenet_v3_large() + >>> # Build model + >>> model = mobilenet_v3_large() - # build model and load imagenet pretrained weight - # model = mobilenet_v3_large(pretrained=True) + >>> # Build model and load imagenet pretrained weight + >>> # model = mobilenet_v3_large(pretrained=True) - # build mobilenet v3 large model with scale=0.5 - model = mobilenet_v3_large(scale=0.5) + >>> # Build mobilenet v3 large model with scale=0.5 + >>> model = mobilenet_v3_large(scale=0.5) - x = paddle.rand([1, 3, 224, 224]) - out = model(x) + >>> x = paddle.rand([1, 3, 224, 224]) + >>> out = model(x) - print(out.shape) - # [1, 1000] + >>> print(out.shape) + [1, 1000] """ model = _mobilenet_v3( "mobilenet_v3_large", scale=scale, pretrained=pretrained, **kwargs