From 2e100826085b9b2a49d596118c73d7df6ed75035 Mon Sep 17 00:00:00 2001 From: Douding Date: Wed, 2 Oct 2024 15:18:00 +0800 Subject: [PATCH] fix: streamlit component (#637) * fix: origin datas is not used in streamlit environment * fix: get_dataset_hash func --- pygwalker/__init__.py | 2 +- pygwalker/api/streamlit.py | 2 +- pygwalker/services/data_parsers.py | 75 ++++++++++++++++++++++-------- 3 files changed, 57 insertions(+), 22 deletions(-) diff --git a/pygwalker/__init__.py b/pygwalker/__init__.py index 4cbb445..391f155 100644 --- a/pygwalker/__init__.py +++ b/pygwalker/__init__.py @@ -10,7 +10,7 @@ from pygwalker.services.global_var import GlobalVarManager from pygwalker.services.kaggle import show_tips_user_kaggle as __show_tips_user_kaggle -__version__ = "0.4.9.9" +__version__ = "0.4.9.10" __hash__ = __rand_str() from pygwalker.api.jupyter import walk, render, table diff --git a/pygwalker/api/streamlit.py b/pygwalker/api/streamlit.py index 967b86b..be34303 100644 --- a/pygwalker/api/streamlit.py +++ b/pygwalker/api/streamlit.py @@ -268,7 +268,7 @@ def _component( vis_spec: Optional[List[Dict[str, Any]]] = None, **kwargs: Dict[str, Any] ): - props = self.walker._get_props("streamlit") + props = self.walker._get_props("streamlit", []) props["gwMode"] = mode props["communicationUrl"] = BASE_URL_PATH if vis_spec is not None: diff --git a/pygwalker/services/data_parsers.py b/pygwalker/services/data_parsers.py index f69a0f2..2e523ae 100644 --- a/pygwalker/services/data_parsers.py +++ b/pygwalker/services/data_parsers.py @@ -84,31 +84,66 @@ def get_parser( return parser +def _get_pl_dataset_hash(dataset: DataFrame) -> str: + """Get polars dataset hash value.""" + import polars as pl + row_count = dataset.shape[0] + other_info = str(dataset.shape) + "_polars" + if row_count > 4000: + dataset = pl.concat([dataset[:2000], dataset[-2000:]]) + hash_bytes = dataset.hash_rows().to_numpy().tobytes() + other_info.encode() + return hashlib.md5(hash_bytes).hexdigest() + + +def _get_pd_dataset_hash(dataset: DataFrame) -> str: + """Get pandas dataset hash value.""" + row_count = dataset.shape[0] + other_info = str(dataset.shape) + "_pandas" + if row_count > 4000: + dataset = pd.concat([dataset[:2000], dataset[-2000:]]) + hash_bytes = pd.util.hash_pandas_object(dataset).values.tobytes() + other_info.encode() + return hashlib.md5(hash_bytes).hexdigest() + + +def _get_modin_dataset_hash(dataset: DataFrame) -> str: + """Get modin dataset hash value.""" + import modin.pandas as mpd + row_count = dataset.shape[0] + other_info = str(dataset.shape) + "_modin" + if row_count > 4000: + dataset = mpd.concat([dataset[:2000], dataset[-2000:]]) + dataset = dataset._to_pandas() + hash_bytes = pd.util.hash_pandas_object(dataset).values.tobytes() + other_info.encode() + return hashlib.md5(hash_bytes).hexdigest() + + +def _get_spark_dataset_hash(dataset: DataFrame) -> str: + """Get pyspark dataset hash value.""" + shape = ((dataset.count(), len(dataset.columns))) + row_count = shape[0] + other_info = str(shape) + "_pyspark" + if row_count > 4000: + dataset = dataset.limit(4000) + dataset_pd = dataset.toPandas() + hash_bytes = pd.util.hash_pandas_object(dataset_pd).values.tobytes() + other_info.encode() + return hashlib.md5(hash_bytes).hexdigest() + + def get_dataset_hash(dataset: Union[DataFrame, Connector, str]) -> str: """Just a less accurate way to get different dataset hash values.""" _, dataset_type = _get_data_parser(dataset) - if dataset_type in ["pandas", "modin", "polars"]: - row_count = dataset.shape[0] - other_info = str(dataset.shape) + "_" + dataset_type - if row_count > 4000: - dataset = dataset[:2000] + dataset[-2000:] - if dataset_type == "modin": - dataset = dataset._to_pandas() - if dataset_type in ["pandas", "modin"]: - hash_bytes = pd.util.hash_pandas_object(dataset).values.tobytes() + other_info.encode() - else: - hash_bytes = dataset.hash_rows().to_numpy().tobytes() + other_info.encode() - return hashlib.md5(hash_bytes).hexdigest() + + if dataset_type == "polars": + return _get_pl_dataset_hash(dataset) + + if dataset_type == "pandas": + return _get_pd_dataset_hash(dataset) + + if dataset_type == "modin": + return _get_modin_dataset_hash(dataset) if dataset_type == "pyspark": - shape = ((dataset.count(), len(dataset.columns))) - row_count = shape[0] - other_info = str(shape) + "_" + dataset_type - if row_count > 4000: - dataset = dataset.limit(4000) - dataset_pd = dataset.toPandas() - hash_bytes = pd.util.hash_pandas_object(dataset_pd).values.tobytes() + other_info.encode() - return hashlib.md5(hash_bytes).hexdigest() + return _get_spark_dataset_hash(dataset) if dataset_type == "connector": return hashlib.md5("_".join([dataset.url, dataset.view_sql, dataset_type]).encode()).hexdigest()