Skip to content

Commit

Permalink
Finalize the API changes for 2.0 (#374)
Browse files Browse the repository at this point in the history
* `qkv_split` arguments for attention heads are now mandatory

* Rename `FromHFHub` mixins to `FromHF`

* Remove `FromHF.convert_hf_state_dict`
  • Loading branch information
danieldk authored Apr 16, 2024
1 parent c96e565 commit 8debb21
Show file tree
Hide file tree
Showing 37 changed files with 124 additions and 123 deletions.
4 changes: 2 additions & 2 deletions curated_transformers/generation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .falcon import FalconGenerator
from .generator import Generator
from .generator_wrapper import GeneratorWrapper
from .hf_hub import FromHFHub
from .hf_hub import FromHF
from .llama import LlamaGenerator
from .logits import (
CompoundLogitsTransform,
Expand All @@ -32,7 +32,7 @@
"DollyV2Generator",
"EndOfSequenceCondition",
"FalconGenerator",
"FromHFHub",
"FromHF",
"Generator",
"GeneratorConfig",
"GeneratorWrapper",
Expand Down
6 changes: 3 additions & 3 deletions curated_transformers/generation/auto_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
from .dolly_v2 import DollyV2Generator
from .falcon import FalconGenerator
from .generator_wrapper import GeneratorWrapper
from .hf_hub import FromHFHub
from .hf_hub import FromHF
from .llama import LlamaGenerator
from .mpt import MPTGenerator

# For the time being, we enable support for a generator on a case-by-case basis.
# In the future we might defer all unknown generators to DefaultGenerator.
GENERATOR_MAP: Dict[str, Type[FromHFHub]] = {
GENERATOR_MAP: Dict[str, Type[FromHF]] = {
"dolly-v2": DollyV2Generator,
"falcon": FalconGenerator,
"llama": LlamaGenerator,
Expand Down Expand Up @@ -70,7 +70,7 @@ def from_hf_hub(
return generator


def _resolve_generator_class(name: str) -> Type[FromHFHub]:
def _resolve_generator_class(name: str) -> Type[FromHF]:
for substring, generator_cls in GENERATOR_MAP.items():
if substring in name.lower():
return generator_cls
Expand Down
4 changes: 2 additions & 2 deletions curated_transformers/generation/default_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
from .config import GeneratorConfig, SampleGeneratorConfig
from .generator import Generator
from .generator_wrapper import GeneratorWrapper
from .hf_hub import FromHFHub
from .hf_hub import FromHF
from .string_generator import StringGenerator

# Only provided as typing.Self in Python 3.11+.
Self = TypeVar("Self", bound="DefaultGenerator")


class DefaultGenerator(Generic[CacheT], GeneratorWrapper, FromHFHub):
class DefaultGenerator(Generic[CacheT], GeneratorWrapper, FromHF):
"""
Generator wrapper for models that do not need specific prompting.
"""
Expand Down
4 changes: 2 additions & 2 deletions curated_transformers/generation/falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
from ..tokenizers.chunks import InputChunks, TextChunk
from ..tokenizers.tokenizer import Tokenizer
from .default_generator import DefaultGenerator
from .hf_hub import FromHFHub
from .hf_hub import FromHF

# Only provided as typing.Self in Python 3.11+.
Self = TypeVar("Self", bound="FalconGenerator")


class FalconGenerator(DefaultGenerator, FromHFHub):
class FalconGenerator(DefaultGenerator, FromHF):
"""
Generator for Falcon model variants.
"""
Expand Down
4 changes: 2 additions & 2 deletions curated_transformers/generation/hf_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from ..quantization.bnb.config import BitsAndBytesConfig

# Only provided as typing.Self in Python 3.11+.
Self = TypeVar("Self", bound="FromHFHub")
Self = TypeVar("Self", bound="FromHF")


class FromHFHub(ABC):
class FromHF(ABC):
"""
Mixin class for downloading generators from Hugging Face Hub.
Expand Down
4 changes: 2 additions & 2 deletions curated_transformers/generation/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
from ..models.llama import LlamaCausalLM
from ..tokenizers.tokenizer import Tokenizer
from .default_generator import DefaultGenerator
from .hf_hub import FromHFHub
from .hf_hub import FromHF

# Only provided as typing.Self in Python 3.11+.
Self = TypeVar("Self", bound="LlamaGenerator")


class LlamaGenerator(DefaultGenerator, FromHFHub):
class LlamaGenerator(DefaultGenerator, FromHF):
"""
Generator for Llama and Llama 2 model variants.
"""
Expand Down
4 changes: 2 additions & 2 deletions curated_transformers/generation/mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from ..models.mpt import MPTCausalLM
from ..tokenizers.tokenizer import Tokenizer
from .default_generator import DefaultGenerator
from .hf_hub import FromHFHub
from .hf_hub import FromHF


class MPTGenerator(DefaultGenerator, FromHFHub):
class MPTGenerator(DefaultGenerator, FromHF):
"""
Generator for MPT model variants.
"""
Expand Down
12 changes: 4 additions & 8 deletions curated_transformers/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from torch import Tensor
from torch.nn import Dropout, Linear, Module

from ..semver import Default, FutureMandatory
from .cache import KeyValueCache
from .embeddings import QueryKeyRotaryEmbeddings

Expand Down Expand Up @@ -346,7 +345,7 @@ def __init__(
*,
n_query_heads: int,
n_key_value_heads: int,
qkv_split: FutureMandatory[QkvSplit] = Default,
qkv_split: QkvSplit,
):
"""
Construct an attention head configuration. This constructor must
Expand All @@ -366,16 +365,13 @@ def __init__(
"""
self._n_query_heads = n_query_heads
self._n_key_value_heads = n_key_value_heads

qkv_split = QkvSplitGroupedByKVHeads() if qkv_split is Default else qkv_split
assert isinstance(qkv_split, QkvSplit)
self._qkv_split = qkv_split

@classmethod
def uniform(
cls,
n_attention_heads: int,
qkv_split: FutureMandatory[QkvSplit] = Default,
qkv_split: QkvSplit,
) -> "AttentionHeads":
"""
Construct a head configuration where query, key, and value have the
Expand All @@ -398,7 +394,7 @@ def uniform(
def multi_query(
cls,
n_query_heads: int,
qkv_split: FutureMandatory[QkvSplit] = Default,
qkv_split: QkvSplit,
) -> "AttentionHeads":
"""
Construct a multi-query attention configuration: key has one head,
Expand All @@ -425,7 +421,7 @@ def key_value_broadcast(
*,
n_query_heads: int,
n_key_value_heads: int,
qkv_split: FutureMandatory[QkvSplit] = Default,
qkv_split: QkvSplit,
) -> "AttentionHeads":
"""
Construct a head configuration where query has a larger number
Expand Down
4 changes: 2 additions & 2 deletions curated_transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
)
from .falcon import FalconCausalLM, FalconConfig, FalconDecoder
from .gpt_neox import GPTNeoXCausalLM, GPTNeoXConfig, GPTNeoXDecoder
from .hf_hub import FromHFHub
from .hf_hub import FromHF
from .llama import LlamaCausalLM, LlamaConfig, LlamaDecoder
from .module import CausalLMModule, DecoderModule, EncoderModule
from .mpt import MPTCausalLM, MPTConfig, MPTDecoder
Expand All @@ -37,7 +37,7 @@
"FalconCausalLM",
"FalconConfig",
"FalconDecoder",
"FromHFHub",
"FromHF",
"GPTNeoXCausalLM",
"GPTNeoXConfig",
"GPTNeoXDecoder",
Expand Down
4 changes: 2 additions & 2 deletions curated_transformers/models/albert/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
EmbeddingLayerNorms,
TransformerEmbeddings,
)
from ..hf_hub import FromHFHub
from ..hf_hub import FromHF
from ..hf_hub.conversion import state_dict_from_hf, state_dict_to_hf
from ..module import EncoderModule
from ..output import ModelOutput
Expand All @@ -22,7 +22,7 @@
Self = TypeVar("Self", bound="ALBERTEncoder")


class ALBERTEncoder(EncoderModule[ALBERTConfig], FromHFHub[ALBERTConfig]):
class ALBERTEncoder(EncoderModule[ALBERTConfig], FromHF[ALBERTConfig]):
"""
ALBERT (`Lan et al., 2022`_) encoder.
Expand Down
4 changes: 3 additions & 1 deletion curated_transformers/models/albert/layer_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
AttentionHeads,
AttentionMask,
QkvMode,
QkvSplitGroupedByKVHeads,
ScaledDotProductAttention,
SelfAttention,
)
Expand Down Expand Up @@ -45,7 +46,8 @@ def __init__(
EncoderLayer(
attention_layer=SelfAttention(
attention_heads=AttentionHeads.uniform(
attention_config.n_query_heads
attention_config.n_query_heads,
QkvSplitGroupedByKVHeads(),
),
attention_scorer=ScaledDotProductAttention(
dropout_prob=attention_config.dropout_prob,
Expand Down
10 changes: 5 additions & 5 deletions curated_transformers/models/auto_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ..repository.hf_hub import HfHubRepository
from ..repository.repository import ModelRepository, Repository
from .config import TransformerConfig
from .hf_hub import FromHFHub
from .hf_hub import FromHF
from .module import CausalLMModule, DecoderModule, EncoderModule, TransformerModule

ModelT = TypeVar("ModelT")
Expand All @@ -33,14 +33,14 @@ class AutoModel(ABC, Generic[ModelT]):
def _resolve_model_cls(
cls,
repo: ModelRepository,
) -> Type[FromHFHub]:
) -> Type[FromHF]:
config = repo.model_config()

for entrypoint, module_cls in cls._registry.get_entry_points().items():
if not issubclass(module_cls, FromHFHub):
if not issubclass(module_cls, FromHF):
warnings.warn(
f"Entry point `{entrypoint}` cannot load from Hugging Face Hub "
"since the FromHFHub mixin is not implemented"
"since the FromHF mixin is not implemented"
)
continue

Expand Down Expand Up @@ -70,7 +70,7 @@ def _instantiate_model(
repo: Repository,
device: Optional[torch.device],
quantization_config: Optional[BitsAndBytesConfig],
) -> FromHFHub:
) -> FromHF:
module_cls = cls._resolve_model_cls(ModelRepository(repo))
module = module_cls.from_repo(
repo=repo,
Expand Down
8 changes: 5 additions & 3 deletions curated_transformers/models/bert/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ...layers.attention import (
AttentionHeads,
QkvMode,
QkvSplitGroupedByKVHeads,
ScaledDotProductAttention,
SelfAttention,
)
Expand All @@ -20,7 +21,7 @@
TransformerEmbeddings,
TransformerLayerNorms,
)
from ..hf_hub import FromHFHub
from ..hf_hub import FromHF
from ..hf_hub.conversion import state_dict_from_hf, state_dict_to_hf
from ..transformer import TransformerEncoder
from ._hf import HF_PARAM_KEY_TRANSFORMS, _config_from_hf, _config_to_hf
Expand All @@ -30,7 +31,7 @@
Self = TypeVar("Self", bound="BERTEncoder")


class BERTEncoder(TransformerEncoder[BERTConfig], FromHFHub[BERTConfig]):
class BERTEncoder(TransformerEncoder[BERTConfig], FromHF[BERTConfig]):
"""
BERT (`Devlin et al., 2018`_) encoder.
Expand Down Expand Up @@ -85,7 +86,8 @@ def __init__(
EncoderLayer(
attention_layer=SelfAttention(
attention_heads=AttentionHeads.uniform(
config.layer.attention.n_query_heads
config.layer.attention.n_query_heads,
QkvSplitGroupedByKVHeads(),
),
attention_scorer=ScaledDotProductAttention(
dropout_prob=config.layer.attention.dropout_prob,
Expand Down
4 changes: 2 additions & 2 deletions curated_transformers/models/falcon/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch.nn import Linear

from ...quantization.quantizable import Quantizable
from ..hf_hub import FromHFHub
from ..hf_hub import FromHF
from ..hf_hub.conversion import state_dict_from_hf, state_dict_to_hf
from ..transformer import TransformerCausalLM
from ._hf import CAUSAL_LM_HF_PARAM_KEY_TRANSFORMS, _config_from_hf, _config_to_hf
Expand All @@ -17,7 +17,7 @@


class FalconCausalLM(
TransformerCausalLM[FalconConfig], FromHFHub[FalconConfig], Quantizable
TransformerCausalLM[FalconConfig], FromHF[FalconConfig], Quantizable
):
"""
Falcon (`Penedo et al., 2019`_) causal language model.
Expand Down
6 changes: 4 additions & 2 deletions curated_transformers/models/falcon/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
AttentionHeads,
AttentionLinearBiases,
QkvMode,
QkvSplitGroupedByKVHeads,
ScaledDotProductAttention,
SelfAttention,
)
Expand All @@ -22,7 +23,7 @@
TransformerEmbeddings,
TransformerLayerNorms,
)
from ..hf_hub import FromHFHub
from ..hf_hub import FromHF
from ..hf_hub.conversion import state_dict_from_hf, state_dict_to_hf
from ..transformer import TransformerDecoder
from ._hf import DECODER_HF_PARAM_KEY_TRANSFORMS, _config_from_hf, _config_to_hf
Expand All @@ -33,7 +34,7 @@
Self = TypeVar("Self", bound="FalconDecoder")


class FalconDecoder(TransformerDecoder[FalconConfig], FromHFHub[FalconConfig]):
class FalconDecoder(TransformerDecoder[FalconConfig], FromHF[FalconConfig]):
"""
Falcon (`Penedo et al., 2019`_) decoder.
Expand Down Expand Up @@ -166,6 +167,7 @@ def _create_new_decoder_architecture_layer(
attention_heads=AttentionHeads.key_value_broadcast(
n_query_heads=n_attention_heads,
n_key_value_heads=config.layer.attention.n_key_value_heads,
qkv_split=QkvSplitGroupedByKVHeads(),
),
attention_scorer=ScaledDotProductAttention(
dropout_prob=config.layer.attention.dropout_prob,
Expand Down
2 changes: 2 additions & 0 deletions curated_transformers/models/falcon/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
AttentionMask,
KeyValueCache,
QkvMode,
QkvSplitGroupedByKVHeads,
ScaledDotProductAttention,
SelfAttention,
)
Expand Down Expand Up @@ -71,6 +72,7 @@ def __init__(
attention_heads=AttentionHeads.key_value_broadcast(
n_query_heads=attention_config.n_query_heads,
n_key_value_heads=attention_config.n_key_value_heads,
qkv_split=QkvSplitGroupedByKVHeads(),
),
rotary_embeds=rotary_embeds,
qkv_mode=(
Expand Down
4 changes: 2 additions & 2 deletions curated_transformers/models/gpt_neox/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch.nn import Linear

from ...quantization import Quantizable
from ..hf_hub import FromHFHub
from ..hf_hub import FromHF
from ..hf_hub.conversion import state_dict_from_hf, state_dict_to_hf
from ..transformer import TransformerCausalLM
from ._hf import CAUSAL_LM_HF_PARAM_KEY_TRANSFORMS, _config_from_hf, _config_to_hf
Expand All @@ -17,7 +17,7 @@


class GPTNeoXCausalLM(
TransformerCausalLM[GPTNeoXConfig], FromHFHub[GPTNeoXConfig], Quantizable
TransformerCausalLM[GPTNeoXConfig], FromHF[GPTNeoXConfig], Quantizable
):
"""
GPT-NeoX (`Black et al., 2022`_) causal language model.
Expand Down
Loading

0 comments on commit 8debb21

Please sign in to comment.