Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[API] Extension of the prebuilt-API to load kernels from JAR files #505

Draft
wants to merge 6 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,15 @@ public TaskGraph prebuiltTask(String id, String entryPoint, String filename, Acc
return this;
}

@Override
public TaskGraph prebuiltTask(String id, String entryPoint, Class<?> klass, String resource, AccessorParameters accessorParameters) {
checkTaskName(id);
PrebuiltTaskPackage prebuiltTask = TaskPackage.createPrebuiltTask(id, entryPoint, resource, accessorParameters);
prebuiltTask.withClass(klass);
taskGraphImpl.addPrebuiltTask(prebuiltTask);
return this;
}

/**
* Obtains the task-schedule name that was assigned.
*
Expand Down Expand Up @@ -849,11 +858,11 @@ long getTotalBytesCopyOut() {
return taskGraphImpl.getTotalBytesCopyOut();
}

protected String getProfileLog() {
String getProfileLog() {
return taskGraphImpl.getProfileLog();
}

public Collection<?> getOutputs() {
Collection<?> getOutputs() {
return taskGraphImpl.getOutputs();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -481,13 +481,31 @@ <T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15> TaskGraphInte
* Kernel's name of the entry point
* @param filename
* Input OpenCL C Kernel
* @param accessorParameters
* {@link AccessorParameters} that contains the accessor for each input and output parameter to the kernel.
jjfumero marked this conversation as resolved.
Show resolved Hide resolved
* @param atomics
* Atomics region.
* @return {@link TaskGraphInterface}
*
*/
TaskGraphInterface prebuiltTask(String id, String entryPoint, String filename, AccessorParameters accessorParameters, int[] atomics);

/**
*
* @param id
* Task-id
* @param entryPoint
* Kernel's name of the entry point
* @param klass
* Class that can access the resource within the JAR file.
* @param resource
* Input file that represents the kernel source. It could be either SPIR-V, OpenCL C, or PTX code.
* @param accessorParameters
* {@link AccessorParameters} that contains the accessor for each input and output parameter to the kernel.
* @return {@link TaskGraphInterface}
*/
TaskGraphInterface prebuiltTask(String id, String entryPoint, Class<?> klass, String resource, AccessorParameters accessorParameters);

/**
* Obtains the task-schedule name that was assigned.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ public class PrebuiltTaskPackage extends TaskPackage {
private final Object[] args;
private final Access[] accesses;
private int[] atomics;
private Class<?> klass;

PrebuiltTaskPackage(String id, String entryPoint, String fileName, AccessorParameters accessorParameters) {
super(id, null);
Expand Down Expand Up @@ -70,4 +71,12 @@ public boolean isPrebuiltTask() {
public int[] getAtomics() {
return atomics;
}

public void withClass(Class<?> klass) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should have been withKlass, right?

this.klass = klass;
}

public Class<?> getClassJar() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and this getKlass?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In Java, there is a method called getClass, so to avoid miss-understanding about the functionality of each, I named it ClassJar

return klass;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
* </p>
*
* <p>
* The constant {@link ARRAY_HEADER} represents the size of the header in bytes.
* The constant {@link TornadoNativeArray#ARRAY_HEADER} represents the size of the header in bytes.
* </p>
*/
public abstract sealed class TornadoNativeArray //
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
package uk.ac.manchester.tornado.drivers.opencl.runtime;

import java.io.IOException;
import java.io.InputStream;
import java.lang.foreign.MemorySegment;
import java.nio.file.Files;
import java.nio.file.Path;
Expand Down Expand Up @@ -53,14 +54,8 @@
import uk.ac.manchester.tornado.api.memory.XPUBuffer;
import uk.ac.manchester.tornado.api.profiler.ProfilerType;
import uk.ac.manchester.tornado.api.profiler.TornadoProfiler;
import uk.ac.manchester.tornado.api.types.arrays.ByteArray;
import uk.ac.manchester.tornado.api.types.arrays.CharArray;
import uk.ac.manchester.tornado.api.types.arrays.DoubleArray;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
import uk.ac.manchester.tornado.api.types.arrays.IntArray;
import uk.ac.manchester.tornado.api.types.arrays.LongArray;
import uk.ac.manchester.tornado.api.types.arrays.ShortArray;
import uk.ac.manchester.tornado.api.types.arrays.TornadoNativeArray;
import uk.ac.manchester.tornado.api.types.tensors.Tensor;
import uk.ac.manchester.tornado.drivers.common.TornadoBufferProvider;
import uk.ac.manchester.tornado.drivers.opencl.OCLBackendImpl;
import uk.ac.manchester.tornado.drivers.opencl.OCLCodeCache;
Expand Down Expand Up @@ -309,31 +304,39 @@ private TornadoInstalledCode compileTask(SchedulableTask task) {
}
}

