From 77e5c42e3d331f226b9c2fc40fc69e5cb23546bb Mon Sep 17 00:00:00 2001 From: Aditya Goel <48102515+adityagoel4512@users.noreply.github.com> Date: Wed, 2 Oct 2024 10:03:28 +0100 Subject: [PATCH] Update readme (#81) --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 3b092a4..cc93e55 100644 --- a/README.md +++ b/README.md @@ -48,14 +48,14 @@ It has a couple of key features: ```python import numpy as np import ndonnx as ndx - from jax.experimental import array_api as jxp + import jax.numpy as jnp def mean_drop_outliers(a, low=-5, high=5): xp = a.__array_namespace__() return xp.mean(a[(low < a) & (a < high)]) np_result = mean_drop_outliers(np.asarray([-10, 0.5, 1, 5])) - jax_result = mean_drop_outliers(jxp.asarray([-10, 0.5, 1, 5])) + jax_result = mean_drop_outliers(jnp.asarray([-10, 0.5, 1, 5])) onnx_result = mean_drop_outliers(ndx.asarray([-10, 0.5, 1, 5])) assert np_result == onnx_result.to_numpy() == jax_result == 0.75