Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

djl-convert does not produce working model from Huggingface #3518

Open
AlEscher opened this issue Nov 14, 2024 · 6 comments
Open

djl-convert does not produce working model from Huggingface #3518

AlEscher opened this issue Nov 14, 2024 · 6 comments
Labels
bug Something isn't working

Comments

@AlEscher
Copy link

Description

I am trying to convert a Huggingface model to make it compatible with DJL.
My goal is to use djl-convert to convert the model and be able to load it locally.
Then I want to generate code-embeddings for Java code, using e.g. Codebert.
I ran djl-convert -m microsoft/codebert-base -o models/codebert and then used this code to import the model:

Criteria<String, float[]> criteria = Criteria.builder().setTypes(String.class, float[].class)
  .optApplication(Application.NLP.TEXT_EMBEDDING).optModelPath(Paths.get("models/codebert"))
  .optModelName("codebert-base.pt").optTranslator(translator).optProgress(new ProgressBar()).build();
ZooModel<String, float[]> model = ModelZoo.loadModel(criteria);
Predictor<String, float[]> predictor = model.newPredictor();
float[] embeddings = predictor.predict(input);

The translator is implemented like this:

protected ModelTranslator(String tokenizerPath, boolean useTokenTypes) throws IOException {
  this.tokenizer = HuggingFaceTokenizer.newInstance(Paths.get(tokenizerPath, "tokenizer.json"));
  this.useTokenTypes = useTokenTypes;
}

@Override
public NDList processInput(TranslatorContext ctx, String input) {
  return tokenizer.encode(input).toNDList(ctx.getNDManager(), useTokenTypes);
}

@Override
public float[] processOutput(TranslatorContext ctx, NDList list) {
  // Retrieve the embeddings from the output
  NDArray embeddings = list.singletonOrThrow();
  return embeddings.toFloatArray();
}

When generating the embeddings, the model fails with:

ai.djl.translate.TranslateException: ai.djl.engine.EngineException: Expected at most 3 argument(s) for operator 'forward', but received 4 argument(s). Declaration: forward(__torch__.transformers.models.roberta.modeling_roberta.RobertaModel self, Tensor input_ids, Tensor attention_mask) -> Dict(str, Tensor)

What am I doing wrong? Is there a better approach to load a model from huggingface? codebert-base does not seem to be available in the Model Zoo.

Expected Behavior

The convert tool produces a model that can be loaded locally and has a working forward method

Error Message

ai.djl.translate.TranslateException: ai.djl.engine.EngineException: Expected at most 3 argument(s) for operator 'forward', but received 4 argument(s). Declaration: forward(__torch__.transformers.models.roberta.modeling_roberta.RobertaModel self, Tensor input_ids, Tensor attention_mask) -> Dict(str, Tensor)

How to Reproduce?

See provided code above

Steps to reproduce

(Paste the commands you ran that produced the error.)

  1. Run the djl-convert tool as described above
  2. Attempt to generate embeddings as described above

What have you tried to solve it?

I tried many different ways of getting a model from huggingface to work locally, this approach seems to be the intended way according to https://djl.ai/extensions/tokenizers/#convert-huggingface-model-to-torchscript

Environment Info

Please run the command ./gradlew debugEnv from the root directory of DJL (if necessary, clone DJL first). It will output information about your system, environment, and installation that can help us debug your issue. Paste the output of the command below:

Found C:\Dev\Research\djl\\gradle\wrapper\gradle-wrapper.jar
Starting a Gradle Daemon (subsequent builds will be faster)

> Task :engines:ml:xgboost:processResources
Downloading https://publish.djl.ai/xgboost/2.0.3/jnilib/linux/aarch64/libxgboost4j.so

