Merge "Faster Bandit"

This commit is contained in:
Jenkins 2015-07-09 18:20:42 +00:00 committed by Gerrit Code Review
commit acc5f3ca4a
2 changed files with 60 additions and 42 deletions

View File

@ -23,7 +23,7 @@ from bandit.core import utils as b_utils
from bandit.core.utils import InvalidModulePath from bandit.core.utils import InvalidModulePath
class BanditNodeVisitor(ast.NodeVisitor): class BanditNodeVisitor(object):
imports = set() imports = set()
import_aliases = {} import_aliases = {}
@ -35,7 +35,7 @@ class BanditNodeVisitor(ast.NodeVisitor):
depth = 0 depth = 0
context = None context = None
context_template = {'node': None, 'filename': None, 'statement': None, context_template = {'node': None, 'filename': None,
'name': None, 'qualname': None, 'module': None, 'name': None, 'qualname': None, 'module': None,
'imports': None, 'import_aliases': None, 'call': None, 'imports': None, 'import_aliases': None, 'call': None,
'function': None, 'lineno': None, 'skip_lines': None} 'function': None, 'lineno': None, 'skip_lines': None}
@ -80,9 +80,12 @@ class BanditNodeVisitor(ast.NodeVisitor):
:return: - :return: -
''' '''
if self.debug:
self.logger.debug("visit_ClassDef called (%s)", ast.dump(node))
# For all child nodes, add this class name to current namespace # For all child nodes, add this class name to current namespace
self.namespace = b_utils.namespace_path_join(self.namespace, node.name) self.namespace = b_utils.namespace_path_join(self.namespace, node.name)
super(BanditNodeVisitor, self).generic_visit(node) self.generic_visit(node)
self.namespace = b_utils.namespace_path_split(self.namespace)[0] self.namespace = b_utils.namespace_path_split(self.namespace)[0]
def visit_FunctionDef(self, node): def visit_FunctionDef(self, node):
@ -97,7 +100,8 @@ class BanditNodeVisitor(ast.NodeVisitor):
self.context['function'] = node self.context['function'] = node
self.logger.debug("visit_FunctionDef called (%s)", ast.dump(node)) if self.debug:
self.logger.debug("visit_FunctionDef called (%s)", ast.dump(node))
qualname = self.namespace + '.' + b_utils.get_func_name(node) qualname = self.namespace + '.' + b_utils.get_func_name(node)
name = qualname.split('.')[-1] name = qualname.split('.')[-1]
@ -109,7 +113,7 @@ class BanditNodeVisitor(ast.NodeVisitor):
# current namespace # current namespace
self.namespace = b_utils.namespace_path_join(self.namespace, name) self.namespace = b_utils.namespace_path_join(self.namespace, name)
self.update_scores(self.tester.run_tests(self.context, 'FunctionDef')) self.update_scores(self.tester.run_tests(self.context, 'FunctionDef'))
super(BanditNodeVisitor, self).generic_visit(node) self.generic_visit(node)
self.namespace = b_utils.namespace_path_split(self.namespace)[0] self.namespace = b_utils.namespace_path_split(self.namespace)[0]
def visit_Call(self, node): def visit_Call(self, node):
@ -123,7 +127,8 @@ class BanditNodeVisitor(ast.NodeVisitor):
self.context['call'] = node self.context['call'] = node
self.logger.debug("visit_Call called (%s)", ast.dump(node)) if self.debug:
self.logger.debug("visit_Call called (%s)", ast.dump(node))
qualname = b_utils.get_call_name(node, self.import_aliases) qualname = b_utils.get_call_name(node, self.import_aliases)
name = qualname.split('.')[-1] name = qualname.split('.')[-1]
@ -132,7 +137,7 @@ class BanditNodeVisitor(ast.NodeVisitor):
self.context['name'] = name self.context['name'] = name
self.update_scores(self.tester.run_tests(self.context, 'Call')) self.update_scores(self.tester.run_tests(self.context, 'Call'))
super(BanditNodeVisitor, self).generic_visit(node) self.generic_visit(node)
def visit_Import(self, node): def visit_Import(self, node):
'''Visitor for AST Import nodes '''Visitor for AST Import nodes
@ -142,15 +147,16 @@ class BanditNodeVisitor(ast.NodeVisitor):
:param node: The node that is being inspected :param node: The node that is being inspected
:return: - :return: -
''' '''
if self.debug:
self.logger.debug("visit_Import called (%s)", ast.dump(node))
self.logger.debug("visit_Import called (%s)", ast.dump(node))
for nodename in node.names: for nodename in node.names:
if nodename.asname: if nodename.asname:
self.context['import_aliases'][nodename.asname] = nodename.name self.context['import_aliases'][nodename.asname] = nodename.name
self.context['imports'].add(nodename.name) self.context['imports'].add(nodename.name)
self.context['module'] = nodename.name self.context['module'] = nodename.name
self.update_scores(self.tester.run_tests(self.context, 'Import')) self.update_scores(self.tester.run_tests(self.context, 'Import'))
super(BanditNodeVisitor, self).generic_visit(node) self.generic_visit(node)
def visit_ImportFrom(self, node): def visit_ImportFrom(self, node):
'''Visitor for AST Import nodes '''Visitor for AST Import nodes
@ -160,8 +166,8 @@ class BanditNodeVisitor(ast.NodeVisitor):
:param node: The node that is being inspected :param node: The node that is being inspected
:return: - :return: -
''' '''
if self.debug:
self.logger.debug("visit_ImportFrom called (%s)", ast.dump(node)) self.logger.debug("visit_ImportFrom called (%s)", ast.dump(node))
module = node.module module = node.module
if module is None: if module is None:
@ -186,7 +192,7 @@ class BanditNodeVisitor(ast.NodeVisitor):
self.context['module'] = module self.context['module'] = module
self.context['name'] = nodename.name self.context['name'] = nodename.name
self.update_scores(self.tester.run_tests(self.context, 'ImportFrom')) self.update_scores(self.tester.run_tests(self.context, 'ImportFrom'))
super(BanditNodeVisitor, self).generic_visit(node) self.generic_visit(node)
def visit_Str(self, node): def visit_Str(self, node):
'''Visitor for AST String nodes '''Visitor for AST String nodes
@ -197,26 +203,32 @@ class BanditNodeVisitor(ast.NodeVisitor):
:return: - :return: -
''' '''
self.context['str'] = node.s self.context['str'] = node.s
self.logger.debug("visit_Str called (%s)", ast.dump(node))
if self.debug:
self.logger.debug("visit_Str called (%s)", ast.dump(node))
if not isinstance(node.parent, ast.Expr): # docstring if not isinstance(node.parent, ast.Expr): # docstring
self.context['linerange'] = b_utils.linerange_fix(node.parent) self.context['linerange'] = b_utils.linerange_fix(node.parent)
self.update_scores(self.tester.run_tests(self.context, 'Str')) self.update_scores(self.tester.run_tests(self.context, 'Str'))
super(BanditNodeVisitor, self).generic_visit(node) self.generic_visit(node)
def visit_Exec(self, node): def visit_Exec(self, node):
self.context['str'] = 'exec' self.context['str'] = 'exec'
self.logger.debug("visit_Exec called (%s)", ast.dump(node)) if self.debug:
self.logger.debug("visit_Exec called (%s)", ast.dump(node))
self.update_scores(self.tester.run_tests(self.context, 'Exec')) self.update_scores(self.tester.run_tests(self.context, 'Exec'))
super(BanditNodeVisitor, self).generic_visit(node) self.generic_visit(node)
def visit_Assert(self, node): def visit_Assert(self, node):
self.context['str'] = 'assert' self.context['str'] = 'assert'
self.logger.debug("visit_Assert called (%s)", ast.dump(node)) if self.debug:
self.logger.debug("visit_Assert called (%s)", ast.dump(node))
self.update_scores(self.tester.run_tests(self.context, 'Assert')) self.update_scores(self.tester.run_tests(self.context, 'Assert'))
super(BanditNodeVisitor, self).generic_visit(node) self.generic_visit(node)
def visit(self, node): def visit(self, node):
'''Generic visitor '''Generic visitor
@ -226,14 +238,17 @@ class BanditNodeVisitor(ast.NodeVisitor):
:return: - :return: -
''' '''
self.context = copy.copy(self.context_template) self.context = copy.copy(self.context_template)
self.logger.debug(ast.dump(node))
if self.debug:
self.logger.debug(ast.dump(node))
self.metaast.add_node(node, '', self.depth) self.metaast.add_node(node, '', self.depth)
if hasattr(node, 'lineno'): if hasattr(node, 'lineno'):
self.context['lineno'] = node.lineno self.context['lineno'] = node.lineno
if ("# nosec" in self.lines[node.lineno - 1] or if ("# nosec" in self.lines[node.lineno - 1] or
"#nosec" in self.lines[node.lineno - 1]): "#nosec" in self.lines[node.lineno - 1]):
self.logger.debug("skipped, nosec") self.logger.debug("skipped, nosec")
return self.scores # skip this node and all sub-nodes return
self.context['node'] = node self.context['node'] = node
self.context['linerange'] = b_utils.linerange_fix(node) self.context['linerange'] = b_utils.linerange_fix(node)
@ -243,10 +258,32 @@ class BanditNodeVisitor(ast.NodeVisitor):
self.logger.debug("entering: %s %s [%s]", hex(id(node)), type(node), self.logger.debug("entering: %s %s [%s]", hex(id(node)), type(node),
self.depth) self.depth)
self.depth += 1 self.depth += 1
super(BanditNodeVisitor, self).visit(node)
method = 'visit_' + node.__class__.__name__
visitor = getattr(self, method, self.generic_visit)
visitor(node)
self.depth -= 1 self.depth -= 1
self.logger.debug("%s\texiting : %s", self.depth, hex(id(node))) self.logger.debug("%s\texiting : %s", self.depth, hex(id(node)))
return self.scores
def generic_visit(self, node):
"""Drive the visitor."""
for _, value in ast.iter_fields(node):
if isinstance(value, list):
max_idx = len(value) - 1
for idx, item in enumerate(value):
if isinstance(item, ast.AST):
if idx < max_idx:
setattr(item, 'sibling', value[idx + 1])
else:
setattr(item, 'sibling', None)
setattr(item, 'parent', node)
self.visit(node=item)
elif isinstance(value, ast.AST):
setattr(value, 'sibling', None)
setattr(value, 'parent', node)
self.visit(node=value)
def update_scores(self, scores): def update_scores(self, scores):
'''Score updater '''Score updater
@ -272,6 +309,5 @@ class BanditNodeVisitor(ast.NodeVisitor):
fdata.seek(0) fdata.seek(0)
self.lines = fdata.readlines() self.lines = fdata.readlines()
f_ast = ast.parse("".join(self.lines)) f_ast = ast.parse("".join(self.lines))
b_utils.embelish_ast(f_ast) self.generic_visit(f_ast)
self.visit(f_ast)
return self.scores return self.scores

View File

@ -263,24 +263,6 @@ def safe_str(obj):
return unicode(obj).encode('unicode_escape') return unicode(obj).encode('unicode_escape')
def embelish_ast(node):
"""Add a parent and sibling info to every node."""
for _, value in ast.iter_fields(node):
if isinstance(value, list):
last = None
for item in value:
if isinstance(item, ast.AST):
if last is not None:
setattr(last, 'sibling', item)
last = item
setattr(item, 'parent', node)
embelish_ast(item)
elif isinstance(value, ast.AST):
setattr(value, 'parent', node)
embelish_ast(value)
def linerange(node): def linerange(node):
"""Get line number range from a node.""" """Get line number range from a node."""
strip = {"body": None, "orelse": None, strip = {"body": None, "orelse": None,