Skip to content

Commit

Permalink
add_data_io_feature
Browse files Browse the repository at this point in the history
  • Loading branch information
Wendong-Fan committed Sep 7, 2023
1 parent 7f0315f commit 54194b4
Show file tree
Hide file tree
Showing 9 changed files with 373 additions and 0 deletions.
161 changes: 161 additions & 0 deletions camel/functions/data_io_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========

from io import BytesIO
from typing import List, Any, Optional
import re
import docx2txt
from langchain.docstore.document import Document
import fitz
from hashlib import md5
from abc import abstractmethod, ABC
from copy import deepcopy
import json
from bs4 import BeautifulSoup

class File(ABC):
"""Represents an uploaded file comprised of Documents"""

def __init__(
self,
name: str,
id: str,
metadata: Optional[dict[str, Any]] = None,
docs: Optional[List[Document]] = None,
):
self.name = name
self.id = id
self.metadata = metadata or {}
self.docs = docs or []

@classmethod
@abstractmethod
def from_bytes(cls, file: BytesIO) -> "File":
"""Creates a File from a BytesIO object"""

def __repr__(self) -> str:
return (
f"File(name={self.name}, id={self.id}, "
f"metadata={self.metadata}, docs={self.docs})"
)

def __str__(self) -> str:
return f"File(name={self.name}, id={self.id}, metadata={self.metadata})"

def copy(self) -> "File":
"""Create a deep copy of this File"""
return self.__class__(
name=self.name,
id=self.id,
metadata=deepcopy(self.metadata),
docs=deepcopy(self.docs),
)

def strip_consecutive_newlines(text: str) -> str:
"""Strips consecutive newlines from a string, possibly with whitespace in between"""
return re.sub(r"\s*\n\s*", "\n", text)

class DocxFile(File):
@classmethod
def from_bytes(cls, file: BytesIO) -> "DocxFile":
# Use docx2txt to extract text from docx files
text = docx2txt.process(file)
text = strip_consecutive_newlines(text)
# Create a Document object from the extracted text
doc = Document(page_content=text.strip())
# Calculate a unique identifier for the file
file_id = md5(file.read()).hexdigest()
# Reset the file pointer to the beginning
file.seek(0)
return cls(name=file.name, id=file_id, docs=[doc])

class PdfFile(File):
@classmethod
def from_bytes(cls, file: BytesIO) -> "PdfFile":
# Use fitz to extract text from pdf files
pdf = fitz.open(stream=file.read(), filetype="pdf")
docs = []
for i, page in enumerate(pdf):
text = page.get_text(sort=True)
text = strip_consecutive_newlines(text)
# Create a Document object from the extracted text
doc = Document(page_content=text.strip())
doc.metadata["page"] = i + 1
docs.append(doc)
# Calculate a unique identifier for the file
file_id = md5(file.read()).hexdigest()
# Reset the file pointer to the beginning
file.seek(0)
return cls(name=file.name, id=file_id, docs=docs)

class TxtFile(File):
@classmethod
def from_bytes(cls, file: BytesIO) -> "TxtFile":
# Read the text from the file
text = file.read().decode("utf-8")
text = strip_consecutive_newlines(text)
# Create a Document object from the extracted text
doc = Document(page_content=text.strip())
# Calculate a unique identifier for the file
file_id = md5(file.read()).hexdigest()
# Reset the file pointer to the beginning
file.seek(0)
return cls(name=file.name, id=file_id, docs=[doc])

class JsonFile(File):
@classmethod
def from_bytes(cls, file: BytesIO) -> "JsonFile":
# Parse the JSON data from the file
#data = json.loads(file.read())
data = json.dumps(json.load(file))
# Create a Document object from the parsed data
doc = Document(page_content=data)
# Calculate a unique identifier for the file
file_id = md5(file.read()).hexdigest()
# Reset the file pointer to the beginning
file.seek(0)
return cls(name=file.name, id=file_id, docs=[doc])