> Task :engines:pytorch:pytorch-jni:processResources
Downloading https://publish.djl.ai/pytorch/2.4.0/jnilib/0.31.0/linux-x86_64/cpu/libdjl_torch.so
Downloading https://publish.djl.ai/pytorch/2.4.0/jnilib/0.31.0/linux-x86_64/cpu-precxx11/libdjl_torch.so
Downloading https://publish.djl.ai/pytorch/2.4.0/jnilib/0.31.0/linux-aarch64/cpu-precxx11/libdjl_torch.so
Downloading https://publish.djl.ai/pytorch/2.4.0/jnilib/0.31.0/osx-aarch64/cpu/libdjl_torch.dylib                                                                                                                                                                                                                   
Downloading https://publish.djl.ai/pytorch/2.4.0/jnilib/0.31.0/win-x86_64/cpu/djl_torch.dll                                                                                                                                                                                                                         
Downloading https://publish.djl.ai/pytorch/2.4.0/jnilib/0.31.0/linux-x86_64/cu124/libdjl_torch.so                                                                                                                                                                                                                   
Downloading https://publish.djl.ai/pytorch/2.4.0/jnilib/0.31.0/linux-x86_64/cu124-precxx11/libdjl_torch.so                                                                                                                                                                                                          
Downloading https://publish.djl.ai/pytorch/2.4.0/jnilib/0.31.0/win-x86_64/cu124/djl_torch.dll                                                                                                                                                                                                                       

