Skip to content

Commit

Permalink
Better PythonFunction error message (#883)
Browse files Browse the repository at this point in the history
* Better PythonFunction error message

Signed-off-by: Albert Wolant <[email protected]>
  • Loading branch information
awolant authored and JanuszL committed May 16, 2019
1 parent 783a3e9 commit 527e2fd
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 2 deletions.
8 changes: 6 additions & 2 deletions dali/pipeline/operators/python_function/python_function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,12 @@ void PythonFunctionImpl<CPUBackend>::RunImpl(SampleWorkspace *ws, const int idx)
const auto &input = ws->Input<CPUBackend>(idx);
auto &output = ws->Output<CPUBackend>(idx);
py::gil_scoped_acquire guard{};
py::array output_array = python_function(TensorToNumpyArray(input));
CopyNumpyArrayToTensor(output, output_array);
try {
py::array output_array = python_function(TensorToNumpyArray(input));
CopyNumpyArrayToTensor(output, output_array);
} catch(const py::error_already_set & e) {
throw std::runtime_error(to_string("PythonFunction error: ") + to_string(e.what()));
}
}

DALI_REGISTER_OPERATOR(PythonFunctionImpl, PythonFunctionImpl<CPUBackend>, CPU);
Expand Down
11 changes: 11 additions & 0 deletions dali/test/python/test_python_function_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,14 @@ def test_python_operator_flip():
dali_output, = dali_flip.run()
for i in range(len(numpy_output)):
assert numpy.array_equal(numpy_output.at(i), dali_output.at(i))

def invalid_function(image):
return img

def test_python_operator_invalid_function():
invalid_pipe = PythonOperatorPipeline(BATCH_SIZE, NUM_WORKERS, DEVICE_ID, SEED, images_dir, invalid_function)
invalid_pipe.build()
try:
invalid_pipe.run()
except Exception as e:
print(e)

0 comments on commit 527e2fd

Please sign in to comment.