Skip to content

Commit

Permalink
MPP: implement minTSO Scheduler (#4044)
Browse files Browse the repository at this point in the history
close #4256
  • Loading branch information
fzhedu authored Mar 15, 2022
1 parent cc8a5c5 commit 4eb79e3
Show file tree
Hide file tree
Showing 9 changed files with 380 additions and 39 deletions.
50 changes: 46 additions & 4 deletions dbms/src/Flash/Mpp/MPPTask.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <Flash/Mpp/MPPTask.h>
#include <Flash/Mpp/MPPTaskManager.h>
#include <Flash/Mpp/MPPTunnelSet.h>
#include <Flash/Mpp/MinTSOScheduler.h>
#include <Flash/Mpp/Utils.h>
#include <Flash/Mpp/getMPPTaskLog.h>
#include <Interpreters/ProcessList.h>
Expand Down Expand Up @@ -42,6 +43,7 @@ MPPTask::MPPTask(const mpp::TaskMeta & meta_, const ContextPtr & context_)
, id(meta.start_ts(), meta.task_id())
, log(getMPPTaskLog("MPPTask", id))
, mpp_task_statistics(id, meta.address())
, scheduled(false)
{}

MPPTask::~MPPTask()
Expand Down Expand Up @@ -251,7 +253,6 @@ void MPPTask::runImpl()
LOG_WARNING(log, "task not in initializing state, skip running");
return;
}

Stopwatch stopwatch;
GET_METRIC(tiflash_coprocessor_request_count, type_run_mpp_task).Increment();
GET_METRIC(tiflash_coprocessor_handling_request_count, type_run_mpp_task).Increment();
Expand All @@ -260,13 +261,16 @@ void MPPTask::runImpl()
GET_METRIC(tiflash_coprocessor_request_duration_seconds, type_run_mpp_task).Observe(stopwatch.elapsedSeconds());
});
String err_msg;
LOG_INFO(log, "task starts running");
try
{
LOG_FMT_INFO(log, "task starts preprocessing");
preprocess();
needed_threads = estimateCountOfNewThreads();
LOG_FMT_DEBUG(log, "Estimate new thread count of query :{} including tunnel_threads: {} , receiver_threads: {}", needed_threads, dag_context->tunnel_set->getRemoteTunnelCnt(), dag_context->getNewThreadCountOfExchangeReceiver());

scheduleOrWait();

int new_thd_cnt = estimateCountOfNewThreads();
LOG_FMT_DEBUG(log, "Estimate new thread count of query :{} including tunnel_thds: {} , receiver_thds: {}", new_thd_cnt, dag_context->tunnel_set->getRemoteTunnelCnt(), dag_context->getNewThreadCountOfExchangeReceiver());
LOG_FMT_INFO(log, "task starts running");
memory_tracker = current_memory_tracker;
if (status.load() != RUNNING)
{
Expand Down Expand Up @@ -380,6 +384,7 @@ void MPPTask::cancel(const String & reason)
}
else if (previous_status == RUNNING && switchStatus(RUNNING, CANCELLED))
{
scheduleThisTask();
context->getProcessList().sendCancelToQuery(context->getCurrentQueryId(), context->getClientInfo().current_user, true);
closeAllTunnels(reason);
/// runImpl is running, leave remaining work to runImpl
Expand All @@ -394,11 +399,48 @@ bool MPPTask::switchStatus(TaskStatus from, TaskStatus to)
return status.compare_exchange_strong(from, to);
}

void MPPTask::scheduleOrWait()
{
if (!manager->tryToScheduleTask(shared_from_this()))
{
LOG_FMT_INFO(log, "task waits for schedule");
Stopwatch stopwatch;
{
std::unique_lock lock(schedule_mu);
schedule_cv.wait(lock, [&] { return scheduled; });
}
LOG_FMT_INFO(log, "task waits for {} ms to schedule and starts to run in parallel.", stopwatch.elapsedMilliseconds());
}
}

void MPPTask::scheduleThisTask()
{
std::unique_lock lock(schedule_mu);
if (!scheduled)
{
LOG_FMT_INFO(log, "task gets schedule");
scheduled = true;
schedule_cv.notify_one();
}
}

int MPPTask::estimateCountOfNewThreads()
{
if (dag_context == nullptr || dag_context->getBlockIO().in == nullptr || dag_context->tunnel_set == nullptr)
throw Exception("It should not estimate the threads for the uninitialized task" + id.toString());

// Estimated count of new threads from InputStreams(including ExchangeReceiver), remote MppTunnels s.
return dag_context->getBlockIO().in->estimateNewThreadCount() + 1
+ dag_context->tunnel_set->getRemoteTunnelCnt();
}

int MPPTask::getNeededThreads()
{
if (needed_threads == 0)
{
throw Exception(" the needed_threads of task " + id.toString() + " is not initialized!");
}
return needed_threads;
}

} // namespace DB
18 changes: 15 additions & 3 deletions dbms/src/Flash/Mpp/MPPTask.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,13 @@ class MPPTask : public std::enable_shared_from_this<MPPTask>

