Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: make Column and Row iterable #55

Merged
58 changes: 31 additions & 27 deletions src/safeds/data/tabular/_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import typing
from numbers import Number
from typing import Any, Callable
from typing import Any, Callable, Iterator

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -34,6 +34,10 @@ def name(self) -> str:
"""
return self._name

@property
def statistics(self) -> ColumnStatistics:
return ColumnStatistics(self)

@property
def type(self) -> ColumnType:
"""
Expand All @@ -46,9 +50,35 @@ def type(self) -> ColumnType:
"""
return self._type

def __eq__(self, other: object) -> bool:
if not isinstance(other, Column):
return NotImplemented
if self is other:
return True
return self._data.equals(other._data) and self.name == other.name

def __getitem__(self, index: int) -> Any:
return self.get_value(index)

def __hash__(self) -> int:
return hash(self._data)

def __iter__(self) -> Iterator[Any]:
return iter(self._data)

def __len__(self) -> int:
return len(self._data)

def __repr__(self) -> str:
tmp = self._data.to_frame()
tmp.columns = [self.name]
return tmp.__repr__()

def __str__(self) -> str:
tmp = self._data.to_frame()
tmp.columns = [self.name]
return tmp.__str__()

def get_value(self, index: int) -> Any:
"""
Return column value at specified index, starting at 0.
Expand All @@ -73,10 +103,6 @@ def get_value(self, index: int) -> Any:

return self._data[index]

@property
def statistics(self) -> ColumnStatistics:
return ColumnStatistics(self)

def count(self) -> int:
"""
Return the number of elements in the column.
Expand Down Expand Up @@ -223,26 +249,6 @@ def get_unique_values(self) -> list[typing.Any]:
"""
return list(self._data.unique())

def __eq__(self, other: object) -> bool:
if not isinstance(other, Column):
return NotImplemented
if self is other:
return True
return self._data.equals(other._data) and self.name == other.name

def __hash__(self) -> int:
return hash(self._data)

def __str__(self) -> str:
tmp = self._data.to_frame()
tmp.columns = [self.name]
return tmp.__str__()

def __repr__(self) -> str:
tmp = self._data.to_frame()
tmp.columns = [self.name]
return tmp.__repr__()

def _ipython_display_(self) -> DisplayHandle:
"""
Return a display object for the column to be used in Jupyter Notebooks.
Expand Down Expand Up @@ -378,7 +384,6 @@ def sum(self) -> float:
return self._column._data.sum()

def variance(self) -> float:

"""
Return the variance of the column. The column has to be numerical.

Expand All @@ -401,7 +406,6 @@ def variance(self) -> float:
return self._column._data.var()

def standard_deviation(self) -> float:

"""
Return the standard deviation of the column. The column has to be numerical.

Expand Down
17 changes: 17 additions & 0 deletions src/safeds/data/tabular/_row.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ def __init__(self, data: typing.Iterable, schema: TableSchema):
def __getitem__(self, column_name: str) -> Any:
return self.get_value(column_name)

def __iter__(self) -> typing.Iterator[Any]:
return iter(self.get_column_names())

def __len__(self) -> int:
return len(self._data)

def get_value(self, column_name: str) -> Any:
"""
Return the value of a specified column.
Expand All @@ -34,6 +40,17 @@ def get_value(self, column_name: str) -> Any:
raise UnknownColumnNameError([column_name])
return self._data[self.schema._get_column_index_by_name(column_name)]

def count(self) -> int:
"""
Return the number of columns in this row.

Returns
-------
count : int
The number of columns.
"""
return len(self._data)

def has_column(self, column_name: str) -> bool:
"""
Return whether the row contains a given column.
Expand Down
4 changes: 2 additions & 2 deletions src/safeds/exceptions/_data_exceptions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
class UnknownColumnNameError(Exception):
class UnknownColumnNameError(KeyError):
"""
Exception raised for trying to access an invalid column name.

Expand Down Expand Up @@ -37,7 +37,7 @@ def __init__(self, column_name: str):
super().__init__(f"Column '{column_name}' already exists.")


class IndexOutOfBoundsError(Exception):
class IndexOutOfBoundsError(IndexError):
"""
Exception raised for trying to access an element by an index that does not exist in the underlying data.

Expand Down
2 changes: 1 addition & 1 deletion tests/safeds/data/tabular/_column/test_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def test_from_columns() -> None:
assert column1._type == column2._type


def negative_test_from_columns() -> None:
def test_from_columns_negative() -> None:
column1 = Column(pd.Series([1, 4]), "A")
column2 = Column(pd.Series(["2", "5"]), "B")

Expand Down
7 changes: 3 additions & 4 deletions tests/safeds/data/tabular/_column/test_count.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pandas as pd
from safeds.data.tabular import Table
from safeds.data.tabular import Column


def test_count_valid() -> None:
table = Table(pd.DataFrame(data={"col1": [1, 2, 3, 4, 5], "col2": [2, 3, 4, 5, 6]}))
assert table.get_column("col1").count() == 5
column = Column([1, 2, 3, 4, 5], "col1")
assert column.count() == 5
6 changes: 6 additions & 0 deletions tests/safeds/data/tabular/_column/test_iter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from safeds.data.tabular import Column


def test_iter() -> None:
column = Column([0, "1"], "testColumn")
assert list(column) == [0, "1"]
6 changes: 6 additions & 0 deletions tests/safeds/data/tabular/_column/test_len.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from safeds.data.tabular import Column


def test_count_valid() -> None:
column = Column([1, 2, 3, 4, 5], "col1")
assert len(column) == 5
12 changes: 12 additions & 0 deletions tests/safeds/data/tabular/_row/test_count.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from safeds.data.tabular import Row
from safeds.data.tabular.typing import IntColumnType, StringColumnType, TableSchema


def test_count() -> None:
row = Row(
[0, "1"],
TableSchema(
{"testColumn1": IntColumnType(), "testColumn2": StringColumnType()}
),
)
assert row.count() == 2
12 changes: 12 additions & 0 deletions tests/safeds/data/tabular/_row/test_iter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from safeds.data.tabular import Row
from safeds.data.tabular.typing import IntColumnType, StringColumnType, TableSchema


def test_iter() -> None:
row = Row(
[0, "1"],
TableSchema(
{"testColumn1": IntColumnType(), "testColumn2": StringColumnType()}
),
)
assert list(row) == ["testColumn1", "testColumn2"]
12 changes: 12 additions & 0 deletions tests/safeds/data/tabular/_row/test_len.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from safeds.data.tabular import Row
from safeds.data.tabular.typing import IntColumnType, StringColumnType, TableSchema


def test_count() -> None:
row = Row(
[0, "1"],
TableSchema(
{"testColumn1": IntColumnType(), "testColumn2": StringColumnType()}
),
)
assert len(row) == 2