From f39445c0be5906b79bab9d9aa2e1ba956ee13c51 Mon Sep 17 00:00:00 2001 From: Chang-Hong Hsu Date: Tue, 6 Oct 2020 17:05:40 -0700 Subject: [PATCH] Adding necessary fields in training job message for distributed training (#84) * `distributed_protocol` added as a field in training job resource config --- .../plugins/sagemaker/training_job.pb.cc | 353 ++++++++- .../plugins/sagemaker/training_job.pb.h | 194 ++++- .../plugins/sagemaker/training_job.pb.go | 169 ++-- .../sagemaker/training_job.pb.validate.go | 69 ++ .../sagemaker/TrainingJobOuterClass.java | 732 +++++++++++++++++- .../plugins/sagemaker/training_job.proto.rst | 56 +- .../plugins/sagemaker/training_job_pb2.py | 74 +- package.json | 2 +- .../plugins/sagemaker/training_job.proto | 19 + setup.py | 2 +- 10 files changed, 1578 insertions(+), 92 deletions(-) diff --git a/gen/pb-cpp/flyteidl/plugins/sagemaker/training_job.pb.cc b/gen/pb-cpp/flyteidl/plugins/sagemaker/training_job.pb.cc index 7f846625f..449f9dac0 100644 --- a/gen/pb-cpp/flyteidl/plugins/sagemaker/training_job.pb.cc +++ b/gen/pb-cpp/flyteidl/plugins/sagemaker/training_job.pb.cc @@ -42,6 +42,10 @@ class AlgorithmSpecificationDefaultTypeInternal { public: ::google::protobuf::internal::ExplicitlyConstructed _instance; } _AlgorithmSpecification_default_instance_; +class DistributedProtocolDefaultTypeInternal { + public: + ::google::protobuf::internal::ExplicitlyConstructed _instance; +} _DistributedProtocol_default_instance_; class TrainingJobResourceConfigDefaultTypeInternal { public: ::google::protobuf::internal::ExplicitlyConstructed _instance; @@ -124,6 +128,20 @@ ::google::protobuf::internal::SCCInfo<1> scc_info_AlgorithmSpecification_flyteid {{ATOMIC_VAR_INIT(::google::protobuf::internal::SCCInfoBase::kUninitialized), 1, InitDefaultsAlgorithmSpecification_flyteidl_2fplugins_2fsagemaker_2ftraining_5fjob_2eproto}, { &scc_info_MetricDefinition_flyteidl_2fplugins_2fsagemaker_2ftraining_5fjob_2eproto.base,}}; +static void InitDefaultsDistributedProtocol_flyteidl_2fplugins_2fsagemaker_2ftraining_5fjob_2eproto() { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + { + void* ptr = &::flyteidl::plugins::sagemaker::_DistributedProtocol_default_instance_; + new (ptr) ::flyteidl::plugins::sagemaker::DistributedProtocol(); + ::google::protobuf::internal::OnShutdownDestroyMessage(ptr); + } + ::flyteidl::plugins::sagemaker::DistributedProtocol::InitAsDefaultInstance(); +} + +::google::protobuf::internal::SCCInfo<0> scc_info_DistributedProtocol_flyteidl_2fplugins_2fsagemaker_2ftraining_5fjob_2eproto = + {{ATOMIC_VAR_INIT(::google::protobuf::internal::SCCInfoBase::kUninitialized), 0, InitDefaultsDistributedProtocol_flyteidl_2fplugins_2fsagemaker_2ftraining_5fjob_2eproto}, {}}; + static void InitDefaultsTrainingJobResourceConfig_flyteidl_2fplugins_2fsagemaker_2ftraining_5fjob_2eproto() { GOOGLE_PROTOBUF_VERIFY_VERSION; @@ -160,12 +178,13 @@ void InitDefaults_flyteidl_2fplugins_2fsagemaker_2ftraining_5fjob_2eproto() { ::google::protobuf::internal::InitSCC(&scc_info_InputContentType_flyteidl_2fplugins_2fsagemaker_2ftraining_5fjob_2eproto.base); ::google::protobuf::internal::InitSCC(&scc_info_MetricDefinition_flyteidl_2fplugins_2fsagemaker_2ftraining_5fjob_2eproto.base); ::google::protobuf::internal::InitSCC(&scc_info_AlgorithmSpecification_flyteidl_2fplugins_2fsagemaker_2ftraining_5fjob_2eproto.base); + ::google::protobuf::internal::InitSCC(&scc_info_DistributedProtocol_flyteidl_2fplugins_2fsagemaker_2ftraining_5fjob_2eproto.base); ::google::protobuf::internal::InitSCC(&scc_info_TrainingJobResourceConfig_flyteidl_2fplugins_2fsagemaker_2ftraining_5fjob_2eproto.base); ::google::protobuf::internal::InitSCC(&scc_info_TrainingJob_flyteidl_2fplugins_2fsagemaker_2ftraining_5fjob_2eproto.base); } -::google::protobuf::Metadata file_level_metadata_flyteidl_2fplugins_2fsagemaker_2ftraining_5fjob_2eproto[7]; -const ::google::protobuf::EnumDescriptor* file_level_enum_descriptors_flyteidl_2fplugins_2fsagemaker_2ftraining_5fjob_2eproto[3]; +::google::protobuf::Metadata file_level_metadata_flyteidl_2fplugins_2fsagemaker_2ftraining_5fjob_2eproto[8]; +const ::google::protobuf::EnumDescriptor* file_level_enum_descriptors_flyteidl_2fplugins_2fsagemaker_2ftraining_5fjob_2eproto[4]; constexpr ::google::protobuf::ServiceDescriptor const** file_level_service_descriptors_flyteidl_2fplugins_2fsagemaker_2ftraining_5fjob_2eproto = nullptr; const ::google::protobuf::uint32 TableStruct_flyteidl_2fplugins_2fsagemaker_2ftraining_5fjob_2eproto::offsets[] PROTOBUF_SECTION_VARIABLE(protodesc_cold) = { @@ -202,6 +221,11 @@ const ::google::protobuf::uint32 TableStruct_flyteidl_2fplugins_2fsagemaker_2ftr PROTOBUF_FIELD_OFFSET(::flyteidl::plugins::sagemaker::AlgorithmSpecification, metric_definitions_), PROTOBUF_FIELD_OFFSET(::flyteidl::plugins::sagemaker::AlgorithmSpecification, input_content_type_), ~0u, // no _has_bits_ + PROTOBUF_FIELD_OFFSET(::flyteidl::plugins::sagemaker::DistributedProtocol, _internal_metadata_), + ~0u, // no _extensions_ + ~0u, // no _oneof_case_ + ~0u, // no _weak_field_map_ + ~0u, // no _has_bits_ PROTOBUF_FIELD_OFFSET(::flyteidl::plugins::sagemaker::TrainingJobResourceConfig, _internal_metadata_), ~0u, // no _extensions_ ~0u, // no _oneof_case_ @@ -209,6 +233,7 @@ const ::google::protobuf::uint32 TableStruct_flyteidl_2fplugins_2fsagemaker_2ftr PROTOBUF_FIELD_OFFSET(::flyteidl::plugins::sagemaker::TrainingJobResourceConfig, instance_count_), PROTOBUF_FIELD_OFFSET(::flyteidl::plugins::sagemaker::TrainingJobResourceConfig, instance_type_), PROTOBUF_FIELD_OFFSET(::flyteidl::plugins::sagemaker::TrainingJobResourceConfig, volume_size_in_gb_), + PROTOBUF_FIELD_OFFSET(::flyteidl::plugins::sagemaker::TrainingJobResourceConfig, distributed_protocol_), ~0u, // no _has_bits_ PROTOBUF_FIELD_OFFSET(::flyteidl::plugins::sagemaker::TrainingJob, _internal_metadata_), ~0u, // no _extensions_ @@ -223,8 +248,9 @@ static const ::google::protobuf::internal::MigrationSchema schemas[] PROTOBUF_SE { 10, -1, sizeof(::flyteidl::plugins::sagemaker::InputContentType)}, { 15, -1, sizeof(::flyteidl::plugins::sagemaker::MetricDefinition)}, { 22, -1, sizeof(::flyteidl::plugins::sagemaker::AlgorithmSpecification)}, - { 32, -1, sizeof(::flyteidl::plugins::sagemaker::TrainingJobResourceConfig)}, - { 40, -1, sizeof(::flyteidl::plugins::sagemaker::TrainingJob)}, + { 32, -1, sizeof(::flyteidl::plugins::sagemaker::DistributedProtocol)}, + { 37, -1, sizeof(::flyteidl::plugins::sagemaker::TrainingJobResourceConfig)}, + { 46, -1, sizeof(::flyteidl::plugins::sagemaker::TrainingJob)}, }; static ::google::protobuf::Message const * const file_default_instances[] = { @@ -233,6 +259,7 @@ static ::google::protobuf::Message const * const file_default_instances[] = { reinterpret_cast(&::flyteidl::plugins::sagemaker::_InputContentType_default_instance_), reinterpret_cast(&::flyteidl::plugins::sagemaker::_MetricDefinition_default_instance_), reinterpret_cast(&::flyteidl::plugins::sagemaker::_AlgorithmSpecification_default_instance_), + reinterpret_cast(&::flyteidl::plugins::sagemaker::_DistributedProtocol_default_instance_), reinterpret_cast(&::flyteidl::plugins::sagemaker::_TrainingJobResourceConfig_default_instance_), reinterpret_cast(&::flyteidl::plugins::sagemaker::_TrainingJob_default_instance_), }; @@ -240,7 +267,7 @@ static ::google::protobuf::Message const * const file_default_instances[] = { ::google::protobuf::internal::AssignDescriptorsTable assign_descriptors_table_flyteidl_2fplugins_2fsagemaker_2ftraining_5fjob_2eproto = { {}, AddDescriptors_flyteidl_2fplugins_2fsagemaker_2ftraining_5fjob_2eproto, "flyteidl/plugins/sagemaker/training_job.proto", schemas, file_default_instances, TableStruct_flyteidl_2fplugins_2fsagemaker_2ftraining_5fjob_2eproto::offsets, - file_level_metadata_flyteidl_2fplugins_2fsagemaker_2ftraining_5fjob_2eproto, 7, file_level_enum_descriptors_flyteidl_2fplugins_2fsagemaker_2ftraining_5fjob_2eproto, file_level_service_descriptors_flyteidl_2fplugins_2fsagemaker_2ftraining_5fjob_2eproto, + file_level_metadata_flyteidl_2fplugins_2fsagemaker_2ftraining_5fjob_2eproto, 8, file_level_enum_descriptors_flyteidl_2fplugins_2fsagemaker_2ftraining_5fjob_2eproto, file_level_service_descriptors_flyteidl_2fplugins_2fsagemaker_2ftraining_5fjob_2eproto, }; const char descriptor_table_protodef_flyteidl_2fplugins_2fsagemaker_2ftraining_5fjob_2eproto[] = @@ -259,21 +286,25 @@ const char descriptor_table_protodef_flyteidl_2fplugins_2fsagemaker_2ftraining_5 " \001(\t\022H\n\022metric_definitions\030\004 \003(\0132,.flyte" "idl.plugins.sagemaker.MetricDefinition\022N" "\n\022input_content_type\030\005 \001(\01622.flyteidl.pl" - "ugins.sagemaker.InputContentType.Value\"e" - "\n\031TrainingJobResourceConfig\022\026\n\016instance_" - "count\030\001 \001(\003\022\025\n\rinstance_type\030\002 \001(\t\022\031\n\021vo" - "lume_size_in_gb\030\003 \001(\003\"\277\001\n\013TrainingJob\022S\n" - "\027algorithm_specification\030\001 \001(\01322.flyteid" - "l.plugins.sagemaker.AlgorithmSpecificati" - "on\022[\n\034training_job_resource_config\030\002 \001(\013" - "25.flyteidl.plugins.sagemaker.TrainingJo" - "bResourceConfigB5Z3github.com/lyft/flyte" - "idl/gen/pb-go/flyteidl/pluginsb\006proto3" + "ugins.sagemaker.InputContentType.Value\"8" + "\n\023DistributedProtocol\"!\n\005Value\022\017\n\013UNSPEC" + "IFIED\020\000\022\007\n\003MPI\020\001\"\272\001\n\031TrainingJobResource" + "Config\022\026\n\016instance_count\030\001 \001(\003\022\025\n\rinstan" + "ce_type\030\002 \001(\t\022\031\n\021volume_size_in_gb\030\003 \001(\003" + "\022S\n\024distributed_protocol\030\004 \001(\01625.flyteid" + "l.plugins.sagemaker.DistributedProtocol." + "Value\"\277\001\n\013TrainingJob\022S\n\027algorithm_speci" + "fication\030\001 \001(\01322.flyteidl.plugins.sagema" + "ker.AlgorithmSpecification\022[\n\034training_j" + "ob_resource_config\030\002 \001(\01325.flyteidl.plug" + "ins.sagemaker.TrainingJobResourceConfigB" + "5Z3github.com/lyft/flyteidl/gen/pb-go/fl" + "yteidl/pluginsb\006proto3" ; ::google::protobuf::internal::DescriptorTable descriptor_table_flyteidl_2fplugins_2fsagemaker_2ftraining_5fjob_2eproto = { false, InitDefaults_flyteidl_2fplugins_2fsagemaker_2ftraining_5fjob_2eproto, descriptor_table_protodef_flyteidl_2fplugins_2fsagemaker_2ftraining_5fjob_2eproto, - "flyteidl/plugins/sagemaker/training_job.proto", &assign_descriptors_table_flyteidl_2fplugins_2fsagemaker_2ftraining_5fjob_2eproto, 998, + "flyteidl/plugins/sagemaker/training_job.proto", &assign_descriptors_table_flyteidl_2fplugins_2fsagemaker_2ftraining_5fjob_2eproto, 1142, }; void AddDescriptors_flyteidl_2fplugins_2fsagemaker_2ftraining_5fjob_2eproto() { @@ -350,6 +381,27 @@ const InputContentType_Value InputContentType::Value_MIN; const InputContentType_Value InputContentType::Value_MAX; const int InputContentType::Value_ARRAYSIZE; #endif // !defined(_MSC_VER) || _MSC_VER >= 1900 +const ::google::protobuf::EnumDescriptor* DistributedProtocol_Value_descriptor() { + ::google::protobuf::internal::AssignDescriptors(&assign_descriptors_table_flyteidl_2fplugins_2fsagemaker_2ftraining_5fjob_2eproto); + return file_level_enum_descriptors_flyteidl_2fplugins_2fsagemaker_2ftraining_5fjob_2eproto[3]; +} +bool DistributedProtocol_Value_IsValid(int value) { + switch (value) { + case 0: + case 1: + return true; + default: + return false; + } +} + +#if !defined(_MSC_VER) || _MSC_VER >= 1900 +const DistributedProtocol_Value DistributedProtocol::UNSPECIFIED; +const DistributedProtocol_Value DistributedProtocol::MPI; +const DistributedProtocol_Value DistributedProtocol::Value_MIN; +const DistributedProtocol_Value DistributedProtocol::Value_MAX; +const int DistributedProtocol::Value_ARRAYSIZE; +#endif // !defined(_MSC_VER) || _MSC_VER >= 1900 // =================================================================== @@ -1849,6 +1901,215 @@ ::google::protobuf::Metadata AlgorithmSpecification::GetMetadata() const { } +// =================================================================== + +void DistributedProtocol::InitAsDefaultInstance() { +} +class DistributedProtocol::HasBitSetters { + public: +}; + +#if !defined(_MSC_VER) || _MSC_VER >= 1900 +#endif // !defined(_MSC_VER) || _MSC_VER >= 1900 + +DistributedProtocol::DistributedProtocol() + : ::google::protobuf::Message(), _internal_metadata_(nullptr) { + SharedCtor(); + // @@protoc_insertion_point(constructor:flyteidl.plugins.sagemaker.DistributedProtocol) +} +DistributedProtocol::DistributedProtocol(const DistributedProtocol& from) + : ::google::protobuf::Message(), + _internal_metadata_(nullptr) { + _internal_metadata_.MergeFrom(from._internal_metadata_); + // @@protoc_insertion_point(copy_constructor:flyteidl.plugins.sagemaker.DistributedProtocol) +} + +void DistributedProtocol::SharedCtor() { +} + +DistributedProtocol::~DistributedProtocol() { + // @@protoc_insertion_point(destructor:flyteidl.plugins.sagemaker.DistributedProtocol) + SharedDtor(); +} + +void DistributedProtocol::SharedDtor() { +} + +void DistributedProtocol::SetCachedSize(int size) const { + _cached_size_.Set(size); +} +const DistributedProtocol& DistributedProtocol::default_instance() { + ::google::protobuf::internal::InitSCC(&::scc_info_DistributedProtocol_flyteidl_2fplugins_2fsagemaker_2ftraining_5fjob_2eproto.base); + return *internal_default_instance(); +} + + +void DistributedProtocol::Clear() { +// @@protoc_insertion_point(message_clear_start:flyteidl.plugins.sagemaker.DistributedProtocol) + ::google::protobuf::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + _internal_metadata_.Clear(); +} + +#if GOOGLE_PROTOBUF_ENABLE_EXPERIMENTAL_PARSER +const char* DistributedProtocol::_InternalParse(const char* begin, const char* end, void* object, + ::google::protobuf::internal::ParseContext* ctx) { + auto msg = static_cast(object); + ::google::protobuf::int32 size; (void)size; + int depth; (void)depth; + ::google::protobuf::uint32 tag; + ::google::protobuf::internal::ParseFunc parser_till_end; (void)parser_till_end; + auto ptr = begin; + while (ptr < end) { + ptr = ::google::protobuf::io::Parse32(ptr, &tag); + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); + switch (tag >> 3) { + default: { + if ((tag & 7) == 4 || tag == 0) { + ctx->EndGroup(tag); + return ptr; + } + auto res = UnknownFieldParse(tag, {_InternalParse, msg}, + ptr, end, msg->_internal_metadata_.mutable_unknown_fields(), ctx); + ptr = res.first; + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr != nullptr); + if (res.second) return ptr; + } + } // switch + } // while + return ptr; +} +#else // GOOGLE_PROTOBUF_ENABLE_EXPERIMENTAL_PARSER +bool DistributedProtocol::MergePartialFromCodedStream( + ::google::protobuf::io::CodedInputStream* input) { +#define DO_(EXPRESSION) if (!PROTOBUF_PREDICT_TRUE(EXPRESSION)) goto failure + ::google::protobuf::uint32 tag; + // @@protoc_insertion_point(parse_start:flyteidl.plugins.sagemaker.DistributedProtocol) + for (;;) { + ::std::pair<::google::protobuf::uint32, bool> p = input->ReadTagWithCutoffNoLastTag(127u); + tag = p.first; + if (!p.second) goto handle_unusual; + handle_unusual: + if (tag == 0) { + goto success; + } + DO_(::google::protobuf::internal::WireFormat::SkipField( + input, tag, _internal_metadata_.mutable_unknown_fields())); + } +success: + // @@protoc_insertion_point(parse_success:flyteidl.plugins.sagemaker.DistributedProtocol) + return true; +failure: + // @@protoc_insertion_point(parse_failure:flyteidl.plugins.sagemaker.DistributedProtocol) + return false; +#undef DO_ +} +#endif // GOOGLE_PROTOBUF_ENABLE_EXPERIMENTAL_PARSER + +void DistributedProtocol::SerializeWithCachedSizes( + ::google::protobuf::io::CodedOutputStream* output) const { + // @@protoc_insertion_point(serialize_start:flyteidl.plugins.sagemaker.DistributedProtocol) + ::google::protobuf::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + if (_internal_metadata_.have_unknown_fields()) { + ::google::protobuf::internal::WireFormat::SerializeUnknownFields( + _internal_metadata_.unknown_fields(), output); + } + // @@protoc_insertion_point(serialize_end:flyteidl.plugins.sagemaker.DistributedProtocol) +} + +::google::protobuf::uint8* DistributedProtocol::InternalSerializeWithCachedSizesToArray( + ::google::protobuf::uint8* target) const { + // @@protoc_insertion_point(serialize_to_array_start:flyteidl.plugins.sagemaker.DistributedProtocol) + ::google::protobuf::uint32 cached_has_bits = 0; + (void) cached_has_bits; + + if (_internal_metadata_.have_unknown_fields()) { + target = ::google::protobuf::internal::WireFormat::SerializeUnknownFieldsToArray( + _internal_metadata_.unknown_fields(), target); + } + // @@protoc_insertion_point(serialize_to_array_end:flyteidl.plugins.sagemaker.DistributedProtocol) + return target; +} + +size_t DistributedProtocol::ByteSizeLong() const { +// @@protoc_insertion_point(message_byte_size_start:flyteidl.plugins.sagemaker.DistributedProtocol) + size_t total_size = 0; + + if (_internal_metadata_.have_unknown_fields()) { + total_size += + ::google::protobuf::internal::WireFormat::ComputeUnknownFieldsSize( + _internal_metadata_.unknown_fields()); + } + ::google::protobuf::uint32 cached_has_bits = 0; + // Prevent compiler warnings about cached_has_bits being unused + (void) cached_has_bits; + + int cached_size = ::google::protobuf::internal::ToCachedSize(total_size); + SetCachedSize(cached_size); + return total_size; +} + +void DistributedProtocol::MergeFrom(const ::google::protobuf::Message& from) { +// @@protoc_insertion_point(generalized_merge_from_start:flyteidl.plugins.sagemaker.DistributedProtocol) + GOOGLE_DCHECK_NE(&from, this); + const DistributedProtocol* source = + ::google::protobuf::DynamicCastToGenerated( + &from); + if (source == nullptr) { + // @@protoc_insertion_point(generalized_merge_from_cast_fail:flyteidl.plugins.sagemaker.DistributedProtocol) + ::google::protobuf::internal::ReflectionOps::Merge(from, this); + } else { + // @@protoc_insertion_point(generalized_merge_from_cast_success:flyteidl.plugins.sagemaker.DistributedProtocol) + MergeFrom(*source); + } +} + +void DistributedProtocol::MergeFrom(const DistributedProtocol& from) { +// @@protoc_insertion_point(class_specific_merge_from_start:flyteidl.plugins.sagemaker.DistributedProtocol) + GOOGLE_DCHECK_NE(&from, this); + _internal_metadata_.MergeFrom(from._internal_metadata_); + ::google::protobuf::uint32 cached_has_bits = 0; + (void) cached_has_bits; + +} + +void DistributedProtocol::CopyFrom(const ::google::protobuf::Message& from) { +// @@protoc_insertion_point(generalized_copy_from_start:flyteidl.plugins.sagemaker.DistributedProtocol) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void DistributedProtocol::CopyFrom(const DistributedProtocol& from) { +// @@protoc_insertion_point(class_specific_copy_from_start:flyteidl.plugins.sagemaker.DistributedProtocol) + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool DistributedProtocol::IsInitialized() const { + return true; +} + +void DistributedProtocol::Swap(DistributedProtocol* other) { + if (other == this) return; + InternalSwap(other); +} +void DistributedProtocol::InternalSwap(DistributedProtocol* other) { + using std::swap; + _internal_metadata_.Swap(&other->_internal_metadata_); +} + +::google::protobuf::Metadata DistributedProtocol::GetMetadata() const { + ::google::protobuf::internal::AssignDescriptors(&::assign_descriptors_table_flyteidl_2fplugins_2fsagemaker_2ftraining_5fjob_2eproto); + return ::file_level_metadata_flyteidl_2fplugins_2fsagemaker_2ftraining_5fjob_2eproto[kIndexInFileMessages]; +} + + // =================================================================== void TrainingJobResourceConfig::InitAsDefaultInstance() { @@ -1861,6 +2122,7 @@ class TrainingJobResourceConfig::HasBitSetters { const int TrainingJobResourceConfig::kInstanceCountFieldNumber; const int TrainingJobResourceConfig::kInstanceTypeFieldNumber; const int TrainingJobResourceConfig::kVolumeSizeInGbFieldNumber; +const int TrainingJobResourceConfig::kDistributedProtocolFieldNumber; #endif // !defined(_MSC_VER) || _MSC_VER >= 1900 TrainingJobResourceConfig::TrainingJobResourceConfig() @@ -1877,8 +2139,8 @@ TrainingJobResourceConfig::TrainingJobResourceConfig(const TrainingJobResourceCo instance_type_.AssignWithDefault(&::google::protobuf::internal::GetEmptyStringAlreadyInited(), from.instance_type_); } ::memcpy(&instance_count_, &from.instance_count_, - static_cast(reinterpret_cast(&volume_size_in_gb_) - - reinterpret_cast(&instance_count_)) + sizeof(volume_size_in_gb_)); + static_cast(reinterpret_cast(&distributed_protocol_) - + reinterpret_cast(&instance_count_)) + sizeof(distributed_protocol_)); // @@protoc_insertion_point(copy_constructor:flyteidl.plugins.sagemaker.TrainingJobResourceConfig) } @@ -1887,8 +2149,8 @@ void TrainingJobResourceConfig::SharedCtor() { &scc_info_TrainingJobResourceConfig_flyteidl_2fplugins_2fsagemaker_2ftraining_5fjob_2eproto.base); instance_type_.UnsafeSetDefault(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); ::memset(&instance_count_, 0, static_cast( - reinterpret_cast(&volume_size_in_gb_) - - reinterpret_cast(&instance_count_)) + sizeof(volume_size_in_gb_)); + reinterpret_cast(&distributed_protocol_) - + reinterpret_cast(&instance_count_)) + sizeof(distributed_protocol_)); } TrainingJobResourceConfig::~TrainingJobResourceConfig() { @@ -1917,8 +2179,8 @@ void TrainingJobResourceConfig::Clear() { instance_type_.ClearToEmptyNoArena(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); ::memset(&instance_count_, 0, static_cast( - reinterpret_cast(&volume_size_in_gb_) - - reinterpret_cast(&instance_count_)) + sizeof(volume_size_in_gb_)); + reinterpret_cast(&distributed_protocol_) - + reinterpret_cast(&instance_count_)) + sizeof(distributed_protocol_)); _internal_metadata_.Clear(); } @@ -1965,6 +2227,14 @@ const char* TrainingJobResourceConfig::_InternalParse(const char* begin, const c GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); break; } + // .flyteidl.plugins.sagemaker.DistributedProtocol.Value distributed_protocol = 4; + case 4: { + if (static_cast<::google::protobuf::uint8>(tag) != 32) goto handle_unusual; + ::google::protobuf::uint64 val = ::google::protobuf::internal::ReadVarint(&ptr); + msg->set_distributed_protocol(static_cast<::flyteidl::plugins::sagemaker::DistributedProtocol_Value>(val)); + GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); + break; + } default: { handle_unusual: if ((tag & 7) == 4 || tag == 0) { @@ -2040,6 +2310,20 @@ bool TrainingJobResourceConfig::MergePartialFromCodedStream( break; } + // .flyteidl.plugins.sagemaker.DistributedProtocol.Value distributed_protocol = 4; + case 4: { + if (static_cast< ::google::protobuf::uint8>(tag) == (32 & 0xFF)) { + int value = 0; + DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< + int, ::google::protobuf::internal::WireFormatLite::TYPE_ENUM>( + input, &value))); + set_distributed_protocol(static_cast< ::flyteidl::plugins::sagemaker::DistributedProtocol_Value >(value)); + } else { + goto handle_unusual; + } + break; + } + default: { handle_unusual: if (tag == 0) { @@ -2087,6 +2371,12 @@ void TrainingJobResourceConfig::SerializeWithCachedSizes( ::google::protobuf::internal::WireFormatLite::WriteInt64(3, this->volume_size_in_gb(), output); } + // .flyteidl.plugins.sagemaker.DistributedProtocol.Value distributed_protocol = 4; + if (this->distributed_protocol() != 0) { + ::google::protobuf::internal::WireFormatLite::WriteEnum( + 4, this->distributed_protocol(), output); + } + if (_internal_metadata_.have_unknown_fields()) { ::google::protobuf::internal::WireFormat::SerializeUnknownFields( _internal_metadata_.unknown_fields(), output); @@ -2121,6 +2411,12 @@ ::google::protobuf::uint8* TrainingJobResourceConfig::InternalSerializeWithCache target = ::google::protobuf::internal::WireFormatLite::WriteInt64ToArray(3, this->volume_size_in_gb(), target); } + // .flyteidl.plugins.sagemaker.DistributedProtocol.Value distributed_protocol = 4; + if (this->distributed_protocol() != 0) { + target = ::google::protobuf::internal::WireFormatLite::WriteEnumToArray( + 4, this->distributed_protocol(), target); + } + if (_internal_metadata_.have_unknown_fields()) { target = ::google::protobuf::internal::WireFormat::SerializeUnknownFieldsToArray( _internal_metadata_.unknown_fields(), target); @@ -2163,6 +2459,12 @@ size_t TrainingJobResourceConfig::ByteSizeLong() const { this->volume_size_in_gb()); } + // .flyteidl.plugins.sagemaker.DistributedProtocol.Value distributed_protocol = 4; + if (this->distributed_protocol() != 0) { + total_size += 1 + + ::google::protobuf::internal::WireFormatLite::EnumSize(this->distributed_protocol()); + } + int cached_size = ::google::protobuf::internal::ToCachedSize(total_size); SetCachedSize(cached_size); return total_size; @@ -2200,6 +2502,9 @@ void TrainingJobResourceConfig::MergeFrom(const TrainingJobResourceConfig& from) if (from.volume_size_in_gb() != 0) { set_volume_size_in_gb(from.volume_size_in_gb()); } + if (from.distributed_protocol() != 0) { + set_distributed_protocol(from.distributed_protocol()); + } } void TrainingJobResourceConfig::CopyFrom(const ::google::protobuf::Message& from) { @@ -2231,6 +2536,7 @@ void TrainingJobResourceConfig::InternalSwap(TrainingJobResourceConfig* other) { GetArenaNoVirtual()); swap(instance_count_, other->instance_count_); swap(volume_size_in_gb_, other->volume_size_in_gb_); + swap(distributed_protocol_, other->distributed_protocol_); } ::google::protobuf::Metadata TrainingJobResourceConfig::GetMetadata() const { @@ -2615,6 +2921,9 @@ template<> PROTOBUF_NOINLINE ::flyteidl::plugins::sagemaker::MetricDefinition* A template<> PROTOBUF_NOINLINE ::flyteidl::plugins::sagemaker::AlgorithmSpecification* Arena::CreateMaybeMessage< ::flyteidl::plugins::sagemaker::AlgorithmSpecification >(Arena* arena) { return Arena::CreateInternal< ::flyteidl::plugins::sagemaker::AlgorithmSpecification >(arena); } +template<> PROTOBUF_NOINLINE ::flyteidl::plugins::sagemaker::DistributedProtocol* Arena::CreateMaybeMessage< ::flyteidl::plugins::sagemaker::DistributedProtocol >(Arena* arena) { + return Arena::CreateInternal< ::flyteidl::plugins::sagemaker::DistributedProtocol >(arena); +} template<> PROTOBUF_NOINLINE ::flyteidl::plugins::sagemaker::TrainingJobResourceConfig* Arena::CreateMaybeMessage< ::flyteidl::plugins::sagemaker::TrainingJobResourceConfig >(Arena* arena) { return Arena::CreateInternal< ::flyteidl::plugins::sagemaker::TrainingJobResourceConfig >(arena); } diff --git a/gen/pb-cpp/flyteidl/plugins/sagemaker/training_job.pb.h b/gen/pb-cpp/flyteidl/plugins/sagemaker/training_job.pb.h index ec52b16c2..8fc367ba7 100644 --- a/gen/pb-cpp/flyteidl/plugins/sagemaker/training_job.pb.h +++ b/gen/pb-cpp/flyteidl/plugins/sagemaker/training_job.pb.h @@ -43,7 +43,7 @@ struct TableStruct_flyteidl_2fplugins_2fsagemaker_2ftraining_5fjob_2eproto { PROTOBUF_SECTION_VARIABLE(protodesc_cold); static const ::google::protobuf::internal::AuxillaryParseTableField aux[] PROTOBUF_SECTION_VARIABLE(protodesc_cold); - static const ::google::protobuf::internal::ParseTable schema[7] + static const ::google::protobuf::internal::ParseTable schema[8] PROTOBUF_SECTION_VARIABLE(protodesc_cold); static const ::google::protobuf::internal::FieldMetadata field_metadata[]; static const ::google::protobuf::internal::SerializationTable serialization_table[]; @@ -59,6 +59,9 @@ extern AlgorithmNameDefaultTypeInternal _AlgorithmName_default_instance_; class AlgorithmSpecification; class AlgorithmSpecificationDefaultTypeInternal; extern AlgorithmSpecificationDefaultTypeInternal _AlgorithmSpecification_default_instance_; +class DistributedProtocol; +class DistributedProtocolDefaultTypeInternal; +extern DistributedProtocolDefaultTypeInternal _DistributedProtocol_default_instance_; class InputContentType; class InputContentTypeDefaultTypeInternal; extern InputContentTypeDefaultTypeInternal _InputContentType_default_instance_; @@ -81,6 +84,7 @@ namespace google { namespace protobuf { template<> ::flyteidl::plugins::sagemaker::AlgorithmName* Arena::CreateMaybeMessage<::flyteidl::plugins::sagemaker::AlgorithmName>(Arena*); template<> ::flyteidl::plugins::sagemaker::AlgorithmSpecification* Arena::CreateMaybeMessage<::flyteidl::plugins::sagemaker::AlgorithmSpecification>(Arena*); +template<> ::flyteidl::plugins::sagemaker::DistributedProtocol* Arena::CreateMaybeMessage<::flyteidl::plugins::sagemaker::DistributedProtocol>(Arena*); template<> ::flyteidl::plugins::sagemaker::InputContentType* Arena::CreateMaybeMessage<::flyteidl::plugins::sagemaker::InputContentType>(Arena*); template<> ::flyteidl::plugins::sagemaker::InputMode* Arena::CreateMaybeMessage<::flyteidl::plugins::sagemaker::InputMode>(Arena*); template<> ::flyteidl::plugins::sagemaker::MetricDefinition* Arena::CreateMaybeMessage<::flyteidl::plugins::sagemaker::MetricDefinition>(Arena*); @@ -154,6 +158,27 @@ inline bool InputContentType_Value_Parse( return ::google::protobuf::internal::ParseNamedEnum( InputContentType_Value_descriptor(), name, value); } +enum DistributedProtocol_Value { + DistributedProtocol_Value_UNSPECIFIED = 0, + DistributedProtocol_Value_MPI = 1, + DistributedProtocol_Value_DistributedProtocol_Value_INT_MIN_SENTINEL_DO_NOT_USE_ = std::numeric_limits<::google::protobuf::int32>::min(), + DistributedProtocol_Value_DistributedProtocol_Value_INT_MAX_SENTINEL_DO_NOT_USE_ = std::numeric_limits<::google::protobuf::int32>::max() +}; +bool DistributedProtocol_Value_IsValid(int value); +const DistributedProtocol_Value DistributedProtocol_Value_Value_MIN = DistributedProtocol_Value_UNSPECIFIED; +const DistributedProtocol_Value DistributedProtocol_Value_Value_MAX = DistributedProtocol_Value_MPI; +const int DistributedProtocol_Value_Value_ARRAYSIZE = DistributedProtocol_Value_Value_MAX + 1; + +const ::google::protobuf::EnumDescriptor* DistributedProtocol_Value_descriptor(); +inline const ::std::string& DistributedProtocol_Value_Name(DistributedProtocol_Value value) { + return ::google::protobuf::internal::NameOfEnum( + DistributedProtocol_Value_descriptor(), value); +} +inline bool DistributedProtocol_Value_Parse( + const ::std::string& name, DistributedProtocol_Value* value) { + return ::google::protobuf::internal::ParseNamedEnum( + DistributedProtocol_Value_descriptor(), name, value); +} // =================================================================== class InputMode final : @@ -836,6 +861,137 @@ class AlgorithmSpecification final : }; // ------------------------------------------------------------------- +class DistributedProtocol final : + public ::google::protobuf::Message /* @@protoc_insertion_point(class_definition:flyteidl.plugins.sagemaker.DistributedProtocol) */ { + public: + DistributedProtocol(); + virtual ~DistributedProtocol(); + + DistributedProtocol(const DistributedProtocol& from); + + inline DistributedProtocol& operator=(const DistributedProtocol& from) { + CopyFrom(from); + return *this; + } + #if LANG_CXX11 + DistributedProtocol(DistributedProtocol&& from) noexcept + : DistributedProtocol() { + *this = ::std::move(from); + } + + inline DistributedProtocol& operator=(DistributedProtocol&& from) noexcept { + if (GetArenaNoVirtual() == from.GetArenaNoVirtual()) { + if (this != &from) InternalSwap(&from); + } else { + CopyFrom(from); + } + return *this; + } + #endif + static const ::google::protobuf::Descriptor* descriptor() { + return default_instance().GetDescriptor(); + } + static const DistributedProtocol& default_instance(); + + static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY + static inline const DistributedProtocol* internal_default_instance() { + return reinterpret_cast( + &_DistributedProtocol_default_instance_); + } + static constexpr int kIndexInFileMessages = + 5; + + void Swap(DistributedProtocol* other); + friend void swap(DistributedProtocol& a, DistributedProtocol& b) { + a.Swap(&b); + } + + // implements Message ---------------------------------------------- + + inline DistributedProtocol* New() const final { + return CreateMaybeMessage(nullptr); + } + + DistributedProtocol* New(::google::protobuf::Arena* arena) const final { + return CreateMaybeMessage(arena); + } + void CopyFrom(const ::google::protobuf::Message& from) final; + void MergeFrom(const ::google::protobuf::Message& from) final; + void CopyFrom(const DistributedProtocol& from); + void MergeFrom(const DistributedProtocol& from); + PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final; + bool IsInitialized() const final; + + size_t ByteSizeLong() const final; + #if GOOGLE_PROTOBUF_ENABLE_EXPERIMENTAL_PARSER + static const char* _InternalParse(const char* begin, const char* end, void* object, ::google::protobuf::internal::ParseContext* ctx); + ::google::protobuf::internal::ParseFunc _ParseFunc() const final { return _InternalParse; } + #else + bool MergePartialFromCodedStream( + ::google::protobuf::io::CodedInputStream* input) final; + #endif // GOOGLE_PROTOBUF_ENABLE_EXPERIMENTAL_PARSER + void SerializeWithCachedSizes( + ::google::protobuf::io::CodedOutputStream* output) const final; + ::google::protobuf::uint8* InternalSerializeWithCachedSizesToArray( + ::google::protobuf::uint8* target) const final; + int GetCachedSize() const final { return _cached_size_.Get(); } + + private: + void SharedCtor(); + void SharedDtor(); + void SetCachedSize(int size) const final; + void InternalSwap(DistributedProtocol* other); + private: + inline ::google::protobuf::Arena* GetArenaNoVirtual() const { + return nullptr; + } + inline void* MaybeArenaPtr() const { + return nullptr; + } + public: + + ::google::protobuf::Metadata GetMetadata() const final; + + // nested types ---------------------------------------------------- + + typedef DistributedProtocol_Value Value; + static const Value UNSPECIFIED = + DistributedProtocol_Value_UNSPECIFIED; + static const Value MPI = + DistributedProtocol_Value_MPI; + static inline bool Value_IsValid(int value) { + return DistributedProtocol_Value_IsValid(value); + } + static const Value Value_MIN = + DistributedProtocol_Value_Value_MIN; + static const Value Value_MAX = + DistributedProtocol_Value_Value_MAX; + static const int Value_ARRAYSIZE = + DistributedProtocol_Value_Value_ARRAYSIZE; + static inline const ::google::protobuf::EnumDescriptor* + Value_descriptor() { + return DistributedProtocol_Value_descriptor(); + } + static inline const ::std::string& Value_Name(Value value) { + return DistributedProtocol_Value_Name(value); + } + static inline bool Value_Parse(const ::std::string& name, + Value* value) { + return DistributedProtocol_Value_Parse(name, value); + } + + // accessors ------------------------------------------------------- + + // @@protoc_insertion_point(class_scope:flyteidl.plugins.sagemaker.DistributedProtocol) + private: + class HasBitSetters; + + ::google::protobuf::internal::InternalMetadataWithArena _internal_metadata_; + mutable ::google::protobuf::internal::CachedSize _cached_size_; + friend struct ::TableStruct_flyteidl_2fplugins_2fsagemaker_2ftraining_5fjob_2eproto; +}; +// ------------------------------------------------------------------- + class TrainingJobResourceConfig final : public ::google::protobuf::Message /* @@protoc_insertion_point(class_definition:flyteidl.plugins.sagemaker.TrainingJobResourceConfig) */ { public: @@ -874,7 +1030,7 @@ class TrainingJobResourceConfig final : &_TrainingJobResourceConfig_default_instance_); } static constexpr int kIndexInFileMessages = - 5; + 6; void Swap(TrainingJobResourceConfig* other); friend void swap(TrainingJobResourceConfig& a, TrainingJobResourceConfig& b) { @@ -957,6 +1113,12 @@ class TrainingJobResourceConfig final : ::google::protobuf::int64 volume_size_in_gb() const; void set_volume_size_in_gb(::google::protobuf::int64 value); + // .flyteidl.plugins.sagemaker.DistributedProtocol.Value distributed_protocol = 4; + void clear_distributed_protocol(); + static const int kDistributedProtocolFieldNumber = 4; + ::flyteidl::plugins::sagemaker::DistributedProtocol_Value distributed_protocol() const; + void set_distributed_protocol(::flyteidl::plugins::sagemaker::DistributedProtocol_Value value); + // @@protoc_insertion_point(class_scope:flyteidl.plugins.sagemaker.TrainingJobResourceConfig) private: class HasBitSetters; @@ -965,6 +1127,7 @@ class TrainingJobResourceConfig final : ::google::protobuf::internal::ArenaStringPtr instance_type_; ::google::protobuf::int64 instance_count_; ::google::protobuf::int64 volume_size_in_gb_; + int distributed_protocol_; mutable ::google::protobuf::internal::CachedSize _cached_size_; friend struct ::TableStruct_flyteidl_2fplugins_2fsagemaker_2ftraining_5fjob_2eproto; }; @@ -1008,7 +1171,7 @@ class TrainingJob final : &_TrainingJob_default_instance_); } static constexpr int kIndexInFileMessages = - 6; + 7; void Swap(TrainingJob* other); friend void swap(TrainingJob& a, TrainingJob& b) { @@ -1353,6 +1516,10 @@ inline void AlgorithmSpecification::set_input_content_type(::flyteidl::plugins:: // ------------------------------------------------------------------- +// DistributedProtocol + +// ------------------------------------------------------------------- + // TrainingJobResourceConfig // int64 instance_count = 1; @@ -1436,6 +1603,20 @@ inline void TrainingJobResourceConfig::set_volume_size_in_gb(::google::protobuf: // @@protoc_insertion_point(field_set:flyteidl.plugins.sagemaker.TrainingJobResourceConfig.volume_size_in_gb) } +// .flyteidl.plugins.sagemaker.DistributedProtocol.Value distributed_protocol = 4; +inline void TrainingJobResourceConfig::clear_distributed_protocol() { + distributed_protocol_ = 0; +} +inline ::flyteidl::plugins::sagemaker::DistributedProtocol_Value TrainingJobResourceConfig::distributed_protocol() const { + // @@protoc_insertion_point(field_get:flyteidl.plugins.sagemaker.TrainingJobResourceConfig.distributed_protocol) + return static_cast< ::flyteidl::plugins::sagemaker::DistributedProtocol_Value >(distributed_protocol_); +} +inline void TrainingJobResourceConfig::set_distributed_protocol(::flyteidl::plugins::sagemaker::DistributedProtocol_Value value) { + + distributed_protocol_ = value; + // @@protoc_insertion_point(field_set:flyteidl.plugins.sagemaker.TrainingJobResourceConfig.distributed_protocol) +} + // ------------------------------------------------------------------- // TrainingJob @@ -1557,6 +1738,8 @@ inline void TrainingJob::set_allocated_training_job_resource_config(::flyteidl:: // ------------------------------------------------------------------- +// ------------------------------------------------------------------- + // @@protoc_insertion_point(namespace_scope) @@ -1582,6 +1765,11 @@ template <> inline const EnumDescriptor* GetEnumDescriptor< ::flyteidl::plugins::sagemaker::InputContentType_Value>() { return ::flyteidl::plugins::sagemaker::InputContentType_Value_descriptor(); } +template <> struct is_proto_enum< ::flyteidl::plugins::sagemaker::DistributedProtocol_Value> : ::std::true_type {}; +template <> +inline const EnumDescriptor* GetEnumDescriptor< ::flyteidl::plugins::sagemaker::DistributedProtocol_Value>() { + return ::flyteidl::plugins::sagemaker::DistributedProtocol_Value_descriptor(); +} } // namespace protobuf } // namespace google diff --git a/gen/pb-go/flyteidl/plugins/sagemaker/training_job.pb.go b/gen/pb-go/flyteidl/plugins/sagemaker/training_job.pb.go index 865ea9a67..3643de78d 100644 --- a/gen/pb-go/flyteidl/plugins/sagemaker/training_job.pb.go +++ b/gen/pb-go/flyteidl/plugins/sagemaker/training_job.pb.go @@ -93,6 +93,37 @@ func (InputContentType_Value) EnumDescriptor() ([]byte, []int) { return fileDescriptor_6a68f64d8fd9fe30, []int{2, 0} } +type DistributedProtocol_Value int32 + +const ( + // Use this value if the user wishes to use framework-native distributed training interfaces. + // If this value is used, Flyte won't configure SageMaker to initialize unnecessary components such as + // OpenMPI or Parameter Server. + DistributedProtocol_UNSPECIFIED DistributedProtocol_Value = 0 + // Use this value if the user wishes to use MPI as the underlying protocol for her distributed training job + // MPI is a framework-agnostic distributed protocol. It has multiple implementations. Currently, we have only + // tested the OpenMPI implementation, which is the recommended implementation for Horovod. + DistributedProtocol_MPI DistributedProtocol_Value = 1 +) + +var DistributedProtocol_Value_name = map[int32]string{ + 0: "UNSPECIFIED", + 1: "MPI", +} + +var DistributedProtocol_Value_value = map[string]int32{ + "UNSPECIFIED": 0, + "MPI": 1, +} + +func (x DistributedProtocol_Value) String() string { + return proto.EnumName(DistributedProtocol_Value_name, int32(x)) +} + +func (DistributedProtocol_Value) EnumDescriptor() ([]byte, []int) { + return fileDescriptor_6a68f64d8fd9fe30, []int{5, 0} +} + // The input mode that the algorithm supports. When using the File input mode, SageMaker downloads // the training data from S3 to the provisioned ML storage Volume, and mounts the directory to docker // volume for training container. When using Pipe input mode, Amazon SageMaker streams data directly @@ -340,6 +371,39 @@ func (m *AlgorithmSpecification) GetInputContentType() InputContentType_Value { return InputContentType_TEXT_CSV } +// When enabling distributed training on a training job, the user should use this message to tell Flyte and SageMaker +// what kind of distributed protocol he/she wants to use to distribute the work. +type DistributedProtocol struct { + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *DistributedProtocol) Reset() { *m = DistributedProtocol{} } +func (m *DistributedProtocol) String() string { return proto.CompactTextString(m) } +func (*DistributedProtocol) ProtoMessage() {} +func (*DistributedProtocol) Descriptor() ([]byte, []int) { + return fileDescriptor_6a68f64d8fd9fe30, []int{5} +} + +func (m *DistributedProtocol) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_DistributedProtocol.Unmarshal(m, b) +} +func (m *DistributedProtocol) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_DistributedProtocol.Marshal(b, m, deterministic) +} +func (m *DistributedProtocol) XXX_Merge(src proto.Message) { + xxx_messageInfo_DistributedProtocol.Merge(m, src) +} +func (m *DistributedProtocol) XXX_Size() int { + return xxx_messageInfo_DistributedProtocol.Size(m) +} +func (m *DistributedProtocol) XXX_DiscardUnknown() { + xxx_messageInfo_DistributedProtocol.DiscardUnknown(m) +} + +var xxx_messageInfo_DistributedProtocol proto.InternalMessageInfo + // TrainingJobResourceConfig is a pass-through, specifying the instance type to use for the training job, the // number of instances to launch, and the size of the ML storage volume the user wants to provision // Refer to SageMaker official doc for more details: https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateTrainingJob.html @@ -349,17 +413,21 @@ type TrainingJobResourceConfig struct { // The ML compute instance type InstanceType string `protobuf:"bytes,2,opt,name=instance_type,json=instanceType,proto3" json:"instance_type,omitempty"` // The size of the ML storage volume that you want to provision. - VolumeSizeInGb int64 `protobuf:"varint,3,opt,name=volume_size_in_gb,json=volumeSizeInGb,proto3" json:"volume_size_in_gb,omitempty"` - XXX_NoUnkeyedLiteral struct{} `json:"-"` - XXX_unrecognized []byte `json:"-"` - XXX_sizecache int32 `json:"-"` + VolumeSizeInGb int64 `protobuf:"varint,3,opt,name=volume_size_in_gb,json=volumeSizeInGb,proto3" json:"volume_size_in_gb,omitempty"` + // When users specify an instance_count > 1, Flyte will try to configure SageMaker to enable distributed training. + // If the users wish to use framework-agnostic distributed protocol such as MPI or Parameter Server, this + // field should be set to the corresponding enum value + DistributedProtocol DistributedProtocol_Value `protobuf:"varint,4,opt,name=distributed_protocol,json=distributedProtocol,proto3,enum=flyteidl.plugins.sagemaker.DistributedProtocol_Value" json:"distributed_protocol,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` } func (m *TrainingJobResourceConfig) Reset() { *m = TrainingJobResourceConfig{} } func (m *TrainingJobResourceConfig) String() string { return proto.CompactTextString(m) } func (*TrainingJobResourceConfig) ProtoMessage() {} func (*TrainingJobResourceConfig) Descriptor() ([]byte, []int) { - return fileDescriptor_6a68f64d8fd9fe30, []int{5} + return fileDescriptor_6a68f64d8fd9fe30, []int{6} } func (m *TrainingJobResourceConfig) XXX_Unmarshal(b []byte) error { @@ -401,6 +469,13 @@ func (m *TrainingJobResourceConfig) GetVolumeSizeInGb() int64 { return 0 } +func (m *TrainingJobResourceConfig) GetDistributedProtocol() DistributedProtocol_Value { + if m != nil { + return m.DistributedProtocol + } + return DistributedProtocol_UNSPECIFIED +} + // The spec of a training job. This is mostly a pass-through object // https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateTrainingJob.html type TrainingJob struct { @@ -415,7 +490,7 @@ func (m *TrainingJob) Reset() { *m = TrainingJob{} } func (m *TrainingJob) String() string { return proto.CompactTextString(m) } func (*TrainingJob) ProtoMessage() {} func (*TrainingJob) Descriptor() ([]byte, []int) { - return fileDescriptor_6a68f64d8fd9fe30, []int{6} + return fileDescriptor_6a68f64d8fd9fe30, []int{7} } func (m *TrainingJob) XXX_Unmarshal(b []byte) error { @@ -454,11 +529,13 @@ func init() { proto.RegisterEnum("flyteidl.plugins.sagemaker.InputMode_Value", InputMode_Value_name, InputMode_Value_value) proto.RegisterEnum("flyteidl.plugins.sagemaker.AlgorithmName_Value", AlgorithmName_Value_name, AlgorithmName_Value_value) proto.RegisterEnum("flyteidl.plugins.sagemaker.InputContentType_Value", InputContentType_Value_name, InputContentType_Value_value) + proto.RegisterEnum("flyteidl.plugins.sagemaker.DistributedProtocol_Value", DistributedProtocol_Value_name, DistributedProtocol_Value_value) proto.RegisterType((*InputMode)(nil), "flyteidl.plugins.sagemaker.InputMode") proto.RegisterType((*AlgorithmName)(nil), "flyteidl.plugins.sagemaker.AlgorithmName") proto.RegisterType((*InputContentType)(nil), "flyteidl.plugins.sagemaker.InputContentType") proto.RegisterType((*MetricDefinition)(nil), "flyteidl.plugins.sagemaker.MetricDefinition") proto.RegisterType((*AlgorithmSpecification)(nil), "flyteidl.plugins.sagemaker.AlgorithmSpecification") + proto.RegisterType((*DistributedProtocol)(nil), "flyteidl.plugins.sagemaker.DistributedProtocol") proto.RegisterType((*TrainingJobResourceConfig)(nil), "flyteidl.plugins.sagemaker.TrainingJobResourceConfig") proto.RegisterType((*TrainingJob)(nil), "flyteidl.plugins.sagemaker.TrainingJob") } @@ -468,42 +545,46 @@ func init() { } var fileDescriptor_6a68f64d8fd9fe30 = []byte{ - // 592 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x84, 0x94, 0xd1, 0x4e, 0xdb, 0x3c, - 0x14, 0xc7, 0x09, 0x05, 0x3e, 0x38, 0x85, 0x2a, 0x58, 0xdf, 0x58, 0x61, 0xd3, 0x84, 0x32, 0x4d, - 0x02, 0x31, 0x12, 0xad, 0x88, 0xbb, 0xdd, 0x8c, 0x8e, 0xa1, 0xa2, 0x31, 0x50, 0x9a, 0x55, 0x68, - 0xbb, 0xc8, 0x92, 0xf4, 0xc4, 0x78, 0x24, 0x76, 0x94, 0x38, 0x68, 0xe5, 0x31, 0xf6, 0x14, 0x7b, - 0xbe, 0x3d, 0xc1, 0x14, 0xa7, 0x09, 0x5d, 0x45, 0xbb, 0x3b, 0xe7, 0xef, 0xe3, 0xe3, 0xff, 0xf9, - 0x9d, 0x13, 0xc3, 0x61, 0x18, 0x8d, 0x24, 0xb2, 0x61, 0x64, 0x25, 0x51, 0x4e, 0x19, 0xcf, 0xac, - 0xcc, 0xa3, 0x18, 0x7b, 0xb7, 0x98, 0x5a, 0x32, 0xf5, 0x18, 0x67, 0x9c, 0xba, 0xdf, 0x85, 0x6f, - 0x26, 0xa9, 0x90, 0x82, 0xec, 0x54, 0xe1, 0xe6, 0x38, 0xdc, 0xac, 0xc3, 0x77, 0x5e, 0x50, 0x21, - 0x68, 0x84, 0x96, 0x8a, 0xf4, 0xf3, 0xd0, 0x1a, 0xe6, 0xa9, 0x27, 0x99, 0xe0, 0xe5, 0x59, 0x63, - 0x0f, 0xd6, 0x7a, 0x3c, 0xc9, 0xe5, 0x85, 0x18, 0xa2, 0xf1, 0x0c, 0x96, 0x07, 0x5e, 0x94, 0x23, - 0x59, 0x85, 0xa5, 0x0f, 0xbd, 0x8f, 0xa7, 0xfa, 0x42, 0xb1, 0xba, 0xea, 0x5d, 0x9d, 0xea, 0x9a, - 0xf1, 0x06, 0x36, 0xde, 0x45, 0x54, 0xa4, 0x4c, 0xde, 0xc4, 0x9f, 0xbc, 0x18, 0x8d, 0xdd, 0x2a, - 0x1a, 0x60, 0xa5, 0xfb, 0xb9, 0xef, 0x5c, 0x5e, 0xe8, 0x0b, 0xa4, 0x09, 0xff, 0x5d, 0x9f, 0x9d, - 0x5c, 0x5e, 0xf6, 0x1d, 0x5d, 0x33, 0xf6, 0x41, 0x57, 0xc9, 0xbb, 0x82, 0x4b, 0xe4, 0xd2, 0x19, - 0x25, 0x68, 0x3c, 0xa9, 0x4e, 0xad, 0xc3, 0xaa, 0x73, 0x7a, 0xed, 0xb8, 0xdd, 0xfe, 0x40, 0x5f, - 0x30, 0xde, 0x82, 0x7e, 0x81, 0x32, 0x65, 0xc1, 0x7b, 0x0c, 0x19, 0x67, 0x85, 0x43, 0x42, 0x60, - 0x89, 0x7b, 0x31, 0xb6, 0xb5, 0x5d, 0x6d, 0x6f, 0xcd, 0x56, 0x6b, 0xf2, 0x3f, 0x2c, 0xa7, 0x48, - 0xf1, 0x47, 0x7b, 0x51, 0x89, 0xe5, 0x87, 0xf1, 0xab, 0x01, 0x5b, 0xb5, 0xb9, 0x7e, 0x82, 0x01, - 0x0b, 0x59, 0xa0, 0xca, 0x24, 0xe7, 0x00, 0xac, 0xf0, 0xe0, 0xc6, 0x62, 0x58, 0xa6, 0x6a, 0x75, - 0x0e, 0xcc, 0xd9, 0xc4, 0xcc, 0x1a, 0x87, 0xa9, 0x7c, 0xda, 0x6b, 0xac, 0x12, 0xc8, 0x00, 0x5a, - 0x5e, 0x75, 0x8b, 0xab, 0xac, 0x2d, 0xaa, 0x7c, 0xd6, 0xbc, 0x7c, 0x7f, 0x41, 0x1b, 0xe7, 0xdc, - 0xf0, 0x26, 0x45, 0x72, 0x00, 0x9b, 0x0f, 0x79, 0xef, 0x30, 0xcd, 0x98, 0xe0, 0xed, 0x86, 0x2a, - 0x50, 0xaf, 0x37, 0x06, 0xa5, 0x4e, 0xbe, 0x02, 0x89, 0x15, 0x29, 0x77, 0x58, 0xa3, 0xca, 0xda, - 0x4b, 0xbb, 0x8d, 0xbd, 0x66, 0xe7, 0xf5, 0x3c, 0x23, 0xd3, 0x7c, 0xed, 0xcd, 0x78, 0x4a, 0xc9, - 0xc8, 0x37, 0x20, 0x25, 0xad, 0xa0, 0x6c, 0x99, 0x2b, 0x47, 0x09, 0xb6, 0x97, 0x55, 0x95, 0x9d, - 0x7f, 0x52, 0x9b, 0xe8, 0xf3, 0xb8, 0x50, 0x9d, 0x4d, 0xf7, 0xff, 0xa7, 0x06, 0xdb, 0xce, 0x78, - 0x86, 0xcf, 0x85, 0x6f, 0x63, 0x26, 0xf2, 0x34, 0xc0, 0xae, 0xe0, 0x21, 0xa3, 0xe4, 0x15, 0xb4, - 0x18, 0xcf, 0xa4, 0xc7, 0x03, 0x74, 0x03, 0x91, 0x73, 0xa9, 0x3a, 0xd6, 0xb0, 0x37, 0x2a, 0xb5, - 0x5b, 0x88, 0xe4, 0x25, 0xd4, 0x42, 0xe9, 0xb0, 0x9c, 0x86, 0xf5, 0x4a, 0x2c, 0x6e, 0x22, 0xfb, - 0xb0, 0x79, 0x27, 0xa2, 0x3c, 0x46, 0x37, 0x63, 0xf7, 0xe8, 0x32, 0xee, 0x52, 0x5f, 0x51, 0x6d, - 0xd8, 0xad, 0x72, 0xa3, 0xcf, 0xee, 0xb1, 0xc7, 0xcf, 0x7c, 0xe3, 0xb7, 0x06, 0xcd, 0x09, 0x53, - 0xe4, 0x16, 0x9e, 0x3e, 0x34, 0x24, 0x9b, 0x9c, 0x27, 0xe5, 0xa7, 0x39, 0x9f, 0xc5, 0xe3, 0x93, - 0x68, 0x6f, 0x79, 0x8f, 0x4f, 0xe8, 0x1d, 0x3c, 0x9f, 0xfc, 0xa9, 0xdd, 0x74, 0x8c, 0xa4, 0xe8, - 0x41, 0xc8, 0xa8, 0xaa, 0xad, 0xd9, 0x39, 0x9e, 0x77, 0xe3, 0x4c, 0xa0, 0xf6, 0xb6, 0x9c, 0xb5, - 0x75, 0x72, 0xfc, 0xe5, 0x88, 0x32, 0x79, 0x93, 0xfb, 0x66, 0x20, 0x62, 0x2b, 0x1a, 0x85, 0xd2, - 0xaa, 0xdf, 0x1d, 0x8a, 0xdc, 0x4a, 0xfc, 0x43, 0x2a, 0xac, 0xe9, 0xa7, 0xc8, 0x5f, 0x51, 0x0f, - 0xc7, 0xd1, 0x9f, 0x00, 0x00, 0x00, 0xff, 0xff, 0xc4, 0x72, 0xc9, 0x58, 0xa5, 0x04, 0x00, 0x00, + // 656 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x84, 0x94, 0x4b, 0x4f, 0xdb, 0x40, + 0x14, 0x85, 0x13, 0x12, 0x5e, 0x37, 0x90, 0x9a, 0x81, 0xd2, 0x40, 0xab, 0x8a, 0xba, 0xaa, 0x04, + 0xa2, 0xd8, 0x2a, 0x08, 0xa9, 0x8b, 0x6e, 0x4a, 0x08, 0xc8, 0xa8, 0x81, 0xc8, 0x09, 0x11, 0x6a, + 0x17, 0xae, 0x1f, 0x63, 0x33, 0xc5, 0x9e, 0xb1, 0xec, 0x31, 0x6a, 0xf8, 0x45, 0xfd, 0x7d, 0x5d, + 0x77, 0x51, 0x79, 0xfc, 0x20, 0x8d, 0x42, 0xba, 0x73, 0xce, 0x5c, 0xdf, 0xdc, 0xf3, 0x9d, 0xeb, + 0x81, 0x03, 0xd7, 0x1f, 0x71, 0x4c, 0x1c, 0x5f, 0x0d, 0xfd, 0xc4, 0x23, 0x34, 0x56, 0x63, 0xd3, + 0xc3, 0x81, 0x79, 0x87, 0x23, 0x95, 0x47, 0x26, 0xa1, 0x84, 0x7a, 0xc6, 0x0f, 0x66, 0x29, 0x61, + 0xc4, 0x38, 0x43, 0xdb, 0x45, 0xb9, 0x92, 0x97, 0x2b, 0x65, 0xf9, 0xf6, 0x6b, 0x8f, 0x31, 0xcf, + 0xc7, 0xaa, 0xa8, 0xb4, 0x12, 0x57, 0x75, 0x92, 0xc8, 0xe4, 0x84, 0xd1, 0xec, 0x5d, 0x79, 0x17, + 0x96, 0x35, 0x1a, 0x26, 0xbc, 0xcb, 0x1c, 0x2c, 0xbf, 0x84, 0xf9, 0xa1, 0xe9, 0x27, 0x18, 0x2d, + 0x41, 0xfd, 0x4c, 0xfb, 0xd2, 0x91, 0x2a, 0xe9, 0x53, 0x4f, 0xeb, 0x75, 0xa4, 0xaa, 0xfc, 0x01, + 0x56, 0x3f, 0xfb, 0x1e, 0x8b, 0x08, 0xbf, 0x0d, 0x2e, 0xcd, 0x00, 0xcb, 0x3b, 0x45, 0x35, 0xc0, + 0x42, 0xfb, 0xba, 0x3f, 0xb8, 0xea, 0x4a, 0x15, 0xd4, 0x80, 0xc5, 0x9b, 0xf3, 0x93, 0xab, 0xab, + 0xfe, 0x40, 0xaa, 0xca, 0x7b, 0x20, 0x89, 0xe6, 0x6d, 0x46, 0x39, 0xa6, 0x7c, 0x30, 0x0a, 0xb1, + 0xfc, 0xbc, 0x78, 0x6b, 0x05, 0x96, 0x06, 0x9d, 0x9b, 0x81, 0xd1, 0xee, 0x0f, 0xa5, 0x8a, 0xfc, + 0x09, 0xa4, 0x2e, 0xe6, 0x11, 0xb1, 0x4f, 0xb1, 0x4b, 0x28, 0x49, 0x27, 0x44, 0x08, 0xea, 0xd4, + 0x0c, 0x70, 0xab, 0xba, 0x53, 0xdd, 0x5d, 0xd6, 0xc5, 0x33, 0xda, 0x80, 0xf9, 0x08, 0x7b, 0xf8, + 0x67, 0x6b, 0x4e, 0x88, 0xd9, 0x0f, 0xf9, 0x57, 0x0d, 0x36, 0xcb, 0xe1, 0xfa, 0x21, 0xb6, 0x89, + 0x4b, 0x6c, 0x61, 0x13, 0x5d, 0x00, 0x90, 0x74, 0x06, 0x23, 0x60, 0x4e, 0xd6, 0xaa, 0x79, 0xb8, + 0xaf, 0x3c, 0x4d, 0x4c, 0x29, 0x71, 0x28, 0x62, 0x4e, 0x7d, 0x99, 0x14, 0x02, 0x1a, 0x42, 0xd3, + 0x2c, 0xfe, 0xc5, 0x10, 0xa3, 0xcd, 0x89, 0x7e, 0xea, 0xac, 0x7e, 0xff, 0x40, 0xcb, 0x7b, 0xae, + 0x9a, 0xe3, 0x22, 0xda, 0x87, 0xb5, 0xc7, 0xbe, 0xf7, 0x38, 0x8a, 0x09, 0xa3, 0xad, 0x9a, 0x30, + 0x28, 0x95, 0x07, 0xc3, 0x4c, 0x47, 0xdf, 0x00, 0x05, 0x82, 0x94, 0xe1, 0x94, 0xa8, 0xe2, 0x56, + 0x7d, 0xa7, 0xb6, 0xdb, 0x38, 0x7c, 0x3f, 0x6b, 0x90, 0x49, 0xbe, 0xfa, 0x5a, 0x30, 0xa1, 0xc4, + 0xe8, 0x3b, 0xa0, 0x8c, 0x96, 0x9d, 0x45, 0x66, 0xf0, 0x51, 0x88, 0x5b, 0xf3, 0xc2, 0xe5, 0xe1, + 0x7f, 0xa9, 0x8d, 0xe5, 0x9c, 0x1b, 0x95, 0xc8, 0x64, 0xfe, 0x1f, 0x61, 0xfd, 0x94, 0xc4, 0x3c, + 0x22, 0x56, 0xc2, 0xb1, 0xd3, 0x4b, 0x97, 0xd0, 0x66, 0xbe, 0xfc, 0xa6, 0x58, 0x8b, 0x67, 0xd0, + 0xb8, 0xbe, 0xec, 0xf7, 0x3a, 0x6d, 0xed, 0x4c, 0xeb, 0x9c, 0x4a, 0x15, 0xb4, 0x08, 0xb5, 0x6e, + 0x4f, 0x93, 0xaa, 0xf2, 0x9f, 0x2a, 0x6c, 0x0d, 0xf2, 0xed, 0xbf, 0x60, 0x96, 0x8e, 0x63, 0x96, + 0x44, 0x36, 0x6e, 0x33, 0xea, 0x12, 0x0f, 0xbd, 0x83, 0x26, 0xa1, 0x31, 0x37, 0xa9, 0x8d, 0x0d, + 0x9b, 0x25, 0x94, 0x8b, 0xac, 0x6b, 0xfa, 0x6a, 0xa1, 0xb6, 0x53, 0x11, 0xbd, 0x85, 0x52, 0xc8, + 0xbc, 0x65, 0x7b, 0xb4, 0x52, 0x88, 0xe9, 0x8c, 0x68, 0x0f, 0xd6, 0xee, 0x99, 0x9f, 0x04, 0xd8, + 0x88, 0xc9, 0x03, 0x36, 0x08, 0x35, 0x3c, 0x4b, 0xe4, 0x51, 0xd3, 0x9b, 0xd9, 0x41, 0x9f, 0x3c, + 0x60, 0x8d, 0x9e, 0x5b, 0xe8, 0x16, 0x36, 0x9c, 0x47, 0x3b, 0x46, 0x98, 0xfb, 0x69, 0xd5, 0x05, + 0xb2, 0xe3, 0x59, 0xc8, 0xa6, 0x60, 0xc8, 0xa9, 0xad, 0x3b, 0x53, 0x08, 0xfd, 0xae, 0x42, 0x63, + 0xcc, 0x3e, 0xba, 0x83, 0x17, 0x8f, 0x4b, 0x13, 0x8f, 0xef, 0xbc, 0x70, 0xde, 0x98, 0x9d, 0xd7, + 0xf4, 0xaf, 0x45, 0xdf, 0x34, 0xa7, 0x7f, 0x45, 0xf7, 0xf0, 0x6a, 0xfc, 0xe2, 0x31, 0xa2, 0x1c, + 0x7e, 0xba, 0x27, 0x2e, 0xf1, 0x04, 0xc5, 0xc6, 0x6c, 0xbb, 0x4f, 0x46, 0xa7, 0x6f, 0xf1, 0xa7, + 0x8e, 0x4e, 0x8e, 0xbf, 0x1e, 0x79, 0x84, 0xdf, 0x26, 0x96, 0x62, 0xb3, 0x40, 0xf5, 0x47, 0x2e, + 0x57, 0xcb, 0xbb, 0xd1, 0xc3, 0x54, 0x0d, 0xad, 0x03, 0x8f, 0xa9, 0x93, 0xd7, 0xa5, 0xb5, 0x20, + 0x72, 0x38, 0xfa, 0x1b, 0x00, 0x00, 0xff, 0xff, 0xec, 0xce, 0x1b, 0x58, 0x49, 0x05, 0x00, 0x00, } diff --git a/gen/pb-go/flyteidl/plugins/sagemaker/training_job.pb.validate.go b/gen/pb-go/flyteidl/plugins/sagemaker/training_job.pb.validate.go index a1ceaa438..19d1eb9a2 100644 --- a/gen/pb-go/flyteidl/plugins/sagemaker/training_job.pb.validate.go +++ b/gen/pb-go/flyteidl/plugins/sagemaker/training_job.pb.validate.go @@ -389,6 +389,73 @@ var _ interface { ErrorName() string } = AlgorithmSpecificationValidationError{} +// Validate checks the field values on DistributedProtocol with the rules +// defined in the proto definition for this message. If any rules are +// violated, an error is returned. +func (m *DistributedProtocol) Validate() error { + if m == nil { + return nil + } + + return nil +} + +// DistributedProtocolValidationError is the validation error returned by +// DistributedProtocol.Validate if the designated constraints aren't met. +type DistributedProtocolValidationError struct { + field string + reason string + cause error + key bool +} + +// Field function returns field value. +func (e DistributedProtocolValidationError) Field() string { return e.field } + +// Reason function returns reason value. +func (e DistributedProtocolValidationError) Reason() string { return e.reason } + +// Cause function returns cause value. +func (e DistributedProtocolValidationError) Cause() error { return e.cause } + +// Key function returns key value. +func (e DistributedProtocolValidationError) Key() bool { return e.key } + +// ErrorName returns error name. +func (e DistributedProtocolValidationError) ErrorName() string { + return "DistributedProtocolValidationError" +} + +// Error satisfies the builtin error interface +func (e DistributedProtocolValidationError) Error() string { + cause := "" + if e.cause != nil { + cause = fmt.Sprintf(" | caused by: %v", e.cause) + } + + key := "" + if e.key { + key = "key for " + } + + return fmt.Sprintf( + "invalid %sDistributedProtocol.%s: %s%s", + key, + e.field, + e.reason, + cause) +} + +var _ error = DistributedProtocolValidationError{} + +var _ interface { + Field() string + Reason() string + Key() bool + Cause() error + ErrorName() string +} = DistributedProtocolValidationError{} + // Validate checks the field values on TrainingJobResourceConfig with the rules // defined in the proto definition for this message. If any rules are // violated, an error is returned. @@ -403,6 +470,8 @@ func (m *TrainingJobResourceConfig) Validate() error { // no validation rules for VolumeSizeInGb + // no validation rules for DistributedProtocol + return nil } diff --git a/gen/pb-java/flyteidl/plugins/sagemaker/TrainingJobOuterClass.java b/gen/pb-java/flyteidl/plugins/sagemaker/TrainingJobOuterClass.java index 145062e99..ce4b0a8a4 100644 --- a/gen/pb-java/flyteidl/plugins/sagemaker/TrainingJobOuterClass.java +++ b/gen/pb-java/flyteidl/plugins/sagemaker/TrainingJobOuterClass.java @@ -3910,6 +3910,549 @@ public flyteidl.plugins.sagemaker.TrainingJobOuterClass.AlgorithmSpecification g } + public interface DistributedProtocolOrBuilder extends + // @@protoc_insertion_point(interface_extends:flyteidl.plugins.sagemaker.DistributedProtocol) + com.google.protobuf.MessageOrBuilder { + } + /** + *
+   * When enabling distributed training on a training job, the user should use this message to tell Flyte and SageMaker
+   * what kind of distributed protocol he/she wants to use to distribute the work.
+   * 
+ * + * Protobuf type {@code flyteidl.plugins.sagemaker.DistributedProtocol} + */ + public static final class DistributedProtocol extends + com.google.protobuf.GeneratedMessageV3 implements + // @@protoc_insertion_point(message_implements:flyteidl.plugins.sagemaker.DistributedProtocol) + DistributedProtocolOrBuilder { + private static final long serialVersionUID = 0L; + // Use DistributedProtocol.newBuilder() to construct. + private DistributedProtocol(com.google.protobuf.GeneratedMessageV3.Builder builder) { + super(builder); + } + private DistributedProtocol() { + } + + @java.lang.Override + public final com.google.protobuf.UnknownFieldSet + getUnknownFields() { + return this.unknownFields; + } + private DistributedProtocol( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + this(); + if (extensionRegistry == null) { + throw new java.lang.NullPointerException(); + } + com.google.protobuf.UnknownFieldSet.Builder unknownFields = + com.google.protobuf.UnknownFieldSet.newBuilder(); + try { + boolean done = false; + while (!done) { + int tag = input.readTag(); + switch (tag) { + case 0: + done = true; + break; + default: { + if (!parseUnknownField( + input, unknownFields, extensionRegistry, tag)) { + done = true; + } + break; + } + } + } + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.setUnfinishedMessage(this); + } catch (java.io.IOException e) { + throw new com.google.protobuf.InvalidProtocolBufferException( + e).setUnfinishedMessage(this); + } finally { + this.unknownFields = unknownFields.build(); + makeExtensionsImmutable(); + } + } + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return flyteidl.plugins.sagemaker.TrainingJobOuterClass.internal_static_flyteidl_plugins_sagemaker_DistributedProtocol_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return flyteidl.plugins.sagemaker.TrainingJobOuterClass.internal_static_flyteidl_plugins_sagemaker_DistributedProtocol_fieldAccessorTable + .ensureFieldAccessorsInitialized( + flyteidl.plugins.sagemaker.TrainingJobOuterClass.DistributedProtocol.class, flyteidl.plugins.sagemaker.TrainingJobOuterClass.DistributedProtocol.Builder.class); + } + + /** + * Protobuf enum {@code flyteidl.plugins.sagemaker.DistributedProtocol.Value} + */ + public enum Value + implements com.google.protobuf.ProtocolMessageEnum { + /** + *
+       * Use this value if the user wishes to use framework-native distributed training interfaces.
+       * If this value is used, Flyte won't configure SageMaker to initialize unnecessary components such as
+       * OpenMPI or Parameter Server.
+       * 
+ * + * UNSPECIFIED = 0; + */ + UNSPECIFIED(0), + /** + *
+       * Use this value if the user wishes to use MPI as the underlying protocol for her distributed training job
+       * MPI is a framework-agnostic distributed protocol. It has multiple implementations. Currently, we have only
+       * tested the OpenMPI implementation, which is the recommended implementation for Horovod.
+       * 
+ * + * MPI = 1; + */ + MPI(1), + UNRECOGNIZED(-1), + ; + + /** + *
+       * Use this value if the user wishes to use framework-native distributed training interfaces.
+       * If this value is used, Flyte won't configure SageMaker to initialize unnecessary components such as
+       * OpenMPI or Parameter Server.
+       * 
+ * + * UNSPECIFIED = 0; + */ + public static final int UNSPECIFIED_VALUE = 0; + /** + *
+       * Use this value if the user wishes to use MPI as the underlying protocol for her distributed training job
+       * MPI is a framework-agnostic distributed protocol. It has multiple implementations. Currently, we have only
+       * tested the OpenMPI implementation, which is the recommended implementation for Horovod.
+       * 
+ * + * MPI = 1; + */ + public static final int MPI_VALUE = 1; + + + public final int getNumber() { + if (this == UNRECOGNIZED) { + throw new java.lang.IllegalArgumentException( + "Can't get the number of an unknown enum value."); + } + return value; + } + + /** + * @deprecated Use {@link #forNumber(int)} instead. + */ + @java.lang.Deprecated + public static Value valueOf(int value) { + return forNumber(value); + } + + public static Value forNumber(int value) { + switch (value) { + case 0: return UNSPECIFIED; + case 1: return MPI; + default: return null; + } + } + + public static com.google.protobuf.Internal.EnumLiteMap + internalGetValueMap() { + return internalValueMap; + } + private static final com.google.protobuf.Internal.EnumLiteMap< + Value> internalValueMap = + new com.google.protobuf.Internal.EnumLiteMap() { + public Value findValueByNumber(int number) { + return Value.forNumber(number); + } + }; + + public final com.google.protobuf.Descriptors.EnumValueDescriptor + getValueDescriptor() { + return getDescriptor().getValues().get(ordinal()); + } + public final com.google.protobuf.Descriptors.EnumDescriptor + getDescriptorForType() { + return getDescriptor(); + } + public static final com.google.protobuf.Descriptors.EnumDescriptor + getDescriptor() { + return flyteidl.plugins.sagemaker.TrainingJobOuterClass.DistributedProtocol.getDescriptor().getEnumTypes().get(0); + } + + private static final Value[] VALUES = values(); + + public static Value valueOf( + com.google.protobuf.Descriptors.EnumValueDescriptor desc) { + if (desc.getType() != getDescriptor()) { + throw new java.lang.IllegalArgumentException( + "EnumValueDescriptor is not for this type."); + } + if (desc.getIndex() == -1) { + return UNRECOGNIZED; + } + return VALUES[desc.getIndex()]; + } + + private final int value; + + private Value(int value) { + this.value = value; + } + + // @@protoc_insertion_point(enum_scope:flyteidl.plugins.sagemaker.DistributedProtocol.Value) + } + + private byte memoizedIsInitialized = -1; + @java.lang.Override + public final boolean isInitialized() { + byte isInitialized = memoizedIsInitialized; + if (isInitialized == 1) return true; + if (isInitialized == 0) return false; + + memoizedIsInitialized = 1; + return true; + } + + @java.lang.Override + public void writeTo(com.google.protobuf.CodedOutputStream output) + throws java.io.IOException { + unknownFields.writeTo(output); + } + + @java.lang.Override + public int getSerializedSize() { + int size = memoizedSize; + if (size != -1) return size; + + size = 0; + size += unknownFields.getSerializedSize(); + memoizedSize = size; + return size; + } + + @java.lang.Override + public boolean equals(final java.lang.Object obj) { + if (obj == this) { + return true; + } + if (!(obj instanceof flyteidl.plugins.sagemaker.TrainingJobOuterClass.DistributedProtocol)) { + return super.equals(obj); + } + flyteidl.plugins.sagemaker.TrainingJobOuterClass.DistributedProtocol other = (flyteidl.plugins.sagemaker.TrainingJobOuterClass.DistributedProtocol) obj; + + if (!unknownFields.equals(other.unknownFields)) return false; + return true; + } + + @java.lang.Override + public int hashCode() { + if (memoizedHashCode != 0) { + return memoizedHashCode; + } + int hash = 41; + hash = (19 * hash) + getDescriptor().hashCode(); + hash = (29 * hash) + unknownFields.hashCode(); + memoizedHashCode = hash; + return hash; + } + + public static flyteidl.plugins.sagemaker.TrainingJobOuterClass.DistributedProtocol parseFrom( + java.nio.ByteBuffer data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static flyteidl.plugins.sagemaker.TrainingJobOuterClass.DistributedProtocol parseFrom( + java.nio.ByteBuffer data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static flyteidl.plugins.sagemaker.TrainingJobOuterClass.DistributedProtocol parseFrom( + com.google.protobuf.ByteString data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static flyteidl.plugins.sagemaker.TrainingJobOuterClass.DistributedProtocol parseFrom( + com.google.protobuf.ByteString data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static flyteidl.plugins.sagemaker.TrainingJobOuterClass.DistributedProtocol parseFrom(byte[] data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static flyteidl.plugins.sagemaker.TrainingJobOuterClass.DistributedProtocol parseFrom( + byte[] data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static flyteidl.plugins.sagemaker.TrainingJobOuterClass.DistributedProtocol parseFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static flyteidl.plugins.sagemaker.TrainingJobOuterClass.DistributedProtocol parseFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + public static flyteidl.plugins.sagemaker.TrainingJobOuterClass.DistributedProtocol parseDelimitedFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input); + } + public static flyteidl.plugins.sagemaker.TrainingJobOuterClass.DistributedProtocol parseDelimitedFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input, extensionRegistry); + } + public static flyteidl.plugins.sagemaker.TrainingJobOuterClass.DistributedProtocol parseFrom( + com.google.protobuf.CodedInputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static flyteidl.plugins.sagemaker.TrainingJobOuterClass.DistributedProtocol parseFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + + @java.lang.Override + public Builder newBuilderForType() { return newBuilder(); } + public static Builder newBuilder() { + return DEFAULT_INSTANCE.toBuilder(); + } + public static Builder newBuilder(flyteidl.plugins.sagemaker.TrainingJobOuterClass.DistributedProtocol prototype) { + return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); + } + @java.lang.Override + public Builder toBuilder() { + return this == DEFAULT_INSTANCE + ? new Builder() : new Builder().mergeFrom(this); + } + + @java.lang.Override + protected Builder newBuilderForType( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + Builder builder = new Builder(parent); + return builder; + } + /** + *
+     * When enabling distributed training on a training job, the user should use this message to tell Flyte and SageMaker
+     * what kind of distributed protocol he/she wants to use to distribute the work.
+     * 
+ * + * Protobuf type {@code flyteidl.plugins.sagemaker.DistributedProtocol} + */ + public static final class Builder extends + com.google.protobuf.GeneratedMessageV3.Builder implements + // @@protoc_insertion_point(builder_implements:flyteidl.plugins.sagemaker.DistributedProtocol) + flyteidl.plugins.sagemaker.TrainingJobOuterClass.DistributedProtocolOrBuilder { + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return flyteidl.plugins.sagemaker.TrainingJobOuterClass.internal_static_flyteidl_plugins_sagemaker_DistributedProtocol_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return flyteidl.plugins.sagemaker.TrainingJobOuterClass.internal_static_flyteidl_plugins_sagemaker_DistributedProtocol_fieldAccessorTable + .ensureFieldAccessorsInitialized( + flyteidl.plugins.sagemaker.TrainingJobOuterClass.DistributedProtocol.class, flyteidl.plugins.sagemaker.TrainingJobOuterClass.DistributedProtocol.Builder.class); + } + + // Construct using flyteidl.plugins.sagemaker.TrainingJobOuterClass.DistributedProtocol.newBuilder() + private Builder() { + maybeForceBuilderInitialization(); + } + + private Builder( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + super(parent); + maybeForceBuilderInitialization(); + } + private void maybeForceBuilderInitialization() { + if (com.google.protobuf.GeneratedMessageV3 + .alwaysUseFieldBuilders) { + } + } + @java.lang.Override + public Builder clear() { + super.clear(); + return this; + } + + @java.lang.Override + public com.google.protobuf.Descriptors.Descriptor + getDescriptorForType() { + return flyteidl.plugins.sagemaker.TrainingJobOuterClass.internal_static_flyteidl_plugins_sagemaker_DistributedProtocol_descriptor; + } + + @java.lang.Override + public flyteidl.plugins.sagemaker.TrainingJobOuterClass.DistributedProtocol getDefaultInstanceForType() { + return flyteidl.plugins.sagemaker.TrainingJobOuterClass.DistributedProtocol.getDefaultInstance(); + } + + @java.lang.Override + public flyteidl.plugins.sagemaker.TrainingJobOuterClass.DistributedProtocol build() { + flyteidl.plugins.sagemaker.TrainingJobOuterClass.DistributedProtocol result = buildPartial(); + if (!result.isInitialized()) { + throw newUninitializedMessageException(result); + } + return result; + } + + @java.lang.Override + public flyteidl.plugins.sagemaker.TrainingJobOuterClass.DistributedProtocol buildPartial() { + flyteidl.plugins.sagemaker.TrainingJobOuterClass.DistributedProtocol result = new flyteidl.plugins.sagemaker.TrainingJobOuterClass.DistributedProtocol(this); + onBuilt(); + return result; + } + + @java.lang.Override + public Builder clone() { + return super.clone(); + } + @java.lang.Override + public Builder setField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.setField(field, value); + } + @java.lang.Override + public Builder clearField( + com.google.protobuf.Descriptors.FieldDescriptor field) { + return super.clearField(field); + } + @java.lang.Override + public Builder clearOneof( + com.google.protobuf.Descriptors.OneofDescriptor oneof) { + return super.clearOneof(oneof); + } + @java.lang.Override + public Builder setRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + int index, java.lang.Object value) { + return super.setRepeatedField(field, index, value); + } + @java.lang.Override + public Builder addRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.addRepeatedField(field, value); + } + @java.lang.Override + public Builder mergeFrom(com.google.protobuf.Message other) { + if (other instanceof flyteidl.plugins.sagemaker.TrainingJobOuterClass.DistributedProtocol) { + return mergeFrom((flyteidl.plugins.sagemaker.TrainingJobOuterClass.DistributedProtocol)other); + } else { + super.mergeFrom(other); + return this; + } + } + + public Builder mergeFrom(flyteidl.plugins.sagemaker.TrainingJobOuterClass.DistributedProtocol other) { + if (other == flyteidl.plugins.sagemaker.TrainingJobOuterClass.DistributedProtocol.getDefaultInstance()) return this; + this.mergeUnknownFields(other.unknownFields); + onChanged(); + return this; + } + + @java.lang.Override + public final boolean isInitialized() { + return true; + } + + @java.lang.Override + public Builder mergeFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + flyteidl.plugins.sagemaker.TrainingJobOuterClass.DistributedProtocol parsedMessage = null; + try { + parsedMessage = PARSER.parsePartialFrom(input, extensionRegistry); + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + parsedMessage = (flyteidl.plugins.sagemaker.TrainingJobOuterClass.DistributedProtocol) e.getUnfinishedMessage(); + throw e.unwrapIOException(); + } finally { + if (parsedMessage != null) { + mergeFrom(parsedMessage); + } + } + return this; + } + @java.lang.Override + public final Builder setUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.setUnknownFields(unknownFields); + } + + @java.lang.Override + public final Builder mergeUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.mergeUnknownFields(unknownFields); + } + + + // @@protoc_insertion_point(builder_scope:flyteidl.plugins.sagemaker.DistributedProtocol) + } + + // @@protoc_insertion_point(class_scope:flyteidl.plugins.sagemaker.DistributedProtocol) + private static final flyteidl.plugins.sagemaker.TrainingJobOuterClass.DistributedProtocol DEFAULT_INSTANCE; + static { + DEFAULT_INSTANCE = new flyteidl.plugins.sagemaker.TrainingJobOuterClass.DistributedProtocol(); + } + + public static flyteidl.plugins.sagemaker.TrainingJobOuterClass.DistributedProtocol getDefaultInstance() { + return DEFAULT_INSTANCE; + } + + private static final com.google.protobuf.Parser + PARSER = new com.google.protobuf.AbstractParser() { + @java.lang.Override + public DistributedProtocol parsePartialFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return new DistributedProtocol(input, extensionRegistry); + } + }; + + public static com.google.protobuf.Parser parser() { + return PARSER; + } + + @java.lang.Override + public com.google.protobuf.Parser getParserForType() { + return PARSER; + } + + @java.lang.Override + public flyteidl.plugins.sagemaker.TrainingJobOuterClass.DistributedProtocol getDefaultInstanceForType() { + return DEFAULT_INSTANCE; + } + + } + public interface TrainingJobResourceConfigOrBuilder extends // @@protoc_insertion_point(interface_extends:flyteidl.plugins.sagemaker.TrainingJobResourceConfig) com.google.protobuf.MessageOrBuilder { @@ -3949,6 +4492,27 @@ public interface TrainingJobResourceConfigOrBuilder extends * int64 volume_size_in_gb = 3; */ long getVolumeSizeInGb(); + + /** + *
+     * When users specify an instance_count > 1, Flyte will try to configure SageMaker to enable distributed training.
+     * If the users wish to use framework-agnostic distributed protocol such as MPI or Parameter Server, this
+     * field should be set to the corresponding enum value
+     * 
+ * + * .flyteidl.plugins.sagemaker.DistributedProtocol.Value distributed_protocol = 4; + */ + int getDistributedProtocolValue(); + /** + *
+     * When users specify an instance_count > 1, Flyte will try to configure SageMaker to enable distributed training.
+     * If the users wish to use framework-agnostic distributed protocol such as MPI or Parameter Server, this
+     * field should be set to the corresponding enum value
+     * 
+ * + * .flyteidl.plugins.sagemaker.DistributedProtocol.Value distributed_protocol = 4; + */ + flyteidl.plugins.sagemaker.TrainingJobOuterClass.DistributedProtocol.Value getDistributedProtocol(); } /** *
@@ -3970,6 +4534,7 @@ private TrainingJobResourceConfig(com.google.protobuf.GeneratedMessageV3.Builder
     }
     private TrainingJobResourceConfig() {
       instanceType_ = "";
+      distributedProtocol_ = 0;
     }
 
     @java.lang.Override
@@ -4012,6 +4577,12 @@ private TrainingJobResourceConfig(
               volumeSizeInGb_ = input.readInt64();
               break;
             }
+            case 32: {
+              int rawValue = input.readEnum();
+
+              distributedProtocol_ = rawValue;
+              break;
+            }
             default: {
               if (!parseUnknownField(
                   input, unknownFields, extensionRegistry, tag)) {
@@ -4112,6 +4683,35 @@ public long getVolumeSizeInGb() {
       return volumeSizeInGb_;
     }
 
+    public static final int DISTRIBUTED_PROTOCOL_FIELD_NUMBER = 4;
+    private int distributedProtocol_;
+    /**
+     * 
+     * When users specify an instance_count > 1, Flyte will try to configure SageMaker to enable distributed training.
+     * If the users wish to use framework-agnostic distributed protocol such as MPI or Parameter Server, this
+     * field should be set to the corresponding enum value
+     * 
+ * + * .flyteidl.plugins.sagemaker.DistributedProtocol.Value distributed_protocol = 4; + */ + public int getDistributedProtocolValue() { + return distributedProtocol_; + } + /** + *
+     * When users specify an instance_count > 1, Flyte will try to configure SageMaker to enable distributed training.
+     * If the users wish to use framework-agnostic distributed protocol such as MPI or Parameter Server, this
+     * field should be set to the corresponding enum value
+     * 
+ * + * .flyteidl.plugins.sagemaker.DistributedProtocol.Value distributed_protocol = 4; + */ + public flyteidl.plugins.sagemaker.TrainingJobOuterClass.DistributedProtocol.Value getDistributedProtocol() { + @SuppressWarnings("deprecation") + flyteidl.plugins.sagemaker.TrainingJobOuterClass.DistributedProtocol.Value result = flyteidl.plugins.sagemaker.TrainingJobOuterClass.DistributedProtocol.Value.valueOf(distributedProtocol_); + return result == null ? flyteidl.plugins.sagemaker.TrainingJobOuterClass.DistributedProtocol.Value.UNRECOGNIZED : result; + } + private byte memoizedIsInitialized = -1; @java.lang.Override public final boolean isInitialized() { @@ -4135,6 +4735,9 @@ public void writeTo(com.google.protobuf.CodedOutputStream output) if (volumeSizeInGb_ != 0L) { output.writeInt64(3, volumeSizeInGb_); } + if (distributedProtocol_ != flyteidl.plugins.sagemaker.TrainingJobOuterClass.DistributedProtocol.Value.UNSPECIFIED.getNumber()) { + output.writeEnum(4, distributedProtocol_); + } unknownFields.writeTo(output); } @@ -4155,6 +4758,10 @@ public int getSerializedSize() { size += com.google.protobuf.CodedOutputStream .computeInt64Size(3, volumeSizeInGb_); } + if (distributedProtocol_ != flyteidl.plugins.sagemaker.TrainingJobOuterClass.DistributedProtocol.Value.UNSPECIFIED.getNumber()) { + size += com.google.protobuf.CodedOutputStream + .computeEnumSize(4, distributedProtocol_); + } size += unknownFields.getSerializedSize(); memoizedSize = size; return size; @@ -4176,6 +4783,7 @@ public boolean equals(final java.lang.Object obj) { .equals(other.getInstanceType())) return false; if (getVolumeSizeInGb() != other.getVolumeSizeInGb()) return false; + if (distributedProtocol_ != other.distributedProtocol_) return false; if (!unknownFields.equals(other.unknownFields)) return false; return true; } @@ -4195,6 +4803,8 @@ public int hashCode() { hash = (37 * hash) + VOLUME_SIZE_IN_GB_FIELD_NUMBER; hash = (53 * hash) + com.google.protobuf.Internal.hashLong( getVolumeSizeInGb()); + hash = (37 * hash) + DISTRIBUTED_PROTOCOL_FIELD_NUMBER; + hash = (53 * hash) + distributedProtocol_; hash = (29 * hash) + unknownFields.hashCode(); memoizedHashCode = hash; return hash; @@ -4340,6 +4950,8 @@ public Builder clear() { volumeSizeInGb_ = 0L; + distributedProtocol_ = 0; + return this; } @@ -4369,6 +4981,7 @@ public flyteidl.plugins.sagemaker.TrainingJobOuterClass.TrainingJobResourceConfi result.instanceCount_ = instanceCount_; result.instanceType_ = instanceType_; result.volumeSizeInGb_ = volumeSizeInGb_; + result.distributedProtocol_ = distributedProtocol_; onBuilt(); return result; } @@ -4427,6 +5040,9 @@ public Builder mergeFrom(flyteidl.plugins.sagemaker.TrainingJobOuterClass.Traini if (other.getVolumeSizeInGb() != 0L) { setVolumeSizeInGb(other.getVolumeSizeInGb()); } + if (other.distributedProtocol_ != 0) { + setDistributedProtocolValue(other.getDistributedProtocolValue()); + } this.mergeUnknownFields(other.unknownFields); onChanged(); return this; @@ -4620,6 +5236,81 @@ public Builder clearVolumeSizeInGb() { onChanged(); return this; } + + private int distributedProtocol_ = 0; + /** + *
+       * When users specify an instance_count > 1, Flyte will try to configure SageMaker to enable distributed training.
+       * If the users wish to use framework-agnostic distributed protocol such as MPI or Parameter Server, this
+       * field should be set to the corresponding enum value
+       * 
+ * + * .flyteidl.plugins.sagemaker.DistributedProtocol.Value distributed_protocol = 4; + */ + public int getDistributedProtocolValue() { + return distributedProtocol_; + } + /** + *
+       * When users specify an instance_count > 1, Flyte will try to configure SageMaker to enable distributed training.
+       * If the users wish to use framework-agnostic distributed protocol such as MPI or Parameter Server, this
+       * field should be set to the corresponding enum value
+       * 
+ * + * .flyteidl.plugins.sagemaker.DistributedProtocol.Value distributed_protocol = 4; + */ + public Builder setDistributedProtocolValue(int value) { + distributedProtocol_ = value; + onChanged(); + return this; + } + /** + *
+       * When users specify an instance_count > 1, Flyte will try to configure SageMaker to enable distributed training.
+       * If the users wish to use framework-agnostic distributed protocol such as MPI or Parameter Server, this
+       * field should be set to the corresponding enum value
+       * 
+ * + * .flyteidl.plugins.sagemaker.DistributedProtocol.Value distributed_protocol = 4; + */ + public flyteidl.plugins.sagemaker.TrainingJobOuterClass.DistributedProtocol.Value getDistributedProtocol() { + @SuppressWarnings("deprecation") + flyteidl.plugins.sagemaker.TrainingJobOuterClass.DistributedProtocol.Value result = flyteidl.plugins.sagemaker.TrainingJobOuterClass.DistributedProtocol.Value.valueOf(distributedProtocol_); + return result == null ? flyteidl.plugins.sagemaker.TrainingJobOuterClass.DistributedProtocol.Value.UNRECOGNIZED : result; + } + /** + *
+       * When users specify an instance_count > 1, Flyte will try to configure SageMaker to enable distributed training.
+       * If the users wish to use framework-agnostic distributed protocol such as MPI or Parameter Server, this
+       * field should be set to the corresponding enum value
+       * 
+ * + * .flyteidl.plugins.sagemaker.DistributedProtocol.Value distributed_protocol = 4; + */ + public Builder setDistributedProtocol(flyteidl.plugins.sagemaker.TrainingJobOuterClass.DistributedProtocol.Value value) { + if (value == null) { + throw new NullPointerException(); + } + + distributedProtocol_ = value.getNumber(); + onChanged(); + return this; + } + /** + *
+       * When users specify an instance_count > 1, Flyte will try to configure SageMaker to enable distributed training.
+       * If the users wish to use framework-agnostic distributed protocol such as MPI or Parameter Server, this
+       * field should be set to the corresponding enum value
+       * 
+ * + * .flyteidl.plugins.sagemaker.DistributedProtocol.Value distributed_protocol = 4; + */ + public Builder clearDistributedProtocol() { + + distributedProtocol_ = 0; + onChanged(); + return this; + } @java.lang.Override public final Builder setUnknownFields( final com.google.protobuf.UnknownFieldSet unknownFields) { @@ -5508,6 +6199,11 @@ public flyteidl.plugins.sagemaker.TrainingJobOuterClass.TrainingJob getDefaultIn private static final com.google.protobuf.GeneratedMessageV3.FieldAccessorTable internal_static_flyteidl_plugins_sagemaker_AlgorithmSpecification_fieldAccessorTable; + private static final com.google.protobuf.Descriptors.Descriptor + internal_static_flyteidl_plugins_sagemaker_DistributedProtocol_descriptor; + private static final + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internal_static_flyteidl_plugins_sagemaker_DistributedProtocol_fieldAccessorTable; private static final com.google.protobuf.Descriptors.Descriptor internal_static_flyteidl_plugins_sagemaker_TrainingJobResourceConfig_descriptor; private static final @@ -5542,16 +6238,20 @@ public flyteidl.plugins.sagemaker.TrainingJobOuterClass.TrainingJob getDefaultIn " \001(\t\022H\n\022metric_definitions\030\004 \003(\0132,.flyte" + "idl.plugins.sagemaker.MetricDefinition\022N" + "\n\022input_content_type\030\005 \001(\01622.flyteidl.pl" + - "ugins.sagemaker.InputContentType.Value\"e" + - "\n\031TrainingJobResourceConfig\022\026\n\016instance_" + - "count\030\001 \001(\003\022\025\n\rinstance_type\030\002 \001(\t\022\031\n\021vo" + - "lume_size_in_gb\030\003 \001(\003\"\277\001\n\013TrainingJob\022S\n" + - "\027algorithm_specification\030\001 \001(\01322.flyteid" + - "l.plugins.sagemaker.AlgorithmSpecificati" + - "on\022[\n\034training_job_resource_config\030\002 \001(\013" + - "25.flyteidl.plugins.sagemaker.TrainingJo" + - "bResourceConfigB5Z3github.com/lyft/flyte" + - "idl/gen/pb-go/flyteidl/pluginsb\006proto3" + "ugins.sagemaker.InputContentType.Value\"8" + + "\n\023DistributedProtocol\"!\n\005Value\022\017\n\013UNSPEC" + + "IFIED\020\000\022\007\n\003MPI\020\001\"\272\001\n\031TrainingJobResource" + + "Config\022\026\n\016instance_count\030\001 \001(\003\022\025\n\rinstan" + + "ce_type\030\002 \001(\t\022\031\n\021volume_size_in_gb\030\003 \001(\003" + + "\022S\n\024distributed_protocol\030\004 \001(\01625.flyteid" + + "l.plugins.sagemaker.DistributedProtocol." + + "Value\"\277\001\n\013TrainingJob\022S\n\027algorithm_speci" + + "fication\030\001 \001(\01322.flyteidl.plugins.sagema" + + "ker.AlgorithmSpecification\022[\n\034training_j" + + "ob_resource_config\030\002 \001(\01325.flyteidl.plug" + + "ins.sagemaker.TrainingJobResourceConfigB" + + "5Z3github.com/lyft/flyteidl/gen/pb-go/fl" + + "yteidl/pluginsb\006proto3" }; com.google.protobuf.Descriptors.FileDescriptor.InternalDescriptorAssigner assigner = new com.google.protobuf.Descriptors.FileDescriptor. InternalDescriptorAssigner() { @@ -5596,14 +6296,20 @@ public com.google.protobuf.ExtensionRegistry assignDescriptors( com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_flyteidl_plugins_sagemaker_AlgorithmSpecification_descriptor, new java.lang.String[] { "InputMode", "AlgorithmName", "AlgorithmVersion", "MetricDefinitions", "InputContentType", }); - internal_static_flyteidl_plugins_sagemaker_TrainingJobResourceConfig_descriptor = + internal_static_flyteidl_plugins_sagemaker_DistributedProtocol_descriptor = getDescriptor().getMessageTypes().get(5); + internal_static_flyteidl_plugins_sagemaker_DistributedProtocol_fieldAccessorTable = new + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( + internal_static_flyteidl_plugins_sagemaker_DistributedProtocol_descriptor, + new java.lang.String[] { }); + internal_static_flyteidl_plugins_sagemaker_TrainingJobResourceConfig_descriptor = + getDescriptor().getMessageTypes().get(6); internal_static_flyteidl_plugins_sagemaker_TrainingJobResourceConfig_fieldAccessorTable = new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_flyteidl_plugins_sagemaker_TrainingJobResourceConfig_descriptor, - new java.lang.String[] { "InstanceCount", "InstanceType", "VolumeSizeInGb", }); + new java.lang.String[] { "InstanceCount", "InstanceType", "VolumeSizeInGb", "DistributedProtocol", }); internal_static_flyteidl_plugins_sagemaker_TrainingJob_descriptor = - getDescriptor().getMessageTypes().get(6); + getDescriptor().getMessageTypes().get(7); internal_static_flyteidl_plugins_sagemaker_TrainingJob_fieldAccessorTable = new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_flyteidl_plugins_sagemaker_TrainingJob_descriptor, diff --git a/gen/pb-protodoc/flyteidl/plugins/sagemaker/training_job.proto.rst b/gen/pb-protodoc/flyteidl/plugins/sagemaker/training_job.proto.rst index dda63465b..625c9d511 100644 --- a/gen/pb-protodoc/flyteidl/plugins/sagemaker/training_job.proto.rst +++ b/gen/pb-protodoc/flyteidl/plugins/sagemaker/training_job.proto.rst @@ -205,12 +205,53 @@ input_content_type +.. _api_msg_flyteidl.plugins.sagemaker.DistributedProtocol: + +flyteidl.plugins.sagemaker.DistributedProtocol +---------------------------------------------- + +`[flyteidl.plugins.sagemaker.DistributedProtocol proto] `_ + +When enabling distributed training on a training job, the user should use this message to tell Flyte and SageMaker +what kind of distributed protocol he/she wants to use to distribute the work. + +.. code-block:: json + + {} + + + +.. _api_enum_flyteidl.plugins.sagemaker.DistributedProtocol.Value: + +Enum flyteidl.plugins.sagemaker.DistributedProtocol.Value +--------------------------------------------------------- + +`[flyteidl.plugins.sagemaker.DistributedProtocol.Value proto] `_ + + +.. _api_enum_value_flyteidl.plugins.sagemaker.DistributedProtocol.Value.UNSPECIFIED: + +UNSPECIFIED + *(DEFAULT)* ⁣Use this value if the user wishes to use framework-native distributed training interfaces. + If this value is used, Flyte won't configure SageMaker to initialize unnecessary components such as + OpenMPI or Parameter Server. + + +.. _api_enum_value_flyteidl.plugins.sagemaker.DistributedProtocol.Value.MPI: + +MPI + ⁣Use this value if the user wishes to use MPI as the underlying protocol for her distributed training job + MPI is a framework-agnostic distributed protocol. It has multiple implementations. Currently, we have only + tested the OpenMPI implementation, which is the recommended implementation for Horovod. + + + .. _api_msg_flyteidl.plugins.sagemaker.TrainingJobResourceConfig: flyteidl.plugins.sagemaker.TrainingJobResourceConfig ---------------------------------------------------- -`[flyteidl.plugins.sagemaker.TrainingJobResourceConfig proto] `_ +`[flyteidl.plugins.sagemaker.TrainingJobResourceConfig proto] `_ TrainingJobResourceConfig is a pass-through, specifying the instance type to use for the training job, the number of instances to launch, and the size of the ML storage volume the user wants to provision @@ -221,7 +262,8 @@ Refer to SageMaker official doc for more details: https://docs.aws.amazon.com/sa { "instance_count": "...", "instance_type": "...", - "volume_size_in_gb": "..." + "volume_size_in_gb": "...", + "distributed_protocol": "..." } .. _api_field_flyteidl.plugins.sagemaker.TrainingJobResourceConfig.instance_count: @@ -242,6 +284,14 @@ volume_size_in_gb (`int64 `_) The size of the ML storage volume that you want to provision. +.. _api_field_flyteidl.plugins.sagemaker.TrainingJobResourceConfig.distributed_protocol: + +distributed_protocol + (:ref:`flyteidl.plugins.sagemaker.DistributedProtocol.Value `) When users specify an instance_count > 1, Flyte will try to configure SageMaker to enable distributed training. + If the users wish to use framework-agnostic distributed protocol such as MPI or Parameter Server, this + field should be set to the corresponding enum value + + .. _api_msg_flyteidl.plugins.sagemaker.TrainingJob: @@ -249,7 +299,7 @@ volume_size_in_gb flyteidl.plugins.sagemaker.TrainingJob -------------------------------------- -`[flyteidl.plugins.sagemaker.TrainingJob proto] `_ +`[flyteidl.plugins.sagemaker.TrainingJob proto] `_ The spec of a training job. This is mostly a pass-through object https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateTrainingJob.html diff --git a/gen/pb_python/flyteidl/plugins/sagemaker/training_job_pb2.py b/gen/pb_python/flyteidl/plugins/sagemaker/training_job_pb2.py index 776e925b9..0934d438d 100644 --- a/gen/pb_python/flyteidl/plugins/sagemaker/training_job_pb2.py +++ b/gen/pb_python/flyteidl/plugins/sagemaker/training_job_pb2.py @@ -21,7 +21,7 @@ package='flyteidl.plugins.sagemaker', syntax='proto3', serialized_options=_b('Z3github.com/lyft/flyteidl/gen/pb-go/flyteidl/plugins'), - serialized_pb=_b('\n-flyteidl/plugins/sagemaker/training_job.proto\x12\x1a\x66lyteidl.plugins.sagemaker\x1a\x1egoogle/protobuf/duration.proto\"(\n\tInputMode\"\x1b\n\x05Value\x12\x08\n\x04\x46ILE\x10\x00\x12\x08\n\x04PIPE\x10\x01\"1\n\rAlgorithmName\" \n\x05Value\x12\n\n\x06\x43USTOM\x10\x00\x12\x0b\n\x07XGBOOST\x10\x01\")\n\x10InputContentType\"\x15\n\x05Value\x12\x0c\n\x08TEXT_CSV\x10\x00\"/\n\x10MetricDefinition\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\r\n\x05regex\x18\x02 \x01(\t\"\xd7\x02\n\x16\x41lgorithmSpecification\x12?\n\ninput_mode\x18\x01 \x01(\x0e\x32+.flyteidl.plugins.sagemaker.InputMode.Value\x12G\n\x0e\x61lgorithm_name\x18\x02 \x01(\x0e\x32/.flyteidl.plugins.sagemaker.AlgorithmName.Value\x12\x19\n\x11\x61lgorithm_version\x18\x03 \x01(\t\x12H\n\x12metric_definitions\x18\x04 \x03(\x0b\x32,.flyteidl.plugins.sagemaker.MetricDefinition\x12N\n\x12input_content_type\x18\x05 \x01(\x0e\x32\x32.flyteidl.plugins.sagemaker.InputContentType.Value\"e\n\x19TrainingJobResourceConfig\x12\x16\n\x0einstance_count\x18\x01 \x01(\x03\x12\x15\n\rinstance_type\x18\x02 \x01(\t\x12\x19\n\x11volume_size_in_gb\x18\x03 \x01(\x03\"\xbf\x01\n\x0bTrainingJob\x12S\n\x17\x61lgorithm_specification\x18\x01 \x01(\x0b\x32\x32.flyteidl.plugins.sagemaker.AlgorithmSpecification\x12[\n\x1ctraining_job_resource_config\x18\x02 \x01(\x0b\x32\x35.flyteidl.plugins.sagemaker.TrainingJobResourceConfigB5Z3github.com/lyft/flyteidl/gen/pb-go/flyteidl/pluginsb\x06proto3') + serialized_pb=_b('\n-flyteidl/plugins/sagemaker/training_job.proto\x12\x1a\x66lyteidl.plugins.sagemaker\x1a\x1egoogle/protobuf/duration.proto\"(\n\tInputMode\"\x1b\n\x05Value\x12\x08\n\x04\x46ILE\x10\x00\x12\x08\n\x04PIPE\x10\x01\"1\n\rAlgorithmName\" \n\x05Value\x12\n\n\x06\x43USTOM\x10\x00\x12\x0b\n\x07XGBOOST\x10\x01\")\n\x10InputContentType\"\x15\n\x05Value\x12\x0c\n\x08TEXT_CSV\x10\x00\"/\n\x10MetricDefinition\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\r\n\x05regex\x18\x02 \x01(\t\"\xd7\x02\n\x16\x41lgorithmSpecification\x12?\n\ninput_mode\x18\x01 \x01(\x0e\x32+.flyteidl.plugins.sagemaker.InputMode.Value\x12G\n\x0e\x61lgorithm_name\x18\x02 \x01(\x0e\x32/.flyteidl.plugins.sagemaker.AlgorithmName.Value\x12\x19\n\x11\x61lgorithm_version\x18\x03 \x01(\t\x12H\n\x12metric_definitions\x18\x04 \x03(\x0b\x32,.flyteidl.plugins.sagemaker.MetricDefinition\x12N\n\x12input_content_type\x18\x05 \x01(\x0e\x32\x32.flyteidl.plugins.sagemaker.InputContentType.Value\"8\n\x13\x44istributedProtocol\"!\n\x05Value\x12\x0f\n\x0bUNSPECIFIED\x10\x00\x12\x07\n\x03MPI\x10\x01\"\xba\x01\n\x19TrainingJobResourceConfig\x12\x16\n\x0einstance_count\x18\x01 \x01(\x03\x12\x15\n\rinstance_type\x18\x02 \x01(\t\x12\x19\n\x11volume_size_in_gb\x18\x03 \x01(\x03\x12S\n\x14\x64istributed_protocol\x18\x04 \x01(\x0e\x32\x35.flyteidl.plugins.sagemaker.DistributedProtocol.Value\"\xbf\x01\n\x0bTrainingJob\x12S\n\x17\x61lgorithm_specification\x18\x01 \x01(\x0b\x32\x32.flyteidl.plugins.sagemaker.AlgorithmSpecification\x12[\n\x1ctraining_job_resource_config\x18\x02 \x01(\x0b\x32\x35.flyteidl.plugins.sagemaker.TrainingJobResourceConfigB5Z3github.com/lyft/flyteidl/gen/pb-go/flyteidl/pluginsb\x06proto3') , dependencies=[google_dot_protobuf_dot_duration__pb2.DESCRIPTOR,]) @@ -89,6 +89,28 @@ ) _sym_db.RegisterEnumDescriptor(_INPUTCONTENTTYPE_VALUE) +_DISTRIBUTEDPROTOCOL_VALUE = _descriptor.EnumDescriptor( + name='Value', + full_name='flyteidl.plugins.sagemaker.DistributedProtocol.Value', + filename=None, + file=DESCRIPTOR, + values=[ + _descriptor.EnumValueDescriptor( + name='UNSPECIFIED', index=0, number=0, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='MPI', index=1, number=1, + serialized_options=None, + type=None), + ], + containing_type=None, + serialized_options=None, + serialized_start=663, + serialized_end=696, +) +_sym_db.RegisterEnumDescriptor(_DISTRIBUTEDPROTOCOL_VALUE) + _INPUTMODE = _descriptor.Descriptor( name='InputMode', @@ -262,6 +284,31 @@ ) +_DISTRIBUTEDPROTOCOL = _descriptor.Descriptor( + name='DistributedProtocol', + full_name='flyteidl.plugins.sagemaker.DistributedProtocol', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + _DISTRIBUTEDPROTOCOL_VALUE, + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=640, + serialized_end=696, +) + + _TRAININGJOBRESOURCECONFIG = _descriptor.Descriptor( name='TrainingJobResourceConfig', full_name='flyteidl.plugins.sagemaker.TrainingJobResourceConfig', @@ -290,6 +337,13 @@ message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='distributed_protocol', full_name='flyteidl.plugins.sagemaker.TrainingJobResourceConfig.distributed_protocol', index=3, + number=4, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), ], extensions=[ ], @@ -302,8 +356,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=640, - serialized_end=741, + serialized_start=699, + serialized_end=885, ) @@ -340,8 +394,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=744, - serialized_end=935, + serialized_start=888, + serialized_end=1079, ) _INPUTMODE_VALUE.containing_type = _INPUTMODE @@ -351,6 +405,8 @@ _ALGORITHMSPECIFICATION.fields_by_name['algorithm_name'].enum_type = _ALGORITHMNAME_VALUE _ALGORITHMSPECIFICATION.fields_by_name['metric_definitions'].message_type = _METRICDEFINITION _ALGORITHMSPECIFICATION.fields_by_name['input_content_type'].enum_type = _INPUTCONTENTTYPE_VALUE +_DISTRIBUTEDPROTOCOL_VALUE.containing_type = _DISTRIBUTEDPROTOCOL +_TRAININGJOBRESOURCECONFIG.fields_by_name['distributed_protocol'].enum_type = _DISTRIBUTEDPROTOCOL_VALUE _TRAININGJOB.fields_by_name['algorithm_specification'].message_type = _ALGORITHMSPECIFICATION _TRAININGJOB.fields_by_name['training_job_resource_config'].message_type = _TRAININGJOBRESOURCECONFIG DESCRIPTOR.message_types_by_name['InputMode'] = _INPUTMODE @@ -358,6 +414,7 @@ DESCRIPTOR.message_types_by_name['InputContentType'] = _INPUTCONTENTTYPE DESCRIPTOR.message_types_by_name['MetricDefinition'] = _METRICDEFINITION DESCRIPTOR.message_types_by_name['AlgorithmSpecification'] = _ALGORITHMSPECIFICATION +DESCRIPTOR.message_types_by_name['DistributedProtocol'] = _DISTRIBUTEDPROTOCOL DESCRIPTOR.message_types_by_name['TrainingJobResourceConfig'] = _TRAININGJOBRESOURCECONFIG DESCRIPTOR.message_types_by_name['TrainingJob'] = _TRAININGJOB _sym_db.RegisterFileDescriptor(DESCRIPTOR) @@ -397,6 +454,13 @@ )) _sym_db.RegisterMessage(AlgorithmSpecification) +DistributedProtocol = _reflection.GeneratedProtocolMessageType('DistributedProtocol', (_message.Message,), dict( + DESCRIPTOR = _DISTRIBUTEDPROTOCOL, + __module__ = 'flyteidl.plugins.sagemaker.training_job_pb2' + # @@protoc_insertion_point(class_scope:flyteidl.plugins.sagemaker.DistributedProtocol) + )) +_sym_db.RegisterMessage(DistributedProtocol) + TrainingJobResourceConfig = _reflection.GeneratedProtocolMessageType('TrainingJobResourceConfig', (_message.Message,), dict( DESCRIPTOR = _TRAININGJOBRESOURCECONFIG, __module__ = 'flyteidl.plugins.sagemaker.training_job_pb2' diff --git a/package.json b/package.json index 9a2e80a5b..02f27b639 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@lyft/flyteidl", - "version": "0.18.8", + "version": "0.18.9", "description": "Compiled protocol buffers and gRPC service clients/servers for Flyte IDLs", "repository": { "type": "git", diff --git a/protos/flyteidl/plugins/sagemaker/training_job.proto b/protos/flyteidl/plugins/sagemaker/training_job.proto index 1d8dc820d..eb14da592 100644 --- a/protos/flyteidl/plugins/sagemaker/training_job.proto +++ b/protos/flyteidl/plugins/sagemaker/training_job.proto @@ -80,6 +80,21 @@ message AlgorithmSpecification { InputContentType.Value input_content_type = 5; } +// When enabling distributed training on a training job, the user should use this message to tell Flyte and SageMaker +// what kind of distributed protocol he/she wants to use to distribute the work. +message DistributedProtocol { + enum Value { + // Use this value if the user wishes to use framework-native distributed training interfaces. + // If this value is used, Flyte won't configure SageMaker to initialize unnecessary components such as + // OpenMPI or Parameter Server. + UNSPECIFIED = 0; + // Use this value if the user wishes to use MPI as the underlying protocol for her distributed training job + // MPI is a framework-agnostic distributed protocol. It has multiple implementations. Currently, we have only + // tested the OpenMPI implementation, which is the recommended implementation for Horovod. + MPI = 1; + } +} + // TrainingJobResourceConfig is a pass-through, specifying the instance type to use for the training job, the // number of instances to launch, and the size of the ML storage volume the user wants to provision // Refer to SageMaker official doc for more details: https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateTrainingJob.html @@ -90,6 +105,10 @@ message TrainingJobResourceConfig { string instance_type = 2; // The size of the ML storage volume that you want to provision. int64 volume_size_in_gb = 3; + // When users specify an instance_count > 1, Flyte will try to configure SageMaker to enable distributed training. + // If the users wish to use framework-agnostic distributed protocol such as MPI or Parameter Server, this + // field should be set to the corresponding enum value + DistributedProtocol.Value distributed_protocol = 4; } // The spec of a training job. This is mostly a pass-through object diff --git a/setup.py b/setup.py index 6193d912a..c412511c1 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,6 @@ from setuptools import setup, find_packages -__version__ = '0.18.8' +__version__ = '0.18.9' setup( name='flyteidl',