Skip to content

Commit

Permalink
Python bindings
Browse files Browse the repository at this point in the history
  • Loading branch information
oir committed Feb 19, 2024
1 parent 76f03e9 commit a0e20a1
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 6 deletions.
41 changes: 35 additions & 6 deletions python/barkeep.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <iostream>
#include <pybind11/operators.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
Expand Down Expand Up @@ -28,8 +29,6 @@ enum class DType { Int, Float, AtomicInt, AtomicFloat };
enum class DType { Int, Float, AtomicInt };
#endif

#include <iostream>

struct PyFileStream : public std::stringbuf, public std::ostream {
py::object file_;

Expand All @@ -53,7 +52,7 @@ class Animation_ : public Animation {

Animation_(py::object file = py::none(),
std::string message = "",
AnimationStyle style = Ellipsis,
std::variant<AnimationStyle, Strings> style = Ellipsis,
double interval = 0.,
bool no_tty = false)
: Animation({.out = nullptr,
Expand Down Expand Up @@ -165,7 +164,7 @@ class ProgressBar_ : public ProgressBar<T> {
std::string message = "",
std::optional<double> speed = std::nullopt,
std::string speed_unit = "it/s",
ProgressBarStyle style = Blocks,
std::variant<ProgressBarStyle, BarParts> style = Blocks,
double interval = 0.,
bool no_tty = false)
: ProgressBar<T>(nullptr,
Expand Down Expand Up @@ -259,6 +258,36 @@ PYBIND11_MODULE(barkeep, m) {
#endif
.export_values();

py::class_<BarParts>(m, "BarParts")
.def(py::init<std::string,
std::string,
std::vector<std::string>,
std::vector<std::string>,
std::string,
std::string,
std::string,
std::string,
std::string,
std::string,
std::string,
std::string,
std::string,
std::string>(),
"left"_a,
"right"_a,
"fill"_a,
"empty"_a,
"incomplete_left_modifier"_a = "",
"complete_left_modifier"_a = "",
"middle_modifier"_a = "",
"right_modifier"_a = "",
"percent_left_modifier"_a = "",
"percent_right_modifier"_a = "",
"value_left_modifier"_a = "",
"value_right_modifier"_a = "",
"speed_left_modifier"_a = "",
"speed_right_modifier"_a = "");

auto async_display = py::class_<AsyncDisplay>(m, "AsyncDisplay")
.def("show", &AsyncDisplay::show)
.def("done", &AsyncDisplay::done);
Expand All @@ -267,7 +296,7 @@ PYBIND11_MODULE(barkeep, m) {
.def(py::init([](py::object file,
std::string msg,
double interval,
AnimationStyle style,
std::variant<AnimationStyle, Strings> style,
bool no_tty) {
return Animation_(file, msg, style, interval, no_tty);
}),
Expand Down Expand Up @@ -423,7 +452,7 @@ PYBIND11_MODULE(barkeep, m) {
py::object file,
std::string msg,
std::optional<double> interval,
ProgressBarStyle style,
std::variant<ProgressBarStyle, BarParts> style,
std::optional<double> speed,
std::string speed_unit,
std::optional<std::string> fmt,
Expand Down
43 changes: 43 additions & 0 deletions python/tests/test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from barkeep import (
Animation,
AnimationStyle,
BarParts,
Counter,
DType,
ProgressBar,
Expand Down Expand Up @@ -79,6 +80,17 @@ def test_animation(i: int, sty: AnimationStyle):
check_anim(check_and_get_parts(out.getvalue()), "Working", animation_stills[i])


def test_custom_animation():
out = io.StringIO()

anim = Animation(message="Working", style=["a", "b", "c"], interval=0.1, file=out)
anim.show()
time.sleep(1)
anim.done()

check_anim(check_and_get_parts(out.getvalue()), "Working", ["a", "b", "c"])


@pytest.mark.parametrize("dtype", dtypes, indirect=True)
@pytest.mark.parametrize("amount", [0, 3])
@pytest.mark.parametrize("discount", [None, 1])
Expand Down Expand Up @@ -277,6 +289,37 @@ def test_progress_bar(dtype, sty, no_tty):
last_spaces = spaces


@pytest.mark.parametrize("dtype", dtypes, indirect=True)
@pytest.mark.parametrize("no_tty", [True, False])
def test_custom_progress_bar(dtype, no_tty):
out = io.StringIO()

bar = ProgressBar(
value=0,
total=50,
message="Computing",
interval=0.001,
file=out,
dtype=dtype,
style=BarParts(left="[", right="]", fill=[")"], empty=[" "]),
no_tty=no_tty,
)
bar.show()
for _ in range(50):
time.sleep(0.0013)
bar += 1
bar.done()

parts = check_and_get_parts(out.getvalue(), no_tty=no_tty)

# Check that space is shrinking
last_spaces = 100000
for part in parts:
spaces = part.count(" ")
assert spaces <= last_spaces
last_spaces = spaces


@pytest.mark.parametrize("dtype", dtypes, indirect=True)
@pytest.mark.parametrize("above", [True, False])
@pytest.mark.parametrize("no_tty", [True, False])
Expand Down

0 comments on commit a0e20a1

Please sign in to comment.