void prepare(const mpp::DispatchTaskRequest & task_request);

void preprocess();

void run();

void registerTunnel(const MPPTaskId & id, MPPTunnelPtr tunnel);

int estimateCountOfNewThreads();
int getNeededThreads();

void scheduleThisTask();

// tunnel and error_message
std::pair<MPPTunnelPtr, String> getTunnel(const ::mpp::EstablishMPPConnectionRequest * request);
Expand All @@ -76,6 +76,12 @@ class MPPTask : public std::enable_shared_from_this<MPPTask>

bool switchStatus(TaskStatus from, TaskStatus to);

void preprocess();

void scheduleOrWait();

int estimateCountOfNewThreads();

tipb::DAGRequest dag_req;

ContextPtr context;
Expand Down Expand Up @@ -104,6 +110,12 @@ class MPPTask : public std::enable_shared_from_this<MPPTask>
Exception err;

friend class MPPTaskManager;

int needed_threads;

std::mutex schedule_mu;
std::condition_variable schedule_cv;
bool scheduled;
};

using MPPTaskPtr = std::shared_ptr<MPPTask>;
Expand Down
71 changes: 48 additions & 23 deletions dbms/src/Flash/Mpp/MPPTaskManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@

namespace DB
{
MPPTaskManager::MPPTaskManager()
: log(&Poco::Logger::get("TaskManager"))
MPPTaskManager::MPPTaskManager(MPPTaskSchedulerPtr scheduler_)
: scheduler(std::move(scheduler_))
, log(&Poco::Logger::get("TaskManager"))
{}

MPPTaskPtr MPPTaskManager::findTaskWithTimeout(const mpp::TaskMeta & meta, std::chrono::seconds timeout, std::string & errMsg)
Expand All @@ -25,15 +26,15 @@ MPPTaskPtr MPPTaskManager::findTaskWithTimeout(const mpp::TaskMeta & meta, std::
{
return false;
}
else if (query_it->second.to_be_cancelled)
else if (query_it->second->to_be_cancelled)
{
/// if the query is cancelled, return true to stop waiting timeout.
LOG_WARNING(log, fmt::format("Query {} is cancelled, all its tasks are invalid.", id.start_ts));
cancelled = true;
return true;
}
it = query_it->second.task_map.find(id);
return it != query_it->second.task_map.end();
it = query_it->second->task_map.find(id);
return it != query_it->second->task_map.end();
});
if (cancelled)
{
Expand All @@ -50,25 +51,26 @@ MPPTaskPtr MPPTaskManager::findTaskWithTimeout(const mpp::TaskMeta & meta, std::

void MPPTaskManager::cancelMPPQuery(UInt64 query_id, const String & reason)
{
MPPQueryTaskSet task_set;
MPPQueryTaskSetPtr task_set;
{
/// cancel task may take a long time, so first
/// set a flag, so we can cancel task one by
/// one without holding the lock
std::lock_guard<std::mutex> lock(mu);
auto it = mpp_query_map.find(query_id);
if (it == mpp_query_map.end() || it->second.to_be_cancelled)
if (it == mpp_query_map.end() || it->second->to_be_cancelled)
return;
it->second.to_be_cancelled = true;
it->second->to_be_cancelled = true;
task_set = it->second;
scheduler->deleteCancelledQuery(query_id, *this);
cv.notify_all();
}
LOG_WARNING(log, fmt::format("Begin cancel query: {}", query_id));
FmtBuffer fmt_buf;
fmt_buf.fmtAppend("Remaining task in query {} are: ", query_id);
// TODO: cancel tasks in order rather than issuing so many threads to cancel tasks
std::vector<std::thread> cancel_workers;
for (const auto & task : task_set.task_map)
for (const auto & task : task_set->task_map)
{
fmt_buf.fmtAppend("{} ", task.first.toString());
std::thread t(&MPPTask::cancel, task.second, std::ref(reason));
Expand All @@ -79,7 +81,7 @@ void MPPTaskManager::cancelMPPQuery(UInt64 query_id, const String & reason)
{
worker.join();
}
MPPQueryTaskSet canceled_task_set;
MPPQueryTaskSetPtr canceled_task_set;
{
std::lock_guard<std::mutex> lock(mu);
/// just to double check the query still exists
Expand All @@ -89,6 +91,7 @@ void MPPTaskManager::cancelMPPQuery(UInt64 query_id, const String & reason)
/// hold the canceled task set, so the mpp task will not be deconstruct when holding the
/// `mu` of MPPTaskManager, otherwise it might cause deadlock
canceled_task_set = it->second;
scheduler->deleteThenSchedule(query_id, *this);
mpp_query_map.erase(it);
}
}
Expand All @@ -99,17 +102,26 @@ bool MPPTaskManager::registerTask(MPPTaskPtr task)
{
std::unique_lock<std::mutex> lock(mu);
const auto & it = mpp_query_map.find(task->id.start_ts);
if (it != mpp_query_map.end() && it->second.to_be_cancelled)
if (it != mpp_query_map.end() && it->second->to_be_cancelled)
{
LOG_WARNING(log, "Do not register task: " + task->id.toString() + " because the query is to be cancelled.");
cv.notify_all();
return false;
}
if (it != mpp_query_map.end() && it->second.task_map.find(task->id) != it->second.task_map.end())
if (it != mpp_query_map.end() && it->second->task_map.find(task->id) != it->second->task_map.end())
{
throw Exception("The task " + task->id.toString() + " has been registered");
}
mpp_query_map[task->id.start_ts].task_map.emplace(task->id, task);
if (it == mpp_query_map.end()) /// the first one
{
auto ptr = std::make_shared<MPPQueryTaskSet>();
ptr->task_map.emplace(task->id, task);
mpp_query_map.insert({task->id.start_ts, ptr});
}
else
{
mpp_query_map[task->id.start_ts]->task_map.emplace(task->id, task);
}
task->manager = this;
cv.notify_all();
return true;
Expand All @@ -121,23 +133,24 @@ void MPPTaskManager::unregisterTask(MPPTask * task)
auto it = mpp_query_map.find(task->id.start_ts);
if (it != mpp_query_map.end())
{
if (it->second.to_be_cancelled)
if (it->second->to_be_cancelled)
return;
auto task_it = it->second.task_map.find(task->id);
if (task_it != it->second.task_map.end())
auto task_it = it->second->task_map.find(task->id);
if (task_it != it->second->task_map.end())
{
it->second.task_map.erase(task_it);
if (it->second.task_map.empty())
it->second->task_map.erase(task_it);
if (it->second->task_map.empty())
{
/// remove query task map if the task is the last one
scheduler->deleteThenSchedule(task->id.start_ts, *this);
mpp_query_map.erase(it);
}
return;
}
}
LOG_ERROR(log, "The task " + task->id.toString() + " cannot be found and fail to unregister");
}

MPPTaskManager::~MPPTaskManager() {}

std::vector<UInt64> MPPTaskManager::getCurrentQueries()
{
std::vector<UInt64> ret;
Expand All @@ -154,9 +167,9 @@ std::vector<MPPTaskPtr> MPPTaskManager::getCurrentTasksForQuery(UInt64 query_id)
std::vector<MPPTaskPtr> ret;
std::lock_guard<std::mutex> lock(mu);
const auto & it = mpp_query_map.find(query_id);
if (it == mpp_query_map.end() || it->second.to_be_cancelled)
if (it == mpp_query_map.end() || it->second->to_be_cancelled)
return ret;
for (const auto & task_it : it->second.task_map)
for (const auto & task_it : it->second->task_map)
ret.push_back(task_it.second);
return ret;
}
Expand All @@ -167,10 +180,22 @@ String MPPTaskManager::toString()
String res("(");
for (auto & query_it : mpp_query_map)
{
for (auto & it : query_it.second.task_map)
for (auto & it : query_it.second->task_map)
res += it.first.toString() + ", ";
}
return res + ")";
}

MPPQueryTaskSetPtr MPPTaskManager::getQueryTaskSetWithoutLock(UInt64 query_id)
{
auto it = mpp_query_map.find(query_id);
return it == mpp_query_map.end() ? nullptr : it->second;
}

bool MPPTaskManager::tryToScheduleTask(const MPPTaskPtr & task)
{
std::lock_guard lock(mu);
return scheduler->tryToSchedule(task, *this);
}

} // namespace DB
20 changes: 17 additions & 3 deletions dbms/src/Flash/Mpp/MPPTaskManager.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <Flash/Mpp/MPPTask.h>
#include <Flash/Mpp/MinTSOScheduler.h>
#include <common/logger_useful.h>
#include <kvproto/mpp.pb.h>

Expand All @@ -19,16 +20,24 @@ struct MPPQueryTaskSet
/// to MPPQueryTaskSet is protected by the mutex in MPPTaskManager
bool to_be_cancelled = false;
MPPTaskMap task_map;
/// only used in scheduler
UInt32 scheduled_task = 0;
UInt32 used_threads = 0;
std::queue<MPPTaskPtr> waiting_tasks;
};

