Skip to content

Commit

Permalink
rework to maintain static tabes as an option
Browse files Browse the repository at this point in the history
  • Loading branch information
rsuderman committed Oct 14, 2024
1 parent 7d39095 commit bdf643a
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions sharktank/sharktank/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(
rope_freq_base: float,
device: Optional[torch.device] = None,
use_hf: bool = False,
static_tables: bool = False,
use_table: bool = True,
tensor_parallelism_size: int = 1,
):
Expand All @@ -33,14 +34,23 @@ def __init__(
self.rope_dimension_count = rope_dimension_count
self.max_seqlen = max_seqlen
self.use_hf = use_hf
self.static_tables = static_tables
self.use_table = use_table

self.rope_freq_base = rope_freq_base if rope_freq_base is not None else 10000.0
self.tensor_parallelism_size = tensor_parallelism_size
if static_tables:
ops.module_register_buffer(
self, "static_rotary_embed_table", self._create_rotary_embed_table()
)
else:
self.static_rotary_embed_table = None

@property
def rotary_embed_table(self):
if self.use_table:
if self.static_tables:
return self.static_rotary_embed_table
return self._create_rotary_embed_table()

if self.tensor_parallelism_size == 1:
Expand Down

0 comments on commit bdf643a

Please sign in to comment.