From 4a4e681a71fa4ddeb5ca6eeac4b48aa49702c1d6 Mon Sep 17 00:00:00 2001 From: Cornelius <39997278+cornzz@users.noreply.github.com> Date: Wed, 28 Aug 2024 14:09:12 +0200 Subject: [PATCH 1/2] Fix device error when using cuda device other than cuda:0 Attention bias was being moved to cuda:0 regardless of the selected cuda device --- src/mistral_inference/cache.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/mistral_inference/cache.py b/src/mistral_inference/cache.py index 93cfb1c..c7459c7 100644 --- a/src/mistral_inference/cache.py +++ b/src/mistral_inference/cache.py @@ -177,19 +177,24 @@ def get_input_metadata(self, seqlens: List[int]) -> CacheInputMetadata: subsequent_prefill = any(seqlen > 1 for seqlen in seqlens) if first_prefill: assert all([pos == 0 for pos in seqpos]), seqpos - mask = BlockDiagonalCausalMask.from_seqlens(seqlens).make_local_attention(self.max_seq_len) + mask = BlockDiagonalCausalMask.from_seqlens( + seqlens, + device=self.device + ).make_local_attention(self.max_seq_len) elif subsequent_prefill: mask = BlockDiagonalMask.from_seqlens( q_seqlen=seqlens, kv_seqlen=[ s + cached_s.clamp(max=self.max_seq_len).item() for (s, cached_s) in zip(seqlens, self.kv_seqlens) ], + device=self.device ).make_local_attention_from_bottomright(self.max_seq_len) else: mask = BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens( q_seqlen=seqlens, kv_padding=self.max_seq_len, kv_seqlen=(self.kv_seqlens + cached_elements).clamp(max=self.max_seq_len).tolist(), + device=self.device ) return CacheInputMetadata( From 83fb40a89617837b5e9f4f5220477ae9b5faddc7 Mon Sep 17 00:00:00 2001 From: Cornelius <39997278+cornzz@users.noreply.github.com> Date: Thu, 29 Aug 2024 13:17:23 +0200 Subject: [PATCH 2/2] Fix formatting in BufferCache.get_input_metadata() --- src/mistral_inference/cache.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/mistral_inference/cache.py b/src/mistral_inference/cache.py index c7459c7..b5fa901 100644 --- a/src/mistral_inference/cache.py +++ b/src/mistral_inference/cache.py @@ -179,7 +179,7 @@ def get_input_metadata(self, seqlens: List[int]) -> CacheInputMetadata: assert all([pos == 0 for pos in seqpos]), seqpos mask = BlockDiagonalCausalMask.from_seqlens( seqlens, - device=self.device + device=self.device, ).make_local_attention(self.max_seq_len) elif subsequent_prefill: mask = BlockDiagonalMask.from_seqlens( @@ -187,14 +187,14 @@ def get_input_metadata(self, seqlens: List[int]) -> CacheInputMetadata: kv_seqlen=[ s + cached_s.clamp(max=self.max_seq_len).item() for (s, cached_s) in zip(seqlens, self.kv_seqlens) ], - device=self.device + device=self.device, ).make_local_attention_from_bottomright(self.max_seq_len) else: mask = BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens( q_seqlen=seqlens, kv_padding=self.max_seq_len, kv_seqlen=(self.kv_seqlens + cached_elements).clamp(max=self.max_seq_len).tolist(), - device=self.device + device=self.device, ) return CacheInputMetadata(