Skip to content

Commit

Permalink
YQL-17542 split FillIoMaps (#1537)
Browse files Browse the repository at this point in the history
  • Loading branch information
zverevgeny authored Feb 2, 2024
1 parent 205480d commit 70035f4
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ class TDqAsyncComputeActor : public TDqComputeActorBase<TDqAsyncComputeActor, TC
Stat->AddCounters2(ev->Get()->Sensors);
}
TypeEnv = const_cast<NKikimr::NMiniKQL::TTypeEnvironment*>(&typeEnv);
FillIoMaps(holderFactory, typeEnv, secureParams, taskParams, readRanges);
FillIoMaps(holderFactory, typeEnv, secureParams, taskParams, readRanges, nullptr);

{
// say "Hello" to executer
Expand Down
74 changes: 28 additions & 46 deletions ydb/library/yql/dq/actors/compute/dq_compute_actor_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1441,19 +1441,12 @@ class TDqComputeActorBase : public NActors::TActorBootstrapped<TDerived>
const NKikimr::NMiniKQL::TTypeEnvironment& typeEnv,
const THashMap<TString, TString>& secureParams,
const THashMap<TString, TString>& taskParams,
const TVector<TString>& readRanges)
const TVector<TString>& readRanges,
IRandomProvider* randomProvider
)
{
if (TaskRunner) {
for (auto& [channelId, channel] : InputChannelsMap) {
channel.Channel = TaskRunner->GetInputChannel(channelId);
}
}
auto collectStatsLevel = StatsModeToCollectStatsLevel(RuntimeSettings.StatsMode);
for (auto& [inputIndex, source] : SourcesMap) {
if constexpr (!TDerived::HasAsyncTaskRunner) {
source.Buffer = TaskRunner->GetSource(inputIndex);
Y_ABORT_UNLESS(source.Buffer);
}
Y_ABORT_UNLESS(AsyncIoFactory);
const auto& inputDesc = Task.GetInputs(inputIndex);
Y_ABORT_UNLESS(inputDesc.HasSource());
Expand Down Expand Up @@ -1487,9 +1480,8 @@ class TDqComputeActorBase : public NActors::TActorBootstrapped<TDerived>
this->RegisterWithSameMailbox(source.Actor);
}
for (auto& [inputIndex, transform] : InputTransformsMap) {
if constexpr (!TDerived::HasAsyncTaskRunner) {
transform.ProgramBuilder.ConstructInPlace(TaskRunner->GetTypeEnv(), *FunctionRegistry);
std::tie(transform.InputBuffer, transform.Buffer) = TaskRunner->GetInputTransform(inputIndex);
Y_ABORT_UNLESS(TaskRunner);
transform.ProgramBuilder.ConstructInPlace(typeEnv, *FunctionRegistry);
Y_ABORT_UNLESS(AsyncIoFactory);
const auto& inputDesc = Task.GetInputs(inputIndex);
CA_LOG_D("Create transform for input " << inputIndex << " " << inputDesc.ShortDebugString());
Expand All @@ -1515,43 +1507,33 @@ class TDqComputeActorBase : public NActors::TActorBootstrapped<TDerived>
throw yexception() << "Failed to create input transform " << inputDesc.GetTransform().GetType() << ": " << ex.what();
}
this->RegisterWithSameMailbox(transform.Actor);
}
}
if (TaskRunner) {
for (auto& [channelId, channel] : OutputChannelsMap) {
channel.Channel = TaskRunner->GetOutputChannel(channelId);
}
}
for (auto& [outputIndex, transform] : OutputTransformsMap) {
if (TaskRunner) {
transform.ProgramBuilder.ConstructInPlace(TaskRunner->GetTypeEnv(), *FunctionRegistry);
std::tie(transform.Buffer, transform.OutputBuffer) = TaskRunner->GetOutputTransform(outputIndex);
Y_ABORT_UNLESS(AsyncIoFactory);
const auto& outputDesc = Task.GetOutputs(outputIndex);
CA_LOG_D("Create transform for output " << outputIndex << " " << outputDesc.ShortDebugString());
try {
std::tie(transform.AsyncOutput, transform.Actor) = AsyncIoFactory->CreateDqOutputTransform(
IDqAsyncIoFactory::TOutputTransformArguments {
.OutputDesc = outputDesc,
.OutputIndex = outputIndex,
.StatsLevel = collectStatsLevel,
.TxId = TxId,
.TransformOutput = transform.OutputBuffer,
.Callback = static_cast<TOutputTransformCallbacks*>(this),
.SecureParams = secureParams,
.TaskParams = taskParams,
.TypeEnv = typeEnv,
.HolderFactory = holderFactory,
.ProgramBuilder = *transform.ProgramBuilder
});
} catch (const std::exception& ex) {
throw yexception() << "Failed to create output transform " << outputDesc.GetTransform().GetType() << ": " << ex.what();
}
this->RegisterWithSameMailbox(transform.Actor);
transform.ProgramBuilder.ConstructInPlace(typeEnv, *FunctionRegistry);
Y_ABORT_UNLESS(AsyncIoFactory);
const auto& outputDesc = Task.GetOutputs(outputIndex);
CA_LOG_D("Create transform for output " << outputIndex << " " << outputDesc.ShortDebugString());
try {
std::tie(transform.AsyncOutput, transform.Actor) = AsyncIoFactory->CreateDqOutputTransform(
IDqAsyncIoFactory::TOutputTransformArguments {
.OutputDesc = outputDesc,
.OutputIndex = outputIndex,
.StatsLevel = collectStatsLevel,
.TxId = TxId,
.TransformOutput = transform.OutputBuffer,
.Callback = static_cast<TOutputTransformCallbacks*>(this),
.SecureParams = secureParams,
.TaskParams = taskParams,
.TypeEnv = typeEnv,
.HolderFactory = holderFactory,
.ProgramBuilder = *transform.ProgramBuilder
});
} catch (const std::exception& ex) {
throw yexception() << "Failed to create output transform " << outputDesc.GetTransform().GetType() << ": " << ex.what();
}
this->RegisterWithSameMailbox(transform.Actor);
}
for (auto& [outputIndex, sink] : SinksMap) {
if (TaskRunner) { sink.Buffer = TaskRunner->GetSink(outputIndex); }
Y_ABORT_UNLESS(AsyncIoFactory);
const auto& outputDesc = Task.GetOutputs(outputIndex);
Y_ABORT_UNLESS(outputDesc.HasSink());
Expand All @@ -1569,7 +1551,7 @@ class TDqComputeActorBase : public NActors::TActorBootstrapped<TDerived>
.TaskParams = taskParams,
.TypeEnv = typeEnv,
.HolderFactory = holderFactory,
.RandomProvider = TaskRunner ? TaskRunner->GetRandomProvider() : nullptr
.RandomProvider = randomProvider
});
} catch (const std::exception& ex) {
throw yexception() << "Failed to create sink " << outputDesc.GetSink().GetType() << ": " << ex.what();
Expand Down
37 changes: 32 additions & 5 deletions ydb/library/yql/dq/actors/compute/dq_sync_compute_actor_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,39 @@ class TDqSyncComputeActorBase: public TDqComputeActorBase<TDerived, TComputeActo

this->TaskRunner->Prepare(this->Task, limits, execCtx);

for (auto& [channelId, channel] : this->InputChannelsMap) {
channel.Channel = this->TaskRunner->GetInputChannel(channelId);
}

for (auto& [inputIndex, source] : this->SourcesMap) {
source.Buffer = this->TaskRunner->GetSource(inputIndex);
Y_ABORT_UNLESS(source.Buffer);
}

for (auto& [inputIndex, transform] : this->InputTransformsMap) {
std::tie(transform.InputBuffer, transform.Buffer) = this->TaskRunner->GetInputTransform(inputIndex);
}

for (auto& [channelId, channel] : this->OutputChannelsMap) {
channel.Channel = this->TaskRunner->GetOutputChannel(channelId);
}

for (auto& [outputIndex, transform] : this->OutputTransformsMap) {
std::tie(transform.Buffer, transform.OutputBuffer) = this->TaskRunner->GetOutputTransform(outputIndex);
}

for (auto& [outputIndex, sink] : this->SinksMap) {
sink.Buffer = this->TaskRunner->GetSink(outputIndex);
}

TBase::FillIoMaps(
this->TaskRunner->GetHolderFactory(),
this->TaskRunner->GetTypeEnv(),
this->TaskRunner->GetSecureParams(),
this->TaskRunner->GetTaskParams(),
this->TaskRunner->GetReadRanges());
this->TaskRunner->GetHolderFactory(),
this->TaskRunner->GetTypeEnv(),
this->TaskRunner->GetSecureParams(),
this->TaskRunner->GetTaskParams(),
this->TaskRunner->GetReadRanges(),
this->TaskRunner->GetRandomProvider()
);
}
};

Expand Down

0 comments on commit 70035f4

Please sign in to comment.