Skip to content

Commit

Permalink
fix: support exact integers as floats for IntCat
Browse files Browse the repository at this point in the history
  • Loading branch information
henryiii committed Feb 16, 2022
1 parent 23d886f commit eb12664
Showing 1 changed file with 22 additions and 1 deletion.
23 changes: 22 additions & 1 deletion include/bh_python/register_axis.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,27 @@ auto vectorize_index(T input) {
#define BHP_NOEXCEPT_17
#endif

namespace detail {
template <class T>
decltype(auto) axis_cast(py::handle x) {
return special_cast<T>(x);
}

template <>
inline decltype(auto) axis_cast<int>(py::handle x) {
if(py::isinstance<int>(x))
return py::cast<int>(x);

auto val = py::cast<float>(x);
auto ival = static_cast<int>(val);

if(static_cast<float>(ival) == val)
return ival;

throw py::type_error(py::str("cannot cast {} to int").format(val));
}
} // namespace detail

// we overload vectorize index for category axis
template <class T, class Options>
auto vectorize_index(int (bh::axis::category<T, metadata_t, Options>::*pindex)(const T&)
Expand All @@ -56,7 +77,7 @@ auto vectorize_index(int (bh::axis::category<T, metadata_t, Options>::*pindex)(c
auto index = std::mem_fn(pindex);

if(detail::is_value<T>(arg)) {
auto index_value = index(self, detail::special_cast<T>(arg));
auto index_value = index(self, detail::axis_cast<T>(arg));
if(index_value >= self.size())
throw pybind11::key_error(py::str("{!r} not in axis").format(arg));
return py::cast(index_value);
Expand Down

0 comments on commit eb12664

Please sign in to comment.