Skip to content

Commit

Permalink
ENH: Update LabelOverlapMeasures to use dynamic threading
Browse files Browse the repository at this point in the history
This follows the LabelStatisticImageFilter with concurent merging for
each chunk.
  • Loading branch information
blowekamp authored and dzenanz committed Aug 2, 2022
1 parent 83480a0 commit 19fab83
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

#include "itkImageToImageFilter.h"
#include "itkNumericTraits.h"

#include <mutex>
#include <unordered_map>

namespace itk
Expand Down Expand Up @@ -235,26 +235,21 @@ class ITK_TEMPLATE_EXPORT LabelOverlapMeasuresImageFilter : public ImageToImageF
void
BeforeThreadedGenerateData() override;

void
AfterThreadedGenerateData() override;

/** Multi-thread version GenerateData. */
void
ThreadedGenerateData(const RegionType &, ThreadIdType) override;

void
DynamicThreadedGenerateData(const RegionType &) override
{
itkExceptionMacro("This class requires threadId so it must use classic multi-threading model");
}
DynamicThreadedGenerateData(const RegionType &) override;

// Override since the filter produces all of its output
void
EnlargeOutputRequestedRegion(DataObject * data) override;

void
MergeMap(MapType & m1, MapType & m2) const;

private:
std::vector<MapType> m_LabelSetMeasuresPerThread;
MapType m_LabelSetMeasures;
MapType m_LabelSetMeasures;

std::mutex m_Mutex;
}; // end of class

} // end namespace itk
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ LabelOverlapMeasuresImageFilter<TLabelImage>::LabelOverlapMeasuresImageFilter()

// This filter requires two input images
this->SetNumberOfRequiredInputs(2);
this->DynamicMultiThreadingOff();
}

template <typename TLabelImage>
Expand All @@ -57,85 +56,74 @@ template <typename TLabelImage>
void
LabelOverlapMeasuresImageFilter<TLabelImage>::BeforeThreadedGenerateData()
{
ThreadIdType numberOfWorkUnits = this->GetNumberOfWorkUnits();

// Resize the thread temporaries
this->m_LabelSetMeasuresPerThread.resize(numberOfWorkUnits);

// Initialize the temporaries
for (ThreadIdType n = 0; n < numberOfWorkUnits; ++n)
{
this->m_LabelSetMeasuresPerThread[n].clear();
}
Superclass::BeforeThreadedGenerateData();

// Initialize the final map
this->m_LabelSetMeasures.clear();
}

template <typename TLabelImage>
void
LabelOverlapMeasuresImageFilter<TLabelImage>::AfterThreadedGenerateData()
LabelOverlapMeasuresImageFilter<TLabelImage>::MergeMap(MapType & m1, MapType & m2) const
{
// Run through the map for each thread and accumulate the set measures.
for (ThreadIdType n = 0; n < this->GetNumberOfWorkUnits(); ++n)
for (auto m2_value : m2)
{
// Iterate over the map for this thread
for (MapConstIterator threadIt = this->m_LabelSetMeasuresPerThread[n].begin();
threadIt != this->m_LabelSetMeasuresPerThread[n].end();
++threadIt)
// Does this label exist in the cumulative structure yet?
auto m1It = m1.find(m2_value.first);
if (m1It == m1.end())
{
// Does this label exist in the cumulative structure yet?
auto mapIt = this->m_LabelSetMeasures.find(threadIt->first);
if (mapIt == this->m_LabelSetMeasures.end())
{
// Create a new entry
using MapValueType = typename MapType::value_type;
mapIt = this->m_LabelSetMeasures.insert(MapValueType(threadIt->first, LabelSetMeasures())).first;
}

// Accumulate the information from this thread
mapIt->second.m_Source += threadIt->second.m_Source; // segmentation which will be compared (TP+FP)
mapIt->second.m_Target += threadIt->second.m_Target; // Ground Truth segmentation (TP+FN)
mapIt->second.m_Union += threadIt->second.m_Union; // (TP+FN+FP)
mapIt->second.m_Intersection += threadIt->second.m_Intersection; //(TP)
mapIt->second.m_SourceComplement += threadIt->second.m_SourceComplement; //(FP)
mapIt->second.m_TargetComplement += threadIt->second.m_TargetComplement; //(FN)
} // end of thread map iterator loop
} // end of thread loop
// move m2 entry into m1, this reuses the histogram if needed.
m1.emplace(m2_value.first, std::move(m2_value.second));
}
else
{
typename MapType::mapped_type & labelStats = m1It->second;

// Accumulate the information into m1
labelStats.m_Source += m2_value.second.m_Source; // segmentation which will be compared (TP+FP)
labelStats.m_Target += m2_value.second.m_Target; // Ground Truth segmentation (TP+FN)
labelStats.m_Union += m2_value.second.m_Union; // (TP+FN+FP)
labelStats.m_Intersection += m2_value.second.m_Intersection; //(TP)
labelStats.m_SourceComplement += m2_value.second.m_SourceComplement; //(FP)
labelStats.m_TargetComplement += m2_value.second.m_TargetComplement; //(FN)
}
}
}

