diff --git a/ross/bearing_seal_element.py b/ross/bearing_seal_element.py index 2cdb663d..458e8f7f 100644 --- a/ross/bearing_seal_element.py +++ b/ross/bearing_seal_element.py @@ -394,49 +394,35 @@ def __eq__(self, other): >>> bearing1 == bearing2 True """ - compared_attributes = [ - "kxx", - "kyy", - "kxy", - "kyx", - "cxx", - "cyy", - "cxy", - "cyx", - "mxx", - "myy", - "mxy", - "myx", - "frequency", - ] + attributes_comparison = False + if isinstance(other, self.__class__): - init_args = [] - for arg in signature(self.__init__).parameters: - if arg not in ["kwargs"]: - init_args.append(arg) - - init_args_comparison = [] - for arg in init_args: - comparison = getattr(self, arg) == getattr(other, arg) - try: - comparison = all(comparison) - except TypeError: - pass + init_args = set(signature(self.__init__).parameters).intersection( + self.__dict__.keys() + ) - init_args_comparison.append(comparison) + coefficients = { + attr.replace("_interpolated", "") + for attr in self.__dict__.keys() + if "_interpolated" in attr + } - init_args_comparison = all(init_args_comparison) - attributes_comparison = all( - ( - ( - np.array(getattr(self, attr)) == np.array(getattr(other, attr)) - ).all() - for attr in compared_attributes - ) - ) + compared_attributes = list(coefficients.union(init_args)) + compared_attributes.sort() + + for attr in compared_attributes: + self_attr = np.array(getattr(self, attr)) + other_attr = np.array(getattr(other, attr)) + + if self_attr.shape == other_attr.shape: + attributes_comparison = (self_attr == other_attr).all() + else: + attributes_comparison = False + + if not attributes_comparison: + return attributes_comparison - return init_args_comparison and attributes_comparison - return False + return attributes_comparison def __hash__(self): return hash(self.tag) @@ -448,25 +434,20 @@ def save(self, file): data = {} # save initialization args and coefficients - args = list(signature(self.__init__).parameters) - args += [ - "kxx", - "kyy", - "kxy", - "kyx", - "cxx", - "cyy", - "cxy", - "cyx", - "mxx", - "myy", - "mxy", - "myx", - ] - brg_data = {} - for arg in args: - if arg not in ["kwargs"]: - brg_data[arg] = self.__dict__[arg] + init_args = set(signature(self.__init__).parameters).intersection( + self.__dict__.keys() + ) + + coefficients = { + attr.replace("_interpolated", "") + for attr in self.__dict__.keys() + if "_interpolated" in attr + } + + args = list(coefficients.union(init_args)) + args.sort() + + brg_data = {arg: self.__dict__[arg] for arg in args} # change np.array to lists so that we can save in .toml as list(floats) for k, v in brg_data.items(): @@ -481,7 +462,18 @@ def save(self, file): except (TypeError, AttributeError): pass - data[f"{self.__class__.__name__}_{self.tag}"] = brg_data + diff_args = set(signature(self.__init__).parameters).difference( + self.__dict__.keys() + ) + diff_args.discard("kwargs") + + class_name = ( + self.__class__.__name__ + if not diff_args + else self.__class__.__bases__[0].__name__ + ) + + data[f"{class_name}_{self.tag}"] = brg_data with open(file, "w") as f: toml.dump(data, f) @@ -943,6 +935,9 @@ class BearingFluidFlow(BearingElement): scale_factor : float, optional The scale factor is used to scale the bearing drawing. Default is 1. + color : str, optional + A color to be used when the element is represented. + Default is '#355d7a'. Returns ------- @@ -1088,7 +1083,7 @@ class SealElement(BearingElement): Default is None. scale_factor : float, optional The scale factor is used to scale the bearing drawing. - Default is 1. + Default is 0.5. color : str, optional A color to be used when the element is represented. Default is "#77ACA2". @@ -1131,12 +1126,11 @@ def __init__( seal_leakage=None, tag=None, n_link=None, - scale_factor=1.0, + scale_factor=None, color="#77ACA2", **kwargs, ): - # make seals with half the bearing size as a default - seal_scale_factor = scale_factor / 2 + super().__init__( n=n, frequency=frequency, @@ -1154,11 +1148,11 @@ def __init__( myy=myy, tag=tag, n_link=n_link, - scale_factor=seal_scale_factor, color=color, ) - self.seal_leakage = seal_leakage + # make seals with half the bearing size as a default + self.scale_factor = scale_factor if scale_factor else self.scale_factor / 2 class BallBearingElement(BearingElement): @@ -1201,6 +1195,9 @@ class BallBearingElement(BearingElement): scale_factor : float, optional The scale factor is used to scale the bearing drawing. Default is 1. + color : str, optional + A color to be used when the element is represented. + Default is '#355d7a'. References ---------- @@ -1234,6 +1231,7 @@ def __init__( tag=None, n_link=None, scale_factor=1, + color="#355d7a", ): Kb = 13.0e6 kyy = ( @@ -1273,10 +1271,9 @@ def __init__( tag=tag, n_link=n_link, scale_factor=scale_factor, + color=color, ) - self.color = "#77ACA2" - class RollerBearingElement(BearingElement): """A bearing element for roller bearings. @@ -1318,6 +1315,9 @@ class RollerBearingElement(BearingElement): scale_factor : float, optional The scale factor is used to scale the bearing drawing. Default is 1. + color : str, optional + A color to be used when the element is represented. + Default is '#355d7a'. References ---------- @@ -1351,6 +1351,7 @@ def __init__( tag=None, n_link=None, scale_factor=1, + color="#355d7a", ): Kb = 1.0e9 kyy = Kb * n_rollers**0.9 * l_rollers**0.8 * fs**0.1 * (np.cos(alpha)) ** 1.9 @@ -1384,10 +1385,9 @@ def __init__( tag=tag, n_link=n_link, scale_factor=scale_factor, + color=color, ) - self.color = "#77ACA2" - class MagneticBearingElement(BearingElement): """Magnetic bearing. @@ -1428,6 +1428,9 @@ class MagneticBearingElement(BearingElement): scale_factor : float, optional The scale factor is used to scale the bearing drawing. Default is 1. + color : str, optional + A color to be used when the element is represented. + Default is '#355d7a'. ---------- See the following reference for the electromagnetic parameters g0, i0, ag, nw, alpha: @@ -1470,6 +1473,7 @@ def __init__( tag=None, n_link=None, scale_factor=1, + color="#355d7a", **kwargs, ): self.g0 = g0 @@ -1561,6 +1565,7 @@ def __init__( tag=tag, n_link=n_link, scale_factor=scale_factor, + color=color, ) @@ -1910,76 +1915,6 @@ def __repr__(self): f" frequency={self.frequency}, tag={self.tag!r})" ) - def __eq__(self, other): - """Equality method for comparasions. - - Parameters - ---------- - other : object - The second object to be compared with. - - Returns - ------- - bool - True if the comparison is true; False otherwise. - - Examples - -------- - >>> bearing1 = bearing_example() - >>> bearing2 = bearing_example() - >>> bearing1 == bearing2 - True - """ - compared_attributes = [ - "kxx", - "kyy", - "kxy", - "kyx", - "kzz", - "cxx", - "cyy", - "cxy", - "cyx", - "czz", - "mxx", - "myy", - "mxy", - "myx", - "mzz", - "frequency", - ] - if isinstance(other, self.__class__): - init_args = [] - for arg in signature(self.__init__).parameters: - if arg not in ["kwargs"]: - init_args.append(arg) - - init_args_comparison = [] - for arg in init_args: - comparison = getattr(self, arg) == getattr(other, arg) - try: - comparison = all(comparison) - except TypeError: - pass - - init_args_comparison.append(comparison) - - init_args_comparison = all(init_args_comparison) - attributes_comparison = all( - ( - ( - np.array(getattr(self, attr)) == np.array(getattr(other, attr)) - ).all() - for attr in compared_attributes - ) - ) - - return init_args_comparison and attributes_comparison - return False - - def __hash__(self): - return hash(self.tag) - def dof_mapping(self): """Degrees of freedom mapping.