diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 598c0197e..6aa2b50d8 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -111,7 +111,8 @@ Code coverage measures the extent to which unit tests cover the code, helping id To generate a report showing the current code coverage, execute one of the following commands. To include all source files into coverage: -``` + +```bash coverage erase coverage run --source=. -m pytest . coverage html diff --git a/camel/utils/python_interpreter.py b/camel/utils/python_interpreter.py index 4dad4d927..a7bfc13fd 100644 --- a/camel/utils/python_interpreter.py +++ b/camel/utils/python_interpreter.py @@ -15,7 +15,7 @@ import difflib import importlib import typing -from typing import Any, Dict, List, Mapping, Optional +from typing import Any, Dict, List, Optional class InterpreterError(ValueError): @@ -165,6 +165,9 @@ def _execute_ast(self, expression: ast.AST) -> Any: elif isinstance(expression, ast.Call): # Function call -> return the value of the function call return self._execute_call(expression) + elif isinstance(expression, ast.Compare): + # Compare -> return True or False + return self._execute_condition(expression) elif isinstance(expression, ast.Constant): # Constant -> just return the value return expression.value @@ -270,9 +273,11 @@ def _execute_subscript(self, subscript: ast.Subscript): return value[int(index)] if index in value: return value[index] - if isinstance(index, str) and isinstance(value, Mapping): - close_matches = difflib.get_close_matches(index, - list(value.keys())) + if isinstance(index, str) and isinstance(value, dict): + close_matches = difflib.get_close_matches( + index, + [key for key in list(value.keys()) if isinstance(key, str)], + ) if len(close_matches) > 0: return value[close_matches[0]] @@ -286,7 +291,7 @@ def _execute_name(self, name: ast.Name): else: raise InterpreterError(f"{name.ctx} is not supported.") - def _execute_condition(self, condition): + def _execute_condition(self, condition: ast.Compare): if len(condition.ops) > 1: raise InterpreterError( "Cannot evaluate conditions with multiple operators") @@ -316,10 +321,14 @@ def _execute_condition(self, condition): elif isinstance(comparator, ast.NotIn): return left not in right else: - raise InterpreterError(f"Operator not supported: {comparator}") + raise InterpreterError(f"Unsupported operator: {comparator}") def _execute_if(self, if_statement: ast.If): result = None + if not isinstance(if_statement.test, ast.Compare): + raise InterpreterError( + "Only Campare expr supported in if statement, get" + f" {if_statement.test.__class__.__name__}") if self._execute_condition(if_statement.test): for line in if_statement.body: line_result = self._execute_ast(line) diff --git a/test/utils/test_python_interpreter.py b/test/utils/test_python_interpreter.py index 631d89eba..d5f688559 100644 --- a/test/utils/test_python_interpreter.py +++ b/test/utils/test_python_interpreter.py @@ -25,13 +25,29 @@ def action_function(): @pytest.fixture() def interpreter(): - action_space = {"action1": action_function} + action_space = {"action1": action_function, "str": str} white_list = ["torch", "numpy.array", "openai"] return PythonInterpreter(action_space=action_space, import_white_list=white_list) -def test_import_success0(interpreter): +def test_state_update(interpreter: PythonInterpreter): + code = "x = input_variable" + input_variable = 10 + execution_res = interpreter.execute( + code, state={"input_variable": input_variable}) + assert execution_res == input_variable + + +def test_syntax_error(interpreter: PythonInterpreter): + code = "x input_variable" + with pytest.raises(InterpreterError) as e: + interpreter.execute(code) + exec_msg = e.value.args[0] + assert "Syntax error in code: invalid syntax" in exec_msg + + +def test_import_success0(interpreter: PythonInterpreter): code = """import torch as pt, openai a = pt.tensor([[1., -1.], [1., -1.]]) openai.__version__""" @@ -41,21 +57,21 @@ def test_import_success0(interpreter): assert isinstance(execution_res, str) -def test_import_success1(interpreter): +def test_import_success1(interpreter: PythonInterpreter): code = """from torch import tensor a = tensor([[1., -1.], [1., -1.]])""" execution_res = interpreter.execute(code) assert torch.equal(execution_res, torch.tensor([[1., -1.], [1., -1.]])) -def test_import_success2(interpreter): +def test_import_success2(interpreter: PythonInterpreter): code = """from numpy import array x = array([[1, 2, 3], [4, 5, 6]])""" execution_res = interpreter.execute(code) assert np.equal(execution_res, np.array([[1, 2, 3], [4, 5, 6]])).all() -def test_import_fail0(interpreter): +def test_import_fail0(interpreter: PythonInterpreter): code = """import os os.mkdir("/tmp/test")""" with pytest.raises(InterpreterError) as e: @@ -66,7 +82,7 @@ def test_import_fail0(interpreter): " white list (try to import os).") -def test_import_fail1(interpreter): +def test_import_fail1(interpreter: PythonInterpreter): code = """import numpy as np x = np.array([[1, 2, 3], [4, 5, 6]], np.int32)""" with pytest.raises(InterpreterError) as e: @@ -77,13 +93,13 @@ def test_import_fail1(interpreter): " white list (try to import numpy).") -def test_action_space(interpreter): +def test_action_space(interpreter: PythonInterpreter): code = "res = action1()" execution_res = interpreter.execute(code) assert execution_res == "access action function" -def test_fuzz_space(interpreter): +def test_fuzz_space(interpreter: PythonInterpreter): from PIL import Image fuzz_state = {"image": Image.new("RGB", (256, 256))} code = "output_image = input_image.crop((20, 20, 100, 100))" @@ -92,7 +108,7 @@ def test_fuzz_space(interpreter): assert execution_res.height == 80 -def test_keep_state0(interpreter): +def test_keep_state0(interpreter: PythonInterpreter): code1 = "a = 42" code2 = "b = a" code3 = "c = b" @@ -108,7 +124,7 @@ def test_keep_state0(interpreter): "The variable `b` is not defined.") -def test_keep_state1(interpreter): +def test_keep_state1(interpreter: PythonInterpreter): code1 = "from torch import tensor" code2 = "a = tensor([[1., -1.], [1., -1.]])" execution_res = interpreter.execute(code1, keep_state=True) @@ -121,14 +137,14 @@ def test_keep_state1(interpreter): "The variable `tensor` is not defined.") -def test_assign0(interpreter): +def test_assign0(interpreter: PythonInterpreter): code = "a = b = 1" interpreter.execute(code) assert interpreter.state["a"] == 1 assert interpreter.state["b"] == 1 -def test_assign1(interpreter): +def test_assign1(interpreter: PythonInterpreter): code = "a, b = c = 2, 3" interpreter.execute(code) assert interpreter.state["a"] == 2 @@ -136,7 +152,7 @@ def test_assign1(interpreter): assert interpreter.state["c"] == (2, 3) -def test_assign_fail(interpreter): +def test_assign_fail(interpreter: PythonInterpreter): code = "x = a, b, c = 2, 3" with pytest.raises(InterpreterError) as e: interpreter.execute(code, keep_state=False) @@ -145,7 +161,7 @@ def test_assign_fail(interpreter): "Expected 3 values but got 2.") -def test_if(interpreter): +def test_if0(interpreter: PythonInterpreter): code = """a = 0 b = 1 if a < b: @@ -159,7 +175,49 @@ def test_if(interpreter): assert interpreter.state["b"] == 0 -def test_for(interpreter): +def test_if1(interpreter: PythonInterpreter): + code = """a = 1 +b = 0 +if a < b: + t = a + a = b + b = t +else: + b = a""" + interpreter.execute(code) + assert interpreter.state["a"] == 1 + assert interpreter.state["b"] == 1 + + +def test_compare(interpreter: PythonInterpreter): + assert interpreter.execute("2 > 1") is True + assert interpreter.execute("2 >= 1") is True + assert interpreter.execute("2 < 1") is False + assert interpreter.execute("2 == 1") is False + assert interpreter.execute("2 != 1") is True + assert interpreter.execute("1 <= 1") is True + assert interpreter.execute("True is True") is True + assert interpreter.execute("1 is not str") is True + assert interpreter.execute("1 in [1, 2]") is True + assert interpreter.execute("1 not in [1, 2]") is False + + +def test_oprators(interpreter: PythonInterpreter): + assert interpreter.execute("1 + 1") == 2 + assert interpreter.execute("1 - 1") == 0 + assert interpreter.execute("1 * 1") == 1 + assert interpreter.execute("1 / 2") == 0.5 + assert interpreter.execute("1 // 2") == 0 + assert interpreter.execute("1 % 2") == 1 + assert interpreter.execute("2 ** 2") == 4 + assert interpreter.execute("10 >> 2") == 2 + assert interpreter.execute("1 << 2") == 4 + assert interpreter.execute("+1") == 1 + assert interpreter.execute("-1") == -1 + assert interpreter.execute("not True") is False + + +def test_for(interpreter: PythonInterpreter): code = """l = [2, 3, 5, 7, 11] sum = 0 for i in l: @@ -168,14 +226,14 @@ def test_for(interpreter): assert execution_res == 28 -def test_subscript_access(interpreter): +def test_subscript_access(interpreter: PythonInterpreter): code = """l = [2, 3, 5, 7, 11] res = l[3]""" execution_res = interpreter.execute(code) assert execution_res == 7 -def test_subscript_assign(interpreter): +def test_subscript_assign(interpreter: PythonInterpreter): code = """l = [2, 3, 5, 7, 11] l[3] = 1""" with pytest.raises(InterpreterError) as e: @@ -184,3 +242,36 @@ def test_subscript_assign(interpreter): assert exec_msg == ("Evaluation of the code stopped at node 1. See:\n" "Unsupported variable type. Expected ast.Name or " "ast.Tuple, got Subscript instead.") + + +def test_dict(interpreter: PythonInterpreter): + code = """x = {1: 10, 2: 20} +y = {"number": 30, **x} +res = y[1] + y[2] + y["numbers"]""" + execution_res = interpreter.execute(code) + assert execution_res == 60 + + +def test_formatted_value(interpreter: PythonInterpreter): + code = """x = 3 +res = f"x = {x}" + """ + execution_res = interpreter.execute(code) + assert execution_res == "x = 3" + + +def test_joined_str(interpreter: PythonInterpreter): + code = """l = ["2", "3", "5", "7", "11"] +res = ",".join(l)""" + execution_res = interpreter.execute(code) + assert execution_res == "2,3,5,7,11" + + +def test_expression_not_support(interpreter: PythonInterpreter): + code = """x = 1 +x += 1""" + with pytest.raises(InterpreterError) as e: + interpreter.execute(code, keep_state=False) + exec_msg = e.value.args[0] + assert exec_msg == ("Evaluation of the code stopped at node 1. See:" + "\nAugAssign is not supported.")