From 527e2fd037e5d49af1357037daeabe68ff3f84a1 Mon Sep 17 00:00:00 2001 From: Albert Wolant <44801854+awolant@users.noreply.github.com> Date: Wed, 15 May 2019 13:05:41 +0200 Subject: [PATCH] Better PythonFunction error message (#883) * Better PythonFunction error message Signed-off-by: Albert Wolant --- .../operators/python_function/python_function.cc | 8 ++++++-- dali/test/python/test_python_function_operator.py | 11 +++++++++++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/dali/pipeline/operators/python_function/python_function.cc b/dali/pipeline/operators/python_function/python_function.cc index 85bd073eee..0046879ed7 100644 --- a/dali/pipeline/operators/python_function/python_function.cc +++ b/dali/pipeline/operators/python_function/python_function.cc @@ -120,8 +120,12 @@ void PythonFunctionImpl::RunImpl(SampleWorkspace *ws, const int idx) const auto &input = ws->Input(idx); auto &output = ws->Output(idx); py::gil_scoped_acquire guard{}; - py::array output_array = python_function(TensorToNumpyArray(input)); - CopyNumpyArrayToTensor(output, output_array); + try { + py::array output_array = python_function(TensorToNumpyArray(input)); + CopyNumpyArrayToTensor(output, output_array); + } catch(const py::error_already_set & e) { + throw std::runtime_error(to_string("PythonFunction error: ") + to_string(e.what())); + } } DALI_REGISTER_OPERATOR(PythonFunctionImpl, PythonFunctionImpl, CPU); diff --git a/dali/test/python/test_python_function_operator.py b/dali/test/python/test_python_function_operator.py index 1c9730f7a8..8bde65569e 100644 --- a/dali/test/python/test_python_function_operator.py +++ b/dali/test/python/test_python_function_operator.py @@ -130,3 +130,14 @@ def test_python_operator_flip(): dali_output, = dali_flip.run() for i in range(len(numpy_output)): assert numpy.array_equal(numpy_output.at(i), dali_output.at(i)) + +def invalid_function(image): + return img + +def test_python_operator_invalid_function(): + invalid_pipe = PythonOperatorPipeline(BATCH_SIZE, NUM_WORKERS, DEVICE_ID, SEED, images_dir, invalid_function) + invalid_pipe.build() + try: + invalid_pipe.run() + except Exception as e: + print(e)