Skip to content

Commit

Permalink
Merge pull request #485 from SuperCowPowers/training_to_bool
Browse files Browse the repository at this point in the history
Switching training views from 0/1 to False/True and misc safe guards/sanity checks
  • Loading branch information
brifordwylie authored Nov 10, 2024
2 parents 1eaac2c + 38b5ba5 commit dfe6c7b
Show file tree
Hide file tree
Showing 21 changed files with 218 additions and 75 deletions.
2 changes: 1 addition & 1 deletion docs/api_classes/endpoint.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ endpoint = Endpoint("abalone-regression-end")
model = Model(endpoint.get_input())
fs = FeatureSet(model.get_input())
athena_table = fs.view("training").table
df = fs.query(f"SELECT * FROM {athena_table} where training = 0")
df = fs.query(f"SELECT * FROM {athena_table} where training = FALSE")

# Run inference/predictions on the Endpoint
results_df = endpoint.inference(df)
Expand Down
2 changes: 1 addition & 1 deletion docs/api_classes/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ endpoint = Endpoint("abalone-regression-end")

# Get a DataFrame of data (not used to train) and run predictions
athena_table = fs.view("training").table
df = fs.query(f"SELECT * FROM {athena_table} where training = 0")
df = fs.query(f"SELECT * FROM {athena_table} where training = FALSE")
results = endpoint.predict(df)
print(results[["class_number_of_rings", "prediction"]])
```
Expand Down
2 changes: 1 addition & 1 deletion examples/endpoint/endpoint_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
model = Model(endpoint.get_input())
fs = FeatureSet(model.get_input())
athena_table = fs.view("training").table
df = fs.query(f"SELECT * FROM {athena_table} where training = 0")
df = fs.query(f"SELECT * FROM {athena_table} where training = FALSE")

# Run inference/predictions on the Endpoint
results_df = endpoint.inference(df)
Expand Down
2 changes: 1 addition & 1 deletion examples/full_ml_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,6 @@

# Get a DataFrame of data (not used to train) and run predictions
athena_table = fs.view("training").table
df = fs.query(f"SELECT * FROM {athena_table} where training = 0")
df = fs.query(f"SELECT * FROM {athena_table} where training = FALSE")
results = endpoint.inference(df)
print(results[["class_number_of_rings", "prediction"]])
2 changes: 1 addition & 1 deletion examples/storage/endpoint_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def run_inference(endpoint_name):
feature_set = ModelCore(model).get_input()
features = FeatureSetCore(feature_set)
table = features.view("training").table
test_df = features.query(f'SELECT * FROM "{table}" where training = 0')
test_df = features.query(f'SELECT * FROM "{table}" where training = FALSE')

# Drop some columns
test_df.drop(["write_time", "api_invocation_time", "is_deleted"], axis=1, inplace=True)
Expand Down
2 changes: 1 addition & 1 deletion notebooks/ML_Pipeline_with_SageWorks.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@
"source": [
"# Get a DataFrame of data (not used to train) and run predictions\n",
"table = feature_set.view(\"training\").table\n",
"test_df = feature_set.query(f\"SELECT * FROM {table} where training = 0\")\n",
"test_df = feature_set.query(f\"SELECT * FROM {table} where training = FALSE\")\n",
"test_df.head()"
]
},
Expand Down
4 changes: 2 additions & 2 deletions notebooks/Regression_Confidence_Experiments.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@
"source": [
"# Grab the Training View\n",
"table = fs.view(\"training\").table\n",
"train_df = fs.query(f\"SELECT * FROM {table} where training = 1\")\n",
"hold_out_df = fs.query(f\"SELECT * FROM {table} where training = 0\")"
"train_df = fs.query(f\"SELECT * FROM {table} where training = TRUE\")\n",
"hold_out_df = fs.query(f\"SELECT * FROM {table} where training = FALSE\")"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion notebooks/Residual_Analysis.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@
"source": [
"# Get a DataFrame of data (not used to train) and run predictions\n",
"table = feature_set.view(\"training\").table\n",
"test_df = feature_set.query(f\"SELECT * FROM {table} where training = 0\")\n",
"test_df = feature_set.query(f\"SELECT * FROM {table} where training = FALSE\")\n",
"test_df.head()"
],
"outputs": []
Expand Down
101 changes: 101 additions & 0 deletions scripts/convert_training_views.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"""Convert training views from 0/1 to FALSE/TRUE in the specified AWS Glue database.
Note: This script is a 'schema' change for the training views, and is a 'one-time'
operation. The code quality here is not as important as the correctness of the
operation and since this will only be run once for existing clients and never
again, we don't want to sweat the details.
"""

import json
import re
import base64
import logging
import awswrangler as wr
from sageworks.core.cloud_platform.aws.aws_account_clamp import AWSAccountClamp

log = logging.getLogger("sageworks")

# Initialize your AWS session and Glue client
aws_account_clamp = AWSAccountClamp()
session = aws_account_clamp.boto3_session
glue_client = session.client("glue")


def _decode_view_sql(encoded_sql: str) -> str:
"""Decode the base64-encoded SQL query from the view.
Args:
encoded_sql (str): The encoded SQL query in the ViewOriginalText.
Returns:
str: The decoded SQL query.
"""
# Extract the base64-encoded content from the comment
match = re.search(r"Presto View: ([\w=+/]+)", encoded_sql)
if match:
base64_sql = match.group(1)
decoded_bytes = base64.b64decode(base64_sql)
decoded_str = decoded_bytes.decode("utf-8")

# Parse the decoded string as JSON to extract the SQL
try:
view_json = json.loads(decoded_str)
return view_json.get("originalSql", "")
except json.JSONDecodeError:
log.error("Failed to parse the decoded view SQL as JSON.")
return ""
return ""


def convert_training_views(database):
"""Convert training views from 0/1 to FALSE/TRUE in the specified AWS Glue database"""

# Use the Glue client to get the list of tables (views) from the database
paginator = glue_client.get_paginator("get_tables")

for page in paginator.paginate(DatabaseName=database):
for table in page["TableList"]:
# Check if the table name ends with "_training" and is a view
if table["Name"].endswith("_training") and table.get("TableType") == "VIRTUAL_VIEW":
print(f"Checking view: {table['Name']}...")

# Decode the 'ViewOriginalText' for the view
view_original_text = _decode_view_sql(table.get("ViewOriginalText"))
if view_original_text and (" THEN 0" in view_original_text or " THEN 1" in view_original_text):
print(f"\tConverting view: {table['Name']}...")

# Update the THEN and ELSE view definitions by replacing 0/1 with FALSE/TRUE
updated_query = view_original_text.replace(" THEN 0", " THEN FALSE").replace(
" THEN 1", " THEN TRUE"
)
updated_query = updated_query.replace(" ELSE 0", " ELSE FALSE").replace(" ELSE 1", " ELSE TRUE")

# Construct the full CREATE OR REPLACE VIEW query
query = f"""
CREATE OR REPLACE VIEW {table['Name']} AS
{updated_query}
"""

try:
# Execute the query using awswrangler
query_execution_id = wr.athena.start_query_execution(
sql=query,
database=database,
boto3_session=session,
)
print(f"\tQueryExecutionId: {query_execution_id}")

# Wait for the query to complete
wr.athena.wait_query(query_execution_id=query_execution_id, boto3_session=session)
print(f"\tSuccessfully converted view: {table['Name']}")
except Exception as e:
print(f"\tError updating view {table['Name']}: {e}")
else:
print(f"\tNo conversion needed for view: {table['Name']}")


if __name__ == "__main__":
# Specify your database scope
database_scope = ["sagemaker_featurestore"]
for db in database_scope:
convert_training_views(db)
2 changes: 1 addition & 1 deletion src/sageworks/api/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def fast_inference(self, eval_df: pd.DataFrame) -> pd.DataFrame:
model = Model(my_endpoint.get_input())
my_features = FeatureSet(model.get_input())
table = my_features.view("training").table
df = my_features.query(f'SELECT * FROM "{table}" where training = 0')
df = my_features.query(f'SELECT * FROM "{table}" where training = FALSE')
results = my_endpoint.inference(df)
target = model.target()
pprint(results[[target, "prediction"]])
Expand Down
7 changes: 6 additions & 1 deletion src/sageworks/core/artifacts/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from abc import ABC, abstractmethod
from datetime import datetime
import logging
from typing import Union

# SageWorks Imports
from sageworks.core.cloud_platform.aws.aws_account_clamp import AWSAccountClamp
Expand Down Expand Up @@ -128,8 +129,12 @@ def exists(self) -> bool:
"""Does the Artifact exist? Can we connect to it?"""
pass

def sageworks_meta(self) -> dict:
def sageworks_meta(self) -> Union[dict, None]:
"""Get the SageWorks specific metadata for this Artifact
Returns:
Union[dict, None]: Dictionary of SageWorks metadata for this Artifact
Note: This functionality will work for FeatureSets, Models, and Endpoints
but not for DataSources and Graphs, those classes need to override this method.
"""
Expand Down
25 changes: 18 additions & 7 deletions src/sageworks/core/artifacts/endpoint_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,14 +336,21 @@ def auto_inference(self, capture: bool = False) -> pd.DataFrame:
capture (bool, optional): Capture the inference results and metrics (default=False)
"""

# Backtrack to the FeatureSet
model_name = self.get_input()
fs_name = ModelCore(model_name).get_input()
fs = FeatureSetCore(fs_name)
# Sanity Check that we have a model
model = ModelCore(self.get_input())
if not model.exists():
self.log.error("No model found for this endpoint. Returning empty DataFrame.")
return pd.DataFrame()

# Now get the FeatureSet and make sure it exists
fs = FeatureSetCore(model.get_input())
if not fs.exists():
self.log.error("No FeatureSet found for this endpoint. Returning empty DataFrame.")
return pd.DataFrame()

# Grab the evaluation data from the FeatureSet
table = fs.view("training").table
eval_df = fs.query(f'SELECT * FROM "{table}" where training = 0')
eval_df = fs.query(f'SELECT * FROM "{table}" where training = FALSE')
capture_uuid = "auto_inference" if capture else None
return self.inference(eval_df, capture_uuid, id_column=fs.id_column)

Expand Down Expand Up @@ -647,8 +654,7 @@ def _capture_inference_results(
self.log.important(f"Recomputing Details for {self.uuid} to show latest Inference Results...")
self.details(recompute=True)

@staticmethod
def regression_metrics(target_column: str, prediction_df: pd.DataFrame) -> pd.DataFrame:
def regression_metrics(self, target_column: str, prediction_df: pd.DataFrame) -> pd.DataFrame:
"""Compute the performance metrics for this Endpoint
Args:
target_column (str): Name of the target column
Expand All @@ -657,6 +663,11 @@ def regression_metrics(target_column: str, prediction_df: pd.DataFrame) -> pd.Da
pd.DataFrame: DataFrame with the performance metrics
"""

# Sanity Check the prediction DataFrame
if prediction_df.empty:
self.log.warning("No predictions were made. Returning empty DataFrame.")
return pd.DataFrame()

# Compute the metrics
y_true = prediction_df[target_column]
prediction_col = "prediction" if "prediction" in prediction_df.columns else "predictions"
Expand Down
19 changes: 16 additions & 3 deletions src/sageworks/core/artifacts/model_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,12 @@ def get_inference_metrics(self, capture_uuid: str = "latest") -> Union[pd.DataFr

# Grab the metrics captured during model training (could return None)
if capture_uuid == "model_training":
# Sanity check the sageworks metadata
if self.sageworks_meta() is None:
error_msg = f"Model {self.model_name} has no sageworks_meta(). Either onboard() or delete this model!"
self.log.critical(error_msg)
raise ValueError(error_msg)

metrics = self.sageworks_meta().get("sageworks_training_metrics")
return pd.DataFrame.from_dict(metrics) if metrics else None

Expand All @@ -266,6 +272,13 @@ def confusion_matrix(self, capture_uuid: str = "latest") -> Union[pd.DataFrame,
Returns:
pd.DataFrame: DataFrame of the Confusion Matrix (might be None)
"""

# Sanity check the sageworks metadata
if self.sageworks_meta() is None:
error_msg = f"Model {self.model_name} has no sageworks_meta(). Either onboard() or delete this model!"
self.log.critical(error_msg)
raise ValueError(error_msg)

# Grab the metrics from the SageWorks Metadata (try inference first, then training)
if capture_uuid == "latest":
cm = self.sageworks_meta().get("sageworks_inference_cm")
Expand Down Expand Up @@ -317,17 +330,17 @@ def arn(self) -> str:

def group_arn(self) -> Union[str, None]:
"""AWS ARN (Amazon Resource Name) for the Model Package Group"""
return self.model_meta["ModelPackageGroupArn"]
return self.model_meta["ModelPackageGroupArn"] if self.model_meta else None

def model_package_arn(self) -> Union[str, None]:
"""AWS ARN (Amazon Resource Name) for the Latest Model Package (within the Group)"""
if self.latest_model is None:
return None
return self.latest_model["ModelPackageArn"]

def container_info(self) -> dict:
def container_info(self) -> Union[dict, None]:
"""Container Info for the Latest Model Package"""
return self.latest_model["InferenceSpecification"]["Containers"][0]
return self.latest_model["InferenceSpecification"]["Containers"][0] if self.latest_model else None

def container_image(self) -> str:
"""Container Image for the Latest Model Package"""
Expand Down
2 changes: 1 addition & 1 deletion src/sageworks/core/artifacts/monitor_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ def monitoring_schedule_exists(self):
#

# Make predictions on the Endpoint using the FeatureSet evaluation data
pred_df = endpoint_utils.predictions_using_fs(my_endpoint)
pred_df = my_endpoint.auto_inference()
print(pred_df.head())

# Check that data capture is working
Expand Down
9 changes: 8 additions & 1 deletion src/sageworks/core/cloud_platform/aws/aws_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,8 +491,15 @@ def s3_describe_objects(self, bucket: str) -> Union[dict, None]:
return wr.s3.describe_objects(path=bucket, boto3_session=self.boto3_session)

@aws_throttle
def get_aws_tags(self, arn: str) -> dict:
def get_aws_tags(self, arn: str) -> Union[dict, None]:
"""List the tags for the given AWS ARN"""

# Sanity check the ARN
if arn is None:
self.log.error("ARN is None, cannot retrieve tags.")
return None

# Grab the tags from AWS
return aws_tags_to_dict(self.sm_session.list_tags(resource_arn=arn))

@aws_throttle
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,8 @@ if __name__ == "__main__":
# Does the dataframe have a training column?
elif "training" in all_df.columns:
print("Found training column, splitting data based on training column")
df_train = all_df[all_df["training"] == 1].copy()
df_val = all_df[all_df["training"] == 0].copy()
df_train = all_df[all_df["training"]].copy()
df_val = all_df[~all_df["training"]].copy()
else:
# Just do a random training Split
print("WARNING: No training column found, splitting data with random state=42")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ if __name__ == "__main__":
# Does the dataframe have a training column?
elif "training" in all_df.columns:
print("Found training column, splitting data based on training column")
df_train = all_df[all_df["training"] == 1].copy()
df_val = all_df[all_df["training"] == 0].copy()
df_train = all_df[all_df["training"]].copy()
df_val = all_df[~all_df["training"]].copy()
else:
# Just do a random training Split
print("WARNING: No training column found, splitting data with random state=42")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def prep_dataframe(self):
"""Training column detected: Since FeatureSets are read-only, SageWorks creates a training view
that can be dynamically changed. We'll use this training column to create a training view."""
)
self.incoming_hold_out_ids = self.output_df[self.output_df["training"] == 0][self.id_column].tolist()
self.incoming_hold_out_ids = self.output_df[~self.output_df["training"]][self.id_column].tolist()
self.output_df = self.output_df.drop(columns=["training"])

def create_feature_group(self):
Expand Down Expand Up @@ -423,8 +423,8 @@ def wait_for_rows(self, expected_rows: int):
data_df = ds.sample()

# Test setting a training column
data_df["training"] = 0
data_df.loc[0:10, "training"] = 1
data_df["training"] = False
data_df.loc[0:10, "training"] = True

# Create my DF to Feature Set Transform (with one-hot encoding)
df_to_features = PandasToFeatures("test_features")
Expand Down
Loading

0 comments on commit dfe6c7b

Please sign in to comment.