template <typename TLabelImage>
void
LabelOverlapMeasuresImageFilter<TLabelImage>::ThreadedGenerateData(const RegionType & outputRegionForThread,
ThreadIdType threadId)
LabelOverlapMeasuresImageFilter<TLabelImage>::DynamicThreadedGenerateData(const RegionType & outputRegionForThread)
{

MapType localStatistics;

ImageRegionConstIterator<LabelImageType> itS(this->GetSourceImage(), outputRegionForThread);
ImageRegionConstIterator<LabelImageType> itT(this->GetTargetImage(), outputRegionForThread);

// Support progress methods/callbacks
ProgressReporter progress(this, threadId, 2 * outputRegionForThread.GetNumberOfPixels());
TotalProgressReporter progress(this, this->GetSourceImage()->GetLargestPossibleRegion().GetNumberOfPixels());

for (itS.GoToBegin(), itT.GoToBegin(); !itS.IsAtEnd(); ++itS, ++itT)
{
LabelType sourceLabel = itS.Get();
LabelType targetLabel = itT.Get();

// Is the label already in this thread?
auto mapItS = this->m_LabelSetMeasuresPerThread[threadId].find(sourceLabel);
auto mapItT = this->m_LabelSetMeasuresPerThread[threadId].find(targetLabel);
// Does the label exist in the local map?
auto mapItS = localStatistics.find(sourceLabel);
auto mapItT = localStatistics.find(targetLabel);

if (mapItS == this->m_LabelSetMeasuresPerThread[threadId].end())
if (mapItS == localStatistics.end())
{
// Create a new label set measures object
using MapValueType = typename MapType::value_type;
mapItS = this->m_LabelSetMeasuresPerThread[threadId].insert(MapValueType(sourceLabel, LabelSetMeasures())).first;
mapItS = localStatistics.insert(MapValueType(sourceLabel, LabelSetMeasures())).first;
}

if (mapItT == this->m_LabelSetMeasuresPerThread[threadId].end())
if (mapItT == localStatistics.end())
{
// Create a new label set measures object
using MapValueType = typename MapType::value_type;
mapItT = this->m_LabelSetMeasuresPerThread[threadId].insert(MapValueType(targetLabel, LabelSetMeasures())).first;
mapItT = localStatistics.insert(MapValueType(targetLabel, LabelSetMeasures())).first;
}

mapItS->second.m_Source++;
Expand All @@ -157,6 +145,35 @@ LabelOverlapMeasuresImageFilter<TLabelImage>::ThreadedGenerateData(const RegionT

progress.CompletedPixel();
}


// Merge localStatistics and m_LabelSetMeasures concurrently safe in a
// local copy, this thread may do multiple merges.
while (true)
{

{
std::unique_lock<std::mutex> lock(m_Mutex);

if (m_LabelSetMeasures.empty())
{
swap(m_LabelSetMeasures, localStatistics);
break;
}
else
{
// copy the output map to thread local storage
MapType tomerge;
swap(m_LabelSetMeasures, tomerge);

// allow other threads to merge data
lock.unlock();

// Merge tomerge into localStatistics, locally
MergeMap(localStatistics, tomerge);
}
} // release lock
}
}

template <typename TLabelImage>
Expand Down

0 comments on commit 19fab83

Please sign in to comment.