diff --git a/jax/api_util.py b/jax/api_util.py index bbb20226ff1a..8f8cffbb9381 100644 --- a/jax/api_util.py +++ b/jax/api_util.py @@ -88,13 +88,10 @@ def argnums_partial_except(f: lu.WrappedFun, static_argnums: Tuple[int, ...], try: hash(static_arg) except TypeError: - logging.warning( - "Static argument (index %s) of type %s for function %s is " - "non-hashable. As this can lead to unexpected cache-misses, it " - "will raise an error in a near future.", i, type(static_arg), - f.__name__) - # e.g. ndarrays, DeviceArrays - fixed_args[i] = WrapHashably(static_arg) # type: ignore + raise ValueError( + "Non-hashable static arguments are not supported, as this can lead " + f"to unexpected cache-misses. Static argument (index {i}) of type " + f"{type(static_arg)} for function {f.__name__} is non-hashable.") else: fixed_args[i] = Hashable(static_arg) # type: ignore diff --git a/tests/api_test.py b/tests/api_test.py index b06523475374..a1f62c6de318 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -415,6 +415,18 @@ def test_jit_reference_dropping(self): del g # no more references to x assert x() is None # x is gone + def test_jit_raises_on_first_invocation_on_non_hashable_static_argnum(self): + if self.jit != jax.api._python_jit: + raise unittest.SkipTest("this test only applies to _python_jit") + f = lambda x, y: x + 3 + jitted_f = self.jit(f, static_argnums=(1,)) + + msg = ("Non-hashable static arguments are not supported, as this can lead " + "to unexpected cache-misses. Static argument (index 1) of type " + " for function is non-hashable.") + with self.assertRaisesRegex(ValueError, re.escape(msg)): + jitted_f(1, np.asarray(1)) + def test_cpp_jit_raises_on_non_hashable_static_argnum(self): if version < (0, 1, 58): raise unittest.SkipTest("Disabled because it depends on some future " @@ -428,9 +440,9 @@ def test_cpp_jit_raises_on_non_hashable_static_argnum(self): jitted_f(1, 1) - msg = ( - """Non-hashable static arguments are not supported. An error occured while trying to hash an object of type , 1. The error was: -TypeError: unhashable type: 'numpy.ndarray'""") + msg = ("Non-hashable static arguments are not supported. An error occured " + "while trying to hash an object of type , 1. " + "The error was:\nTypeError: unhashable type: 'numpy.ndarray'") with self.assertRaisesRegex(ValueError, re.escape(msg)): jitted_f(1, np.asarray(1))