Skip to content

Commit

Permalink
Refactor the column check function
Browse files Browse the repository at this point in the history
  • Loading branch information
d-shree committed Dec 27, 2023
1 parent dd8d2ab commit 3cbc67e
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions skit_labels/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,30 +117,38 @@ def validate_headers(input_file, tagging_type):
expected_headers = expected_columns_mapping.get(tagging_type)

df = pd.read_csv(input_file)

column_headers = df.columns.to_list()
column_headers = [header.lower() for header in column_headers]
column_headers = sorted(column_headers)
expected_headers = sorted(expected_headers)

logger.info(f"column_headers: {column_headers}")
logger.info(f"expected_headers: {expected_headers}")

is_match = column_headers == expected_headers
mismatch_headers = []
logger.info(f"Is match: {is_match}")

if not is_match:
mismatch_headers_set =set(column_headers).symmetric_difference(set(expected_headers))
mismatch_headers = list(mismatch_headers_set)
return is_match, mismatch_headers
missing_headers = set(expected_headers).difference(set(column_headers))
additional_headers = set(column_headers).difference(set(expected_headers))
if missing_headers:
return missing_headers
elif additional_headers:
df.drop(additional_headers, axis=1, inplace=True)
df.to_csv(input_file, index=False)
is_match = True
logger.info(f"Following additional headers have been removed from the csv: {additional_headers}")
return []


def validate_input_data(tagging_type, input_file):
is_valid = True
error = ''
if tagging_type == const.CONVERSATION_TAGGING:
is_match, mismatch_headers = validate_headers(input_file, tagging_type)
if not is_match:
error = f'Headers in the input file does not match the expected fields. Mismatched fields = {mismatch_headers}'
missing_headers = validate_headers(input_file, tagging_type)
if missing_headers:
error = f'Headers in the input file does not match the expected fields. Missing fields = {missing_headers}'
is_valid = False

return is_valid, error

0 comments on commit 3cbc67e

Please sign in to comment.