From 7ef2214956c7e74309ef2c1d66e24c5857ee7ae2 Mon Sep 17 00:00:00 2001 From: PPakalns Date: Thu, 12 Dec 2019 12:16:55 +0200 Subject: [PATCH] Correctly validate evaluation expression input * There must be exactly one statement in input * Given statement must be an expression (assignment, augmented assignment, etc. are not allowed) --- simpleeval.py | 37 +++++++++++++++++++++++++- test_simpleeval.py | 65 ++++++++++++++++++++++++++++++++++------------ 2 files changed, 84 insertions(+), 18 deletions(-) diff --git a/simpleeval.py b/simpleeval.py index 731c37e..d84caf6 100644 --- a/simpleeval.py +++ b/simpleeval.py @@ -51,6 +51,7 @@ - JCavallo (Jean Cavallo) names dict shouldn't be modified - Birne94 (Daniel Birnstiel) for fixing leaking generators. - patricksurry (Patrick Surry) or should return last value, even if falsy. +- PPakalns (Peteris Pakalns) correctly handle incorrectly given expressions ------------------------------------- Basic Usage: @@ -185,6 +186,25 @@ class IterableTooLong(InvalidExpression): pass +class NotAnExpression(InvalidExpression): + """ Given statement is not an expression. e.g. `a += b`, `a = b` """ + + pass + + +class MultipleStatementsPassed(InvalidExpression): + """ When multiple statements are passed like 'a\nb' simple eval + can not decide which statement should be evaluated """ + + pass + + +class NoStatementPassed(InvalidExpression): + """ Input without any statements are passed can not be evaluated """ + + pass + + ######################################## # Default simple functions to include: @@ -329,7 +349,22 @@ def eval(self, expr): self.expr = expr # and evaluate: - return self._eval(ast.parse(expr.strip()).body[0].value) + statements = ast.parse(expr.strip()).body + if len(statements) == 0: + raise NoStatementPassed( + "Expression doesn't contain evaluable statement." + ) + if len(statements) > 1: + raise MultipleStatementsPassed( + "Expression contains more than one evaluable statement." + ) + statement = statements[0] + if not isinstance(statement, ast.Expr): + raise NotAnExpression( + "Given input expression is not pure." + " It contains assignment, raise or similar statement." + ) + return self._eval(statement.value) def _eval(self, node): """ The internal evaluator used on each node in the parsed tree. """ diff --git a/test_simpleeval.py b/test_simpleeval.py index 5ce703e..a58f57c 100644 --- a/test_simpleeval.py +++ b/test_simpleeval.py @@ -13,8 +13,9 @@ import simpleeval import os from simpleeval import ( - SimpleEval, EvalWithCompoundTypes, FeatureNotAvailable, FunctionNotDefined, NameNotDefined, - InvalidExpression, AttributeDoesNotExist, simple_eval + SimpleEval, EvalWithCompoundTypes, FeatureNotAvailable, FunctionNotDefined, + NameNotDefined, InvalidExpression, AttributeDoesNotExist, simple_eval, + NotAnExpression, MultipleStatementsPassed, NoStatementPassed ) @@ -170,6 +171,27 @@ def test_set_not_allowed(self): with self.assertRaises(FeatureNotAvailable): self.t('{22}', False) + def test_multiple_statemets(self): + with self.assertRaises(MultipleStatementsPassed): + self.t("1\n2", 1) + + self.t("(1 + \n 2)", 3) + + self.t("\n 1 \n ", 1) + + with self.assertRaises(MultipleStatementsPassed): + self.t("a = 11; x = 21; x + x", 11) + + def test_no_statement_passed(self): + with self.assertRaises(NoStatementPassed): + self.t("", None) + + with self.assertRaises(NoStatementPassed): + self.t("\n\n\n", None) + + with self.assertRaises(NoStatementPassed): + self.t("\n\t\n", None) + class TestFunctions(DRYTest): """ Functions for expressions to play with """ @@ -299,7 +321,7 @@ class TestTryingToBreakOut(DRYTest): def test_import(self): """ usual suspect. import """ # cannot import things: - with self.assertRaises(AttributeError): + with self.assertRaises(NotAnExpression): self.t("import sys", None) def test_long_running(self): @@ -374,12 +396,6 @@ def test_list_length_test(self): with self.assertRaises(simpleeval.IterableTooLong): self.t("('spam spam spam' * 5000).split() * 5000", None) - def test_python_stuff(self): - """ other various pythony things. """ - # it only evaluates the first statement: - self.t("a = 11; x = 21; x + x", 11) - - def test_function_globals_breakout(self): """ by accessing function.__globals__ or func_... """ # thanks perkinslr. @@ -662,7 +678,7 @@ def test_none(self): self.s.names["s"] = 21 - with self.assertRaises(NameNotDefined): + with self.assertRaises(NotAnExpression): self.t("s += a", 21) self.s.names = None @@ -687,7 +703,18 @@ def test_dict(self): # however, you can't assign to those names: - self.t("a = 200", 200) + with self.assertRaises(NotAnExpression): + self.t("a = 200", 200) + + self.assertEqual(self.s.names['a'], 42) + + # however, you can't augmented assign to those names: + + with self.assertRaises(NotAnExpression): + self.t("a += 200", 200) + + with self.assertRaises(NotAnExpression): + self.t("a -= 200", 200) self.assertEqual(self.s.names['a'], 42) @@ -695,7 +722,8 @@ def test_dict(self): self.s.names['b'] = [0] - self.t("b[0] = 11", 11) + with self.assertRaises(NotAnExpression): + self.t("b[0] = 11", 11) self.assertEqual(self.s.names['b'], [0]) @@ -716,7 +744,8 @@ def test_dict(self): # you still can't assign though: - self.t("c['b'] = 99", 99) + with self.assertRaises(NotAnExpression): + self.t("c['b'] = 99", 99) self.assertFalse('b' in self.s.names['c']) @@ -724,7 +753,8 @@ def test_dict(self): self.s.names['c']['c'] = {'c': 11} - self.t("c['c']['c'] = 21", 21) + with self.assertRaises(NotAnExpression): + self.t("c['c']['c'] = 21", 21) self.assertEqual(self.s.names['c']['c']['c'], 11) @@ -737,12 +767,13 @@ def test_dict_attr_access(self): self.t("a.b.c*2", 84) - self.t("a.b.c = 11", 11) + with self.assertRaises(NotAnExpression): + self.t("a.b.c = 11", 11) self.assertEqual(self.s.names['a']['b']['c'], 42) - # TODO: Wat? - self.t("a.d = 11", 11) + with self.assertRaises(NotAnExpression): + self.t("a.d = 11", 11) with self.assertRaises(KeyError): self.assertEqual(self.s.names['a']['d'], 11)