Skip to content

Commit

Permalink
Creates MultiDevice (#2819)
Browse files Browse the repository at this point in the history
This creates an abstraction for combining devices into a single device. The main
use case for now is in DJL Serving TP_parallel. It will allow us to create a
WorkerGroup and a PyPredictor for a set of devices and then track the usage of
devices properly. It could also be used later for multi-gpu training or other
multi-device cases.
  • Loading branch information
zachgk authored Oct 25, 2023
1 parent 0b6474f commit 185981b
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 3 deletions.
105 changes: 102 additions & 3 deletions api/src/main/java/ai/djl/Device.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,16 @@

import ai.djl.engine.Engine;

import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

/**
* The {@code Device} class provides the specified assignment for CPU/GPU processing on the {@code
Expand All @@ -30,7 +35,7 @@
* @see <a href="https://d2l.djl.ai/chapter_deep-learning-computation/use-gpu.html">The D2L chapter
* on GPU devices</a>
*/
public final class Device {
public class Device {

private static final Map<String, Device> CACHE = new ConcurrentHashMap<>();

Expand All @@ -39,8 +44,8 @@ public final class Device {

private static final Pattern DEVICE_NAME = Pattern.compile("([a-z]+)([0-9]*)");

private String deviceType;
private int deviceId;
protected String deviceType;
protected int deviceId;

/**
* Creates a {@code Device} with basic information.
Expand Down Expand Up @@ -101,6 +106,13 @@ public static Device fromName(String deviceName, Engine engine) {
return engine.defaultDevice();
}

if (deviceName.contains("+")) {
String[] split = deviceName.split("\\+");
List<Device> subDevices =
Arrays.stream(split).map(n -> fromName(n, engine)).collect(Collectors.toList());
return new MultiDevice(subDevices);
}

Matcher matcher = DEVICE_NAME.matcher(deviceName);
if (matcher.matches()) {
String deviceType = matcher.group(1);
Expand Down Expand Up @@ -214,4 +226,91 @@ public interface Type {
String CPU = "cpu";
String GPU = "gpu";
}

/** A combined {@link Device} representing the composition of multiple other devices. */
public static class MultiDevice extends Device {

List<Device> devices;

/**
* Constructs a {@link MultiDevice} with a range of new devices.
*
* @param deviceType the type of the sub-devices
* @param startInclusive the start (inclusive) of the devices range
* @param endExclusive the end (exclusive) of the devices range
*/
public MultiDevice(String deviceType, int startInclusive, int endExclusive) {
this(
IntStream.range(startInclusive, endExclusive)
.mapToObj(i -> Device.of(deviceType, i))
.collect(Collectors.toList()));
}

/**
* Constructs a {@link MultiDevice} from sub devices.
*
* @param devices the sub devices
*/
public MultiDevice(Device... devices) {
this(Arrays.asList(devices));
}

/**
* Constructs a {@link MultiDevice} from sub devices.
*
* @param devices the sub devices
*/
public MultiDevice(List<Device> devices) {
super(null, -1);
devices.sort(
Comparator.comparing(Device::getDeviceType, String.CASE_INSENSITIVE_ORDER)
.thenComparingInt(Device::getDeviceId));
this.deviceType =
String.join(
"+",
(Iterable<String>)
() ->
devices.stream()
.map(d -> d.getDeviceType() + d.getDeviceId())
.iterator());
this.devices = devices;
}

/**
* Returns the sub devices.
*
* @return the sub devices
*/
public List<Device> getDevices() {
return devices;
}

/** {@inheritDoc} */
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
if (!super.equals(o)) {
return false;
}
MultiDevice that = (MultiDevice) o;
return Objects.equals(devices, that.devices);
}

/** {@inheritDoc} */
@Override
public int hashCode() {
return Objects.hash(super.hashCode(), devices);
}

/** {@inheritDoc} */
@Override
public String toString() {
return deviceType + "()";
}
}
}
5 changes: 5 additions & 0 deletions api/src/main/java/ai/djl/training/ParameterStore.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package ai.djl.training;

import ai.djl.Device;
import ai.djl.Device.MultiDevice;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.nn.Parameter;
Expand Down Expand Up @@ -64,6 +65,10 @@ public void setParameterServer(ParameterServer parameterServer, Device[] devices
this.parameterServer = parameterServer;
deviceMap.clear();
for (int i = 0; i < devices.length; ++i) {
if (devices[i] instanceof MultiDevice) {
throw new IllegalArgumentException(
"The parameter store does not support MultiDevices");
}
if (deviceMap.put(devices[i], i) != null) {
throw new IllegalArgumentException("Duplicated devices are not allowed.");
}
Expand Down
7 changes: 7 additions & 0 deletions api/src/test/java/ai/djl/DeviceTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

package ai.djl;

import ai.djl.Device.MultiDevice;
import ai.djl.engine.Engine;

import org.testng.Assert;
Expand All @@ -37,6 +38,8 @@ public void testDevice() {

System.setProperty("test_key", "test");
Engine.debugEnvironment();

Assert.assertEquals(2, new MultiDevice(Device.gpu(1), Device.gpu(2)).getDevices().size());
}

@Test
Expand All @@ -54,5 +57,9 @@ public void testDeviceName() {
Device defaultDevice = Engine.getInstance().defaultDevice();
Assert.assertEquals(Device.fromName(""), defaultDevice);
Assert.assertEquals(Device.fromName(null), defaultDevice);

Assert.assertEquals(
Device.fromName("gpu1+gpu2"), new MultiDevice(Device.gpu(2), Device.gpu(1)));
Assert.assertEquals(Device.fromName("gpu1+gpu2"), new MultiDevice("gpu", 1, 3));
}
}

0 comments on commit 185981b

Please sign in to comment.