diff --git a/src/config/accelerate_config.yaml b/src/config/accelerate_config.yaml index 8831e95..9f14171 100644 --- a/src/config/accelerate_config.yaml +++ b/src/config/accelerate_config.yaml @@ -3,11 +3,11 @@ distributed_type: FSDP downcast_bf16: 'no' fsdp_config: fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP - fsdp_backward_prefetch_policy: BACKWARD_PRE + fsdp_backward_prefetch: BACKWARD_PRE fsdp_offload_params: false fsdp_sharding_strategy: 1 fsdp_state_dict_type: FULL_STATE_DICT - fsdp_transformer_layer_cls_to_wrap: CustomOPTForCausalLM + fsdp_transformer_layer_cls_to_wrap: OPTForCausalLM machine_rank: 0 main_training_function: main mixed_precision: bf16 diff --git a/src/config/test_configs/fsdp_config.yaml b/src/config/test_configs/fsdp_config.yaml index 531ad87..8665e42 100644 --- a/src/config/test_configs/fsdp_config.yaml +++ b/src/config/test_configs/fsdp_config.yaml @@ -3,7 +3,7 @@ distributed_type: FSDP downcast_bf16: 'no' fsdp_config: fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP - fsdp_backward_prefetch_policy: BACKWARD_PRE + fsdp_backward_prefetch: BACKWARD_PRE fsdp_offload_params: false fsdp_sharding_strategy: 1 fsdp_state_dict_type: FULL_STATE_DICT