From 2fc80589af9902f0abc943e84b4ba4d6934cfff4 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Fri, 1 Dec 2023 09:42:41 -0800 Subject: [PATCH] [tokenizer] Uses fp32 for TextembeddingTranslator clip() (#2881) --- .../ai/djl/huggingface/translator/TextEmbeddingTranslator.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TextEmbeddingTranslator.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TextEmbeddingTranslator.java index 6dc1a4ed454..326a641dee0 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TextEmbeddingTranslator.java +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TextEmbeddingTranslator.java @@ -113,7 +113,7 @@ private static NDArray meanPool(NDArray embeddings, NDArray attentionMask, boole long[] shape = embeddings.getShape().getShape(); attentionMask = attentionMask.expandDims(-1).broadcast(shape); NDArray inputAttentionMaskSum = attentionMask.sum(AXIS); - NDArray clamp = inputAttentionMaskSum.clip(1e-9, 1e12); + NDArray clamp = inputAttentionMaskSum.clip(1e-9f, 1e12f); NDArray prod = embeddings.mul(attentionMask); NDArray sum = prod.sum(AXIS); if (sqrt) {