diff --git a/constraints.txt b/constraints.txt index 42aa9e2..eaca344 100644 --- a/constraints.txt +++ b/constraints.txt @@ -8,15 +8,15 @@ asgiref==3.8.1 # via django build==1.2.1 # via pip-tools -certifi==2024.7.4 +certifi==2024.8.30 # via requests charset-normalizer==3.3.2 # via requests click==8.1.7 # via pip-tools -coverage==7.6.0 +coverage==7.6.1 # via pytest-cov -django==4.2.14 +django==5.0.8 # via kirppu (pyproject.toml) django-environ==0.11.2 # via kirppu (pyproject.toml) @@ -24,7 +24,7 @@ django-ipware==5.0.2 # via kirppu (pyproject.toml) django-ratelimit==4.1.0 # via kirppu (pyproject.toml) -factory-boy==3.3.0 +factory-boy==3.3.1 # via kirppu (pyproject.toml) faker==19.13.0 # via @@ -34,7 +34,7 @@ future==1.0.0 # via pubcode gunicorn==22.0.0 # via kirppu (pyproject.toml) -idna==3.7 +idna==3.8 # via requests iniconfig==2.0.0 # via pytest @@ -57,6 +57,8 @@ pluggy==1.5.0 # via pytest psycopg==3.1.20 # via kirppu (pyproject.toml) +psycopg-c==3.1.20 + # via psycopg pubcode==1.1.0 # via kirppu (pyproject.toml) pyproject-hooks==1.1.0 @@ -91,7 +93,7 @@ typing-extensions==4.12.2 # via psycopg urllib3==2.2.2 # via requests -wheel==0.43.0 +wheel==0.44.0 # via pip-tools whitenoise==6.5.0 # via kirppu (pyproject.toml) diff --git a/kirppu/provision.py b/kirppu/provision.py index e084e9b..255c1f4 100644 --- a/kirppu/provision.py +++ b/kirppu/provision.py @@ -65,7 +65,7 @@ def run_function(cls, provision_function, sold_and_compensated) -> Optional[Deci _r = run(provision_function, sold_and_compensated=sold_and_compensated) - assert _r is None or isinstance(_r, Decimal), "Value returned from function must be null or a number" + assert _r is None or isinstance(_r, (Decimal, int)), "Value returned from function must be null or a number" return _r def _run_function(self, items: Optional[QuerySet] = None) -> Optional[Decimal]: diff --git a/kirppu/provision_dsl/interpreter.py b/kirppu/provision_dsl/interpreter.py index 81d3f53..23ad118 100644 --- a/kirppu/provision_dsl/interpreter.py +++ b/kirppu/provision_dsl/interpreter.py @@ -149,18 +149,47 @@ def atomize(token: str) -> typing.Union[decimal.Decimal, Symbol, Literal]: return Symbol(token) +def list_index(t: list | tuple, v: typing.Any) -> int: + try: + return t.index(v) + except ValueError: + return -1 + + def ensure_args(fn, *types): + vararg_pos = list_index(types, ...) + type_count = len(types) + assert type_count > 0, "Must have at least one type to ensure" + + if vararg_pos >= 1: + type_count -= 1 + assert vararg_pos == type_count, "In-the-middle varargs are not supported" + else: + vararg_pos = None + @functools.wraps(fn) def inner(*args): arg_count = len(args) - type_count = len(types) - if arg_count > type_count: + if vararg_pos is None and arg_count > type_count: raise Error("Too many arguments, %d, expected %d" % (arg_count, type_count), ErrorType.ARGUMENT_COUNT) if arg_count < type_count: raise Error("Too few arguments, %d, expected %d" % (arg_count, type_count), ErrorType.ARGUMENT_COUNT) - for index, (arg, arg_type) in enumerate(zip(args, types), start=1): + + arg_iter = iter(args) + type_iter = iter(types) + prev_type = None + current_type = next(type_iter) + for index, arg in enumerate(arg_iter, start=1): + arg_type = current_type if arg_type == decimal.Decimal: arg_type = (int, decimal.Decimal) + + if current_type is ...: + arg_type = prev_type + else: + prev_type = arg_type + current_type = next(type_iter, StopIteration) + if not isinstance(arg, arg_type): raise Error("Wrong type of argument given in index %d, got %s" % (index, type(arg)), ErrorType.ARGUMENT_TYPE) @@ -168,27 +197,46 @@ def inner(*args): return inner +def ensure_arg_count(fn, count: int): + @functools.wraps(fn) + def inner(*args): + arg_count = len(args) + if arg_count > count: + raise Error("Too many arguments, %d, expected %d" % (arg_count, count), ErrorType.ARGUMENT_COUNT) + if arg_count < count: + raise Error("Too few arguments, %d, expected %d" % (arg_count, count), ErrorType.ARGUMENT_COUNT) + return fn(*args) + return inner + + +def va_op(op: typing.Callable): + @functools.wraps(op) + def inner(*args): + return functools.reduce(op, args) + return inner + + def make_std_env(): return { - "+": ensure_args(operator.add, decimal.Decimal, decimal.Decimal), - "-": ensure_args(operator.sub, decimal.Decimal, decimal.Decimal), - "*": ensure_args(operator.mul, decimal.Decimal, decimal.Decimal), + "+": ensure_args(va_op(operator.add), decimal.Decimal, decimal.Decimal, ...), + "-": ensure_args(va_op(operator.sub), decimal.Decimal, decimal.Decimal, ...), + "*": ensure_args(va_op(operator.mul), decimal.Decimal, decimal.Decimal, ...), "/": ensure_args(operator.truediv, decimal.Decimal, decimal.Decimal), "//": ensure_args(operator.floordiv, decimal.Decimal, decimal.Decimal), - "<": operator.lt, - ">": operator.gt, - "=": operator.eq, - "<=": operator.le, - ">=": operator.ge, - "!": operator.ne, - "not": operator.not_, + "<": ensure_arg_count(operator.lt, 2), + ">": ensure_arg_count(operator.gt, 2), + "=": ensure_arg_count(operator.eq, 2), + "<=": ensure_arg_count(operator.le, 2), + ">=": ensure_arg_count(operator.ge, 2), + "!": ensure_arg_count(operator.ne, 2), + "not": ensure_arg_count(operator.not_, 1), "null": None, "abs": ensure_args(abs, decimal.Decimal), "begin": lambda *x: x[-1], # arguments are evaluated in evaluate. - "length": len, - "max": ensure_args(max, decimal.Decimal, decimal.Decimal), - "min": ensure_args(min, decimal.Decimal, decimal.Decimal), + "length": ensure_arg_count(len, 1), + "max": ensure_args(max, decimal.Decimal, decimal.Decimal, ...), + "min": ensure_args(min, decimal.Decimal, decimal.Decimal, ...), "round": ensure_args(round, decimal.Decimal), "ceil": ensure_args(math.ceil, decimal.Decimal), diff --git a/kirppu/provision_dsl/test_dsl.py b/kirppu/provision_dsl/test_dsl.py index 1dde0c4..8d372cf 100644 --- a/kirppu/provision_dsl/test_dsl.py +++ b/kirppu/provision_dsl/test_dsl.py @@ -132,7 +132,7 @@ def test_wrong_addition(self): def test_wrong_argument_count(self): with self.assertRaises(Error) as e: - dsl.run("""(+ 1 2 3)""") + dsl.run("""(< 1 2 3)""") self.assertEqual(ErrorType.ARGUMENT_COUNT, e.exception.code, make_ex_str(e.exception)) def test_parens_1(self): @@ -147,6 +147,12 @@ def test_non_single_top(self): with self.assertRaises(ValueError): dsl.run("1 1") + def test_wrong_vararg_type(self): + with self.assertRaises(Error) as e: + dsl.run("(min 1 2 'a)") + self.assertEqual(ErrorType.ARGUMENT_TYPE, e.exception.code, make_ex_str(e.exception)) + + class ProvisionDslDjangoTestCase(DjangoTestCase): def setUp(self): diff --git a/kirppu/provision_dsl/test_dsl.scm b/kirppu/provision_dsl/test_dsl.scm index 203f6d4..7da23ef 100644 --- a/kirppu/provision_dsl/test_dsl.scm +++ b/kirppu/provision_dsl/test_dsl.scm @@ -87,3 +87,35 @@ ;TEST conditional-op ((if 1 + -) 2 3) ;= 5 + +;TEST min regular +(min 2 1) +;= 1 + +;TEST max regular +(max 2 1) +;= 2 + +;TEST min va +(min 3 2 4) +;= 2 + +;TEST min va 2 +(min 3 5 2 4) +;= 2 + +;TEST max va +(max 3 4 5 6) +;= 6 + +;TEST add va +(+ 2 3 4) +;= 9 + +;TEST sub va +(- 10 2 3) +;= 5 + +;TEST mul va +(* 2 2 2 2) +;= 16 diff --git a/pyproject.toml b/pyproject.toml index 3272731..56d6279 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ dependencies = [ "django-environ~=0.11.0", "django-ipware~=5.0.0", "django-ratelimit~=4.0", - "django~=4.2.0", + "django~=5.0.0", "pillow~=10.0", "pubcode~=1.1.0", "mistune~=3.0.0", @@ -24,7 +24,7 @@ oauth = [ ] production = [ "gunicorn~=22.0", - "psycopg~=3.1.0", + "psycopg[c]~=3.1.0", ] dev = [ "factory-boy~=3.3.0", diff --git a/requirements-github.txt b/requirements-github.txt index 42aa9e2..eaca344 100644 --- a/requirements-github.txt +++ b/requirements-github.txt @@ -8,15 +8,15 @@ asgiref==3.8.1 # via django build==1.2.1 # via pip-tools -certifi==2024.7.4 +certifi==2024.8.30 # via requests charset-normalizer==3.3.2 # via requests click==8.1.7 # via pip-tools -coverage==7.6.0 +coverage==7.6.1 # via pytest-cov -django==4.2.14 +django==5.0.8 # via kirppu (pyproject.toml) django-environ==0.11.2 # via kirppu (pyproject.toml) @@ -24,7 +24,7 @@ django-ipware==5.0.2 # via kirppu (pyproject.toml) django-ratelimit==4.1.0 # via kirppu (pyproject.toml) -factory-boy==3.3.0 +factory-boy==3.3.1 # via kirppu (pyproject.toml) faker==19.13.0 # via @@ -34,7 +34,7 @@ future==1.0.0 # via pubcode gunicorn==22.0.0 # via kirppu (pyproject.toml) -idna==3.7 +idna==3.8 # via requests iniconfig==2.0.0 # via pytest @@ -57,6 +57,8 @@ pluggy==1.5.0 # via pytest psycopg==3.1.20 # via kirppu (pyproject.toml) +psycopg-c==3.1.20 + # via psycopg pubcode==1.1.0 # via kirppu (pyproject.toml) pyproject-hooks==1.1.0 @@ -91,7 +93,7 @@ typing-extensions==4.12.2 # via psycopg urllib3==2.2.2 # via requests -wheel==0.43.0 +wheel==0.44.0 # via pip-tools whitenoise==6.5.0 # via kirppu (pyproject.toml) diff --git a/requirements-production.txt b/requirements-production.txt index 83b394f..6f0d6cb 100644 --- a/requirements-production.txt +++ b/requirements-production.txt @@ -36,8 +36,10 @@ packaging # via gunicorn pillow # via kirppu (pyproject.toml) -psycopg +psycopg[c] # via kirppu (pyproject.toml) +psycopg-c + # via psycopg pubcode # via kirppu (pyproject.toml) requests