Skip to content

Commit

Permalink
[MetaSchedule] Distributed Measurement
Browse files Browse the repository at this point in the history
  • Loading branch information
Kathryn-cat committed Jun 17, 2022
1 parent 6732a9e commit 36ad76c
Show file tree
Hide file tree
Showing 10 changed files with 364 additions and 13 deletions.
21 changes: 21 additions & 0 deletions include/tvm/meta_schedule/database.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,11 @@ class DatabaseNode : public runtime::Object {
* \return An array of top K tuning records for the given workload.
*/
virtual Array<TuningRecord> GetTopK(const Workload& workload, int top_k) = 0;
/*!
* \brief Get all tuning records from the database.
* \return An Array of all the tuning records in the database.
*/
virtual Array<TuningRecord> GetAllTuningRecords() = 0;
/*!
* \brief Get the size of the database.
* \return The size of the database.
Expand Down Expand Up @@ -224,6 +229,11 @@ class PyDatabaseNode : public DatabaseNode {
* \return An array of top K tuning records for the given workload.
*/
using FGetTopK = runtime::TypedPackedFunc<Array<TuningRecord>(const Workload&, int)>;
/*!
* \brief The function type of `GetAllTuningRecords` method.
* \return An Array of all the tuning records in the database.
*/
using FGetAllTuningRecords = runtime::TypedPackedFunc<Array<TuningRecord>()>;
/*!
* \brief The function type of `Size` method.
* \return The size of the database.
Expand All @@ -238,6 +248,8 @@ class PyDatabaseNode : public DatabaseNode {
FCommitTuningRecord f_commit_tuning_record;
/*! \brief The packed function to the `GetTopK` function. */
FGetTopK f_get_top_k;
/*! \brief The packed function to the `GetAllTuningRecords` function. */
FGetAllTuningRecords f_get_all_tuning_records;
/*! \brief The packed function to the `Size` function. */
FSize f_size;

Expand All @@ -249,6 +261,7 @@ class PyDatabaseNode : public DatabaseNode {
// `f_commit_workload` is not visited
// `f_commit_tuning_record` is not visited
// `f_get_top_k` is not visited
// `f_get_all_tuning_records` is not visited
// `f_size` is not visited
}

Expand All @@ -273,6 +286,12 @@ class PyDatabaseNode : public DatabaseNode {
return f_get_top_k(workload, top_k);
}

Array<TuningRecord> GetAllTuningRecords() final {
ICHECK(f_get_all_tuning_records != nullptr)
<< "PyDatabase's GetAllTuningRecords method not implemented!";
return f_get_all_tuning_records();
}

int64_t Size() final {
ICHECK(f_size != nullptr) << "PyDatabase's Size method not implemented!";
return f_size();
Expand Down Expand Up @@ -302,13 +321,15 @@ class Database : public runtime::ObjectRef {
* \param f_commit_workload The packed function of `CommitWorkload`.
* \param f_commit_tuning_record The packed function of `CommitTuningRecord`.
* \param f_get_top_k The packed function of `GetTopK`.
* \param f_get_all_tuning_records The packed function of `GetAllTuningRecords`.
* \param f_size The packed function of `Size`.
* \return The created database.
*/
TVM_DLL static Database PyDatabase(PyDatabaseNode::FHasWorkload f_has_workload,
PyDatabaseNode::FCommitWorkload f_commit_workload,
PyDatabaseNode::FCommitTuningRecord f_commit_tuning_record,
PyDatabaseNode::FGetTopK f_get_top_k,
PyDatabaseNode::FGetAllTuningRecords f_get_all_tuning_records,
PyDatabaseNode::FSize f_size);
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Database, runtime::ObjectRef, DatabaseNode);
};
Expand Down
2 changes: 2 additions & 0 deletions include/tvm/meta_schedule/tune_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ class TuneContextNode : public runtime::Object {

/*! \brief Initialize members that needs initialization with tune context. */
void Initialize();
/*! \brief Construct the measure candidate given initial IR module and trace. */
MeasureCandidate _GetMeasureCandidate(const IRModule& mod, const tir::Trace& trace);
/*! \brief Set the measure candidates from the SearchStrategy */
void _SetMeasureCandidates(const Array<MeasureCandidate>& candidates);
/*!
Expand Down
23 changes: 23 additions & 0 deletions python/tvm/meta_schedule/database/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,16 @@ def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]:
"""
return _ffi_api.DatabaseGetTopK(self, workload, top_k) # type: ignore # pylint: disable=no-member

def get_all_tuning_records(self) -> List[TuningRecord]:
"""Get all the tuning records from the database.
Returns
-------
tuning_records : List[TuningRecord]
All tuning records from the database.
"""
return _ffi_api.DatabaseGetAllTuningRecords(self) # type: ignore # pylint: disable=no-member

def __len__(self) -> int:
"""Get the number of records in the database.
Expand All @@ -229,6 +239,7 @@ def __init__(
f_commit_workload: Callable = None,
f_commit_tuning_record: Callable = None,
f_get_top_k: Callable = None,
f_get_all_tuning_records: Callable = None,
f_size: Callable = None,
):
"""Constructor."""
Expand All @@ -239,6 +250,7 @@ def __init__(
f_commit_workload,
f_commit_tuning_record,
f_get_top_k,
f_get_all_tuning_records,
f_size,
)

Expand All @@ -258,6 +270,7 @@ class PyDatabase:
"commit_workload",
"commit_tuning_record",
"get_top_k",
"get_all_tuning_records",
"__len__",
],
}
Expand Down Expand Up @@ -317,6 +330,16 @@ def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]:
"""
raise NotImplementedError

def get_all_tuning_records(self) -> List[TuningRecord]:
"""Get all the tuning records from the database.
Returns
-------
tuning_records : List[TuningRecord]
All tuning records from the database.
"""
raise NotImplementedError

def __len__(self) -> int:
"""Get the number of records in the database.
Expand Down
3 changes: 3 additions & 0 deletions python/tvm/meta_schedule/database/memory_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]:
)
)[: int(top_k)]

def get_all_tuning_records(self) -> List[TuningRecord]:
return self.records

def __len__(self) -> int:
return len(self.records)

Expand Down
23 changes: 12 additions & 11 deletions python/tvm/meta_schedule/testing/dataset_sample_candidates.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,14 @@ def sample_candidates(task, task_name, model_name):
-------
None
"""
candidate_path = os.path.join(
args.candidate_cache_dir, model_name, task_name + "_candidates.json"
)
workload_path = os.path.join(args.candidate_cache_dir, model_name, task_name + "_workload.json")
database = ms.database.JSONDatabase(
path_workload=workload_path,
path_tuning_record=candidate_path,
)
sample_init_population = tvm.get_global_func(
"meta_schedule.SearchStrategyEvolutionarySearchSampleInitPopulation"
)
Expand All @@ -128,7 +136,7 @@ def sample_candidates(task, task_name, model_name):
context.initialize()
context.pre_tuning(
context.generate_design_space(),
database=ms.database.MemoryDatabase(), # type: ignore
database=database,
cost_model=ms.cost_model.RandomModel(), # type: ignore
)

Expand All @@ -148,16 +156,9 @@ def sample_candidates(task, task_name, model_name):
all_states = all_states[: args.num_samples_per_task]

workload = ms.database.Workload(context.mod)
file_path = os.path.join(args.candidate_cache_dir, model_name, task_name + ".json")
with open(file_path, "w", encoding="utf8") as file:
for i, state in enumerate(all_states):
tuning_record = ms.database.TuningRecord(state.trace, workload)
json_str = json.dumps(tuning_record.as_json())
assert "\n" not in json_str, "Failed to generate single line string."
if i == len(all_states) - 1:
file.write(json_str)
else:
file.write(json_str + "\n")
database.commit_workload(context.mod)
for state in all_states:
database.commit_tuning_record(ms.database.TuningRecord(state.trace, workload))


args = _parse_args() # pylint: disable=invalid-name
Expand Down
199 changes: 199 additions & 0 deletions python/tvm/meta_schedule/testing/distributed_measure_candidates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-docstring

import argparse
import glob
import os
import time

from tqdm import tqdm # type: ignore
from tvm import meta_schedule as ms
from tvm.target import Target


def _parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--candidate_cache_dir", type=str, help="Please provide the full path to the candidates."
)
parser.add_argument(
"--result_cache_dir", type=str, help="Please provide the full path to the result database."
)
parser.add_argument(
"--target",
type=str,
default="nvidia/nvidia-v100",
help="Please specify the target hardware for tuning context.",
)
parser.add_argument(
"--rpc_host", type=str, help="Please provide the private IPv4 address for the tracker."
)
parser.add_argument(
"--rpc_port", type=int, default=4445, help="Please provide the port for the tracker."
)
parser.add_argument(
"--rpc_key",
type=str,
default="p3.2xlarge",
help="Please provide the key for the rpc servers.",
)
parser.add_argument(
"--builder_timeout_sec",
type=int,
default=10,
help="The time for the builder session to time out.",
)
parser.add_argument(
"--min_repeat_ms", type=int, default=100, help="The time for preheating the gpu."
)
parser.add_argument(
"--runner_timeout_sec",
type=int,
default=100,
help="The time for the runner session to time out.",
)
parser.add_argument(
"--batch_size",
type=int,
default=128,
help="The batch size of candidates sent to builder and runner each time.",
)
return parser.parse_args()


# pylint: disable=too-many-locals
def measure_candidates(database, builder, runner):
"""Send the candidates to builder and runner for distributed measurement,
and save the results in a new json database.
Parameters
----------
database : JSONDatabase
The database for candidates to be measured.
builder : Builder
The builder for building the candidates.
runner : Runner
The runner for measuring the candidates.
Returns
-------
None
"""
candidates, runner_results, build_fail_indices, run_fail_indices = [], [], [], []
build_time, run_time = 0.0, 0.0
context = ms.TuneContext(target=Target(args.target))
tuning_records = database.get_all_tuning_records()
for record in tuning_records:
candidates.append(context.get_measure_candidate(record.workload.mod, record.trace))
for idx in range(0, len(candidates), args.batch_size):
batch_candidates = candidates[idx : idx + args.batch_size]
context.set_measure_candidates(batch_candidates)
build_start_time = time.time()
context.send_to_builder(builder)
build_end_time = time.time()
context.send_to_runner(runner)
batch_runner_results = context.join()
run_end_time = time.time()
runner_results.extend(batch_runner_results)
for i, result in enumerate(context.builder_results):
if result.error_msg is None:
ms.utils.remove_build_dir(result.artifact_path)
else:
build_fail_indices.append(i + idx)
context.clear_measure_state()
build_time += build_end_time - build_start_time
run_time += run_end_time - build_end_time

model_name, workload_name = database.path_workload.split("/")[-2:]
record_name = database.path_tuning_record.split("/")[-1]
new_database = ms.database.JSONDatabase(
path_workload=os.path.join(args.result_cache_dir, model_name, workload_name),
path_tuning_record=os.path.join(args.result_cache_dir, model_name, record_name),
)
workload = tuning_records[0].workload
new_database.commit_workload(workload.mod)
for i, (record, result) in enumerate(zip(tuning_records, runner_results)):
if result.error_msg is None:
new_database.commit_tuning_record(
ms.database.TuningRecord(
trace=record.trace,
workload=workload,
run_secs=[v.value for v in result.run_secs],
target=Target(args.target),
)
)
else:
run_fail_indices.append(i)
fail_indices_name = workload_name[:-13] + "failed_indices.txt"
with open(
os.path.join(args.result_cache_dir, model_name, fail_indices_name), "w", encoding="utf8"
) as file:
file.write(" ".join([str(n) for n in run_fail_indices]))
print(
f"Builder time: {build_time}, Runner time: {run_time}\n\
Failed number of builds: {len(build_fail_indices)},\
Failed number of runs: {len(run_fail_indices)}"
)


args = _parse_args() # pylint: disable=invalid-name


def main():
builder = ms.builder.LocalBuilder(timeout_sec=args.builder_timeout_sec)
runner = ms.runner.RPCRunner(
rpc_config=ms.runner.RPCConfig(
tracker_host=args.rpc_host,
tracker_port=args.rpc_port,
tracker_key=args.rpc_key,
session_timeout_sec=args.runner_timeout_sec,
),
evaluator_config=ms.runner.EvaluatorConfig(
number=1,
repeat=1,
min_repeat_ms=args.min_repeat_ms,
enable_cpu_cache_flush=False,
),
max_workers=os.cpu_count(),
)
if not os.path.isdir(args.candidate_cache_dir):
raise Exception("Please provide a correct candidate cache dir.")
try:
os.makedirs(args.result_cache_dir, exist_ok=True)
except OSError:
print(f"Directory {args.result_cache_dir} cannot be created successfully.")
model_dirs = glob.glob(os.path.join(args.candidate_cache_dir, "*"))
for model_dir in model_dirs:
model_name = model_dir.split("/")[-1]
os.makedirs(os.path.join(args.result_cache_dir, model_name), exist_ok=True)
all_tasks = glob.glob(os.path.join(model_dir, "*.json"))
workload_paths = []
for path in all_tasks:
if "workload" in path:
workload_paths.append(path)
for workload_path in tqdm(workload_paths):
candidate_path = workload_path[:-13] + "candidates.json"
database = ms.database.JSONDatabase(
path_workload=workload_path,
path_tuning_record=candidate_path,
)
measure_candidates(database, builder, runner)


if __name__ == "__main__":
main()
Loading

0 comments on commit 36ad76c

Please sign in to comment.