-
Notifications
You must be signed in to change notification settings - Fork 676
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
7f0315f
commit 54194b4
Showing
9 changed files
with
373 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== | ||
# Licensed under the Apache License, Version 2.0 (the “License”); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an “AS IS” BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== | ||
|
||
from io import BytesIO | ||
from typing import List, Any, Optional | ||
import re | ||
import docx2txt | ||
from langchain.docstore.document import Document | ||
import fitz | ||
from hashlib import md5 | ||
from abc import abstractmethod, ABC | ||
from copy import deepcopy | ||
import json | ||
from bs4 import BeautifulSoup | ||
|
||
class File(ABC): | ||
"""Represents an uploaded file comprised of Documents""" | ||
|
||
def __init__( | ||
self, | ||
name: str, | ||
id: str, | ||
metadata: Optional[dict[str, Any]] = None, | ||
docs: Optional[List[Document]] = None, | ||
): | ||
self.name = name | ||
self.id = id | ||
self.metadata = metadata or {} | ||
self.docs = docs or [] | ||
|
||
@classmethod | ||
@abstractmethod | ||
def from_bytes(cls, file: BytesIO) -> "File": | ||
"""Creates a File from a BytesIO object""" | ||
|
||
def __repr__(self) -> str: | ||
return ( | ||
f"File(name={self.name}, id={self.id}, " | ||
f"metadata={self.metadata}, docs={self.docs})" | ||
) | ||
|
||
def __str__(self) -> str: | ||
return f"File(name={self.name}, id={self.id}, metadata={self.metadata})" | ||
|
||
def copy(self) -> "File": | ||
"""Create a deep copy of this File""" | ||
return self.__class__( | ||
name=self.name, | ||
id=self.id, | ||
metadata=deepcopy(self.metadata), | ||
docs=deepcopy(self.docs), | ||
) | ||
|
||
def strip_consecutive_newlines(text: str) -> str: | ||
"""Strips consecutive newlines from a string, possibly with whitespace in between""" | ||
return re.sub(r"\s*\n\s*", "\n", text) | ||
|
||
class DocxFile(File): | ||
@classmethod | ||
def from_bytes(cls, file: BytesIO) -> "DocxFile": | ||
# Use docx2txt to extract text from docx files | ||
text = docx2txt.process(file) | ||
text = strip_consecutive_newlines(text) | ||
# Create a Document object from the extracted text | ||
doc = Document(page_content=text.strip()) | ||
# Calculate a unique identifier for the file | ||
file_id = md5(file.read()).hexdigest() | ||
# Reset the file pointer to the beginning | ||
file.seek(0) | ||
return cls(name=file.name, id=file_id, docs=[doc]) | ||
|
||
class PdfFile(File): | ||
@classmethod | ||
def from_bytes(cls, file: BytesIO) -> "PdfFile": | ||
# Use fitz to extract text from pdf files | ||
pdf = fitz.open(stream=file.read(), filetype="pdf") | ||
docs = [] | ||
for i, page in enumerate(pdf): | ||
text = page.get_text(sort=True) | ||
text = strip_consecutive_newlines(text) | ||
# Create a Document object from the extracted text | ||
doc = Document(page_content=text.strip()) | ||
doc.metadata["page"] = i + 1 | ||
docs.append(doc) | ||
# Calculate a unique identifier for the file | ||
file_id = md5(file.read()).hexdigest() | ||
# Reset the file pointer to the beginning | ||
file.seek(0) | ||
return cls(name=file.name, id=file_id, docs=docs) | ||
|
||
class TxtFile(File): | ||
@classmethod | ||
def from_bytes(cls, file: BytesIO) -> "TxtFile": | ||
# Read the text from the file | ||
text = file.read().decode("utf-8") | ||
text = strip_consecutive_newlines(text) | ||
# Create a Document object from the extracted text | ||
doc = Document(page_content=text.strip()) | ||
# Calculate a unique identifier for the file | ||
file_id = md5(file.read()).hexdigest() | ||
# Reset the file pointer to the beginning | ||
file.seek(0) | ||
return cls(name=file.name, id=file_id, docs=[doc]) | ||
|
||
class JsonFile(File): | ||
@classmethod | ||
def from_bytes(cls, file: BytesIO) -> "JsonFile": | ||
# Parse the JSON data from the file | ||
#data = json.loads(file.read()) | ||
data = json.dumps(json.load(file)) | ||
# Create a Document object from the parsed data | ||
doc = Document(page_content=data) | ||
# Calculate a unique identifier for the file | ||
file_id = md5(file.read()).hexdigest() | ||
# Reset the file pointer to the beginning | ||
file.seek(0) | ||
return cls(name=file.name, id=file_id, docs=[doc]) | ||
|
||
class HtmlFile(File): | ||
@classmethod | ||
def from_bytes(cls, file: BytesIO) -> "HtmlFile": | ||
# Parse the HTML data from the file | ||
soup = BeautifulSoup(file, "html.parser") | ||
text = soup.get_text() | ||
text = strip_consecutive_newlines(text) | ||
# Create a Document object from the parsed data | ||
doc = Document(page_content=text.strip()) | ||
# Calculate a unique identifier for the file | ||
file_id = md5(file.read()).hexdigest() | ||
# Reset the file pointer to the beginning | ||
file.seek(0) | ||
return cls(name=file.name, id=file_id, docs=[doc]) | ||
|
||
def read_file(file: BytesIO) -> File: | ||
"""Reads an uploaded file and returns a File object""" | ||
# Determine the file type based on the file extension | ||
if file.name.lower().endswith(".docx"): | ||
return DocxFile.from_bytes(file) | ||
elif file.name.lower().endswith(".pdf"): | ||
return PdfFile.from_bytes(file) | ||
elif file.name.lower().endswith(".txt"): | ||
return TxtFile.from_bytes(file) | ||
elif file.name.lower().endswith(".json"): | ||
return JsonFile.from_bytes(file) | ||
elif file.name.lower().endswith(".html"): | ||
return HtmlFile.from_bytes(file) | ||
else: | ||
raise NotImplementedError( | ||
f"File type {file.name.split('.')[-1]} not supported" | ||
) |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
Hello World |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
{ | ||
"message": "Hello World" | ||
} |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
Hello World |
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,207 @@ | ||
import pytest | ||
from io import BytesIO | ||
from pathlib import Path | ||
from camel.functions.data_io_functions import ( | ||
DocxFile, | ||
PdfFile, | ||
TxtFile, | ||
JsonFile, | ||
HtmlFile, | ||
read_file, | ||
strip_consecutive_newlines, | ||
) | ||
from langchain.docstore.document import Document | ||
|
||
from camel.functions.data_io_functions import (File) | ||
import fitz | ||
|
||
|
||
# Define a FakeFile class for testing purposes | ||
class FakeFile(File): | ||
"""A fake file for testing purposes""" | ||
|
||
@classmethod | ||
def from_bytes(cls, file: BytesIO) -> "FakeFile": | ||
return NotImplemented | ||
|
||
# Define paths to test resources | ||
UNIT_TESTS_ROOT = Path(__file__).parent.resolve() | ||
TESTS_ROOT = UNIT_TESTS_ROOT.parent.resolve() | ||
PROJECT_ROOT = TESTS_ROOT.parent.resolve() | ||
RESOURCE_ROOT = PROJECT_ROOT / "test" | ||
SAMPLE_ROOT = RESOURCE_ROOT / "data_samples" | ||
|
||
|
||
# Test functions for each file type | ||
def test_docx_file(): | ||
with open(SAMPLE_ROOT / "test_hello.docx", "rb") as f: | ||
file = BytesIO(f.read()) | ||
file.name = "test.docx" | ||
docx_file = DocxFile.from_bytes(file) | ||
|
||
assert docx_file.name == "test.docx" | ||
assert len(docx_file.docs) == 1 | ||
assert docx_file.docs[0].page_content == "Hello World" | ||
|
||
|
||
def test_docx_file_with_multiple_pages(): | ||
with open(SAMPLE_ROOT / "test_hello_multi.docx", "rb") as f: | ||
file = BytesIO(f.read()) | ||
file.name = "test.docx" | ||
docx_file = DocxFile.from_bytes(file) | ||
|
||
assert docx_file.name == "test.docx" | ||
assert len(docx_file.docs) == 1 | ||
assert ( | ||
docx_file.docs[0].page_content | ||
== "Hello World 1\nHello World 2\nHello World 3" | ||
) | ||
|
||
|
||
def test_pdf_file_with_single_page(): | ||
with open(SAMPLE_ROOT / "test_hello.pdf", "rb") as f: | ||
file = BytesIO(f.read()) | ||
file.name = "test_hello.pdf" | ||
pdf_file = PdfFile.from_bytes(file) | ||
|
||
assert pdf_file.name == "test_hello.pdf" | ||
assert len(pdf_file.docs) == 1 | ||
assert pdf_file.docs[0].page_content == "Hello World" | ||
|
||
|
||
def test_pdf_file_with_multiple_pages(): | ||
with open(SAMPLE_ROOT / "test_hello_multi.pdf", "rb") as f: | ||
file = BytesIO(f.read()) | ||
file.name = "test_hello_multiple.pdf" | ||
pdf_file = PdfFile.from_bytes(file) | ||
|
||
assert pdf_file.name == "test_hello_multiple.pdf" | ||
assert len(pdf_file.docs) == 3 | ||
assert pdf_file.docs[0].page_content == "Hello World 1" | ||
assert pdf_file.docs[1].page_content == "Hello World 2" | ||
assert pdf_file.docs[2].page_content == "Hello World 3" | ||
assert pdf_file.docs[0].metadata["page"] == 1 | ||
assert pdf_file.docs[1].metadata["page"] == 2 | ||
assert pdf_file.docs[2].metadata["page"] == 3 | ||
|
||
|
||
def test_txt_file(): | ||
with open(SAMPLE_ROOT / "test_hello.txt", "rb") as f: | ||
file = BytesIO(f.read()) | ||
file.name = "test.txt" | ||
txt_file = TxtFile.from_bytes(file) | ||
|
||
assert txt_file.name == "test.txt" | ||
assert len(txt_file.docs) == 1 | ||
assert txt_file.docs[0].page_content == "Hello World" | ||
|
||
|
||
def test_json_file(): | ||
with open(SAMPLE_ROOT / "test_hello.json", "rb") as f: | ||
file = BytesIO(f.read()) | ||
file.name = "test.json" | ||
json_file = JsonFile.from_bytes(file) | ||
|
||
assert json_file.name == "test.json" | ||
assert len(json_file.docs) == 1 | ||
assert json_file.docs[0].page_content == '{"message": "Hello World"}' | ||
|
||
|
||
|
||
def test_html_file(): | ||
with open(SAMPLE_ROOT / "test_hello.html", "rb") as f: | ||
file = BytesIO(f.read()) | ||
file.name = "test.html" | ||
html_file = HtmlFile.from_bytes(file) | ||
|
||
assert html_file.name == "test.html" | ||
assert len(html_file.docs) == 1 | ||
assert html_file.docs[0].page_content == "Hello World" | ||
|
||
|
||
# Test the `read_file` function with each file type | ||
def test_read_file(): | ||
for ext, FileClass in [ | ||
(".docx", DocxFile), | ||
(".pdf", PdfFile), | ||
(".txt", TxtFile), | ||
(".json", JsonFile), | ||
(".html", HtmlFile), | ||
]: | ||
with open(SAMPLE_ROOT / f"test_hello{ext}", "rb") as f: | ||
file = BytesIO(f.read()) | ||
file.name = f"test_hello{ext}" | ||
file_obj = read_file(file) | ||
|
||
assert isinstance(file_obj, FileClass) | ||
assert file_obj.name == f"test_hello{ext}" | ||
assert len(file_obj.docs) == 1 | ||
assert file_obj.docs[0].page_content == "Hello World" or '{"message": "Hello World"}' | ||
|
||
|
||
# Test that read_file raises a NotImplementedError for unsupported file types | ||
def test_read_file_not_implemented(): | ||
file = BytesIO(b"Hello World") | ||
file.name = "test.unknown" | ||
with pytest.raises(NotImplementedError): | ||
read_file(file) | ||
|
||
|
||
# Test the File.copy() method | ||
def test_file_copy(): | ||
# Create a Document and FakeFile instance | ||
document = Document(page_content="test content", metadata={"page": "1"}) | ||
file = FakeFile("test_file", "1234", {"author": "test"}, [document]) | ||
|
||
# Create a copy of the file | ||
file_copy = file.copy() | ||
|
||
# Check that the original and copy are distinct objects | ||
assert file is not file_copy | ||
|
||
# Check that the copy has the same attributes as the original | ||
assert file.name == file_copy.name | ||
assert file.id == file_copy.id | ||
|
||
# Check that the mutable attributes were deeply copied | ||
assert file.metadata == file_copy.metadata | ||
assert file.metadata is not file_copy.metadata | ||
|
||
# Check that the documents were deeply copied | ||
assert file.docs == file_copy.docs | ||
assert file.docs is not file_copy.docs | ||
|
||
# Check that individual documents are not the same objects | ||
assert file.docs[0] is not file_copy.docs[0] | ||
|
||
# Check that the documents have the same attributes | ||
assert file.docs[0].page_content == file_copy.docs[0].page_content | ||
assert file.docs[0].metadata == file_copy.docs[0].metadata | ||
|
||
|
||
# Test the strip_consecutive_newlines function | ||
def test_strip_consecutive_newlines(): | ||
# Test with multiple consecutive newlines | ||
text = "\n\n\n" | ||
expected = "\n" | ||
assert strip_consecutive_newlines(text) == expected | ||
|
||
# Test with newlines and spaces | ||
text = "\n \n \n" | ||
expected = "\n" | ||
assert strip_consecutive_newlines(text) == expected | ||
|
||
# Test with newlines and tabs | ||
text = "\n\t\n\t\n" | ||
expected = "\n" | ||
assert strip_consecutive_newlines(text) == expected | ||
|
||
# Test with mixed whitespace characters | ||
text = "\n \t\n \t \n" | ||
expected = "\n" | ||
assert strip_consecutive_newlines(text) == expected | ||
|
||
# Test with no consecutive newlines | ||
text = "\nHello\nWorld\n" | ||
expected = "\nHello\nWorld\n" | ||
assert strip_consecutive_newlines(text) == expected |