private TornadoInstalledCode compilePreBuiltTask(SchedulableTask task) {
final OCLDeviceContextInterface deviceContext = getDeviceContext();
final PrebuiltTask executable = (PrebuiltTask) task;
if (deviceContext.isCached(task.getId(), executable.getEntryPoint())) {
return deviceContext.getInstalledCode(task.getId(), executable.getEntryPoint());
private byte[] getSource(PrebuiltTask prebuiltTask) {
byte[] source;
Class<?> klass = prebuiltTask.getClassJAR();
if (klass != null) {
try (InputStream inputStream = klass.getClassLoader().getResourceAsStream(prebuiltTask.getFilename())) {
TornadoInternalError.guarantee(inputStream != null, "file does not exist: %s", prebuiltTask.getFilename());
source = inputStream.readAllBytes();
} catch (IOException e) {
throw new TornadoBailoutRuntimeException("[Error] I/O Exception in reallAllBytes", e);
}
} else {
final Path path = Paths.get(prebuiltTask.getFilename());
TornadoInternalError.guarantee(path.toFile().exists(), "file does not exist: %s", prebuiltTask.getFilename());
try {
source = Files.readAllBytes(path);
} catch (IOException e) {
throw new TornadoBailoutRuntimeException("[Error] I/O Exception in reallAllBytes", e);
}
}
return source;
}

final Path path = Paths.get(executable.getFilename());
TornadoInternalError.guarantee(path.toFile().exists(), "file does not exist: %s", executable.getFilename());
try {
final byte[] source = Files.readAllBytes(path);
private TornadoInstalledCode compilePreBuiltTask(SchedulableTask task) {
final OCLDeviceContextInterface deviceContext = getDeviceContext();
final PrebuiltTask prebuiltTask = (PrebuiltTask) task;
if (deviceContext.isCached(task.getId(), prebuiltTask.getEntryPoint()))
return deviceContext.getInstalledCode(task.getId(), prebuiltTask.getEntryPoint());

OCLInstalledCode installedCode;
if (OCLBackend.isDeviceAnFPGAAccelerator(deviceContext)) {
// A) for FPGA
installedCode = deviceContext.installCode(task.getId(), executable.getEntryPoint(), source, task.meta().isPrintKernelEnabled());
} else {
// B) for CPU multi-core or GPU
installedCode = deviceContext.installCode(executable.meta(), task.getId(), executable.getEntryPoint(), source);
}
return installedCode;
} catch (IOException e) {
e.printStackTrace();
}
return null;
byte[] source = getSource(prebuiltTask);
if (OCLBackend.isDeviceAnFPGAAccelerator(deviceContext))
return deviceContext.installCode(task.getId(), prebuiltTask.getEntryPoint(), source, task.meta().isPrintKernelEnabled());
else
return deviceContext.installCode(prebuiltTask.meta(), task.getId(), prebuiltTask.getEntryPoint(), source);
}

private TornadoInstalledCode compileJavaToAccelerator(SchedulableTask task) {
Expand Down Expand Up @@ -404,8 +407,7 @@ public int[] checkAtomicsForTask(SchedulableTask task) {

@Override
public int[] checkAtomicsForTask(SchedulableTask task, int[] array, int paramIndex, Object value) {
if (value instanceof AtomicInteger) {
AtomicInteger ai = (AtomicInteger) value;
if (value instanceof AtomicInteger ai) {
if (TornadoAtomicIntegerNode.globalAtomicsParameters.containsKey(task.meta().getCompiledResolvedJavaMethod())) {
HashMap<Integer, Integer> values = TornadoAtomicIntegerNode.globalAtomicsParameters.get(task.meta().getCompiledResolvedJavaMethod());
int index = values.get(paramIndex);
Expand Down Expand Up @@ -535,21 +537,7 @@ private XPUBuffer createDeviceBuffer(Class<?> type, Object object, OCLDeviceCont
result = new OCLVectorWrapper(deviceContext, object, batchSize);
} else if (object instanceof MemorySegment) {
result = new OCLMemorySegmentWrapper(deviceContext, batchSize);
} else if (object instanceof IntArray) {
result = new OCLMemorySegmentWrapper(deviceContext, batchSize);
} else if (object instanceof FloatArray) {
result = new OCLMemorySegmentWrapper(deviceContext, batchSize);
} else if (object instanceof DoubleArray) {
result = new OCLMemorySegmentWrapper(deviceContext, batchSize);
} else if (object instanceof LongArray) {
result = new OCLMemorySegmentWrapper(deviceContext, batchSize);
} else if (object instanceof ShortArray) {
result = new OCLMemorySegmentWrapper(deviceContext, batchSize);
} else if (object instanceof ByteArray) {
result = new OCLMemorySegmentWrapper(deviceContext, batchSize);
} else if (object instanceof CharArray) {
result = new OCLMemorySegmentWrapper(deviceContext, batchSize);
} else if (object instanceof HalfFloatArray) {
} else if (object instanceof TornadoNativeArray && !(object instanceof Tensor)) {
mairooni marked this conversation as resolved.
Show resolved Hide resolved
result = new OCLMemorySegmentWrapper(deviceContext, batchSize);
} else {
result = new OCLXPUBuffer(deviceContext, object);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import static uk.ac.manchester.tornado.drivers.ptx.graal.PTXCodeUtil.buildKernelName;

import java.io.IOException;
import java.io.InputStream;
import java.lang.foreign.MemorySegment;
import java.nio.file.Files;
import java.nio.file.Path;
Expand All @@ -49,14 +50,8 @@
import uk.ac.manchester.tornado.api.memory.XPUBuffer;
import uk.ac.manchester.tornado.api.profiler.ProfilerType;
import uk.ac.manchester.tornado.api.profiler.TornadoProfiler;
import uk.ac.manchester.tornado.api.types.arrays.ByteArray;
import uk.ac.manchester.tornado.api.types.arrays.CharArray;
import uk.ac.manchester.tornado.api.types.arrays.DoubleArray;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
import uk.ac.manchester.tornado.api.types.arrays.IntArray;
import uk.ac.manchester.tornado.api.types.arrays.LongArray;
import uk.ac.manchester.tornado.api.types.arrays.ShortArray;
import uk.ac.manchester.tornado.api.types.arrays.TornadoNativeArray;
import uk.ac.manchester.tornado.api.types.tensors.Tensor;
import uk.ac.manchester.tornado.drivers.common.TornadoBufferProvider;
import uk.ac.manchester.tornado.drivers.ptx.PTX;
import uk.ac.manchester.tornado.drivers.ptx.PTXBackendImpl;
Expand Down Expand Up @@ -201,24 +196,38 @@ private TornadoInstalledCode compileTask(SchedulableTask task) {
}
}

private byte[] getSource(PrebuiltTask prebuiltTask) {
byte[] source;
Class<?> klass = prebuiltTask.getClassJAR();
if (klass != null) {
try (InputStream inputStream = klass.getClassLoader().getResourceAsStream(prebuiltTask.getFilename())) {
TornadoInternalError.guarantee(inputStream != null, "file does not exist: %s", prebuiltTask.getFilename());
source = inputStream.readAllBytes();
} catch (IOException e) {
throw new TornadoBailoutRuntimeException("[Error] I/O Exception in reallAllBytes", e);
}
} else {
final Path path = Paths.get(prebuiltTask.getFilename());
TornadoInternalError.guarantee(path.toFile().exists(), "file does not exist: %s", prebuiltTask.getFilename());
try {
source = Files.readAllBytes(path);
} catch (IOException e) {
throw new TornadoBailoutRuntimeException("[Error] I/O Exception in reallAllBytes", e);
}
}
return source;
}

private TornadoInstalledCode compilePreBuiltTask(SchedulableTask task) {
final PTXDeviceContext deviceContext = getDeviceContext();
final PrebuiltTask executable = (PrebuiltTask) task;
String functionName = buildKernelName(executable.getEntryPoint(), executable);
if (deviceContext.isCached(executable.getEntryPoint(), executable)) {
final PrebuiltTask prebuiltTask = (PrebuiltTask) task;
String functionName = buildKernelName(prebuiltTask.getEntryPoint(), prebuiltTask);
if (deviceContext.isCached(prebuiltTask.getEntryPoint(), prebuiltTask))
return deviceContext.getInstalledCode(functionName);
}

final Path path = Paths.get(executable.getFilename());
TornadoInternalError.guarantee(path.toFile().exists(), "file does not exist: %s", executable.getFilename());
try {
byte[] source = Files.readAllBytes(path);
source = PTXCodeUtil.getCodeWithAttachedPTXHeader(source, getBackend());
return deviceContext.installCode(functionName, source, executable.getEntryPoint(), task.meta().isPrintKernelEnabled());
} catch (IOException e) {
e.printStackTrace();
}
return null;
byte[] source = getSource(prebuiltTask);
source = PTXCodeUtil.getCodeWithAttachedPTXHeader(source, getBackend());
return deviceContext.installCode(functionName, source, prebuiltTask.getEntryPoint(), task.meta().isPrintKernelEnabled());
}

@Override
Expand All @@ -229,8 +238,7 @@ public boolean isFullJITMode(SchedulableTask task) {
@Override
public TornadoInstalledCode getCodeFromCache(SchedulableTask task) {
String methodName;
if (task instanceof PrebuiltTask) {
PrebuiltTask prebuiltTask = (PrebuiltTask) task;
if (task instanceof PrebuiltTask prebuiltTask) {
methodName = prebuiltTask.getEntryPoint();
} else {
CompilableTask compilableTask = (CompilableTask) task;
Expand Down Expand Up @@ -260,21 +268,7 @@ private XPUBuffer createDeviceBuffer(Class<?> type, Object object, long batchSiz
result = new PTXVectorWrapper(getDeviceContext(), object, batchSize);
} else if (object instanceof MemorySegment) {
result = new PTXMemorySegmentWrapper(getDeviceContext(), batchSize);
} else if (object instanceof IntArray) {
result = new PTXMemorySegmentWrapper(getDeviceContext(), batchSize);
} else if (object instanceof FloatArray) {
result = new PTXMemorySegmentWrapper(getDeviceContext(), batchSize);
} else if (object instanceof DoubleArray) {
result = new PTXMemorySegmentWrapper(getDeviceContext(), batchSize);
} else if (object instanceof LongArray) {
result = new PTXMemorySegmentWrapper(getDeviceContext(), batchSize);
} else if (object instanceof ShortArray) {
result = new PTXMemorySegmentWrapper(getDeviceContext(), batchSize);
} else if (object instanceof ByteArray) {
result = new PTXMemorySegmentWrapper(getDeviceContext(), batchSize);
} else if (object instanceof CharArray) {
result = new PTXMemorySegmentWrapper(getDeviceContext(), batchSize);
} else if (object instanceof HalfFloatArray) {
} else if (object instanceof TornadoNativeArray && !(object instanceof Tensor)) {
mairooni marked this conversation as resolved.
Show resolved Hide resolved
result = new PTXMemorySegmentWrapper(getDeviceContext(), batchSize);
} else {
result = new PTXObjectWrapper(getDeviceContext(), object);
Expand Down Expand Up @@ -527,8 +521,7 @@ public void sync(long executionPlanId) {

@Override
public boolean equals(Object obj) {
if (obj instanceof PTXTornadoDevice) {
final PTXTornadoDevice other = (PTXTornadoDevice) obj;
if (obj instanceof PTXTornadoDevice other) {
return (other.deviceIndex == deviceIndex);
}
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
package uk.ac.manchester.tornado.drivers.spirv.runtime;

import java.lang.foreign.MemorySegment;
import java.net.URL;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.List;
Expand Down Expand Up @@ -140,14 +141,28 @@ public TornadoInstalledCode installCode(SchedulableTask task) {
}
}

private TornadoInstalledCode compilePreBuiltTask(PrebuiltTask task) {
private Path getSource(PrebuiltTask prebuiltTask) {
Class<?> klass = prebuiltTask.getClassJAR();
if (klass != null) {
URL url = klass.getClassLoader().getResource(prebuiltTask.getFilename());
if (url != null) {
return Paths.get(url.getPath());
} else {
throw new TornadoRuntimeException("[ERROR] Prebuilt file path not found: " + url.getPath());
}
} else {
return Paths.get(prebuiltTask.getFilename());
}
}

private TornadoInstalledCode compilePreBuiltTask(PrebuiltTask prebuiltTask) {
final SPIRVDeviceContext deviceContext = getDeviceContext();
if (deviceContext.isCached(task.getId(), task.getEntryPoint())) {
return deviceContext.getInstalledCode(task.getId(), task.getEntryPoint());
if (deviceContext.isCached(prebuiltTask.getId(), prebuiltTask.getEntryPoint())) {
return deviceContext.getInstalledCode(prebuiltTask.getId(), prebuiltTask.getEntryPoint());
}
final Path pathToSPIRVBin = Paths.get(task.getFilename());
TornadoInternalError.guarantee(pathToSPIRVBin.toFile().exists(), "files does not exists %s", task.getFilename());
return deviceContext.installBinary(task.meta(), task.getId(), task.getEntryPoint(), task.getFilename());
final Path pathToSPIRVBin = getSource(prebuiltTask);
TornadoInternalError.guarantee(pathToSPIRVBin.toFile().exists(), "files does not exists %s", prebuiltTask.getFilename());
return deviceContext.installBinary(prebuiltTask.meta(), prebuiltTask.getId(), prebuiltTask.getEntryPoint(), prebuiltTask.getFilename());
}

public SPIRVBackend getBackend() {
Expand Down
Loading