Skip to content

Commit

Permalink
ogr2ogr: optim: call GetArrowStream() only once on source layer when …
Browse files Browse the repository at this point in the history
…using Arrow interface

as this might be an expensive operation on some drivers.

Covered by existing tests.
  • Loading branch information
rouault committed Oct 13, 2024
1 parent c98ef3a commit 0f2eefb
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 50 deletions.
110 changes: 60 additions & 50 deletions apps/ogr2ogr_lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,7 @@ struct TargetLayerInfo
const char *m_pszGeomField = nullptr;
std::vector<int> m_anDateTimeFieldIdx{};
bool m_bSupportCurves = false;
OGRArrowArrayStream m_sArrowArrayStream{};
};

struct AssociatedLayers
Expand All @@ -507,7 +508,8 @@ class SetupTargetLayer
bool CanUseWriteArrowBatch(OGRLayer *poSrcLayer, OGRLayer *poDstLayer,
bool bJustCreatedLayer,
const GDALVectorTranslateOptions *psOptions,
bool &bError);
bool bPreserveFID, bool &bError,
OGRArrowArrayStream &streamSrc);

public:
GDALDataset *m_poSrcDS = nullptr;
Expand Down Expand Up @@ -3990,13 +3992,46 @@ static int GetArrowGeomFieldIndex(const struct ArrowSchema *psLayerSchema,
return -1;
}

/************************************************************************/
/* BuildGetArrowStreamOptions() */
/************************************************************************/

static CPLStringList
BuildGetArrowStreamOptions(const GDALVectorTranslateOptions *psOptions,
bool bPreserveFID)
{
CPLStringList aosOptionsGetArrowStream;
aosOptionsGetArrowStream.SetNameValue("SILENCE_GET_SCHEMA_ERROR", "YES");
aosOptionsGetArrowStream.SetNameValue("GEOMETRY_ENCODING", "WKB");
if (!bPreserveFID)
aosOptionsGetArrowStream.SetNameValue("INCLUDE_FID", "NO");
if (psOptions->nLimit >= 0)
{
aosOptionsGetArrowStream.SetNameValue(
"MAX_FEATURES_IN_BATCH",
CPLSPrintf(CPL_FRMT_GIB,
std::min<GIntBig>(psOptions->nLimit,
(psOptions->nGroupTransactions > 0
? psOptions->nGroupTransactions
: 65536))));
}
else if (psOptions->nGroupTransactions > 0)
{
aosOptionsGetArrowStream.SetNameValue(
"MAX_FEATURES_IN_BATCH",
CPLSPrintf("%d", psOptions->nGroupTransactions));
}
return aosOptionsGetArrowStream;
}

/************************************************************************/
/* SetupTargetLayer::CanUseWriteArrowBatch() */
/************************************************************************/

bool SetupTargetLayer::CanUseWriteArrowBatch(
OGRLayer *poSrcLayer, OGRLayer *poDstLayer, bool bJustCreatedLayer,
const GDALVectorTranslateOptions *psOptions, bool &bError)
const GDALVectorTranslateOptions *psOptions, bool bPreserveFID,
bool &bError, OGRArrowArrayStream &streamSrc)
{
bError = false;

Expand Down Expand Up @@ -4050,20 +4085,20 @@ bool SetupTargetLayer::CanUseWriteArrowBatch(
}
}

