diff --git a/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java b/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java index e4856524dd2..59047e688c8 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java +++ b/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java @@ -146,7 +146,10 @@ public NDArray toDevice(Device device, boolean copy) { } return this; } - throw new UnsupportedOperationException(UNSUPPORTED_MSG); + NDArray array = getManager().create(getShape(), getDataType(), device); + array.setName(getName()); + copyTo(array); + return array; } /** {@inheritDoc} */ @@ -160,7 +163,9 @@ public NDArray toType(DataType dataType, boolean copy) { } Number[] numbers = toArray(); ByteBuffer bb = toTypeInternal(numbers, dataType); - return manager.create(bb, getShape(), dataType); + NDArray array = manager.create(bb, getShape(), dataType); + array.setName(getName()); + return array; } private ByteBuffer toTypeInternal(Number[] numbers, DataType dataType) { diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java index e1f75ff4b33..9e36ec35884 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java @@ -162,7 +162,9 @@ public PtNDArray toDevice(Device device, boolean copy) { if (device.equals(getDevice()) && !copy) { return this; } - return JniUtils.to(this, getDataType(), device); + PtNDArray array = JniUtils.to(this, getDataType(), device); + array.setName(getName()); + return array; } /** {@inheritDoc} */ @@ -171,7 +173,9 @@ public PtNDArray toType(DataType dataType, boolean copy) { if (dataType.equals(getDataType()) && !copy) { return this; } - return JniUtils.to(this, dataType, getDevice()); + PtNDArray array = JniUtils.to(this, dataType, getDevice()); + array.setName(array.getName()); + return array; } /** {@inheritDoc} */ @@ -366,7 +370,9 @@ public void detach() { /** {@inheritDoc} */ @Override public NDArray duplicate() { - return JniUtils.clone(this); + NDArray array = JniUtils.clone(this); + array.setName(getName()); + return array; } /** {@inheritDoc} */