Skip to content

Commit

Permalink
Move op dispatching logic into an Environment class; and use Mode t…
Browse files Browse the repository at this point in the history
…o capture dispatcher instead of tensor. (#7009)
  • Loading branch information
qihqi authored May 8, 2024
1 parent a006372 commit 825ba0d
Show file tree
Hide file tree
Showing 28 changed files with 2,555 additions and 2,525 deletions.
40 changes: 40 additions & 0 deletions experimental/torch_xla2/docs/ops_registry.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Ops Registry

## Background

In the [How it works](how_it_works.md) doc, we mentioned 2 important pieces:

1. A mechanism to route `ATen` ops to implementation written in
Jax or in PyTorch, and

2. The ops themselves.


Ops Registry is there to help us to organize the ops themselves.

An op implementation can written in terms of Jax, or in other PyTorch ops.
The latter is also known as "decompositions". For decompositions,
one need to be careful of not introducing circular dependencies.

Here we simply store the operator implementations in a dictionary,
which key the torch / Aten callable that we wish to override, and
value an instance of `Operator` class.

`Operator` class has this schema:

```python
@dataclasses.dataclass
class Operator:
torch_op: TorchCallable
func: Union[TorchCallable, JaxCallable]
is_jax_function: bool
is_user_defined: bool
needs_env: bool
```

The `torch_op` is the corresponding torch callable, and `func` the implementation. `is_jax_function` is True if `func` is implemented using Jax, False if `func` is implemented using other torch ops. We can use this information to decide how to call it.

If `needs_env` is true, `func` will recieve an extra kwarg with name `env`.
This will be the "Environment" in which this op operate on. In particular,
the environment will contain the Jax random number generator key, that might be useful for ops like `aten::rand`.

31 changes: 13 additions & 18 deletions experimental/torch_xla2/examples/basic_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
from torch.utils import _pytree as pytree
import torchvision
import torchvision.transforms as transforms
import torch_xla2
import torch_xla2.tensor


xla_env = torch_xla2.tensor.Environment(0)
mode = xla_env.mode()

# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter
Expand Down Expand Up @@ -80,6 +84,7 @@ def forward(self, x):


model = GarmentClassifier()
model = xla_env.to_xla(model)

loss_fn = torch.nn.CrossEntropyLoss()

Expand All @@ -96,13 +101,6 @@ def forward(self, x):
print('Total loss for this batch: {}'.format(loss.item()))

# Optimizers specified in the torch.optim package

# NEW: Move model to XLA device
state_dict = model.state_dict()
state_dict = pytree.tree_map_only(torch.Tensor,
torch_xla2.tensor.move_to_device, state_dict)
model.load_state_dict(state_dict, strict=False, assign=True)

optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

def train_one_epoch(epoch_index, tb_writer):
Expand All @@ -115,14 +113,14 @@ def train_one_epoch(epoch_index, tb_writer):
for i, data in enumerate(training_loader):
# Every data instance is an input + label pair
# NEW: Move model to XLA device
data = pytree.tree_map_only(torch.Tensor,
torch_xla2.tensor.move_to_device, data)
data = xla_env.to_xla(data)
inputs, labels = data

# Zero your gradients for every batch!
optimizer.zero_grad()

# Make predictions for this batch

outputs = model(inputs)

# Compute the loss and its gradients
Expand Down Expand Up @@ -169,14 +167,11 @@ def train_one_epoch(epoch_index, tb_writer):
# Disable gradient computation and reduce memory consumption.
with torch.no_grad():
for i, vdata in enumerate(validation_loader):
# NOTE: move to XLA device
vinputs, vlabels = pytree.tree_map_only(
torch.Tensor,
torch_xla2.tensor.move_to_device,
vdata)
voutputs = model(vinputs) # call model's forward
vloss = loss_fn(voutputs, vlabels)
running_vloss += vloss
# NOTE: move to XLA device
vinputs, vlabels = xla_env.to_xla(vdata)
voutputs = model(vinputs) # call model's forward
vloss = loss_fn(voutputs, vlabels)
running_vloss += vloss

avg_vloss = running_vloss / (i + 1)
print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))
Expand Down
12 changes: 3 additions & 9 deletions experimental/torch_xla2/examples/basic_training_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torchvision
import torchvision.transforms as transforms
import torch_xla2
import torch_xla2.extra
import torch_xla2.interop
import jax
import optax
import numpy as np
Expand Down Expand Up @@ -91,7 +91,7 @@ def forward(self, x):

def jax_loss(weights, data, label):
pred = jax_func(weights, data)
loss = torch_xla2.extra.call_torch(loss_fn, pred, label)
loss = torch_xla2.interop.call_torch(loss_fn, pred, label)
return loss

grad_fn = jax.jit(jax.value_and_grad(jax_loss))
Expand Down Expand Up @@ -155,12 +155,6 @@ def train_one_epoch(jax_weights, opt_state, epoch_index, tb_writer):
# Make sure gradient tracking is on, and do a pass over the data
model.train(True)

# NEW: Move model to XLA device
state_dict = model.state_dict()
state_dict = pytree.tree_map_only(torch.Tensor,
torch_xla2.tensor.move_to_device, state_dict)
model.load_state_dict(state_dict, strict=False, assign=True)

avg_loss, opt_state = train_one_epoch(jax_weights, opt_state, epoch_number, writer)

running_vloss = 0.0
Expand All @@ -174,7 +168,7 @@ def train_one_epoch(jax_weights, opt_state, epoch_index, tb_writer):

vinputs, vlabels = pytree.tree_map_only(torch.Tensor, torch_xla2.tensor.t2j, vdata)
voutputs = jax_func(jax_weights, (vinputs, )) # call model's forward
vloss = torch_xla2.extra.call_torch(loss_fn, voutputs, vlabels)
vloss = torch_xla2.interop.call_torch(loss_fn, voutputs, vlabels)
running_vloss += vloss

avg_vloss = running_vloss / (i + 1)
Expand Down
13 changes: 6 additions & 7 deletions experimental/torch_xla2/examples/eager_mode.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@

from torch_xla2.tensor import move_to_device
import torch_xla2
from torch import nn
from torch.nn import functional as F
import torch
from torch.utils import _pytree as pytree

xla_env = torch_xla2.default_env()


class MyModel(nn.Module):
Expand All @@ -22,21 +21,21 @@ def forward(self, x):
return x

m = MyModel()
m = xla_env.to_xla(m)

# Execute this model using torch
inputs = (torch.randn(3, 3, 28, 28), )
inputs = xla_env.to_xla(inputs)

inputs, state_dict = pytree.tree_map_only(torch.Tensor, move_to_device, (inputs, m.state_dict()))
m.load_state_dict(state_dict, strict=False, assign=True)
print(m(*inputs))
print('---=====')

from torch_xla2.extra import jax_jit
from torch_xla2.interop import jax_jit

@jax_jit
def model_func(param, inputs):
return torch.func.functional_call(m, param, inputs)

print(model_func(state_dict, inputs))
print(model_func(m.state_dict(), inputs))


2 changes: 1 addition & 1 deletion experimental/torch_xla2/test/gemma/test_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def test_gemma(self):

weights, jax_func = torch_xla2.extract_jax(model)
inputs_jax = pytree.tree_map_only(
torch.Tensor, torch_xla2.tensor.move_to_device, inputs)
torch.Tensor, torch_xla2.tensor.t2j, inputs)

import jax
print(jax.jit(jax_func)(weights, inputs_jax))
Expand Down
5 changes: 1 addition & 4 deletions experimental/torch_xla2/test/llama/test_llama.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
import unittest
import jax
import torch
from torch._functorch.make_functional import make_functional_with_buffers
from torch_xla2 import tensor, ops # pylint: disable=unused-import
from torch_xla2 import tensor # pylint: disable=unused-import
import torch_xla2

from .. import test_base
Expand Down
8 changes: 5 additions & 3 deletions experimental/torch_xla2/test/test_context.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
import unittest

import torch
import torch_xla2
from torch_xla2 import tensor

xla_env = tensor.Environment(0)


class TestContext(unittest.TestCase):

def test_mode_context_manager(self):
with torch_xla2.mode():
with xla_env:
x = torch.full((3, 3), -1)
self.assertIsInstance(x, tensor.XLATensor2)
y = x.abs()
self.assertIsInstance(y, tensor.XLATensor2)

@staticmethod
@torch_xla2.mode()
@xla_env
def _test_mode_decorator():
x = torch.full((3, 3), -1)
y = x.abs()
Expand Down
17 changes: 9 additions & 8 deletions experimental/torch_xla2/test/test_core_aten_ops.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import unittest

import torch
from torch_xla2 import ops_registry
from torch_xla2 import tensor

from . import test_base
Expand Down Expand Up @@ -34,12 +33,13 @@ def run_export_and_compare(testcase,
rtol=1e-5,
equal_nan=True,
ignore_indices=False):

with testcase.subTest("torch_eval"):
res = func(*args, **kwargs)
with testcase.subTest("torch_xla2_eval"):
args2, kwargs2 = pytree.tree_map_only(torch.Tensor, tensor.move_to_device,
(args, kwargs))
res2 = func(*args2, **kwargs2)
args2, kwargs2 = testcase.env.to_xla((args, kwargs))
with testcase.env:
res2 = func(*args2, **kwargs2)
res2 = pytree.tree_map_only(tensor.XLATensor2, lambda t: t.torch(), res2)
# import pdb; pdb.set_trace()
with testcase.subTest("torch_xla2_diff:" + str(atol)):
Expand All @@ -61,11 +61,11 @@ class TestCoreAtenOps(unittest.TestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
ops_registry.print_missing_ops()

def setUp(self):
super().setUp()
torch.manual_seed(0)
self.env = tensor.Environment(0)

def test_aten_abs_0(self):
args = (torch.randn((10, 10)).to(torch.float32),)
Expand Down Expand Up @@ -2109,7 +2109,7 @@ def test_aten_logit_0(self):
def test_aten_logit_1(self):
args = (torch.randn((10, 10)).to(torch.float16),)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.logit, args, kwargs)
run_export_and_compare(self, torch.ops.aten.logit, args, kwargs, atol=0.01,)

def test_aten_logit_2(self):
args = (torch.randint(0, 10, (10, 10)).to(torch.int32),)
Expand Down Expand Up @@ -3639,8 +3639,9 @@ def test_aten__softmax_1(self):
def _compare_sorted_result(self, args):
res = torch.ops.aten.sort(*args)
with self.subTest("torch_xla2_eval"):
args2 = pytree.tree_map_only(torch.Tensor, tensor.move_to_device, args)
res2 = torch.ops.aten.sort(*args2)
args2 = self.env.to_xla(args)
with self.env:
res2 = torch.ops.aten.sort(*args2)

# The second argument is the sorted index. These might not be
# identical from torch vs. jax; but both can be correct
Expand Down
64 changes: 0 additions & 64 deletions experimental/torch_xla2/test/test_extra.py

This file was deleted.

6 changes: 4 additions & 2 deletions experimental/torch_xla2/test/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
from absl.testing import parameterized
import torch
import torch_xla2
import torch_xla2.functions
import torch_xla2.tensor


class TestTorchFunctions(parameterized.TestCase):

def setUp(self):
self.env = torch_xla2.tensor.Environment(0)

@parameterized.named_parameters(
('tensor_2d', lambda: torch.tensor([[0.1, 1.2], [2.2, 3.1], [4.9, 5.2]])),
('tensor_1d', lambda: torch.tensor([0, 1],)),
Expand All @@ -32,7 +34,7 @@ class TestTorchFunctions(parameterized.TestCase):
def test_tensor_constructor(self, func: Callable[[], torch.Tensor]):
expected = func()

with torch_xla2.functions.XLAFunctionMode():
with self.env:
actual = func()
self.assertIsInstance(actual, torch_xla2.tensor.XLATensor2)

Expand Down
Loading

0 comments on commit 825ba0d

Please sign in to comment.