struct ArrowArrayStream streamSrc;
const char *const apszOptions[] = {"SILENCE_GET_SCHEMA_ERROR=YES",
nullptr};
if (poSrcLayer->GetArrowStream(&streamSrc, apszOptions))
const CPLStringList aosGetArrowStreamOptions(
BuildGetArrowStreamOptions(psOptions, bPreserveFID));
if (poSrcLayer->GetArrowStream(streamSrc.get(),
aosGetArrowStreamOptions.List()))
{
struct ArrowSchema schemaSrc;
if (streamSrc.get_schema(&streamSrc, &schemaSrc) == 0)
if (streamSrc.get_schema(&schemaSrc) == 0)
{
if (psOptions->bTransform &&
GetArrowGeomFieldIndex(&schemaSrc,
poSrcLayer->GetGeometryColumn()) < 0)
{
schemaSrc.release(&schemaSrc);
streamSrc.release(&streamSrc);
streamSrc.clear();
return false;
}

Expand Down Expand Up @@ -4145,7 +4180,7 @@ bool SetupTargetLayer::CanUseWriteArrowBatch(
"Cannot create field %s",
pszFieldName);
schemaSrc.release(&schemaSrc);
streamSrc.release(&streamSrc);
streamSrc.clear();
return false;
}
}
Expand All @@ -4157,7 +4192,8 @@ bool SetupTargetLayer::CanUseWriteArrowBatch(
// check that it looks to be the same as the source
// one
struct ArrowArrayStream streamDst;
if (poDstLayer->GetArrowStream(&streamDst, nullptr))
if (poDstLayer->GetArrowStream(
&streamDst, aosGetArrowStreamOptions.List()))
{
struct ArrowSchema schemaDst;
if (streamDst.get_schema(&streamDst, &schemaDst) ==
Expand Down Expand Up @@ -4188,7 +4224,8 @@ bool SetupTargetLayer::CanUseWriteArrowBatch(
}
schemaSrc.release(&schemaSrc);
}
streamSrc.release(&streamSrc);
if (!bUseWriteArrowBatch)
streamSrc.clear();
}
}
return bUseWriteArrowBatch;
Expand Down Expand Up @@ -4915,8 +4952,10 @@ SetupTargetLayer::Setup(OGRLayer *poSrcLayer, const char *pszNewLayerName,
}

bool bError = false;
const bool bUseWriteArrowBatch = CanUseWriteArrowBatch(
poSrcLayer, poDstLayer, bJustCreatedLayer, psOptions, bError);
OGRArrowArrayStream streamSrc;
const bool bUseWriteArrowBatch =
CanUseWriteArrowBatch(poSrcLayer, poDstLayer, bJustCreatedLayer,
psOptions, bPreserveFID, bError, streamSrc);
if (bError)
return nullptr;

Expand Down Expand Up @@ -5378,7 +5417,7 @@ SetupTargetLayer::Setup(OGRLayer *poSrcLayer, const char *pszNewLayerName,
nTotalEventsDone = 0;
}

std::unique_ptr<TargetLayerInfo> psInfo(new TargetLayerInfo);
auto psInfo = std::make_unique<TargetLayerInfo>();
psInfo->m_bUseWriteArrowBatch = bUseWriteArrowBatch;
psInfo->m_nFeaturesRead = 0;
psInfo->m_bPerFeatureCT = false;
Expand Down Expand Up @@ -5475,6 +5514,8 @@ SetupTargetLayer::Setup(OGRLayer *poSrcLayer, const char *pszNewLayerName,
psInfo->m_bSupportCurves =
CPL_TO_BOOL(poDstLayer->TestCapability(OLCCurveGeometries));

psInfo->m_sArrowArrayStream = std::move(streamSrc);

return psInfo;
}