class HtmlFile(File):
@classmethod
def from_bytes(cls, file: BytesIO) -> "HtmlFile":
# Parse the HTML data from the file
soup = BeautifulSoup(file, "html.parser")
text = soup.get_text()
text = strip_consecutive_newlines(text)
# Create a Document object from the parsed data
doc = Document(page_content=text.strip())
# Calculate a unique identifier for the file
file_id = md5(file.read()).hexdigest()
# Reset the file pointer to the beginning
file.seek(0)
return cls(name=file.name, id=file_id, docs=[doc])

def read_file(file: BytesIO) -> File:
"""Reads an uploaded file and returns a File object"""
# Determine the file type based on the file extension
if file.name.lower().endswith(".docx"):
return DocxFile.from_bytes(file)
elif file.name.lower().endswith(".pdf"):
return PdfFile.from_bytes(file)
elif file.name.lower().endswith(".txt"):
return TxtFile.from_bytes(file)
elif file.name.lower().endswith(".json"):
return JsonFile.from_bytes(file)
elif file.name.lower().endswith(".html"):
return HtmlFile.from_bytes(file)
else:
raise NotImplementedError(
f"File type {file.name.split('.')[-1]} not supported"
)
Binary file added test/data_samples/test_hello.docx
Binary file not shown.
1 change: 1 addition & 0 deletions test/data_samples/test_hello.html
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Hello World
3 changes: 3 additions & 0 deletions test/data_samples/test_hello.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"message": "Hello World"
}
Binary file added test/data_samples/test_hello.pdf
Binary file not shown.
1 change: 1 addition & 0 deletions test/data_samples/test_hello.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Hello World
Binary file added test/data_samples/test_hello_multi.docx
Binary file not shown.
Binary file added test/data_samples/test_hello_multi.pdf
Binary file not shown.
207 changes: 207 additions & 0 deletions test/functions/test_data_io_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
import pytest
from io import BytesIO
from pathlib import Path
from camel.functions.data_io_functions import (
DocxFile,
PdfFile,
TxtFile,
JsonFile,
HtmlFile,
read_file,
strip_consecutive_newlines,
)
from langchain.docstore.document import Document

from camel.functions.data_io_functions import (File)
import fitz


# Define a FakeFile class for testing purposes
class FakeFile(File):
"""A fake file for testing purposes"""

@classmethod
def from_bytes(cls, file: BytesIO) -> "FakeFile":
return NotImplemented

# Define paths to test resources
UNIT_TESTS_ROOT = Path(__file__).parent.resolve()
TESTS_ROOT = UNIT_TESTS_ROOT.parent.resolve()
PROJECT_ROOT = TESTS_ROOT.parent.resolve()
RESOURCE_ROOT = PROJECT_ROOT / "test"
SAMPLE_ROOT = RESOURCE_ROOT / "data_samples"


# Test functions for each file type
def test_docx_file():
with open(SAMPLE_ROOT / "test_hello.docx", "rb") as f:
file = BytesIO(f.read())
file.name = "test.docx"
docx_file = DocxFile.from_bytes(file)

assert docx_file.name == "test.docx"
assert len(docx_file.docs) == 1
assert docx_file.docs[0].page_content == "Hello World"


def test_docx_file_with_multiple_pages():
with open(SAMPLE_ROOT / "test_hello_multi.docx", "rb") as f:
file = BytesIO(f.read())
file.name = "test.docx"
docx_file = DocxFile.from_bytes(file)

assert docx_file.name == "test.docx"
assert len(docx_file.docs) == 1
assert (
docx_file.docs[0].page_content
== "Hello World 1\nHello World 2\nHello World 3"
)


def test_pdf_file_with_single_page():
with open(SAMPLE_ROOT / "test_hello.pdf", "rb") as f:
file = BytesIO(f.read())
file.name = "test_hello.pdf"
pdf_file = PdfFile.from_bytes(file)

assert pdf_file.name == "test_hello.pdf"
assert len(pdf_file.docs) == 1
assert pdf_file.docs[0].page_content == "Hello World"


def test_pdf_file_with_multiple_pages():
with open(SAMPLE_ROOT / "test_hello_multi.pdf", "rb") as f:
file = BytesIO(f.read())
file.name = "test_hello_multiple.pdf"
pdf_file = PdfFile.from_bytes(file)

