Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Triton #6798

Merged
merged 67 commits into from
Jun 7, 2024
Merged

Triton #6798

Show file tree
Hide file tree
Changes from 65 commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
6dccf0a
Update infra_triggers.tf
ManfeiBai Oct 4, 2023
9828123
Skeleton trition support
bhavya01 Mar 20, 2024
99bf48d
Merge branch 'master' into triton
bhavya01 Mar 20, 2024
b89e558
Fix bugs
bhavya01 Mar 21, 2024
64189bd
Fix custom call invocation
bhavya01 Mar 21, 2024
0c208ef
Refactor to include gpu custom call and create triton dir
bhavya01 Mar 22, 2024
b553ba7
Lint fixes
bhavya01 Mar 22, 2024
c5129e6
python lint fix
bhavya01 Mar 22, 2024
48e7127
Updated base image for CI
bhavya01 Mar 27, 2024
e04fc97
Update github workflow gcr image
bhavya01 Mar 28, 2024
37bf127
Merge branch 'master' into custom
bhavya01 Mar 28, 2024
6061895
Remove xrt build and test file
bhavya01 Mar 28, 2024
f59ddbf
Add temporary test to run triton kernel
bhavya01 Mar 28, 2024
158aed4
Fix tests
bhavya01 Mar 28, 2024
87b92c5
Update payload for xla gpu custom call
bhavya01 Mar 29, 2024
847ccc5
Update gpu runner
bhavya01 Mar 29, 2024
eca6d52
Merge branch 'master' into triton
bhavya01 Apr 4, 2024
2348ca3
Extract payload from triton kernel programatically
bhavya01 Apr 12, 2024
110c8c6
Merge branch 'master' into triton
bhavya01 Apr 12, 2024
a226150
Lint fixes
bhavya01 Apr 12, 2024
4c1f4f5
Only build triton files for GPU
bhavya01 Apr 12, 2024
431f822
build pytorch for ampere gpus
bhavya01 Apr 13, 2024
4bade16
c++ lint fix
bhavya01 Apr 13, 2024
1c5b47d
Python lint fix
bhavya01 Apr 13, 2024
3138a92
Fix torch cuda arch list
bhavya01 Apr 13, 2024
3f00cfd
Use a bigger machine for CI build
bhavya01 Apr 13, 2024
e729cfb
Add triton test to run_tests.sh
bhavya01 Apr 13, 2024
8e304c0
Update triton env variable
bhavya01 Apr 15, 2024
27bdc3a
Set up a separate CI for triton tests
bhavya01 Apr 15, 2024
9a3ef84
Fix github workflow to add _triton.yml
bhavya01 Apr 15, 2024
ade444d
Rebuild torch xla for triton tests
bhavya01 Apr 15, 2024
cb0bb85
Create a separate CI tab for triton tests
bhavya01 Apr 16, 2024
015b1ad
Separate build and test phase for triton
bhavya01 Apr 16, 2024
a18028a
Fix flags for docker run container
bhavya01 Apr 16, 2024
993ee92
Update triton.yml to output docker image
bhavya01 Apr 16, 2024
a87b782
Add a python binding to register custom calls and remove jax files
bhavya01 May 10, 2024
bf05d1b
Fix lint
bhavya01 May 10, 2024
4582fe8
Merge main
bhavya01 May 10, 2024
9680167
Merge master
bhavya01 May 10, 2024
a7b94c6
Merge master after updating
bhavya01 May 10, 2024
e14636a
Update CI to use cuda plugin
bhavya01 May 10, 2024
256d819
Install jaxlib while setting up triton tests
bhavya01 May 10, 2024
c616e64
Install triton package while running triton tests
bhavya01 May 10, 2024
60b8d18
Experimental: Build pytorch with cuda
bhavya01 May 13, 2024
2bde624
Revert build pytorch with CUDA
bhavya01 May 14, 2024
e6c4e0a
Merge branch 'master' into triton
bhavya01 May 14, 2024
14ee545
Remove ansible path for triton CI
bhavya01 May 14, 2024
25acb26
Style fixes
bhavya01 May 20, 2024
6b0ac18
[Experimental] test new CI
bhavya01 May 28, 2024
4d97150
[Experimental] Set XLA_CUDA=0 for cuda arch in ansible
bhavya01 May 28, 2024
e079049
[Experimental] Update CI to build pytorch cuda with ansible
bhavya01 May 29, 2024
d9c89b6
Update CI
bhavya01 May 30, 2024
7a6c809
Fix CI workflow file
bhavya01 May 30, 2024
6b1954d
Fix CI workflow
bhavya01 May 30, 2024
21797a6
Fix the wheels installed for tests requiring torch cuda
bhavya01 May 30, 2024
e6e89d3
Add compute_capability=8.6 for xla cuda plugin
bhavya01 May 31, 2024
ac45fe1
update TORCH_CUDA_ARCH_LIST
bhavya01 May 31, 2024
f828fbb
Experimental build torch and torch_xla cuda wheels
bhavya01 May 31, 2024
ac56c00
Merge branch 'master' into triton
bhavya01 May 31, 2024
c3b8653
Update build_and_test.yml
bhavya01 May 31, 2024
a1168c6
Update dlpack test to only use one device
bhavya01 May 31, 2024
39551a2
Remove compute capability 8.6 from cuda plugin
bhavya01 May 31, 2024
35e0869
Remove triton.sh
bhavya01 May 31, 2024
f95d898
Default empty torch_cuda_arch_list in ansible config
bhavya01 May 31, 2024
291104d
Merge branch 'master' into triton
bhavya01 Jun 5, 2024
f5c9b1a
Revert CI changes
bhavya01 Jun 6, 2024
5b23969
Revert CI changes pt2
bhavya01 Jun 6, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 23 additions & 11 deletions .github/workflows/_build_torch_with_cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,49 +7,61 @@ on:
type: string
description: Base image for builds
torch-commit:
required: true
type: string
description: torch-commit
required: true
type: string
description: torch-commit
runner:
required: false
type: string
description: Runner type for the test
default: linux.12xlarge
secrets:
gcloud-service-key:
required: true
description: Secret to access Bazel build cache
jobs:
build:
runs-on: ${{ inputs.runner }}
container:
image: ${{ inputs.dev-image }}
options: "--gpus all --shm-size 16g"
env:
GCLOUD_SERVICE_KEY: ${{ secrets.gcloud-service-key }}
GOOGLE_APPLICATION_CREDENTIALS: /tmp/default_credentials.json
BAZEL_JOBS: 16
BAZEL_REMOTE_CACHE: 1
_GLIBCXX_USE_CXX11_ABI: 0
steps:
# See https://github.com/actions/checkout/issues/1014#issuecomment-1906802802
- name: Clean up workspace
run: |
ls -la
rm -rvf ${GITHUB_WORKSPACE}/*
- name: Setup gcloud
shell: bash
run: |
echo "${GCLOUD_SERVICE_KEY}" > $GOOGLE_APPLICATION_CREDENTIALS
- name: Setup CUDA environment
shell: bash
run: |
echo "PATH=$PATH:/usr/local/cuda-12.1/bin" >> $GITHUB_ENV
echo "LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-12.1/lib64" >> $GITHUB_ENV
- name: Check GPU
run: nvidia-smi
- name: Checkout PyTorch Repo
uses: actions/checkout@v4
with:
repository: pytorch/pytorch
path: pytorch
ref: ${{ inputs.torch-commit }}
submodules: recursive
- name: Checkout PyTorch/XLA Repo
uses: actions/checkout@v4
with:
path: pytorch/xla
- name: Build
shell: bash
run: |
cd pytorch
USE_CUDA=1 python setup.py bdist_wheel
cd pytorch/xla/infra/ansible
ansible-playbook playbook.yaml -vvv -e "stage=build arch=amd64 accelerator=cuda src_root=${GITHUB_WORKSPACE} cuda_compute_capabilities=compute_86 torch_cuda_arch_list=8.6 build_pytorch_with_cuda=1 bundle_libtpu=0 build_cpp_tests=1 git_versioned_xla_build=1 cache_suffix=-ci" --skip-tags=fetch_srcs,install_deps
- name: Upload wheel
uses: actions/upload-artifact@v4
with:
name: torch-with-cuda
path: pytorch/dist/*.whl
name: torch-cuda-wheels
path: /dist/*.whl
33 changes: 13 additions & 20 deletions .github/workflows/_test_requiring_torch_cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ on:
timeout-minutes:
required: false
type: number
default: 30
default: 270
description: |
Set the maximum (in minutes) how long the workflow should take to finish
timeout-minutes:
Expand All @@ -46,30 +46,14 @@ jobs:
run: |
ls -la
rm -rvf ${GITHUB_WORKSPACE}/*
- name: Fetch torch/torch_xla/torchvision wheels
uses: actions/download-artifact@v4
with:
name: torch-xla-wheels
path: /tmp/wheels/
- name: Remove torch wheel built with CUDA disabled
shell: bash
run: |
rm -rf /tmp/wheels/torch-*
- name: Fetch the torch wheel built with CUDA enabled
uses: actions/download-artifact@v4
with:
name: torch-with-cuda
path: /tmp/wheels/
- name: Fetch CUDA plugin
uses: actions/download-artifact@v4
with:
name: cuda-plugin
name: torch-cuda-wheels
path: /tmp/wheels/
- name: Setup CUDA environment
shell: bash
run: |
echo "XLA_REGISTER_INSTALLED_PLUGINS=1" >> $GITHUB_ENV

echo "PATH=$PATH:/usr/local/cuda-12.1/bin" >> $GITHUB_ENV
echo "LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-12.1/lib64" >> $GITHUB_ENV
- name: Check GPU
Expand All @@ -81,7 +65,6 @@ jobs:
# TODO: Add these in setup.py
pip install fsspec
pip install rich

echo "Import check..."
python -c "import torch, torch_xla, torchvision"
echo "Import check done."
Expand All @@ -98,9 +81,19 @@ jobs:
uses: actions/checkout@v4
with:
path: pytorch/xla
- name: Extra CI deps
shell: bash
run: |
set -x
pip install expecttest unittest-xml-reporting
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
pip install -U --pre jax-cuda12-pjrt jax-cuda12-plugin -f https://storage.googleapis.com/jax-releases/jax_cuda_plugin_nightly_releases.html
pip install -U --pre jax -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
pip install --no-deps triton==2.3.0
- name: Test
shell: bash
run: |
set -xue
PJRT_DEVICE=CUDA TRITON_PTXAS_PATH=/usr/local/cuda-12.1/bin/ptxas python pytorch/xla/test/test_triton.py
PJRT_DEVICE=CUDA python pytorch/xla/test/test_operations.py -v
PJRT_DEVICE=CUDA python pytorch/xla/test/dynamo/test_dynamo.py -v
PJRT_DEVICE=CUDA python pytorch/xla/test/dynamo/test_dynamo.py -v
26 changes: 14 additions & 12 deletions .github/workflows/build_and_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ jobs:
# note that to build a torch wheel with CUDA enabled, we do not need a GPU runner.
dev-image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:3.10_cuda_12.1
torch-commit: ${{needs.get-torch-commit.outputs.torch_commit}}
runner: linux.8xlarge.nvidia.gpu
runner: linux.24xlarge
secrets:
gcloud-service-key: ${{ secrets.GCLOUD_SERVICE_KEY }}

build-cuda-plugin:
name: "Build XLA CUDA plugin"
Expand All @@ -52,6 +54,17 @@ jobs:
secrets:
gcloud-service-key: ${{ secrets.GCLOUD_SERVICE_KEY }}

test-cuda-with-pytorch-cuda-enabled:
name: "GPU tests requiring torch CUDA"
uses: ./.github/workflows/_test_requiring_torch_cuda.yml
needs: [build-torch-with-cuda, get-torch-commit]
with:
dev-image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:3.10_cuda_12.1
runner: linux.g5.4xlarge.nvidia.gpu
timeout-minutes: 300
collect-coverage: false
torch-commit: ${{needs.get-torch-commit.outputs.torch_commit}}

test-python-cpu:
name: "CPU tests"
uses: ./.github/workflows/_test.yml
Expand All @@ -78,17 +91,6 @@ jobs:
secrets:
gcloud-service-key: ${{ secrets.GCLOUD_SERVICE_KEY }}

test-cuda-with-pytorch-cuda-enabled:
name: "GPU tests requiring torch CUDA"
uses: ./.github/workflows/_test_requiring_torch_cuda.yml
needs: [build-torch-with-cuda, build-torch-xla, build-cuda-plugin, get-torch-commit]
with:
dev-image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:3.10_cuda_12.1
runner: linux.8xlarge.nvidia.gpu
timeout-minutes: 300
collect-coverage: false
torch-commit: ${{needs.get-torch-commit.outputs.torch_commit}}

test-tpu:
name: "TPU tests"
uses: ./.github/workflows/_tpu_ci.yml
Expand Down
2 changes: 2 additions & 0 deletions infra/ansible/config/env.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ build_env:

cuda:
TF_CUDA_COMPUTE_CAPABILITIES: "{{ cuda_compute_capabilities }}"
TORCH_CUDA_ARCH_LIST: "{{ torch_cuda_arch_list | default('') }}"
XLA_CUDA: 1
USE_CUDA: "{{ build_pytorch_with_cuda | default(0) }}"
bhavya01 marked this conversation as resolved.
Show resolved Hide resolved

tpu:
ACCELERATOR: tpu
Expand Down
3 changes: 1 addition & 2 deletions infra/ansible/roles/build_srcs/tasks/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@
cmd: python setup.py bdist_wheel
chdir: "{{ (src_root, 'pytorch') | path_join }}"
creates: "{{ (src_root, 'pytorch/dist/*.whl') | path_join }}"
# Set `USE_CUDA=0` as PyTorch cannot be used with GPU in eager and XLA mode.
environment: "{{ env_vars | combine({'USE_CUDA': 0}) }}"
environment: "{{ env_vars }}"

- name: Find PyTorch *.whl files in pytorch/dist
ansible.builtin.find:
Expand Down
23 changes: 0 additions & 23 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2711,18 +2711,6 @@ def test_dlpack_pytorch_cuda_to_xla(self):
t2_cuda.fill_(6)
self.assertTrue(torch.allclose(xla_t2.cpu(), t2_cuda.cpu()))

cuda1 = torch.device('cuda:1')
bhavya01 marked this conversation as resolved.
Show resolved Hide resolved
t3_cuda = torch.tensor(5, device=cuda1)
dlt3 = torch.utils.dlpack.to_dlpack(t3_cuda)
xla_t3 = xdlpack.from_dlpack(dlt3)
self.assertEqual(xla_t3.device.type, 'xla')
self.assertEqual(
xla_t3.device.index,
t3_cuda.device.index,
msg='both value should 1. xla_t3.device should be xla:1.')
t3_cuda.fill_(6)
self.assertTrue(torch.allclose(xla_t3.cpu(), t3_cuda.cpu()))

@onlyIfTorchSupportsCUDA
@onlyIfPJRTDeviceIsCUDA
def test_dlpack_pytorch_cuda_to_xla_protocol_conversion(self):
Expand All @@ -2743,17 +2731,6 @@ def test_dlpack_pytorch_cuda_to_xla_protocol_conversion(self):
t2_cuda.fill_(6)
self.assertTrue(torch.allclose(xla_t2.cpu(), t2_cuda.cpu()))

cuda1 = torch.device('cuda:1')
t3_cuda = torch.tensor(5, device=cuda1)
xla_t3 = xdlpack.from_dlpack(t3_cuda)
self.assertEqual(xla_t3.device.type, 'xla')
self.assertEqual(
xla_t3.device.index,
t3_cuda.device.index,
msg='both value should 1. xla_t3.device should be xla:1.')
t3_cuda.fill_(6)
self.assertTrue(torch.allclose(xla_t3.cpu(), t3_cuda.cpu()))

@onlyIfTorchSupportsCUDA
@onlyIfPJRTDeviceIsCUDA
def test_dlpack_xla_to_pytorch_cuda(self):
Expand Down
68 changes: 68 additions & 0 deletions test/test_triton.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import logging
import torch
from torch import nn as nn
import unittest

import torch_xla.experimental.triton as xla_triton
import torch_xla
from torch_xla import runtime as xr

import triton
import triton.language as tl


@triton.jit
def add_kernel(
x_ptr, # *Pointer* to first input vector.
y_ptr, # *Pointer* to second input vector.
output_ptr, # *Pointer* to output vector.
n_elements, # Size of the vector.
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process.
# NOTE: `constexpr` so it can be used as a shape value.
):
# Triton add kernel from https://github.com/openai/triton/blob/main/python/tutorials/01-vector-add.py#L28
# There are multiple 'programs' processing different data. We identify which program
# we are here:
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.
# This program will process inputs that are offset from the initial data.
# For instance, if you had a vector of length 256 and block_size of 64, the programs
# would each access the elements [0:64, 64:128, 128:192, 192:256].
# Note that offsets is a list of pointers:
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# Create a mask to guard memory operations against out-of-bounds accesses.
mask = offsets < n_elements
# Load x and y from DRAM, masking out any extra elements in case the input is not a
# multiple of the block size.
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y
# Write x + y back to DRAM.
tl.store(output_ptr + offsets, output, mask=mask)


class TritonTest(unittest.TestCase):

@unittest.skipIf(xr.device_type() != 'CUDA', "This test only works on GPU.")
def test_gpu_custom_call_triton_add(self):
size = 16

x = torch.arange(size, dtype=torch.int64).to("xla")
y = torch.arange(size, dtype=torch.int64).to("xla")
output = torch.empty_like(x)
block_size = 8
grid = (triton.cdiv(size, block_size),)
payload = xla_triton.triton_call(
x, y, output, size, kernel=add_kernel, grid=grid, BLOCK_SIZE=block_size)
output = torch_xla._XLAC._xla_gpu_custom_call([x, y], payload,
[output.shape], [torch.int64])
output_torch = x + y
self.assertTrue(torch.allclose(output[0].cpu(), output_torch.cpu()))


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
torch.set_default_dtype(torch.float32)
torch.manual_seed(42)
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
1 change: 1 addition & 0 deletions torch_xla/csrc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ ptxla_cc_library(
"@xla//xla/service:hlo_verifier",
"@xla//xla/service:sharding_propagation",
"@xla//xla/service/spmd:spmd_partitioner",
"@xla//xla/service:custom_call_target_registry",
],
)

Expand Down
45 changes: 35 additions & 10 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
#include "tsl/profiler/lib/traceme.h"
#include "xla/pjrt/distributed/distributed.h"
#include "xla/python/profiler/internal/traceme_wrapper.h"
#include "xla/service/custom_call_target_registry.h"
#include "xla/service/hlo_parser.h"

namespace torch_xla {
Expand Down Expand Up @@ -202,6 +203,24 @@ std::vector<std::vector<int64_t>> CreateReduceGroups(const py::list& groups) {
return replica_groups;
}

std::vector<at::Tensor> XlaCustomCall(
const std::vector<at::Tensor>& inputs, const std::string& payload,
const std::vector<std::vector<int64_t>>& output_shapes,
const std::vector<py::object>& output_dtypes, bool is_tpu) {
std::vector<at::ScalarType> dtypes;
dtypes.reserve(output_dtypes.size());
for (auto& dtype : output_dtypes) {
dtypes.push_back(reinterpret_cast<THPDtype*>(dtype.ptr())->scalar_type);
}

if (is_tpu) {
return bridge::AtenFromXlaTensors(tensor_methods::tpu_custom_call(
bridge::GetXlaTensors(inputs), payload, output_shapes, dtypes));
}
return bridge::AtenFromXlaTensors(tensor_methods::gpu_custom_call(
bridge::GetXlaTensors(inputs), payload, output_shapes, dtypes));
}

std::vector<std::pair<int64_t, int64_t>> CreateSourceTargetPairs(
const py::list& pairs) {
std::vector<std::pair<int64_t, int64_t>> source_target_pairs;
Expand Down Expand Up @@ -2401,16 +2420,22 @@ void InitXlaModuleBindings(py::module m) {
const std::vector<std::vector<int64_t>>& output_shapes,
const std::vector<py::object>& output_dtypes)
-> std::vector<at::Tensor> {
std::vector<at::ScalarType> dtypes;
dtypes.reserve(output_dtypes.size());
for (auto& dtype : output_dtypes) {
dtypes.push_back(
reinterpret_cast<THPDtype*>(dtype.ptr())->scalar_type);
}

auto xtensors = tensor_methods::tpu_custom_call(
bridge::GetXlaTensors(inputs), payload, output_shapes, dtypes);
return bridge::AtenFromXlaTensors(xtensors);
return XlaCustomCall(inputs, payload, output_shapes, output_dtypes,
/*is_tpu=*/true);
});
m.def("_xla_gpu_custom_call",
[](const std::vector<at::Tensor>& inputs, const std::string& payload,
const std::vector<std::vector<int64_t>>& output_shapes,
const std::vector<py::object>& output_dtypes)
-> std::vector<at::Tensor> {
return XlaCustomCall(inputs, payload, output_shapes, output_dtypes,
/*is_tpu=*/false);
});
m.def("_xla_register_custom_call_target",
[](const std::string& fn_name, const py::capsule& function_ptr,
const std::string& platform) {
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(
fn_name, function_ptr.get_pointer(), platform);
});
m.def("_set_xla_custom_op_name_prefix",
[](const at::Tensor& input, const std::string& op_name_prefix,
Expand Down
Loading
Loading