diff --git a/src/mistral_inference/cache.py b/src/mistral_inference/cache.py index 93cfb1c..b5fa901 100644 --- a/src/mistral_inference/cache.py +++ b/src/mistral_inference/cache.py @@ -177,19 +177,24 @@ def get_input_metadata(self, seqlens: List[int]) -> CacheInputMetadata: subsequent_prefill = any(seqlen > 1 for seqlen in seqlens) if first_prefill: assert all([pos == 0 for pos in seqpos]), seqpos - mask = BlockDiagonalCausalMask.from_seqlens(seqlens).make_local_attention(self.max_seq_len) + mask = BlockDiagonalCausalMask.from_seqlens( + seqlens, + device=self.device, + ).make_local_attention(self.max_seq_len) elif subsequent_prefill: mask = BlockDiagonalMask.from_seqlens( q_seqlen=seqlens, kv_seqlen=[ s + cached_s.clamp(max=self.max_seq_len).item() for (s, cached_s) in zip(seqlens, self.kv_seqlens) ], + device=self.device, ).make_local_attention_from_bottomright(self.max_seq_len) else: mask = BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens( q_seqlen=seqlens, kv_padding=self.max_seq_len, kv_seqlen=(self.kv_seqlens + cached_elements).clamp(max=self.max_seq_len).tolist(), + device=self.device, ) return CacheInputMetadata(