Skip to content

Commit

Permalink
Added expression-based permutation functor
Browse files Browse the repository at this point in the history
Example:

    auto reverse = vex::eslice(N - 1 - vex::element_index());
    Y = reverse(X);
  • Loading branch information
ddemidov committed Sep 9, 2013
1 parent 6d54cab commit d78879e
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 35 deletions.
17 changes: 13 additions & 4 deletions tests/vector_view.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,22 @@ BOOST_AUTO_TEST_CASE(vector_permutation)
vex::vector<double> Y(queue, N);
vex::vector<size_t> I(queue, N);

I = N - 1 - vex::element_index();
{
I = N - 1 - vex::element_index();
vex::permutation reverse(I);
Y = reverse(X);

vex::permutation reverse(I);
check_sample(Y, [&](size_t idx, double v) { BOOST_CHECK_EQUAL(v, x[N - 1 - idx]); });
}

Y = 0;

Y = reverse(X);
{
auto reverse = vex::eslice(N - 1 - vex::element_index());
Y = reverse(X);

check_sample(Y, [&](size_t idx, double v) { BOOST_CHECK_EQUAL(v, x[N - 1 - idx]); });
check_sample(Y, [&](size_t idx, double v) { BOOST_CHECK_EQUAL(v, x[N - 1 - idx]); });
}
}

BOOST_AUTO_TEST_CASE(reduce_slice)
Expand Down
155 changes: 124 additions & 31 deletions vexcl/vector_view.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,14 @@ struct vector_view : public vector_view_terminal_expression
operator cop(const Expr &expr) { \
std::vector<size_t> part(2, 0); \
part.back() = slice.size(); \
if (part.back() == 0) part.back() = base.size(); \
detail::assign_expression<op>(*this, expr, base.queue_list(), part); \
return *this; \
} \
const vector_view& operator cop(const vector_view &other) { \
std::vector<size_t> part(2, 0); \
part.back() = slice.size(); \
if (part.back() == 0) part.back() = base.size(); \
detail::assign_expression<op>(*this, other, base.queue_list(), part); \
return *this; \
}
Expand Down Expand Up @@ -124,31 +126,31 @@ struct proto_terminal_is_value< vector_view_terminal > : std::true_type {};

