Skip to content

Commit

Permalink
refactor: simplification and provider masking [no ci]
Browse files Browse the repository at this point in the history
  • Loading branch information
program-- committed Jun 24, 2024
1 parent f2c7a49 commit 03cfd76
Show file tree
Hide file tree
Showing 7 changed files with 281 additions and 344 deletions.
176 changes: 111 additions & 65 deletions include/forcing/ForcingsEngineDataProvider.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,44 @@ static constexpr auto forcings_engine_python_class = "NWMv3_Forcing_Engine_BMI_
static constexpr auto forcings_engine_python_classpath = "NextGen_Forcings_Engine.NWMv3_Forcing_Engine_BMI_model";
static constexpr auto default_time_format = "%Y-%m-%d %H:%M:%S";

//! Parse time string from format.
//! Utility function for ForcingsEngineLumpedDataProvider constructor.
time_t parse_time(const std::string& time, const std::string& fmt);

/**
* Check that requirements for running the forcings engine
* are available at runtime. If requirements are not available,
* then this function throws.
*/
void assert_forcings_engine_requirements();

namespace detail {

//! Storage for Forcings Engine-specific BMI instances.
struct ForcingsEngineStorage {
using key_type = std::string;
using bmi_type = models::bmi::Bmi_Py_Adapter;
using value_type = std::shared_ptr<bmi_type>;

value_type get(const key_type& key)
{
auto pos = data_.find(key);
if (pos == data_.end()) {
return nullptr;
}

return pos->second;
}

void set(const key_type& key, value_type value) { data_[key] = value; }
void set(const key_type& key, value_type value)
{
data_[key] = value;
}

void clear() { data_.clear(); }
void clear()
{
data_.clear();
}

private:
//! Instance map of underlying BMI models.
Expand All @@ -45,17 +62,6 @@ static ForcingsEngineStorage forcings_engine_instances{};

} // namespace detail

//! Parse time string from format.
//! Utility function for ForcingsEngineLumpedDataProvider constructor.
time_t parse_time(const std::string& time, const std::string& fmt);

/**
* Check that requirements for running the forcings engine
* are available at runtime. If requirements are not available,
* then this function throws.
*/
void assert_forcings_engine_requirements();

template<typename DataType, typename SelectionType>
struct ForcingsEngineDataProvider
: public DataProvider<DataType, SelectionType>
Expand All @@ -64,7 +70,7 @@ struct ForcingsEngineDataProvider
using selection_type = SelectionType;
using clock_type = std::chrono::system_clock;

~ForcingsEngineDataProvider() = default;
~ForcingsEngineDataProvider() override = default;

boost::span<const std::string> get_available_variable_names() override
{
Expand All @@ -75,7 +81,7 @@ struct ForcingsEngineDataProvider
{
return clock_type::to_time_t(time_begin_);
}

long get_data_stop_time() override
{
return clock_type::to_time_t(time_end_);
Expand Down Expand Up @@ -109,65 +115,50 @@ struct ForcingsEngineDataProvider
/* Remaining virtual member functions from DataProvider must be implemented
by derived classes. */

data_type get_value(const selection_type& selector, data_access::ReSampleMethod m) override = 0;

std::vector<data_type> get_values(const selection_type& selector, data_access::ReSampleMethod m) override = 0;
data_type get_value(
const selection_type& selector,
data_access::ReSampleMethod m
) override = 0;

std::vector<data_type> get_values(
const selection_type& selector,
data_access::ReSampleMethod m
) override = 0;

protected:
// TODO: It may make more sense to have time_begin_seconds and time_end_seconds coalesced into
// a single argument: `clock_type::duration time_duration`, since the forcings engine
// manages time via a duration rather than time points. !! Need to double check
ForcingsEngineDataProvider(
const std::string& init,
std::size_t time_begin_seconds,
std::size_t time_end_seconds
std::string init,
std::size_t time_begin_seconds,
std::size_t time_end_seconds
)
: time_begin_(std::chrono::seconds{time_begin_seconds})
, time_end_(std::chrono::seconds{time_end_seconds})
, bmi_(detail::forcings_engine_instances.get(init))
, time_step_(std::chrono::seconds{static_cast<int64_t>(bmi_->GetTimeStep())})
, time_current_index_(std::chrono::seconds{static_cast<int64_t>(bmi_->GetCurrentTime())} / time_step_)
, var_output_names_(bmi_->GetOutputVarNames())
{}

/**
* Update the Forcings Engine instance to the next timestep.
*/
void next()
{

assert_forcings_engine_requirements();

bmi_ = detail::forcings_engine_instances.get(init);
if (bmi_ == nullptr) {
bmi_ = std::make_shared<models::bmi::Bmi_Py_Adapter>(
"ForcingsEngine",
init,
forcings_engine_python_classpath,
/*allow_exceed_end=*/true,
/*has_fixed_time_step=*/true,
utils::getStdOut()
);

detail::forcings_engine_instances.set(init, bmi_);
}

time_step_ = std::chrono::seconds{static_cast<int64_t>(bmi_->GetTimeStep())};
time_current_index_ = std::chrono::seconds{static_cast<int64_t>(bmi_->GetCurrentTime())} / time_step_;
var_output_names_ = bmi_->GetOutputVarNames();
}

template<typename Derived>
static std::unique_ptr<ForcingsEngineDataProvider> make_instance(
const std::string& init,
const std::string& time_begin,
const std::string& time_end,
const std::string& time_fmt = default_time_format
)
{
auto time_begin_epoch = parse_time(time_begin, time_fmt);
auto time_end_epoch = parse_time(time_end, time_fmt);
return std::unique_ptr<Derived>{
new Derived{init, time_begin_epoch, time_end_epoch}
};
}

void next() {
bmi_->Update();
time_current_index_++;
}

void next(double time) {
/**
* Update the Forcings Engine instance to the next timestep
* before `time`.
*
* @param time Time in seconds to update to.
* i.e. A value of 14401 will update
* the instance to the 5th time index,
* since 3600 * 4 = 14400 but 14400 < 14401.
*/
void next(double time)
{
const auto start = bmi_->GetCurrentTime();
bmi_->UpdateUntil(time);
const auto end = bmi_->GetCurrentTime();
Expand All @@ -192,4 +183,59 @@ struct ForcingsEngineDataProvider
std::size_t time_current_index_{};
};


// Forcings Engine factory function
template<
typename Tp,
typename... Args,
std::enable_if_t<
std::is_base_of<
ForcingsEngineDataProvider<
typename Tp::data_type,
typename Tp::selection_type
>,
Tp
>::value,
bool
> = true
>
std::unique_ptr<typename Tp::base_type> make_forcings_engine(
const std::string& init,
const std::string& time_begin,
const std::string& time_end,
const std::string& time_fmt = default_time_format,
Args&&... args
)
{
// Ensure python and other requirements are met to use
// the python forcings engine.
// TODO: Move this to a place where it only runs once,
// possibly in NGen.cpp?
assert_forcings_engine_requirements();

// Get or create a BMI instance based on the init config path
auto instance = detail::forcings_engine_instances.get(init);
if (instance == nullptr) {
instance = std::make_shared<models::bmi::Bmi_Py_Adapter>(
"ForcingsEngine",
init,
forcings_engine_python_classpath,
/*allow_exceed_end=*/true,
/*has_fixed_time_step=*/true,
utils::getStdOut()
);

detail::forcings_engine_instances.set(init, instance);
}

// Create the derived instance
return std::make_unique<Tp>(
init,
parse_time(time_begin, time_fmt),
parse_time(time_end, time_fmt),
std::forward<Args>(args)...
);
}


} // namespace data_access
87 changes: 65 additions & 22 deletions include/forcing/ForcingsEngineGriddedDataProvider.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,44 +3,87 @@
#include "GridDataSelector.hpp"
#include "ForcingsEngineDataProvider.hpp"

#include <utilities/mdarray.hpp>

namespace data_access {

struct ForcingsEngineGriddedDataProvider
: public ForcingsEngineDataProvider<Cell, GridDataSelector>
: public ForcingsEngineDataProvider<double, GridDataSelector>
{
using data_type = data_type;
using selection_type = selection_type;
using base_type = ForcingsEngineDataProvider<data_type, selection_type>;

~ForcingsEngineGriddedDataProvider() override = default;

data_type get_value(const selection_type& selector, data_access::ReSampleMethod m) override;

std::vector<data_type> get_values(const selection_type& selector, data_access::ReSampleMethod m) override;

static std::unique_ptr<ForcingsEngineDataProvider> make_gridded_instance(
/**
* Construct a domain-wide Gridded Forcings Engine data provider
* @param init Path to instance initialization file
* @param time_begin_seconds Time in seconds for begin time. Typically 0.
* @param time_end_seconds Time in seconds for end time. Typically the lifetime of the simulation.
*/
ForcingsEngineGriddedDataProvider(
const std::string& init,
const std::string& time_start,
const std::string& time_end,
const std::string& time_fmt = default_time_format
)
{
return make_instance<ForcingsEngineGriddedDataProvider>(init, time_start, time_end, time_fmt);
}
std::size_t time_begin_seconds,
std::size_t time_end_seconds
);

private:
friend base_type;
/**
* Construct a masked Gridded Forcings Engine data provider
* @param mask Bounding box used to mask data provider, results will be returned
* within this region.
*/
ForcingsEngineGriddedDataProvider(
const std::string& init,
std::size_t time_begin_seconds,
std::size_t time_end_seconds,
const BoundingBox& mask
);

/**
* Construct a polygon masked Gridded Forcings Engine data provider
* @param boundary Polygon used to mask data provider, results will be returned
* within the **bounding box** of this region.
*/
ForcingsEngineGriddedDataProvider(
const std::string& init,
std::time_t time_begin_seconds,
std::time_t time_end_seconds
std::size_t time_begin_seconds,
std::size_t time_end_seconds,
const geojson::polygon_t& boundary
);

int var_grid_id_{};
~ForcingsEngineGriddedDataProvider() override = default;

data_type get_value(
const selection_type& selector,
data_access::ReSampleMethod m
) override;

/**
* Get the values of a gridded variable in time.
*/
std::vector<data_type> get_values(
const selection_type& selector,
data_access::ReSampleMethod m
) override;

private:
/**
* Get the underlying grid specification of this provider's instance.
*/
const GridSpecification& grid() const noexcept;

/**
* Get the mask of this provider. If the provider is domain-wide, the
* returned bounding box will equivalent to the grid's bounding box.
*/
const BoundingBox& mask() const noexcept;


//! Grid ID for underlying forcings engine instance
int var_grid_id_ = -1;

//! Total grid specification for forcings engine instance
GridSpecification var_grid_{};

//! Provider Grid Mask, the AOI this provider operates under
BoundingBox var_grid_mask_{};
};

} // namespace data_access
Loading

0 comments on commit 03cfd76

Please sign in to comment.