Skip to content

Commit

Permalink
Merge pull request #1102 from jguarato/fix/error-element-equality
Browse files Browse the repository at this point in the history
Fix error in equality method for bearing elements
  • Loading branch information
raphaeltimbo authored Sep 19, 2024
2 parents fa13ef3 + 69a747f commit 21fc845
Showing 1 changed file with 74 additions and 139 deletions.
213 changes: 74 additions & 139 deletions ross/bearing_seal_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,49 +394,35 @@ def __eq__(self, other):
>>> bearing1 == bearing2
True
"""
compared_attributes = [
"kxx",
"kyy",
"kxy",
"kyx",
"cxx",
"cyy",
"cxy",
"cyx",
"mxx",
"myy",
"mxy",
"myx",
"frequency",
]
attributes_comparison = False

if isinstance(other, self.__class__):
init_args = []
for arg in signature(self.__init__).parameters:
if arg not in ["kwargs"]:
init_args.append(arg)

init_args_comparison = []
for arg in init_args:
comparison = getattr(self, arg) == getattr(other, arg)
try:
comparison = all(comparison)
except TypeError:
pass
init_args = set(signature(self.__init__).parameters).intersection(
self.__dict__.keys()
)

init_args_comparison.append(comparison)
coefficients = {
attr.replace("_interpolated", "")
for attr in self.__dict__.keys()
if "_interpolated" in attr
}

init_args_comparison = all(init_args_comparison)
attributes_comparison = all(
(
(
np.array(getattr(self, attr)) == np.array(getattr(other, attr))
).all()
for attr in compared_attributes
)
)
compared_attributes = list(coefficients.union(init_args))
compared_attributes.sort()

for attr in compared_attributes:
self_attr = np.array(getattr(self, attr))
other_attr = np.array(getattr(other, attr))

if self_attr.shape == other_attr.shape:
attributes_comparison = (self_attr == other_attr).all()
else:
attributes_comparison = False

if not attributes_comparison:
return attributes_comparison

return init_args_comparison and attributes_comparison
return False
return attributes_comparison

def __hash__(self):
return hash(self.tag)
Expand All @@ -448,25 +434,20 @@ def save(self, file):
data = {}

# save initialization args and coefficients
args = list(signature(self.__init__).parameters)
args += [
"kxx",
"kyy",
"kxy",
"kyx",
"cxx",
"cyy",
"cxy",
"cyx",
"mxx",
"myy",
"mxy",
"myx",
]
brg_data = {}
for arg in args:
if arg not in ["kwargs"]:
brg_data[arg] = self.__dict__[arg]
init_args = set(signature(self.__init__).parameters).intersection(
self.__dict__.keys()
)

coefficients = {
attr.replace("_interpolated", "")
for attr in self.__dict__.keys()
if "_interpolated" in attr
}

args = list(coefficients.union(init_args))
args.sort()

brg_data = {arg: self.__dict__[arg] for arg in args}

# change np.array to lists so that we can save in .toml as list(floats)
for k, v in brg_data.items():
Expand All @@ -481,7 +462,18 @@ def save(self, file):
except (TypeError, AttributeError):
pass

data[f"{self.__class__.__name__}_{self.tag}"] = brg_data
diff_args = set(signature(self.__init__).parameters).difference(
self.__dict__.keys()
)
diff_args.discard("kwargs")

class_name = (
self.__class__.__name__
if not diff_args
else self.__class__.__bases__[0].__name__
)

data[f"{class_name}_{self.tag}"] = brg_data

with open(file, "w") as f:
toml.dump(data, f)
Expand Down Expand Up @@ -943,6 +935,9 @@ class BearingFluidFlow(BearingElement):
scale_factor : float, optional
The scale factor is used to scale the bearing drawing.
Default is 1.
color : str, optional
A color to be used when the element is represented.
Default is '#355d7a'.
Returns
-------
Expand Down Expand Up @@ -1088,7 +1083,7 @@ class SealElement(BearingElement):
Default is None.
scale_factor : float, optional
The scale factor is used to scale the bearing drawing.
Default is 1.
Default is 0.5.
color : str, optional
A color to be used when the element is represented.
Default is "#77ACA2".
Expand Down Expand Up @@ -1131,12 +1126,11 @@ def __init__(
seal_leakage=None,
tag=None,
n_link=None,
scale_factor=1.0,
scale_factor=None,
color="#77ACA2",
**kwargs,
):
# make seals with half the bearing size as a default
seal_scale_factor = scale_factor / 2

super().__init__(
n=n,
frequency=frequency,
Expand All @@ -1154,11 +1148,11 @@ def __init__(
myy=myy,
tag=tag,
n_link=n_link,
scale_factor=seal_scale_factor,
color=color,
)

self.seal_leakage = seal_leakage
# make seals with half the bearing size as a default
self.scale_factor = scale_factor if scale_factor else self.scale_factor / 2


class BallBearingElement(BearingElement):
Expand Down Expand Up @@ -1201,6 +1195,9 @@ class BallBearingElement(BearingElement):
scale_factor : float, optional
The scale factor is used to scale the bearing drawing.
Default is 1.
color : str, optional
A color to be used when the element is represented.
Default is '#355d7a'.
References
----------
Expand Down Expand Up @@ -1234,6 +1231,7 @@ def __init__(
tag=None,
n_link=None,
scale_factor=1,
color="#355d7a",
):
Kb = 13.0e6
kyy = (
Expand Down Expand Up @@ -1273,10 +1271,9 @@ def __init__(
tag=tag,
n_link=n_link,
scale_factor=scale_factor,
color=color,
)

self.color = "#77ACA2"


class RollerBearingElement(BearingElement):
"""A bearing element for roller bearings.
Expand Down Expand Up @@ -1318,6 +1315,9 @@ class RollerBearingElement(BearingElement):
scale_factor : float, optional
The scale factor is used to scale the bearing drawing.
Default is 1.
color : str, optional
A color to be used when the element is represented.
Default is '#355d7a'.
References
----------
Expand Down Expand Up @@ -1351,6 +1351,7 @@ def __init__(
tag=None,
n_link=None,
scale_factor=1,
color="#355d7a",
):
Kb = 1.0e9
kyy = Kb * n_rollers**0.9 * l_rollers**0.8 * fs**0.1 * (np.cos(alpha)) ** 1.9
Expand Down Expand Up @@ -1384,10 +1385,9 @@ def __init__(
tag=tag,
n_link=n_link,
scale_factor=scale_factor,
color=color,
)