Expand Down Expand Up @@ -5769,49 +5810,19 @@ bool LayerTranslator::TranslateArrow(
GIntBig *pnReadFeatureCount, GDALProgressFunc pfnProgress,
void *pProgressArg, const GDALVectorTranslateOptions *psOptions)
{
struct ArrowArrayStream stream;
struct ArrowSchema schema;
CPLStringList aosOptionsGetArrowStream;
CPLStringList aosOptionsWriteArrowBatch;
aosOptionsGetArrowStream.SetNameValue("GEOMETRY_ENCODING", "WKB");
if (!psInfo->m_bPreserveFID)
aosOptionsGetArrowStream.SetNameValue("INCLUDE_FID", "NO");
else
if (psInfo->m_bPreserveFID)
{
aosOptionsWriteArrowBatch.SetNameValue(
"FID", psInfo->m_poSrcLayer->GetFIDColumn());
aosOptionsWriteArrowBatch.SetNameValue("IF_FID_NOT_PRESERVED",
"WARNING");
}
if (psOptions->nLimit >= 0)
{
aosOptionsGetArrowStream.SetNameValue(
"MAX_FEATURES_IN_BATCH",
CPLSPrintf(CPL_FRMT_GIB,
std::min<GIntBig>(psOptions->nLimit,
(psOptions->nGroupTransactions > 0
? psOptions->nGroupTransactions
: 65536))));
}
else if (psOptions->nGroupTransactions > 0)
{
aosOptionsGetArrowStream.SetNameValue(
"MAX_FEATURES_IN_BATCH",
CPLSPrintf("%d", psOptions->nGroupTransactions));
}
if (psInfo->m_poSrcLayer->GetArrowStream(&stream,
aosOptionsGetArrowStream.List()))
{
if (stream.get_schema(&stream, &schema) != 0)
{
CPLError(CE_Failure, CPLE_AppDefined, "stream.get_schema() failed");
stream.release(&stream);
return false;
}
}
else

if (psInfo->m_sArrowArrayStream.get_schema(&schema) != 0)
{
CPLError(CE_Failure, CPLE_AppDefined, "GetArrowStream() failed");
CPLError(CE_Failure, CPLE_AppDefined, "stream.get_schema() failed");
return false;
}

Expand Down Expand Up @@ -5865,7 +5876,7 @@ bool LayerTranslator::TranslateArrow(
{
struct ArrowArray array;
// Acquire source batch
if (stream.get_next(&stream, &array) != 0)
if (psInfo->m_sArrowArrayStream.get_next(&array) != 0)
{
CPLError(CE_Failure, CPLE_AppDefined, "stream.get_next() failed");
bRet = false;
Expand Down Expand Up @@ -6043,7 +6054,6 @@ bool LayerTranslator::TranslateArrow(

schema.release(&schema);

stream.release(&stream);
return bRet;
}

Expand Down
61 changes: 61 additions & 0 deletions ogr/ogrsf_frmts/generic/ogrlayerarrow.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
#include <map>
#include <string>

#include "ogr_recordbatch.h"

constexpr const char *ARROW_EXTENSION_NAME_KEY = "ARROW:extension:name";
constexpr const char *ARROW_EXTENSION_METADATA_KEY = "ARROW:extension:metadata";
constexpr const char *EXTENSION_NAME_OGC_WKB = "ogc.wkb";
Expand All @@ -34,4 +36,63 @@ bool CPL_DLL OGRCloneArrowArray(const struct ArrowSchema *schema,
bool CPL_DLL OGRCloneArrowSchema(const struct ArrowSchema *schema,
struct ArrowSchema *out_schema);

/** C++ wrapper on top of ArrowArrayStream */
class OGRArrowArrayStream
{
public:
inline OGRArrowArrayStream()
{
memset(&m_stream, 0, sizeof(m_stream));
}

inline ~OGRArrowArrayStream()
{
clear();
}

inline void clear()
{
if (m_stream.release)
{
m_stream.release(&m_stream);
m_stream.release = nullptr;
}
}

inline ArrowArrayStream *get()
{
return &m_stream;
}

inline int get_schema(struct ArrowSchema *schema)
{
return m_stream.get_schema(&m_stream, schema);
}

inline int get_next(struct ArrowArray *array)
{
return m_stream.get_next(&m_stream, array);
}

inline OGRArrowArrayStream &operator=(OGRArrowArrayStream &&other)
{
if (this != &other)
{
clear();
memcpy(&m_stream, &(other.m_stream), sizeof(m_stream));
memset(&(other.m_stream), 0, sizeof(m_stream));
}
return *this;
}

private:
struct ArrowArrayStream m_stream
{
};

OGRArrowArrayStream(const OGRArrowArrayStream &) = delete;
OGRArrowArrayStream(OGRArrowArrayStream &&) = delete;
OGRArrowArrayStream &operator=(const OGRArrowArrayStream &) = delete;
};

#endif // OGRLAYERARROW_H_DEFINED

0 comments on commit 0f2eefb

Please sign in to comment.