Skip to content

Commit

Permalink
add init carveout and add validator to flow field arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
RHammond2 committed Dec 12, 2023
1 parent 675e418 commit 5af909e
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 13 deletions.
27 changes: 16 additions & 11 deletions floris/simulation/flow_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@
Grid,
)
from floris.type_dec import (
array_5D_field,
floris_array_converter,
NDArrayFloat,
validate_3DArray_shape,
ValidateMixin,
)

Expand All @@ -49,16 +49,18 @@ class FlowField(BaseClass, ValidateMixin):

n_wind_speeds: int = field(init=False)
n_wind_directions: int = field(init=False)

u_initial_sorted: NDArrayFloat = field(init=False, factory=lambda: np.array([]))
v_initial_sorted: NDArrayFloat = field(init=False, factory=lambda: np.array([]))
w_initial_sorted: NDArrayFloat = field(init=False, factory=lambda: np.array([]))
u_sorted: NDArrayFloat = field(init=False, factory=lambda: np.array([]))
v_sorted: NDArrayFloat = field(init=False, factory=lambda: np.array([]))
w_sorted: NDArrayFloat = field(init=False, factory=lambda: np.array([]))
u: NDArrayFloat = field(init=False, factory=lambda: np.array([]))
v: NDArrayFloat = field(init=False, factory=lambda: np.array([]))
w: NDArrayFloat = field(init=False, factory=lambda: np.array([]))
n_turbines: int = field(init=False)
grid_resolution: int = field(init=False)

u_initial_sorted: NDArrayFloat = array_5D_field
v_initial_sorted: NDArrayFloat = array_5D_field
w_initial_sorted: NDArrayFloat = array_5D_field
u_sorted: NDArrayFloat = array_5D_field
v_sorted: NDArrayFloat = array_5D_field
w_sorted: NDArrayFloat = array_5D_field
u: NDArrayFloat = array_5D_field
v: NDArrayFloat = array_5D_field
w: NDArrayFloat = array_5D_field
het_map: list = field(init=False, default=None)
dudz_initial_sorted: NDArrayFloat = field(init=False, factory=lambda: np.array([]))

Expand Down Expand Up @@ -133,6 +135,9 @@ def initialize_velocity_field(self, grid: Grid) -> None:
# determined by this line. Since the right-most dimension on grid.z is storing the values
# for height, using it here to apply the shear law makes that dimension store the vertical
# wind profile.
self.n_turbines = grid.n_turbines
self.grid_resolution = grid.grid_resolution

wind_profile_plane = (grid.z_sorted / self.reference_wind_height) ** self.wind_shear
dwind_profile_plane = (
self.wind_shear
Expand Down
19 changes: 17 additions & 2 deletions floris/type_dec.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@
import attrs
import numpy as np
import numpy.typing as npt
from attrs import Attribute, define
from attrs import (
Attribute,
define,
field,
)


### Define general data types used throughout
Expand Down Expand Up @@ -160,6 +164,10 @@ def validate_3DArray_shape(instance, attribute: Attribute, value: np.ndarray) ->
if not isinstance(value, np.ndarray):
raise TypeError(f"`{attribute.name}` is not a valid NumPy array type.")

# Don't fail on the initialized empty array
if value.size == 0:
return

shape = (instance.n_wind_directions, instance.n_wind_speeds, instance.n_turbines)
if value.shape != shape:
# The grid sorted_coord_indices are broadcast along the wind speed dimension
Expand Down Expand Up @@ -188,6 +196,10 @@ def validate_5DArray_shape(instance, attribute: Attribute, value: np.ndarray) ->
print(type(value))
raise TypeError(f"`{attribute.name}` is not a valid NumPy array type.")

# Don't fail on the initialized empty array
if value.size == 0:
return

grid = instance.grid_resolution
shape = (instance.n_wind_directions, instance.n_wind_speeds, instance.n_turbines, grid, grid)
if value.shape != shape:
Expand Down Expand Up @@ -265,7 +277,10 @@ def validate(self) -> None:
attrs.validate(self)


# Avoids constant redefinition of the same attr.ib properties for model attributes
# Avoids constant redefinition of the same field properties for model attributes

array_5D_field = field(init=False, factory=lambda: np.array([]), validator=validate_5DArray_shape)


# from functools import partial, update_wrapper

Expand Down

0 comments on commit 5af909e

Please sign in to comment.