Skip to content

Latest commit

 

History

History
26 lines (19 loc) · 5.39 KB

OP_LOWERING_GUIDE.md

File metadata and controls

26 lines (19 loc) · 5.39 KB

OP Lowering Guide

Background

PyTorch wraps the C++ ATen tensor library that offers a wide range of operations implemented on GPU and CPU. Pytorch/XLA is a PyTorch extension; one of its purposes is to convert PyTorch operations to XLA operations. Lowering defines a process of converting a higher-level representation to a lower-level representation. In this document, I will refer to the process of converting PyTorch operation to XLA operation as the lowering. XLA Compiler will also lower XlaOp to HLO, but that’s beyond the scope of this documentation. We will forward operations that we haven’t provided an XLA lowering yet to CPU and call ATen implementations. Operations that are forwarded to the CPU will cause a significant slowdown. We must lower all operations used in the model to achieve the best performance.

Before you start

You should follow the instructions in here to install required dependencies and build pytorch and pytorch/XLA from the source. You do not need access to TPU to implement the lowering. It is recommended to experiment on a workstation and configure it to use XLA:CPU.

Understanding the operation

You can find the definition of the C++ ATen operations in native_functions.yaml. After you build Pytorch/XLA from source, you will also find our default implementation (forward to PyTorch native CPU) in xla/torch_xla/csrc/aten_xla_type_default.h/cpp. Pytorch operations can usually be mapped to PyTorch tensor api easily. If that is not the case searching the PyTorch native implementation under PyTorch repo is recommended. The goal is to lower the PyTorch operations into a sequence of XLA operations defined in here.

File structure

All file mentioned below lives under the xla/torch_xla/csrc folder, with the exception of xla_native_functions.yaml

  1. xla_native_functions.yaml contains the list of all operators that are lowered. Each operator name must directly match a pytorch operator listed in native_functions.yaml. This file serves as the interface to adding new xla operators, and is an input to PyTorch's codegen machinery. It generates the below 3 files: aten_xla_type.h, aten_xla_type_default.h, and aten_xla_type_default.cpp
  2. aten_xla_type.h/.cpp are entry points of PyTorch to the pytorch_xla world. aten_xla_type.h is auto-generated through a combination of xla_native_functions.yaml and the PyTorch core native_functions.yaml file, and contains declarations for kernels that need to be defined in aten_xla_type.cpp. The kernels written here need to construct 'XLATensor' using the input at::Tensor and other parameters. The resulting XLATensor needs to be converted back to the at::Tensor before returning to the PyTorch world.
  3. aten_xla_type_default.h/.cpp are also auto-generated, and contain our default implementation of the PyTorch operations which simply fall back to the underlying CPU implementation. Functions in here will be used if lowering is not explicitly defined in xla_native_functions.yaml + aten_xla_type.cpp.
  4. tensor.h contains the XLATensor declarations. These declarations are one to one mapping of the at::Tensor nodes we declared in aten_xla_type.h
  5. tensor_methods.cpp contains the implementation of XLATensor node defined in tensor.h. We constructed the corresponding ir::op from the parameter’s ir::Value and wrapped it inside a XLATensor. Ir stands for intermediate representation.
  6. ops/ directory contains all ir::ops declaration and definition. Smaller nodes can be put in ops/ops.h/.cpp. More complicated nodes can be put into a separate file. All ops inherit from ir::ops::Node and provide a way to lower input ir::Value to a sequence of XlaOp.

Unit Test

Our CircleCI runs PyTorch native python tests for every change and every day. Those tests will use XLA implementation if we provide a lowering. We usually don’t need to add additional python tests for PyTorch/XLA unless we want to verify some xla behaviors(like dynamic shape) or we skipped the pytorch native test for some reason. The python test should be added to xla/test/test_operations.py if it is required. We also need to add CPP tests in xla/test/cpp/test_aten_xla_tensor.cpp. This test should call PyTorch c++ API and verify our implementation yields the same result as PyTorch native implementation. We also need to verify if the xla implementation is called when the tensor is a XLA tensor by checking the aten::op and xla::op counters.

Tips

The process of lowering is breaking down the PyTorch operations into a sequence of XlaOp. To provide a good lowering of the PyTorch operation, one needs to have a good grasp of what XLA is capable of. Reading the XlaOp document and looking into how similar ops is lowered is the best way to achieve that. You can find a minimal Op lowering example in this pr. You can also find a slightly more complicated example with backward lowering in this pr.