Skip to content

Commit

Permalink
Add convert_to_gdf function (#872)
Browse files Browse the repository at this point in the history
* Add convert_to_gdf function

* Fix gpd import error

* Use forward reference

* Use type checking
  • Loading branch information
giswqs authored Aug 21, 2024
1 parent 96ccb2d commit 18c8b7f
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 32 deletions.
130 changes: 110 additions & 20 deletions leafmap/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,16 @@
import whitebox
import subprocess
from pathlib import Path
from typing import Union, List, Dict, Optional, Tuple
from typing import Union, List, Dict, Optional, Tuple, TYPE_CHECKING
from .stac import *

try:
from IPython.display import display, IFrame
from IPython.display import display
except ImportError:
pass

try:
if TYPE_CHECKING:
import geopandas as gpd
except ImportError:
gpd = None


class WhiteboxTools(whitebox.WhiteboxTools):
Expand Down Expand Up @@ -104,7 +102,7 @@ def set_proxy(
try:
response = requests.get("https://google.com")
response.raise_for_status()
except requests.exceptions.RequestException as error_requests:
except requests.exceptions.RequestException as e:
print(
"Failed to connect to Google Services. "
"Please double check the port number and IP address."
Expand Down Expand Up @@ -421,7 +419,7 @@ def display_html(
else:
raise ValueError("Invalid input type. Expected a file path or an HTML string.")

display(IFrame(srcdoc=html_content, width=width, height=height))
display(IFrame(src=html_content, width=width, height=height))


def has_transparency(img) -> bool:
Expand Down Expand Up @@ -2492,19 +2490,6 @@ def to_hex_colors(colors):
return colors


def display_html(src, width=950, height=600):
"""Display an HTML file in a Jupyter Notebook.
Args
src (str): File path to HTML file.
width (int, optional): Width of the map. Defaults to 950.
height (int, optional): Height of the map. Defaults to 600.
"""
if not os.path.isfile(src):
raise ValueError(f"{src} is not a valid file path.")
display(IFrame(src=src, width=width, height=height))


def get_census_dict(reset=False):
"""Returns a dictionary of Census data.
Expand Down Expand Up @@ -14146,3 +14131,108 @@ def d2s_tile(
return f"{url}?API_KEY={api_key}"
else:
return url


def convert_to_gdf(
data: Union[pd.DataFrame, str],
geometry: Optional[str] = None,
lat: Optional[str] = None,
lon: Optional[str] = None,
crs: str = "EPSG:4326",
included: Optional[List[str]] = None,
excluded: Optional[List[str]] = None,
obj_to_str: bool = False,
open_args: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> "gpd.GeoDataFrame":
"""Convert data to a GeoDataFrame.
Args:
data (Union[pd.DataFrame, str]): The input data, either as a DataFrame or a file path.
geometry (Optional[str], optional): The column name containing geometry data. Defaults to None.
lat (Optional[str], optional): The column name containing latitude data. Defaults to None.
lon (Optional[str], optional): The column name containing longitude data. Defaults to None.
crs (str, optional): The coordinate reference system to use. Defaults to "EPSG:4326".
included (Optional[List[str]], optional): List of columns to include. Defaults to None.
excluded (Optional[List[str]], optional): List of columns to exclude. Defaults to None.
obj_to_str (bool, optional): Whether to convert object dtype columns to string. Defaults to False.
open_args (Optional[Dict[str, Any]], optional): Additional arguments for file opening functions. Defaults to None.
**kwargs (Any): Additional keyword arguments for GeoDataFrame creation.
Returns:
gpd.GeoDataFrame: The converted GeoDataFrame.
Raises:
ValueError: If the file format is unsupported or required columns are not provided.
"""
import geopandas as gpd
from shapely.geometry import Point, shape

if open_args is None:
open_args = {}

if not isinstance(data, pd.DataFrame):
if isinstance(data, str):
if data.endswith(".parquet"):
data = pd.read_parquet(data, **open_args)
elif data.endswith(".csv"):
data = pd.read_csv(data, **open_args)
elif data.endswith(".json"):
data = pd.read_json(data, **open_args)
elif data.endswith(".xlsx"):
data = pd.read_excel(data, **open_args)
else:
raise ValueError(
"Unsupported file format. Only Parquet, CSV, JSON, and Excel files are supported."
)

# If include_cols is specified, filter the DataFrame to include only those columns
if included:
if geometry:
included.append(geometry)
elif lat and lon:
included.append(lat)
included.append(lon)
data = data[included]

# Exclude specified columns if provided
if excluded:
data = data.drop(columns=excluded)

# Convert 'object' dtype columns to 'string' if obj_to_str is True
if obj_to_str:
data = data.astype(
{col: "string" for col in data.select_dtypes(include="object").columns}
)

# Handle the creation of geometry
if geometry:

def convert_geometry(x):
if isinstance(x, str):
try:
# Parse the string as JSON and then convert to a geometry
return shape(json.loads(x))
except (json.JSONDecodeError, TypeError) as e:
print(f"Error converting geometry: {e}")
return None
return x

data = data[data[geometry].notnull()]
data[geometry] = data[geometry].apply(convert_geometry)
elif lat and lon:
# Create a geometry column from latitude and longitude
data["geometry"] = data.apply(lambda row: Point(row[lon], row[lat]), axis=1)
geometry = "geometry"
else:
raise ValueError(
"Either geometry_col or both lat_col and lon_col must be provided."
)

# Convert the DataFrame to a GeoDataFrame
gdf = gpd.GeoDataFrame(data, geometry=geometry, **kwargs)

# Set CRS (assuming WGS84 by default, modify as needed)
gdf.set_crs(crs, inplace=True)

return gdf
24 changes: 12 additions & 12 deletions leafmap/leafmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,39 +211,39 @@ def handle_draw(target, action, geo_json):
if "catalog_source" in kwargs:
self.set_catalog_source(kwargs["catalog_source"])

def add(self, object, index=None, **kwargs) -> None:
def add(self, obj, index=None, **kwargs) -> None:
"""Adds a layer to the map.
Args:
layer (object): The layer to add to the map.
index (int, optional): The index at which to add the layer. Defaults to None.
"""
if isinstance(object, str):
if object in basemaps.keys():
object = get_basemap(object)
if isinstance(obj, str):
if obj in basemaps.keys():
obj = get_basemap(obj)
else:
if object == "nasa_earth_data":
if obj == "nasa_earth_data":
from .toolbar import nasa_data_gui

nasa_data_gui(self, **kwargs)
elif object == "inspector":
elif obj == "inspector":
from .toolbar import inspector_gui

inspector_gui(self, **kwargs)

elif object == "stac":
elif obj == "stac":
self.add_stac_gui(**kwargs)
elif object == "basemap":
elif obj == "basemap":
self.add_basemap_gui(**kwargs)
elif object == "inspector":
elif obj == "inspector":
self.add_inspector_gui(**kwargs)
elif object == "layer_manager":
elif obj == "layer_manager":
self.add_layer_manager(**kwargs)
elif object == "oam":
elif obj == "oam":
self.add_oam_gui(**kwargs)
return

super().add(object, index=index)
super().add(obj, index=index)

if hasattr(self, "layer_manager_widget"):
self.update_layer_manager()
Expand Down

0 comments on commit 18c8b7f

Please sign in to comment.