assert pdf_file.name == "test_hello_multiple.pdf"
assert len(pdf_file.docs) == 3
assert pdf_file.docs[0].page_content == "Hello World 1"
assert pdf_file.docs[1].page_content == "Hello World 2"
assert pdf_file.docs[2].page_content == "Hello World 3"
assert pdf_file.docs[0].metadata["page"] == 1
assert pdf_file.docs[1].metadata["page"] == 2
assert pdf_file.docs[2].metadata["page"] == 3


def test_txt_file():
with open(SAMPLE_ROOT / "test_hello.txt", "rb") as f:
file = BytesIO(f.read())
file.name = "test.txt"
txt_file = TxtFile.from_bytes(file)

assert txt_file.name == "test.txt"
assert len(txt_file.docs) == 1
assert txt_file.docs[0].page_content == "Hello World"


def test_json_file():
with open(SAMPLE_ROOT / "test_hello.json", "rb") as f:
file = BytesIO(f.read())
file.name = "test.json"
json_file = JsonFile.from_bytes(file)

assert json_file.name == "test.json"
assert len(json_file.docs) == 1
assert json_file.docs[0].page_content == '{"message": "Hello World"}'



def test_html_file():
with open(SAMPLE_ROOT / "test_hello.html", "rb") as f:
file = BytesIO(f.read())
file.name = "test.html"
html_file = HtmlFile.from_bytes(file)

assert html_file.name == "test.html"
assert len(html_file.docs) == 1
assert html_file.docs[0].page_content == "Hello World"


# Test the `read_file` function with each file type
def test_read_file():
for ext, FileClass in [
(".docx", DocxFile),
(".pdf", PdfFile),
(".txt", TxtFile),
(".json", JsonFile),
(".html", HtmlFile),
]:
with open(SAMPLE_ROOT / f"test_hello{ext}", "rb") as f:
file = BytesIO(f.read())
file.name = f"test_hello{ext}"
file_obj = read_file(file)

assert isinstance(file_obj, FileClass)
assert file_obj.name == f"test_hello{ext}"
assert len(file_obj.docs) == 1
assert file_obj.docs[0].page_content == "Hello World" or '{"message": "Hello World"}'


# Test that read_file raises a NotImplementedError for unsupported file types
def test_read_file_not_implemented():
file = BytesIO(b"Hello World")
file.name = "test.unknown"
with pytest.raises(NotImplementedError):
read_file(file)


# Test the File.copy() method
def test_file_copy():
# Create a Document and FakeFile instance
document = Document(page_content="test content", metadata={"page": "1"})
file = FakeFile("test_file", "1234", {"author": "test"}, [document])

# Create a copy of the file
file_copy = file.copy()

# Check that the original and copy are distinct objects
assert file is not file_copy

# Check that the copy has the same attributes as the original
assert file.name == file_copy.name
assert file.id == file_copy.id

# Check that the mutable attributes were deeply copied
assert file.metadata == file_copy.metadata
assert file.metadata is not file_copy.metadata

# Check that the documents were deeply copied
assert file.docs == file_copy.docs
assert file.docs is not file_copy.docs

# Check that individual documents are not the same objects
assert file.docs[0] is not file_copy.docs[0]

# Check that the documents have the same attributes
assert file.docs[0].page_content == file_copy.docs[0].page_content
assert file.docs[0].metadata == file_copy.docs[0].metadata


# Test the strip_consecutive_newlines function
def test_strip_consecutive_newlines():
# Test with multiple consecutive newlines
text = "\n\n\n"
expected = "\n"
assert strip_consecutive_newlines(text) == expected

# Test with newlines and spaces
text = "\n \n \n"
expected = "\n"
assert strip_consecutive_newlines(text) == expected

# Test with newlines and tabs
text = "\n\t\n\t\n"
expected = "\n"
assert strip_consecutive_newlines(text) == expected

# Test with mixed whitespace characters
text = "\n \t\n \t \n"
expected = "\n"
assert strip_consecutive_newlines(text) == expected

# Test with no consecutive newlines
text = "\nHello\nWorld\n"
expected = "\nHello\nWorld\n"
assert strip_consecutive_newlines(text) == expected

0 comments on commit 54194b4

Please sign in to comment.