Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix error in equality method for bearing elements #1102

Merged
merged 10 commits into from
Sep 19, 2024
213 changes: 74 additions & 139 deletions ross/bearing_seal_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,49 +394,35 @@ def __eq__(self, other):
>>> bearing1 == bearing2
True
"""
compared_attributes = [
"kxx",
"kyy",
"kxy",
"kyx",
"cxx",
"cyy",
"cxy",
"cyx",
"mxx",
"myy",
"mxy",
"myx",
"frequency",
]
attributes_comparison = False

if isinstance(other, self.__class__):
init_args = []
for arg in signature(self.__init__).parameters:
if arg not in ["kwargs"]:
init_args.append(arg)

init_args_comparison = []
for arg in init_args:
comparison = getattr(self, arg) == getattr(other, arg)
try:
comparison = all(comparison)
except TypeError:
pass
init_args = set(signature(self.__init__).parameters).intersection(
self.__dict__.keys()
)

init_args_comparison.append(comparison)
coefficients = {
attr.replace("_interpolated", "")
for attr in self.__dict__.keys()
if "_interpolated" in attr
}

init_args_comparison = all(init_args_comparison)
attributes_comparison = all(
(
(
np.array(getattr(self, attr)) == np.array(getattr(other, attr))
).all()
for attr in compared_attributes
)
)
compared_attributes = list(coefficients.union(init_args))
compared_attributes.sort()

for attr in compared_attributes:
self_attr = np.array(getattr(self, attr))
other_attr = np.array(getattr(other, attr))

if self_attr.shape == other_attr.shape:
attributes_comparison = (self_attr == other_attr).all()
else:
attributes_comparison = False

if not attributes_comparison:
return attributes_comparison

return init_args_comparison and attributes_comparison
return False
return attributes_comparison

def __hash__(self):
return hash(self.tag)
Expand All @@ -448,25 +434,20 @@ def save(self, file):
data = {}

# save initialization args and coefficients
args = list(signature(self.__init__).parameters)
args += [
"kxx",
"kyy",
"kxy",
"kyx",
"cxx",
"cyy",
"cxy",
"cyx",
"mxx",
"myy",
"mxy",
"myx",
]
brg_data = {}
for arg in args:
if arg not in ["kwargs"]:
brg_data[arg] = self.__dict__[arg]
init_args = set(signature(self.__init__).parameters).intersection(
self.__dict__.keys()
)

coefficients = {
attr.replace("_interpolated", "")
for attr in self.__dict__.keys()
if "_interpolated" in attr
}

args = list(coefficients.union(init_args))
args.sort()

brg_data = {arg: self.__dict__[arg] for arg in args}

# change np.array to lists so that we can save in .toml as list(floats)
for k, v in brg_data.items():
Expand All @@ -481,7 +462,18 @@ def save(self, file):
except (TypeError, AttributeError):
pass

data[f"{self.__class__.__name__}_{self.tag}"] = brg_data
diff_args = set(signature(self.__init__).parameters).difference(
self.__dict__.keys()
)
diff_args.discard("kwargs")

class_name = (
self.__class__.__name__
if not diff_args
else self.__class__.__bases__[0].__name__
)

data[f"{class_name}_{self.tag}"] = brg_data

with open(file, "w") as f:
toml.dump(data, f)
Expand Down Expand Up @@ -943,6 +935,9 @@ class BearingFluidFlow(BearingElement):
scale_factor : float, optional
The scale factor is used to scale the bearing drawing.
Default is 1.
color : str, optional
A color to be used when the element is represented.
Default is '#355d7a'.

Returns
-------
Expand Down Expand Up @@ -1088,7 +1083,7 @@ class SealElement(BearingElement):
Default is None.
scale_factor : float, optional
The scale factor is used to scale the bearing drawing.
Default is 1.
Default is 0.5.
color : str, optional
A color to be used when the element is represented.
Default is "#77ACA2".
Expand Down Expand Up @@ -1131,12 +1126,11 @@ def __init__(
seal_leakage=None,
tag=None,
n_link=None,
scale_factor=1.0,
scale_factor=None,
color="#77ACA2",
**kwargs,
):
# make seals with half the bearing size as a default
seal_scale_factor = scale_factor / 2

super().__init__(
n=n,
frequency=frequency,
Expand All @@ -1154,11 +1148,11 @@ def __init__(
myy=myy,
tag=tag,
n_link=n_link,
scale_factor=seal_scale_factor,
color=color,
)

self.seal_leakage = seal_leakage
# make seals with half the bearing size as a default
self.scale_factor = scale_factor if scale_factor else self.scale_factor / 2


class BallBearingElement(BearingElement):
Expand Down Expand Up @@ -1201,6 +1195,9 @@ class BallBearingElement(BearingElement):
scale_factor : float, optional
The scale factor is used to scale the bearing drawing.
Default is 1.
color : str, optional
A color to be used when the element is represented.
Default is '#355d7a'.

References
----------
Expand Down Expand Up @@ -1234,6 +1231,7 @@ def __init__(
tag=None,
n_link=None,
scale_factor=1,
color="#355d7a",
):
Kb = 13.0e6
kyy = (
Expand Down Expand Up @@ -1273,10 +1271,9 @@ def __init__(
tag=tag,
n_link=n_link,
scale_factor=scale_factor,
color=color,
)

self.color = "#77ACA2"


class RollerBearingElement(BearingElement):
"""A bearing element for roller bearings.
Expand Down Expand Up @@ -1318,6 +1315,9 @@ class RollerBearingElement(BearingElement):
scale_factor : float, optional
The scale factor is used to scale the bearing drawing.
Default is 1.
color : str, optional
A color to be used when the element is represented.
Default is '#355d7a'.

