Skip to content

Commit

Permalink
Fix reading grib with extensions that gdal ignores
Browse files Browse the repository at this point in the history
  • Loading branch information
amarandon committed Nov 8, 2024
1 parent 46ef711 commit 08003ad
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 1 deletion.
8 changes: 7 additions & 1 deletion eodag_cube/api/product/drivers/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@
logger = logging.getLogger("eodag-cube.driver.generic")


# File extensions to accept on top of those known to rasterio/GDAL
EXTRA_ALLOWED_FILE_EXTENSIONS = [".grib", ".grib2"]


class GenericDriver(DatasetDriver):
"""Generic Driver for products that need to be downloaded"""

Expand All @@ -56,7 +60,9 @@ def get_data_address(self, eo_product: EOProduct, band: str) -> str:
matching_files = []
for f_path in Path(uri_to_path(eo_product.location)).glob("**/*"):
f_str = str(f_path.resolve())
if p.search(f_str):
if f_path.suffix in EXTRA_ALLOWED_FILE_EXTENSIONS:
matching_files.append(f_str)
elif p.search(f_str):
try:
# files readable by rasterio
rasterio.drivers.driver_from_extension(f_path)
Expand Down
Binary file not shown.
Binary file not shown.
28 changes: 28 additions & 0 deletions tests/units/test_eoproduct_driver_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,16 @@
UnsupportedDatasetAddressScheme,
)

TEST_GRIB_PRODUCT = (
"CAMS_EAC4_20210101_20210102_4d792734017419d1719b53f4d5b5d4d6888641de"
)
TEST_GRIB_FILENAME = f"{TEST_GRIB_PRODUCT}.grib"
TEST_GRIB_PRODUCT_PATH = os.path.join(
TEST_RESOURCES_PATH,
"products",
TEST_GRIB_PRODUCT,
)


class TestEOProductDriverGeneric(EODagTestCase):
def setUp(self):
Expand Down Expand Up @@ -58,6 +68,15 @@ def test_driver_get_local_dataset_address_ok(self):
address = self.product.driver.get_data_address(product, band)
self.assertEqual(address, self.local_band_file)

def test_driver_get_local_grib_dataset_address_ok(self):
"""Driver returns a good address for a grib file"""
with self._grib_product() as product:

address = self.product.driver.get_data_address(product, TEST_GRIB_FILENAME)

grib_path = os.path.join(TEST_GRIB_PRODUCT_PATH, TEST_GRIB_FILENAME)
self.assertEqual(address, grib_path)

def test_driver_get_http_remote_dataset_address_fail(self):
"""Driver must raise UnsupportedDatasetAddressScheme if location scheme is http or https"""
# Default value of self.product.location is 'https://...'
Expand All @@ -79,3 +98,12 @@ def _filesystem_product(self):
yield self.product
finally:
self.product.location = original

@contextmanager
def _grib_product(self):
original = self.product.location
try:
self.product.location = f"file://{TEST_GRIB_PRODUCT_PATH}"
yield self.product
finally:
self.product.location = original

0 comments on commit 08003ad

Please sign in to comment.