diff --git a/test/tpu/xla_test_job.yaml b/test/tpu/xla_test_job.yaml index 5091555a508..e7f5258b9dd 100644 --- a/test/tpu/xla_test_job.yaml +++ b/test/tpu/xla_test_job.yaml @@ -46,8 +46,6 @@ spec: pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html cd /src/pytorch/xla - # TODO: pallas test requires JAX, now we need to explicitly set TPU_LIBRARY_PATH for JAX, need a permanent fix. - TPU_LIBRARY_PATH=/usr/local/lib/python3.10/site-packages/torch_xla/lib/libtpu.so test/tpu/run_tests.sh volumeMounts: - mountPath: /dev/shm name: dshm