Skip to content

Instantly share code, notes, and snippets.

@danielkelshaw
Created August 10, 2023 13:31
Show Gist options
  • Save danielkelshaw/8911674fbbbbec1676191874eaacc3e3 to your computer and use it in GitHub Desktop.
Save danielkelshaw/8911674fbbbbec1676191874eaacc3e3 to your computer and use it in GitHub Desktop.
============================= test session starts ==============================
platform darwin -- Python 3.10.0, pytest-7.4.0, pluggy-1.2.0
rootdir: /Users/djk21/coding/jax-cfd
plugins: xdist-3.3.1
created: 16/16 workers
16 workers [458 items]
........................................................................ [ 15%]
........................................F............................... [ 31%]
.....................F.................................................. [ 46%]
....F...................F....................F.......................F.. [ 62%]
.........F......F.......F............................................... [ 78%]
........................................................................ [ 94%]
....F...F................. [100%]
==================================== ERRORS ====================================
________________ ERROR collecting jax_cfd/ml/equations_test.py _________________
ImportError while importing test module '/Users/djk21/coding/jax-cfd/jax_cfd/ml/equations_test.py'.
Hint: make sure your test modules/packages have valid Python names.
Traceback:
../../.pyenv/versions/3.10.0/lib/python3.10/importlib/__init__.py:126: in import_module
return _bootstrap._gcd_import(name[level:], package, level)
jax_cfd/ml/__init__.py:17: in <module>
import jax_cfd.ml.advections
jax_cfd/ml/advections.py:8: in <module>
from jax_cfd.ml import interpolations
jax_cfd/ml/interpolations.py:11: in <module>
from jax_cfd.ml import layers
jax_cfd/ml/layers.py:8: in <module>
import haiku as hk
venv/lib/python3.10/site-packages/haiku/__init__.py:20: in <module>
from haiku import experimental
venv/lib/python3.10/site-packages/haiku/experimental/__init__.py:21: in <module>
from haiku._src.base import current_name
venv/lib/python3.10/site-packages/haiku/_src/base.py:27: in <module>
from haiku._src.typing import ( # pylint: disable=g-multiple-import
venv/lib/python3.10/site-packages/haiku/_src/typing.py:22: in <module>
from typing_extensions import Protocol, runtime_checkable # pylint: disable=multiple-statements,g-multiple-import
E ModuleNotFoundError: No module named 'typing_extensions'
__________________ ERROR collecting jax_cfd/ml/layers_test.py __________________
ImportError while importing test module '/Users/djk21/coding/jax-cfd/jax_cfd/ml/layers_test.py'.
Hint: make sure your test modules/packages have valid Python names.
Traceback:
../../.pyenv/versions/3.10.0/lib/python3.10/importlib/__init__.py:126: in import_module
return _bootstrap._gcd_import(name[level:], package, level)
jax_cfd/ml/__init__.py:17: in <module>
import jax_cfd.ml.advections
jax_cfd/ml/advections.py:8: in <module>
from jax_cfd.ml import interpolations
jax_cfd/ml/interpolations.py:11: in <module>
from jax_cfd.ml import layers
jax_cfd/ml/layers.py:8: in <module>
import haiku as hk
venv/lib/python3.10/site-packages/haiku/__init__.py:20: in <module>
from haiku import experimental
venv/lib/python3.10/site-packages/haiku/experimental/__init__.py:21: in <module>
from haiku._src.base import current_name
venv/lib/python3.10/site-packages/haiku/_src/base.py:27: in <module>
from haiku._src.typing import ( # pylint: disable=g-multiple-import
venv/lib/python3.10/site-packages/haiku/_src/typing.py:22: in <module>
from typing_extensions import Protocol, runtime_checkable # pylint: disable=multiple-statements,g-multiple-import
E ModuleNotFoundError: No module named 'typing_extensions'
_______________ ERROR collecting jax_cfd/ml/layers_util_test.py ________________
ImportError while importing test module '/Users/djk21/coding/jax-cfd/jax_cfd/ml/layers_util_test.py'.
Hint: make sure your test modules/packages have valid Python names.
Traceback:
../../.pyenv/versions/3.10.0/lib/python3.10/importlib/__init__.py:126: in import_module
return _bootstrap._gcd_import(name[level:], package, level)
jax_cfd/ml/__init__.py:17: in <module>
import jax_cfd.ml.advections
jax_cfd/ml/advections.py:8: in <module>
from jax_cfd.ml import interpolations
jax_cfd/ml/interpolations.py:11: in <module>
from jax_cfd.ml import layers
jax_cfd/ml/layers.py:8: in <module>
import haiku as hk
venv/lib/python3.10/site-packages/haiku/__init__.py:20: in <module>
from haiku import experimental
venv/lib/python3.10/site-packages/haiku/experimental/__init__.py:21: in <module>
from haiku._src.base import current_name
venv/lib/python3.10/site-packages/haiku/_src/base.py:27: in <module>
from haiku._src.typing import ( # pylint: disable=g-multiple-import
venv/lib/python3.10/site-packages/haiku/_src/typing.py:22: in <module>
from typing_extensions import Protocol, runtime_checkable # pylint: disable=multiple-statements,g-multiple-import
E ModuleNotFoundError: No module named 'typing_extensions'
__________________ ERROR collecting jax_cfd/ml/towers_test.py __________________
ImportError while importing test module '/Users/djk21/coding/jax-cfd/jax_cfd/ml/towers_test.py'.
Hint: make sure your test modules/packages have valid Python names.
Traceback:
../../.pyenv/versions/3.10.0/lib/python3.10/importlib/__init__.py:126: in import_module
return _bootstrap._gcd_import(name[level:], package, level)
jax_cfd/ml/__init__.py:17: in <module>
import jax_cfd.ml.advections
jax_cfd/ml/advections.py:8: in <module>
from jax_cfd.ml import interpolations
jax_cfd/ml/interpolations.py:11: in <module>
from jax_cfd.ml import layers
jax_cfd/ml/layers.py:8: in <module>
import haiku as hk
venv/lib/python3.10/site-packages/haiku/__init__.py:20: in <module>
from haiku import experimental
venv/lib/python3.10/site-packages/haiku/experimental/__init__.py:21: in <module>
from haiku._src.base import current_name
venv/lib/python3.10/site-packages/haiku/_src/base.py:27: in <module>
from haiku._src.typing import ( # pylint: disable=g-multiple-import
venv/lib/python3.10/site-packages/haiku/_src/typing.py:22: in <module>
from typing_extensions import Protocol, runtime_checkable # pylint: disable=multiple-statements,g-multiple-import
E ModuleNotFoundError: No module named 'typing_extensions'
=================================== FAILURES ===================================
________ AdvectionTest.test_mass_conservation_dirichlet_dichlet_advect _________
[gw8] darwin -- Python 3.10.0 /Users/djk21/coding/jax-cfd/venv/bin/python
self = <jax_cfd.collocated.advection_test.AdvectionTest testMethod=test_mass_conservation_dirichlet_dichlet_advect>
shape = (101,), method = <function _euler_step.<locals>.step at 0x138bd57e0>
@parameterized.named_parameters(
dict(
testcase_name='dichlet_advect',
shape=(101,),
method=_euler_step(advection.advect_linear)),)
def test_mass_conservation_dirichlet(self, shape, method):
cfl_number = 0.1
dt = cfl_number / shape[0]
num_steps = 1000
grid = grids.Grid(shape, domain=([-1., 1.],))
bc = boundaries.dirichlet_boundary_conditions(grid.ndim)
c_bc = boundaries.dirichlet_boundary_conditions(grid.ndim, ((-1., 1.),))
def u(grid):
x = grid.mesh((0.5,))[0]
return grids.GridArray(-jnp.sin(jnp.pi * x), (0.5,), grid)
def c0(grid):
x = grid.mesh((0.5,))[0]
return grids.GridArray(x, (0.5,), grid)
v = (bc.impose_bc(u(grid)),)
c = c_bc.impose_bc(c0(grid))
ct = c
advect = jax.jit(functools.partial(method, v=v, dt=dt))
initial_mass = np.sum(c.data)
for _ in range(num_steps):
ct = advect(ct)
current_total_mass = np.sum(ct.data)
> self.assertAllClose(current_total_mass, initial_mass, atol=1e-6)
jax_cfd/collocated/advection_test.py:107:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jax_cfd/base/test_util.py:87: in assertAllClose
np.testing.assert_allclose(expected, actual, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (<function assert_allclose.<locals>.compare at 0x14c604160>, array(-9.536743e-07, dtype=float32), array(-2.861023e-06, dtype=float32))
kwds = {'equal_nan': True, 'err_msg': '', 'header': 'Not equal to tolerance rtol=1e-07, atol=1e-06', 'verbose': True}
@wraps(func)
def inner(*args, **kwds):
with self._recreate_cm():
> return func(*args, **kwds)
E AssertionError:
E Not equal to tolerance rtol=1e-07, atol=1e-06
E
E Mismatched elements: 1 / 1 (100%)
E Max absolute difference: 1.9073486e-06
E Max relative difference: 0.6666667
E x: array(-9.536743e-07, dtype=float32)
E y: array(-2.861023e-06, dtype=float32)
../../.pyenv/versions/3.10.0/lib/python3.10/contextlib.py:79: AssertionError
___________ AdvectionTest.test_neumann_bc_one_step_linear_1d_neumann ___________
[gw8] darwin -- Python 3.10.0 /Users/djk21/coding/jax-cfd/venv/bin/python
self = <jax_cfd.collocated.advection_test.AdvectionTest testMethod=test_neumann_bc_one_step_linear_1d_neumann>
shape = (1000,), method = <function advect_linear at 0x138bd4c10>
@parameterized.named_parameters(
dict(
testcase_name='linear_1d_neumann',
shape=(1000,),
method=advection.advect_linear),)
def test_neumann_bc_one_step(self, shape, method):
grid = grids.Grid(shape, domain=([-1., 1.],))
bc = boundaries.neumann_boundary_conditions(grid.ndim)
c_bc = boundaries.neumann_boundary_conditions(grid.ndim)
def u(grid):
x = grid.mesh((0.5,))[0]
return grids.GridArray(jnp.cos(jnp.pi * x), (0.5,), grid)
def c0(grid):
x = grid.mesh((0.5,))[0]
return grids.GridArray(jnp.cos(jnp.pi * x), (0.5,), grid)
def dcdt(grid):
x = grid.mesh((0.5,))[0]
return grids.GridArray(jnp.pi * jnp.sin(2 * jnp.pi * x), (0.5,), grid)
v = (bc.impose_bc(u(grid)),)
c = c_bc.impose_bc(c0(grid))
advect = jax.jit(functools.partial(method, v=v))
ct = advect(c)
> self.assertAllClose(ct, dcdt(grid), atol=1e-4)
jax_cfd/collocated/advection_test.py:137:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jax_cfd/base/test_util.py:87: in assertAllClose
np.testing.assert_allclose(expected, actual, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (<function assert_allclose.<locals>.compare at 0x14c6a51b0>, array([ 0.01972914, 0.05921721, 0.09869038, 0.13811885...0.25632194, -0.21695682,
-0.17755595, -0.13812703, -0.09867631, -0.0592115 , -0.01973734],
dtype=float32))
kwds = {'equal_nan': True, 'err_msg': '', 'header': 'Not equal to tolerance rtol=1e-07, atol=0.0001', 'verbose': True}
@wraps(func)
def inner(*args, **kwds):
with self._recreate_cm():
> return func(*args, **kwds)
E AssertionError:
E Not equal to tolerance rtol=1e-07, atol=0.0001
E
E Mismatched elements: 135 / 1000 (13.5%)
E Max absolute difference: 0.00017357
E Max relative difference: 0.00055049
E x: array([ 0.019729, 0.059217, 0.09869 , 0.138119, 0.177547, 0.216946,
E 0.256315, 0.295639, 0.334918, 0.374153, 0.413313, 0.452399,
E 0.491425, 0.530392, 0.569254, 0.608042, 0.646725, 0.685304,...
E y: array([ 0.019739, 0.059214, 0.098679, 0.13813 , 0.177557, 0.216957,
E 0.256323, 0.29565 , 0.334929, 0.374154, 0.413322, 0.452422,
E 0.491453, 0.530406, 0.569274, 0.608054, 0.646736, 0.685317,...
../../.pyenv/versions/3.10.0/lib/python3.10/contextlib.py:79: AssertionError
______ TimeSteppingTest.test_implicit_solve_harmonic_oscillator_implicit _______
[gw9] darwin -- Python 3.10.0 /Users/djk21/coding/jax-cfd/venv/bin/python
self = <jax_cfd.spectral.time_stepping_test.TimeSteppingTest testMethod=test_implicit_solve_harmonic_oscillator_implicit>
implicit_terms = <function <lambda> at 0x15332e440>
implicit_solve = <function <lambda> at 0x15332e4d0>
initial_state = array([1., 1.])
@parameterized.named_parameters(ALL_TEST_PROBLEMS)
def test_implicit_solve(
self,
explicit_terms,
implicit_terms,
implicit_solve,
dt,
inner_steps,
outer_steps,
initial_state,
closed_form,
tolerances,
):
"""Tests that time integration is accurate for a range of test cases."""
del dt, explicit_terms, inner_steps, outer_steps, closed_form # unused
del tolerances # unused
# Verifies that `implicit_solve` solves (y - eta * F(y)) = x
# This does not test the integrator, but rather verifies that the test
# case is valid.
eta = 0.3
solved_state = implicit_solve(initial_state, eta)
reconstructed_state = solved_state - eta * implicit_terms(solved_state)
> np.testing.assert_allclose(reconstructed_state, initial_state)
jax_cfd/spectral/time_stepping_test.py:159:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (<function assert_allclose.<locals>.compare at 0x1534de560>, array([0.9999999 , 0.99999994], dtype=float32), array([1., 1.]))
kwds = {'equal_nan': True, 'err_msg': '', 'header': 'Not equal to tolerance rtol=1e-07, atol=0', 'verbose': True}
@wraps(func)
def inner(*args, **kwds):
with self._recreate_cm():
> return func(*args, **kwds)
E AssertionError:
E Not equal to tolerance rtol=1e-07, atol=0
E
E Mismatched elements: 1 / 2 (50%)
E Max absolute difference: 1.1920929e-07
E Max relative difference: 1.1920929e-07
E x: array([1., 1.], dtype=float32)
E y: array([1., 1.])
../../.pyenv/versions/3.10.0/lib/python3.10/contextlib.py:79: AssertionError
____________ TimeSteppingTest.test_integration_constant_derivative _____________
[gw9] darwin -- Python 3.10.0 /Users/djk21/coding/jax-cfd/venv/bin/python
self = <jax_cfd.spectral.time_stepping_test.TimeSteppingTest testMethod=test_integration_constant_derivative>
explicit_terms = <function <lambda> at 0x15332da20>
implicit_terms = <function <lambda> at 0x15332dab0>
implicit_solve = <function <lambda> at 0x15332db40>, dt = 0.01, inner_steps = 10
outer_steps = 5, initial_state = array([1., 1., 1.])
closed_form = <function <lambda> at 0x15332dbd0>
tolerances = [1e-12, 1e-12, 1e-12, 1e-12, 1e-12]
@parameterized.named_parameters(ALL_TEST_PROBLEMS)
def test_integration(
self,
explicit_terms,
implicit_terms,
implicit_solve,
dt,
inner_steps,
outer_steps,
initial_state,
closed_form,
tolerances,
):
# Compute closed-form solution.
time = dt * inner_steps * (1 + np.arange(outer_steps))
expected = jax.vmap(closed_form, in_axes=(None, 0))(
initial_state, time)
# Compute trajectory using time-stepper.
for atol, time_stepper in zip(tolerances, ALL_TIME_STEPPERS):
with self.subTest(time_stepper.__name__):
equation = CustomODE(explicit_terms, implicit_terms, implicit_solve)
semi_implicit_step = time_stepper(equation, dt)
integrator = funcutils.trajectory(
funcutils.repeated(semi_implicit_step, inner_steps), outer_steps)
_, actual = integrator(initial_state)
> np.testing.assert_allclose(expected, actual, atol=atol, rtol=0)
jax_cfd/spectral/time_stepping_test.py:187:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (<function assert_allclose.<locals>.compare at 0x1534df250>, array([[1.5, 1.5, 1.5],
[2. , 2. , 2. ],
[2...99986, 2.4999986],
[2.999998 , 2.999998 , 2.999998 ],
[3.4999976, 3.4999976, 3.4999976]], dtype=float32))
kwds = {'equal_nan': True, 'err_msg': '', 'header': 'Not equal to tolerance rtol=0, atol=1e-12', 'verbose': True}
@wraps(func)
def inner(*args, **kwds):
with self._recreate_cm():
> return func(*args, **kwds)
E AssertionError:
E Not equal to tolerance rtol=0, atol=1e-12
E
E Mismatched elements: 15 / 15 (100%)
E Max absolute difference: 2.3841858e-06
E Max relative difference: 6.811964e-07
E x: array([[1.5, 1.5, 1.5],
E [2. , 2. , 2. ],
E [2.5, 2.5, 2.5],...
E y: array([[1.5 , 1.5 , 1.5 ],
E [1.999999, 1.999999, 1.999999],
E [2.499999, 2.499999, 2.499999],...
../../.pyenv/versions/3.10.0/lib/python3.10/contextlib.py:79: AssertionError
________ TimeSteppingTest.test_integration_harmonic_oscillator_explicit ________
[gw9] darwin -- Python 3.10.0 /Users/djk21/coding/jax-cfd/venv/bin/python
self = <jax_cfd.spectral.time_stepping_test.TimeSteppingTest testMethod=test_integration_harmonic_oscillator_explicit>
explicit_terms = <function <lambda> at 0x15332e320>
implicit_terms = <function zeros_like at 0x1155ec5e0>
implicit_solve = <function <lambda> at 0x15332e3b0>, dt = 0.01, inner_steps = 20
outer_steps = 5, initial_state = array([1., 1.])
closed_form = <function harmonic_oscillator at 0x15332d6c0>
tolerances = [0.01, 3e-05, 6e-08, 5e-11, 6e-08]
@parameterized.named_parameters(ALL_TEST_PROBLEMS)
def test_integration(
self,
explicit_terms,
implicit_terms,
implicit_solve,
dt,
inner_steps,
outer_steps,
initial_state,
closed_form,
tolerances,
):
# Compute closed-form solution.
time = dt * inner_steps * (1 + np.arange(outer_steps))
expected = jax.vmap(closed_form, in_axes=(None, 0))(
initial_state, time)
# Compute trajectory using time-stepper.
for atol, time_stepper in zip(tolerances, ALL_TIME_STEPPERS):
with self.subTest(time_stepper.__name__):
equation = CustomODE(explicit_terms, implicit_terms, implicit_solve)
semi_implicit_step = time_stepper(equation, dt)
integrator = funcutils.trajectory(
funcutils.repeated(semi_implicit_step, inner_steps), outer_steps)
_, actual = integrator(initial_state)
> np.testing.assert_allclose(expected, actual, atol=atol, rtol=0)
jax_cfd/spectral/time_stepping_test.py:187:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (<function assert_allclose.<locals>.compare at 0x1533a64d0>, array([[ 1.1787359 , 0.7813972 ],
[ 1.3104794 , ... [ 1.3899782 , 0.26069334],
[ 1.414063 , -0.02064916],
[ 1.3817736 , -0.30116856]], dtype=float32))
kwds = {'equal_nan': True, 'err_msg': '', 'header': 'Not equal to tolerance rtol=0, atol=6e-08', 'verbose': True}
@wraps(func)
def inner(*args, **kwds):
with self._recreate_cm():
> return func(*args, **kwds)
E AssertionError:
E Not equal to tolerance rtol=0, atol=6e-08
E
E Mismatched elements: 9 / 10 (90%)
E Max absolute difference: 3.5762787e-07
E Max relative difference: 1.2809022e-05
E x: array([[ 1.178736, 0.781397],
E [ 1.310479, 0.531643],
E [ 1.389978, 0.260693],...
E y: array([[ 1.178736, 0.781397],
E [ 1.310479, 0.531643],
E [ 1.389978, 0.260693],...
../../.pyenv/versions/3.10.0/lib/python3.10/contextlib.py:79: AssertionError
________ TimeSteppingTest.test_integration_harmonic_oscillator_implicit ________
[gw9] darwin -- Python 3.10.0 /Users/djk21/coding/jax-cfd/venv/bin/python
self = <jax_cfd.spectral.time_stepping_test.TimeSteppingTest testMethod=test_integration_harmonic_oscillator_implicit>
explicit_terms = <function zeros_like at 0x1155ec5e0>
implicit_terms = <function <lambda> at 0x15332e440>
implicit_solve = <function <lambda> at 0x15332e4d0>, dt = 0.01, inner_steps = 20
outer_steps = 5, initial_state = array([1., 1.])
closed_form = <function harmonic_oscillator at 0x15332d6c0>
tolerances = [0.01, 2e-05, 2e-06, 1e-06, 6e-06]
@parameterized.named_parameters(ALL_TEST_PROBLEMS)
def test_integration(
self,
explicit_terms,
implicit_terms,
implicit_solve,
dt,
inner_steps,
outer_steps,
initial_state,
closed_form,
tolerances,
):
# Compute closed-form solution.
time = dt * inner_steps * (1 + np.arange(outer_steps))
expected = jax.vmap(closed_form, in_axes=(None, 0))(
initial_state, time)
# Compute trajectory using time-stepper.
for atol, time_stepper in zip(tolerances, ALL_TIME_STEPPERS):
with self.subTest(time_stepper.__name__):
equation = CustomODE(explicit_terms, implicit_terms, implicit_solve)
semi_implicit_step = time_stepper(equation, dt)
integrator = funcutils.trajectory(
funcutils.repeated(semi_implicit_step, inner_steps), outer_steps)
_, actual = integrator(initial_state)
> np.testing.assert_allclose(expected, actual, atol=atol, rtol=0)
jax_cfd/spectral/time_stepping_test.py:187:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (<function assert_allclose.<locals>.compare at 0x1535d6e60>, array([[ 1.1787359 , 0.7813972 ],
[ 1.3104794 , ... [ 1.3899857 , 0.26069555],
[ 1.4140741 , -0.02064833],
[ 1.3817878 , -0.30117026]], dtype=float32))
kwds = {'equal_nan': True, 'err_msg': '', 'header': 'Not equal to tolerance rtol=0, atol=2e-06', 'verbose': True}
@wraps(func)
def inner(*args, **kwds):
with self._recreate_cm():
> return func(*args, **kwds)
E AssertionError:
E Not equal to tolerance rtol=0, atol=2e-06
E
E Mismatched elements: 7 / 10 (70%)
E Max absolute difference: 1.4543533e-05
E Max relative difference: 5.331295e-05
E x: array([[ 1.178736, 0.781397],
E [ 1.310479, 0.531643],
E [ 1.389978, 0.260693],...
E y: array([[ 1.178739, 0.781399],
E [ 1.310485, 0.531645],
E [ 1.389986, 0.260696],...
../../.pyenv/versions/3.10.0/lib/python3.10/contextlib.py:79: AssertionError
_________ TimeSteppingTest.test_integration_linear_derivative_explicit _________
[gw9] darwin -- Python 3.10.0 /Users/djk21/coding/jax-cfd/venv/bin/python
self = <jax_cfd.spectral.time_stepping_test.TimeSteppingTest testMethod=test_integration_linear_derivative_explicit>
explicit_terms = <function <lambda> at 0x15332dc60>
implicit_terms = <function <lambda> at 0x15332dcf0>
implicit_solve = <function <lambda> at 0x15332dd80>, dt = 0.01, inner_steps = 20
outer_steps = 5, initial_state = array([0., 1., 2.])
closed_form = <function <lambda> at 0x15332de10>
tolerances = [0.05, 0.0001, 1e-06, 1e-09, 1e-06]
@parameterized.named_parameters(ALL_TEST_PROBLEMS)
def test_integration(
self,
explicit_terms,
implicit_terms,
implicit_solve,
dt,
inner_steps,
outer_steps,
initial_state,
closed_form,
tolerances,
):
# Compute closed-form solution.
time = dt * inner_steps * (1 + np.arange(outer_steps))
expected = jax.vmap(closed_form, in_axes=(None, 0))(
initial_state, time)
# Compute trajectory using time-stepper.
for atol, time_stepper in zip(tolerances, ALL_TIME_STEPPERS):
with self.subTest(time_stepper.__name__):
equation = CustomODE(explicit_terms, implicit_terms, implicit_solve)
semi_implicit_step = time_stepper(equation, dt)
integrator = funcutils.trajectory(
funcutils.repeated(semi_implicit_step, inner_steps), outer_steps)
_, actual = integrator(initial_state)
> np.testing.assert_allclose(expected, actual, atol=atol, rtol=0)
jax_cfd/spectral/time_stepping_test.py:187:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (<function assert_allclose.<locals>.compare at 0x1534df010>, array([[0. , 1.2214028, 2.4428055],
[0. ...21195, 3.644239 ],
[0. , 2.225541 , 4.451082 ],
[0. , 2.7182808, 5.4365616]], dtype=float32))
kwds = {'equal_nan': True, 'err_msg': '', 'header': 'Not equal to tolerance rtol=0, atol=1e-06', 'verbose': True}
@wraps(func)
def inner(*args, **kwds):
with self._recreate_cm():
> return func(*args, **kwds)
E AssertionError:
E Not equal to tolerance rtol=0, atol=1e-06
E
E Mismatched elements: 2 / 15 (13.3%)
E Max absolute difference: 1.9073486e-06
E Max relative difference: 3.508373e-07
E x: array([[0. , 1.221403, 2.442806],
E [0. , 1.491825, 2.983649],
E [0. , 1.822119, 3.644238],...
E y: array([[0. , 1.221403, 2.442806],
E [0. , 1.491825, 2.98365 ],
E [0. , 1.822119, 3.644239],...
../../.pyenv/versions/3.10.0/lib/python3.10/contextlib.py:79: AssertionError
_________ TimeSteppingTest.test_integration_linear_derivative_implicit _________
[gw9] darwin -- Python 3.10.0 /Users/djk21/coding/jax-cfd/venv/bin/python
self = <jax_cfd.spectral.time_stepping_test.TimeSteppingTest testMethod=test_integration_linear_derivative_implicit>
explicit_terms = <function <lambda> at 0x15332dea0>
implicit_terms = <function <lambda> at 0x15332df30>
implicit_solve = <function <lambda> at 0x15332dfc0>, dt = 0.01, inner_steps = 20
outer_steps = 5, initial_state = array([0., 1., 2.])
closed_form = <function <lambda> at 0x15332e050>
tolerances = [0.05, 5e-05, 1e-05, 1e-05, 3e-05]
@parameterized.named_parameters(ALL_TEST_PROBLEMS)
def test_integration(
self,
explicit_terms,
implicit_terms,
implicit_solve,
dt,
inner_steps,
outer_steps,
initial_state,
closed_form,
tolerances,
):
# Compute closed-form solution.
time = dt * inner_steps * (1 + np.arange(outer_steps))
expected = jax.vmap(closed_form, in_axes=(None, 0))(
initial_state, time)
# Compute trajectory using time-stepper.
for atol, time_stepper in zip(tolerances, ALL_TIME_STEPPERS):
with self.subTest(time_stepper.__name__):
equation = CustomODE(explicit_terms, implicit_terms, implicit_solve)
semi_implicit_step = time_stepper(equation, dt)
integrator = funcutils.trajectory(
funcutils.repeated(semi_implicit_step, inner_steps), outer_steps)
_, actual = integrator(initial_state)
> np.testing.assert_allclose(expected, actual, atol=atol, rtol=0)
jax_cfd/spectral/time_stepping_test.py:187:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (<function assert_allclose.<locals>.compare at 0x1535d6290>, array([[0. , 1.2214028, 2.4428055],
[0. ...21304, 3.644261 ],
[0. , 2.2255597, 4.4511194],
[0. , 2.7183104, 5.4366207]], dtype=float32))
kwds = {'equal_nan': True, 'err_msg': '', 'header': 'Not equal to tolerance rtol=0, atol=5e-05', 'verbose': True}
@wraps(func)
def inner(*args, **kwds):
with self._recreate_cm():
> return func(*args, **kwds)
E AssertionError:
E Not equal to tolerance rtol=0, atol=5e-05
E
E Mismatched elements: 1 / 15 (6.67%)
E Max absolute difference: 5.722046e-05
E Max relative difference: 1.0525005e-05
E x: array([[0. , 1.221403, 2.442806],
E [0. , 1.491825, 2.983649],
E [0. , 1.822119, 3.644238],...
E y: array([[0. , 1.221406, 2.442811],
E [0. , 1.491831, 2.983662],
E [0. , 1.82213 , 3.644261],...
../../.pyenv/versions/3.10.0/lib/python3.10/contextlib.py:79: AssertionError
______ TimeSteppingTest.test_integration_linear_derivative_semi_implicit _______
[gw9] darwin -- Python 3.10.0 /Users/djk21/coding/jax-cfd/venv/bin/python
self = <jax_cfd.spectral.time_stepping_test.TimeSteppingTest testMethod=test_integration_linear_derivative_semi_implicit>
explicit_terms = <function <lambda> at 0x15332e0e0>
implicit_terms = <function <lambda> at 0x15332e170>
implicit_solve = <function <lambda> at 0x15332e200>, dt = 0.01, inner_steps = 20
outer_steps = 5, initial_state = array([0., 1., 2.])
closed_form = <function <lambda> at 0x15332e290>
tolerances = [0.0001, 2e-05, 2e-06, 1e-06, 2e-05]
@parameterized.named_parameters(ALL_TEST_PROBLEMS)
def test_integration(
self,
explicit_terms,
implicit_terms,
implicit_solve,
dt,
inner_steps,
outer_steps,
initial_state,
closed_form,
tolerances,
):
# Compute closed-form solution.
time = dt * inner_steps * (1 + np.arange(outer_steps))
expected = jax.vmap(closed_form, in_axes=(None, 0))(
initial_state, time)
# Compute trajectory using time-stepper.
for atol, time_stepper in zip(tolerances, ALL_TIME_STEPPERS):
with self.subTest(time_stepper.__name__):
equation = CustomODE(explicit_terms, implicit_terms, implicit_solve)
semi_implicit_step = time_stepper(equation, dt)
integrator = funcutils.trajectory(
funcutils.repeated(semi_implicit_step, inner_steps), outer_steps)
_, actual = integrator(initial_state)
> np.testing.assert_allclose(expected, actual, atol=atol, rtol=0)
jax_cfd/spectral/time_stepping_test.py:187:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (<function assert_allclose.<locals>.compare at 0x1534dee60>, array([[0. , 1.2214028, 2.4428055],
[0. ...21099, 3.6442199],
[0. , 2.2255282, 4.4510565],
[0. , 2.7182617, 5.4365234]], dtype=float32))
kwds = {'equal_nan': True, 'err_msg': '', 'header': 'Not equal to tolerance rtol=0, atol=2e-06', 'verbose': True}
@wraps(func)
def inner(*args, **kwds):
with self._recreate_cm():
> return func(*args, **kwds)
E AssertionError:
E Not equal to tolerance rtol=0, atol=2e-06
E
E Mismatched elements: 10 / 15 (66.7%)
E Max absolute difference: 4.005432e-05
E Max relative difference: 7.367635e-06
E x: array([[0. , 1.221403, 2.442806],
E [0. , 1.491825, 2.983649],
E [0. , 1.822119, 3.644238],...
E y: array([[0. , 1.221401, 2.442801],
E [0. , 1.49182 , 2.983639],
E [0. , 1.82211 , 3.64422 ],...
../../.pyenv/versions/3.10.0/lib/python3.10/contextlib.py:79: AssertionError
_ SubgridModelsTest.test_divergence_and_momentum_sinusoidal_velocity_with_subgrid_model _
[gw14] darwin -- Python 3.10.0 /Users/djk21/coding/jax-cfd/venv/bin/python
self = <jax_cfd.base.subgrid_models_test.SubgridModelsTest testMethod=test_divergence_and_momentum_sinusoidal_velocity_with_subgrid_model>
cs = 0.12, velocity = <function sinusoidal_velocity_field at 0x13359bb50>
forcing = None, shape = (100, 100), step = (1.0, 1.0), density = 1.0
viscosity = 0.0001, convect = <function convect_linear at 0x1116b0ee0>
pressure_solve = <function solve_fast_diag at 0x129bb1c60>, dt = 0.001
time_steps = 1000, divergence_atol = 0.001, momentum_atol = 0.001
@parameterized.named_parameters(
dict(
testcase_name='sinusoidal_velocity_base',
cs=0.0,
velocity=sinusoidal_velocity_field,
forcing=None,
shape=(100, 100),
step=(1., 1.),
density=1.,
viscosity=1e-4,
convect=advection.convect_linear,
pressure_solve=pressure.solve_cg,
dt=1e-3,
time_steps=1000,
divergence_atol=1e-3,
momentum_atol=2e-3),
dict(
testcase_name='gaussian_force_upwind_with_subgrid_model',
cs=0.12,
velocity=zero_velocity_field,
forcing=gaussian_forcing,
shape=(40, 40, 40),
step=(1., 1., 1.),
density=1.,
viscosity=0,
convect=_convect_upwind,
pressure_solve=pressure.solve_cg,
dt=1e-3,
time_steps=100,
divergence_atol=1e-4,
momentum_atol=1e-4),
dict(
testcase_name='sinusoidal_velocity_with_subgrid_model',
cs=0.12,
velocity=sinusoidal_velocity_field,
forcing=None,
shape=(100, 100),
step=(1., 1.),
density=1.,
viscosity=1e-4,
convect=advection.convect_linear,
pressure_solve=pressure.solve_fast_diag,
dt=1e-3,
time_steps=1000,
divergence_atol=1e-3,
momentum_atol=1e-3),
)
def test_divergence_and_momentum(
self,
cs,
velocity,
forcing,
shape,
step,
density,
viscosity,
convect,
pressure_solve,
dt,
time_steps,
divergence_atol,
momentum_atol,
):
grid = grids.Grid(shape, step)
kwargs = dict(
density=density,
viscosity=viscosity,
cs=cs,
dt=dt,
grid=grid,
convect=convect,
pressure_solve=pressure_solve,
forcing=forcing)
# Explicit and implicit navier-stokes solvers:
explicit_eq = subgrid_models.explicit_smagorinsky_navier_stokes(**kwargs)
implicit_eq = subgrid_models.implicit_smagorinsky_navier_stokes(**kwargs)
v_initial = velocity(grid)
v_final = funcutils.repeated(explicit_eq, time_steps)(v_initial)
# TODO(dkochkov) consider adding more thorough tests for these models.
with self.subTest('divergence free'):
divergence = fd.divergence(v_final)
self.assertLess(jnp.max(divergence.data), divergence_atol)
with self.subTest('conservation of momentum'):
initial_momentum = momentum(v_initial, density)
final_momentum = momentum(v_final, density)
if forcing is not None:
expected_change = (
jnp.array([f.data for f in forcing(v_initial)]).sum() *
jnp.array(grid.step).prod() * dt * time_steps)
else:
expected_change = 0
expected_momentum = initial_momentum + expected_change
> self.assertAllClose(expected_momentum, final_momentum, atol=momentum_atol)
jax_cfd/base/subgrid_models_test.py:211:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jax_cfd/base/test_util.py:87: in assertAllClose
np.testing.assert_allclose(expected, actual, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (<function assert_allclose.<locals>.compare at 0x147687130>, array(-0.00071716, dtype=float32), array(-0.00175476, dtype=float32))
kwds = {'equal_nan': True, 'err_msg': '', 'header': 'Not equal to tolerance rtol=1e-07, atol=0.001', 'verbose': True}
@wraps(func)
def inner(*args, **kwds):
with self._recreate_cm():
> return func(*args, **kwds)
E AssertionError:
E Not equal to tolerance rtol=1e-07, atol=0.001
E
E Mismatched elements: 1 / 1 (100%)
E Max absolute difference: 0.0010376
E Max relative difference: 0.59130436
E x: array(-0.000717, dtype=float32)
E y: array(-0.001755, dtype=float32)
../../.pyenv/versions/3.10.0/lib/python3.10/contextlib.py:79: AssertionError
_______________ AdvectionTest.test_mass_conservation_van_leer_1D _______________
[gw0] darwin -- Python 3.10.0 /Users/djk21/coding/jax-cfd/venv/bin/python
self = <jax_cfd.base.advection_test.AdvectionTest testMethod=test_mass_conservation_van_leer_1D>
shape = (101,), method = <function _euler_step.<locals>.step at 0x129e8f9a0>
@parameterized.named_parameters(
dict(
testcase_name='van_leer_1D',
shape=(101,),
method=_euler_step(advection.advect_van_leer)),
)
def test_mass_conservation(self, shape, method):
offset = 0.5
cfl_number = 0.1
dt = cfl_number / shape[0]
num_steps = 1000
grid = grids.Grid(shape, domain=([-1., 1.],))
bc = boundaries.dirichlet_boundary_conditions(grid.ndim)
c_bc = boundaries.dirichlet_boundary_conditions(grid.ndim, ((-1., 1.),))
def u(grid, offset):
x = grid.mesh((offset,))[0]
return grids.GridArray(-jnp.sin(jnp.pi * x), (offset,), grid)
def c0(grid, offset):
x = grid.mesh((offset,))[0]
return grids.GridArray(x, (offset,), grid)
v = (bc.impose_bc(u(grid, 1.)),)
c = c_bc.impose_bc(c0(grid, offset))
ct = c
advect = jax.jit(functools.partial(method, v=v, dt=dt))
initial_mass = np.sum(c.data)
for _ in range(num_steps):
ct = advect(ct)
current_total_mass = np.sum(ct.data)
> self.assertAllClose(current_total_mass, initial_mass, atol=1e-6)
jax_cfd/base/advection_test.py:442:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jax_cfd/base/test_util.py:87: in assertAllClose
np.testing.assert_allclose(expected, actual, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (<function assert_allclose.<locals>.compare at 0x1468c7010>, array(0., dtype=float32), array(-2.861023e-06, dtype=float32))
kwds = {'equal_nan': True, 'err_msg': '', 'header': 'Not equal to tolerance rtol=1e-07, atol=1e-06', 'verbose': True}
@wraps(func)
def inner(*args, **kwds):
with self._recreate_cm():
> return func(*args, **kwds)
E AssertionError:
E Not equal to tolerance rtol=1e-07, atol=1e-06
E
E Mismatched elements: 1 / 1 (100%)
E Max absolute difference: 2.861023e-06
E Max relative difference: 1.
E x: array(0., dtype=float32)
E y: array(-2.861023e-06, dtype=float32)
../../.pyenv/versions/3.10.0/lib/python3.10/contextlib.py:79: AssertionError
=============================== warnings summary ===============================
venv/lib/python3.10/site-packages/jax/_src/pjit.py:288: 16 warnings
/Users/djk21/coding/jax-cfd/venv/lib/python3.10/site-packages/jax/_src/pjit.py:288: DeprecationWarning: backend and device argument on jit is deprecated. You can use a `jax.sharding.Mesh` context manager or device_put the arguments before passing them to `jit`. Please see https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html for more information.
warnings.warn(
jax_cfd/base/grids_test.py::GridArrayTest::test_tree_util
/Users/djk21/coding/jax-cfd/jax_cfd/base/grids_test.py:32: DeprecationWarning: jax.tree_flatten is deprecated: use jax.tree_util.tree_flatten.
flat, treedef = jax.tree_flatten(array)
jax_cfd/base/grids_test.py::GridArrayTest::test_tree_util
/Users/djk21/coding/jax-cfd/jax_cfd/base/grids_test.py:33: DeprecationWarning: jax.tree_unflatten is deprecated: use jax.tree_util.tree_unflatten.
roundtripped = jax.tree_unflatten(treedef, flat)
jax_cfd/base/fast_diagonalization_test.py::FastDiagonalizationTest::test_identity_nd3
jax_cfd/base/fast_diagonalization_test.py::FastDiagonalizationTest::test_identity_nd4
jax_cfd/base/fast_diagonalization_test.py::FastDiagonalizationTest::test_identity_nd5
jax_cfd/base/fast_diagonalization_test.py::FastDiagonalizationTest::test_poisson_1d1
jax_cfd/base/fast_diagonalization_test.py::FastDiagonalizationTest::test_poisson_2d_fft0
jax_cfd/base/fast_diagonalization_test.py::FastDiagonalizationTest::test_random_1d_fft0
/Users/djk21/coding/jax-cfd/venv/lib/python3.10/site-packages/jax/_src/lax/lax.py:511: ComplexWarning: Casting complex values to real discards the imaginary part
return _convert_element_type(operand, new_dtype, weak_type=False)
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_along_axis_shapes0
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_along_axis_shapes0
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_along_axis_shapes1
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_along_axis_shapes1
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_along_axis_shapes2
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_along_axis_shapes2
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_along_axis_shapes3
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_along_axis_shapes3
/Users/djk21/coding/jax-cfd/jax_cfd/base/array_utils.py:118: DeprecationWarning: jax.tree_flatten is deprecated: use jax.tree_util.tree_flatten.
arrays, tree_def = jax.tree_flatten(inputs)
jax_cfd/base/array_utils_test.py: 74 warnings
/Users/djk21/coding/jax-cfd/jax_cfd/base/array_utils.py:127: DeprecationWarning: jax.tree_unflatten is deprecated: use jax.tree_util.tree_unflatten.
return tuple(jax.tree_unflatten(tree_def, leaves) for leaves in splits)
jax_cfd/base/resize_test.py: 32 warnings
jax_cfd/base/array_utils_test.py: 12 warnings
/Users/djk21/coding/jax-cfd/jax_cfd/base/array_utils.py:60: DeprecationWarning: jax.tree_flatten is deprecated: use jax.tree_util.tree_flatten.
arrays, tree_def = jax.tree_flatten(inputs)
jax_cfd/base/resize_test.py: 32 warnings
jax_cfd/base/array_utils_test.py: 10 warnings
/Users/djk21/coding/jax-cfd/jax_cfd/base/array_utils.py:71: DeprecationWarning: jax.tree_unflatten is deprecated: use jax.tree_util.tree_unflatten.
return jax.tree_unflatten(tree_def, sliced)
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat0
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat1
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat2
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat3
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat4
/Users/djk21/coding/jax-cfd/jax_cfd/base/array_utils_test.py:137: DeprecationWarning: jax.tree_leaves is deprecated: use jax.tree_util.tree_leaves.
self.assertEqual(jax.tree_leaves(split_a)[0].shape[axis], idx)
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat0
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat1
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat2
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat3
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat4
/Users/djk21/coding/jax-cfd/jax_cfd/base/array_utils_test.py:141: DeprecationWarning: jax.tree_structure is deprecated: use jax.tree_util.tree_structure.
actual_tree_def = jax.tree_structure(reconstruction)
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat0
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat1
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat2
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat3
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat4
/Users/djk21/coding/jax-cfd/jax_cfd/base/array_utils_test.py:142: DeprecationWarning: jax.tree_structure is deprecated: use jax.tree_util.tree_structure.
expected_tree_def = jax.tree_structure(pytree)
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat0
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat1
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat2
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat3
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat4
/Users/djk21/coding/jax-cfd/jax_cfd/base/array_utils_test.py:145: DeprecationWarning: jax.tree_leaves is deprecated: use jax.tree_util.tree_leaves.
actual_values = jax.tree_leaves(reconstruction)
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat0
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat1
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat2
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat3
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat4
/Users/djk21/coding/jax-cfd/jax_cfd/base/array_utils_test.py:146: DeprecationWarning: jax.tree_leaves is deprecated: use jax.tree_util.tree_leaves.
expected_values = jax.tree_leaves(pytree)
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat0
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat1
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat2
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat3
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat4
/Users/djk21/coding/jax-cfd/jax_cfd/base/array_utils_test.py:160: DeprecationWarning: jax.tree_leaves is deprecated: use jax.tree_util.tree_leaves.
actual_shape = jax.tree_leaves(double_concat)[0].shape[axis]
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat0
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat1
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat2
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat3
jax_cfd/base/array_utils_test.py::ArrayUtilsTest::test_split_and_concat4
/Users/djk21/coding/jax-cfd/jax_cfd/base/array_utils_test.py:161: DeprecationWarning: jax.tree_leaves is deprecated: use jax.tree_util.tree_leaves.
expected_shape = jax.tree_leaves(pytree)[0].shape[axis] * 2
jax_cfd/base/subgrid_models_test.py: 15 warnings
/Users/djk21/coding/jax-cfd/jax_cfd/base/subgrid_models.py:98: DeprecationWarning: jax.tree_unflatten is deprecated: use jax.tree_util.tree_unflatten.
return jax.tree_unflatten(jax.tree_util.tree_structure(s_ij), viscosities)
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
=========================== short test summary info ============================
FAILED jax_cfd/collocated/advection_test.py::AdvectionTest::test_mass_conservation_dirichlet_dichlet_advect
FAILED jax_cfd/collocated/advection_test.py::AdvectionTest::test_neumann_bc_one_step_linear_1d_neumann
FAILED jax_cfd/spectral/time_stepping_test.py::TimeSteppingTest::test_implicit_solve_harmonic_oscillator_implicit
FAILED jax_cfd/spectral/time_stepping_test.py::TimeSteppingTest::test_integration_constant_derivative
FAILED jax_cfd/spectral/time_stepping_test.py::TimeSteppingTest::test_integration_harmonic_oscillator_explicit
FAILED jax_cfd/spectral/time_stepping_test.py::TimeSteppingTest::test_integration_harmonic_oscillator_implicit
FAILED jax_cfd/spectral/time_stepping_test.py::TimeSteppingTest::test_integration_linear_derivative_explicit
FAILED jax_cfd/spectral/time_stepping_test.py::TimeSteppingTest::test_integration_linear_derivative_implicit
FAILED jax_cfd/spectral/time_stepping_test.py::TimeSteppingTest::test_integration_linear_derivative_semi_implicit
FAILED jax_cfd/base/subgrid_models_test.py::SubgridModelsTest::test_divergence_and_momentum_sinusoidal_velocity_with_subgrid_model
FAILED jax_cfd/base/advection_test.py::AdvectionTest::test_mass_conservation_van_leer_1D
ERROR jax_cfd/ml/equations_test.py
ERROR jax_cfd/ml/layers_test.py
ERROR jax_cfd/ml/layers_util_test.py
ERROR jax_cfd/ml/towers_test.py
====== 11 failed, 447 passed, 242 warnings, 4 errors in 118.92s (0:01:58) ======
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment