Skip to content

Adding Jit functionality to Plan finialization. #24

Adding Jit functionality to Plan finialization.

Adding Jit functionality to Plan finialization. #24

Workflow file for this run

name: JAX tests
on:
pull_request:
schedule:
# Every weekday at 03:53 UTC, see https://crontab.guru/
- cron: "53 3 * * 1-5"
workflow_dispatch:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
test:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
# How to set up Jax on an ARM Mac: https://developer.apple.com/metal/jax/
os: ["ubuntu-latest", "macos-14"]
python-version: ["3.11"]
steps:
- name: Checkout source
uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Set up Python
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
architecture: x64
- name: Setup Graphviz
uses: ts-graphviz/setup-graphviz@v2
- name: Install
run: |
python -m pip install --upgrade pip
python -m pip install -e '.[test-jax]'
# Verify jax
python -c 'import jax; print(jax.numpy.arange(10))'
- name: Run tests
run: |
# exclude tests that rely on structured types since JAX doesn't support these
# exclude tests that rely on randomness because JAX is picky about this.
# TODO(#494): Turn back on tests that do visualization when the "FileNotFound" error is fixed. These are "visualization", "plan_scaling", and "optimization".
pytest -k "not argmax and not argmin and not mean and not apply_reduction and not broadcast_trick and not groupby and not random and not visualization and not plan_scaling and not optimization"
env:
CUBED_BACKEND_ARRAY_API_MODULE: jax.numpy
JAX_ENABLE_X64: True
ENABLE_PJRT_COMPATIBILITY: True