> Task :integration:debugEnv
----------- System Properties -----------
java.specification.version: 21
sun.cpu.isalist: amd64
sun.jnu.encoding: Cp1252
java.class.path: C:\Dev\Research\djl\integration\build\classes\java\main;C:\Dev\Research\djl\integration\build\resources\main;C:\Users\Alessandro\.gradle\caches\modules-2\files-2.1\commons-cli\commons-cli\1.9.0\e1cdfa8bf40ccbb7440b2d1232f9f45bb20a1844\commons-cli-1.9.0.jar;C:\Users\Alessandro\.gradle\caches
\modules-2\files-2.1\org.apache.logging.log4j\log4j-slf4j2-impl\2.24.0\3d550671b19e83591d5e66cc8c77272e7aaac34c\log4j-slf4j2-impl-2.24.0.jar;C:\Dev\Research\djl\basicdataset\build\libs\basicdataset-0.31.0-SNAPSHOT.jar;C:\Dev\Research\djl\model-zoo\build\libs\model-zoo-0.31.0-SNAPSHOT.jar;C:\Dev\Research\djl
\testing\build\libs\testing-0.31.0-SNAPSHOT.jar;C:\Dev\Research\djl\engines\mxnet\mxnet-model-zoo\build\libs\mxnet-model-zoo-0.31.0-SNAPSHOT.jar;C:\Dev\Research\djl\engines\pytorch\pytorch-model-zoo\build\libs\pytorch-model-zoo-0.31.0-SNAPSHOT.jar;C:\Dev\Research\djl\engines\pytorch\pytorch-jni\build\libs\p
ytorch-jni-2.4.0-0.31.0-SNAPSHOT.jar;C:\Dev\Research\djl\engines\tensorflow\tensorflow-model-zoo\build\libs\tensorflow-model-zoo-0.31.0-SNAPSHOT.jar;C:\Dev\Research\djl\engines\ml\xgboost\build\libs\xgboost-0.31.0-SNAPSHOT.jar;C:\Dev\Research\djl\engines\ml\lightgbm\build\libs\lightgbm-0.31.0-SNAPSHOT.jar;C
:\Dev\Research\djl\engines\onnxruntime\onnxruntime-engine\build\libs\onnxruntime-engine-0.31.0-SNAPSHOT.jar;C:\Dev\Research\djl\extensions\tokenizers\build\libs\tokenizers-0.31.0-SNAPSHOT.jar;C:\Users\Alessandro\.gradle\caches\modules-2\files-2.1\org.apache.logging.log4j\log4j-core\2.24.0\537543d3b84d78b4d7
ad055c98f8af13e5e7f3a8\log4j-core-2.24.0.jar;C:\Users\Alessandro\.gradle\caches\modules-2\files-2.1\org.apache.logging.log4j\log4j-api\2.24.0\c6d9bd0c95c9bb6c530f4800da9507b98f018654\log4j-api-2.24.0.jar;C:\Dev\Research\djl\engines\mxnet\mxnet-engine\build\libs\mxnet-engine-0.31.0-SNAPSHOT.jar;C:\Dev\Resear
ch\djl\engines\pytorch\pytorch-engine\build\libs\pytorch-engine-0.31.0-SNAPSHOT.jar;C:\Dev\Research\djl\engines\tensorflow\tensorflow-engine\build\libs\tensorflow-engine-0.31.0-SNAPSHOT.jar;C:\Dev\Research\djl\api\build\libs\api-0.31.0-SNAPSHOT.jar;C:\Users\Alessandro\.gradle\caches\modules-2\files-2.1\org.
testng\testng\7.10.2\30742acada21960d4333a4204039fbdc6a92083a\testng-7.10.2.jar;C:\Users\Alessandro\.gradle\caches\modules-2\files-2.1\org.slf4j\slf4j-api\2.0.16\172931663a09a1fa515567af5fbef00897d3c04\slf4j-api-2.0.16.jar;C:\Users\Alessandro\.gradle\caches\modules-2\files-2.1\org.apache.commons\commons-csv
\1.11.0\8f2dc805097da534612128b7cdf491a5a76752bf\commons-csv-1.11.0.jar;C:\Users\Alessandro\.gradle\caches\modules-2\files-2.1\ml.dmlc\xgboost4j_2.12\2.0.3\db511d04d1ca1364cde79a6c8238a2694e31c592\xgboost4j_2.12-2.0.3.jar;C:\Users\Alessandro\.gradle\caches\modules-2\files-2.1\commons-logging\commons-logging
\1.3.4\b9fc14968d63a8b8a8a2c1885fe3e90564239708\commons-logging-1.3.4.jar;C:\Users\Alessandro\.gradle\caches\modules-2\files-2.1\com.microsoft.ml.lightgbm\lightgbmlib\3.2.110\f6c85e5d7cc44d49c4544240ea5c96004680007b\lightgbmlib-3.2.110.jar;C:\Users\Alessandro\.gradle\caches\modules-2\files-2.1\com.microsoft
.onnxruntime\onnxruntime\1.19.0\52985f239457f0b1f635b9a0e9e5b0b03c76b22b\onnxruntime-1.19.0.jar;C:\Users\Alessandro\.gradle\caches\modules-2\files-2.1\com.google.code.gson\gson\2.11.0\527175ca6d81050b53bdd4c457a6d6e017626b0e\gson-2.11.0.jar;C:\Users\Alessandro\.gradle\caches\modules-2\files-2.1\net.java.dev
.jna\jna\5.14.0\67bf3eaea4f0718cb376a181a629e5f88fa1c9dd\jna-5.14.0.jar;C:\Users\Alessandro\.gradle\caches\modules-2\files-2.1\org.apache.commons\commons-compress\1.27.1\a19151084758e2fbb6b41eddaa88e7b8ff4e6599\commons-compress-1.27.1.jar;C:\Users\Alessandro\.gradle\caches\modules-2\files-2.1\commons-io\com
mons-io\2.16.1\377d592e740dc77124e0901291dbfaa6810a200e\commons-io-2.16.1.jar;C:\Users\Alessandro\.gradle\caches\modules-2\files-2.1\commons-codec\commons-codec\1.17.1\973638b7149d333563584137ebf13a691bb60579\commons-codec-1.17.1.jar;C:\Users\Alessandro\.gradle\caches\modules-2\files-2.1\com.beust\jcommande
r\1.82\a7c5fef184d238065de38f81bbc6ee50cca2e21\jcommander-1.82.jar;C:\Users\Alessandro\.gradle\caches\modules-2\files-2.1\org.webjars\jquery\3.7.1\42088e652462c40a369b64d87e18e825644acfab\jquery-3.7.1.jar;C:\Dev\Research\djl\engines\tensorflow\tensorflow-api\build\libs\tensorflow-api-0.31.0-SNAPSHOT.jar;C:\
Users\Alessandro\.gradle\caches\modules-2\files-2.1\org.scala-lang.modules\scala-collection-compat_2.12\2.10.0\bf81785e892f4185f470bddd205b011237aab553\scala-collection-compat_2.12-2.10.0.jar;C:\Users\Alessandro\.gradle\caches\modules-2\files-2.1\com.google.errorprone\error_prone_annotations\2.27.0\91b2c29d
8a6148b5e2e4930f070d4840e2e48e34\error_prone_annotations-2.27.0.jar;C:\Users\Alessandro\.gradle\caches\modules-2\files-2.1\org.tensorflow\tensorflow-core-api\1.0.0-rc.1\ea1878fb8e289742237e5a0ba6f15398f3e9b7ef\tensorflow-core-api-1.0.0-rc.1.jar;C:\Users\Alessandro\.gradle\caches\modules-2\files-2.1\org.tens
orflow\tensorflow-core-native\1.0.0-rc.1\62b5fa3283865cc696dfbebf073ca2116b18f327\tensorflow-core-native-1.0.0-rc.1.jar;C:\Users\Alessandro\.gradle\caches\modules-2\files-2.1\org.bytedeco\javacpp\1.5.10\afb6ae145e7563c66b677cb4896dd0197d49fce6\javacpp-1.5.10.jar;C:\Users\Alessandro\.gradle\caches\modules-2\files-2.1\com.google.protobuf\protobuf-java\3.25.5\5ae5c9ec39930ae9b5a61b32b93288818ec05ec1\protobuf-java-3.25.5.jar;C:\Users\Alessandro\.gradle\caches\modules-2\files-2.1\org.tensorflow\ndarray\1.0.0-rc.1\4a96a398ad87bec32be9177b1441b9880c04d822\ndarray-1.0.0-rc.1.jar
java.vm.vendor: Oracle Corporation
sun.arch.data.model: 64
user.variant:
java.vendor.url: https://java.oracle.com/
user.timezone: Europe/Berlin
java.vm.specification.version: 21
os.name: Windows 11
user.country: GB
sun.java.launcher: SUN_STANDARD
sun.boot.library.path: C:\Program Files\Java\jdk-21\bin
sun.java.command: ai.djl.integration.util.DebugEnvironment
jdk.debug: release
sun.cpu.endian: little
user.home: C:\Users\Alessandro
user.language: en
java.specification.vendor: Oracle Corporation
java.version.date: 2024-10-15
java.home: C:\Program Files\Java\jdk-21
file.separator: \
java.vm.compressedOopsMode: Zero based
line.separator:

