Skip to content

Commit

Permalink
fix: streamlit component (#637)
Browse files Browse the repository at this point in the history
* fix: origin datas is not used in streamlit environment

* fix: get_dataset_hash func
  • Loading branch information
longxiaofei authored Oct 2, 2024
1 parent dfb36e3 commit 2e10082
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 22 deletions.
2 changes: 1 addition & 1 deletion pygwalker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pygwalker.services.global_var import GlobalVarManager
from pygwalker.services.kaggle import show_tips_user_kaggle as __show_tips_user_kaggle

__version__ = "0.4.9.9"
__version__ = "0.4.9.10"
__hash__ = __rand_str()

from pygwalker.api.jupyter import walk, render, table
Expand Down
2 changes: 1 addition & 1 deletion pygwalker/api/streamlit.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def _component(
vis_spec: Optional[List[Dict[str, Any]]] = None,
**kwargs: Dict[str, Any]
):
props = self.walker._get_props("streamlit")
props = self.walker._get_props("streamlit", [])
props["gwMode"] = mode
props["communicationUrl"] = BASE_URL_PATH
if vis_spec is not None:
Expand Down
75 changes: 55 additions & 20 deletions pygwalker/services/data_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,31 +84,66 @@ def get_parser(
return parser


def _get_pl_dataset_hash(dataset: DataFrame) -> str:
"""Get polars dataset hash value."""
import polars as pl
row_count = dataset.shape[0]
other_info = str(dataset.shape) + "_polars"
if row_count > 4000:
dataset = pl.concat([dataset[:2000], dataset[-2000:]])
hash_bytes = dataset.hash_rows().to_numpy().tobytes() + other_info.encode()
return hashlib.md5(hash_bytes).hexdigest()


def _get_pd_dataset_hash(dataset: DataFrame) -> str:
"""Get pandas dataset hash value."""
row_count = dataset.shape[0]
other_info = str(dataset.shape) + "_pandas"
if row_count > 4000:
dataset = pd.concat([dataset[:2000], dataset[-2000:]])
hash_bytes = pd.util.hash_pandas_object(dataset).values.tobytes() + other_info.encode()
return hashlib.md5(hash_bytes).hexdigest()


def _get_modin_dataset_hash(dataset: DataFrame) -> str:
"""Get modin dataset hash value."""
import modin.pandas as mpd
row_count = dataset.shape[0]
other_info = str(dataset.shape) + "_modin"
if row_count > 4000:
dataset = mpd.concat([dataset[:2000], dataset[-2000:]])
dataset = dataset._to_pandas()
hash_bytes = pd.util.hash_pandas_object(dataset).values.tobytes() + other_info.encode()
return hashlib.md5(hash_bytes).hexdigest()


def _get_spark_dataset_hash(dataset: DataFrame) -> str:
"""Get pyspark dataset hash value."""
shape = ((dataset.count(), len(dataset.columns)))
row_count = shape[0]
other_info = str(shape) + "_pyspark"
if row_count > 4000:
dataset = dataset.limit(4000)
dataset_pd = dataset.toPandas()
hash_bytes = pd.util.hash_pandas_object(dataset_pd).values.tobytes() + other_info.encode()
return hashlib.md5(hash_bytes).hexdigest()


def get_dataset_hash(dataset: Union[DataFrame, Connector, str]) -> str:
"""Just a less accurate way to get different dataset hash values."""
_, dataset_type = _get_data_parser(dataset)
if dataset_type in ["pandas", "modin", "polars"]:
row_count = dataset.shape[0]
other_info = str(dataset.shape) + "_" + dataset_type
if row_count > 4000:
dataset = dataset[:2000] + dataset[-2000:]
if dataset_type == "modin":
dataset = dataset._to_pandas()
if dataset_type in ["pandas", "modin"]:
hash_bytes = pd.util.hash_pandas_object(dataset).values.tobytes() + other_info.encode()
else:
hash_bytes = dataset.hash_rows().to_numpy().tobytes() + other_info.encode()
return hashlib.md5(hash_bytes).hexdigest()

if dataset_type == "polars":
return _get_pl_dataset_hash(dataset)

if dataset_type == "pandas":
return _get_pd_dataset_hash(dataset)

if dataset_type == "modin":
return _get_modin_dataset_hash(dataset)

if dataset_type == "pyspark":
shape = ((dataset.count(), len(dataset.columns)))
row_count = shape[0]
other_info = str(shape) + "_" + dataset_type
if row_count > 4000:
dataset = dataset.limit(4000)
dataset_pd = dataset.toPandas()
hash_bytes = pd.util.hash_pandas_object(dataset_pd).values.tobytes() + other_info.encode()
return hashlib.md5(hash_bytes).hexdigest()
return _get_spark_dataset_hash(dataset)

if dataset_type == "connector":
return hashlib.md5("_".join([dataset.url, dataset.view_sql, dataset_type]).encode()).hexdigest()
Expand Down

0 comments on commit 2e10082

Please sign in to comment.