From 552a711dfebb60105f043931295741c4ee3d9c53 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Thu, 16 Feb 2023 14:15:00 +0000 Subject: [PATCH] xxx-llvm-lowering --- .../Conversion/IteratorsToLLVM/ArrowUtils.cpp | 116 +++++++++ .../Conversion/IteratorsToLLVM/ArrowUtils.h | 60 +++++ .../Conversion/IteratorsToLLVM/CMakeLists.txt | 1 + .../IteratorsToLLVM/IteratorAnalysis.cpp | 24 ++ .../IteratorsToLLVM/IteratorsToLLVM.cpp | 235 ++++++++++++++++++ .../test/python/dialects/iterators/dialect.py | 147 ++++++++++- 6 files changed, 582 insertions(+), 1 deletion(-) create mode 100644 experimental/iterators/lib/Conversion/IteratorsToLLVM/ArrowUtils.cpp create mode 100644 experimental/iterators/lib/Conversion/IteratorsToLLVM/ArrowUtils.h diff --git a/experimental/iterators/lib/Conversion/IteratorsToLLVM/ArrowUtils.cpp b/experimental/iterators/lib/Conversion/IteratorsToLLVM/ArrowUtils.cpp new file mode 100644 index 000000000000..5fdf735fb84e --- /dev/null +++ b/experimental/iterators/lib/Conversion/IteratorsToLLVM/ArrowUtils.cpp @@ -0,0 +1,116 @@ +//===-- IteratorsToLLVM.h - Conversion from Iterators to LLVM ---*- C++ -*-===// +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "ArrowUtils.h" + +#include "iterators/Dialect/Iterators/IR/ArrowUtils.h" +#include "iterators/Utils/MLIRSupport.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/BuiltinOps.h" + +using namespace mlir; +using namespace mlir::iterators; +using namespace mlir::LLVM; + +namespace mlir { +namespace iterators { + +LLVMFuncOp lookupOrInsertArrowArrayGetSize(ModuleOp module) { + MLIRContext *context = module.getContext(); + Type i64 = IntegerType::get(context, 64); + Type array = getArrowArrayType(context); + Type arrayPtr = LLVMPointerType::get(array); + return lookupOrCreateFn(module, "mlirIteratorsArrowArrayGetSize", {arrayPtr}, + i64); +} + +LLVMFuncOp lookupOrInsertArrowArrayGetColumn(ModuleOp module, + Type elementType) { + assert(elementType.isIntOrFloat() && + "only int or float types supported currently"); + MLIRContext *context = module.getContext(); + + // Assemble types for signature. + Type elementPtr = LLVMPointerType::get(elementType); + Type i64 = IntegerType::get(context, 64); + Type array = getArrowArrayType(context); + Type arrayPtr = LLVMPointerType::get(array); + Type schema = getArrowSchemaType(context); + Type schemaPtr = LLVMPointerType::get(schema); + + // Assemble function name. + StringRef typeNameBase; + if (elementType.isSignedInteger() || elementType.isSignlessInteger()) + typeNameBase = "Int"; + else if (elementType.isUnsignedInteger()) + typeNameBase = "UInt"; + else { + assert(elementType.isF16() || elementType.isF32() || elementType.isF64()); + typeNameBase = "Float"; + } + std::string typeWidth = std::to_string(elementType.getIntOrFloatBitWidth()); + std::string funcName = + ("mlirIteratorsArrowArrayGet" + typeNameBase + typeWidth + "Column") + .str(); + + // Lookup or insert function. + return lookupOrCreateFn(module, funcName, {arrayPtr, schemaPtr, i64}, + elementPtr); +} + +LLVMFuncOp lookupOrInsertArrowArrayRelease(ModuleOp module) { + MLIRContext *context = module.getContext(); + Type array = getArrowArrayType(context); + Type arrayPtr = LLVMPointerType::get(array); + Type voidType = LLVMVoidType::get(context); + return lookupOrCreateFn(module, "mlirIteratorsArrowArrayRelease", {arrayPtr}, + voidType); +} + +LLVMFuncOp lookupOrInsertArrowSchemaRelease(ModuleOp module) { + MLIRContext *context = module.getContext(); + Type schema = getArrowSchemaType(context); + Type schemaPtr = LLVMPointerType::get(schema); + Type voidType = LLVMVoidType::get(context); + return lookupOrCreateFn(module, "mlirIteratorsArrowSchemaRelease", + {schemaPtr}, voidType); +} + +LLVMFuncOp lookupOrInsertArrowArrayStreamGetSchema(ModuleOp module) { + MLIRContext *context = module.getContext(); + Type arrayStream = getArrowArrayStreamType(context); + Type arrayStreamPtr = LLVMPointerType::get(arrayStream); + Type schema = getArrowSchemaType(context); + Type schemaPtr = LLVMPointerType::get(schema); + Type voidType = LLVMVoidType::get(context); + return lookupOrCreateFn(module, "mlirIteratorsArrowArrayStreamGetSchema", + {arrayStreamPtr, schemaPtr}, voidType); +} + +LLVMFuncOp lookupOrInsertArrowArrayStreamGetNext(ModuleOp module) { + MLIRContext *context = module.getContext(); + Type i1 = IntegerType::get(context, 1); + Type stream = getArrowArrayStreamType(context); + Type streamPtr = LLVMPointerType::get(stream); + Type array = getArrowArrayType(context); + Type arrayPtr = LLVMPointerType::get(array); + return lookupOrCreateFn(module, "mlirIteratorsArrowArrayStreamGetNext", + {streamPtr, arrayPtr}, i1); +} + +LLVMFuncOp lookupOrInsertArrowArrayStreamRelease(ModuleOp module) { + MLIRContext *context = module.getContext(); + Type arrayStream = getArrowArrayStreamType(context); + Type arrayStreamPtr = LLVMPointerType::get(arrayStream); + Type voidType = LLVMVoidType::get(context); + return lookupOrCreateFn(module, "mlirIteratorsArrowArrayStreamRelease", + {arrayStreamPtr}, voidType); +} + +} // namespace iterators +} // namespace mlir diff --git a/experimental/iterators/lib/Conversion/IteratorsToLLVM/ArrowUtils.h b/experimental/iterators/lib/Conversion/IteratorsToLLVM/ArrowUtils.h new file mode 100644 index 000000000000..203bcd9efa5b --- /dev/null +++ b/experimental/iterators/lib/Conversion/IteratorsToLLVM/ArrowUtils.h @@ -0,0 +1,60 @@ +//===-- IteratorsToLLVM.h - Conversion from Iterators to LLVM ---*- C++ -*-===// +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LIB_CONVERSION_ITERATORSTOLLVM_ARROWUTILS_H +#define LIB_CONVERSION_ITERATORSTOLLVM_ARROWUTILS_H + +namespace mlir { +class ModuleOp; +class Type; +namespace LLVM { +class LLVMFuncOp; +} // namespace LLVM +} // namespace mlir + +namespace mlir { +namespace iterators { + +/// Ensures that the runtime function `mlirIteratorsArrowArrayGetSize` is +/// present in the current module and returns the corresponding LLVM func op. +mlir::LLVM::LLVMFuncOp lookupOrInsertArrowArrayGetSize(mlir::ModuleOp module); + +/// Ensures that the runtime function `mlirIteratorsArrowArrayGet*Column` +/// corresponding to the given type is present in the current module and returns +/// the corresponding LLVM func op. +mlir::LLVM::LLVMFuncOp +lookupOrInsertArrowArrayGetColumn(mlir::ModuleOp module, + mlir::Type elementType); + +/// Ensures that the runtime function `mlirIteratorsArrowArrayRelease` is +/// present in the current module and returns the corresponding LLVM func op. +mlir::LLVM::LLVMFuncOp lookupOrInsertArrowArrayRelease(mlir::ModuleOp module); + +/// Ensures that the runtime function `mlirIteratorsArrowSchemaRelease` is +/// present in the current module and returns the corresponding LLVM func op. +mlir::LLVM::LLVMFuncOp lookupOrInsertArrowSchemaRelease(mlir::ModuleOp module); + +/// Ensures that the runtime function `mlirIteratorsArrowArrayStreamGetSchema` +/// is present in the current module and returns the corresponding LLVM func op. +mlir::LLVM::LLVMFuncOp +lookupOrInsertArrowArrayStreamGetSchema(mlir::ModuleOp module); + +/// Ensures that the runtime function `mlirIteratorsArrowArrayStreamGetNext` is +/// present in the current module and returns the corresponding LLVM func op. +mlir::LLVM::LLVMFuncOp +lookupOrInsertArrowArrayStreamGetNext(mlir::ModuleOp module); + +/// Ensures that the runtime function `mlirIteratorsArrowArrayStreamRelease` is +/// present in the current module and returns the corresponding LLVM func op. +mlir::LLVM::LLVMFuncOp +lookupOrInsertArrowArrayStreamRelease(mlir::ModuleOp module); + +} // namespace iterators +} // namespace mlir + +#endif // LIB_CONVERSION_ITERATORSTOLLVM_ARROWUTILS_H diff --git a/experimental/iterators/lib/Conversion/IteratorsToLLVM/CMakeLists.txt b/experimental/iterators/lib/Conversion/IteratorsToLLVM/CMakeLists.txt index a8e7f1d61dc4..3db21401aa5e 100644 --- a/experimental/iterators/lib/Conversion/IteratorsToLLVM/CMakeLists.txt +++ b/experimental/iterators/lib/Conversion/IteratorsToLLVM/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_conversion_library(MLIRIteratorsToLLVM + ArrowUtils.cpp IteratorsToLLVM.cpp IteratorAnalysis.cpp diff --git a/experimental/iterators/lib/Conversion/IteratorsToLLVM/IteratorAnalysis.cpp b/experimental/iterators/lib/Conversion/IteratorsToLLVM/IteratorAnalysis.cpp index 426c03dd5b82..96cbb58f1461 100644 --- a/experimental/iterators/lib/Conversion/IteratorsToLLVM/IteratorAnalysis.cpp +++ b/experimental/iterators/lib/Conversion/IteratorsToLLVM/IteratorAnalysis.cpp @@ -1,13 +1,16 @@ #include "IteratorAnalysis.h" +#include "iterators/Dialect/Iterators/IR/ArrowUtils.h" #include "iterators/Dialect/Iterators/IR/Iterators.h" #include "iterators/Utils/NameAssigner.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/TypeSwitch.h" using namespace mlir; using namespace mlir::iterators; +using namespace mlir::LLVM; using SymbolTriple = std::tuple; @@ -77,6 +80,26 @@ StateTypeComputer::operator()(FilterOp op, return StateType::get(context, {upstreamStateTypes[0]}); } +/// The state of FromArrowArrayStreamOp consists of the pointers to the +/// ArrowArrayStream struct it reads, to an ArrowSchema struct describing the +/// stream, and to an ArrowArray struct that owns the memory of the last element +/// the iterator has returned. Pseudocode: +/// +/// struct { struct ArrowArrayStream *stream; struct ArrowSchema *schema; }; +template <> +StateType StateTypeComputer::operator()( + FromArrowArrayStreamOp op, + llvm::SmallVector /*upstreamStateTypes*/) { + MLIRContext *context = op->getContext(); + Type arrayStream = getArrowArrayStreamType(context); + Type arrayStreamPtr = LLVMPointerType::get(arrayStream); + Type schema = getArrowSchemaType(context); + Type schemaPtr = LLVMPointerType::get(schema); + Type array = getArrowArrayType(context); + Type arrayPtr = LLVMPointerType::get(array); + return StateType::get(context, {arrayStreamPtr, schemaPtr, arrayPtr}); +} + /// The state of MapOp only consists of the state of its upstream iterator, /// i.e., the state of the iterator that produces its input stream. template <> @@ -170,6 +193,7 @@ mlir::iterators::IteratorAnalysis::IteratorAnalysis( // clang-format off ConstantStreamOp, FilterOp, + FromArrowArrayStreamOp, MapOp, ReduceOp, TabularViewToStreamOp, diff --git a/experimental/iterators/lib/Conversion/IteratorsToLLVM/IteratorsToLLVM.cpp b/experimental/iterators/lib/Conversion/IteratorsToLLVM/IteratorsToLLVM.cpp index 2edcb34db5ff..6ce78eba6919 100644 --- a/experimental/iterators/lib/Conversion/IteratorsToLLVM/IteratorsToLLVM.cpp +++ b/experimental/iterators/lib/Conversion/IteratorsToLLVM/IteratorsToLLVM.cpp @@ -9,11 +9,14 @@ #include "iterators/Conversion/IteratorsToLLVM/IteratorsToLLVM.h" #include "../PassDetail.h" +#include "ArrowUtils.h" #include "IteratorAnalysis.h" #include "iterators/Conversion/TabularToLLVM/TabularToLLVM.h" +#include "iterators/Dialect/Iterators/IR/ArrowUtils.h" #include "iterators/Dialect/Iterators/IR/Iterators.h" #include "iterators/Dialect/Tabular/IR/Tabular.h" #include "iterators/Utils/MLIRSupport.h" +#include "mlir/Conversion/LLVMCommon/MemRefBuilder.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -610,6 +613,234 @@ static Value buildStateCreation(FilterOp op, FilterOp::Adaptor adaptor, return b.create(stateType, upstreamState); } +//===----------------------------------------------------------------------===// +// FromArrowArrayStreamOp. +//===----------------------------------------------------------------------===// + +/// Builds IR that opens the nested upstream iterator. Possible output: +/// +/// XXX +static Value buildOpenBody(FromArrowArrayStreamOp op, OpBuilder &builder, + Value initialState, + ArrayRef upstreamInfos) { + MLIRContext *context = op.getContext(); + Location loc = op.getLoc(); + ImplicitLocOpBuilder b(loc, builder); + + // Extract stream and schema pointers from state. + Type arrowArrayStream = getArrowArrayStreamType(context); + Type arrowArrayStreamPtr = LLVMPointerType::get(arrowArrayStream); + Type arrowSchema = getArrowSchemaType(context); + Type arrowSchemaPtr = LLVMPointerType::get(arrowSchema); + + Value streamPtr = b.create( + arrowArrayStreamPtr, initialState, b.getIndexAttr(0)); + Value schemaPtr = b.create( + arrowSchemaPtr, initialState, b.getIndexAttr(1)); + + // Call runtime function to load schema. + ModuleOp module = op->getParentOfType(); + LLVMFuncOp getSchemaFunc = lookupOrInsertArrowArrayStreamGetSchema(module); + b.create(getSchemaFunc, ValueRange{streamPtr, schemaPtr}); + + // Return initial state. (We only modified the pointees.) + return initialState; +} + +/// Builds IR that consumes all elements of the upstream iterator and returns +/// a stream of those that pass the given precicate. Pseudo-code: +/// +/// XXX +/// +/// Possible output: +/// +/// XXX +static llvm::SmallVector +buildNextBody(FromArrowArrayStreamOp op, OpBuilder &builder, Value initialState, + ArrayRef upstreamInfos, Type elementType) { + MLIRContext *context = op->getContext(); + Location loc = op.getLoc(); + ImplicitLocOpBuilder b(loc, builder); + + // Extract stream, schema, and block pointers from state. + Type arrowArrayStream = getArrowArrayStreamType(context); + Type arrowArrayStreamPtr = LLVMPointerType::get(arrowArrayStream); + Type arrowSchema = getArrowSchemaType(context); + Type arrowSchemaPtr = LLVMPointerType::get(arrowSchema); + Type arrowArray = getArrowArrayType(context); + Type arrowArrayPtr = LLVMPointerType::get(arrowArray); + + Value streamPtr = b.create( + arrowArrayStreamPtr, initialState, b.getIndexAttr(0)); + Value schemaPtr = b.create( + arrowSchemaPtr, initialState, b.getIndexAttr(1)); + Value arrayPtr = b.create( + arrowArrayPtr, initialState, b.getIndexAttr(2)); + + ModuleOp module = op->getParentOfType(); + + LLVMFuncOp releaseArrayFunc = lookupOrInsertArrowArrayRelease(module); + b.create(releaseArrayFunc, arrayPtr); + + // Call getNext on arrow stream. + LLVMFuncOp getNextFunc = lookupOrInsertArrowArrayStreamGetNext(module); + auto getNextResult = + b.create(getNextFunc, ValueRange{streamPtr, arrayPtr}); + Value hasNextElement = getNextResult.getResult(); + + // XXX: return here if there is no next record batch. + auto tabularViewType = elementType.cast(); + SmallVector memrefs; + LLVMTypeConverter typeConverter(context); + Value zero = b.create(/*value=*/0, /*width=*/64); + + // Call getBatchSize on current record batch. + LLVMFuncOp getBatchSizeFunc = lookupOrInsertArrowArrayGetSize(module); + auto ifOp = b.create( + /*condition=*/hasNextElement, /*ifBuilder*/ + [&](OpBuilder &builder, Location loc) { + ImplicitLocOpBuilder b(loc, builder); + /*elseBuilder*/ + auto getBatchSizeResult = + b.create(getBatchSizeFunc, ValueRange{arrayPtr}); + Value batchSize = getBatchSizeResult.getResult(); + b.create(batchSize); + }, + [&](OpBuilder &builder, Location loc) { + // Apply map function. + ImplicitLocOpBuilder b(loc, builder); + b.create(zero); + }); + Value batchSize = ifOp->getResult(0); + + // Extract column pointers from record batch. + for (auto [idx, t] : llvm::enumerate(tabularViewType.getColumnTypes())) { + auto memrefType = MemRefType::get({ShapedType::kDynamic}, t); + Type columnPtrType = LLVMPointerType::get(t); + + auto ifOp = b.create( + /*condition=*/hasNextElement, /*ifBuilder*/ + [&, idx = idx, t = t](OpBuilder &builder, Location loc) { + ImplicitLocOpBuilder b(loc, builder); + + // Call getColumn on current record batch. + auto idxValue = + b.create(/*value=*/idx, /*width=*/64); + LLVMFuncOp getColumnFunc = + lookupOrInsertArrowArrayGetColumn(module, t); + auto getColumnResult = b.create( + getColumnFunc, ValueRange{arrayPtr, schemaPtr, idxValue}); + Value columnPtr = getColumnResult->getResult(0); + + b.create(ValueRange{columnPtr, batchSize}); + }, + /*elseBuilder*/ + [&](OpBuilder &builder, Location loc) { + ImplicitLocOpBuilder b(loc, builder); + + Value columnPtr = b.create(columnPtrType); + b.create(ValueRange{columnPtr, zero}); + }); + + Value columnPtr = ifOp.getResult(0); + Value size = ifOp->getResult(1); + + // XXX: Use makeStridedMemRefDescriptor instead + auto memrefValues = {/*allocated pointer=*/columnPtr, + /*aligned pointer=*/columnPtr, + /*offset=*/zero, /*sizes=*/size, + /*shapes=*/size}; + auto memrefDescriptor = + MemRefDescriptor::pack(b, loc, typeConverter, memrefType, memrefValues); + auto castOp = + b.create(memrefType, memrefDescriptor); + + memrefs.push_back(castOp.getResult(0)); + } + + Value tab = b.create(elementType, memrefs); + + return {initialState, hasNextElement, tab}; +} + +/// Builds IR that closes the nested upstream iterator. Possible output: +/// +/// %0 = iterators.extractvalue %arg0[0] : !iterators.state +/// %1 = call @iterators.upstream.close.0(%0) : (!nested_state) -> !nested_state +/// %2 = iterators.insertvalue %1 into %arg0[0] : +/// !iterators.state +static Value buildCloseBody(FromArrowArrayStreamOp op, OpBuilder &builder, + Value initialState, + ArrayRef upstreamInfos) { + MLIRContext *context = op.getContext(); + Location loc = op.getLoc(); + ImplicitLocOpBuilder b(loc, builder); + + // Extract stream and schema pointers from state. + Type arrayStreamType = getArrowArrayStreamType(context); + Type arrayStreamPtrType = LLVMPointerType::get(arrayStreamType); + Type schemaType = getArrowSchemaType(context); + Type schemaPtrType = LLVMPointerType::get(schemaType); + + Value streamPtr = b.create( + arrayStreamPtrType, initialState, b.getIndexAttr(0)); + Value schemaPtr = b.create( + schemaPtrType, initialState, b.getIndexAttr(1)); + + // Call runtime functions to release structs. + ModuleOp module = op->getParentOfType(); + LLVMFuncOp releaseStreamFunc = lookupOrInsertArrowArrayStreamRelease(module); + LLVMFuncOp releaseSchemaFunc = lookupOrInsertArrowSchemaRelease(module); + b.create(releaseStreamFunc, streamPtr); + b.create(releaseSchemaFunc, schemaPtr); + + // Return initial state. (We only modified the pointees.) + return initialState; +} + +/// Builds IR that initializes the iterator state with the state of the upstream +/// iterator. Possible output: +/// +/// %0 = ... +/// XXX +static Value buildStateCreation(FromArrowArrayStreamOp op, + FromArrowArrayStreamOp::Adaptor adaptor, + OpBuilder &builder, StateType stateType) { + MLIRContext *context = op.getContext(); + Location loc = op.getLoc(); + ImplicitLocOpBuilder b(loc, builder); + + // Allocate memory for schema and array on the stack. + Value one = b.create(/*value=*/1, /*width=*/64); + LLVMStructType arrayType = getArrowArrayType(context); + LLVMStructType schemaType = getArrowSchemaType(context); + Type arrayPtrType = LLVMPointerType::get(arrayType); + Type schemaPtrType = LLVMPointerType::get(schemaType); + Value arrayPtr = b.create(arrayPtrType, one); + Value schemaPtr = b.create(schemaPtrType, one); + + // Initialize it with zeros. + Value zero = b.create(/*value=*/0, /*width=*/8); + Value constFalse = b.create(/*value=*/0, /*width=*/1); + uint32_t arrayTypeSize = mlir::DataLayout::closest(op).getTypeSize(arrayType); + uint32_t schemaTypeSize = + mlir::DataLayout::closest(op).getTypeSize(schemaType); + Value arrayTypeSizeVal = + b.create(/*value=*/arrayTypeSize, + /*width=*/64); + Value schemaTypeSizeVal = + b.create(/*value=*/schemaTypeSize, + /*width=*/64); + b.create(arrayPtr, zero, arrayTypeSizeVal, + /*isVolatile=*/constFalse); + b.create(schemaPtr, zero, schemaTypeSizeVal, + /*isVolatile=*/constFalse); + + Value streamPtr = adaptor.getArrowStream(); + return b.create(stateType, + ValueRange{streamPtr, schemaPtr, arrayPtr}); +} + //===----------------------------------------------------------------------===// // MapOp. //===----------------------------------------------------------------------===// @@ -1274,6 +1505,7 @@ static Value buildOpenBody(Operation *op, OpBuilder &builder, // clang-format off ConstantStreamOp, FilterOp, + FromArrowArrayStreamOp, MapOp, ReduceOp, TabularViewToStreamOp, @@ -1293,6 +1525,7 @@ buildNextBody(Operation *op, OpBuilder &builder, Value initialState, // clang-format off ConstantStreamOp, FilterOp, + FromArrowArrayStreamOp, MapOp, ReduceOp, TabularViewToStreamOp, @@ -1313,6 +1546,7 @@ static Value buildCloseBody(Operation *op, OpBuilder &builder, // clang-format off ConstantStreamOp, FilterOp, + FromArrowArrayStreamOp, MapOp, ReduceOp, TabularViewToStreamOp, @@ -1331,6 +1565,7 @@ static Value buildStateCreation(IteratorOpInterface op, OpBuilder &builder, // clang-format off ConstantStreamOp, FilterOp, + FromArrowArrayStreamOp, MapOp, ReduceOp, TabularViewToStreamOp, diff --git a/experimental/iterators/test/python/dialects/iterators/dialect.py b/experimental/iterators/test/python/dialects/iterators/dialect.py index 3e5ed5cb297b..039ef6852dc1 100644 --- a/experimental/iterators/test/python/dialects/iterators/dialect.py +++ b/experimental/iterators/test/python/dialects/iterators/dialect.py @@ -1,9 +1,12 @@ # RUN: %PYTHON %s | FileCheck %s import ctypes +import os -import pandas as pd import numpy as np +import pandas as pd +import pyarrow as pa +import pyarrow.cffi from mlir_iterators.runtime.pandas_to_iterators import to_tabular_view_descriptor from mlir_iterators.dialects import iterators as it @@ -146,3 +149,145 @@ def testEndToEndWithInput(): # CHECK-NEXT: (2, 5) engine = ExecutionEngine(mod) engine.invoke('main', arg) + + +# CHECK-LABEL: TEST: testArrowStreamInput +@run +def testArrowStreamInput(): + mod = Module.parse(''' + !arrow_schema = !llvm.struct<"ArrowSchema", ( + ptr, // format + ptr, // name + ptr, // metadata + i64, // flags + i64, // n_children + ptr>>, // children + ptr>, // dictionary + ptr>)>>, // release + ptr // private_data + )> + !arrow_array = !llvm.struct<"ArrowArray", ( + i64, // length + i64, // null_count + i64, // offset + i64, // n_buffers + i64, // n_children + ptr, // buffers + ptr>>, // children + ptr>, // dictionary + ptr>)>>, // release + ptr // private_data + )> + !arrow_array_stream = !llvm.struct<"ArrowArrayStream", ( + ptr>, ptr)>>, // get_schema + ptr>, ptr)>>, // get_next + ptr (ptr>)>>, // get_last_error + ptr>)>>, // release + ptr // private_data + )> + + !struct_type = !llvm.struct<(i8, i16, i32, i64, f16, f32, f64)> + !tabular_view_type = !tabular.tabular_view + + func.func private @sum_struct(%lhs : !struct_type, %rhs : !struct_type) -> !struct_type { + %lhs0 = llvm.extractvalue %lhs[0] : !struct_type + %lhs1 = llvm.extractvalue %lhs[1] : !struct_type + %lhs2 = llvm.extractvalue %lhs[2] : !struct_type + %lhs3 = llvm.extractvalue %lhs[3] : !struct_type + %lhs4 = llvm.extractvalue %lhs[4] : !struct_type + %lhs5 = llvm.extractvalue %lhs[5] : !struct_type + %lhs6 = llvm.extractvalue %lhs[6] : !struct_type + %rhs0 = llvm.extractvalue %rhs[0] : !struct_type + %rhs1 = llvm.extractvalue %rhs[1] : !struct_type + %rhs2 = llvm.extractvalue %rhs[2] : !struct_type + %rhs3 = llvm.extractvalue %rhs[3] : !struct_type + %rhs4 = llvm.extractvalue %rhs[4] : !struct_type + %rhs5 = llvm.extractvalue %rhs[5] : !struct_type + %rhs6 = llvm.extractvalue %rhs[6] : !struct_type + %sum0 = arith.addi %lhs0, %rhs0 : i8 + %sum1 = arith.addi %lhs1, %rhs1 : i16 + %sum2 = arith.addi %lhs2, %rhs2 : i32 + %sum3 = arith.addi %lhs3, %rhs3 : i64 + %sum4 = arith.addf %lhs4, %rhs4 : f16 + %sum5 = arith.addf %lhs5, %rhs5 : f32 + %sum6 = arith.addf %lhs6, %rhs6 : f64 + %undef = llvm.mlir.undef : !struct_type + %struct0 = llvm.insertvalue %sum0, %undef[0] : !struct_type + %struct1 = llvm.insertvalue %sum1, %struct0[1] : !struct_type + %struct2 = llvm.insertvalue %sum2, %struct1[2] : !struct_type + %struct3 = llvm.insertvalue %sum3, %struct2[3] : !struct_type + %struct4 = llvm.insertvalue %sum4, %struct3[4] : !struct_type + %struct5 = llvm.insertvalue %sum5, %struct4[5] : !struct_type + %struct6 = llvm.insertvalue %sum6, %struct5[6] : !struct_type + return %struct6 : !struct_type + } + + func.func @sum_tabular_view(%tabular_view: !tabular_view_type) -> !struct_type { + %tabular_stream = iterators.tabular_view_to_stream %tabular_view + to !iterators.stream + %reduced = "iterators.reduce"(%tabular_stream) {reduceFuncRef = @sum_struct} + : (!iterators.stream) -> (!iterators.stream) + %result:2 = iterators.stream_to_value %reduced : !iterators.stream + return %result#0 : !struct_type + } + + func.func @main(%arrow_stream: !llvm.ptr) + attributes { llvm.emit_c_interface } { + %tabular_view_stream = iterators.from_arrow_array_stream %arrow_stream + to !iterators.stream + %sums = "iterators.map"(%tabular_view_stream) {mapFuncRef = @sum_tabular_view} + : (!iterators.stream) -> (!iterators.stream) + "iterators.sink"(%sums) : (!iterators.stream) -> () + return + } + ''') + pm = PassManager.parse('builtin.module(' + 'convert-iterators-to-llvm,' + 'convert-tabular-to-llvm,' + 'convert-states-to-llvm,' + 'one-shot-bufferize,' + 'canonicalize,cse,' + 'expand-strided-metadata,' + 'finalize-memref-to-llvm,' + 'canonicalize,cse,' + 'convert-func-to-llvm,' + 'reconcile-unrealized-casts,' + 'convert-scf-to-cf,' + 'convert-cf-to-llvm)') + pm.run(mod) + + # Use pyarrow to create a record batch reader from an Arrow table. + types = [ + pa.int8(), + pa.int16(), + pa.int32(), + pa.int64(), + pa.float16(), + pa.float32(), + pa.float64() + ] + fields = [pa.field(str(t), t, False) for t in types] + schema = pa.schema(fields) + arrays = [ + pa.array(np.array(np.arange(10) + 100 * i, field.type.to_pandas_dtype())) + for i, field in enumerate(fields) + ] + table = pa.table(arrays, schema) + batches = [ + b for batch in table.to_batches(max_chunksize=5) + for b in pa.Table.from_pandas(batch.to_pandas()).to_batches() + ] + reader = pa.RecordBatchReader.from_batches(schema, batches) + + # Create C struct describing the record batch reader. + ffi = pa.cffi.ffi + cffi_stream = ffi.new('struct ArrowArrayStream *') + cffi_stream_ptr = int(ffi.cast("intptr_t", cffi_stream)) + reader._export_to_c(cffi_stream_ptr) + + arg = ctypes.pointer(ctypes.cast(cffi_stream_ptr, ctypes.c_void_p)) + runtime_lib = os.environ['ITERATORS_RUNTIME_LIBRARY_PATH'] + engine = ExecutionEngine(mod, shared_libs=[runtime_lib]) + # CHECK-NEXT: (10, 510, 1010, 1510, 2010, 2510, 3010) + # CHECK-NEXT: (35, 535, 1035, 1535, 2035, 2535, 3035) + engine.invoke('main', arg)