diff --git a/CHANGELOG.md b/CHANGELOG.md index d7fbc8c..8fd7b56 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,6 @@ # CHANGELOG +## 0.3.35 +- [x] add: validations for the input file for conversation tagging ## 0.3.34 - [x] PL-61: Add retry mechanism for uploading data to Label studio diff --git a/pyproject.toml b/pyproject.toml index db35829..810a948 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "skit-labels" -version = "0.3.34" +version = "0.3.35" description = "Command line tool for interacting with labelled datasets at skit.ai." authors = [] license = "MIT" diff --git a/skit_labels/cli.py b/skit_labels/cli.py index 9dba2df..9d805fe 100644 --- a/skit_labels/cli.py +++ b/skit_labels/cli.py @@ -227,6 +227,13 @@ def upload_dataset_to_labelstudio_command( required=True, help="The data label implying the source of data", ) + + parser.add_argument( + "--tagging-type", + type=str, + help="The tagging type for the calls being uploaded", + ) + return parser @@ -319,12 +326,17 @@ def build_cli(): return parser -def upload_dataset(input_file, url, token, job_id, data_source, data_label = None): +def upload_dataset(input_file, url, token, job_id, data_source, data_label = None, tagging_type=None): input_file = utils.add_data_label(input_file, data_label) if data_source == const.SOURCE__DB: fn = commands.upload_dataset_to_db elif data_source == const.SOURCE__LABELSTUDIO: - fn = commands.upload_dataset_to_labelstudio + if tagging_type: + is_valid, error = utils.validate_input_data(tagging_type, input_file) + if not is_valid: + return error, None + + fn = commands.upload_dataset_to_labelstudio errors, df_size = asyncio.run( fn( input_file, @@ -386,7 +398,7 @@ def cmd_to_str(args: argparse.Namespace) -> str: arg_id = args.job_id _ = is_valid_data_label(args.data_label) - errors, df_size = upload_dataset(args.input, args.url, args.token, arg_id, args.data_source, args.data_label) + errors, df_size = upload_dataset(args.input, args.url, args.token, arg_id, args.data_source, args.data_label, args.tagging_type) if errors: return ( diff --git a/skit_labels/constants.py b/skit_labels/constants.py index ae80e08..c612d9b 100644 --- a/skit_labels/constants.py +++ b/skit_labels/constants.py @@ -120,4 +120,10 @@ FROM_NAME_INTENT = "tag" CHOICES = "choices" TAXONOMY = "taxonomy" -VALUE = "value" \ No newline at end of file +VALUE = "value" + +EXPECTED_COLUMNS_MAPPING = { + "conversation_tagging": ['scenario', 'scenario_category', 'situation_str', 'call', 'data_label'] +} + +CONVERSATION_TAGGING = 'conversation_tagging' \ No newline at end of file diff --git a/skit_labels/utils.py b/skit_labels/utils.py index 7979eac..b35acb9 100644 --- a/skit_labels/utils.py +++ b/skit_labels/utils.py @@ -10,7 +10,7 @@ from datetime import datetime import pandas as pd from typing import Union - +from skit_labels import constants as const LOG_LEVELS = ["CRITICAL", "ERROR", "WARNING", "SUCCESS", "INFO", "DEBUG", "TRACE"] @@ -110,3 +110,45 @@ def add_data_label(input_file: str, data_label: Optional[str] = None) -> str: df = df.assign(data_label=data_label) df.to_csv(input_file, index=False) return input_file + + +def validate_headers(input_file, tagging_type): + expected_columns_mapping = const.EXPECTED_COLUMNS_MAPPING + expected_headers = expected_columns_mapping.get(tagging_type) + + df = pd.read_csv(input_file) + + column_headers = df.columns.to_list() + column_headers = [header.lower() for header in column_headers] + column_headers = sorted(column_headers) + expected_headers = sorted(expected_headers) + + logger.info(f"column_headers: {column_headers}") + logger.info(f"expected_headers: {expected_headers}") + + is_match = column_headers == expected_headers + logger.info(f"Is match: {is_match}") + + if not is_match: + missing_headers = set(expected_headers).difference(set(column_headers)) + additional_headers = set(column_headers).difference(set(expected_headers)) + if missing_headers: + return missing_headers + elif additional_headers: + df.drop(additional_headers, axis=1, inplace=True) + df.to_csv(input_file, index=False) + is_match = True + logger.info(f"Following additional headers have been removed from the csv: {additional_headers}") + return [] + + +def validate_input_data(tagging_type, input_file): + is_valid = True + error = '' + if tagging_type == const.CONVERSATION_TAGGING: + missing_headers = validate_headers(input_file, tagging_type) + if missing_headers: + error = f'Headers in the input file does not match the expected fields. Missing fields = {missing_headers}' + is_valid = False + + return is_valid, error \ No newline at end of file