forked from ROCm/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
_deploy.py
104 lines (84 loc) · 3.38 KB
/
_deploy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
# mypy: allow-untyped-defs
import io
import torch
from torch.package import Importer, OrderedImporter, PackageImporter, sys_importer
from torch.package._package_pickler import create_pickler
from torch.package._package_unpickler import PackageUnpickler
from torch.serialization import _maybe_decode_ascii
def _save_storages(importer, obj):
serialized_storages = []
serialized_dtypes = []
importer = importer if isinstance(importer, torch.package.PackageImporter) else None
importers: Importer
if importer is not None:
importers = OrderedImporter(importer, sys_importer)
else:
importers = sys_importer
def persistent_id(obj):
if torch.is_storage(obj) or isinstance(obj, torch.storage.TypedStorage):
if isinstance(obj, torch.storage.TypedStorage):
# TODO: Once we decide to break serialization FC, we can
# remove this case
dtype = obj.dtype
else:
dtype = torch.uint8
serialized_storages.append(obj)
serialized_dtypes.append(dtype)
return ("storage", len(serialized_storages) - 1)
if hasattr(obj, "__reduce_deploy__"):
if _serialized_reduces.get(id(obj)) is None:
_serialized_reduces[id(obj)] = (
"reduce_deploy",
id(obj),
*obj.__reduce_deploy__(importers),
)
return _serialized_reduces[id(obj)]
return None
# Write the pickle data for `obj`
data_buf = io.BytesIO()
pickler = create_pickler(data_buf, importers)
pickler.persistent_id = persistent_id
pickler.dump(obj)
data_value = data_buf.getvalue()
return (
data_value,
serialized_storages,
serialized_dtypes,
importer.zip_reader if importer else None,
)
def _load_storages(id, zip_reader, obj_bytes, serialized_storages, serialized_dtypes):
def persistent_load(saved_id):
assert isinstance(saved_id, tuple)
typename = _maybe_decode_ascii(saved_id[0])
data = saved_id[1:]
if typename == "storage":
# TODO: Once we decide to break serialization FC, we can
# stop wrapping with TypedStorage
storage = serialized_storages[data[0]]
dtype = serialized_dtypes[data[0]]
return torch.storage.TypedStorage(
wrap_storage=storage.untyped(), dtype=dtype
)
if typename == "reduce_deploy":
reduce_id, func, args = data
if reduce_id not in _loaded_reduces:
_loaded_reduces[reduce_id] = func(_raw_packages[zip_reader], *args)
return _loaded_reduces[reduce_id]
return None
importer: Importer
if zip_reader is not None:
importer = OrderedImporter(_get_package(zip_reader), sys_importer)
else:
importer = sys_importer
unpickler = PackageUnpickler(importer, io.BytesIO(obj_bytes))
unpickler.persistent_load = persistent_load # type: ignore[method-assign]
result = _deploy_objects[id] = unpickler.load()
return result
def _get_package(zip_reader):
if zip_reader not in _raw_packages:
_raw_packages[zip_reader] = PackageImporter(zip_reader)
return _raw_packages[zip_reader]
_raw_packages: dict = {}
_deploy_objects: dict = {}
_serialized_reduces: dict = {}
_loaded_reduces: dict = {}