diff --git a/experimental/torch_xla2/README.md b/experimental/torch_xla2/README.md index 5d538490429..5cb283ab288 100644 --- a/experimental/torch_xla2/README.md +++ b/experimental/torch_xla2/README.md @@ -21,7 +21,7 @@ the instructions below from scratch (fresh venv / conda environment.) ### 1. Installing `torch_xla2` -The following instructions assume you are in the `torch_xla2 directory: +The following instructions assume you are in the `torch_xla2` directory: ``` $ git clone https://github.com/pytorch/xla.git diff --git a/experimental/torch_xla2/test/test_exports.py b/experimental/torch_xla2/test/test_exports.py index aec955b360b..ce465324a4c 100644 --- a/experimental/torch_xla2/test/test_exports.py +++ b/experimental/torch_xla2/test/test_exports.py @@ -54,9 +54,6 @@ def test_interpolate(self): self.assertIn("func.func private @clip(%arg0: tensor<500xf32>", module_str) self.assertIn("stablehlo.minimum", module_str) - # Test with dynamic export - - def test_constant(self): # Check Accuracy diff --git a/experimental/torch_xla2/torch_xla2/export.py b/experimental/torch_xla2/torch_xla2/export.py index 7bf5589b8c1..387d9889386 100644 --- a/experimental/torch_xla2/torch_xla2/export.py +++ b/experimental/torch_xla2/torch_xla2/export.py @@ -6,7 +6,6 @@ from torch.utils import _pytree as pytree from torch_xla2 import tensor from torch_xla2.ops import ops_registry -from jax.experimental import export import jax import jax.numpy as jnp import sympy @@ -122,11 +121,9 @@ def _get_dim(d): return symbolic_shapes[str(d)] return d - def is_scalar(meta): - val = meta['val'] - return isinstance(val, float) or isinstance(val, int) or isinstance(val, bool) - - if is_scalar(arg_meta): + val = arg_meta['val'] + is_scalar = isinstance(val, float) or isinstance(val, int) or isinstance(val, bool) + if is_scalar: return jax.ShapeDtypeStruct([], type(arg_meta['val'])) tensor_meta = arg_meta['tensor_meta'] @@ -162,13 +159,13 @@ def _build_symbolic_constraints(symbol_name, torch_constraint): ==> ("a >= 5", "a <= 10",) """ if not isinstance(torch_constraint, torch.utils._sympy.value_ranges.ValueRanges) or torch_constraint.is_bool: - raise TypeError(f"No symbolic constraint handler for: {constraint}") + raise TypeError(f"No symbolic constraint handler for: {torch_constraint}") constraints = [] symbol = sympy.Symbol(symbol_name) - if (torch_constraint.lower != 2): + if torch_constraint.lower != 2: constraints.append(symbol >= torch_constraint.lower) - if (not torch_constraint.upper.is_infinite): + if not torch_constraint.upper.is_infinite: constraints.append(symbol <= torch_constraint.upper) return tuple(sympy.pretty(c, use_unicode=False) for c in constraints) @@ -202,10 +199,7 @@ def _build_symbolic_shape(sym, constraint, free_symbols): symbolic_shapes = {} symbol_variables = [(s,v) for s,v in range_constraints.items() if s.is_symbol] symbol_exprs = [(s,v) for s,v in range_constraints.items() if not s.is_symbol] - for sym, constraint in symbol_variables: - symbolic_shape = _build_symbolic_shape(sym, constraint, symbolic_shapes) - symbolic_shapes[str(sym)] = symbolic_shape - for sym, constraint in symbol_exprs: + for sym, constraint in symbol_variables + symbol_exprs: symbolic_shape = _build_symbolic_shape(sym, constraint, symbolic_shapes) symbolic_shapes[str(sym)] = symbolic_shape return symbolic_shapes @@ -232,5 +226,5 @@ def exported_program_to_stablehlo(exported_program): """ weights, func = exported_program_to_jax(exported_program) jax_avals = extract_avals(exported_program) - jax_export = export.export(func)(weights, (jax_avals,)) + jax_export = jax.experimental.export.export(func)(weights, (jax_avals,)) return jax_export