java.vm.specification.vendor: Oracle Corporation
java.specification.name: Java Platform API Specification
user.script:
sun.management.compiler: HotSpot 64-Bit Tiered Compilers
java.runtime.version: 21.0.5+9-LTS-239
user.name: Alessandro
stdout.encoding: Cp1252
path.separator: ;
os.version: 10.0
java.runtime.name: Java(TM) SE Runtime Environment
file.encoding: UTF-8
java.vm.name: Java HotSpot(TM) 64-Bit Server VM
java.vendor.url.bug: https://bugreport.java.com/bugreport/
java.io.tmpdir: C:\Users\ALESSA~1\AppData\Local\Temp\
java.version: 21.0.5
user.dir: C:\Dev\Research\djl\integration
os.arch: amd64
java.vm.specification.name: Java Virtual Machine Specification
sun.os.patch.level:
native.encoding: Cp1252
java.library.path: C:\Program Files\Java\jdk-21\bin;C:\WINDOWS\Sun\Java\bin;C:\WINDOWS\system32;C:\WINDOWS;C:\Program Files\BullseyeCoverage\bin;C:\Testwell\CTC;C:\Program Files\Common Files\Oracle\Java\javapath;C:\WINDOWS\system32;C:\WINDOWS;C:\WINDOWS\System32\Wbem;C:\WINDOWS\System32\WindowsPowerShell\v1
.0\;C:\WINDOWS\System32\OpenSSH\;C:\ProgramData\chocolatey\bin;C:\Program Files\osquery;C:\LLVM\bin;C:\cygwin64\bin;C:\Strawberry\c\bin;C:\Strawberry\perl\site\bin;C:\Strawberry\perl\bin;C:\Program Files\microsoft.codecoverage.17.1.0\build\netstandard1.0\CodeCoverage;C:\Program Files\Git LFS;C:\Program File
s\teamscale-upload-windows;C:\Program Files\dotnet\;C:\Program Files\OpenCppCoverage;C:\Program Files\BullseyeCoverage\lib;C:\Program Files\Go\bin;C:\Program Files\apache-maven-3.9.0\bin;C:\Program Files (x86)\Google\Cloud SDK\google-cloud-sdk\bin;C:\Program Files (x86)\GtkSharp\2.12\bin;C:\Program Files\no
dejs\;C:\Program Files\Docker\Docker\resources\bin;C:\Users\Alessandro\AppData\Local\Android\Sdk\platform-tools;C:\Users\Alessandro\AppData\Local\Programs\CLion\bin\cmake\win\x64\bin;C:\Users\Alessandro\AppData\Local\Programs\CLion\bin\mingw\bin;C:\Users\Alessandro\AppData\Roaming\Python\Python312\Scripts;C
:\Users\Alessandro\AppData\Local\Programs\Python\Python312\Scripts\;C:\Users\Alessandro\AppData\Local\Programs\Python\Python312\;C:\Users\Alessandro\AppData\Local\Programs\Python\Launcher\;C:\Users\Alessandro\AppData\Local\pnpm;C:\Users\Alessandro\.poetry\bin;C:\Users\Alessandro\AppData\Local\Microsoft\WindowsApps;C:\Users\Alessandro\AppData\Local\Programs\Microsoft VS Code\bin;C:\Users\Alessandro\go\bin;C:\Users\Alessandro\AppData\Roaming\npm;C:\Users\Alessandro\AppData\Local\JetBrains\Toolbox\scripts;C:\Users\Alessandro\.dotnet\tools;C:\Users\Alessandro\AppData\Local\Programs\Git\cmd;.
java.vm.info: mixed mode, sharing
stderr.encoding: Cp1252
java.vendor: Oracle Corporation
java.vm.version: 21.0.5+9-LTS-239
sun.io.unicode.encoding: UnicodeLittle
java.class.version: 65.0

