Skip to content

Commit

Permalink
Look for nvcc in CUDA_HOME
Browse files Browse the repository at this point in the history
  • Loading branch information
amiller27 committed Aug 26, 2024
1 parent daf9628 commit dcef513
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions bindings/torch/setup.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os

from pathlib import Path
import re
from setuptools import setup
from pkg_resources import parse_version
Expand All @@ -8,7 +9,7 @@
import sys
import torch
from glob import glob
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME

SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
ROOT_DIR = os.path.dirname(os.path.dirname(SCRIPT_DIR))
Expand Down Expand Up @@ -80,9 +81,14 @@ def find_cl_path():
cpp_standard = 14

# Get CUDA version and make sure the targeted compute capability is compatible
if os.system("nvcc --version") == 0:
nvcc_out = subprocess.check_output(["nvcc", "--version"]).decode()
cuda_version = re.search(r"release (\S+),", nvcc_out)
nvcc_version_result = subprocess.run(
[str(Path(CUDA_HOME) / "bin" / "nvcc"), "--version"],
text=True,
check=False,
stdout=subprocess.PIPE,
)
if nvcc_version_result.returncode == 0:
cuda_version = re.search(r"release (\S+),", nvcc_version_result.stdout)

if cuda_version:
cuda_version = parse_version(cuda_version.group(1))
Expand Down

0 comments on commit dcef513

Please sign in to comment.