-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Project import generated by Copybara. (#102)
GitOrigin-RevId: 4224a34cc1f2c4a947efd6c5fcc6cea040c37bc6 Co-authored-by: Snowflake Authors <[email protected]>
- Loading branch information
1 parent
c530f5c
commit 2932445
Showing
90 changed files
with
4,477 additions
and
1,158 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,20 @@ | ||
load("//bazel:py_rules.bzl", "py_library") | ||
load("//bazel:py_rules.bzl", "py_library", "py_test") | ||
|
||
package(default_visibility = ["//visibility:public"]) | ||
|
||
py_library( | ||
name = "dataset_dataframe", | ||
name = "lineage_utils", | ||
srcs = [ | ||
"data_source.py", | ||
"dataset_dataframe.py", | ||
"lineage_utils.py", | ||
], | ||
) | ||
|
||
py_test( | ||
name = "lineage_utils_test", | ||
srcs = ["lineage_utils_test.py"], | ||
deps = [ | ||
":lineage_utils", | ||
"//snowflake/ml/utils:connection_params", | ||
], | ||
) |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import copy | ||
import functools | ||
from typing import Any, Callable, List | ||
|
||
from snowflake import snowpark | ||
from snowflake.ml._internal.lineage import data_source | ||
|
||
DATA_SOURCES_ATTR = "_data_sources" | ||
|
||
|
||
def _get_datasources(*args: Any) -> List[data_source.DataSource]: | ||
"""Helper method for extracting data sources attribute from DataFrames in an argument list""" | ||
result = [] | ||
for arg in args: | ||
srcs = getattr(arg, DATA_SOURCES_ATTR, None) | ||
if isinstance(srcs, list) and all(isinstance(s, data_source.DataSource) for s in srcs): | ||
result += srcs | ||
return result | ||
|
||
|
||
def _wrap_func( | ||
fn: Callable[..., snowpark.DataFrame], data_sources: List[data_source.DataSource] | ||
) -> Callable[..., snowpark.DataFrame]: | ||
"""Wrap a DataFrame transform function to propagate data_sources to derived DataFrames.""" | ||
|
||
@functools.wraps(fn) | ||
def wrapped(*args: Any, **kwargs: Any) -> snowpark.DataFrame: | ||
df = fn(*args, **kwargs) | ||
patch_dataframe(df, data_sources=data_sources, inplace=True) | ||
return df | ||
|
||
return wrapped | ||
|
||
|
||
def patch_dataframe( | ||
df: snowpark.DataFrame, data_sources: List[data_source.DataSource], inplace: bool = False | ||
) -> snowpark.DataFrame: | ||
""" | ||
Monkey patch a DataFrame to add attach the provided data_sources as an attribute of the DataFrame. | ||
Also patches the DataFrame's transformation functions to propagate the new data sources attribute to | ||
derived DataFrames. | ||
Args: | ||
df: DataFrame to be patched | ||
data_sources: List of data sources for the DataFrame | ||
inplace: If True, patches to DataFrame in-place. If False, creates a shallow copy of the DataFrame. | ||
Returns: | ||
Patched DataFrame | ||
""" | ||
# Instance-level monkey-patches | ||
funcs = [ | ||
"_with_plan", | ||
"_lateral", | ||
"group_by", | ||
"group_by_grouping_sets", | ||
"cube", | ||
"pivot", | ||
"rollup", | ||
"cache_result", | ||
"_to_df", # RelationalGroupedDataFrame | ||
] | ||
if not inplace: | ||
df = copy.copy(df) | ||
setattr(df, DATA_SOURCES_ATTR, data_sources) | ||
for func in funcs: | ||
fn = getattr(df, func, None) | ||
if fn is not None: | ||
setattr(df, func, _wrap_func(fn, data_sources=data_sources)) | ||
return df | ||
|
||
|
||
def _wrap_class_func(fn: Callable[..., snowpark.DataFrame]) -> Callable[..., snowpark.DataFrame]: | ||
@functools.wraps(fn) | ||
def wrapped(*args: Any, **kwargs: Any) -> snowpark.DataFrame: | ||
df = fn(*args, **kwargs) | ||
data_sources = _get_datasources(*args) + _get_datasources(*kwargs.values()) | ||
if data_sources: | ||
patch_dataframe(df, data_sources, inplace=True) | ||
return df | ||
|
||
return wrapped | ||
|
||
|
||
# Class-level monkey-patches | ||
for klass, func_list in { | ||
snowpark.DataFrame: [ | ||
"__copy__", | ||
], | ||
snowpark.RelationalGroupedDataFrame: [], | ||
}.items(): | ||
assert isinstance(func_list, list) # mypy | ||
for func in func_list: | ||
fn = getattr(klass, func) | ||
setattr(klass, func, _wrap_class_func(fn)) |
Oops, something went wrong.