Skip to content

Commit

Permalink
Implement lowering to LLVM for new FromArrowArrayStreamOp.
Browse files Browse the repository at this point in the history
  • Loading branch information
ingomueller-net committed Mar 15, 2023
1 parent 72201a3 commit c3dd884
Show file tree
Hide file tree
Showing 7 changed files with 904 additions and 0 deletions.
116 changes: 116 additions & 0 deletions experimental/iterators/lib/Conversion/IteratorsToLLVM/ArrowUtils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
//===-- ArrowUtils.cpp - Utils for converting Arrow to LLVM -----*- C++ -*-===//
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "ArrowUtils.h"

#include "iterators/Dialect/Iterators/IR/ArrowUtils.h"
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/BuiltinOps.h"

using namespace mlir;
using namespace mlir::iterators;
using namespace mlir::LLVM;

namespace mlir {
namespace iterators {

LLVMFuncOp lookupOrInsertArrowArrayGetSize(ModuleOp module) {
MLIRContext *context = module.getContext();
Type i64 = IntegerType::get(context, 64);
Type array = getArrowArrayType(context);
Type arrayPtr = LLVMPointerType::get(array);
return lookupOrCreateFn(module, "mlirIteratorsArrowArrayGetSize", {arrayPtr},
i64);
}

LLVMFuncOp lookupOrInsertArrowArrayGetColumn(ModuleOp module,
Type elementType) {
assert(elementType.isIntOrFloat() &&
"only int or float types supported currently");
MLIRContext *context = module.getContext();

// Assemble types for signature.
Type elementPtr = LLVMPointerType::get(elementType);
Type i64 = IntegerType::get(context, 64);
Type array = getArrowArrayType(context);
Type arrayPtr = LLVMPointerType::get(array);
Type schema = getArrowSchemaType(context);
Type schemaPtr = LLVMPointerType::get(schema);

// Assemble function name.
StringRef typeNameBase;
if (elementType.isSignedInteger() || elementType.isSignlessInteger())
typeNameBase = "Int";
else if (elementType.isUnsignedInteger())
typeNameBase = "UInt";
else {
assert(elementType.isF16() || elementType.isF32() || elementType.isF64());
typeNameBase = "Float";
}
std::string typeWidth = std::to_string(elementType.getIntOrFloatBitWidth());
std::string funcName =
("mlirIteratorsArrowArrayGet" + typeNameBase + typeWidth + "Column")
.str();

// Lookup or insert function.
return lookupOrCreateFn(module, funcName, {arrayPtr, schemaPtr, i64},
elementPtr);
}

LLVMFuncOp lookupOrInsertArrowArrayRelease(ModuleOp module) {
MLIRContext *context = module.getContext();
Type array = getArrowArrayType(context);
Type arrayPtr = LLVMPointerType::get(array);
Type voidType = LLVMVoidType::get(context);
return lookupOrCreateFn(module, "mlirIteratorsArrowArrayRelease", {arrayPtr},
voidType);
}

LLVMFuncOp lookupOrInsertArrowSchemaRelease(ModuleOp module) {
MLIRContext *context = module.getContext();
Type schema = getArrowSchemaType(context);
Type schemaPtr = LLVMPointerType::get(schema);
Type voidType = LLVMVoidType::get(context);
return lookupOrCreateFn(module, "mlirIteratorsArrowSchemaRelease",
{schemaPtr}, voidType);
}

LLVMFuncOp lookupOrInsertArrowArrayStreamGetSchema(ModuleOp module) {
MLIRContext *context = module.getContext();
Type arrayStream = getArrowArrayStreamType(context);
Type arrayStreamPtr = LLVMPointerType::get(arrayStream);
Type schema = getArrowSchemaType(context);
Type schemaPtr = LLVMPointerType::get(schema);
Type voidType = LLVMVoidType::get(context);
return lookupOrCreateFn(module, "mlirIteratorsArrowArrayStreamGetSchema",
{arrayStreamPtr, schemaPtr}, voidType);
}

LLVMFuncOp lookupOrInsertArrowArrayStreamGetNext(ModuleOp module) {
MLIRContext *context = module.getContext();
Type i1 = IntegerType::get(context, 1);
Type stream = getArrowArrayStreamType(context);
Type streamPtr = LLVMPointerType::get(stream);
Type array = getArrowArrayType(context);
Type arrayPtr = LLVMPointerType::get(array);
return lookupOrCreateFn(module, "mlirIteratorsArrowArrayStreamGetNext",
{streamPtr, arrayPtr}, i1);
}

LLVMFuncOp lookupOrInsertArrowArrayStreamRelease(ModuleOp module) {
MLIRContext *context = module.getContext();
Type arrayStream = getArrowArrayStreamType(context);
Type arrayStreamPtr = LLVMPointerType::get(arrayStream);
Type voidType = LLVMVoidType::get(context);
return lookupOrCreateFn(module, "mlirIteratorsArrowArrayStreamRelease",
{arrayStreamPtr}, voidType);
}

} // namespace iterators
} // namespace mlir
60 changes: 60 additions & 0 deletions experimental/iterators/lib/Conversion/IteratorsToLLVM/ArrowUtils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
//===-- ArrowUtils.h - Utils for converting Arrow to LLVM -------*- C++ -*-===//
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef LIB_CONVERSION_ITERATORSTOLLVM_ARROWUTILS_H
#define LIB_CONVERSION_ITERATORSTOLLVM_ARROWUTILS_H

