Skip to content

Commit

Permalink
[Fix] More coverage for python interpreter (#289)
Browse files Browse the repository at this point in the history
  • Loading branch information
dandansamax authored Sep 18, 2023
1 parent 03b7d27 commit 4c6a644
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 24 deletions.
3 changes: 2 additions & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ Code coverage measures the extent to which unit tests cover the code, helping id
To generate a report showing the current code coverage, execute one of the following commands.

To include all source files into coverage:
```

```bash
coverage erase
coverage run --source=. -m pytest .
coverage html
Expand Down
21 changes: 15 additions & 6 deletions camel/utils/python_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import difflib
import importlib
import typing
from typing import Any, Dict, List, Mapping, Optional
from typing import Any, Dict, List, Optional


class InterpreterError(ValueError):
Expand Down Expand Up @@ -165,6 +165,9 @@ def _execute_ast(self, expression: ast.AST) -> Any:
elif isinstance(expression, ast.Call):
# Function call -> return the value of the function call
return self._execute_call(expression)
elif isinstance(expression, ast.Compare):
# Compare -> return True or False
return self._execute_condition(expression)
elif isinstance(expression, ast.Constant):
# Constant -> just return the value
return expression.value
Expand Down Expand Up @@ -270,9 +273,11 @@ def _execute_subscript(self, subscript: ast.Subscript):
return value[int(index)]
if index in value:
return value[index]
if isinstance(index, str) and isinstance(value, Mapping):
close_matches = difflib.get_close_matches(index,
list(value.keys()))
if isinstance(index, str) and isinstance(value, dict):
close_matches = difflib.get_close_matches(
index,
[key for key in list(value.keys()) if isinstance(key, str)],
)
if len(close_matches) > 0:
return value[close_matches[0]]

Expand All @@ -286,7 +291,7 @@ def _execute_name(self, name: ast.Name):
else:
raise InterpreterError(f"{name.ctx} is not supported.")

def _execute_condition(self, condition):
def _execute_condition(self, condition: ast.Compare):
if len(condition.ops) > 1:
raise InterpreterError(
"Cannot evaluate conditions with multiple operators")
Expand Down Expand Up @@ -316,10 +321,14 @@ def _execute_condition(self, condition):
elif isinstance(comparator, ast.NotIn):
return left not in right
else:
raise InterpreterError(f"Operator not supported: {comparator}")
raise InterpreterError(f"Unsupported operator: {comparator}")

def _execute_if(self, if_statement: ast.If):
result = None
if not isinstance(if_statement.test, ast.Compare):
raise InterpreterError(
"Only Campare expr supported in if statement, get"
f" {if_statement.test.__class__.__name__}")
if self._execute_condition(if_statement.test):
for line in if_statement.body:
line_result = self._execute_ast(line)
Expand Down
125 changes: 108 additions & 17 deletions test/utils/test_python_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,29 @@ def action_function():

@pytest.fixture()
def interpreter():
action_space = {"action1": action_function}
action_space = {"action1": action_function, "str": str}
white_list = ["torch", "numpy.array", "openai"]
return PythonInterpreter(action_space=action_space,
import_white_list=white_list)


def test_import_success0(interpreter):
def test_state_update(interpreter: PythonInterpreter):
code = "x = input_variable"
input_variable = 10
execution_res = interpreter.execute(
code, state={"input_variable": input_variable})
assert execution_res == input_variable


def test_syntax_error(interpreter: PythonInterpreter):
code = "x input_variable"
with pytest.raises(InterpreterError) as e:
interpreter.execute(code)
exec_msg = e.value.args[0]
assert "Syntax error in code: invalid syntax" in exec_msg


def test_import_success0(interpreter: PythonInterpreter):
code = """import torch as pt, openai
a = pt.tensor([[1., -1.], [1., -1.]])
openai.__version__"""
Expand All @@ -41,21 +57,21 @@ def test_import_success0(interpreter):
assert isinstance(execution_res, str)


def test_import_success1(interpreter):
def test_import_success1(interpreter: PythonInterpreter):
code = """from torch import tensor
a = tensor([[1., -1.], [1., -1.]])"""
execution_res = interpreter.execute(code)
assert torch.equal(execution_res, torch.tensor([[1., -1.], [1., -1.]]))


def test_import_success2(interpreter):
def test_import_success2(interpreter: PythonInterpreter):
code = """from numpy import array
x = array([[1, 2, 3], [4, 5, 6]])"""
execution_res = interpreter.execute(code)
assert np.equal(execution_res, np.array([[1, 2, 3], [4, 5, 6]])).all()


def test_import_fail0(interpreter):
def test_import_fail0(interpreter: PythonInterpreter):
code = """import os
os.mkdir("/tmp/test")"""
with pytest.raises(InterpreterError) as e:
Expand All @@ -66,7 +82,7 @@ def test_import_fail0(interpreter):
" white list (try to import os).")


def test_import_fail1(interpreter):
def test_import_fail1(interpreter: PythonInterpreter):
code = """import numpy as np
x = np.array([[1, 2, 3], [4, 5, 6]], np.int32)"""
with pytest.raises(InterpreterError) as e:
Expand All @@ -77,13 +93,13 @@ def test_import_fail1(interpreter):
" white list (try to import numpy).")


def test_action_space(interpreter):
def test_action_space(interpreter: PythonInterpreter):
code = "res = action1()"
execution_res = interpreter.execute(code)
assert execution_res == "access action function"