--------- Environment Variables ---------
USERDOMAIN_ROAMINGPROFILE: DESKTOP-022SVF2
PROCESSOR_LEVEL: 6
LCOV_HOME: C:\ProgramData\chocolatey\lib\lcov\tools
SESSIONNAME: Console
ALLUSERSPROFILE: C:\ProgramData
COVFILE: C:\Dev\bullseye-testwise-coverage\bullseye.cov
PROCESSOR_ARCHITECTURE: AMD64
PSModulePath: C:\Users\Alessandro\OneDrive\Documents\WindowsPowerShell\Modules;C:\Program Files\WindowsPowerShell\Modules;C:\WINDOWS\system32\WindowsPowerShell\v1.0\Modules;C:\Program Files (x86)\Google\Cloud SDK\google-cloud-sdk\platform\PowerShell
SystemDrive: C:
PNPM_HOME: C:\Users\Alessandro\AppData\Local\pnpm
DIRNAME: C:\Dev\Research\djl\
USERNAME: Alessandro
CMD_LINE_ARGS: debugEnv
ProgramFiles(x86): C:\Program Files (x86)
APP_HOME: C:\Dev\Research\djl\
PATHEXT: .COM;.EXE;.BAT;.CMD;.VBS;.VBE;.JS;.JSE;.WSF;.WSH;.MSC;.PY;.PYW;.CPL
DriverData: C:\Windows\System32\Drivers\DriverData
OneDriveConsumer: C:\Users\Alessandro\OneDrive
GOPATH: C:\Users\Alessandro\go
ProgramData: C:\ProgramData
GIT_LFS_PATH: C:\Program Files\Git LFS
ProgramW6432: C:\Program Files
HOMEPATH: \Users\Alessandro
PROCESSOR_IDENTIFIER: Intel64 Family 6 Model 158 Stepping 10, GenuineIntel
ProgramFiles: C:\Program Files
PUBLIC: C:\Users\Public
windir: C:\WINDOWS
CTCHOME: C:\Testwell\CTC
=::: ::\
ZES_ENABLE_SYSMAN: 1
_SKIP: 2
LOCALAPPDATA: C:\Users\Alessandro\AppData\Local
USERDOMAIN: DESKTOP-022SVF2
LOGONSERVER: \\DESKTOP-022SVF2
JAVA_HOME: C:\Program Files\Java\jdk-21
PROMPT: $P$G
JETBRAINS_INTELLIJ_COMMAND_END_MARKER: SUyi67dqQ9BKmVqVo3br2NnDywq1xvC4ulCIYXe9Obl4Owe0u0wC2bPj9Yi6YYBr
EFC_10204: 1
OneDrive: C:\Users\Alessandro\OneDrive
=C:: C:\Dev\Research\djl
APPDATA: C:\Users\Alessandro\AppData\Roaming
DOWNLOAD_URL: "https://raw.githubusercontent.com/gradle/gradle/master/gradle/wrapper/gradle-wrapper.jar"
GTK_BASEPATH: C:\Program Files (x86)\GtkSharp\2.12\
JAVA_EXE: C:\Program Files\Java\jdk-21/bin/java.exe
ChocolateyInstall: C:\ProgramData\chocolatey
CommonProgramFiles: C:\Program Files\Common Files
Path: C:\Program Files\BullseyeCoverage\bin;C:\Testwell\CTC;C:\Program Files\Common Files\Oracle\Java\javapath;C:\WINDOWS\system32;C:\WINDOWS;C:\WINDOWS\System32\Wbem;C:\WINDOWS\System32\WindowsPowerShell\v1.0\;C:\WINDOWS\System32\OpenSSH\;C:\ProgramData\chocolatey\bin;C:\Program Files\osquery;C:\LLVM\bin;C
:\cygwin64\bin;C:\Strawberry\c\bin;C:\Strawberry\perl\site\bin;C:\Strawberry\perl\bin;C:\Program Files\microsoft.codecoverage.17.1.0\build\netstandard1.0\CodeCoverage;C:\Program Files\Git LFS;C:\Program Files\teamscale-upload-windows;C:\Program Files\dotnet\;C:\Program Files\OpenCppCoverage;C:\Program Files
\BullseyeCoverage\lib;C:\Program Files\Go\bin;C:\Program Files\apache-maven-3.9.0\bin;C:\Program Files (x86)\Google\Cloud SDK\google-cloud-sdk\bin;C:\Program Files (x86)\GtkSharp\2.12\bin;C:\Program Files\nodejs\;C:\Program Files\Docker\Docker\resources\bin;C:\Users\Alessandro\AppData\Local\Android\Sdk\plat
form-tools;C:\Users\Alessandro\AppData\Local\Programs\CLion\bin\cmake\win\x64\bin;C:\Users\Alessandro\AppData\Local\Programs\CLion\bin\mingw\bin;C:\Users\Alessandro\AppData\Roaming\Python\Python312\Scripts;C:\Users\Alessandro\AppData\Local\Programs\Python\Python312\Scripts\;C:\Users\Alessandro\AppData\Local
\Programs\Python\Python312\;C:\Users\Alessandro\AppData\Local\Programs\Python\Launcher\;C:\Users\Alessandro\AppData\Local\pnpm;C:\Users\Alessandro\.poetry\bin;C:\Users\Alessandro\AppData\Local\Microsoft\WindowsApps;C:\Users\Alessandro\AppData\Local\Programs\Microsoft VS Code\bin;C:\Users\Alessandro\go\bin;C:\Users\Alessandro\AppData\Roaming\npm;C:\Users\Alessandro\AppData\Local\JetBrains\Toolbox\scripts;C:\Users\Alessandro\.dotnet\tools;C:\Users\Alessandro\AppData\Local\Programs\Git\cmd
OS: Windows_NT
COMPUTERNAME: DESKTOP-022SVF2
PROCESSOR_REVISION: 9e0a
CLASSPATH: C:\Dev\Research\djl\\gradle\wrapper\gradle-wrapper.jar
CommonProgramW6432: C:\Program Files\Common Files
ComSpec: C:\WINDOWS\system32\cmd.exe
APP_BASE_NAME: gradlew
TERMINAL_EMULATOR: JetBrains-JediTerm
PSExecutionPolicyPreference: Bypass
SystemRoot: C:\WINDOWS
TEMP: C:\Users\ALESSA~1\AppData\Local\Temp
HOMEDRIVE: C:
USERPROFILE: C:\Users\Alessandro
TMP: C:\Users\ALESSA~1\AppData\Local\Temp
CommonProgramFiles(x86): C:\Program Files (x86)\Common Files
NUMBER_OF_PROCESSORS: 12

