diff --git a/src/include/optimizer/logical_operator_visitor.h b/src/include/optimizer/logical_operator_visitor.h index e41d291c93..2ace82a8d9 100644 --- a/src/include/optimizer/logical_operator_visitor.h +++ b/src/include/optimizer/logical_operator_visitor.h @@ -75,6 +75,12 @@ class LogicalOperatorVisitor { return op; } + virtual void visitNodeLabelFilter(planner::LogicalOperator* /*op*/) {} + virtual std::shared_ptr visitNodeLabelFilterReplace( + std::shared_ptr op) { + return op; + } + virtual void visitFlatten(planner::LogicalOperator* /*op*/) {} virtual std::shared_ptr visitFlattenReplace( std::shared_ptr op) { diff --git a/src/include/optimizer/projection_push_down_optimizer.h b/src/include/optimizer/projection_push_down_optimizer.h index 420342f136..346e19653c 100644 --- a/src/include/optimizer/projection_push_down_optimizer.h +++ b/src/include/optimizer/projection_push_down_optimizer.h @@ -34,6 +34,7 @@ class ProjectionPushDownOptimizer : public LogicalOperatorVisitor { void visitExtend(planner::LogicalOperator* op) override; void visitAccumulate(planner::LogicalOperator* op) override; void visitFilter(planner::LogicalOperator* op) override; + void visitNodeLabelFilter(planner::LogicalOperator* op) override; void visitHashJoin(planner::LogicalOperator* op) override; void visitIntersect(planner::LogicalOperator* op) override; void visitProjection(planner::LogicalOperator* op) override; diff --git a/src/optimizer/logical_operator_visitor.cpp b/src/optimizer/logical_operator_visitor.cpp index 1e996b092a..92f7b39321 100644 --- a/src/optimizer/logical_operator_visitor.cpp +++ b/src/optimizer/logical_operator_visitor.cpp @@ -37,6 +37,9 @@ void LogicalOperatorVisitor::visitOperatorSwitch(LogicalOperator* op) { case LogicalOperatorType::FILTER: { visitFilter(op); } break; + case LogicalOperatorType::NODE_LABEL_FILTER: { + visitNodeLabelFilter(op); + } break; case LogicalOperatorType::FLATTEN: { visitFlatten(op); } break; @@ -123,6 +126,9 @@ std::shared_ptr LogicalOperatorVisitor::visitOperatorReplaceSwi case LogicalOperatorType::FILTER: { return visitFilterReplace(op); } + case LogicalOperatorType::NODE_LABEL_FILTER: { + return visitNodeLabelFilterReplace(op); + } case LogicalOperatorType::FLATTEN: { return visitFlattenReplace(op); } diff --git a/src/optimizer/projection_push_down_optimizer.cpp b/src/optimizer/projection_push_down_optimizer.cpp index 755370509f..4633793bef 100644 --- a/src/optimizer/projection_push_down_optimizer.cpp +++ b/src/optimizer/projection_push_down_optimizer.cpp @@ -7,6 +7,7 @@ #include "planner/operator/logical_filter.h" #include "planner/operator/logical_hash_join.h" #include "planner/operator/logical_intersect.h" +#include "planner/operator/logical_node_label_filter.h" #include "planner/operator/logical_order_by.h" #include "planner/operator/logical_projection.h" #include "planner/operator/logical_table_function_call.h" @@ -84,6 +85,11 @@ void ProjectionPushDownOptimizer::visitFilter(LogicalOperator* op) { collectExpressionsInUse(filter.getPredicate()); } +void ProjectionPushDownOptimizer::visitNodeLabelFilter(LogicalOperator* op) { + auto& filter = op->constCast(); + collectExpressionsInUse(filter.getNodeID()); +} + void ProjectionPushDownOptimizer::visitHashJoin(LogicalOperator* op) { auto& hashJoin = op->constCast(); for (auto& [probeJoinKey, buildJoinKey] : hashJoin.getJoinConditions()) {