self.color = "#77ACA2"


class MagneticBearingElement(BearingElement):
"""Magnetic bearing.
Expand Down Expand Up @@ -1428,6 +1428,9 @@ class MagneticBearingElement(BearingElement):
scale_factor : float, optional
The scale factor is used to scale the bearing drawing.
Default is 1.
color : str, optional
A color to be used when the element is represented.
Default is '#355d7a'.
----------
See the following reference for the electromagnetic parameters g0, i0, ag, nw, alpha:
Expand Down Expand Up @@ -1470,6 +1473,7 @@ def __init__(
tag=None,
n_link=None,
scale_factor=1,
color="#355d7a",
**kwargs,
):
self.g0 = g0
Expand Down Expand Up @@ -1561,6 +1565,7 @@ def __init__(
tag=tag,
n_link=n_link,
scale_factor=scale_factor,
color=color,
)


Expand Down Expand Up @@ -1910,76 +1915,6 @@ def __repr__(self):
f" frequency={self.frequency}, tag={self.tag!r})"
)

def __eq__(self, other):
"""Equality method for comparasions.
Parameters
----------
other : object
The second object to be compared with.
Returns
-------
bool
True if the comparison is true; False otherwise.
Examples
--------
>>> bearing1 = bearing_example()
>>> bearing2 = bearing_example()
>>> bearing1 == bearing2
True
"""
compared_attributes = [
"kxx",
"kyy",
"kxy",
"kyx",
"kzz",
"cxx",
"cyy",
"cxy",
"cyx",
"czz",
"mxx",
"myy",
"mxy",
"myx",
"mzz",
"frequency",
]
if isinstance(other, self.__class__):
init_args = []
for arg in signature(self.__init__).parameters:
if arg not in ["kwargs"]:
init_args.append(arg)

init_args_comparison = []
for arg in init_args:
comparison = getattr(self, arg) == getattr(other, arg)
try:
comparison = all(comparison)
except TypeError:
pass

init_args_comparison.append(comparison)

init_args_comparison = all(init_args_comparison)
attributes_comparison = all(
(
(
np.array(getattr(self, attr)) == np.array(getattr(other, attr))
).all()
for attr in compared_attributes
)
)

return init_args_comparison and attributes_comparison
return False

def __hash__(self):
return hash(self.tag)

def dof_mapping(self):
"""Degrees of freedom mapping.
Expand Down

0 comments on commit 21fc845

Please sign in to comment.