Skip to content

Commit

Permalink
[fiber] Implement std concurrency interfaces
Browse files Browse the repository at this point in the history
  • Loading branch information
salkinium committed Apr 14, 2024
1 parent a224213 commit 3526783
Show file tree
Hide file tree
Showing 13 changed files with 1,022 additions and 10 deletions.
4 changes: 4 additions & 0 deletions src/modm/processing/fiber.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,7 @@
#include "fiber/fiber.hpp"
#include "fiber/scheduler.hpp"
#include "fiber/functions.hpp"
#include "fiber/mutex.hpp"
#include "fiber/shared_mutex.hpp"
#include "fiber/semaphore.hpp"
#include "fiber/latch.hpp"
68 changes: 68 additions & 0 deletions src/modm/processing/fiber/latch.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Copyright (c) 2023, Niklas Hauser
*
* This file is part of the modm project.
*
* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at http://mozilla.org/MPL/2.0/.
*/
// ----------------------------------------------------------------------------

#pragma once

#include "functions.hpp"
#include <atomic>

namespace modm::fiber
{

/// @ingroup modm_processing_fiber
/// @{

/// Implements the `std::latch` interface for fibers.
/// @see https://en.cppreference.com/w/cpp/thread/latch
class latch
{
latch(const latch&) = delete;
latch& operator=(const latch&) = delete;

using count_t = uint16_t;
std::atomic<count_t> count;
public:
constexpr explicit
latch(std::ptrdiff_t expected)
: count(expected) {}

[[nodiscard]] static constexpr std::ptrdiff_t
max() { return count_t(-1); }

void inline
count_down(std::ptrdiff_t n=1)
{
count -= n;
}

[[nodiscard]] bool inline
try_wait() const
{
return count.load() == 0;
}

void inline
wait() const
{
while(not try_wait()) modm::this_fiber::yield();
}

void inline
arrive_and_wait(std::ptrdiff_t n=1)
{
count_down(n);
wait();
}
};

/// @}

}
7 changes: 6 additions & 1 deletion src/modm/processing/fiber/module.lb
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def is_enabled(env):
not env.has_module(":processing:protothread")

def prepare(module, options):
module.depends(":processing:timer")
module.depends(":processing:timer", ":architecture:atomic")

module.add_query(
EnvironmentQuery(name="__enabled", factory=is_enabled))
Expand Down Expand Up @@ -77,3 +77,8 @@ def build(env):
env.copy("task.hpp")
env.copy("functions.hpp")
env.copy("fiber.hpp")

env.copy("mutex.hpp")
env.copy("shared_mutex.hpp")
env.copy("semaphore.hpp")
env.copy("latch.hpp")
150 changes: 150 additions & 0 deletions src/modm/processing/fiber/mutex.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
/*
* Copyright (c) 2023, Niklas Hauser
*
* This file is part of the modm project.
*
* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at http://mozilla.org/MPL/2.0/.
*/
// ----------------------------------------------------------------------------

#pragma once

#include "functions.hpp"
#include <modm/architecture/interface/atomic_lock.hpp>
#include <atomic>