namespace mlir {
class ModuleOp;
class Type;
namespace LLVM {
class LLVMFuncOp;
} // namespace LLVM
} // namespace mlir

namespace mlir {
namespace iterators {

/// Ensures that the runtime function `mlirIteratorsArrowArrayGetSize` is
/// present in the current module and returns the corresponding LLVM func op.
mlir::LLVM::LLVMFuncOp lookupOrInsertArrowArrayGetSize(mlir::ModuleOp module);

/// Ensures that the runtime function `mlirIteratorsArrowArrayGet*Column`
/// corresponding to the given type is present in the current module and returns
/// the corresponding LLVM func op.
mlir::LLVM::LLVMFuncOp
lookupOrInsertArrowArrayGetColumn(mlir::ModuleOp module,
mlir::Type elementType);

/// Ensures that the runtime function `mlirIteratorsArrowArrayRelease` is
/// present in the current module and returns the corresponding LLVM func op.
mlir::LLVM::LLVMFuncOp lookupOrInsertArrowArrayRelease(mlir::ModuleOp module);

/// Ensures that the runtime function `mlirIteratorsArrowSchemaRelease` is
/// present in the current module and returns the corresponding LLVM func op.
mlir::LLVM::LLVMFuncOp lookupOrInsertArrowSchemaRelease(mlir::ModuleOp module);

/// Ensures that the runtime function `mlirIteratorsArrowArrayStreamGetSchema`
/// is present in the current module and returns the corresponding LLVM func op.
mlir::LLVM::LLVMFuncOp
lookupOrInsertArrowArrayStreamGetSchema(mlir::ModuleOp module);

/// Ensures that the runtime function `mlirIteratorsArrowArrayStreamGetNext` is
/// present in the current module and returns the corresponding LLVM func op.
mlir::LLVM::LLVMFuncOp
lookupOrInsertArrowArrayStreamGetNext(mlir::ModuleOp module);

/// Ensures that the runtime function `mlirIteratorsArrowArrayStreamRelease` is
/// present in the current module and returns the corresponding LLVM func op.
mlir::LLVM::LLVMFuncOp
lookupOrInsertArrowArrayStreamRelease(mlir::ModuleOp module);

} // namespace iterators
} // namespace mlir

#endif // LIB_CONVERSION_ITERATORSTOLLVM_ARROWUTILS_H
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_mlir_conversion_library(MLIRIteratorsToLLVM
ArrowUtils.cpp
IteratorsToLLVM.cpp
IteratorAnalysis.cpp

Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
#include "IteratorAnalysis.h"

#include "iterators/Dialect/Iterators/IR/ArrowUtils.h"
#include "iterators/Dialect/Iterators/IR/Iterators.h"
#include "iterators/Utils/NameAssigner.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/TypeSwitch.h"

using namespace mlir;
using namespace mlir::iterators;
using namespace mlir::LLVM;

using SymbolTriple = std::tuple<SymbolRefAttr, SymbolRefAttr, SymbolRefAttr>;

Expand Down Expand Up @@ -77,6 +80,26 @@ StateTypeComputer::operator()(FilterOp op,
return StateType::get(context, {upstreamStateTypes[0]});
}

/// The state of FromArrowArrayStreamOp consists of the pointers to the
/// ArrowArrayStream struct it reads, to an ArrowSchema struct describing the
/// stream, and to an ArrowArray struct that owns the memory of the last element
/// the iterator has returned. Pseudocode:
///
/// struct { struct ArrowArrayStream *stream; struct ArrowSchema *schema; };
template <>
StateType StateTypeComputer::operator()(
FromArrowArrayStreamOp op,
llvm::SmallVector<StateType> /*upstreamStateTypes*/) {
MLIRContext *context = op->getContext();
Type arrayStream = getArrowArrayStreamType(context);
Type arrayStreamPtr = LLVMPointerType::get(arrayStream);
Type schema = getArrowSchemaType(context);
Type schemaPtr = LLVMPointerType::get(schema);
Type array = getArrowArrayType(context);
Type arrayPtr = LLVMPointerType::get(array);
return StateType::get(context, {arrayStreamPtr, schemaPtr, arrayPtr});
}

/// The state of MapOp only consists of the state of its upstream iterator,
/// i.e., the state of the iterator that produces its input stream.
template <>
Expand Down Expand Up @@ -170,6 +193,7 @@ mlir::iterators::IteratorAnalysis::IteratorAnalysis(
// clang-format off
ConstantStreamOp,
FilterOp,
FromArrowArrayStreamOp,
MapOp,
ReduceOp,
TabularViewToStreamOp,
Expand Down
Loading

0 comments on commit c3dd884

Please sign in to comment.