using MPPQueryTaskSetPtr = std::shared_ptr<MPPQueryTaskSet>;

/// a map from the mpp query id to mpp query task set, we use
/// the start ts of a query as the query id as TiDB will guarantee
/// the uniqueness of the start ts
using MPPQueryMap = std::unordered_map<UInt64, MPPQueryTaskSet>;
using MPPQueryMap = std::unordered_map<UInt64, MPPQueryTaskSetPtr>;

// MPPTaskManger holds all running mpp tasks. It's a single instance holden in Context.
class MPPTaskManager : private boost::noncopyable
{
MPPTaskSchedulerPtr scheduler;

std::mutex mu;

MPPQueryMap mpp_query_map;
Expand All @@ -38,17 +47,22 @@ class MPPTaskManager : private boost::noncopyable
std::condition_variable cv;

public:
MPPTaskManager();
~MPPTaskManager();
explicit MPPTaskManager(MPPTaskSchedulerPtr scheduler);

~MPPTaskManager() = default;

std::vector<UInt64> getCurrentQueries();

std::vector<MPPTaskPtr> getCurrentTasksForQuery(UInt64 query_id);

MPPQueryTaskSetPtr getQueryTaskSetWithoutLock(UInt64 query_id);

bool registerTask(MPPTaskPtr task);

void unregisterTask(MPPTask * task);

bool tryToScheduleTask(const MPPTaskPtr & task);

MPPTaskPtr findTaskWithTimeout(const mpp::TaskMeta & meta, std::chrono::seconds timeout, std::string & errMsg);

void cancelMPPQuery(UInt64 query_id, const String & reason);
Expand Down
Loading

0 comments on commit 4eb79e3

Please sign in to comment.