-------------- Directories --------------
temp directory: C:\Users\ALESSA~1\AppData\Local\Temp
DJL cache directory: C:\Users\Alessandro\.djl.ai
Engine cache directory: C:\Users\Alessandro\.djl.ai

------------------ CUDA -----------------
GPU Count: 0

----------------- Engines ---------------
DJL version: 0.31.0-SNAPSHOT
[INFO ] - Downloading libgcc_s_seh-1.dll ...
[INFO ] - Downloading libgfortran-3.dll ...
[INFO ] - Downloading libopenblas.dll ...
[INFO ] - Downloading libquadmath-0.dll ...
[INFO ] - Downloading mxnet.dll ...
Default Engine: MXNet:1.9.0, capabilities: [
        SIGNAL_HANDLER,
        LAPACK,
        BLAS_OPEN,
        OPENMP,
        OPENCV,
        MKLDNN,
]
MXNet Library: C:\Users\Alessandro\.djl.ai\mxnet\1.9.1-mkl-win-x86_64\mxnet.dll
Default Device: cpu()
Rust: 4
PyTorch: 2
MXNet: 0
XGBoost: 10
LightGBM: 10
OnnxRuntime: 10
TensorFlow: 3

--------------- Hardware --------------
Available processors (cores): 12
Byte Order: LITTLE_ENDIAN
Free memory (bytes): 513832784
Maximum memory (bytes): 8527020032
Total memory available to JVM (bytes): 536870912
Heap committed: 536870912
Heap nonCommitted: 29818880