References
----------
Expand Down Expand Up @@ -1351,6 +1351,7 @@ def __init__(
tag=None,
n_link=None,
scale_factor=1,
color="#355d7a",
):
Kb = 1.0e9
kyy = Kb * n_rollers**0.9 * l_rollers**0.8 * fs**0.1 * (np.cos(alpha)) ** 1.9
Expand Down Expand Up @@ -1384,10 +1385,9 @@ def __init__(
tag=tag,
n_link=n_link,
scale_factor=scale_factor,
color=color,
)

self.color = "#77ACA2"


class MagneticBearingElement(BearingElement):
"""Magnetic bearing.
Expand Down Expand Up @@ -1428,6 +1428,9 @@ class MagneticBearingElement(BearingElement):
scale_factor : float, optional
The scale factor is used to scale the bearing drawing.
Default is 1.
color : str, optional
A color to be used when the element is represented.
Default is '#355d7a'.

----------
See the following reference for the electromagnetic parameters g0, i0, ag, nw, alpha:
Expand Down Expand Up @@ -1470,6 +1473,7 @@ def __init__(
tag=None,
n_link=None,
scale_factor=1,
color="#355d7a",
**kwargs,
):
self.g0 = g0
Expand Down Expand Up @@ -1561,6 +1565,7 @@ def __init__(
tag=tag,
n_link=n_link,
scale_factor=scale_factor,
color=color,
)


Expand Down Expand Up @@ -1910,76 +1915,6 @@ def __repr__(self):
f" frequency={self.frequency}, tag={self.tag!r})"
)

def __eq__(self, other):
"""Equality method for comparasions.

Parameters
----------
other : object
The second object to be compared with.

Returns
-------
bool
True if the comparison is true; False otherwise.

Examples
--------
>>> bearing1 = bearing_example()
>>> bearing2 = bearing_example()
>>> bearing1 == bearing2
True
"""
compared_attributes = [
"kxx",
"kyy",
"kxy",
"kyx",
"kzz",
"cxx",
"cyy",
"cxy",
"cyx",
"czz",
"mxx",
"myy",
"mxy",
"myx",
"mzz",
"frequency",
]
if isinstance(other, self.__class__):
init_args = []
for arg in signature(self.__init__).parameters:
if arg not in ["kwargs"]:
init_args.append(arg)

init_args_comparison = []
for arg in init_args:
comparison = getattr(self, arg) == getattr(other, arg)
try:
comparison = all(comparison)
except TypeError:
pass

init_args_comparison.append(comparison)

init_args_comparison = all(init_args_comparison)
attributes_comparison = all(
(
(
np.array(getattr(self, attr)) == np.array(getattr(other, attr))
).all()
for attr in compared_attributes
)
)

return init_args_comparison and attributes_comparison
return False

def __hash__(self):
return hash(self.tag)

def dof_mapping(self):
"""Degrees of freedom mapping.

Expand Down
Loading