forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 0
/
pyproject.toml
107 lines (100 loc) · 3.12 KB
/
pyproject.toml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
[build-system]
requires = ["setuptools", "wheel"]
build-backend = "setuptools.build_meta"
[tool.mypy]
show_error_codes = true
disable_error_code = "attr-defined, name-defined, annotation-unchecked"
no_implicit_optional = true
[[tool.mypy.overrides]]
module = [
"absl.*",
"colorama.*",
"importlib_metadata.*",
"IPython.*",
"numpy.*",
"opt_einsum.*",
"scipy.*",
"libtpu.*",
"jaxlib.mlir.*",
"iree.*",
"rich.*",
"optax.*",
"flax.*",
"tensorflow.*",
"tensorflowjs.*",
"tensorflow.io.*",
"tensorstore.*",
"web_pdb.*",
"etils.*",
"google.colab.*",
"pygments.*",
"jraph.*",
"matplotlib.*",
"tensorboard_plugin_profile.convert.*",
"jaxlib.*",
"pytest.*",
"zstandard.*",
"jax.experimental.jax2tf.tests.flax_models",
"jax.experimental.jax2tf.tests.back_compat_testdata",
"setuptools.*",
]
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = [
"jax.interpreters.autospmd",
"jax.lax.lax_parallel",
"jax.experimental.jax2tf.tests.primitive_harness"
]
ignore_errors = true
[tool.pytest.ini_options]
markers = [
"multiaccelerator: indicates that a test can make use of and possibly requires multiple accelerators",
"SlurmMultiNodeGpuTest: mark a test for Slurm multinode GPU nightly CI"
]
filterwarnings = [
"error",
"ignore:No GPU/TPU found, falling back to CPU.:UserWarning",
"ignore:xmap is an experimental feature and probably has bugs!",
"ignore:the imp module is deprecated in favour of importlib.*:DeprecationWarning",
"ignore:can't resolve package from __spec__ or __package__:ImportWarning",
"ignore:Using or importing the ABCs.*:DeprecationWarning",
"ignore:numpy.ufunc size changed",
"ignore:.*experimental feature",
"ignore:The distutils.* is deprecated.*:DeprecationWarning",
"default:Error reading persistent compilation cache entry for 'jit__lambda_'",
"default:Error writing persistent compilation cache entry for 'jit__lambda_'",
"ignore:backend and device argument on jit is deprecated.*:DeprecationWarning",
# TODO(skyewm): remove when jaxlib >= 0.4.12 is released (needs
# https://github.com/openxla/xla/commit/fb9dc3db0999bf14c78d95cb7c3aa6815221ddc7)
"ignore:ml_dtypes.float8_e4m3b11 is deprecated.",
"ignore:JAX_USE_PJRT_C_API_ON_TPU=false will no longer be supported.*:UserWarning",
"ignore:np.find_common_type is deprecated.*:DeprecationWarning",
"ignore:jax.numpy.in1d is deprecated.*:DeprecationWarning",
"ignore:jax.numpy.trapz is deprecated.*:DeprecationWarning",
]
doctest_optionflags = [
"NUMBER",
"NORMALIZE_WHITESPACE"
]
addopts = "--doctest-glob='*.rst'"
[tool.pylint.master]
extension-pkg-whitelist = "numpy"
[tool.pylint."messages control"]
disable = [
"missing-docstring",
"too-many-locals",
"invalid-name",
"redefined-outer-name",
"redefined-builtin",
"protected-name",
"no-else-return",
"fixme",
"protected-access",
"too-many-arguments",
"blacklisted-name",
"too-few-public-methods",
"unnecessary-lambda"
]
enable = "c-extension-no-member"
[tool.pylint.format]
indent-string=" "