Skip to content

Commit

Permalink
Add validation for tagging type (#23)
Browse files Browse the repository at this point in the history
* Add validation for tagging type

* Fix import

* Update the expected fields

* Move values to constants

* Update the expected fields

* Add support for tagging type arg in cli

* Update the version

* Refactor the column check function
  • Loading branch information
d-shree authored Dec 27, 2023
1 parent 3900ac6 commit 5094419
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 6 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# CHANGELOG
## 0.3.35
- [x] add: validations for the input file for conversation tagging

## 0.3.34
- [x] PL-61: Add retry mechanism for uploading data to Label studio
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "skit-labels"
version = "0.3.34"
version = "0.3.35"
description = "Command line tool for interacting with labelled datasets at skit.ai."
authors = []
license = "MIT"
Expand Down
18 changes: 15 additions & 3 deletions skit_labels/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,13 @@ def upload_dataset_to_labelstudio_command(
required=True,
help="The data label implying the source of data",
)

parser.add_argument(
"--tagging-type",
type=str,
help="The tagging type for the calls being uploaded",
)

return parser


Expand Down Expand Up @@ -319,12 +326,17 @@ def build_cli():
return parser


def upload_dataset(input_file, url, token, job_id, data_source, data_label = None):
def upload_dataset(input_file, url, token, job_id, data_source, data_label = None, tagging_type=None):
input_file = utils.add_data_label(input_file, data_label)
if data_source == const.SOURCE__DB:
fn = commands.upload_dataset_to_db
elif data_source == const.SOURCE__LABELSTUDIO:
fn = commands.upload_dataset_to_labelstudio
if tagging_type:
is_valid, error = utils.validate_input_data(tagging_type, input_file)
if not is_valid:
return error, None

fn = commands.upload_dataset_to_labelstudio
errors, df_size = asyncio.run(
fn(
input_file,
Expand Down Expand Up @@ -386,7 +398,7 @@ def cmd_to_str(args: argparse.Namespace) -> str:
arg_id = args.job_id

_ = is_valid_data_label(args.data_label)
errors, df_size = upload_dataset(args.input, args.url, args.token, arg_id, args.data_source, args.data_label)
errors, df_size = upload_dataset(args.input, args.url, args.token, arg_id, args.data_source, args.data_label, args.tagging_type)

if errors:
return (
Expand Down
8 changes: 7 additions & 1 deletion skit_labels/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,4 +120,10 @@
FROM_NAME_INTENT = "tag"
CHOICES = "choices"
TAXONOMY = "taxonomy"
VALUE = "value"
VALUE = "value"

EXPECTED_COLUMNS_MAPPING = {
"conversation_tagging": ['scenario', 'scenario_category', 'situation_str', 'call', 'data_label']
}

CONVERSATION_TAGGING = 'conversation_tagging'
44 changes: 43 additions & 1 deletion skit_labels/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from datetime import datetime
import pandas as pd
from typing import Union

from skit_labels import constants as const

LOG_LEVELS = ["CRITICAL", "ERROR", "WARNING", "SUCCESS", "INFO", "DEBUG", "TRACE"]

Expand Down Expand Up @@ -110,3 +110,45 @@ def add_data_label(input_file: str, data_label: Optional[str] = None) -> str:
df = df.assign(data_label=data_label)
df.to_csv(input_file, index=False)
return input_file


def validate_headers(input_file, tagging_type):
expected_columns_mapping = const.EXPECTED_COLUMNS_MAPPING
expected_headers = expected_columns_mapping.get(tagging_type)

df = pd.read_csv(input_file)

column_headers = df.columns.to_list()
column_headers = [header.lower() for header in column_headers]
column_headers = sorted(column_headers)
expected_headers = sorted(expected_headers)

logger.info(f"column_headers: {column_headers}")
logger.info(f"expected_headers: {expected_headers}")

is_match = column_headers == expected_headers
logger.info(f"Is match: {is_match}")

if not is_match:
missing_headers = set(expected_headers).difference(set(column_headers))
additional_headers = set(column_headers).difference(set(expected_headers))
if missing_headers:
return missing_headers
elif additional_headers:
df.drop(additional_headers, axis=1, inplace=True)
df.to_csv(input_file, index=False)
is_match = True
logger.info(f"Following additional headers have been removed from the csv: {additional_headers}")
return []


def validate_input_data(tagging_type, input_file):
is_valid = True
error = ''
if tagging_type == const.CONVERSATION_TAGGING:
missing_headers = validate_headers(input_file, tagging_type)
if missing_headers:
error = f'Headers in the input file does not match the expected fields. Missing fields = {missing_headers}'
is_valid = False

return is_valid, error

0 comments on commit 5094419

Please sign in to comment.