forked from ROCm/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
init_flatbuffer_module.cpp
126 lines (119 loc) · 4.55 KB
/
init_flatbuffer_module.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
#include <torch/csrc/python_headers.h>
#include <libshm.h>
#include <cstdlib>
#include <pybind11/detail/common.h>
#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>
#include <pybind11/stl.h>
#include <pybind11/stl_bind.h>
#include <Python.h> // NOLINT
#include <torch/csrc/jit/mobile/flatbuffer_loader.h>
#include <torch/csrc/jit/python/module_python.h>
#include <torch/csrc/jit/python/python_ivalue.h>
#include <torch/csrc/jit/python/python_sugared_value.h>
#include <torch/csrc/jit/serialization/flatbuffer_serializer.h>
#include <torch/csrc/jit/serialization/flatbuffer_serializer_jit.h>
namespace py = pybind11;
static std::shared_ptr<char> copyStr(const std::string& bytes) {
size_t size = (bytes.size() / FLATBUFFERS_MAX_ALIGNMENT + 1) *
FLATBUFFERS_MAX_ALIGNMENT;
#ifdef _WIN32
std::shared_ptr<char> bytes_copy(
static_cast<char*>(_aligned_malloc(size, FLATBUFFERS_MAX_ALIGNMENT)),
_aligned_free);
#elif defined(__APPLE__)
void* p;
::posix_memalign(&p, FLATBUFFERS_MAX_ALIGNMENT, size);
TORCH_INTERNAL_ASSERT(p, "Could not allocate memory for flatbuffer");
std::shared_ptr<char> bytes_copy(static_cast<char*>(p), free);
#else
std::shared_ptr<char> bytes_copy(
static_cast<char*>(aligned_alloc(FLATBUFFERS_MAX_ALIGNMENT, size)), free);
#endif
memcpy(bytes_copy.get(), bytes.data(), bytes.size());
return bytes_copy;
}
extern "C"
#ifdef _WIN32
__declspec(dllexport)
#endif
PyObject* initModuleFlatbuffer() {
using namespace torch::jit;
PyMethodDef m[] = {{nullptr, nullptr, 0, nullptr}}; // NOLINT
static struct PyModuleDef torchmodule = {
PyModuleDef_HEAD_INIT,
"torch._C_flatbuffer",
nullptr,
-1,
m,
}; // NOLINT
PyObject* module = PyModule_Create(&torchmodule);
auto pym = py::handle(module).cast<py::module>();
pym.def("_load_mobile_module_from_file", [](const std::string& filename) {
return torch::jit::load_mobile_module_from_file(filename);
});
pym.def("_load_mobile_module_from_bytes", [](const std::string& bytes) {
auto bytes_copy = copyStr(bytes);
return torch::jit::parse_and_initialize_mobile_module(
bytes_copy, bytes.size());
});
pym.def("_load_jit_module_from_file", [](const std::string& filename) {
ExtraFilesMap extra_files = ExtraFilesMap();
return torch::jit::load_jit_module_from_file(filename, extra_files);
});
pym.def("_load_jit_module_from_bytes", [](const std::string& bytes) {
auto bytes_copy = copyStr(bytes);
ExtraFilesMap extra_files = ExtraFilesMap();
return torch::jit::parse_and_initialize_jit_module(
bytes_copy, bytes.size(), extra_files);
});
pym.def(
"_save_mobile_module",
[](const torch::jit::mobile::Module& module,
const std::string& filename,
const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
return torch::jit::save_mobile_module(module, filename, _extra_files);
});
pym.def(
"_save_jit_module",
[](const torch::jit::Module& module,
const std::string& filename,
const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
return torch::jit::save_jit_module(module, filename, _extra_files);
});
pym.def(
"_save_mobile_module_to_bytes",
[](const torch::jit::mobile::Module& module,
const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
auto detached_buffer =
torch::jit::save_mobile_module_to_bytes(module, _extra_files);
return py::bytes(
reinterpret_cast<char*>(detached_buffer.data()),
detached_buffer.size());
});
pym.def(
"_save_jit_module_to_bytes",
[](const torch::jit::Module& module,
const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
auto detached_buffer =
torch::jit::save_jit_module_to_bytes(module, _extra_files);
return py::bytes(
reinterpret_cast<char*>(detached_buffer.data()),
detached_buffer.size());
});
pym.def(
"_get_module_info_from_flatbuffer", [](std::string flatbuffer_content) {
py::gil_scoped_acquire acquire;
py::dict result;
mobile::ModuleInfo minfo =
torch::jit::get_module_info_from_flatbuffer(&flatbuffer_content[0]);
result["bytecode_version"] = minfo.bytecode_version;
result["operator_version"] = minfo.operator_version;
result["function_names"] = minfo.function_names;
result["type_names"] = minfo.type_names;
result["opname_to_num_args"] = minfo.opname_to_num_args;
return result;
});
return module;
}