diff --git a/examples/aot_mlp/README.md b/examples/aot_mlp/README.md new file mode 100644 index 00000000..e8057b70 --- /dev/null +++ b/examples/aot_mlp/README.md @@ -0,0 +1,23 @@ +# AOT MLP Example + +This example demonstrates export, compilation, and inference of +a simple Multi-Layer Perceptron (MLP) model. +The model is a four-layer neural network. + +To run this example, you should clone the repository to your local device and +install the requirements in a virtual environment: + +```bash +git clone https://github.com/iree-org/iree-turbine.git +cd iree-turbine/examples/aot_mlp +python -m venv mlp.venv +source ./mlp.venv/bin/activate +pip install -r requirements.txt +``` + +Once the requirements are installed, you should be able to run the example. + +```bash +python mlp_export_simple.py +``` + diff --git a/examples/aot_mlp/mlp_export_simple.py b/examples/aot_mlp/mlp_export_simple.py index 30d7ae95..b9ae79c9 100644 --- a/examples/aot_mlp/mlp_export_simple.py +++ b/examples/aot_mlp/mlp_export_simple.py @@ -4,23 +4,44 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +import iree.runtime as rt import logging -import unittest +import numpy as np import torch import torch.nn as nn +import unittest import iree.turbine.aot as aot +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + + class MLP(nn.Module): - def __init__(self): + """ + Multi-Layer Perceptron (MLP) model class. + Defines a neural network with four linear layers and sigmoid activations. + """ + + def __init__(self) -> None: super().__init__() + # Define model layers self.layer0 = nn.Linear(8, 8, bias=True) self.layer1 = nn.Linear(8, 4, bias=True) self.layer2 = nn.Linear(4, 2, bias=True) self.layer3 = nn.Linear(2, 2, bias=True) - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the MLP model. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor after forward pass. + """ x = self.layer0(x) x = torch.sigmoid(x) x = self.layer1(x) @@ -38,25 +59,38 @@ def forward(self, x: torch.Tensor): compiled_binary = exported.compile(save_to=None) -def infer(): - import numpy as np - import iree.runtime as rt +def run_inference() -> np.ndarray: + """ + Runs inference on the compiled model. + Returns: + np.ndarray: The result of inference as a NumPy array. + """ config = rt.Config("local-task") vmm = rt.load_vm_module( - rt.VmModule.wrap_buffer(config.vm_instance, compiled_binary.map_memory()), + rt.VmModule.wrap_buffer( + config.vm_instance, compiled_binary.map_memory() + ), config, ) x = np.random.rand(97, 8).astype(np.float32) y = vmm.main(x) - print(y.to_host()) + logger.debug(f"Inference result: {y.to_host()}") + return y.to_host() class ModelTest(unittest.TestCase): - def testMLPExportSimple(selfs): - infer() + def test_mlp_export_simple(self) -> None: + """Tests if the model export and inference work as expected.""" + output = run_inference() + self.assertIsNotNone(output, "Inference output should not be None") + self.assertEqual( + output.shape, (97, 2), + "Output shape doesn't match the expected (97, 2)" + ) if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) + # Run unit tests unittest.main() + diff --git a/examples/aot_mlp/requirements.txt b/examples/aot_mlp/requirements.txt new file mode 100644 index 00000000..aa30906e --- /dev/null +++ b/examples/aot_mlp/requirements.txt @@ -0,0 +1,4 @@ +numpy +torch +iree-turbine +