Skip to content

Commit

Permalink
address feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
GleasonK committed May 22, 2024
1 parent 9e22442 commit a995b89
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 18 deletions.
2 changes: 1 addition & 1 deletion experimental/torch_xla2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ the instructions below from scratch (fresh venv / conda environment.)

### 1. Installing `torch_xla2`

The following instructions assume you are in the `torch_xla2 directory:
The following instructions assume you are in the `torch_xla2` directory:

```
$ git clone https://github.com/pytorch/xla.git
Expand Down
3 changes: 0 additions & 3 deletions experimental/torch_xla2/test/test_exports.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,6 @@ def test_interpolate(self):
self.assertIn("func.func private @clip(%arg0: tensor<500xf32>", module_str)
self.assertIn("stablehlo.minimum", module_str)

# Test with dynamic export


def test_constant(self):

# Check Accuracy
Expand Down
22 changes: 8 additions & 14 deletions experimental/torch_xla2/torch_xla2/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from torch.utils import _pytree as pytree
from torch_xla2 import tensor
from torch_xla2.ops import ops_registry
from jax.experimental import export
import jax
import jax.numpy as jnp
import sympy
Expand Down Expand Up @@ -122,11 +121,9 @@ def _get_dim(d):
return symbolic_shapes[str(d)]
return d

def is_scalar(meta):
val = meta['val']
return isinstance(val, float) or isinstance(val, int) or isinstance(val, bool)

if is_scalar(arg_meta):
val = arg_meta['val']
is_scalar = isinstance(val, float) or isinstance(val, int) or isinstance(val, bool)
if is_scalar:
return jax.ShapeDtypeStruct([], type(arg_meta['val']))

tensor_meta = arg_meta['tensor_meta']
Expand Down Expand Up @@ -162,13 +159,13 @@ def _build_symbolic_constraints(symbol_name, torch_constraint):
==> ("a >= 5", "a <= 10",)
"""
if not isinstance(torch_constraint, torch.utils._sympy.value_ranges.ValueRanges) or torch_constraint.is_bool:
raise TypeError(f"No symbolic constraint handler for: {constraint}")
raise TypeError(f"No symbolic constraint handler for: {torch_constraint}")

constraints = []
symbol = sympy.Symbol(symbol_name)
if (torch_constraint.lower != 2):
if torch_constraint.lower != 2:
constraints.append(symbol >= torch_constraint.lower)
if (not torch_constraint.upper.is_infinite):
if not torch_constraint.upper.is_infinite:
constraints.append(symbol <= torch_constraint.upper)

return tuple(sympy.pretty(c, use_unicode=False) for c in constraints)
Expand Down Expand Up @@ -202,10 +199,7 @@ def _build_symbolic_shape(sym, constraint, free_symbols):
symbolic_shapes = {}
symbol_variables = [(s,v) for s,v in range_constraints.items() if s.is_symbol]
symbol_exprs = [(s,v) for s,v in range_constraints.items() if not s.is_symbol]
for sym, constraint in symbol_variables:
symbolic_shape = _build_symbolic_shape(sym, constraint, symbolic_shapes)
symbolic_shapes[str(sym)] = symbolic_shape
for sym, constraint in symbol_exprs:
for sym, constraint in symbol_variables + symbol_exprs:
symbolic_shape = _build_symbolic_shape(sym, constraint, symbolic_shapes)
symbolic_shapes[str(sym)] = symbolic_shape
return symbolic_shapes
Expand All @@ -232,5 +226,5 @@ def exported_program_to_stablehlo(exported_program):
"""
weights, func = exported_program_to_jax(exported_program)
jax_avals = extract_avals(exported_program)
jax_export = export.export(func)(weights, (jax_avals,))
jax_export = jax.experimental.export.export(func)(weights, (jax_avals,))
return jax_export

0 comments on commit a995b89

Please sign in to comment.