BUILD SUCCESSFUL in 49s
64 actionable tasks: 15 executed, 49 up-to-date

@AlEscher AlEscher added the bug Something isn't working label Nov 14, 2024
@frankfliu
Copy link
Contributor

@AlEscher
Can you use our built-in TextEmbeddingTranslator?

The following code works for me for this model:

        Criteria<String, float[]> criteria =
                Criteria.builder()
                        .setTypes(String.class, float[].class)
                        .optModelPath(path)
                        .optEngine("PyTorch")
                        .optTranslatorFactory(new TextEmbeddingTranslatorFactory())
                        .optProgress(new ProgressBar())
                        .build();

@frankfliu
Copy link
Contributor

@AlEscher

Your error may caused by you set useTokenTypes = true

@AlEscher
Copy link
Author

@frankfliu Thank you for looking into this!
It seems to have been an issue with the translator.
Setting useTokenTypes to false throws a new error:

Traceback of TorchScript, serialized code (most recent call last):
  File "code/__torch__/transformers/models/roberta/modeling_roberta.py", line 19, in forward
    batch_size = ops.prim.NumToTensor(torch.size(input_ids, 0))
    _0 = int(batch_size)
    seq_length = ops.prim.NumToTensor(torch.size(input_ids, 1))
                                      ~~~~~~~~~~ <--- HERE
    _1 = int(seq_length)
    _2 = int(seq_length)

Traceback of TorchScript, original code (most recent call last):
C:\Users\Alessandro\AppData\Local\Programs\Python\Python312\Lib\site-packages\transformers\models\roberta\modeling_roberta.py(892): forward
C:\Users\Alessandro\AppData\Local\Programs\Python\Python312\Lib\site-packages\torch\nn\modules\module.py(1726): _slow_forward
C:\Users\Alessandro\AppData\Local\Programs\Python\Python312\Lib\site-packages\torch\nn\modules\module.py(1747): _call_impl
C:\Users\Alessandro\AppData\Local\Programs\Python\Python312\Lib\site-packages\torch\nn\modules\module.py(1736): _wrapped_call_impl
C:\Users\Alessandro\AppData\Local\Programs\Python\Python312\Lib\site-packages\torch\jit\_trace.py(1278): trace_module
C:\Users\Alessandro\AppData\Local\Programs\Python\Python312\Lib\site-packages\torch\jit\_trace.py(698): _trace_impl
C:\Users\Alessandro\AppData\Local\Programs\Python\Python312\Lib\site-packages\torch\jit\_trace.py(1002): trace
C:\Users\Alessandro\AppData\Local\Programs\Python\Python312\Lib\site-packages\djl_converter\huggingface_converter.py(286): jit_trace_model
C:\Users\Alessandro\AppData\Local\Programs\Python\Python312\Lib\site-packages\djl_converter\huggingface_converter.py(231): save_pytorch_model
C:\Users\Alessandro\AppData\Local\Programs\Python\Python312\Lib\site-packages\djl_converter\huggingface_converter.py(64): save_model
C:\Users\Alessandro\AppData\Local\Programs\Python\Python312\Lib\site-packages\djl_converter\model_converter.py(74): main
C:\Users\Alessandro\AppData\Local\Programs\Python\Python312\Scripts\djl-convert.exe\__main__.py(7): <module>
<frozen runpy>(88): _run_code
<frozen runpy>(198): _run_module_as_main
RuntimeError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

