Adding Jit functionality to Plan finialization. #24
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
name: JAX tests | |
on: | |
pull_request: | |
schedule: | |
# Every weekday at 03:53 UTC, see https://crontab.guru/ | |
- cron: "53 3 * * 1-5" | |
workflow_dispatch: | |
concurrency: | |
group: ${{ github.workflow }}-${{ github.ref }} | |
cancel-in-progress: true | |
jobs: | |
test: | |
runs-on: ${{ matrix.os }} | |
strategy: | |
fail-fast: false | |
matrix: | |
# How to set up Jax on an ARM Mac: https://developer.apple.com/metal/jax/ | |
os: ["ubuntu-latest", "macos-14"] | |
python-version: ["3.11"] | |
steps: | |
- name: Checkout source | |
uses: actions/checkout@v3 | |
with: | |
fetch-depth: 0 | |
- name: Set up Python | |
uses: actions/setup-python@v3 | |
with: | |
python-version: ${{ matrix.python-version }} | |
architecture: x64 | |
- name: Setup Graphviz | |
uses: ts-graphviz/setup-graphviz@v2 | |
- name: Install | |
run: | | |
python -m pip install --upgrade pip | |
python -m pip install -e '.[test-jax]' | |
# Verify jax | |
python -c 'import jax; print(jax.numpy.arange(10))' | |
- name: Run tests | |
run: | | |
# exclude tests that rely on structured types since JAX doesn't support these | |
# exclude tests that rely on randomness because JAX is picky about this. | |
# TODO(#494): Turn back on tests that do visualization when the "FileNotFound" error is fixed. These are "visualization", "plan_scaling", and "optimization". | |
pytest -k "not argmax and not argmin and not mean and not apply_reduction and not broadcast_trick and not groupby and not random and not visualization and not plan_scaling and not optimization" | |
env: | |
CUBED_BACKEND_ARRAY_API_MODULE: jax.numpy | |
JAX_ENABLE_X64: True | |
ENABLE_PJRT_COMPATIBILITY: True |