template <typename T, class Slice>
struct terminal_preamble< vector_view<T, Slice> > {
static std::string get(const vector_view<T, Slice>&,
const cl::Device&, const std::string &prm_name,
static std::string get(const vector_view<T, Slice> &term,
const cl::Device &device, const std::string &prm_name,
detail::kernel_generator_state&)
{
return Slice::indexing_function(prm_name);
return term.slice.indexing_function(prm_name, device);
}
};

template <typename T, class Slice>
struct kernel_param_declaration< vector_view<T, Slice> > {
static std::string get(const vector_view<T, Slice>&,
const cl::Device&, const std::string &prm_name,
static std::string get(const vector_view<T, Slice> &term,
const cl::Device &device, const std::string &prm_name,
detail::kernel_generator_state&)
{
return Slice::template parameter_declaration<T>(prm_name);
return term.slice.template parameter_declaration<T>(prm_name, device);
}
};

template <typename T, class Slice>
struct partial_vector_expr< vector_view<T, Slice> > {
static std::string get(const vector_view<T, Slice>&,
const cl::Device&, const std::string &prm_name,
static std::string get(const vector_view<T, Slice> &term,
const cl::Device &device, const std::string &prm_name,
detail::kernel_generator_state&)
{
return Slice::partial_expression(prm_name);
return term.slice.partial_expression(prm_name, device);
}
};

Expand All @@ -160,7 +162,8 @@ struct kernel_arg_setter< vector_view<T, Slice> > {
{
assert(device == 0);

Slice::setArgs(kernel, device, index_offset, position, term);
kernel.setArg(position++, term.base(device));
term.slice.setArgs(kernel, device, index_offset, position);
}
};

Expand Down Expand Up @@ -238,7 +241,9 @@ struct gslice {
static_cast<size_t>(1), std::multiplies<size_t>());
}

static std::string indexing_function(const std::string &prm_name) {
std::string indexing_function(const std::string &prm_name,
const cl::Device&) const
{
std::ostringstream s;

s << type_name<size_t>() << " slice_" << prm_name
Expand All @@ -263,7 +268,9 @@ struct gslice {
return s.str();
}

static std::string partial_expression(const std::string &prm_name) {
std::string partial_expression(const std::string &prm_name,
const cl::Device&) const
{
std::ostringstream s;

s << prm_name << "_base[" << "slice_" << prm_name << "("
Expand All @@ -277,7 +284,9 @@ struct gslice {
}

template <typename T>
static std::string parameter_declaration(const std::string &prm_name) {
std::string parameter_declaration(const std::string &prm_name,
const cl::Device&) const
{
std::ostringstream s;

s << ",\n\tglobal " << type_name<T>() << " * " << prm_name << "_base"
Expand All @@ -290,15 +299,13 @@ struct gslice {
return s.str();
}

template <typename T>
static void setArgs(cl::Kernel &kernel, unsigned device, size_t/*index_offset*/,
unsigned &position, const vector_view<T, gslice> &term)
void setArgs(cl::Kernel &kernel, unsigned device, size_t/*index_offset*/,
unsigned &position) const
{
kernel.setArg(position++, term.base(device));
kernel.setArg(position++, term.slice.start);
kernel.setArg(position++, start);
for(size_t k = 0; k < NDIM; ++k) {
kernel.setArg(position++, term.slice.length[k]);
kernel.setArg(position++, term.slice.stride[k]);
kernel.setArg(position++, length[k]);
kernel.setArg(position++, stride[k]);
}
}

Expand Down Expand Up @@ -512,19 +519,25 @@ struct permutation {
return index.size();
}

static std::string partial_expression(const std::string &prm_name) {
std::string partial_expression(const std::string &prm_name,
const cl::Device&) const
{
std::ostringstream s;
s << prm_name << "_base[" << prm_name << "_index[idx]]";

return s.str();
}

static std::string indexing_function(const std::string &/*prm_name*/) {
std::string indexing_function(const std::string &/*prm_name*/,
const cl::Device&) const
{
return "";
}

template <typename T>
static std::string parameter_declaration(const std::string &prm_name) {
std::string parameter_declaration(const std::string &prm_name,
const cl::Device&) const
{
std::ostringstream s;

s << ",\n\tglobal " << type_name<T>() << " * " << prm_name << "_base"
Expand All @@ -533,12 +546,10 @@ struct permutation {
return s.str();
}

template <typename T>
static void setArgs(cl::Kernel &kernel, unsigned device, size_t/*index_offset*/,
unsigned &position, const vector_view<T, permutation> &term)
void setArgs(cl::Kernel &kernel, unsigned device, size_t/*index_offset*/,
unsigned &position) const
{
kernel.setArg(position++, term.base(device));
kernel.setArg(position++, term.slice.index(device));
kernel.setArg(position++, index(device));
}

template <typename T>
Expand All @@ -548,6 +559,88 @@ struct permutation {
}
};

/// Expression-based permutation operator.
template <class Expr>
struct expr_slice {
const Expr expr;

expr_slice(const Expr &expr) : expr(expr) {}

size_t size() const {
detail::get_expression_properties prop;
detail::extract_terminals()(expr, prop);
return prop.size;
}

std::string indexing_function(const std::string &prm_name,
const cl::Device &dev) const
{
std::ostringstream s;

detail::output_terminal_preamble ctx(s, dev, 1, prm_name + "_");
boost::proto::eval(boost::proto::as_child(expr), ctx);

return s.str();
}

std::string partial_expression(const std::string &prm_name,
const cl::Device &dev) const
{
// TODO: local preamble?
std::ostringstream s;
s << prm_name << "_base[";
detail::vector_expr_context ctx(s, dev, 1, prm_name + "_");
boost::proto::eval(boost::proto::as_child(expr), ctx);
s << "]";

return s.str();
}

template <typename T>
std::string parameter_declaration(const std::string &prm_name,
const cl::Device &dev) const
{
std::ostringstream s;

s << ",\n\tglobal " << type_name<T>() << " * " << prm_name << "_base";

detail::declare_expression_parameter ctx(s, dev, 1, prm_name + "_");
detail::extract_terminals()(boost::proto::as_child(expr), ctx);

return s.str();
}

void setArgs(cl::Kernel &kernel, unsigned device, size_t index_offset,
unsigned &position) const
{
detail::extract_terminals()( boost::proto::as_child(expr),
detail::set_expression_argument(kernel, device, position, index_offset));
}

template <typename T>
vector_view<T, expr_slice> operator()(const vector<T> &base) const {
assert(base.queue_list().size() == 1);
return vector_view<T, expr_slice>(base, *this);
}
};

/// Returns permutation functor which is based on an integral expression.
/**
* Example:
* \code
* auto reverse = vex::eslice(N - 1 - vex::element_index());
* Y = reverse(X);
* \endcode
*/
template <class Expr>
typename std::enable_if<
std::is_integral<typename detail::return_type<Expr>::type>::value,
expr_slice<Expr>
>::type
eslice(const Expr &expr) {
return expr_slice<Expr>(expr);
}

//---------------------------------------------------------------------------
// Slice reduction
//---------------------------------------------------------------------------
Expand Down Expand Up @@ -633,11 +726,11 @@ struct terminal_preamble< reduced_vector_view<T, NDIM, NR, RDC> > {

template <typename T, size_t NDIM, size_t NR, class RDC>
struct kernel_param_declaration< reduced_vector_view<T, NDIM, NR, RDC> > {
static std::string get(const reduced_vector_view<T, NDIM, NR, RDC>&,
const cl::Device&, const std::string &prm_name,
static std::string get(const reduced_vector_view<T, NDIM, NR, RDC> &term,
const cl::Device &device, const std::string &prm_name,
detail::kernel_generator_state&)
{
return gslice<NDIM>::template parameter_declaration<T>(prm_name);
return term.slice.template parameter_declaration<T>(prm_name, device);
}
};

Expand Down

0 comments on commit d78879e

Please sign in to comment.