def test_fuzz_space(interpreter):
def test_fuzz_space(interpreter: PythonInterpreter):
from PIL import Image
fuzz_state = {"image": Image.new("RGB", (256, 256))}
code = "output_image = input_image.crop((20, 20, 100, 100))"
Expand All @@ -92,7 +108,7 @@ def test_fuzz_space(interpreter):
assert execution_res.height == 80


def test_keep_state0(interpreter):
def test_keep_state0(interpreter: PythonInterpreter):
code1 = "a = 42"
code2 = "b = a"
code3 = "c = b"
Expand All @@ -108,7 +124,7 @@ def test_keep_state0(interpreter):
"The variable `b` is not defined.")


def test_keep_state1(interpreter):
def test_keep_state1(interpreter: PythonInterpreter):
code1 = "from torch import tensor"
code2 = "a = tensor([[1., -1.], [1., -1.]])"
execution_res = interpreter.execute(code1, keep_state=True)
Expand All @@ -121,22 +137,22 @@ def test_keep_state1(interpreter):
"The variable `tensor` is not defined.")


def test_assign0(interpreter):
def test_assign0(interpreter: PythonInterpreter):
code = "a = b = 1"
interpreter.execute(code)
assert interpreter.state["a"] == 1
assert interpreter.state["b"] == 1


def test_assign1(interpreter):
def test_assign1(interpreter: PythonInterpreter):
code = "a, b = c = 2, 3"
interpreter.execute(code)
assert interpreter.state["a"] == 2
assert interpreter.state["b"] == 3
assert interpreter.state["c"] == (2, 3)


def test_assign_fail(interpreter):
def test_assign_fail(interpreter: PythonInterpreter):
code = "x = a, b, c = 2, 3"
with pytest.raises(InterpreterError) as e:
interpreter.execute(code, keep_state=False)
Expand All @@ -145,7 +161,7 @@ def test_assign_fail(interpreter):
"Expected 3 values but got 2.")


def test_if(interpreter):
def test_if0(interpreter: PythonInterpreter):
code = """a = 0
b = 1
if a < b:
Expand All @@ -159,7 +175,49 @@ def test_if(interpreter):
assert interpreter.state["b"] == 0


def test_for(interpreter):
def test_if1(interpreter: PythonInterpreter):
code = """a = 1
b = 0
if a < b:
t = a
a = b
b = t
else:
b = a"""
interpreter.execute(code)
assert interpreter.state["a"] == 1
assert interpreter.state["b"] == 1


def test_compare(interpreter: PythonInterpreter):
assert interpreter.execute("2 > 1") is True
assert interpreter.execute("2 >= 1") is True
assert interpreter.execute("2 < 1") is False
assert interpreter.execute("2 == 1") is False
assert interpreter.execute("2 != 1") is True
assert interpreter.execute("1 <= 1") is True
assert interpreter.execute("True is True") is True
assert interpreter.execute("1 is not str") is True
assert interpreter.execute("1 in [1, 2]") is True
assert interpreter.execute("1 not in [1, 2]") is False


def test_oprators(interpreter: PythonInterpreter):
assert interpreter.execute("1 + 1") == 2
assert interpreter.execute("1 - 1") == 0
assert interpreter.execute("1 * 1") == 1
assert interpreter.execute("1 / 2") == 0.5
assert interpreter.execute("1 // 2") == 0
assert interpreter.execute("1 % 2") == 1
assert interpreter.execute("2 ** 2") == 4
assert interpreter.execute("10 >> 2") == 2
assert interpreter.execute("1 << 2") == 4
assert interpreter.execute("+1") == 1
assert interpreter.execute("-1") == -1
assert interpreter.execute("not True") is False


def test_for(interpreter: PythonInterpreter):
code = """l = [2, 3, 5, 7, 11]
sum = 0
for i in l:
Expand All @@ -168,14 +226,14 @@ def test_for(interpreter):
assert execution_res == 28


def test_subscript_access(interpreter):
def test_subscript_access(interpreter: PythonInterpreter):
code = """l = [2, 3, 5, 7, 11]
res = l[3]"""
execution_res = interpreter.execute(code)
assert execution_res == 7


def test_subscript_assign(interpreter):
def test_subscript_assign(interpreter: PythonInterpreter):
code = """l = [2, 3, 5, 7, 11]
l[3] = 1"""
with pytest.raises(InterpreterError) as e:
Expand All @@ -184,3 +242,36 @@ def test_subscript_assign(interpreter):
assert exec_msg == ("Evaluation of the code stopped at node 1. See:\n"
"Unsupported variable type. Expected ast.Name or "
"ast.Tuple, got Subscript instead.")


def test_dict(interpreter: PythonInterpreter):
code = """x = {1: 10, 2: 20}
y = {"number": 30, **x}
res = y[1] + y[2] + y["numbers"]"""
execution_res = interpreter.execute(code)
assert execution_res == 60


def test_formatted_value(interpreter: PythonInterpreter):
code = """x = 3
res = f"x = {x}"
"""
execution_res = interpreter.execute(code)
assert execution_res == "x = 3"


def test_joined_str(interpreter: PythonInterpreter):
code = """l = ["2", "3", "5", "7", "11"]
res = ",".join(l)"""
execution_res = interpreter.execute(code)
assert execution_res == "2,3,5,7,11"


def test_expression_not_support(interpreter: PythonInterpreter):
code = """x = 1
x += 1"""
with pytest.raises(InterpreterError) as e:
interpreter.execute(code, keep_state=False)
exec_msg = e.value.args[0]
assert exec_msg == ("Evaluation of the code stopped at node 1. See:"
"\nAugAssign is not supported.")

0 comments on commit 4c6a644

Please sign in to comment.