diff --git a/cyipopt/scipy_interface.py b/cyipopt/scipy_interface.py index d9a7fd1..238a4e6 100644 --- a/cyipopt/scipy_interface.py +++ b/cyipopt/scipy_interface.py @@ -682,6 +682,8 @@ def _minimize_ipopt_iv(fun, x0, args, kwargs, method, jac, hess, hessp, tol = np.asarray(tol)[()] if tol.ndim != 0 or not np.issubdtype(tol.dtype, np.number) or tol <= 0: raise ValueError('`tol` must be a positive scalar.') + else: # tol should be a float, not an array + tol = float(tol) options = dict() if options is None else options if not isinstance(options, dict): diff --git a/cyipopt/tests/unit/test_scipy_optional.py b/cyipopt/tests/unit/test_scipy_optional.py index 22ddc13..fbd0669 100644 --- a/cyipopt/tests/unit/test_scipy_optional.py +++ b/cyipopt/tests/unit/test_scipy_optional.py @@ -30,6 +30,22 @@ def test_minimize_ipopt_import_error_if_no_scipy(): cyipopt.minimize_ipopt(None, None) +@pytest.mark.skipif("scipy" not in sys.modules, + reason="Test only valid if Scipy available.") +def test_tol_type_issue_235(): + # from: https://github.com/mechmotum/cyipopt/issues/235 + + def fun(x): + return np.sum(x ** 2) + + # tol should not raise an error + cyipopt.minimize_ipopt( + fun=fun, + x0=np.zeros(2), + tol=1e-9, + ) + + @pytest.mark.skipif("scipy" not in sys.modules, reason="Test only valid if Scipy available.") def test_minimize_ipopt_input_validation():