diff --git a/detoxify/detoxify.py b/detoxify/detoxify.py index 92fe0c9..5837348 100644 --- a/detoxify/detoxify.py +++ b/detoxify/detoxify.py @@ -1,6 +1,7 @@ import torch import transformers + DOWNLOAD_URL = "https://github.com/unitaryai/detoxify/releases/download/" MODEL_URLS = { "original": DOWNLOAD_URL + "v0.1-alpha/toxic_original-c1212f89.ckpt", @@ -17,10 +18,10 @@ def get_model_and_tokenizer( model_type, model_name, tokenizer_name, num_classes, state_dict, huggingface_config_path=None ): model_class = getattr(transformers, model_name) + config = model_class.config_class.from_pretrained(model_type, num_labels=num_classes) model = model_class.from_pretrained( pretrained_model_name_or_path=None, - config=huggingface_config_path or model_type, - num_labels=num_classes, + config=huggingface_config_path or config, state_dict=state_dict, local_files_only=huggingface_config_path is not None, ) diff --git a/setup.py b/setup.py index 722a8db..8dee5a1 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,7 @@ author_email="laura@unitary.ai", url="https://github.com/unitaryai/detoxify", install_requires=[ - "transformers == 4.30.0", + "transformers", "torch >= 1.7.0", "sentencepiece >= 0.1.94", ],