namespace modm::fiber
{

/// @ingroup modm_processing_fiber
/// @{

/// Implements the `std::mutex` interface for fibers.
/// @see https://en.cppreference.com/w/cpp/thread/mutex
class mutex
{
mutex(const mutex&) = delete;
mutex& operator=(const mutex&) = delete;

std::atomic_bool locked{false};
public:
constexpr mutex() = default;

[[nodiscard]] bool inline
try_lock()
{
bool expected{false};
return locked.compare_exchange_strong(expected, true);
}

void inline
lock()
{
while(not try_lock()) modm::this_fiber::yield();
}

void inline
unlock()
{
locked = false;
}
};

/// Implements the `std::timed_mutex` interface for fibers.
/// @see https://en.cppreference.com/w/cpp/thread/timed_mutex
class timed_mutex : public mutex
{
public:
template< typename Rep, typename Period >
[[nodiscard]] bool
try_lock_for(std::chrono::duration<Rep, Period> sleep_duration)
{
return this_fiber::poll_for(sleep_duration, [this](){ return try_lock(); });
}

template< class Clock, class Duration >
[[nodiscard]] bool
try_lock_until(std::chrono::time_point<Clock, Duration> sleep_time)
{
return this_fiber::poll_until(sleep_time, [this](){ return try_lock(); });
}
};

/// Implements the `std::recursive_mutex` interface for fibers.
/// @see https://en.cppreference.com/w/cpp/thread/recursive_mutex
class recursive_mutex
{
recursive_mutex(const recursive_mutex&) = delete;
recursive_mutex& operator=(const recursive_mutex&) = delete;
using count_t = uint16_t;

static constexpr fiber::id NoOwner{fiber::id(-1)};
volatile fiber::id owner{NoOwner};
static constexpr count_t countMax{count_t(-1)};
volatile count_t count{1};

public:
constexpr recursive_mutex() = default;

[[nodiscard]] bool inline
try_lock()
{
const auto id = modm::this_fiber::get_id();
{
modm::atomic::Lock _;
if (owner == NoOwner) {
owner = id;
// count = 1; is implicit
return true;
}
if (owner == id and count < countMax) {
count++;
return true;
}
}
return false;
}

void inline
lock()
{
while(not try_lock()) modm::this_fiber::yield();
}

void inline
unlock()
{
modm::atomic::Lock _;
if (count > 1) count--;
else {
// count = 1; is implicit
owner = NoOwner;
}
}
};

/// Implements the `std::timed_recursive_mutex` interface for fibers.
/// @see https://en.cppreference.com/w/cpp/thread/recursive_mutex
class timed_recursive_mutex : public recursive_mutex
{
public:
template< typename Rep, typename Period >
[[nodiscard]] bool
try_lock_for(std::chrono::duration<Rep, Period> sleep_duration)
{
return this_fiber::poll_for(sleep_duration, [this](){ return try_lock(); });
}

template< class Clock, class Duration >
[[nodiscard]] bool
try_lock_until(std::chrono::time_point<Clock, Duration> sleep_time)
{
return this_fiber::poll_until(sleep_time, [this](){ return try_lock(); });
}
};

/// @}

}
84 changes: 84 additions & 0 deletions src/modm/processing/fiber/semaphore.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Copyright (c) 2023, Niklas Hauser
*
* This file is part of the modm project.
*
* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at http://mozilla.org/MPL/2.0/.
*/
// ----------------------------------------------------------------------------

#pragma once

#include "functions.hpp"
#include <atomic>

namespace modm::fiber
{

/// @ingroup modm_processing_fiber
/// @{

/// Implements the `std::counting_semaphore` interface for fibers.
/// @see https://en.cppreference.com/w/cpp/thread/counting_semaphore
template< std::ptrdiff_t LeastMaxValue = 255 >
class counting_semaphore
{
counting_semaphore(const counting_semaphore&) = delete;
counting_semaphore& operator=(const counting_semaphore&) = delete;

static_assert(LeastMaxValue <= uint16_t(-1), "counting_semaphore uses a 16-bit counter!");
using count_t = std::conditional_t<LeastMaxValue < 256, uint8_t, uint16_t>;
std::atomic<count_t> count{};
public:
constexpr explicit
counting_semaphore(std::ptrdiff_t desired)
: count(desired) {}

[[nodiscard]] static constexpr std::ptrdiff_t
max() { return count_t(-1); }

[[nodiscard]] bool inline
try_acquire()
{
count_t current = count.load();
do if (current == 0) return false;
while(count.compare_exchange_weak(current, current - 1) == false);
return true;
}

void inline
acquire()
{
while(not try_acquire()) modm::this_fiber::yield();
}

void inline
release()
{
count++;
}

template< typename Rep, typename Period >
[[nodiscard]] bool
try_acquire_for(std::chrono::duration<Rep, Period> sleep_duration)
{
return this_fiber::poll_for(sleep_duration, [this](){ return try_acquire(); });
}

template< class Clock, class Duration >
[[nodiscard]] bool
try_acquire_until(std::chrono::time_point<Clock, Duration> sleep_time)
{
return this_fiber::poll_until(sleep_time, [this](){ return try_acquire(); });
}
};

/// Implements the `std::binary_semaphore` interface for fibers.
/// @see https://en.cppreference.com/w/cpp/thread/counting_semaphore
using binary_semaphore = counting_semaphore<1>;

/// @}

}
Loading

0 comments on commit 3526783

Please sign in to comment.