-
base_model
- Description: Path to the base model or
caikit.Resource
object of the base model to be used for tuning. A model-name string may also be provided. In this case, the Transformers API will automatically load it up from Hugging Face model cache if the model is locally available. If it is not available, the model may be downloaded by setting theALLOW_DOWNLOADS
environment variable totrue
. - Accepted values:
- The model needs to be of type causal-lm or seq2seq, thus loadable via HuggingFace
AutoModelForCausalLM
orAutoModelForSeq2SeqLM
loading methods.
- The model needs to be of type causal-lm or seq2seq, thus loadable via HuggingFace
- Description: Path to the base model or
-
tuning_type
: (str
orcaikit_nlp.modules.text_generation.TuningType
)- Type of Peft Tuning config which we would like to build.
- Accepted values:
PROMPT_TUNING
andMULTITASK_PROMPT_TUNING
- Default:
PROMPT_TUNING
-
num_epochs
: (int)- The number of epochs is the number of complete passes through the training dataset.
- quality depends a lot on number of epochs.
- Expose to end user recommendation: True
- Accepted values: any positive int
- Default: 20
-
device
: (str)- Training device to be used, could be
cpu
,cuda
, cuda specific device name - Expose to end user recommendation: False
- Training device to be used, could be
-
lr
/learning_rate
- (float) The name of the parameter soon to be changed to make it more intuitive.- Learning rate to be used for training
- Expose to end user recommendation: True
-
accumulate_steps
:- Number of steps to be used for gradient accumulation. Gradient accumulation refers to a method of collecting gradient for configured number of steps instead of updating the model variables at every step and then applying the update to model variables. This can be used as a tool to overcome smaller batch size limitation. Often also referred in conjuction with "effective batch size".
- Expose to end user recommendation: True
-
verbalizer
- Verbalizer template to be used for formatting data at train and inference time. This template may use brackets to indicate where fields from the data model TrainGenerationRecord must be rendered. Default: "{{input}}", i.e., the raw text.
- Default: "{{input}}", i.e., the raw text.
- Expose to end user recommendation: True
-
batch_size
:- The batch size is a number of samples processed before the model is updated.
- Default: 8
- Expose to end user recommendation: True
-
max_source_length
:- Max length of input sequences being considered.
- Default: 256.
- Expose to end user recommendation: True
-
max_target_length
:- Max length of target sequences being predicted. Default: 128.
- Default: 128
- Expose to end user recommendation: True
-
tuning_config.num_virtual_tokens
:- Number of virtual tokens to be used for training. In prompt tuning we are essentially learning the embedded representations for soft prompts, which are known as virtual tokens, via back propagation for a specific task(s) while keeping the rest of the model is fixed.
num_virtual_tokens
is the number of dimensions for these virtual tokens. - Expose to end user recommendation: True (default to be set by application layer end)
- This should also correspond to available source prompt, if source prompt exists, i.e a user need to select number of virtual token as per the source prompt available, in case they want to use MPT source prompts.
- Number of virtual tokens to be used for training. In prompt tuning we are essentially learning the embedded representations for soft prompts, which are known as virtual tokens, via back propagation for a specific task(s) while keeping the rest of the model is fixed.
-
tuning_config.prompt_tuning_init_method
:- Could be:
RANDOM
,TEXT
,ONLY_SOURCE_SHARED
andAVERAGE_SOURCE
TEXT
requirestuning_config.prompt_tuning_init_text
to be setONLY_SOURCE_SHARED
andAVERAGE_SOURCE
requirestuning_config.prompt_tuning_init_source_model
to be set and source prompt model to be available for the givenbase_model
- Default:
RANDOM
- Expose to end user recommendation: True
- Only
RANDOM
,TEXT
andAVERAGE_SOURCE
to be exposed whereAVERAGE_SOURCE
is only applicable for tuning method isMULTITASK_PROMPT_TUNING
- Only
- Could be:
-
tuning_config.prompt_tuning_init_text
:- Initialization text to be used IF
tuning_config.prompt_tuning_init_method
is set toTEXT
otherwise this will be ignored. - Default: NO Default.
- Expose to end user recommendation: True (if
TEXT
init_method is exposed to customers)
- Initialization text to be used IF
-
tuning_config.prompt_tuning_init_source_model
:- Path pointing to the source prompt model. This path is relative to
config.source_prompt_base
(or SOURCE_PROMPT_BASE` env variable) - The source model selection needs to correspond to the
base_model
. - There may be cases where we have multiple source prompts available for a given model, in which case, their selection criteria needs to be determined.
- Default Would depend on the
base_model
. IfMULTITASK_PROMPT_TUNING
is not selected as the tuning type, then this will be ignored.
- Path pointing to the source prompt model. This path is relative to
-
tuning_config.output_model_types
:List(str)
- Could contain a list containing string
ENCODER
,DECODER
. - Acceptable values for types of models:
- CausalLM:
["DECODER"]
- Seq2Seq:
["ENCODER"]
,["DECODER"]
,["ENCODER", "DECODER"]
- CausalLM:
- Default:
- CausalLM:
[DECODER]
- Seq2Seq:
[ENCODER]
- CausalLM:
- Expose to end user recommendation: False
- Could contain a list containing string
-
torch_dtype
: (str)- Datatype to use for training of the underlying text generation model. If no value is provided, we pull from torch_dtype in config. If an in memory resource is provided which does not match the specified data type, the model underpinning the resource will be converted in place to the correct torch dtype.
- Expose to end user recommendation: False
- Recommended to be configured at environment or server configuration level.
-
silence_progress_bars
(bool)- Toggle to control progress bar for training. This is relevant to only "python user experience" and doesn't apply training via caikit runtime.
- Expose to end user recommendation: False