diff --git a/skit_labels/utils.py b/skit_labels/utils.py index 2982e56..b35acb9 100644 --- a/skit_labels/utils.py +++ b/skit_labels/utils.py @@ -117,30 +117,38 @@ def validate_headers(input_file, tagging_type): expected_headers = expected_columns_mapping.get(tagging_type) df = pd.read_csv(input_file) + column_headers = df.columns.to_list() column_headers = [header.lower() for header in column_headers] column_headers = sorted(column_headers) expected_headers = sorted(expected_headers) + logger.info(f"column_headers: {column_headers}") logger.info(f"expected_headers: {expected_headers}") is_match = column_headers == expected_headers - mismatch_headers = [] logger.info(f"Is match: {is_match}") if not is_match: - mismatch_headers_set =set(column_headers).symmetric_difference(set(expected_headers)) - mismatch_headers = list(mismatch_headers_set) - return is_match, mismatch_headers + missing_headers = set(expected_headers).difference(set(column_headers)) + additional_headers = set(column_headers).difference(set(expected_headers)) + if missing_headers: + return missing_headers + elif additional_headers: + df.drop(additional_headers, axis=1, inplace=True) + df.to_csv(input_file, index=False) + is_match = True + logger.info(f"Following additional headers have been removed from the csv: {additional_headers}") + return [] def validate_input_data(tagging_type, input_file): is_valid = True error = '' if tagging_type == const.CONVERSATION_TAGGING: - is_match, mismatch_headers = validate_headers(input_file, tagging_type) - if not is_match: - error = f'Headers in the input file does not match the expected fields. Mismatched fields = {mismatch_headers}' + missing_headers = validate_headers(input_file, tagging_type) + if missing_headers: + error = f'Headers in the input file does not match the expected fields. Missing fields = {missing_headers}' is_valid = False return is_valid, error \ No newline at end of file