Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Adreno][OpenCL] Get rid of extra memory copy #12286

Merged
merged 7 commits into from
Aug 12, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions python/tvm/relay/op/strategy/adreno.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,3 +257,21 @@ def schedule_pool_adreno(attrs, outs, target):
if attrs.layout == "NCHW4c":
return topi.adreno.schedule_pool(outs, attrs.layout)
return topi.cuda.schedule_pool(outs, attrs.layout)


@schedule_injective.register(["adreno"])
def schedule_injective_adreno(attrs, outs, target):
"""schedule injective ops for cuda"""
echuraev marked this conversation as resolved.
Show resolved Hide resolved
with target:
return topi.adreno.schedule_injective(outs)


@concatenate_strategy.register(["adreno"])
def concatenate_strategy_adreno(attrs, inputs, out_type, target):
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_concat(topi.transform.concatenate),
wrap_topi_schedule(topi.adreno.schedule_injective),
name="concatenate.cuda",
)
return strategy
1 change: 1 addition & 0 deletions python/tvm/topi/adreno/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@
from .conv2d_alter_op import *
from .conv2d_nchw_winograd import *
from .conv2d_nhwc_winograd import *
from .injective import schedule_injective
29 changes: 18 additions & 11 deletions python/tvm/topi/adreno/conv2d_nchw.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,28 +279,35 @@ def schedule_conv2d_NCHWc_KCRSk(cfg, s, output):
): # len(latest.op.axis) == 4:
# manage scheduling of datacopy
pad_data, kernel = s[conv].op.input_tensors
pack_data = pad_data.op.input_tensors[0]
bind_data_copy(s[pack_data])
if "pad_temp" in pad_data.op.name:
pack_data = pad_data.op.input_tensors[0]
bind_data_copy(s[pack_data])
else:
bind_data_copy(s[pad_data])
bind_data_copy(s[kernel])

pad_data, kernel = s[conv].op.input_tensors

s[pad_data].compute_inline()

s[conv].set_scope("local")
if latest_blocked == latest and output != latest:
s[output].compute_inline()

# create cache stage
AT = s.cache_read(pad_data, get_texture_storage(pad_data.shape), [conv])
bind_data_copy(s[AT])
if (
autotvm.GLOBAL_SCOPE.in_tuning
or isinstance(kernel.op, tvm.te.ComputeOp)
and "filter_pack" in kernel.op.tag
):
if "pad_temp" in pad_data.op.name:
s[pad_data].compute_inline()
AT = s.cache_read(pad_data, get_texture_storage(pad_data.shape), [conv])
bind_data_copy(s[AT])
WT = s.cache_read(kernel, get_texture_storage(kernel.shape), [conv])
bind_data_copy(s[WT])
elif "pad_temp" in pad_data.op.name:
s[pad_data].compute_inline()
# create cache stage
AT = s.cache_read(pad_data, get_texture_storage(pad_data.shape), [conv])
bind_data_copy(s[AT])

s[conv].set_scope("local")
if latest_blocked == latest and output != latest:
s[output].compute_inline()

# tile and bind spatial axes
n, fc, y, x, fb = s[latest_blocked].op.axis
Expand Down
29 changes: 18 additions & 11 deletions python/tvm/topi/adreno/conv2d_nhwc.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,28 +275,35 @@ def schedule_conv2d_NHWC(cfg, s, output):
): # len(latest.op.axis) == 4:
# manage scheduling of datacopy
pad_data, kernel = s[conv].op.input_tensors
pack_data = pad_data.op.input_tensors[0]
bind_data_copy(s[pack_data])
if "pad_temp" in pad_data.op.name:
pack_data = pad_data.op.input_tensors[0]
bind_data_copy(s[pack_data])
else:
bind_data_copy(s[pad_data])
bind_data_copy(s[kernel])

pad_data, kernel = s[conv].op.input_tensors

s[pad_data].compute_inline()

s[conv].set_scope("local")
if latest_blocked == latest and output != latest:
s[output].compute_inline()

# create cache stage
AT = s.cache_read(pad_data, get_texture_storage(pad_data.shape), [conv])
bind_data_copy(s[AT])
if (
autotvm.GLOBAL_SCOPE.in_tuning
or isinstance(kernel.op, tvm.te.ComputeOp)
and "filter_pack" in kernel.op.tag
):
if "pad_temp" in pad_data.op.name:
s[pad_data].compute_inline()
AT = s.cache_read(pad_data, get_texture_storage(pad_data.shape), [conv])
bind_data_copy(s[AT])
WT = s.cache_read(kernel, get_texture_storage(kernel.shape), [conv])
bind_data_copy(s[WT])
elif "pad_temp" in pad_data.op.name:
s[pad_data].compute_inline()
# create cache stage
AT = s.cache_read(pad_data, get_texture_storage(pad_data.shape), [conv])
bind_data_copy(s[AT])

s[conv].set_scope("local")
if latest_blocked == latest and output != latest:
s[output].compute_inline()

# tile and bind spatial axes
n, y, x, fc, fb = s[latest_blocked].op.axis
Expand Down
10 changes: 7 additions & 3 deletions python/tvm/topi/adreno/depthwise_conv2d_nchw.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,13 +253,17 @@ def schedule_depthwise_conv2d_NCHWc_KCRSk(cfg, s, output):
): # len(latest.op.axis) == 4:
# manage scheduling of datacopy
pad_data, kernel = s[conv].op.input_tensors
pack_data = pad_data.op.input_tensors[0]
bind_data_copy(s[pack_data])
if "pad_temp" in pad_data.op.name:
pack_data = pad_data.op.input_tensors[0]
bind_data_copy(s[pack_data])
else:
bind_data_copy(s[pad_data])
bind_data_copy(s[kernel])

pad_data, kernel = s[conv].op.input_tensors

s[pad_data].compute_inline()
if "pad_temp" in pad_data.op.name:
s[pad_data].compute_inline()

s[conv].set_scope("local")
if latest_blocked == latest and output != latest:
Expand Down
10 changes: 7 additions & 3 deletions python/tvm/topi/adreno/depthwise_conv2d_nhwc.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,13 +247,17 @@ def schedule_depthwise_conv2d_NHWC_HWOI(cfg, s, output):
): # len(latest.op.axis) == 4:
# manage scheduling of datacopy
pad_data, kernel = s[conv].op.input_tensors
pack_data = pad_data.op.input_tensors[0]
bind_data_copy(s[pack_data])
if "pad_temp" in pad_data.op.name:
pack_data = pad_data.op.input_tensors[0]
bind_data_copy(s[pack_data])
else:
bind_data_copy(s[pad_data])
bind_data_copy(s[kernel])

pad_data, kernel = s[conv].op.input_tensors

s[pad_data].compute_inline()
if "pad_temp" in pad_data.op.name:
s[pad_data].compute_inline()

s[conv].set_scope("local")
if latest_blocked == latest and output != latest:
Expand Down
66 changes: 66 additions & 0 deletions python/tvm/topi/adreno/injective.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-variable,
"""Schedule for composition of injective operator"""
import tvm
from tvm import te
from .utils import bind_data_copy
from .. import utils


def schedule_injective_from_existing(sch, out):
"""Schedule for injective op from existing schedule.

Parameters
----------
sch: Schedule
The schedule to update.
out: Tensor
The tensor representing the injective op.

Returns
-------
sch: Schedule
The updated schedule.
"""

bind_data_copy(sch[out])
return sch


def schedule_injective(outs):
"""Schedule for injective op.

Parameters
----------
outs: Array of Tensor
The computation graph description of injective in the format
of an array of tensors.

Returns
-------
sch: Schedule
The computation schedule for the op.
"""
outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
s = te.create_schedule([x.op for x in outs])

tvm.te.schedule.AutoInlineInjective(s)
for out in outs:
if not utils.is_empty_shape(out.shape):
schedule_injective_from_existing(s, out)
return s
21 changes: 20 additions & 1 deletion python/tvm/topi/adreno/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,19 @@ def add_pad(
pad_after[x_axis] -= in_width + pad_before[x_axis] + pad_after[x_axis] - input_latest_w
if input_latest_h < in_height + pad_before[y_axis] + pad_after[y_axis]:
pad_after[y_axis] -= in_height + pad_before[y_axis] + pad_after[y_axis] - input_latest_h
return nn.pad(data, pad_before, pad_after, name="pad_temp")
if (
pad_before[0] == 0
and pad_before[1] == 0
and pad_before[2] == 0
and pad_before[3] == 0
and pad_after[0] == 0
and pad_after[1] == 0
and pad_after[2] == 0
and pad_after[3] == 0
):
return data
else:
return nn.pad(data, pad_before, pad_after, name="pad_temp")


def bind_data_copy(stage, axis_to_vectorize=None):
Expand Down Expand Up @@ -522,6 +534,13 @@ def bind_data_copy(stage, axis_to_vectorize=None):
stage.bind(thread, te.thread_axis("threadIdx.x"))
if shape[-1] == 4:
stage.vectorize(axes[-1])
elif shape[-1] > 1024:
echuraev marked this conversation as resolved.
Show resolved Hide resolved
ftc = numpy.prod(shape[:-1])
div = get_div(ftc, 1024)
by, ty = stage.split(axes[-1], factor=div)
stage.bind(fused, te.thread_axis("blockIdx.x"))
stage.bind(by, te.thread_axis("blockIdx.y"))
stage.bind(ty, te.thread_axis("threadIdx.y"))
else:
stage.bind(fused, te.thread_axis("blockIdx.x"))
stage.bind(*axes[-1:], te.thread_axis("threadIdx.x"))
Expand Down
1 change: 0 additions & 1 deletion src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,6 @@ class RelayBuildModule : public runtime::ModuleNode {
relay_module = transform::InferType()(relay_module);
relay_module = transform::LabelOps()(relay_module);
relay_module = transform::AnnotateMemoryScope(config_)(relay_module);

echuraev marked this conversation as resolved.
Show resolved Hide resolved
ICHECK(relay_module.defined());

return relay_module;
Expand Down
Loading