diff --git a/experimental/iterators/lib/Conversion/IteratorsToLLVM/ArrowUtils.cpp b/experimental/iterators/lib/Conversion/IteratorsToLLVM/ArrowUtils.cpp new file mode 100644 index 000000000000..e0402abadd7b --- /dev/null +++ b/experimental/iterators/lib/Conversion/IteratorsToLLVM/ArrowUtils.cpp @@ -0,0 +1,116 @@ +//===-- ArrowUtils.cpp - Utils for converting Arrow to LLVM -----*- C++ -*-===// +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "ArrowUtils.h" + +#include "iterators/Dialect/Iterators/IR/ArrowUtils.h" +#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/BuiltinOps.h" + +using namespace mlir; +using namespace mlir::iterators; +using namespace mlir::LLVM; + +namespace mlir { +namespace iterators { + +LLVMFuncOp lookupOrInsertArrowArrayGetSize(ModuleOp module) { + MLIRContext *context = module.getContext(); + Type i64 = IntegerType::get(context, 64); + Type array = getArrowArrayType(context); + Type arrayPtr = LLVMPointerType::get(array); + return lookupOrCreateFn(module, "mlirIteratorsArrowArrayGetSize", {arrayPtr}, + i64); +} + +LLVMFuncOp lookupOrInsertArrowArrayGetColumn(ModuleOp module, + Type elementType) { + assert(elementType.isIntOrFloat() && + "only int or float types supported currently"); + MLIRContext *context = module.getContext(); + + // Assemble types for signature. + Type elementPtr = LLVMPointerType::get(elementType); + Type i64 = IntegerType::get(context, 64); + Type array = getArrowArrayType(context); + Type arrayPtr = LLVMPointerType::get(array); + Type schema = getArrowSchemaType(context); + Type schemaPtr = LLVMPointerType::get(schema); + + // Assemble function name. + StringRef typeNameBase; + if (elementType.isSignedInteger() || elementType.isSignlessInteger()) + typeNameBase = "Int"; + else if (elementType.isUnsignedInteger()) + typeNameBase = "UInt"; + else { + assert(elementType.isF16() || elementType.isF32() || elementType.isF64()); + typeNameBase = "Float"; + } + std::string typeWidth = std::to_string(elementType.getIntOrFloatBitWidth()); + std::string funcName = + ("mlirIteratorsArrowArrayGet" + typeNameBase + typeWidth + "Column") + .str(); + + // Lookup or insert function. + return lookupOrCreateFn(module, funcName, {arrayPtr, schemaPtr, i64}, + elementPtr); +} + +LLVMFuncOp lookupOrInsertArrowArrayRelease(ModuleOp module) { + MLIRContext *context = module.getContext(); + Type array = getArrowArrayType(context); + Type arrayPtr = LLVMPointerType::get(array); + Type voidType = LLVMVoidType::get(context); + return lookupOrCreateFn(module, "mlirIteratorsArrowArrayRelease", {arrayPtr}, + voidType); +} + +LLVMFuncOp lookupOrInsertArrowSchemaRelease(ModuleOp module) { + MLIRContext *context = module.getContext(); + Type schema = getArrowSchemaType(context); + Type schemaPtr = LLVMPointerType::get(schema); + Type voidType = LLVMVoidType::get(context); + return lookupOrCreateFn(module, "mlirIteratorsArrowSchemaRelease", + {schemaPtr}, voidType); +} + +LLVMFuncOp lookupOrInsertArrowArrayStreamGetSchema(ModuleOp module) { + MLIRContext *context = module.getContext(); + Type arrayStream = getArrowArrayStreamType(context); + Type arrayStreamPtr = LLVMPointerType::get(arrayStream); + Type schema = getArrowSchemaType(context); + Type schemaPtr = LLVMPointerType::get(schema); + Type voidType = LLVMVoidType::get(context); + return lookupOrCreateFn(module, "mlirIteratorsArrowArrayStreamGetSchema", + {arrayStreamPtr, schemaPtr}, voidType); +} + +LLVMFuncOp lookupOrInsertArrowArrayStreamGetNext(ModuleOp module) { + MLIRContext *context = module.getContext(); + Type i1 = IntegerType::get(context, 1); + Type stream = getArrowArrayStreamType(context); + Type streamPtr = LLVMPointerType::get(stream); + Type array = getArrowArrayType(context); + Type arrayPtr = LLVMPointerType::get(array); + return lookupOrCreateFn(module, "mlirIteratorsArrowArrayStreamGetNext", + {streamPtr, arrayPtr}, i1); +} + +LLVMFuncOp lookupOrInsertArrowArrayStreamRelease(ModuleOp module) { + MLIRContext *context = module.getContext(); + Type arrayStream = getArrowArrayStreamType(context); + Type arrayStreamPtr = LLVMPointerType::get(arrayStream); + Type voidType = LLVMVoidType::get(context); + return lookupOrCreateFn(module, "mlirIteratorsArrowArrayStreamRelease", + {arrayStreamPtr}, voidType); +} + +} // namespace iterators +} // namespace mlir diff --git a/experimental/iterators/lib/Conversion/IteratorsToLLVM/ArrowUtils.h b/experimental/iterators/lib/Conversion/IteratorsToLLVM/ArrowUtils.h new file mode 100644 index 000000000000..e1b8a7e5722d --- /dev/null +++ b/experimental/iterators/lib/Conversion/IteratorsToLLVM/ArrowUtils.h @@ -0,0 +1,60 @@ +//===-- ArrowUtils.h - Utils for converting Arrow to LLVM -------*- C++ -*-===// +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LIB_CONVERSION_ITERATORSTOLLVM_ARROWUTILS_H +#define LIB_CONVERSION_ITERATORSTOLLVM_ARROWUTILS_H + +namespace mlir { +class ModuleOp; +class Type; +namespace LLVM { +class LLVMFuncOp; +} // namespace LLVM +} // namespace mlir + +namespace mlir { +namespace iterators { + +/// Ensures that the runtime function `mlirIteratorsArrowArrayGetSize` is +/// present in the current module and returns the corresponding LLVM func op. +mlir::LLVM::LLVMFuncOp lookupOrInsertArrowArrayGetSize(mlir::ModuleOp module); + +/// Ensures that the runtime function `mlirIteratorsArrowArrayGet*Column` +/// corresponding to the given type is present in the current module and returns +/// the corresponding LLVM func op. +mlir::LLVM::LLVMFuncOp +lookupOrInsertArrowArrayGetColumn(mlir::ModuleOp module, + mlir::Type elementType); + +/// Ensures that the runtime function `mlirIteratorsArrowArrayRelease` is +/// present in the current module and returns the corresponding LLVM func op. +mlir::LLVM::LLVMFuncOp lookupOrInsertArrowArrayRelease(mlir::ModuleOp module); + +/// Ensures that the runtime function `mlirIteratorsArrowSchemaRelease` is +/// present in the current module and returns the corresponding LLVM func op. +mlir::LLVM::LLVMFuncOp lookupOrInsertArrowSchemaRelease(mlir::ModuleOp module); + +/// Ensures that the runtime function `mlirIteratorsArrowArrayStreamGetSchema` +/// is present in the current module and returns the corresponding LLVM func op. +mlir::LLVM::LLVMFuncOp +lookupOrInsertArrowArrayStreamGetSchema(mlir::ModuleOp module); + +/// Ensures that the runtime function `mlirIteratorsArrowArrayStreamGetNext` is +/// present in the current module and returns the corresponding LLVM func op. +mlir::LLVM::LLVMFuncOp +lookupOrInsertArrowArrayStreamGetNext(mlir::ModuleOp module); + +/// Ensures that the runtime function `mlirIteratorsArrowArrayStreamRelease` is +/// present in the current module and returns the corresponding LLVM func op. +mlir::LLVM::LLVMFuncOp +lookupOrInsertArrowArrayStreamRelease(mlir::ModuleOp module); + +} // namespace iterators +} // namespace mlir + +#endif // LIB_CONVERSION_ITERATORSTOLLVM_ARROWUTILS_H diff --git a/experimental/iterators/lib/Conversion/IteratorsToLLVM/CMakeLists.txt b/experimental/iterators/lib/Conversion/IteratorsToLLVM/CMakeLists.txt index a8e7f1d61dc4..3db21401aa5e 100644 --- a/experimental/iterators/lib/Conversion/IteratorsToLLVM/CMakeLists.txt +++ b/experimental/iterators/lib/Conversion/IteratorsToLLVM/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_conversion_library(MLIRIteratorsToLLVM + ArrowUtils.cpp IteratorsToLLVM.cpp IteratorAnalysis.cpp diff --git a/experimental/iterators/lib/Conversion/IteratorsToLLVM/IteratorAnalysis.cpp b/experimental/iterators/lib/Conversion/IteratorsToLLVM/IteratorAnalysis.cpp index 426c03dd5b82..96cbb58f1461 100644 --- a/experimental/iterators/lib/Conversion/IteratorsToLLVM/IteratorAnalysis.cpp +++ b/experimental/iterators/lib/Conversion/IteratorsToLLVM/IteratorAnalysis.cpp @@ -1,13 +1,16 @@ #include "IteratorAnalysis.h" +#include "iterators/Dialect/Iterators/IR/ArrowUtils.h" #include "iterators/Dialect/Iterators/IR/Iterators.h" #include "iterators/Utils/NameAssigner.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/TypeSwitch.h" using namespace mlir; using namespace mlir::iterators; +using namespace mlir::LLVM; using SymbolTriple = std::tuple; @@ -77,6 +80,26 @@ StateTypeComputer::operator()(FilterOp op, return StateType::get(context, {upstreamStateTypes[0]}); } +/// The state of FromArrowArrayStreamOp consists of the pointers to the +/// ArrowArrayStream struct it reads, to an ArrowSchema struct describing the +/// stream, and to an ArrowArray struct that owns the memory of the last element +/// the iterator has returned. Pseudocode: +/// +/// struct { struct ArrowArrayStream *stream; struct ArrowSchema *schema; }; +template <> +StateType StateTypeComputer::operator()( + FromArrowArrayStreamOp op, + llvm::SmallVector /*upstreamStateTypes*/) { + MLIRContext *context = op->getContext(); + Type arrayStream = getArrowArrayStreamType(context); + Type arrayStreamPtr = LLVMPointerType::get(arrayStream); + Type schema = getArrowSchemaType(context); + Type schemaPtr = LLVMPointerType::get(schema); + Type array = getArrowArrayType(context); + Type arrayPtr = LLVMPointerType::get(array); + return StateType::get(context, {arrayStreamPtr, schemaPtr, arrayPtr}); +} + /// The state of MapOp only consists of the state of its upstream iterator, /// i.e., the state of the iterator that produces its input stream. template <> @@ -170,6 +193,7 @@ mlir::iterators::IteratorAnalysis::IteratorAnalysis( // clang-format off ConstantStreamOp, FilterOp, + FromArrowArrayStreamOp, MapOp, ReduceOp, TabularViewToStreamOp, diff --git a/experimental/iterators/lib/Conversion/IteratorsToLLVM/IteratorsToLLVM.cpp b/experimental/iterators/lib/Conversion/IteratorsToLLVM/IteratorsToLLVM.cpp index 5ce14745a2cb..09a999ac8556 100644 --- a/experimental/iterators/lib/Conversion/IteratorsToLLVM/IteratorsToLLVM.cpp +++ b/experimental/iterators/lib/Conversion/IteratorsToLLVM/IteratorsToLLVM.cpp @@ -9,10 +9,13 @@ #include "iterators/Conversion/IteratorsToLLVM/IteratorsToLLVM.h" #include "../PassDetail.h" +#include "ArrowUtils.h" #include "IteratorAnalysis.h" #include "iterators/Conversion/TabularToLLVM/TabularToLLVM.h" +#include "iterators/Dialect/Iterators/IR/ArrowUtils.h" #include "iterators/Dialect/Iterators/IR/Iterators.h" #include "iterators/Dialect/Tabular/IR/Tabular.h" +#include "mlir/Conversion/LLVMCommon/MemRefBuilder.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -628,6 +631,301 @@ static Value buildStateCreation(FilterOp op, FilterOp::Adaptor adaptor, return b.create(stateType, upstreamState); } +//===----------------------------------------------------------------------===// +// FromArrowArrayStreamOp. +//===----------------------------------------------------------------------===// + +/// Builds IR that retrieves the schema from the input input stream in order to +/// allow cached access during the next calls. Possible output: +/// +/// %0 = iterators.extractvalue %arg0[0] : !state_type +/// %1 = iterators.extractvalue %arg0[1] : !state_type +/// llvm.call @mlirIteratorsArrowArrayStreamGetSchema(%0, %1) : +/// (!llvm.ptr, !llvm.ptr) -> () +static Value buildOpenBody(FromArrowArrayStreamOp op, OpBuilder &builder, + Value initialState, + ArrayRef upstreamInfos) { + MLIRContext *context = op.getContext(); + Location loc = op.getLoc(); + ImplicitLocOpBuilder b(loc, builder); + + // Extract stream and schema pointers from state. + Type arrowArrayStream = getArrowArrayStreamType(context); + Type arrowArrayStreamPtr = LLVMPointerType::get(arrowArrayStream); + Type arrowSchema = getArrowSchemaType(context); + Type arrowSchemaPtr = LLVMPointerType::get(arrowSchema); + + Value streamPtr = b.create( + arrowArrayStreamPtr, initialState, b.getIndexAttr(0)); + Value schemaPtr = b.create( + arrowSchemaPtr, initialState, b.getIndexAttr(1)); + + // Call runtime function to load schema. + ModuleOp module = op->getParentOfType(); + LLVMFuncOp getSchemaFunc = lookupOrInsertArrowArrayStreamGetSchema(module); + b.create(getSchemaFunc, ValueRange{streamPtr, schemaPtr}); + + // Return initial state. (We only modified the pointees.) + return initialState; +} + +/// Builds IR that calls the get_next function of the Arrow array stream and +/// returns the obtained record batch wrapped in a tabular view. Pseudo-code +/// +/// if (array = arrow_stream->get_next(arrow_stream)): +/// return convert_to_tabular_view(array) +/// return {} +/// +/// Possible output: +/// +/// %0 = iterators.extractvalue %arg0[0] : !state_type +/// %1 = iterators.extractvalue %arg0[1] : !state_type +/// %2 = iterators.extractvalue %arg0[2] : !state_type +/// llvm.call @mlirIteratorsArrowArrayRelease(%2) : +/// (!llvm.ptr) -> () +/// %3 = llvm.call @mlirIteratorsArrowArrayStreamGetNext(%0, %2) : +/// (!llvm.ptr, !llvm.ptr) -> i1 +/// %c0_i64 = arith.constant 0 : i64 +/// %4 = scf.if %3 -> (i64) { +/// %6 = llvm.call @mlirIteratorsArrowArrayGetSize(%2) : +/// (!llvm.ptr) -> i64 +/// scf.yield %6 : i64 +/// } else { +/// scf.yield %c0_i64 : i64 +/// } +/// %5:2 = scf.if %3 -> (!llvm.ptr, i64) { +/// %c2_i64 = arith.constant 2 : i64 +/// %6 = llvm.call @mlirIteratorsArrowArrayGetInt32Column(%2, %1, %c2_i64) : +/// (!llvm.ptr, !llvm.ptr, i64) -> +/// !llvm.ptr +/// scf.yield %6, %4 : !llvm.ptr, i64 +/// } else { +/// %6 = llvm.mlir.null : !llvm.ptr +/// scf.yield %6, %c0_i64 : !llvm.ptr, i64 +/// } +/// %6 = llvm.mlir.undef : !memref_descr_type +/// %7 = llvm.insertvalue %5#0, %6[0] : !memref_descr_type +/// %8 = llvm.insertvalue %5#0, %7[1] : !memref_descr_type +/// %9 = llvm.insertvalue %c0_i64, %8[2] : !memref_descr_type +/// %10 = llvm.insertvalue %5#1, %9[3, 0] : !memref_descr_type +/// %11 = llvm.insertvalue %5#1, %10[4, 0] : !memref_descr_type +/// %12 = builtin.unrealized_conversion_cast %11 : +/// !memref_descr_type to memref +/// %tabularview = tabular.view_as_tabular %12 : (memref) -> +/// !tabular.tabular_view +static llvm::SmallVector +buildNextBody(FromArrowArrayStreamOp op, OpBuilder &builder, Value initialState, + ArrayRef upstreamInfos, Type elementType) { + MLIRContext *context = op->getContext(); + Location loc = op.getLoc(); + ImplicitLocOpBuilder b(loc, builder); + + // Extract stream, schema, and block pointers from state. + Type arrowArrayStream = getArrowArrayStreamType(context); + Type arrowArrayStreamPtr = LLVMPointerType::get(arrowArrayStream); + Type arrowSchema = getArrowSchemaType(context); + Type arrowSchemaPtr = LLVMPointerType::get(arrowSchema); + Type arrowArray = getArrowArrayType(context); + Type arrowArrayPtr = LLVMPointerType::get(arrowArray); + + Value streamPtr = b.create( + arrowArrayStreamPtr, initialState, b.getIndexAttr(0)); + Value schemaPtr = b.create( + arrowSchemaPtr, initialState, b.getIndexAttr(1)); + Value arrayPtr = b.create( + arrowArrayPtr, initialState, b.getIndexAttr(2)); + + // Get type-unspecific LLVM functions. + ModuleOp module = op->getParentOfType(); + LLVMFuncOp releaseArrayFunc = lookupOrInsertArrowArrayRelease(module); + LLVMFuncOp getNextFunc = lookupOrInsertArrowArrayStreamGetNext(module); + LLVMFuncOp getArraySizeFunc = lookupOrInsertArrowArrayGetSize(module); + + // Release Arrow array from previous call to next. + b.create(releaseArrayFunc, arrayPtr); + + // Call getNext on Arrow stream. + auto getNextResult = + b.create(getNextFunc, ValueRange{streamPtr, arrayPtr}); + Value hasNextElement = getNextResult.getResult(); + + // Call getSize on current array if we got one; use 0 otherwise. + Value zero = b.create(/*value=*/0, /*width=*/64); + auto ifOp = b.create( + /*condition=*/hasNextElement, /*ifBuilder*/ + [&](OpBuilder &builder, Location loc) { + ImplicitLocOpBuilder b(loc, builder); + /*elseBuilder*/ + auto callOp = b.create(getArraySizeFunc, arrayPtr); + Value arraySize = callOp.getResult(); + b.create(arraySize); + }, + [&](OpBuilder &builder, Location loc) { + // Apply map function. + ImplicitLocOpBuilder b(loc, builder); + b.create(zero); + }); + Value arraySize = ifOp->getResult(0); + + // Extract column pointers from Arrow array. + auto tabularViewType = elementType.cast(); + SmallVector memrefs; + LLVMTypeConverter typeConverter(context); + for (auto [idx, t] : llvm::enumerate(tabularViewType.getColumnTypes())) { + auto memrefType = MemRefType::get({ShapedType::kDynamic}, t); + Type columnPtrType = LLVMPointerType::get(t); + + // Get column pointer from the array if we got one; nullptr otherwise. + auto ifOp = b.create( + /*condition=*/hasNextElement, /*ifBuilder*/ + [&, idx = idx, t = t](OpBuilder &builder, Location loc) { + ImplicitLocOpBuilder b(loc, builder); + + // Call type-specific getColumn on current array. + auto idxValue = + b.create(/*value=*/idx, /*width=*/64); + LLVMFuncOp getColumnFunc = + lookupOrInsertArrowArrayGetColumn(module, t); + auto callOp = b.create( + getColumnFunc, ValueRange{arrayPtr, schemaPtr, idxValue}); + Value columnPtr = callOp->getResult(0); + + b.create(ValueRange{columnPtr, arraySize}); + }, + /*elseBuilder*/ + [&](OpBuilder &builder, Location loc) { + ImplicitLocOpBuilder b(loc, builder); + + // Use nullptr instead. + Value columnPtr = b.create(columnPtrType); + b.create(ValueRange{columnPtr, zero}); + }); + + Value columnPtr = ifOp.getResult(0); + Value size = ifOp->getResult(1); + + // Assemble a memref descriptor and cast it to memref. + auto memrefValues = {/*allocated pointer=*/columnPtr, + /*aligned pointer=*/columnPtr, + /*offset=*/zero, /*sizes=*/size, + /*shapes=*/size}; + auto memrefDescriptor = + MemRefDescriptor::pack(b, loc, typeConverter, memrefType, memrefValues); + auto castOp = + b.create(memrefType, memrefDescriptor); + + memrefs.push_back(castOp.getResult(0)); + } + + // Create a tabular view from the memrefs. + Value tab = b.create(elementType, memrefs); + + return {initialState, hasNextElement, tab}; +} + +/// Builds IR that frees up all resources, namely, release the stream, the +/// schema, and the current array. Possible output: +/// +/// %0 = iterators.extractvalue %arg0[0] : !state_type +/// %1 = iterators.extractvalue %arg0[1] : !state_type +/// %2 = iterators.extractvalue %arg0[2] : !state_type +/// llvm.call @mlirIteratorsArrowArrayStreamRelease(%0) : +/// (!llvm.ptr) -> () +/// llvm.call @mlirIteratorsArrowSchemaRelease(%1) : +/// (!llvm.ptr) -> () +/// llvm.call @mlirIteratorsArrowArrayRelease(%2) : +/// (!llvm.ptr) -> () +static Value buildCloseBody(FromArrowArrayStreamOp op, OpBuilder &builder, + Value initialState, + ArrayRef upstreamInfos) { + MLIRContext *context = op.getContext(); + Location loc = op.getLoc(); + ImplicitLocOpBuilder b(loc, builder); + + // Extract stream and schema pointers from state. + Type arrayStreamType = getArrowArrayStreamType(context); + Type arrayStreamPtrType = LLVMPointerType::get(arrayStreamType); + Type schemaType = getArrowSchemaType(context); + Type schemaPtrType = LLVMPointerType::get(schemaType); + Type arrayType = getArrowArrayType(context); + Type arrayPtrType = LLVMPointerType::get(arrayType); + + Value streamPtr = b.create( + arrayStreamPtrType, initialState, b.getIndexAttr(0)); + Value schemaPtr = b.create( + schemaPtrType, initialState, b.getIndexAttr(1)); + Value arrayPtr = b.create( + arrayPtrType, initialState, b.getIndexAttr(2)); + + // Call runtime functions to release structs. + ModuleOp module = op->getParentOfType(); + LLVMFuncOp releaseStreamFunc = lookupOrInsertArrowArrayStreamRelease(module); + LLVMFuncOp releaseSchemaFunc = lookupOrInsertArrowSchemaRelease(module); + LLVMFuncOp releaseArrayFunc = lookupOrInsertArrowArrayRelease(module); + b.create(releaseStreamFunc, streamPtr); + b.create(releaseSchemaFunc, schemaPtr); + b.create(releaseArrayFunc, arrayPtr); + + // Return initial state. (We only modified the pointees.) + return initialState; +} + +/// Builds IR that allocates data for the schema and the current array on the +/// stack and stores pointers to them in the state. Possible output: +/// +/// %c1_i64 = arith.constant 1 : i64 +/// %0 = llvm.alloca %c1_i64 x !llvm.!array_type : (i64) -> +/// !llvm.ptr +/// %1 = llvm.alloca %c1_i64 x !llvm.!schema_type : (i64) -> +/// !llvm.ptr +/// %c0_i8 = arith.constant 0 : i8 +/// %false = arith.constant false +/// %c80_i64 = arith.constant 80 : i64 +/// %c72_i64 = arith.constant 72 : i64 +/// "llvm.intr.memset"(%0, %c0_i8, %c80_i64, %false) : +/// (!llvm.ptr, i8, i64, i1) -> () +/// "llvm.intr.memset"(%1, %c0_i8, %c72_i64, %false) : +/// (!llvm.ptr, i8, i64, i1) -> () +/// %state = iterators.createstate(%arg0, %1, %0) : !state_type +static Value buildStateCreation(FromArrowArrayStreamOp op, + FromArrowArrayStreamOp::Adaptor adaptor, + OpBuilder &builder, StateType stateType) { + MLIRContext *context = op.getContext(); + Location loc = op.getLoc(); + ImplicitLocOpBuilder b(loc, builder); + + // Allocate memory for schema and array on the stack. + Value one = b.create(/*value=*/1, /*width=*/64); + LLVMStructType arrayType = getArrowArrayType(context); + LLVMStructType schemaType = getArrowSchemaType(context); + Type arrayPtrType = LLVMPointerType::get(arrayType); + Type schemaPtrType = LLVMPointerType::get(schemaType); + Value arrayPtr = b.create(arrayPtrType, one); + Value schemaPtr = b.create(schemaPtrType, one); + + // Initialize it with zeros. + Value zero = b.create(/*value=*/0, /*width=*/8); + Value constFalse = b.create(/*value=*/0, /*width=*/1); + uint32_t arrayTypeSize = mlir::DataLayout::closest(op).getTypeSize(arrayType); + uint32_t schemaTypeSize = + mlir::DataLayout::closest(op).getTypeSize(schemaType); + Value arrayTypeSizeVal = + b.create(/*value=*/arrayTypeSize, + /*width=*/64); + Value schemaTypeSizeVal = + b.create(/*value=*/schemaTypeSize, + /*width=*/64); + b.create(arrayPtr, zero, arrayTypeSizeVal, + /*isVolatile=*/constFalse); + b.create(schemaPtr, zero, schemaTypeSizeVal, + /*isVolatile=*/constFalse); + + // Create the state. + Value streamPtr = adaptor.getArrowStream(); + return b.create(stateType, + ValueRange{streamPtr, schemaPtr, arrayPtr}); +} + //===----------------------------------------------------------------------===// // MapOp. //===----------------------------------------------------------------------===// @@ -1289,6 +1587,7 @@ static Value buildOpenBody(Operation *op, OpBuilder &builder, // clang-format off ConstantStreamOp, FilterOp, + FromArrowArrayStreamOp, MapOp, ReduceOp, TabularViewToStreamOp, @@ -1308,6 +1607,7 @@ buildNextBody(Operation *op, OpBuilder &builder, Value initialState, // clang-format off ConstantStreamOp, FilterOp, + FromArrowArrayStreamOp, MapOp, ReduceOp, TabularViewToStreamOp, @@ -1328,6 +1628,7 @@ static Value buildCloseBody(Operation *op, OpBuilder &builder, // clang-format off ConstantStreamOp, FilterOp, + FromArrowArrayStreamOp, MapOp, ReduceOp, TabularViewToStreamOp, @@ -1346,6 +1647,7 @@ static Value buildStateCreation(IteratorOpInterface op, OpBuilder &builder, // clang-format off ConstantStreamOp, FilterOp, + FromArrowArrayStreamOp, MapOp, ReduceOp, TabularViewToStreamOp, diff --git a/experimental/iterators/test/Conversion/IteratorsToLLVM/from-arrow-stream.mlir b/experimental/iterators/test/Conversion/IteratorsToLLVM/from-arrow-stream.mlir new file mode 100644 index 000000000000..5c9fb621a10d --- /dev/null +++ b/experimental/iterators/test/Conversion/IteratorsToLLVM/from-arrow-stream.mlir @@ -0,0 +1,105 @@ +// RUN: iterators-opt %s -convert-iterators-to-llvm \ +// RUN: | FileCheck --enable-var-scope %s +!arrow_schema = !llvm.struct<"ArrowSchema", ( + ptr, // format + ptr, // name + ptr, // metadata + i64, // flags + i64, // n_children + ptr>>, // children + ptr>, // dictionary + ptr>)>>, // release + ptr // private_data + )> +!arrow_array = !llvm.struct<"ArrowArray", ( + i64, // length + i64, // null_count + i64, // offset + i64, // n_buffers + i64, // n_children + ptr, // buffers + ptr>>, // children + ptr>, // dictionary + ptr>)>>, // release + ptr // private_data + )> +!arrow_array_stream = !llvm.struct<"ArrowArrayStream", ( + ptr>, ptr)>>, // get_schema + ptr>, ptr)>>, // get_next + ptr (ptr>)>>, // get_last_error + ptr>)>>, // release + ptr // private_data + )> + +// CHECK-LABEL: func.func private @iterators.from_arrow_array_stream.close.{{[0-9]+}}( +// CHECK-SAME: %[[ARG0:.*]]: !iterators.state, !llvm.ptr<[[SCHEMATYPE:.*]]>, !llvm.ptr<[[ARRAYTYPE:.*]]>>) -> +// CHECK-SAME: !iterators.state, !llvm.ptr<[[SCHEMATYPE]]>, !llvm.ptr<[[ARRAYTYPE]]>> { +// CHECK-NEXT: %[[V0:.*]] = iterators.extractvalue %[[ARG0]][0] : !iterators.state, !llvm.ptr<[[SCHEMATYPE]]>, !llvm.ptr<[[ARRAYTYPE]]>> +// CHECK-NEXT: %[[V1:.*]] = iterators.extractvalue %[[ARG0]][1] : !iterators.state, !llvm.ptr<[[SCHEMATYPE]]>, !llvm.ptr<[[ARRAYTYPE]]>> +// CHECK-NEXT: %[[V2:.*]] = iterators.extractvalue %[[ARG0]][2] : !iterators.state, !llvm.ptr<[[SCHEMATYPE]]>, !llvm.ptr<[[ARRAYTYPE]]>> +// CHECK-NEXT: llvm.call @mlirIteratorsArrowArrayStreamRelease(%[[V0]]) : (!llvm.ptr<[[STREAMTYPE]]>) -> () +// CHECK-NEXT: llvm.call @mlirIteratorsArrowSchemaRelease(%[[V1]]) : (!llvm.ptr<[[SCHEMATYPE]]>) -> () +// CHECK-NEXT: llvm.call @mlirIteratorsArrowArrayRelease(%[[V2]]) : (!llvm.ptr<[[ARRAYTYPE]]>) -> () +// CHECK-NEXT: return %[[ARG0]] : !iterators.state, !llvm.ptr<[[SCHEMATYPE]]>, !llvm.ptr<[[ARRAYTYPE]]>> + +// CHECK-LABEL: func.func private @iterators.from_arrow_array_stream.next.{{[0-9]+}}( +// CHECK-SAME: %[[ARG0:.*]]: !iterators.state, !llvm.ptr<[[SCHEMATYPE:.*]]>, !llvm.ptr<[[ARRAYTYPE:.*]]>>) -> +// CHECK-SAME: (!iterators.state, !llvm.ptr<[[SCHEMATYPE]]>, !llvm.ptr<[[ARRAYTYPE]]>>, i1, !llvm.struct<(i64, ptr)>) { +// CHECK-NEXT: %[[V0:.*]] = iterators.extractvalue %[[ARG0]][0] : !iterators.state, !llvm.ptr<[[SCHEMATYPE]]>, !llvm.ptr<[[ARRAYTYPE]]>> +// CHECK-NEXT: %[[V1:.*]] = iterators.extractvalue %[[ARG0]][1] : !iterators.state, !llvm.ptr<[[SCHEMATYPE]]>, !llvm.ptr<[[ARRAYTYPE]]>> +// CHECK-NEXT: %[[V2:.*]] = iterators.extractvalue %[[ARG0]][2] : !iterators.state, !llvm.ptr<[[SCHEMATYPE]]>, !llvm.ptr<[[ARRAYTYPE]]>> +// CHECK-NEXT: llvm.call @mlirIteratorsArrowArrayRelease(%[[V2]]) : (!llvm.ptr<[[ARRAYTYPE]]>) -> () +// CHECK-NEXT: %[[V3:.*]] = llvm.call @mlirIteratorsArrowArrayStreamGetNext(%[[V0]], %[[V2]]) : (!llvm.ptr<[[STREAMTYPE]]>, !llvm.ptr<[[ARRAYTYPE]]>) -> i1 +// CHECK-NEXT: %[[V4:.*]] = arith.constant 0 : i64 +// CHECK-NEXT: %[[V5:.*]] = scf.if %[[V3]] -> (i64) { +// CHECK-NEXT: %[[V6:.*]] = llvm.call @mlirIteratorsArrowArrayGetSize(%[[V2]]) : (!llvm.ptr<[[ARRAYTYPE]]>) -> i64 +// CHECK-NEXT: scf.yield %[[V6]] : i64 +// CHECK-NEXT: } else { +// CHECK-NEXT: scf.yield %[[V4]] : i64 +// CHECK-NEXT: } +// CHECK-NEXT: %[[V7:.*]]:2 = scf.if %[[V3]] -> (!llvm.ptr, i64) { +// CHECK-NEXT: %[[V8:.*]] = arith.constant 0 : i64 +// CHECK-NEXT: %[[V9:.*]] = llvm.call @mlirIteratorsArrowArrayGetInt32Column(%[[V2]], %[[V1]], %[[V8]]) : (!llvm.ptr<[[ARRAYTYPE]]>, !llvm.ptr<[[SCHEMATYPE]]>, i64) -> !llvm.ptr +// CHECK-NEXT: scf.yield %[[V9]], %[[V5]] : !llvm.ptr, i64 +// CHECK-NEXT: } else { +// CHECK-NEXT: %[[Va:.*]] = llvm.mlir.null : !llvm.ptr +// CHECK-NEXT: scf.yield %[[Va]], %[[V4]] : !llvm.ptr, i64 +// CHECK-NEXT: } +// CHECK-NEXT: %[[Vb:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK-NEXT: %[[Vc:.*]] = llvm.insertvalue %[[V7]]#0, %[[Vb]][0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK-NEXT: %[[Vd:.*]] = llvm.insertvalue %[[V7]]#0, %[[Vc]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK-NEXT: %[[Ve:.*]] = llvm.insertvalue %[[V4]], %[[Vd]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK-NEXT: %[[Vf:.*]] = llvm.insertvalue %[[V7]]#1, %[[Ve]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK-NEXT: %[[Vg:.*]] = llvm.insertvalue %[[V7]]#1, %[[Vf]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK-NEXT: %[[Vh:.*]] = builtin.unrealized_conversion_cast %[[Vg]] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> to memref +// CHECK-NEXT: %[[Vi:.*]] = tabular.view_as_tabular %[[Vh]] : (memref) -> !tabular.tabular_view +// CHECK-NEXT: %[[Vj:.*]] = builtin.unrealized_conversion_cast %[[Vi]] : !tabular.tabular_view to !llvm.struct<(i64, ptr)> +// CHECK-NEXT: return %[[ARG0]], %[[V3]], %[[Vj]] : !iterators.state, !llvm.ptr<[[SCHEMATYPE]]>, !llvm.ptr<[[ARRAYTYPE]]>>, i1, !llvm.struct<(i64, ptr)> + +// CHECK-LABEL: func.func private @iterators.from_arrow_array_stream.open.{{[0-9]+}}( +// CHECK-SAME: %[[ARG0:.*]]: !iterators.state, !llvm.ptr<[[SCHEMATYPE:.*]]>, !llvm.ptr<[[ARRAYTYPE:.*]]>>) -> +// CHECK-SAME: !iterators.state, !llvm.ptr<[[SCHEMATYPE]]>, !llvm.ptr<[[ARRAYTYPE]]>> { +// CHECK-NEXT: %[[V0:.*]] = iterators.extractvalue %[[ARG0]][0] : !iterators.state, !llvm.ptr<[[SCHEMATYPE]]>, !llvm.ptr<[[ARRAYTYPE]]>> +// CHECK-NEXT: %[[V1:.*]] = iterators.extractvalue %[[ARG0]][1] : !iterators.state, !llvm.ptr<[[SCHEMATYPE]]>, !llvm.ptr<[[ARRAYTYPE]]>> +// CHECK-NEXT: llvm.call @mlirIteratorsArrowArrayStreamGetSchema(%[[V0]], %[[V1]]) : (!llvm.ptr<[[STREAMTYPE]]>, !llvm.ptr<[[SCHEMATYPE]]>) -> () +// CHECK-NEXT: return %[[ARG0]] : !iterators.state, !llvm.ptr<[[SCHEMATYPE]]>, !llvm.ptr<[[ARRAYTYPE]]>> + +// CHECK-LABEL: func.func @main( +// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr<[[STREAMTYPE:.*]]>) +// CHECK-NEXT: %[[V0:.*]] = arith.constant 1 : i64 +// CHECK-NEXT: %[[V1:.*]] = llvm.alloca %[[V0]] x !llvm.[[ARRAYTYPE:.*]] : (i64) -> +// CHECK-SAME: !llvm.ptr<[[ARRAYTYPE]]> +// CHECK-NEXT: %[[V2:.*]] = llvm.alloca %[[V0]] x !llvm.[[SCHEMATYPE:.*]] : (i64) -> +// CHECK-SAME: !llvm.ptr<[[SCHEMATYPE]]> +// CHECK-NEXT: %[[V3:.*]] = arith.constant 0 : i8 +// CHECK-NEXT: %[[V4:.*]] = arith.constant false +// CHECK-NEXT: %[[V5:.*]] = arith.constant 80 : i64 +// CHECK-NEXT: %[[V6:.*]] = arith.constant 72 : i64 +// CHECK-NEXT: "llvm.intr.memset"(%[[V1]], %[[V3]], %[[V5]], %[[V4]]) : (!llvm.ptr<[[ARRAYTYPE]]>, i8, i64, i1) -> () +// CHECK-NEXT: "llvm.intr.memset"(%[[V2]], %[[V3]], %[[V6]], %[[V4]]) : (!llvm.ptr<[[SCHEMATYPE]]>, i8, i64, i1) -> () +// CHECK-NEXT: %[[V7:.*]] = iterators.createstate(%[[ARG0]], %[[V2]], %[[V1]]) : !iterators.state, !llvm.ptr<[[SCHEMATYPE]]>, !llvm.ptr<[[ARRAYTYPE]]>> +// CHECK-NEXT: return +func.func @main(%arrow_stream: !llvm.ptr) { + %tabular_view_stream = iterators.from_arrow_array_stream %arrow_stream to !iterators.stream> + return +} diff --git a/experimental/iterators/test/python/dialects/iterators/arrow.py b/experimental/iterators/test/python/dialects/iterators/arrow.py new file mode 100644 index 000000000000..2acb4e444b79 --- /dev/null +++ b/experimental/iterators/test/python/dialects/iterators/arrow.py @@ -0,0 +1,296 @@ +# RUN: %PYTHON %s | FileCheck %s + +import ctypes +import os +import sys +from tempfile import NamedTemporaryFile +import time + +import numpy as np +import pandas as pd +import pyarrow as pa +import pyarrow.cffi +import pyarrow.csv + +from mlir_iterators.dialects import iterators as it +from mlir_iterators.dialects import tabular as tab +from mlir_iterators.passmanager import PassManager +from mlir_iterators.execution_engine import ExecutionEngine +from mlir_iterators.ir import Context, Module + + +def run(f): + print("\nTEST:", f.__name__) + with Context(): + it.register_dialect() + tab.register_dialect() + f() + return f + + +# MLIR definitions of the C structs of the Arrow ABI. +ARROW_STRUCT_DEFINITIONS_MLIR = ''' + !arrow_schema = !llvm.struct<"ArrowSchema", ( + ptr, // format + ptr, // name + ptr, // metadata + i64, // flags + i64, // n_children + ptr>>, // children + ptr>, // dictionary + ptr>)>>, // release + ptr // private_data + )> + !arrow_array = !llvm.struct<"ArrowArray", ( + i64, // length + i64, // null_count + i64, // offset + i64, // n_buffers + i64, // n_children + ptr, // buffers + ptr>>, // children + ptr>, // dictionary + ptr>)>>, // release + ptr // private_data + )> + !arrow_array_stream = !llvm.struct<"ArrowArrayStream", ( + ptr>, ptr)>>, // get_schema + ptr>, ptr)>>, // get_next + ptr (ptr>)>>, // get_last_error + ptr>)>>, // release + ptr // private_data + )> + ''' + +# Arrow data types that are currently supported. +ARROW_SUPPORTED_TYPES = [ + pa.int8(), + pa.int16(), + pa.int32(), + pa.int64(), + pa.float16(), + pa.float32(), + pa.float64() +] + + +# Converts the given Arrow type to the name of the corresponding MLIR type. +def to_mlir_type(t: pa.DataType) -> str: + if pa.types.is_signed_integer(t): + return 'i' + str(t.bit_width) + if pa.types.is_floating(t): + return 'f' + str(t.bit_width) + raise NotImplementedError("Only floats and signed integers supported") + + +# Compiles the given code and wraps it into an execution engine. +def build_and_create_engine(code: str) -> ExecutionEngine: + mod = Module.parse(ARROW_STRUCT_DEFINITIONS_MLIR + code) + pm = PassManager.parse('builtin.module(' + 'convert-iterators-to-llvm,' + 'convert-tabular-to-llvm,' + 'convert-states-to-llvm,' + 'one-shot-bufferize,' + 'canonicalize,cse,' + 'expand-strided-metadata,' + 'finalize-memref-to-llvm,' + 'canonicalize,cse,' + 'convert-func-to-llvm,' + 'reconcile-unrealized-casts,' + 'convert-scf-to-cf,' + 'convert-cf-to-llvm)') + pm.run(mod) + runtime_lib = os.environ['ITERATORS_RUNTIME_LIBRARY_PATH'] + engine = ExecutionEngine(mod, shared_libs=[runtime_lib]) + return engine + + +# Generate MLIR that reads the arrays of an Arrow array stream and produces (and +# prints) the element-wise sum of each array. +def generate_sum_batches_elementwise_code(schema: pa.Schema) -> str: + mlir_types = [to_mlir_type(t) for t in schema.types] + + # Generate code that, for each type, extracts rhs and lhs struct values, adds + # them, and then inserts the result into a result struct. + elementwise_sum = '%struct0 = llvm.mlir.undef : !struct_type' + for i, t in enumerate(mlir_types): + elementwise_sum += f''' + %lhs{i} = llvm.extractvalue %lhs[{i}] : !struct_type + %rhs{i} = llvm.extractvalue %rhs[{i}] : !struct_type + %sum{i} = arith.add{t[0]} %lhs{i}, %rhs{i} : {t} + %struct{i+1} = llvm.insertvalue %sum{i}, %struct{i}[{i}] : !struct_type + ''' + + # Adapt main program to types of the given schema. + code = f''' + !struct_type = !llvm.struct<({', '.join(mlir_types)})> + !tabular_view_type = !tabular.tabular_view<{', '.join(mlir_types)}> + + // Add numbers of two structs element-wise. + func.func private @sum_struct(%lhs : !struct_type, %rhs : !struct_type) -> !struct_type {{ + {elementwise_sum} + return %struct{len(mlir_types)} : !struct_type + }} + + // Consume the given tabular view and produce one element-wise sum from the elements. + func.func @sum_tabular_view(%tabular_view: !tabular_view_type) -> !struct_type {{ + %tabular_stream = iterators.tabular_view_to_stream %tabular_view + to !iterators.stream + %reduced = "iterators.reduce"(%tabular_stream) {{reduceFuncRef = @sum_struct}} + : (!iterators.stream) -> (!iterators.stream) + %result:2 = iterators.stream_to_value %reduced : !iterators.stream + return %result#0 : !struct_type + }} + + // For each Arrow array in the input stream, produce an element-wise sum. + func.func @main(%arrow_stream: !llvm.ptr) + attributes {{ llvm.emit_c_interface }} {{ + %tabular_view_stream = iterators.from_arrow_array_stream %arrow_stream + to !iterators.stream + %sums = "iterators.map"(%tabular_view_stream) {{mapFuncRef = @sum_tabular_view}} + : (!iterators.stream) -> (!iterators.stream) + "iterators.sink"(%sums) : (!iterators.stream) -> () + return + }} + ''' + + return code + + +# Feeds the given Arrow array stream/record batch reader into an MLIR kernel +# that reads the arrays the stream and produces (and prints) the element-wise +# sum of each array/record batch. +def sum_batches_elementwise_with_iterators( + record_batch_reader: pa.RecordBatchReader) -> None: + + code = generate_sum_batches_elementwise_code(record_batch_reader.schema) + engine = build_and_create_engine(code) + + # Create C struct describing the record batch reader. + ffi = pa.cffi.ffi + cffi_stream = ffi.new('struct ArrowArrayStream *') + cffi_stream_ptr = int(ffi.cast("intptr_t", cffi_stream)) + record_batch_reader._export_to_c(cffi_stream_ptr) + + # Wrap argument and invoke compiled function. + arg = ctypes.pointer(ctypes.cast(cffi_stream_ptr, ctypes.c_void_p)) + engine.invoke('main', arg) + + +# Create a sample Arrow table with one column per supported type. +def create_test_input() -> pa.Table: + # Use pyarrow to create an Arrow table in memory. + fields = [pa.field(str(t), t, False) for t in ARROW_SUPPORTED_TYPES] + schema = pa.schema(fields) + arrays = [ + pa.array(np.array(np.arange(10) + 100 * i, field.type.to_pandas_dtype())) + for i, field in enumerate(fields) + ] + table = pa.table(arrays, schema) + return table + + +# Test case: Read from a sequence of Arrow arrays/record batches (produced by a +# Python generator). + + +# CHECK-LABEL: TEST: testArrowStreamInput +@run +def testArrowStreamInput(): + # Use pyarrow to create an Arrow table in memory. + table = create_test_input() + + # Make physically separate batches from the table. (This ensures offset=0). + batches = (b for batch in table.to_batches(max_chunksize=5) + for b in pa.Table.from_pandas(batch.to_pandas()).to_batches()) + + # Create a RecordBatchReader and export it as a C struct. + reader = pa.RecordBatchReader.from_batches(table.schema, batches) + + # Hand the reader as an Arrow array stream to the Iterators test program. + # CHECK-NEXT: (10, 510, 1010, 1510, 2010, 2510, 3010) + # CHECK-NEXT: (35, 535, 1035, 1535, 2035, 2535, 3035) + sum_batches_elementwise_with_iterators(reader) + + +# Test case: Read data from a CSV file (through pyarrow's C++-implemented CSV +# reader). + + +# CHECK-LABEL: TEST: testArrowCSVInput +@run +def testArrowCSVInput(): + table = create_test_input() + # Remove f16 column, which the CSV reader doesn't support yet. + table = table.drop(['halffloat']) + + # Create a temporary CSV file with test data. + with NamedTemporaryFile(mode='wb') as temp_file, \ + pa.PythonFile(temp_file) as csv_file: + # Export test data as CSV. + pa.csv.write_csv(table, csv_file) + + # Flush and rewind to the beginning of the file. + csv_file.flush() + temp_file.flush() + temp_file.seek(0) + + # Re-open as read-only. + csv_file = pa.PythonFile(open(temp_file.name, 'rb')) + + # Create a RecordBatchReader (which reads CSV as a stream). The block_size + # is the manually determined number of bytes of the header plus the first 5 + # rows (such that we get two record batches of 5 rows each). + convert_options = pa.csv.ConvertOptions(column_types=table.schema) + read_options = pa.csv.ReadOptions(block_size=158) + reader = pa.csv.open_csv(csv_file, + read_options=read_options, + convert_options=convert_options) + + # Hand the reader as an Arrow array stream to the Iterators test program. + # CHECK-NEXT: (10, 510, 1010, 1510, 2510, 3010) + # CHECK-NEXT: (35, 535, 1035, 1535, 2535, 3035) + sum_batches_elementwise_with_iterators(reader) + + +# Test case: Read from a sequence of Arrow arrays/record batches (produced by a +# Python generator). + + +# Create a generator that produces single-row record batches with increasing +# numbers with an artificial delay of one second after each of them. Since each +# generated record batch immediately produces output, this visually demonstrate +# that the consumption by the MLIR-based iterators interleaves with the +# Python-based production of the record batches in the stream. +def generate_batches_with_delay(schema: pa.Schema) -> None: + for i in range(5): + arrays = [ + pa.array(np.array([i], field.type.to_pandas_dtype())) + for field in schema + ] + batch = pa.RecordBatch.from_arrays(arrays, schema=schema) + yield batch + # Sleep only when a TTY is attached (in order not to delay unit tests). + if sys.stdout.isatty(): + time.sleep(1) + + +# CHECK-LABEL: TEST: testGeneratorInput +@run +def testGeneratorInput(): + # Use pyarrow to create an Arrow table in memory. + table = create_test_input() + + # Make physically separate batches from the table. (This ensures offset=0). + generator = generate_batches_with_delay(table.schema) + + # Create a RecordBatchReader and export it as a C struct. + reader = pa.RecordBatchReader.from_batches(table.schema, generator) + + # Hand the reader as an Arrow array stream to the Iterators test program. + # CHECK-NEXT: (0, 0, 0, 0, 0, 0, 0) + # CHECK-NEXT: (1, 1, 1, 1, 1, 1, 1) + # CHECK-NEXT: (2, 2, 2, 2, 2, 2, 2) + # CHECK-NEXT: (3, 3, 3, 3, 3, 3, 3) + # CHECK-NEXT: (4, 4, 4, 4, 4, 4, 4) + sum_batches_elementwise_with_iterators(reader)