Skip to content

Commit

Permalink
[USMP] HillClimb stability patch (#10547)
Browse files Browse the repository at this point in the history
This patch increases stability of the hill climb allocation algorithm

Change-Id: I56414ae661fa856baeddce00f4717a9f5a9e2954
  • Loading branch information
d-smirnov authored Jul 6, 2022
1 parent c57320b commit c98626c
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 45 deletions.
50 changes: 22 additions & 28 deletions src/tir/usmp/algo/hill_climb.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ namespace algo {
* Works by continiously invoking 'greedy-by-size' allocation,
* assessing the result, and introducing permutations to the allocation
* order which hopefully will led to more 'compact' memory allocation.
* Do not forget to use srand for repeatable results
*/
class HillClimbAllocator : public GreedyBase {
private:
Expand All @@ -59,18 +60,18 @@ class HillClimbAllocator : public GreedyBase {
/*
* Initial sorting routine
*/
void sort_vector(std::vector<BufferInfo>* buffer_info_vec) {
std::sort(buffer_info_vec->begin(), buffer_info_vec->end(),
[](const BufferInfo& a, const BufferInfo& b) {
if (a->size_bytes->value == b->size_bytes->value) {
if (a->conflicts.size() == b->conflicts.size()) {
return std::string(a->name_hint->data) > std::string(b->name_hint->data);
} else {
return a->conflicts.size() > b->conflicts.size();
}
}
return a->size_bytes->value > b->size_bytes->value;
});
template <typename T>
void sort_vector(std::vector<T>* buffer_info_vec) {
std::sort(buffer_info_vec->begin(), buffer_info_vec->end(), [](const T& a, const T& b) {
if (a->size_bytes->value == b->size_bytes->value) {
if (a->conflicts.size() == b->conflicts.size()) {
return std::string(a->name_hint->data) > std::string(b->name_hint->data);
} else {
return a->conflicts.size() > b->conflicts.size();
}
}
return a->size_bytes->value > b->size_bytes->value;
});
}

/*
Expand Down Expand Up @@ -156,33 +157,21 @@ class HillClimbAllocator : public GreedyBase {
void collect_neighbor_lists(const BufferInfoNode* buf,
std::vector<const BufferInfoNode*>* first_level,
std::vector<const BufferInfoNode*>* second_level, const TPos& _pos) {
std::unordered_map<int, const BufferInfoNode*> first_level_set;
std::unordered_map<int, const BufferInfoNode*> second_level_set;

auto buf_pos = _pos(buf);
for (const auto& c1 : buf->conflicts) {
const auto* c1_buf = c1.as<BufferInfoNode>();
int c1_pos = _pos(c1_buf);
if (buf_pos > c1_pos) {
first_level_set[c1_pos] = c1_buf;
first_level->push_back(c1_buf);
}
int c2_pos = -1;
for (const auto& c2 : c1_buf->conflicts) {
const auto c2_buf = c2.as<BufferInfoNode>();
if (c1_pos > (c2_pos = _pos(c2_buf))) {
second_level_set[c2_pos] = c2_buf;
second_level->push_back(c2_buf);
}
}
}

// std::vector<const BufferInfoNode*> first_level;
for (const auto& i : first_level_set) {
first_level->push_back(i.second);
}
// std::vector<const BufferInfoNode*> second_level;
for (const auto& i : second_level_set) {
second_level->push_back(i.second);
}
}

public:
Expand All @@ -202,7 +191,7 @@ class HillClimbAllocator : public GreedyBase {
buffer_info_vec.push_back(std::move(buffer_info));
}

sort_vector(&buffer_info_vec);
sort_vector<BufferInfo>(&buffer_info_vec);

// populate positional index map
std::unordered_map<const BufferInfoNode*, int> _pos_map;
Expand Down Expand Up @@ -283,12 +272,17 @@ class HillClimbAllocator : public GreedyBase {
max_pool_buf.push_back(buf);
}
}

sort(max_pool_buf.begin(), max_pool_buf.end(),
[&_pos](const auto* a, const auto* b) { return _pos(a) < _pos(b); });
// pick highest
const BufferInfoNode* node = max_pool_buf[rnd_func() % max_pool_buf.size()];
std::vector<const BufferInfoNode*> first_level;
std::vector<const BufferInfoNode*> second_level;
collect_neighbor_lists(node, &first_level, &second_level, _pos);
sort(first_level.begin(), first_level.end(),
[&_pos](const auto* a, const auto* b) { return _pos(a) < _pos(b); });
sort(second_level.begin(), second_level.end(),
[&_pos](const auto* a, const auto* b) { return _pos(a) < _pos(b); });

// retry if no first level neightbors were collected
if (!first_level.size()) {
Expand Down
50 changes: 38 additions & 12 deletions tests/python/relay/aot/test_crt_aot_usmp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

from collections import OrderedDict
import re

import random
import numpy as np
import pytest

Expand Down Expand Up @@ -100,23 +102,47 @@ def test_synthetic(interface_api, use_unpacked_api, test_runner):


@pytest.mark.parametrize(
"workspace_byte_alignment,constant_byte_alignment,main_workspace_size,main_constant_size",
"workspace_byte_alignment,constant_byte_alignment,"
"main_workspace_size,main_constant_size,usmp_algo",
[
(8, 8, 17280, 948),
(16, 8, 17280, 948),
(256, 8, 17792, 948),
(8, 16, 17280, 956),
(16, 16, 17280, 956),
(256, 16, 17792, 956),
(8, 256, 17280, 1804),
(16, 256, 17280, 1804),
(256, 256, 17792, 1804),
(8, 8, 17280, 948, "greedy_by_conflicts"),
(16, 8, 17280, 948, "greedy_by_conflicts"),
(256, 8, 17792, 948, "greedy_by_conflicts"),
(8, 16, 17280, 956, "greedy_by_conflicts"),
(16, 16, 17280, 956, "greedy_by_conflicts"),
(256, 16, 17792, 956, "greedy_by_conflicts"),
(8, 256, 17280, 1804, "greedy_by_conflicts"),
(16, 256, 17280, 1804, "greedy_by_conflicts"),
(256, 256, 17792, 1804, "greedy_by_conflicts"),
(8, 8, 22032, 948, "greedy_by_size"),
(16, 8, 22032, 948, "greedy_by_size"),
(256, 8, 22976, 948, "greedy_by_size"),
(8, 16, 22032, 956, "greedy_by_size"),
(16, 16, 22032, 956, "greedy_by_size"),
(256, 16, 22976, 956, "greedy_by_size"),
(8, 256, 22032, 1804, "greedy_by_size"),
(16, 256, 22032, 1804, "greedy_by_size"),
(256, 256, 22976, 1804, "greedy_by_size"),
(8, 8, 11424, 948, "hill_climb"),
(16, 8, 11424, 948, "hill_climb"),
(256, 8, 11920, 948, "hill_climb"),
(8, 16, 11424, 956, "hill_climb"),
(16, 16, 11424, 956, "hill_climb"),
(256, 16, 11920, 956, "hill_climb"),
(8, 256, 11424, 1804, "hill_climb"),
(16, 256, 11424, 1804, "hill_climb"),
(256, 256, 11920, 1804, "hill_climb"),
],
)
def test_memory_planning(
workspace_byte_alignment, constant_byte_alignment, main_workspace_size, main_constant_size
workspace_byte_alignment,
constant_byte_alignment,
main_workspace_size,
main_constant_size,
usmp_algo,
):
"""Checks calculated workspace against known values"""
random.seed(0)
mod, params = tvm.relay.testing.synthetic.get_workload()
target = "c"
runtime = Runtime("crt")
Expand All @@ -133,7 +159,7 @@ def test_memory_planning(
"tir.disable_vectorize": True,
"tir.disable_storage_rewrite": True,
"tir.usmp.enable": True,
"tir.usmp.algorithm": "greedy_by_conflicts",
"tir.usmp.algorithm": usmp_algo,
},
):
lib = tvm.relay.build(mod, target, executor=executor, runtime=runtime, params=params)
Expand Down
12 changes: 7 additions & 5 deletions tests/python/unittest/test_tir_usmp_algo_hill_climb.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from tvm import WorkspacePoolInfo, PoolInfoProperties


def _check_max_workspace_size(buffer_pool_allocations, pool_info, size):
def _check_max_workspace_size(buffer_pool_allocations, pool_info, size, tolerance=0):
"""Helper to check maximum allocated memory size"""
max_workspace_size = 0
for buffer_info, pool_allocation in buffer_pool_allocations.items():
Expand All @@ -33,7 +33,7 @@ def _check_max_workspace_size(buffer_pool_allocations, pool_info, size):
max_workspace_size = size_candidate
_diff = max_workspace_size.value - size
return (
(max_workspace_size.value == size),
(max_workspace_size.value == size if tolerance == 0 else tolerance > 100 * _diff / size),
"'{}': expected {} got {}, diff {:0.2f}% ({} bytes)".format(
pool_info.pool_name, size, max_workspace_size, 100 * _diff / size, _diff
),
Expand Down Expand Up @@ -335,7 +335,7 @@ def find_maximum_from_intervals(intervals):
def test_intervals(intervals):
"""Tests supplied intervals"""
random.seed(0)
result = run_intervals(intervals)
result = run_intervals(intervals, 5)
assert result["tir.usmp.algo.hill_climb"] == True, f" {result}"


Expand All @@ -355,7 +355,7 @@ def test_random_intervals(interval_len=16):
return run_intervals(intervals)


def run_intervals(intervals):
def run_intervals(intervals, tolerance=0):
"""Helper to run intervals"""
expected_mem = find_maximum_from_intervals(intervals)
pools = [WorkspacePoolInfo("default", [])]
Expand Down Expand Up @@ -391,7 +391,9 @@ def run_intervals(intervals):
print()

_verify_all_conflicts(buffer_info_arr)
result[alg], msg = _check_max_workspace_size(buffer_info_arr, pools[0], expected_mem)
result[alg], msg = _check_max_workspace_size(
buffer_info_arr, pools[0], expected_mem, tolerance
)
if not result[alg]:
print(alg, msg)

Expand Down

0 comments on commit c98626c

Please sign in to comment.