However just using the built-in translator like in the code snippet you provided seems to work.
I guess there is no reason to use a custom translator for my use case?
Thanks!

@frankfliu
Copy link
Contributor

@AlEscher

For text embedding case, I really suggest use our built-in TextEmbedingTranslatorFactory, it reads serving.properties file (exported by djl-convert and handles pre/post processing properly and efficiently.

If can always create your own translator if you need further customization, if it's an generic use case, please raise an issue or create a PR to improve our build-in Translator.

I think your error is caused by batch handing. You need return stack batchifier in your Translator:

    @Override
    public Batchifier getBatchifier() {
        return Batchifier.STACK;
    }

@AlEscher
Copy link
Author

@frankfliu I see, thanks!
I think the default translator should be enough for me for now. :)
It seems to be working now for CodeBert, however I see that for some huggingface models, the djl-convert is not able to perform the conversion.
2 examples are microsoft/graphcodebert-base and Salesforce/codet5p-110m-embedding.
GraphCodeBert:

djl-convert -m microsoft/graphcodebert-base -o models/graphcodebert-base
converting HuggingFace hub model: microsoft/graphcodebert-base
Loading model: microsoft/graphcodebert-base ...
Tracing model: microsoft/graphcodebert-base include_token_types=False ...
C:\Users\Alessandro\AppData\Local\Programs\Python\Python312\Lib\site-packages\transformers\modeling_utils.py:5006: FutureWarning: `_is_quantized_training_enabled` is going to be deprecated in transformers 4.39.0. Please use `model.hf_quantizer.is_trainable` instead
  warnings.warn(
`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.
Saving torchscript model: graphcodebert-base.pt ...
Verifying torchscript model(include_token_types=False): models/graphcodebert-base\graphcodebert-base.pt ...
Unexpected inference result: world
microsoft/graphcodebert-base: Unexpected inference result

CodeT5:

djl-convert -m Salesforce/codet5p-110m-embedding -o models/codet5
converting HuggingFace hub model: Salesforce/codet5p-110m-embedding
Unsupported model architecture: CodeT5p_Embedding for Salesforce/codet5p-110m-embedding.

The CodeT5 issue seems to be an issue with the model using an unsupported construct NotSupportedError(r, "Comprehension ifs are not supported yet") by torch.jit.script (which I assume is being used in the conversion), I encountered this aswell when trying to convert the model on my own.
However I am not sure what causes the issue for GraphCodeBert.
Do you know if this is something that is fixable in djl-convert or whether the issue lies in e.g. the torch.jit library like for CodeT5?
Thank you for your time so far!

@frankfliu
Copy link
Contributor

  1. microsoft/graphcodebert-base should work. It seems there is bug in the djl-convert, after we converted the model, we try to validate the output, the there is white space difference between the results. I will take a look to fix it. For the mean time, you can convert your model to OnnxRuntime, djl-convert -m microsoft/graphcodebert-base -f OnnxRuntime
  2. Salesforce/codet5p-110m-embedding we cannot support it for now, the default jit trace doesn't for this model, and onnx convert also failed. candle doesn't support this model either. It seems this is the only model uses CodeT5p_Embedding architecture.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants