Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

In order to be compatible with iree-turbine, make iree-turbine can training #13

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions examples/mlp_train/ut_mlp_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

device = 'cuda'

# [ y = W_n * x_n + W_{n-1} * x_{n-1} + ... + W_1 * x_1 + b ]
torch.cuda.manual_seed_all(0)
x = torch.linspace(-1, 1, 100).reshape(-1)
y = 3 * x + 2 + torch.randn(x.size()) * 0.2

# cvt to tensor
x = torch.tensor(x, dtype=torch.float32).to(device)
y = torch.tensor(y, dtype=torch.float32).to(device)
print(x)
class SimpleMLP(nn.Module):
def __init__(self):
super(SimpleMLP, self).__init__()
self.weight = nn.Parameter(torch.randn(1, requires_grad=True))
print(self.weight)
self.bias = nn.Parameter(torch.randn(1, requires_grad=True))

def forward(self, x : torch.Tensor):
out = x * self.weight + self.bias
return out


# model = SimpleMLP().to(device)
mod = SimpleMLP().to(device)

model = torch.compile(mod, backend='turbine_cpu')

learning_rate = 0.1
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
loss_func = nn.MSELoss()

epochs = 2000
for epoch in range(epochs):
y_pred = model(x)
# print(y_pred)

loss = loss_func(y_pred.to(device), y.to(device))

optimizer.zero_grad()
# loss = y_pred.sum()
# loss = loss.to(device)
loss.backward()

optimizer.step()

if (epoch + 1) % 10 == 0:
print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')

predicted = model(x).detach().cpu().numpy()
plt.plot(x.cpu().numpy(), y.cpu().numpy(), 'ro', label='Original data')
plt.plot(x.cpu().numpy(), predicted, label='Fitted line')
plt.legend()
plt.savefig('fitting_result.png')
plt.close()
59 changes: 51 additions & 8 deletions shark_turbine/dynamo/backends/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import functools
import sys
import os

from ...runtime.device import (
DeviceState,
Expand All @@ -16,6 +17,7 @@
)

from iree.compiler.api import (
_initializeGlobalCL,
Invocation,
Session,
Source,
Expand All @@ -38,11 +40,31 @@
import torch
from torch._dynamo.backends.common import aot_autograd
from ..passes import turbine_cpu_pass_pipeline
from typing import Any, List
from functorch.compile import min_cut_rematerialization_partition

DEFAULT_COMPILER_FLAGS = ("--iree-input-type=torch",)
DEFAULT_COMPILER_FLAGS = (
"--iree-input-type=torch",
)

global_cl_options = []
if os.getenv("mlir_print_ir_after_all") is not None:
global_cl_options.append("--mlir-print-ir-after-all")
global_cl_options.append("--mlir-print-ir-after-change")

if os.getenv("mlir_print_ir_before_all") is not None:
global_cl_options.append("--mlir-print-ir-before-all")


if len(global_cl_options) != 0:
_initializeGlobalCL("dynamo", *global_cl_options)

def device_from_inputs(example_inputs) -> torch.device:
for x in example_inputs:
if hasattr(x, "device"):
return x.device

def _base_backend(gm: torch.fx.GraphModule, example_inputs):
def _base_backend(gm: torch.fx.GraphModule, example_inputs, is_fw=True):
# Set up the session, context and invocation.
# Note that we do this on one in-memory module in a few phases:
# 1. Build it from the FX graph.
Expand All @@ -52,7 +74,18 @@ def _base_backend(gm: torch.fx.GraphModule, example_inputs):
# 4. Output to an mmap buffer.
session = Session()
session.set_flags(*DEFAULT_COMPILER_FLAGS)
session.set_flags("--iree-hal-target-backends=llvm-cpu")

device = device_from_inputs(example_inputs)


device_index = None
device_type = device.type
if device_type == "cpu":
session.set_flags("--iree-hal-target-backends=llvm-cpu")
elif device_type == "cuda":
device_index = device.index
session.set_flags("--iree-hal-target-backends=cuda")

context = session.context
importer = FxImporter(context=context)
module = importer.module
Expand All @@ -65,6 +98,8 @@ def _base_backend(gm: torch.fx.GraphModule, example_inputs):
gm = turbine_cpu_pass_pipeline(gm, example_inputs)

# Import phase.
print("before import graph")
print(gm.print_readable(), file=sys.stderr)
importer.import_graph_module(gm)
print(module, file=sys.stderr)
with context:
Expand All @@ -80,7 +115,7 @@ def _base_backend(gm: torch.fx.GraphModule, example_inputs):
inv.output_vm_bytecode(output)

# Set up for runtime.
device_state = _get_device_state()
device_state = _get_device_state(device_type, device_index)
# TODO: Switch to wrap_buffer once https://github.com/openxla/iree/issues/14926
# is fixed.
# vmfb_module = VmModule.wrap_buffer(
Expand All @@ -94,14 +129,22 @@ def _base_backend(gm: torch.fx.GraphModule, example_inputs):
)
output.close()

return SpecializedExecutable(vmfb_module, device_state)
return SpecializedExecutable(vmfb_module, device_state, importer.anticipated_return_value)

def _base_backend_fw(gm: torch.fx.GraphModule, example_inputs):
return _base_backend(gm, example_inputs, is_fw=True)

backend = aot_autograd(fw_compiler=_base_backend)
def _base_backend_bw(gm: torch.fx.GraphModule, example_inputs):
return _base_backend(gm, example_inputs, is_fw=False)

backend = aot_autograd(fw_compiler=_base_backend_fw, bw_compiler=_base_backend_bw, partition_fn=functools.partial(min_cut_rematerialization_partition, compiler="turbine_cpu"))

# IREE runtime globals. For the CPU right now, there is no device selection,
# so it is easy.
@functools.lru_cache(maxsize=None)
def _get_device_state() -> DeviceState:
return DeviceState(driver="local-task")
def _get_device_state(device_type, device_index) -> DeviceState:
if device_type == "cpu":
return DeviceState(driver="local-task")
elif device_type == "cuda":
return DeviceState(driver="cuda", enumerated_info={'device_id':device_index})

47 changes: 34 additions & 13 deletions shark_turbine/dynamo/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import os
from typing import List, Optional, Sequence, Union
from dataclasses import dataclass
import torch.nn as nn
from iree.runtime import (
asdevicearray,
create_hal_module,
Expand All @@ -31,7 +32,7 @@
)

from ..runtime.device import Device, DeviceState

from ..dynamo.tensor import dtype_to_element_type

@functools.lru_cache(maxsize=None)
def get_vm_instance() -> VmInstance:
Expand Down Expand Up @@ -64,12 +65,14 @@ class SpecializedExecutable:
"entry_function",
"user_module",
"vm_context",
"anticipated_return_value",
]

def __init__(
self,
user_module: VmModule,
device_state: DeviceState,
anticipated_return_value: list[bool],
entry_name: str = "main",
):
self.user_module = user_module
Expand All @@ -81,6 +84,7 @@ def __init__(
),
)
self.device_state = device_state
self.anticipated_return_value = anticipated_return_value
self.entry_function = self.user_module.lookup_function(entry_name)

def __call__(self, *inputs):
Expand All @@ -101,26 +105,43 @@ def _inputs_to_device(self, inputs: list, arg_list: VmVariantList):
# TODO: We are assuming the worst case here which is that we have unknown Torch
# tensors that we send to the CPU and make continguous. Ideally, we would have
# fast paths for our own backends and interop.
device = self.device_state.device
device_name = self.device_state.torch_device
for input in inputs:
input_cpu = input.cpu().contiguous()
# Since this is already a fallback case, just use the numpy array interop.
# It isn't great, but meh... fallback case.
device_array = asdevicearray(self.device_state.device, input_cpu)
arg_list.push_ref(device_array._buffer_view)

# input_cpu = input.cpu().contiguous()
# # Since this is already a fallback case, just use the numpy array interop.
# # It isn't great, but meh... fallback case.
# device_array = asdevicearray(self.device_state.device, input_cpu)
# arg_list.push_ref(device_array._buffer_view)
if not input.is_contiguous():
input = input.cpu().contiguous()

if input.device.type.startswith("cpu"):
if device_name.startswith("cuda"):
input = input.to("cuda")

if(isinstance(input, nn.Parameter)):
buffer_view = device.allocator.import_buffer(device, input.data, dtype_to_element_type(input.dtype))
else:
buffer_view = device.allocator.import_buffer(device, input, dtype_to_element_type(input.dtype))
arg_list.push_ref(buffer_view)

def _returns_to_user(self, ret_list: VmVariantList):
# TODO: This is also not good that we are moving back to the CPU like this.
# We should be returning a custom Tensor implementation which represents
# our device data and has synchronization hooks for accessing it.
device = self.device_state.device
num_returns = len(ret_list)
# num_returns = len(ret_list)
num_returns = len(self.anticipated_return_value)
user_returns = [None] * num_returns
for i in range(num_returns):
device_buffer_view = HalBufferView.__iree_vm_cast__(ret_list.get_as_ref(i))
device_array = DeviceArray(device, device_buffer_view)
host_array = device_array.to_host()
user_returns[i] = torch_from_numpy(host_array) # type: ignore
ret_list_idx = 0 # self.anticipated_return_value could have None type elements, so here use ret_list_idx

for i in range(num_returns):
if self.anticipated_return_value[i]:
device_buffer_view = HalBufferView.__iree_vm_cast__(ret_list.get_as_ref(ret_list_idx))
ret_list_idx += 1
element_type = HalElementType(device_buffer_view.element_type)
user_returns[i] = device.allocator.export_buffer(device, device_buffer_view, element_type)
return user_returns


Expand Down