Skip to content

Commit

Permalink
Add support for scanning edge properties to GDS edgeCompute
Browse files Browse the repository at this point in the history
  • Loading branch information
benjaminwinger committed Oct 17, 2024
1 parent b3bfc6e commit 2f2a454
Show file tree
Hide file tree
Showing 11 changed files with 162 additions and 52 deletions.
4 changes: 4 additions & 0 deletions src/catalog/catalog_entry/table_catalog_entry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ column_id_t TableCatalogEntry::getColumnID(const std::string& propertyName) cons
return propertyCollection.getColumnID(propertyName);
}

common::column_id_t TableCatalogEntry::getColumnID(common::idx_t idx) const {
return propertyCollection.getColumnID(idx);
}

void TableCatalogEntry::addProperty(const PropertyDefinition& propertyDefinition) {
propertyCollection.add(propertyDefinition);
}
Expand Down
9 changes: 6 additions & 3 deletions src/function/gds/all_shortest_paths.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "binder/expression/expression_util.h"
#include "common/data_chunk/sel_vector.h"
#include "common/vector/value_vector.h"
#include "function/gds/bfs_graph.h"
#include "function/gds/gds_frontier.h"
#include "function/gds/gds_function_collection.h"
Expand Down Expand Up @@ -208,7 +209,7 @@ class AllSPLengthsEdgeCompute : public EdgeCompute {
: frontierPair{frontierPair}, multiplicities{multiplicities} {};

void edgeCompute(nodeID_t boundNodeID, std::span<const nodeID_t> nbrIDs,
std::span<const relID_t>, SelectionVector& mask, bool) override {
std::span<const relID_t>, SelectionVector& mask, bool, const ValueVector*) override {
size_t activeCount = 0;
mask.forEach([&](auto i) {
auto nbrVal =
Expand Down Expand Up @@ -249,7 +250,8 @@ class AllSPPathsEdgeCompute : public EdgeCompute {
}

void edgeCompute(nodeID_t boundNodeID, std::span<const nodeID_t> nbrNodeIDs,
std::span<const relID_t> edgeIDs, SelectionVector& mask, bool fwdEdge) override {
std::span<const relID_t> edgeIDs, SelectionVector& mask, bool fwdEdge,
const ValueVector*) override {
size_t activeCount = 0;
mask.forEach([&](auto i) {
auto nbrLen = frontiersPair->pathLengths->getMaskValueFromNextFrontierFixedMask(
Expand Down Expand Up @@ -399,7 +401,8 @@ struct VarLenJoinsEdgeCompute : public EdgeCompute {
};

void edgeCompute(nodeID_t boundNodeID, std::span<const nodeID_t> nbrNodeIDs,
std::span<const relID_t> edgeIDs, SelectionVector& mask, bool isFwd) override {
std::span<const relID_t> edgeIDs, SelectionVector& mask, bool isFwd,
const ValueVector*) override {
mask.forEach([&](auto i) {
// We should always update the nbrID in variable length joins
if (!parentPtrsBlock->hasSpace()) {
Expand Down
38 changes: 19 additions & 19 deletions src/function/gds/gds_task.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@ using namespace kuzu::common;
namespace kuzu {
namespace function {

static uint64_t computeScanResult(nodeID_t sourceNodeID, std::span<const nodeID_t> nbrNodeIDs,
std::span<const relID_t> edgeIDs, SelectionVector& mask, EdgeCompute& ec,
FrontierPair& frontierPair, bool isFwd) {
KU_ASSERT(nbrNodeIDs.size() == edgeIDs.size());
ec.edgeCompute(sourceNodeID, nbrNodeIDs, edgeIDs, mask, isFwd);
frontierPair.getNextFrontierUnsafe().setActive(mask, nbrNodeIDs);
return mask.getSelSize();
static uint64_t computeScanResult(nodeID_t sourceNodeID, graph::GraphScanState::Chunk& chunk,
EdgeCompute& ec, FrontierPair& frontierPair, bool isFwd) {
KU_ASSERT(chunk.nbrNodes.size() == chunk.edges.size());
ec.edgeCompute(sourceNodeID, chunk.nbrNodes, chunk.edges, chunk.selVector, isFwd,
chunk.propertyVector);
frontierPair.getNextFrontierUnsafe().setActive(chunk.selVector, chunk.nbrNodes);
return chunk.selVector.getSelSize();
}

void FrontierTask::run() {
Expand All @@ -29,25 +29,25 @@ void FrontierTask::run() {
if (sharedState->frontierPair.curFrontier->isActive(nodeID)) {
switch (info.direction) {
case ExtendDirection::FWD: {
for (auto [nodes, edges, mask] : graph->scanFwd(nodeID, *scanState)) {
numApproxActiveNodesForNextIter += computeScanResult(nodeID, nodes, edges,
mask, *localEc, sharedState->frontierPair, true);
for (auto chunk : graph->scanFwd(nodeID, *scanState)) {
numApproxActiveNodesForNextIter += computeScanResult(nodeID, chunk,
*localEc, sharedState->frontierPair, true);
}
} break;
case ExtendDirection::BWD: {
for (auto [nodes, edges, mask] : graph->scanBwd(nodeID, *scanState)) {
numApproxActiveNodesForNextIter += computeScanResult(nodeID, nodes, edges,
mask, *localEc, sharedState->frontierPair, false);
for (auto chunk : graph->scanBwd(nodeID, *scanState)) {
numApproxActiveNodesForNextIter += computeScanResult(nodeID, chunk,
*localEc, sharedState->frontierPair, false);
}
} break;
case ExtendDirection::BOTH: {
for (auto [nodes, edges, mask] : graph->scanFwd(nodeID, *scanState)) {
numApproxActiveNodesForNextIter += computeScanResult(nodeID, nodes, edges,
mask, *localEc, sharedState->frontierPair, true);
for (auto chunk : graph->scanFwd(nodeID, *scanState)) {
numApproxActiveNodesForNextIter += computeScanResult(nodeID, chunk,
*localEc, sharedState->frontierPair, true);
}
for (auto [nodes, edges, mask] : graph->scanBwd(nodeID, *scanState)) {
numApproxActiveNodesForNextIter += computeScanResult(nodeID, nodes, edges,
mask, *localEc, sharedState->frontierPair, false);
for (auto chunk : graph->scanBwd(nodeID, *scanState)) {
numApproxActiveNodesForNextIter += computeScanResult(nodeID, chunk,
*localEc, sharedState->frontierPair, false);
}
} break;
default:
Expand Down
6 changes: 4 additions & 2 deletions src/function/gds/single_shortest_paths.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "common/data_chunk/sel_vector.h"
#include "common/vector/value_vector.h"
#include "function/gds/bfs_graph.h"
#include "function/gds/gds_frontier.h"
#include "function/gds/gds_function_collection.h"
Expand Down Expand Up @@ -63,7 +64,7 @@ class SingleSPLengthsEdgeCompute : public EdgeCompute {
: frontierPair{frontierPair} {};

void edgeCompute(common::nodeID_t, std::span<const common::nodeID_t> nbrIDs,
std::span<const relID_t>, SelectionVector& mask, bool) override {
std::span<const relID_t>, SelectionVector& mask, bool, const ValueVector*) override {
size_t activeCount = 0;
mask.forEach([&](auto i) {
if (frontierPair->pathLengths->getMaskValueFromNextFrontierFixedMask(
Expand All @@ -90,7 +91,8 @@ class SingleSPPathsEdgeCompute : public EdgeCompute {
}

void edgeCompute(nodeID_t boundNodeID, std::span<const nodeID_t> nbrNodeIDs,
std::span<const relID_t> edgeIDs, SelectionVector& mask, bool isFwd) override {
std::span<const relID_t> edgeIDs, SelectionVector& mask, bool isFwd,
const ValueVector*) override {
size_t activeCount = 0;
mask.forEach([&](auto i) {
auto shouldUpdate = frontierPair->pathLengths->getMaskValueFromNextFrontierFixedMask(
Expand Down
42 changes: 34 additions & 8 deletions src/graph/on_disk_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include "binder/expression/property_expression.h"
#include "common/assert.h"
#include "common/data_chunk/data_chunk_state.h"
#include "common/enums/rel_direction.h"
#include "common/types/types.h"
#include "common/vector/value_vector.h"
Expand All @@ -30,11 +31,15 @@ namespace graph {
static std::unique_ptr<RelTableScanState> getRelScanState(MemoryManager& mm,
const TableCatalogEntry& relEntry, const RelTable& table, RelDataDirection direction,
ValueVector* srcVector, ValueVector* dstVector, ValueVector* relIDVector,
expression_vector properties, const Schema& schema, const ResultSet& resultSet) {
expression_vector predicateEdgeProperties, std::optional<column_id_t> edgePropertyID,
ValueVector* propertyVector, const Schema& schema, const ResultSet& resultSet) {
auto columnIDs = std::vector<column_id_t>{NBR_ID_COLUMN_ID, REL_ID_COLUMN_ID};
for (auto property : properties) {
for (auto property : predicateEdgeProperties) {
columnIDs.push_back(property->constCast<PropertyExpression>().getColumnID(relEntry));
}
if (edgePropertyID) {
columnIDs.push_back(*edgePropertyID);
}
auto columns = std::vector<Column*>{};
for (const auto columnID : columnIDs) {
columns.push_back(table.getColumn(columnID, direction));
Expand All @@ -44,17 +49,20 @@ static std::unique_ptr<RelTableScanState> getRelScanState(MemoryManager& mm,
scanState->nodeIDVector = srcVector;
scanState->outputVectors.push_back(dstVector);
scanState->outputVectors.push_back(relIDVector);
for (auto& property : properties) {
for (auto& property : predicateEdgeProperties) {
auto pos = DataPos(schema.getExpressionPos(*property));
auto vector = resultSet.getValueVector(pos).get();
scanState->outputVectors.push_back(vector);
}
if (edgePropertyID) {
scanState->outputVectors.push_back(propertyVector);
}
scanState->outState = dstVector->state.get();
return scanState;
}

OnDiskGraphScanStates::OnDiskGraphScanStates(ClientContext* context, std::span<RelTable*> tables,
const GraphEntry& graphEntry)
const GraphEntry& graphEntry, std::optional<idx_t> edgePropertyIndex)
: iteratorIndex{0}, direction{RelDataDirection::INVALID} {
auto schema = graphEntry.getRelPropertiesSchema();
auto descriptor = ResultSetDescriptor(&schema);
Expand All @@ -70,6 +78,21 @@ OnDiskGraphScanStates::OnDiskGraphScanStates(ClientContext* context, std::span<R
relIDVector =
std::make_unique<ValueVector>(LogicalType::INTERNAL_ID(), context->getMemoryManager());
relIDVector->state = state;
std::optional<column_id_t> edgePropertyID;
if (edgePropertyIndex) {
// Edge property scans are only supported for single table scans at the moment
KU_ASSERT(tables.size() == 1);
// TODO(bmwinger): If there are both a predicate and a custom edgePropertyIndex, they will
// currently be scanned twice. The propertyVector could simply be one of the vectors used
// for the predicate.
auto catalogEntry =
context->getCatalog()->getTableCatalogEntry(context->getTx(), tables[0]->getTableID());
propertyVector = std::make_unique<ValueVector>(
catalogEntry->getProperty(*edgePropertyIndex).getType().copy(),
context->getMemoryManager());
propertyVector->state = std::make_shared<DataChunkState>();
edgePropertyID = catalogEntry->getColumnID(*edgePropertyIndex);
}
if (graphEntry.hasRelPredicate()) {
auto mapper = ExpressionMapper(&schema);
relPredicateEvaluator = mapper.getEvaluator(graphEntry.getRelPredicate());
Expand All @@ -80,10 +103,12 @@ OnDiskGraphScanStates::OnDiskGraphScanStates(ClientContext* context, std::span<R
auto relEntry = graphEntry.getRelEntry(table->getTableID());
auto fwdState = getRelScanState(*context->getMemoryManager(), *relEntry, *table,
RelDataDirection::FWD, srcNodeIDVector.get(), dstNodeIDVector.get(), relIDVector.get(),
graphEntry.getRelProperties(), schema, resultSet);
graphEntry.getRelProperties(), *edgePropertyID, propertyVector.get(), schema,
resultSet);
auto bwdState = getRelScanState(*context->getMemoryManager(), *relEntry, *table,
RelDataDirection::BWD, srcNodeIDVector.get(), dstNodeIDVector.get(), relIDVector.get(),
graphEntry.getRelProperties(), schema, resultSet);
graphEntry.getRelProperties(), *edgePropertyID, propertyVector.get(), schema,
resultSet);
scanStates.emplace_back(table->getTableID(),
OnDiskGraphScanState{context, *table, std::move(fwdState), std::move(bwdState)});
}
Expand Down Expand Up @@ -165,10 +190,11 @@ std::vector<RelTableIDInfo> OnDiskGraph::getRelTableIDInfos() {
return result;
}

std::unique_ptr<GraphScanState> OnDiskGraph::prepareScan(table_id_t relTableID) {
std::unique_ptr<GraphScanState> OnDiskGraph::prepareScan(table_id_t relTableID,
std::optional<idx_t> edgePropertyIndex) {
auto relTable = context->getStorageManager()->getTable(relTableID)->ptrCast<RelTable>();
return std::unique_ptr<OnDiskGraphScanStates>(
new OnDiskGraphScanStates(context, std::span(&relTable, 1), graphEntry));
new OnDiskGraphScanStates(context, std::span(&relTable, 1), graphEntry, edgePropertyIndex));
}

std::unique_ptr<GraphScanState> OnDiskGraph::prepareMultiTableScanFwd(
Expand Down
2 changes: 2 additions & 0 deletions src/include/catalog/catalog_entry/table_catalog_entry.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "catalog/catalog_entry/catalog_entry.h"
#include "catalog/property_definition_collection.h"
#include "common/enums/table_type.h"
#include "common/types/types.h"
#include "function/table_functions.h"

namespace kuzu {
Expand Down Expand Up @@ -58,6 +59,7 @@ class KUZU_API TableCatalogEntry : public CatalogEntry {
const binder::PropertyDefinition& getProperty(const std::string& propertyName) const;
const binder::PropertyDefinition& getProperty(common::idx_t idx) const;
virtual common::column_id_t getColumnID(const std::string& propertyName) const;
common::column_id_t getColumnID(common::idx_t idx) const;
void addProperty(const binder::PropertyDefinition& propertyDefinition);
void dropProperty(const std::string& propertyName);
void renameProperty(const std::string& propertyName, const std::string& newName);
Expand Down
2 changes: 2 additions & 0 deletions src/include/common/vector/value_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include "common/assert.h"
#include "common/cast.h"
#include "common/copy_constructors.h"
#include "common/data_chunk/data_chunk_state.h"
#include "common/null_mask.h"
#include "common/types/ku_string.h"
Expand All @@ -30,6 +31,7 @@ class KUZU_API ValueVector {
KU_ASSERT(dataTypeID != LogicalTypeID::LIST);
}

DELETE_COPY_AND_MOVE(ValueVector);
~ValueVector() = default;

void setState(const std::shared_ptr<DataChunkState>& state_);
Expand Down
3 changes: 2 additions & 1 deletion src/include/function/gds/gds_frontier.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include "common/data_chunk/sel_vector.h"
#include "common/types/types.h"
#include "common/vector/value_vector.h"
#include "storage/buffer_manager/memory_manager.h"

namespace kuzu {
Expand All @@ -26,7 +27,7 @@ class EdgeCompute {
// **do not** call setActive. Helper functions in GDSUtils will do that work.
virtual void edgeCompute(common::nodeID_t boundNodeID,
std::span<const common::nodeID_t> nbrNodeID, std::span<const common::relID_t> edgeID,
common::SelectionVector& mask, bool fwdEdge) = 0;
common::SelectionVector& mask, bool fwdEdge, const common::ValueVector* edgeProperty) = 0;

virtual std::unique_ptr<EdgeCompute> copy() = 0;
};
Expand Down
6 changes: 5 additions & 1 deletion src/include/graph/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
#include <cstdint>
#include <iterator>
#include <memory>
#include <optional>

#include "common/copy_constructors.h"
#include "common/data_chunk/sel_vector.h"
#include "common/types/types.h"
#include "common/vector/value_vector.h"
#include <span>

namespace kuzu {
Expand All @@ -26,6 +28,7 @@ class GraphScanState {
// this reference can be modified, but the underlying data will be reset the next time next
// is called
common::SelectionVector& selVector;
const common::ValueVector* propertyVector;
};
virtual ~GraphScanState() = default;
virtual Chunk getChunk() = 0;
Expand Down Expand Up @@ -115,7 +118,8 @@ class Graph {
virtual std::vector<RelTableIDInfo> getRelTableIDInfos() = 0;

// Prepares scan on the specified relationship table (works for backwards and forwards scans)
virtual std::unique_ptr<GraphScanState> prepareScan(common::table_id_t relTableID) = 0;
virtual std::unique_ptr<GraphScanState> prepareScan(common::table_id_t relTableID,
std::optional<common::column_id_t> edgePropertyID = std::nullopt) = 0;
// Prepares scan on all connected relationship tables using forward adjList.
virtual std::unique_ptr<GraphScanState> prepareMultiTableScanFwd(
std::span<common::table_id_t> nodeTableIDs) = 0;
Expand Down
10 changes: 7 additions & 3 deletions src/include/graph/on_disk_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ class OnDiskGraphScanStates : public GraphScanState {
public:
GraphScanState::Chunk getChunk() override {
auto& iter = getInnerIterator();
return Chunk{iter.getNbrNodes(), iter.getEdges(), iter.getSelVectorUnsafe()};
return Chunk{iter.getNbrNodes(), iter.getEdges(), iter.getSelVectorUnsafe(),
propertyVector.get()};
}
bool next() override;

Expand Down Expand Up @@ -110,13 +111,15 @@ class OnDiskGraphScanStates : public GraphScanState {
std::unique_ptr<common::ValueVector> srcNodeIDVector;
std::unique_ptr<common::ValueVector> dstNodeIDVector;
std::unique_ptr<common::ValueVector> relIDVector;
std::unique_ptr<common::ValueVector> propertyVector;
size_t iteratorIndex;
common::RelDataDirection direction;

std::unique_ptr<evaluator::ExpressionEvaluator> relPredicateEvaluator;

explicit OnDiskGraphScanStates(main::ClientContext* context,
std::span<storage::RelTable*> tableIDs, const GraphEntry& graphEntry);
std::span<storage::RelTable*> tableIDs, const GraphEntry& graphEntry,
std::optional<common::idx_t> edgePropertyIndex = std::nullopt);
std::vector<std::pair<common::table_id_t, OnDiskGraphScanState>> scanStates;
};

Expand All @@ -134,7 +137,8 @@ class OnDiskGraph final : public Graph {

std::vector<RelTableIDInfo> getRelTableIDInfos() override;

std::unique_ptr<GraphScanState> prepareScan(common::table_id_t relTableID) override;
std::unique_ptr<GraphScanState> prepareScan(common::table_id_t relTableID,
std::optional<common::column_id_t> propertyID = std::nullopt) override;
std::unique_ptr<GraphScanState> prepareMultiTableScanFwd(
std::span<common::table_id_t> nodeTableIDs) override;
std::unique_ptr<GraphScanState> prepareMultiTableScanBwd(
Expand Down
Loading

0 comments on commit 2f2a454

Please sign in to comment.