Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] support in predict for complex types (backport #26333) #26978

Merged
merged 1 commit into from
Jul 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 60 additions & 18 deletions be/src/column/map_column.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "column/map_column.h"

#include <cstdint>
#include <set>

#include "column/column_helper.h"
#include "column/fixed_length_column.h"
Expand Down Expand Up @@ -422,34 +423,75 @@ int MapColumn::equals(size_t left, const Column& rhs, size_t right, bool safe_eq
return false;
}

bool has_null = false;
// process the null key at last if exists, so non-nullable keys can exactly identify equal one or not.
// if any non-nullable key does not match, return false;
// else if all non-nullable key are matched (true or null), check the last nullable keys.
// if the last nullable key is not matched, return false; else if there is null result from all keys matching,
// return null, else return true.

bool has_null_eq = false;
uint32_t null_id = 0;
std::vector<uint32_t> index;
for (uint32_t i = lhs_offset; i < lhs_end; ++i) {
bool found = false;
for (uint32_t j = rhs_offset; j < rhs_end; ++j) {
int res = _keys->equals(i, *(rhs_map._keys.get()), j, safe_eq);
if (res == EQUALS_FALSE) {
if (_keys->is_null(i)) {
null_id = i;
continue;
}
index.push_back(i);
}
if (index.size() < (lhs_end - lhs_offset)) {
index.push_back(null_id);
}
std::set<uint32_t> right_index;
for (uint32_t j = rhs_offset; j < rhs_end; ++j) {
right_index.insert(j);
}

for (auto i : index) {
bool real_eq = false;
bool null_eq = false;
uint32_t eq_id = 0;
for (unsigned int j : right_index) {
int key_res = _keys->equals(i, *(rhs_map._keys.get()), j, safe_eq);
if (key_res == EQUALS_FALSE) {
continue;
}

has_null |= (res == EQUALS_NULL);
// So two keys is the same
res = _values->equals(i, *(rhs_map._values.get()), j, safe_eq);
if (res == EQUALS_FALSE) {
return EQUALS_FALSE;
// So two keys are the same or right key is null
int val_res = _values->equals(i, *(rhs_map._values.get()), j, safe_eq);

// case 1: key_res == EQUALS_TRUE
if (key_res == EQUALS_TRUE) {
if (val_res == EQUALS_FALSE) {
return EQUALS_FALSE;
} else if (val_res == EQUALS_NULL) {
null_eq = true;
} else if (val_res == EQUALS_TRUE) {
null_eq = false;
real_eq = true;
}
eq_id = j;
break;
}
// case 2: key_res == EQUALS_NULL, continue
if (val_res != EQUALS_FALSE) {
eq_id = j;
null_eq = true;
}
has_null |= (res == EQUALS_NULL);
found = true;
break;
}
if (!found) {
if (null_eq || real_eq) {
right_index.erase(eq_id);
has_null_eq |= (!real_eq && null_eq);
} else {
return EQUALS_FALSE;
}
}

// unsafe eq && has null, should return NULL
// unsafe eq && none null, should return TRUE
DCHECK(right_index.empty()); // all matched return null or true

// unsafe eq && has null eq, should return NULL
// unsafe eq && none null eq, should return TRUE
// safe eq, should return TRUE
return !safe_eq && has_null ? EQUALS_NULL : EQUALS_TRUE;
return !safe_eq && has_null_eq ? EQUALS_NULL : EQUALS_TRUE;
}

void MapColumn::fnv_hash_at(uint32_t* hash, uint32_t idx) const {
Expand Down
15 changes: 14 additions & 1 deletion be/src/column/struct_column.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,14 @@ class StructColumn final : public ColumnFactory<Column, StructColumn> {
using Container = Buffer<std::string>;

// Used to construct an unnamed struct
StructColumn(Columns fields) : _fields(std::move(fields)) {}
StructColumn(Columns fields) : _fields(std::move(fields)) {
DCHECK(_fields.size() > 0);
for (auto& f : fields) {
DCHECK(f->is_nullable());
DCHECK_EQ(f->size(), size());
f->check_or_die();
}
}

StructColumn(Columns fields, std::vector<std::string> field_names)
: _fields(std::move(fields)), _field_names(std::move(field_names)) {
Expand All @@ -37,6 +44,12 @@ class StructColumn final : public ColumnFactory<Column, StructColumn> {

// fields and field_names must have the same size.
DCHECK(_fields.size() == _field_names.size());

for (auto& f : fields) {
DCHECK(f->is_nullable());
DCHECK_EQ(f->size(), size());
f->check_or_die();
}
}

StructColumn(const StructColumn& rhs) {
Expand Down
22 changes: 13 additions & 9 deletions be/src/exprs/binary_predicate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ class ArrayPredicate final : public Predicate {
EvalCmpZero _comparator;
};

template <bool is_equal>
class CommonEqualsPredicate final : public Predicate {
public:
explicit CommonEqualsPredicate(const TExprNode& node) : Predicate(node) {}
Expand All @@ -168,14 +169,15 @@ class CommonEqualsPredicate final : public Predicate {
if (l->only_null() || r->only_null()) {
return ColumnHelper::create_const_null_column(l->size());
}
auto& const1 = FunctionHelper::get_data_column_of_nullable(l);
auto& const2 = FunctionHelper::get_data_column_of_nullable(r);
// a nullable column must not contain const columns
size_t lstep = l->is_constant() ? 0 : 1;
size_t rstep = r->is_constant() ? 0 : 1;

size_t lstep = const1->is_constant() ? 0 : 1;
size_t rstep = const2->is_constant() ? 0 : 1;
auto& const1 = FunctionHelper::get_data_column_of_const(l);
auto& const2 = FunctionHelper::get_data_column_of_const(r);

auto& data1 = FunctionHelper::get_data_column_of_const(const1);
auto& data2 = FunctionHelper::get_data_column_of_const(const2);
auto& data1 = FunctionHelper::get_data_column_of_nullable(const1);
auto& data2 = FunctionHelper::get_data_column_of_nullable(const2);

size_t size = l->size();
ColumnBuilder<TYPE_BOOLEAN> builder(size);
Expand All @@ -187,7 +189,7 @@ class CommonEqualsPredicate final : public Predicate {
if (res == -1) {
builder.append_null();
} else {
builder.append(res);
builder.append(!(res ^ is_equal));
}
}

Expand Down Expand Up @@ -342,17 +344,19 @@ Expr* VectorizedBinaryPredicateFactory::from_thrift(const TExprNode& node) {

if (type == TYPE_ARRAY) {
if (node.opcode == TExprOpcode::EQ) {
return new CommonEqualsPredicate(node);
return new CommonEqualsPredicate<true>(node);
} else if (node.opcode == TExprOpcode::EQ_FOR_NULL) {
return new CommonNullSafeEqualsPredicate(node);
} else {
return new ArrayPredicate(node);
}
} else if (type == TYPE_MAP || type == TYPE_STRUCT) {
if (node.opcode == TExprOpcode::EQ) {
return new CommonEqualsPredicate(node);
return new CommonEqualsPredicate<true>(node);
} else if (node.opcode == TExprOpcode::EQ_FOR_NULL) {
return new CommonNullSafeEqualsPredicate(node);
} else if (node.opcode == TExprOpcode::NE) {
return new CommonEqualsPredicate<false>(node);
} else {
return nullptr;
}
Expand Down
4 changes: 3 additions & 1 deletion be/src/exprs/cast_nested.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,12 @@ StatusOr<ColumnPtr> CastStructExpr::evaluate_checked(ExprContext* context, Chunk
Chunk field_chunk;
field_chunk.append_column(struct_column->fields()[i], 0);
ASSIGN_OR_RETURN(auto casted_field, _field_casts[i]->evaluate_checked(context, &field_chunk));
casted_field = NullableColumn::wrap_if_necessary(casted_field);
casted_fields.emplace_back(std::move(casted_field));
} else {
casted_fields.emplace_back(struct_column->fields()[i]->clone_shared());
casted_fields.emplace_back(NullableColumn::wrap_if_necessary(struct_column->fields()[i]->clone_shared()));
}
DCHECK(casted_fields[i]->is_nullable());
}

auto casted_struct = StructColumn::create(std::move(casted_fields), _type.field_names);
Expand Down
117 changes: 117 additions & 0 deletions be/src/exprs/in_const_predicate.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,17 @@

#pragma once

#include "column/chunk.h"
#include "column/column_builder.h"
#include "column/column_helper.h"
#include "column/column_viewer.h"
#include "column/hash_set.h"
#include "common/object_pool.h"
#include "exprs/function_helper.h"
#include "exprs/literal.h"
#include "exprs/predicate.h"
#include "gutil/strings/substitute.h"
#include "simd/simd.h"

namespace starrocks {

Expand Down Expand Up @@ -396,6 +400,119 @@ class VectorizedInConstPredicate final : public Predicate {
std::vector<ColumnPtr> _string_values;
};

class VectorizedInConstPredicateGeneric final : public Predicate {
public:
VectorizedInConstPredicateGeneric(const TExprNode& node)
: Predicate(node), _is_not_in(node.in_predicate.is_not_in) {}

VectorizedInConstPredicateGeneric(const VectorizedInConstPredicateGeneric& other)
: Predicate(other), _is_not_in(other._is_not_in) {}

~VectorizedInConstPredicateGeneric() override = default;

Expr* clone(ObjectPool* pool) const override { return pool->add(new VectorizedInConstPredicateGeneric(*this)); }

Status open(RuntimeState* state, ExprContext* context, FunctionContext::FunctionStateScope scope) override {
RETURN_IF_ERROR(Expr::open(state, context, scope));
_const_input.resize(_children.size());
for (auto i = 0; i < _children.size(); ++i) {
if (_children[i]->is_constant()) {
// _const_input[i] maybe not be of ConstColumn
ASSIGN_OR_RETURN(_const_input[i], _children[i]->evaluate_checked(context, nullptr));
} else {
_const_input[i] = nullptr;
}
}
return Status::OK();
}

StatusOr<ColumnPtr> evaluate_checked(ExprContext* context, Chunk* ptr) override {
auto child_size = _children.size();
Columns input_data(child_size);
std::vector<NullColumnPtr> input_null(child_size);
std::vector<bool> is_const(child_size, true);
Columns columns_ref(child_size);
ColumnPtr value;
bool all_const = true;
for (int i = 0; i < child_size; ++i) {
value = _const_input[i];
if (value == nullptr) {
ASSIGN_OR_RETURN(value, _children[i]->evaluate_checked(context, ptr));
is_const[i] = false;
all_const = false;
}
if (i == 0) {
RETURN_IF_COLUMNS_ONLY_NULL({value});
}
columns_ref[i] = value;
if (value->is_constant()) {
value = down_cast<ConstColumn*>(value.get())->data_column();
}
if (value->is_nullable()) {
auto nullable = down_cast<const NullableColumn*>(value.get());
input_null[i] = nullable->null_column();
input_data[i] = nullable->data_column();
} else {
input_null[i] = nullptr;
input_data[i] = value;
}
}
auto size = columns_ref[0]->size();
DCHECK(ptr == nullptr || ptr->num_rows() == size); // ptr is null in tests.
auto dest_size = size;
if (all_const) {
dest_size = 1;
}
BooleanColumn::Ptr res = BooleanColumn::create(dest_size, _is_not_in);
NullColumnPtr res_null = NullColumn::create(dest_size, DATUM_NULL);
auto& res_data = res->get_data();
auto& res_null_data = res_null->get_data();
for (auto i = 0; i < dest_size; ++i) {
auto id_0 = is_const[0] ? 0 : i;
if (input_null[0] == nullptr || !input_null[0]->get_data()[id_0]) {
bool has_null = false;
for (auto j = 1; j < child_size; ++j) {
auto id = is_const[j] ? 0 : i;
// input[j] is null
if (input_null[j] != nullptr && input_null[j]->get_data()[id]) {
has_null = true;
continue;
}
// input[j] is not null
auto is_equal = input_data[0]->equals(id_0, *input_data[j], id, false);
if (is_equal == 1) {
res_null_data[i] = false;
res_data[i] = !_is_not_in;
break;
} else if (is_equal == -1) {
has_null = true;
}
}
if (_is_not_in == res_data[i]) {
res_null_data[i] = has_null;
}
}
}
if (all_const) {
if (res_null_data[0]) { // return only_null column
return ColumnHelper::create_const_null_column(size);
} else {
return ConstColumn::create(res, size);
}
} else {
if (SIMD::count_nonzero(res_null_data) > 0) {
return NullableColumn::create(std::move(res), std::move(res_null));
} else {
return res;
}
}
}

private:
const bool _is_not_in{false};
Columns _const_input;
};

class VectorizedInConstPredicateBuilder {
public:
VectorizedInConstPredicateBuilder(RuntimeState* state, ObjectPool* pool, Expr* expr)
Expand Down
13 changes: 11 additions & 2 deletions be/src/exprs/in_predicate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,22 @@ namespace starrocks {
struct InConstPredicateBuilder {
template <LogicalType ltype>
Expr* operator()(const TExprNode& node) {
return new VectorizedInConstPredicate<ltype>(node);
if constexpr (lt_is_collection<ltype>) {
return new VectorizedInConstPredicateGeneric(node);
} else {
return new VectorizedInConstPredicate<ltype>(node);
}
}
};

Expr* VectorizedInPredicateFactory::from_thrift(const TExprNode& node) {
// children type
LogicalType child_type = thrift_to_type(node.child_type);
if (node.__isset.child_type_desc) {
child_type = TypeDescriptor::from_thrift(node.child_type_desc).type;
} else {
child_type = thrift_to_type(node.child_type);
}

if (child_type == TYPE_CHAR) {
child_type = TYPE_VARCHAR;
Expand All @@ -39,7 +48,7 @@ Expr* VectorizedInPredicateFactory::from_thrift(const TExprNode& node) {
switch (node.opcode) {
case TExprOpcode::FILTER_IN:
case TExprOpcode::FILTER_NOT_IN:
return type_dispatch_basic(child_type, InConstPredicateBuilder(), node);
return type_dispatch_basic_and_complex_types(child_type, InConstPredicateBuilder(), node);
case TExprOpcode::FILTER_NEW_IN:
case TExprOpcode::FILTER_NEW_NOT_IN:
// NOTE: These two opcode are deprecated
Expand Down
13 changes: 13 additions & 0 deletions be/src/types/logical_type_infra.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,19 @@ auto type_dispatch_basic(LogicalType ltype, Functor fun, Args... args) {
}
}

template <class Functor, class... Args>
auto type_dispatch_basic_and_complex_types(LogicalType ltype, Functor fun, Args... args) {
switch (ltype) {
APPLY_FOR_ALL_SCALAR_TYPE_WITH_NULL(_TYPE_DISPATCH_CASE)
_TYPE_DISPATCH_CASE(TYPE_ARRAY)
_TYPE_DISPATCH_CASE(TYPE_MAP)
_TYPE_DISPATCH_CASE(TYPE_STRUCT)
default:
CHECK(false) << "Unknown type: " << ltype;
__builtin_unreachable();
}
}

template <class Functor, class... Args>
auto type_dispatch_all(LogicalType ltype, Functor fun, Args... args) {
switch (ltype) {
Expand Down
Loading
Loading