forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Resize.h
109 lines (96 loc) · 3.21 KB
/
Resize.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
#pragma once
#include <ATen/ATen.h>
#include <TH/THTensor.hpp>
namespace at { namespace native {
// These functions are called by native::resize_ as well as (legacy) TH resize.
// They are not in TH/THTensor.cpp because the at namespace is easier
// to benchmark than TH; I can't get gbenchmark to call fns from THTensor.cpp
static inline void maybe_resize_storage_cpu(TensorImpl* self, int64_t new_size) {
if (new_size + self->storage_offset() > 0) {
if (!THTensor_getStoragePtr(self)) {
THTensor_stealAndSetStoragePtr(self, THStorage_new(self->dtype()));
}
if (new_size + self->storage_offset() > self->storage().numel()) {
THStorage_resize(
THTensor_getStoragePtr(self),
new_size + self->storage_offset());
}
}
}
inline TensorImpl* resize_impl_cpu_(
TensorImpl* self,
IntList size,
c10::optional<IntList> stride) {
if (self->sizes() == size && (!stride || self->strides() == stride)) {
return self;
}
int64_t storage_size = 1;
if (stride) {
self->set_sizes_and_strides(size, *stride);
// NB: storage size can be different from numel.
for (size_t dim = 0; dim < size.size(); ++dim) {
// FIXME: Don't rely on storage_size being negative because this
// may not be true for some edge cases.
if (size[dim] == 0) {
storage_size = 0;
break;
}
storage_size += (size[dim] - 1) * stride.value()[dim];
}
} else {
self->set_sizes_contiguous(size);
storage_size = self->numel();
}
maybe_resize_storage_cpu(self, storage_size);
return self;
}
static inline int64_t computeStorageSize(IntList sizes, IntList strides) {
int64_t storage_size = 1;
for (size_t dim = 0; dim < sizes.size(); ++dim) {
if (sizes[dim] == 0) {
return 0;
}
storage_size += strides[dim] * (sizes[dim] - 1);
}
return storage_size;
}
static inline void checkInBoundsForStorage(
IntList size,
IntList stride,
int64_t storage_offset,
const Storage& new_storage) {
int64_t storage_size = computeStorageSize(size, stride);
if (storage_size == 0) {
// NB: (a tensor with arbitrary 0 dims)'s storage can have any numel.
return;
}
int64_t new_storage_size = new_storage.numel();
AT_CHECK(
storage_offset + storage_size <= new_storage_size,
"setStorage: sizes ", size, ", strides ", stride, ","
" and storage offset ", storage_offset,
" requiring a storage size of ", storage_size + storage_offset,
" are out of bounds for storage with numel ", new_storage_size);
}
/**
* Set self's sizes, strides, and storage_offset.
* (size, stride, storage_offset) must be in bounds for self's storage.
*/
inline void setStrided(
const Tensor& self,
IntList size,
IntList stride,
int64_t storage_offset) {
auto* self_ = self.unsafeGetTensorImpl();
checkInBoundsForStorage(size, stride, storage_offset, self_->storage());
/* storage offset */
AT_CHECK(storage_offset >= 0, "Tensor: invalid storage offset ", storage_offset);
self_->set_storage_offset(storage_offset);
/* size and stride */
AT_ASSERT(size.size() == stride.size());
if (self_->sizes() == size && self_->strides() == stride) {
return;
}
self_->set_sizes_and_strides(size, stride);
}
}}