Skip to content

Commit

Permalink
version bump
Browse files Browse the repository at this point in the history
  • Loading branch information
areibman committed Sep 16, 2024
1 parent 8cce891 commit b6afb02
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 54 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -166,4 +166,5 @@ cython_debug/
# other
.DS_Store

./output.txt
output.txt
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "spellcaster"
version = "0.0.1"
version = "0.0.2"
authors = [
{ name="Alex Reibman", email="[email protected]" }
]
Expand Down
6 changes: 4 additions & 2 deletions spellcaster/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,9 @@ def main():
)

args = parser.parse_args()

path = ""
if args.url:
path = args.url
current_dir = Path.cwd()
repo_name = args.url.rstrip('/').split("/")[-1].replace(".git", "")
org = args.url.rstrip('/').split("/")[-2]
Expand All @@ -63,6 +64,7 @@ def main():
clone_repository(args.url, str(directory))
print(f"Repository cloned successfully to {directory}")
elif args.directory:
path = args.directory
directory = args.directory
print(f"Using existing directory: {directory}")
else:
Expand Down Expand Up @@ -92,7 +94,7 @@ def main():

total = 0
for result in results:
errors = display_results(result, args.url)
errors = display_results(result, result.file_path, args.url)
total += errors

console.print(f"[bold red]Total errors in the docs found: {total}[/bold red]")
Expand Down
60 changes: 9 additions & 51 deletions spellcaster/grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import litellm
from agentops import record_action
from .config import MODEL
from agentops import ErrorEvent, record, ActionEvent


class Error(BaseModel):
Expand Down Expand Up @@ -92,9 +93,7 @@ def check_grammar(file_path: str, proper_nouns: str, model: str = MODEL) -> Gram
resp = Grammar.model_validate_json(text_response)

# Double check step
print('Double check step')
# validate_reasoning(text_response)
# print("while checking", resp.json())
console.print("[bold blue]Double check step[/bold blue]")

validated_response = validate_reasoning(resp.json(), model=model)

Expand All @@ -108,6 +107,8 @@ def check_grammar(file_path: str, proper_nouns: str, model: str = MODEL) -> Gram
return resp
except Exception as e:
print(f"Error: {e}")
action_event = ActionEvent()
record(ErrorEvent(exception=e, trigger_event=action_event))
return Grammar(spelling=[], punctuation=[], grammar=[], file_path=file_path)


Expand Down Expand Up @@ -195,40 +196,16 @@ def validate_reasoning(text: str, model: str = MODEL) -> bool:
return True


# def display_results(response: Grammar, github_url: str):
# """Display the grammar check results using Rich."""
# # Replace local file path with GitHub URL
# github_file_path = github_url.rstrip('/') + '/blob/main/' + \
# '/'.join(response.file_path.split("samples/")[1].split('/')[1:])
# console.print(f"\n[bold cyan]File: {github_file_path}[/bold cyan]")
# for category in ['spelling', 'punctuation', 'grammar']:
# table = Table(title=f"{category.capitalize()} Corrections", box=ROUNDED)
# table.add_column("Original", justify="left", style="bold red")
# table.add_column("Corrected", justify="left", style="bold green")
# table.add_column("Explanation", justify="left", style="italic")

# errors = getattr(response, category)
# for error in errors:
# if error.before != error.after:
# table.add_row(error.before, error.after, error.explanation)

# if table.row_count > 0:
# console.print(table)
# else:
# console.print(f"No {category} errors found.")

# # console.print(Text("\nCorrected Text:\n", style="bold cyan"))
# # console.print(Text(response.corrected, style="white"))

def display_results(response: Grammar, github_url: str):
def display_results(response: Grammar, path: str, repo_link: str = ""):
"""Display the grammar check results using Rich."""
# Replace local file path with GitHub URL
github_file_path = github_url.rstrip('/') + '/blob/main/' + \
'/'.join(response.file_path.split("samples/")[1].split('/')[2:])
if repo_link:
path = repo_link.rstrip('/') + '/blob/main/' + \
'/'.join(response.file_path.split("samples/")[1].split('/')[2:])
# Create a console for file output
console = Console(record=True)

console.print(f"\n[bold cyan]File: {github_file_path}[/bold cyan]")
console.print(f"\n[bold cyan]File: {path}[/bold cyan]")

total_errors = 0

Expand Down Expand Up @@ -257,22 +234,3 @@ def display_results(response: Grammar, github_url: str):
f.write(console.export_text())

return total_errors


def process_file(file_path: str, model: str = MODEL):
"""Process a single file and display results."""
console.print(f"\n[bold cyan]Processing file: {file_path}[/bold cyan]")
response = check_grammar(file_path, model)
display_results(response, "https://github.com/AgentOps-AI/spellcaster/blob/main/spellcaster/samples/")

output_file = f"{file_path.rsplit('.', 1)[0]}_corrected.mdx"
with open(output_file, "w") as file:
file.write(response.corrected)
console.print(f"[green]Corrected text saved to: {output_file}[/green]")


if __name__ == "__main__":
sample_files = ["../data/sample1.mdx", "../data/sample2.mdx", "../data/sample3.mdx"]

for file_path in sample_files:
process_file(file_path)

0 comments on commit b6afb02

Please sign in to comment.