diff --git a/bandit/core/node_visitor.py b/bandit/core/node_visitor.py index cf7f190a..6a820cbc 100755 --- a/bandit/core/node_visitor.py +++ b/bandit/core/node_visitor.py @@ -23,7 +23,7 @@ from bandit.core import utils as b_utils from bandit.core.utils import InvalidModulePath -class BanditNodeVisitor(ast.NodeVisitor): +class BanditNodeVisitor(object): imports = set() import_aliases = {} @@ -35,7 +35,7 @@ class BanditNodeVisitor(ast.NodeVisitor): depth = 0 context = None - context_template = {'node': None, 'filename': None, 'statement': None, + context_template = {'node': None, 'filename': None, 'name': None, 'qualname': None, 'module': None, 'imports': None, 'import_aliases': None, 'call': None, 'function': None, 'lineno': None, 'skip_lines': None} @@ -80,9 +80,12 @@ class BanditNodeVisitor(ast.NodeVisitor): :return: - ''' + if self.debug: + self.logger.debug("visit_ClassDef called (%s)", ast.dump(node)) + # For all child nodes, add this class name to current namespace self.namespace = b_utils.namespace_path_join(self.namespace, node.name) - super(BanditNodeVisitor, self).generic_visit(node) + self.generic_visit(node) self.namespace = b_utils.namespace_path_split(self.namespace)[0] def visit_FunctionDef(self, node): @@ -97,7 +100,8 @@ class BanditNodeVisitor(ast.NodeVisitor): self.context['function'] = node - self.logger.debug("visit_FunctionDef called (%s)", ast.dump(node)) + if self.debug: + self.logger.debug("visit_FunctionDef called (%s)", ast.dump(node)) qualname = self.namespace + '.' + b_utils.get_func_name(node) name = qualname.split('.')[-1] @@ -109,7 +113,7 @@ class BanditNodeVisitor(ast.NodeVisitor): # current namespace self.namespace = b_utils.namespace_path_join(self.namespace, name) self.update_scores(self.tester.run_tests(self.context, 'FunctionDef')) - super(BanditNodeVisitor, self).generic_visit(node) + self.generic_visit(node) self.namespace = b_utils.namespace_path_split(self.namespace)[0] def visit_Call(self, node): @@ -123,7 +127,8 @@ class BanditNodeVisitor(ast.NodeVisitor): self.context['call'] = node - self.logger.debug("visit_Call called (%s)", ast.dump(node)) + if self.debug: + self.logger.debug("visit_Call called (%s)", ast.dump(node)) qualname = b_utils.get_call_name(node, self.import_aliases) name = qualname.split('.')[-1] @@ -132,7 +137,7 @@ class BanditNodeVisitor(ast.NodeVisitor): self.context['name'] = name self.update_scores(self.tester.run_tests(self.context, 'Call')) - super(BanditNodeVisitor, self).generic_visit(node) + self.generic_visit(node) def visit_Import(self, node): '''Visitor for AST Import nodes @@ -142,15 +147,16 @@ class BanditNodeVisitor(ast.NodeVisitor): :param node: The node that is being inspected :return: - ''' + if self.debug: + self.logger.debug("visit_Import called (%s)", ast.dump(node)) - self.logger.debug("visit_Import called (%s)", ast.dump(node)) for nodename in node.names: if nodename.asname: self.context['import_aliases'][nodename.asname] = nodename.name self.context['imports'].add(nodename.name) self.context['module'] = nodename.name self.update_scores(self.tester.run_tests(self.context, 'Import')) - super(BanditNodeVisitor, self).generic_visit(node) + self.generic_visit(node) def visit_ImportFrom(self, node): '''Visitor for AST Import nodes @@ -160,8 +166,8 @@ class BanditNodeVisitor(ast.NodeVisitor): :param node: The node that is being inspected :return: - ''' - - self.logger.debug("visit_ImportFrom called (%s)", ast.dump(node)) + if self.debug: + self.logger.debug("visit_ImportFrom called (%s)", ast.dump(node)) module = node.module if module is None: @@ -186,7 +192,7 @@ class BanditNodeVisitor(ast.NodeVisitor): self.context['module'] = module self.context['name'] = nodename.name self.update_scores(self.tester.run_tests(self.context, 'ImportFrom')) - super(BanditNodeVisitor, self).generic_visit(node) + self.generic_visit(node) def visit_Str(self, node): '''Visitor for AST String nodes @@ -197,26 +203,32 @@ class BanditNodeVisitor(ast.NodeVisitor): :return: - ''' self.context['str'] = node.s - self.logger.debug("visit_Str called (%s)", ast.dump(node)) + + if self.debug: + self.logger.debug("visit_Str called (%s)", ast.dump(node)) if not isinstance(node.parent, ast.Expr): # docstring self.context['linerange'] = b_utils.linerange_fix(node.parent) self.update_scores(self.tester.run_tests(self.context, 'Str')) - super(BanditNodeVisitor, self).generic_visit(node) + self.generic_visit(node) def visit_Exec(self, node): self.context['str'] = 'exec' - self.logger.debug("visit_Exec called (%s)", ast.dump(node)) + if self.debug: + self.logger.debug("visit_Exec called (%s)", ast.dump(node)) + self.update_scores(self.tester.run_tests(self.context, 'Exec')) - super(BanditNodeVisitor, self).generic_visit(node) + self.generic_visit(node) def visit_Assert(self, node): self.context['str'] = 'assert' - self.logger.debug("visit_Assert called (%s)", ast.dump(node)) + if self.debug: + self.logger.debug("visit_Assert called (%s)", ast.dump(node)) + self.update_scores(self.tester.run_tests(self.context, 'Assert')) - super(BanditNodeVisitor, self).generic_visit(node) + self.generic_visit(node) def visit(self, node): '''Generic visitor @@ -226,14 +238,17 @@ class BanditNodeVisitor(ast.NodeVisitor): :return: - ''' self.context = copy.copy(self.context_template) - self.logger.debug(ast.dump(node)) + + if self.debug: + self.logger.debug(ast.dump(node)) + self.metaast.add_node(node, '', self.depth) if hasattr(node, 'lineno'): self.context['lineno'] = node.lineno if ("# nosec" in self.lines[node.lineno - 1] or "#nosec" in self.lines[node.lineno - 1]): self.logger.debug("skipped, nosec") - return self.scores # skip this node and all sub-nodes + return self.context['node'] = node self.context['linerange'] = b_utils.linerange_fix(node) @@ -243,10 +258,32 @@ class BanditNodeVisitor(ast.NodeVisitor): self.logger.debug("entering: %s %s [%s]", hex(id(node)), type(node), self.depth) self.depth += 1 - super(BanditNodeVisitor, self).visit(node) + + method = 'visit_' + node.__class__.__name__ + visitor = getattr(self, method, self.generic_visit) + visitor(node) + self.depth -= 1 self.logger.debug("%s\texiting : %s", self.depth, hex(id(node))) - return self.scores + + def generic_visit(self, node): + """Drive the visitor.""" + for _, value in ast.iter_fields(node): + if isinstance(value, list): + max_idx = len(value) - 1 + for idx, item in enumerate(value): + if isinstance(item, ast.AST): + if idx < max_idx: + setattr(item, 'sibling', value[idx + 1]) + else: + setattr(item, 'sibling', None) + setattr(item, 'parent', node) + self.visit(node=item) + + elif isinstance(value, ast.AST): + setattr(value, 'sibling', None) + setattr(value, 'parent', node) + self.visit(node=value) def update_scores(self, scores): '''Score updater @@ -272,6 +309,5 @@ class BanditNodeVisitor(ast.NodeVisitor): fdata.seek(0) self.lines = fdata.readlines() f_ast = ast.parse("".join(self.lines)) - b_utils.embelish_ast(f_ast) - self.visit(f_ast) + self.generic_visit(f_ast) return self.scores diff --git a/bandit/core/utils.py b/bandit/core/utils.py index abe1ac7f..a8cbbbbc 100644 --- a/bandit/core/utils.py +++ b/bandit/core/utils.py @@ -263,24 +263,6 @@ def safe_str(obj): return unicode(obj).encode('unicode_escape') -def embelish_ast(node): - """Add a parent and sibling info to every node.""" - for _, value in ast.iter_fields(node): - if isinstance(value, list): - last = None - for item in value: - if isinstance(item, ast.AST): - if last is not None: - setattr(last, 'sibling', item) - last = item - setattr(item, 'parent', node) - embelish_ast(item) - - elif isinstance(value, ast.AST): - setattr(value, 'parent', node) - embelish_ast(value) - - def linerange(node): """Get line number range from a node.""" strip = {"body": None, "orelse": None,