diff --git a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/TaskGraph.java b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/TaskGraph.java index 95bc98af58..6d5b300507 100644 --- a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/TaskGraph.java +++ b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/TaskGraph.java @@ -645,6 +645,29 @@ public TaskGraph prebuiltTask(String id, String entryPoint, String filename, Acc return this; } + /** + * + * @param id + * Task-id + * @param entryPoint + * Kernel's name of the entry point + * @param klass + * Class that can access the resource within the JAR file. + * @param resource + * Input file that represents the kernel source. It could be either SPIR-V, OpenCL C, or PTX code. + * @param accessorParameters + * {@link AccessorParameters} that contains the accessor for each input and output parameter to the kernel. + * @return {@link TaskGraph}. + */ + @Override + public TaskGraph prebuiltTask(String id, String entryPoint, Class klass, String resource, AccessorParameters accessorParameters) { + checkTaskName(id); + PrebuiltTaskPackage prebuiltTask = TaskPackage.createPrebuiltTask(id, entryPoint, resource, accessorParameters); + prebuiltTask.withClass(klass); + taskGraphImpl.addPrebuiltTask(prebuiltTask); + return this; + } + /** * Obtains the task-schedule name that was assigned. * @@ -849,11 +872,11 @@ long getTotalBytesCopyOut() { return taskGraphImpl.getTotalBytesCopyOut(); } - protected String getProfileLog() { + String getProfileLog() { return taskGraphImpl.getProfileLog(); } - public Collection getOutputs() { + Collection getOutputs() { return taskGraphImpl.getOutputs(); } diff --git a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/TaskGraphInterface.java b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/TaskGraphInterface.java index 8de9590415..3ad647ab97 100644 --- a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/TaskGraphInterface.java +++ b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/TaskGraphInterface.java @@ -481,6 +481,8 @@ TaskGraphInte * Kernel's name of the entry point * @param filename * Input OpenCL C Kernel + * @param accessorParameters + * {@link AccessorParameters} that contains the accessor for each input and output parameter of the kernel. * @param atomics * Atomics region. * @return {@link TaskGraphInterface} @@ -488,6 +490,22 @@ TaskGraphInte */ TaskGraphInterface prebuiltTask(String id, String entryPoint, String filename, AccessorParameters accessorParameters, int[] atomics); + /** + * + * @param id + * Task-id + * @param entryPoint + * Kernel's name of the entry point + * @param klass + * Class that can access the resource within the JAR file. + * @param resource + * Input file that represents the kernel source. It could be either SPIR-V, OpenCL C, or PTX code. + * @param accessorParameters + * {@link AccessorParameters} that contains the accessor for each input and output parameter to the kernel. + * @return {@link TaskGraphInterface} + */ + TaskGraphInterface prebuiltTask(String id, String entryPoint, Class klass, String resource, AccessorParameters accessorParameters); + /** * Obtains the task-schedule name that was assigned. * diff --git a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/common/PrebuiltTaskPackage.java b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/common/PrebuiltTaskPackage.java index 923deb4120..2a55232b91 100644 --- a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/common/PrebuiltTaskPackage.java +++ b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/common/PrebuiltTaskPackage.java @@ -26,6 +26,7 @@ public class PrebuiltTaskPackage extends TaskPackage { private final Object[] args; private final Access[] accesses; private int[] atomics; + private Class klass; PrebuiltTaskPackage(String id, String entryPoint, String fileName, AccessorParameters accessorParameters) { super(id, null); @@ -70,4 +71,12 @@ public boolean isPrebuiltTask() { public int[] getAtomics() { return atomics; } + + public void withClass(Class klass) { + this.klass = klass; + } + + public Class getClassJar() { + return klass; + } } diff --git a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/TornadoNativeArray.java b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/TornadoNativeArray.java index 8577c322f1..93ee7e7b21 100644 --- a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/TornadoNativeArray.java +++ b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/types/arrays/TornadoNativeArray.java @@ -32,7 +32,7 @@ *

* *

- * The constant {@link ARRAY_HEADER} represents the size of the header in bytes. + * The constant {@link TornadoNativeArray#ARRAY_HEADER} represents the size of the header in bytes. *

*/ public abstract sealed class TornadoNativeArray // diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/runtime/OCLTornadoDevice.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/runtime/OCLTornadoDevice.java index d09912004f..67f92d4112 100644 --- a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/runtime/OCLTornadoDevice.java +++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/runtime/OCLTornadoDevice.java @@ -24,6 +24,7 @@ package uk.ac.manchester.tornado.drivers.opencl.runtime; import java.io.IOException; +import java.io.InputStream; import java.lang.foreign.MemorySegment; import java.nio.file.Files; import java.nio.file.Path; @@ -53,14 +54,8 @@ import uk.ac.manchester.tornado.api.memory.XPUBuffer; import uk.ac.manchester.tornado.api.profiler.ProfilerType; import uk.ac.manchester.tornado.api.profiler.TornadoProfiler; -import uk.ac.manchester.tornado.api.types.arrays.ByteArray; -import uk.ac.manchester.tornado.api.types.arrays.CharArray; -import uk.ac.manchester.tornado.api.types.arrays.DoubleArray; -import uk.ac.manchester.tornado.api.types.arrays.FloatArray; -import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; -import uk.ac.manchester.tornado.api.types.arrays.IntArray; -import uk.ac.manchester.tornado.api.types.arrays.LongArray; -import uk.ac.manchester.tornado.api.types.arrays.ShortArray; +import uk.ac.manchester.tornado.api.types.arrays.TornadoNativeArray; +import uk.ac.manchester.tornado.api.types.tensors.Tensor; import uk.ac.manchester.tornado.drivers.common.TornadoBufferProvider; import uk.ac.manchester.tornado.drivers.opencl.OCLBackendImpl; import uk.ac.manchester.tornado.drivers.opencl.OCLCodeCache; @@ -309,31 +304,39 @@ private TornadoInstalledCode compileTask(SchedulableTask task) { } } - private TornadoInstalledCode compilePreBuiltTask(SchedulableTask task) { - final OCLDeviceContextInterface deviceContext = getDeviceContext(); - final PrebuiltTask executable = (PrebuiltTask) task; - if (deviceContext.isCached(task.getId(), executable.getEntryPoint())) { - return deviceContext.getInstalledCode(task.getId(), executable.getEntryPoint()); + private byte[] getSource(PrebuiltTask prebuiltTask) { + byte[] source; + Class klass = prebuiltTask.getKlassJar(); + if (klass != null) { + try (InputStream inputStream = klass.getClassLoader().getResourceAsStream(prebuiltTask.getFilename())) { + TornadoInternalError.guarantee(inputStream != null, "file does not exist: %s", prebuiltTask.getFilename()); + source = inputStream.readAllBytes(); + } catch (IOException e) { + throw new TornadoBailoutRuntimeException("[Error] I/O Exception in reallAllBytes", e); + } + } else { + final Path path = Paths.get(prebuiltTask.getFilename()); + TornadoInternalError.guarantee(path.toFile().exists(), "file does not exist: %s", prebuiltTask.getFilename()); + try { + source = Files.readAllBytes(path); + } catch (IOException e) { + throw new TornadoBailoutRuntimeException("[Error] I/O Exception in reallAllBytes", e); + } } + return source; + } - final Path path = Paths.get(executable.getFilename()); - TornadoInternalError.guarantee(path.toFile().exists(), "file does not exist: %s", executable.getFilename()); - try { - final byte[] source = Files.readAllBytes(path); + private TornadoInstalledCode compilePreBuiltTask(SchedulableTask task) { + final OCLDeviceContextInterface deviceContext = getDeviceContext(); + final PrebuiltTask prebuiltTask = (PrebuiltTask) task; + if (deviceContext.isCached(task.getId(), prebuiltTask.getEntryPoint())) + return deviceContext.getInstalledCode(task.getId(), prebuiltTask.getEntryPoint()); - OCLInstalledCode installedCode; - if (OCLBackend.isDeviceAnFPGAAccelerator(deviceContext)) { - // A) for FPGA - installedCode = deviceContext.installCode(task.getId(), executable.getEntryPoint(), source, task.meta().isPrintKernelEnabled()); - } else { - // B) for CPU multi-core or GPU - installedCode = deviceContext.installCode(executable.meta(), task.getId(), executable.getEntryPoint(), source); - } - return installedCode; - } catch (IOException e) { - e.printStackTrace(); - } - return null; + byte[] source = getSource(prebuiltTask); + if (OCLBackend.isDeviceAnFPGAAccelerator(deviceContext)) + return deviceContext.installCode(task.getId(), prebuiltTask.getEntryPoint(), source, task.meta().isPrintKernelEnabled()); + else + return deviceContext.installCode(prebuiltTask.meta(), task.getId(), prebuiltTask.getEntryPoint(), source); } private TornadoInstalledCode compileJavaToAccelerator(SchedulableTask task) { @@ -404,8 +407,7 @@ public int[] checkAtomicsForTask(SchedulableTask task) { @Override public int[] checkAtomicsForTask(SchedulableTask task, int[] array, int paramIndex, Object value) { - if (value instanceof AtomicInteger) { - AtomicInteger ai = (AtomicInteger) value; + if (value instanceof AtomicInteger ai) { if (TornadoAtomicIntegerNode.globalAtomicsParameters.containsKey(task.meta().getCompiledResolvedJavaMethod())) { HashMap values = TornadoAtomicIntegerNode.globalAtomicsParameters.get(task.meta().getCompiledResolvedJavaMethod()); int index = values.get(paramIndex); @@ -535,21 +537,7 @@ private XPUBuffer createDeviceBuffer(Class type, Object object, OCLDeviceCont result = new OCLVectorWrapper(deviceContext, object, batchSize); } else if (object instanceof MemorySegment) { result = new OCLMemorySegmentWrapper(deviceContext, batchSize); - } else if (object instanceof IntArray) { - result = new OCLMemorySegmentWrapper(deviceContext, batchSize); - } else if (object instanceof FloatArray) { - result = new OCLMemorySegmentWrapper(deviceContext, batchSize); - } else if (object instanceof DoubleArray) { - result = new OCLMemorySegmentWrapper(deviceContext, batchSize); - } else if (object instanceof LongArray) { - result = new OCLMemorySegmentWrapper(deviceContext, batchSize); - } else if (object instanceof ShortArray) { - result = new OCLMemorySegmentWrapper(deviceContext, batchSize); - } else if (object instanceof ByteArray) { - result = new OCLMemorySegmentWrapper(deviceContext, batchSize); - } else if (object instanceof CharArray) { - result = new OCLMemorySegmentWrapper(deviceContext, batchSize); - } else if (object instanceof HalfFloatArray) { + } else if (object instanceof TornadoNativeArray && !(object instanceof Tensor)) { result = new OCLMemorySegmentWrapper(deviceContext, batchSize); } else { result = new OCLXPUBuffer(deviceContext, object); diff --git a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/runtime/PTXTornadoDevice.java b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/runtime/PTXTornadoDevice.java index 2632bf16cf..a90efbc95d 100644 --- a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/runtime/PTXTornadoDevice.java +++ b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/runtime/PTXTornadoDevice.java @@ -26,6 +26,7 @@ import static uk.ac.manchester.tornado.drivers.ptx.graal.PTXCodeUtil.buildKernelName; import java.io.IOException; +import java.io.InputStream; import java.lang.foreign.MemorySegment; import java.nio.file.Files; import java.nio.file.Path; @@ -49,14 +50,8 @@ import uk.ac.manchester.tornado.api.memory.XPUBuffer; import uk.ac.manchester.tornado.api.profiler.ProfilerType; import uk.ac.manchester.tornado.api.profiler.TornadoProfiler; -import uk.ac.manchester.tornado.api.types.arrays.ByteArray; -import uk.ac.manchester.tornado.api.types.arrays.CharArray; -import uk.ac.manchester.tornado.api.types.arrays.DoubleArray; -import uk.ac.manchester.tornado.api.types.arrays.FloatArray; -import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; -import uk.ac.manchester.tornado.api.types.arrays.IntArray; -import uk.ac.manchester.tornado.api.types.arrays.LongArray; -import uk.ac.manchester.tornado.api.types.arrays.ShortArray; +import uk.ac.manchester.tornado.api.types.arrays.TornadoNativeArray; +import uk.ac.manchester.tornado.api.types.tensors.Tensor; import uk.ac.manchester.tornado.drivers.common.TornadoBufferProvider; import uk.ac.manchester.tornado.drivers.ptx.PTX; import uk.ac.manchester.tornado.drivers.ptx.PTXBackendImpl; @@ -201,24 +196,38 @@ private TornadoInstalledCode compileTask(SchedulableTask task) { } } + private byte[] getSource(PrebuiltTask prebuiltTask) { + byte[] source; + Class klass = prebuiltTask.getKlassJar(); + if (klass != null) { + try (InputStream inputStream = klass.getClassLoader().getResourceAsStream(prebuiltTask.getFilename())) { + TornadoInternalError.guarantee(inputStream != null, "file does not exist: %s", prebuiltTask.getFilename()); + source = inputStream.readAllBytes(); + } catch (IOException e) { + throw new TornadoBailoutRuntimeException("[Error] I/O Exception in reallAllBytes", e); + } + } else { + final Path path = Paths.get(prebuiltTask.getFilename()); + TornadoInternalError.guarantee(path.toFile().exists(), "file does not exist: %s", prebuiltTask.getFilename()); + try { + source = Files.readAllBytes(path); + } catch (IOException e) { + throw new TornadoBailoutRuntimeException("[Error] I/O Exception in reallAllBytes", e); + } + } + return source; + } + private TornadoInstalledCode compilePreBuiltTask(SchedulableTask task) { final PTXDeviceContext deviceContext = getDeviceContext(); - final PrebuiltTask executable = (PrebuiltTask) task; - String functionName = buildKernelName(executable.getEntryPoint(), executable); - if (deviceContext.isCached(executable.getEntryPoint(), executable)) { + final PrebuiltTask prebuiltTask = (PrebuiltTask) task; + String functionName = buildKernelName(prebuiltTask.getEntryPoint(), prebuiltTask); + if (deviceContext.isCached(prebuiltTask.getEntryPoint(), prebuiltTask)) return deviceContext.getInstalledCode(functionName); - } - final Path path = Paths.get(executable.getFilename()); - TornadoInternalError.guarantee(path.toFile().exists(), "file does not exist: %s", executable.getFilename()); - try { - byte[] source = Files.readAllBytes(path); - source = PTXCodeUtil.getCodeWithAttachedPTXHeader(source, getBackend()); - return deviceContext.installCode(functionName, source, executable.getEntryPoint(), task.meta().isPrintKernelEnabled()); - } catch (IOException e) { - e.printStackTrace(); - } - return null; + byte[] source = getSource(prebuiltTask); + source = PTXCodeUtil.getCodeWithAttachedPTXHeader(source, getBackend()); + return deviceContext.installCode(functionName, source, prebuiltTask.getEntryPoint(), task.meta().isPrintKernelEnabled()); } @Override @@ -229,8 +238,7 @@ public boolean isFullJITMode(SchedulableTask task) { @Override public TornadoInstalledCode getCodeFromCache(SchedulableTask task) { String methodName; - if (task instanceof PrebuiltTask) { - PrebuiltTask prebuiltTask = (PrebuiltTask) task; + if (task instanceof PrebuiltTask prebuiltTask) { methodName = prebuiltTask.getEntryPoint(); } else { CompilableTask compilableTask = (CompilableTask) task; @@ -260,21 +268,7 @@ private XPUBuffer createDeviceBuffer(Class type, Object object, long batchSiz result = new PTXVectorWrapper(getDeviceContext(), object, batchSize); } else if (object instanceof MemorySegment) { result = new PTXMemorySegmentWrapper(getDeviceContext(), batchSize); - } else if (object instanceof IntArray) { - result = new PTXMemorySegmentWrapper(getDeviceContext(), batchSize); - } else if (object instanceof FloatArray) { - result = new PTXMemorySegmentWrapper(getDeviceContext(), batchSize); - } else if (object instanceof DoubleArray) { - result = new PTXMemorySegmentWrapper(getDeviceContext(), batchSize); - } else if (object instanceof LongArray) { - result = new PTXMemorySegmentWrapper(getDeviceContext(), batchSize); - } else if (object instanceof ShortArray) { - result = new PTXMemorySegmentWrapper(getDeviceContext(), batchSize); - } else if (object instanceof ByteArray) { - result = new PTXMemorySegmentWrapper(getDeviceContext(), batchSize); - } else if (object instanceof CharArray) { - result = new PTXMemorySegmentWrapper(getDeviceContext(), batchSize); - } else if (object instanceof HalfFloatArray) { + } else if (object instanceof TornadoNativeArray && !(object instanceof Tensor)) { result = new PTXMemorySegmentWrapper(getDeviceContext(), batchSize); } else { result = new PTXObjectWrapper(getDeviceContext(), object); @@ -527,8 +521,7 @@ public void sync(long executionPlanId) { @Override public boolean equals(Object obj) { - if (obj instanceof PTXTornadoDevice) { - final PTXTornadoDevice other = (PTXTornadoDevice) obj; + if (obj instanceof PTXTornadoDevice other) { return (other.deviceIndex == deviceIndex); } return false; diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVCodeCache.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVCodeCache.java index 32048ebdac..b655967b04 100644 --- a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVCodeCache.java +++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVCodeCache.java @@ -114,8 +114,8 @@ public SPIRVInstalledCode installSPIRVBinary(TaskMetaData meta, String id, Strin } writeBufferToFile(buffer, spirvFile); - return installSPIRVBinary(meta, id, entryPoint, spirvFile); + return installSPIRVBinary(meta, id, entryPoint, spirvFile, null); } - public abstract SPIRVInstalledCode installSPIRVBinary(TaskMetaData meta, String id, String entryPoint, String pathToFile); + public abstract SPIRVInstalledCode installSPIRVBinary(TaskMetaData meta, String id, String entryPoint, String pathToFile, Class klassJar); } diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVDeviceContext.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVDeviceContext.java index d9d600d9a2..c4f995fcf7 100644 --- a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVDeviceContext.java +++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVDeviceContext.java @@ -386,8 +386,8 @@ public SPIRVInstalledCode installBinary(TaskMetaData meta, String id, String ent return codeCache.installSPIRVBinary(meta, id, entryPoint, code); } - public SPIRVInstalledCode installBinary(TaskMetaData meta, String id, String entryPoint, String pathToFile) { - return codeCache.installSPIRVBinary(meta, id, entryPoint, pathToFile); + public SPIRVInstalledCode installBinary(TaskMetaData meta, String id, String entryPoint, String pathToFile, Class klassJar) { + return codeCache.installSPIRVBinary(meta, id, entryPoint, pathToFile, klassJar); } public boolean isCached(String id, String entryPoint) { diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVLevelZeroCodeCache.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVLevelZeroCodeCache.java index 6321bda461..63087c4217 100644 --- a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVLevelZeroCodeCache.java +++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVLevelZeroCodeCache.java @@ -24,12 +24,18 @@ package uk.ac.manchester.tornado.drivers.spirv; import java.io.FileNotFoundException; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; import uk.ac.manchester.beehivespirvtoolkit.lib.SPIRVTool; import uk.ac.manchester.beehivespirvtoolkit.lib.disassembler.Disassembler; import uk.ac.manchester.beehivespirvtoolkit.lib.disassembler.SPIRVDisassemblerOptions; import uk.ac.manchester.beehivespirvtoolkit.lib.disassembler.SPVFileReader; import uk.ac.manchester.tornado.api.exceptions.TornadoBailoutRuntimeException; +import uk.ac.manchester.tornado.api.exceptions.TornadoInternalError; +import uk.ac.manchester.tornado.api.exceptions.TornadoRuntimeException; import uk.ac.manchester.tornado.drivers.common.logging.Logger; import uk.ac.manchester.tornado.drivers.spirv.graal.SPIRVInstalledCode; import uk.ac.manchester.tornado.drivers.spirv.graal.SPIRVLevelZeroInstalledCode; @@ -55,14 +61,28 @@ public SPIRVLevelZeroCodeCache(SPIRVDeviceContext deviceContext) { } @Override - public synchronized SPIRVInstalledCode installSPIRVBinary(TaskMetaData meta, String id, String entryPoint, String pathToFile) { + public synchronized SPIRVInstalledCode installSPIRVBinary(TaskMetaData meta, String id, String entryPoint, String pathToFile, Class klassJar) { ZeModuleHandle module = new ZeModuleHandle(); ZeModuleDescriptor moduleDesc = new ZeModuleDescriptor(); ZeBuildLogHandle buildLog = new ZeBuildLogHandle(); moduleDesc.setFormat(ZeModuleFormat.ZE_MODULE_FORMAT_IL_SPIRV); moduleDesc.setBuildFlags("-ze-opt-level 2 -ze-opt-large-register-file"); - checkBinaryFileExists(pathToFile); + if (klassJar == null) { + checkBinaryFileExists(pathToFile); + } else { + try (InputStream inputStream = klassJar.getClassLoader().getResourceAsStream(pathToFile)) { + TornadoInternalError.guarantee(inputStream != null, "file does not exist: %s", pathToFile); + byte[] source = inputStream.readAllBytes(); + String tempFileName = createSPIRVTempDirectoryName() + "temp.spv"; + OutputStream stream = new FileOutputStream(tempFileName); + stream.write(source); + stream.close(); + pathToFile = tempFileName; + } catch (IOException e) { + throw new TornadoRuntimeException(e); + } + } SPIRVContext spirvContext = deviceContext.getSpirvContext(); SPIRVLevelZeroContext levelZeroContext = (SPIRVLevelZeroContext) spirvContext; diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVOCLCodeCache.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVOCLCodeCache.java index ca0d074c34..4f60d7f8b6 100644 --- a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVOCLCodeCache.java +++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/SPIRVOCLCodeCache.java @@ -26,13 +26,17 @@ import java.io.File; import java.io.FileInputStream; import java.io.FileNotFoundException; +import java.io.FileOutputStream; import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; import uk.ac.manchester.beehivespirvtoolkit.lib.SPIRVTool; import uk.ac.manchester.beehivespirvtoolkit.lib.disassembler.Disassembler; import uk.ac.manchester.beehivespirvtoolkit.lib.disassembler.SPIRVDisassemblerOptions; import uk.ac.manchester.beehivespirvtoolkit.lib.disassembler.SPVFileReader; import uk.ac.manchester.tornado.api.exceptions.TornadoBailoutRuntimeException; +import uk.ac.manchester.tornado.api.exceptions.TornadoInternalError; import uk.ac.manchester.tornado.api.exceptions.TornadoRuntimeException; import uk.ac.manchester.tornado.drivers.opencl.OCLErrorCode; import uk.ac.manchester.tornado.drivers.opencl.OCLTargetDevice; @@ -66,14 +70,28 @@ private byte[] readFile(String spirvFile) { } @Override - public SPIRVInstalledCode installSPIRVBinary(TaskMetaData meta, String id, String entryPoint, String pathToFile) { - - checkBinaryFileExists(pathToFile); + public SPIRVInstalledCode installSPIRVBinary(TaskMetaData meta, String id, String entryPoint, String fileName, Class klassJar) { + + if (klassJar == null) { + checkBinaryFileExists(fileName); + } else { + try (InputStream inputStream = klassJar.getClassLoader().getResourceAsStream(fileName)) { + TornadoInternalError.guarantee(inputStream != null, "file does not exist: %s", fileName); + byte[] source = inputStream.readAllBytes(); + String tempFileName = createSPIRVTempDirectoryName() + "temp.spv"; + OutputStream stream = new FileOutputStream(tempFileName); + stream.write(source); + stream.close(); + fileName = tempFileName; + } catch (IOException e) { + throw new TornadoRuntimeException(e); + } + } if (meta.isPrintKernelEnabled()) { SPVFileReader reader; try { - reader = new SPVFileReader(pathToFile); + reader = new SPVFileReader(fileName); } catch (FileNotFoundException e) { throw new TornadoBailoutRuntimeException(e.getMessage()); } @@ -86,36 +104,34 @@ public SPIRVInstalledCode installSPIRVBinary(TaskMetaData meta, String id, Strin } } - byte[] binary = readFile(pathToFile); + byte[] binary = readFile(fileName); + + long contextId = deviceContext.getSpirvContext().getOpenCLLayer().getContextId(); + long programPointer; - long contextId = deviceContext.getSpirvContext().getOpenCLLayer().getContextId(); - long programPointer; + SPIRVOCLNativeDispatcher spirvoclNativeCompiler = new SPIRVOCLNativeDispatcher(); + int[] errorCode = new int[1];programPointer=spirvoclNativeCompiler.clCreateProgramWithIL(contextId,binary,new long[]{binary.length},errorCode);if(errorCode[0]!=OCLErrorCode.CL_SUCCESS) + { + throw new TornadoRuntimeException("[ERROR] - clCreateProgramWithIL failed"); + } - SPIRVOCLNativeDispatcher spirvoclNativeCompiler = new SPIRVOCLNativeDispatcher(); - int[] errorCode = new int[1]; - programPointer = spirvoclNativeCompiler.clCreateProgramWithIL(contextId, binary, new long[] { binary.length }, errorCode); - if (errorCode[0] != OCLErrorCode.CL_SUCCESS) { - throw new TornadoRuntimeException("[ERROR] - clCreateProgramWithIL failed"); - } + OCLTargetDevice oclDevice = (OCLTargetDevice) deviceContext.getDevice().getDeviceRuntime(); + int status = spirvoclNativeCompiler.clBuildProgram(programPointer, 1, new long[] { oclDevice.getDevicePointer() }, "");if(status!=OCLErrorCode.CL_SUCCESS) + { + String log = spirvoclNativeCompiler.clGetProgramBuildInfo(programPointer, oclDevice.getDevicePointer()); + System.out.println(log); + throw new TornadoRuntimeException("[ERROR] - clBuildProgram failed"); + } - OCLTargetDevice oclDevice = (OCLTargetDevice) deviceContext.getDevice().getDeviceRuntime(); - int status = spirvoclNativeCompiler.clBuildProgram(programPointer, 1, new long[] { oclDevice.getDevicePointer() }, ""); - if (status != OCLErrorCode.CL_SUCCESS) { - String log = spirvoclNativeCompiler.clGetProgramBuildInfo(programPointer, oclDevice.getDevicePointer()); - System.out.println(log); - throw new TornadoRuntimeException("[ERROR] - clBuildProgram failed"); - } + long kernelPointer = spirvoclNativeCompiler.clCreateKernel(programPointer, entryPoint, errorCode);if(errorCode[0]!=OCLErrorCode.CL_SUCCESS) + { + throw new TornadoRuntimeException("[ERROR] - clCreateKernel failed"); + } - long kernelPointer = spirvoclNativeCompiler.clCreateKernel(programPointer, entryPoint, errorCode); - if (errorCode[0] != OCLErrorCode.CL_SUCCESS) { - throw new TornadoRuntimeException("[ERROR] - clCreateKernel failed"); - } + SPIRVOCLModule module = new SPIRVOCLModule(kernelPointer, entryPoint, fileName); + final SPIRVOCLInstalledCode installedCode = new SPIRVOCLInstalledCode(entryPoint, module, deviceContext); - SPIRVOCLModule module = new SPIRVOCLModule(kernelPointer, entryPoint, pathToFile); - final SPIRVOCLInstalledCode installedCode = new SPIRVOCLInstalledCode(entryPoint, module, deviceContext); - - // Install code in the code cache - cache.put(STR."\{id}-\{entryPoint}", installedCode); - return installedCode; - } + // Install code in the code cache + cache.put(STR."\{id}-\{entryPoint}",installedCode);return installedCode; +} } diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/runtime/SPIRVTornadoDevice.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/runtime/SPIRVTornadoDevice.java index 2f6c9f2cb2..268b8c8901 100644 --- a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/runtime/SPIRVTornadoDevice.java +++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/runtime/SPIRVTornadoDevice.java @@ -140,14 +140,22 @@ public TornadoInstalledCode installCode(SchedulableTask task) { } } - private TornadoInstalledCode compilePreBuiltTask(PrebuiltTask task) { + private void checkSourceInFS(PrebuiltTask prebuiltTask) { + if (prebuiltTask.getKlassJar() == null) { + Path pathtoSPIRVBin = Paths.get(prebuiltTask.getFilename()); + if (!pathtoSPIRVBin.toFile().exists()) { + throw new RuntimeException("SPIRV file does not exist: " + prebuiltTask.getFilename()); + } + } + } + + private TornadoInstalledCode compilePreBuiltTask(PrebuiltTask prebuiltTask) { final SPIRVDeviceContext deviceContext = getDeviceContext(); - if (deviceContext.isCached(task.getId(), task.getEntryPoint())) { - return deviceContext.getInstalledCode(task.getId(), task.getEntryPoint()); + if (deviceContext.isCached(prebuiltTask.getId(), prebuiltTask.getEntryPoint())) { + return deviceContext.getInstalledCode(prebuiltTask.getId(), prebuiltTask.getEntryPoint()); } - final Path pathToSPIRVBin = Paths.get(task.getFilename()); - TornadoInternalError.guarantee(pathToSPIRVBin.toFile().exists(), "files does not exists %s", task.getFilename()); - return deviceContext.installBinary(task.meta(), task.getId(), task.getEntryPoint(), task.getFilename()); + checkSourceInFS(prebuiltTask); + return deviceContext.installBinary(prebuiltTask.meta(), prebuiltTask.getId(), prebuiltTask.getEntryPoint(), prebuiltTask.getFilename(), prebuiltTask.getKlassJar()); } public SPIRVBackend getBackend() { diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/tests/TestSPIRVTornadoCompiler.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/tests/TestSPIRVTornadoCompiler.java index b08e2c6abb..bb208267b0 100644 --- a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/tests/TestSPIRVTornadoCompiler.java +++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/tests/TestSPIRVTornadoCompiler.java @@ -66,7 +66,7 @@ public static void main(String[] args) { String tornadoSDK = System.getenv("TORNADO_SDK"); String pathToSPIRVBinaryFile = tornadoSDK + "/examples/generated/add.spv"; - SPIRVInstalledCode code = codeCache.installSPIRVBinary(task, "add", "add", pathToSPIRVBinaryFile); + SPIRVInstalledCode code = codeCache.installSPIRVBinary(task, "add", "add", pathToSPIRVBinaryFile, null); String generatedCode = code.getGeneratedSourceCode(); if (scheduleMetaData.isPrintKernelEnabled()) { diff --git a/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/analyzer/TaskUtils.java b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/analyzer/TaskUtils.java index 355fad20fa..a84d099423 100644 --- a/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/analyzer/TaskUtils.java +++ b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/analyzer/TaskUtils.java @@ -55,8 +55,6 @@ import uk.ac.manchester.tornado.api.exceptions.TornadoInternalError; import uk.ac.manchester.tornado.runtime.TornadoCoreRuntime; import uk.ac.manchester.tornado.runtime.common.TornadoLogger; -import uk.ac.manchester.tornado.runtime.domain.DomainTree; -import uk.ac.manchester.tornado.runtime.domain.IntDomain; import uk.ac.manchester.tornado.runtime.tasks.CompilableTask; import uk.ac.manchester.tornado.runtime.tasks.PrebuiltTask; import uk.ac.manchester.tornado.runtime.tasks.meta.ScheduleMetaData; @@ -302,15 +300,6 @@ private static Object[] extractCapturedVariables(Object code) { return cvs; } - private static DomainTree buildDomainTree(int[] dims) { - final DomainTree domain = new DomainTree(dims.length); - for (int i = 0; i < dims.length; i++) { - domain.set(i, new IntDomain(0, 1, dims[i])); - } - return domain; - - } - public static PrebuiltTask createTask(ScheduleMetaData meta, PrebuiltTaskPackage taskPackage) { PrebuiltTask prebuiltTask = new PrebuiltTask(meta, // taskPackage.getId(), // @@ -318,9 +307,16 @@ public static PrebuiltTask createTask(ScheduleMetaData meta, PrebuiltTaskPackage taskPackage.getFilename(), // taskPackage.getArgs(), // taskPackage.getAccesses()); + + // Attach atomics if (taskPackage.getAtomics() != null) { prebuiltTask.withAtomics(taskPackage.getAtomics()); } + + // Attach Class if JAR file configuration is present + if (taskPackage.getClassJar() != null) { + prebuiltTask.withKlassJAR(taskPackage.getClassJar()); + } return prebuiltTask; } diff --git a/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/common/TornadoOptions.java b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/common/TornadoOptions.java index 350f238ef0..94f9a5bedd 100644 --- a/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/common/TornadoOptions.java +++ b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/common/TornadoOptions.java @@ -198,7 +198,7 @@ public class TornadoOptions { /** * Use Level Zero or OpenCL as the SPIR-V Code Dispatcher and Runtime. Allowed values: "opencl", "levelzero". The default option is "opencl". */ - public static final String SPIRV_DISPATCHER = getProperty("tornado.spirv.dispatcher", "opencl"); + public static final String SPIRV_DISPATCHER = getProperty("tornado.spirv.dispatcher", "levelzero"); /** * Check I/O parameters for every task within a task-graph. */ diff --git a/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/tasks/PrebuiltTask.java b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/tasks/PrebuiltTask.java index dd4a8de90c..374658457d 100644 --- a/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/tasks/PrebuiltTask.java +++ b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/tasks/PrebuiltTask.java @@ -49,6 +49,7 @@ public class PrebuiltTask implements SchedulableTask { private TornadoProfiler profiler; private boolean forceCompiler; private int[] atomics; + private Class klass; public PrebuiltTask(ScheduleMetaData scheduleMeta, String id, String entryPoint, String filename, Object[] args, Access[] access, TornadoDevice device, DomainTree domain) { this.entryPoint = entryPoint; @@ -246,4 +247,11 @@ public int[] getAtomics() { return atomics; } + public Class getKlassJar() { + return klass; + } + + public void withKlassJAR(Class classJar) { + this.klass = classJar; + } } diff --git a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/atomics/TestAtomics.java b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/atomics/TestAtomics.java index a6d3b9ee3e..229df26a46 100644 --- a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/atomics/TestAtomics.java +++ b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/atomics/TestAtomics.java @@ -285,12 +285,13 @@ public void testAtomic05_precompiled() throws TornadoExecutionPlanException { accessorParameters.set(0, a, Access.WRITE_ONLY); accessorParameters.set(1, b, Access.WRITE_ONLY); + int[] atomics = new int[] { 155 }; TaskGraph taskGraph = new TaskGraph("s0") // .prebuiltTask("t0", // "add", // tornadoSDK + "/examples/generated/atomics.cl", // accessorParameters, // - new int[] { 155 } // Array for AtomicsInteger - Initial int value + atomics // Array for AtomicsInteger - Initial int value ).transferToHost(DataTransferMode.EVERY_EXECUTION, a); WorkerGrid workerGrid = new WorkerGrid1D(32); diff --git a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/prebuilt/PrebuiltTest.java b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/prebuilt/PrebuiltTest.java index d6713ccbb1..c970035efb 100644 --- a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/prebuilt/PrebuiltTest.java +++ b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/prebuilt/PrebuiltTest.java @@ -276,6 +276,64 @@ public void testPrebuilt04SPIRVThroughOpenCLRuntime() throws TornadoExecutionPla } assertEquals(512, finalSum, 0.0f); + } + + /** + * Test: + * + * tornado-test -V uk.ac.manchester.tornado.unittests.prebuilt.PrebuiltTest#testPrebuiltWithJar + * + * + * @throws TornadoExecutionPlanException + */ + @Test + public void testPrebuiltWithJar() throws TornadoExecutionPlanException { + final int numElements = 8; + IntArray a = new IntArray(numElements); + IntArray b = new IntArray(numElements); + IntArray c = new IntArray(numElements); + + a.init(1); + b.init(2); + + String resource = switch (backendType) { + case PTX -> "add.ptx"; + case OPENCL -> "add.cl"; + case SPIRV -> "add.spv"; + default -> throw new TornadoRuntimeException("Backend not supported"); + }; + + // Define accessors for each parameter + AccessorParameters accessorParameters = new AccessorParameters(3); + accessorParameters.set(0, a, Access.READ_ONLY); + accessorParameters.set(1, b, Access.READ_ONLY); + accessorParameters.set(2, c, Access.WRITE_ONLY); + + // Define the Task-Graph + TaskGraph taskGraph = new TaskGraph("s0") // + .transferToDevice(DataTransferMode.FIRST_EXECUTION, a, b) // + .prebuiltTask("t0", // task-name (up to the developer to set it up) + "add", // name of the low-level kernel to invoke (as it appears in the kernel file) + PrebuiltTest.class, // use a .class that it is contained in the JAR file of the + resource, // resource file + accessorParameters) // accessors + .transferToHost(DataTransferMode.EVERY_EXECUTION, c); + + // When using the prebuilt API, we need to define the WorkerGrid, otherwise it will launch 1 thread + // on the target device + WorkerGrid workerGrid = new WorkerGrid1D(numElements); + GridScheduler gridScheduler = new GridScheduler("s0.t0", workerGrid); + + // Launch the application on the target device + try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(taskGraph.snapshot())) { + executionPlan.withGridScheduler(gridScheduler) // + .withDevice(defaultDevice) // + .execute(); + } + + for (int i = 0; i < c.getSize(); i++) { + assertEquals(a.get(i) + b.get(i), c.get(i)); + } } } diff --git a/tornado-unittests/src/main/resources/add.cl b/tornado-unittests/src/main/resources/add.cl new file mode 100644 index 0000000000..8bf25c1548 --- /dev/null +++ b/tornado-unittests/src/main/resources/add.cl @@ -0,0 +1,33 @@ +#pragma OPENCL EXTENSION cl_khr_fp64 : enable +__kernel void add(__global long *_kernel_context, __constant uchar *_constant_region, __local uchar *_local_region, __global int *_atomics, __global uchar *a, __global uchar *b, __global uchar *c) +{ + ulong ul_8, ul_10, ul_1, ul_0, ul_2, ul_12; + long l_5, l_7, l_6; + int i_11, i_9, i_15, i_14, i_13, i_3, i_4; + + // BLOCK 0 + ul_0 = a; + ul_1 = b; + ul_2 = c; + i_3 = get_global_id(0); + // BLOCK 1 MERGES [0 2 ] + i_4 = i_3; + for(;i_4 < 8;) { + // BLOCK 2 + l_5 = (long) i_4; + l_6 = l_5 << 2; + l_7 = l_6 + 24L; + ul_8 = ul_0 + l_7; + i_9 = *((__global int *) ul_8); + ul_10 = ul_1 + l_7; + i_11 = *((__global int *) ul_10); + ul_12 = ul_2 + l_7; + i_13 = i_9 + i_11; + *((__global int *) ul_12) = i_13; + i_14 = get_global_size(0); + i_15 = i_14 + i_4; + i_4 = i_15; + } + // BLOCK 3 + return; +} diff --git a/tornado-unittests/src/main/resources/add.ptx b/tornado-unittests/src/main/resources/add.ptx new file mode 100644 index 0000000000..7aa881971a --- /dev/null +++ b/tornado-unittests/src/main/resources/add.ptx @@ -0,0 +1,44 @@ + .visible .entry s0_t0_add_arrays_intarray_arrays_intarray_arrays_intarray(.param .u64 .ptr .global .align 8 kernel_context, .param .u64 .ptr .global .align 8 a, .param .u64 .ptr .global .align 8 b, .param .u64 .ptr .global .align 8 c) { + .reg .s64 rsd<3>; + .reg .pred rpb<2>; + .reg .u32 rui<5>; + .reg .s32 rsi<9>; + .reg .u64 rud<9>; + + BLOCK_0: + ld.param.u64 rud0, [kernel_context]; + ld.param.u64 rud1, [a]; + ld.param.u64 rud2, [b]; + ld.param.u64 rud3, [c]; + mov.u32 rui0, %nctaid.x; + mov.u32 rui1, %ntid.x; + mul.wide.u32 rud4, rui0, rui1; + cvt.s32.u64 rsi0, rud4; + mov.u32 rui2, %tid.x; + mov.u32 rui3, %ctaid.x; + mad.lo.s32 rsi1, rui3, rui1, rui2; + + BLOCK_1: + mov.s32 rsi2, rsi1; + LOOP_COND_1: + setp.lt.s32 rpb0, rsi2, 8; + @!rpb0 bra BLOCK_3; + + BLOCK_2: + add.s32 rsi3, rsi2, 6; + cvt.s64.s32 rsd0, rsi3; + shl.b64 rsd1, rsd0, 2; + add.u64 rud5, rud1, rsd1; + ld.global.s32 rsi4, [rud5]; + add.u64 rud6, rud2, rsd1; + ld.global.s32 rsi5, [rud6]; + add.u64 rud7, rud3, rsd1; + add.s32 rsi6, rsi4, rsi5; + st.global.s32 [rud7], rsi6; + add.s32 rsi7, rsi0, rsi2; + mov.s32 rsi2, rsi7; + bra.uni LOOP_COND_1; + + BLOCK_3: + ret; + } diff --git a/tornado-unittests/src/main/resources/add.spv b/tornado-unittests/src/main/resources/add.spv new file mode 100644 index 0000000000..3775789999 Binary files /dev/null and b/tornado-unittests/src/main/resources/add.spv differ