Merge "Faster Bandit"
This commit is contained in:
commit
acc5f3ca4a
@ -23,7 +23,7 @@ from bandit.core import utils as b_utils
|
||||
from bandit.core.utils import InvalidModulePath
|
||||
|
||||
|
||||
class BanditNodeVisitor(ast.NodeVisitor):
|
||||
class BanditNodeVisitor(object):
|
||||
|
||||
imports = set()
|
||||
import_aliases = {}
|
||||
@ -35,7 +35,7 @@ class BanditNodeVisitor(ast.NodeVisitor):
|
||||
depth = 0
|
||||
|
||||
context = None
|
||||
context_template = {'node': None, 'filename': None, 'statement': None,
|
||||
context_template = {'node': None, 'filename': None,
|
||||
'name': None, 'qualname': None, 'module': None,
|
||||
'imports': None, 'import_aliases': None, 'call': None,
|
||||
'function': None, 'lineno': None, 'skip_lines': None}
|
||||
@ -80,9 +80,12 @@ class BanditNodeVisitor(ast.NodeVisitor):
|
||||
:return: -
|
||||
'''
|
||||
|
||||
if self.debug:
|
||||
self.logger.debug("visit_ClassDef called (%s)", ast.dump(node))
|
||||
|
||||
# For all child nodes, add this class name to current namespace
|
||||
self.namespace = b_utils.namespace_path_join(self.namespace, node.name)
|
||||
super(BanditNodeVisitor, self).generic_visit(node)
|
||||
self.generic_visit(node)
|
||||
self.namespace = b_utils.namespace_path_split(self.namespace)[0]
|
||||
|
||||
def visit_FunctionDef(self, node):
|
||||
@ -97,6 +100,7 @@ class BanditNodeVisitor(ast.NodeVisitor):
|
||||
|
||||
self.context['function'] = node
|
||||
|
||||
if self.debug:
|
||||
self.logger.debug("visit_FunctionDef called (%s)", ast.dump(node))
|
||||
|
||||
qualname = self.namespace + '.' + b_utils.get_func_name(node)
|
||||
@ -109,7 +113,7 @@ class BanditNodeVisitor(ast.NodeVisitor):
|
||||
# current namespace
|
||||
self.namespace = b_utils.namespace_path_join(self.namespace, name)
|
||||
self.update_scores(self.tester.run_tests(self.context, 'FunctionDef'))
|
||||
super(BanditNodeVisitor, self).generic_visit(node)
|
||||
self.generic_visit(node)
|
||||
self.namespace = b_utils.namespace_path_split(self.namespace)[0]
|
||||
|
||||
def visit_Call(self, node):
|
||||
@ -123,6 +127,7 @@ class BanditNodeVisitor(ast.NodeVisitor):
|
||||
|
||||
self.context['call'] = node
|
||||
|
||||
if self.debug:
|
||||
self.logger.debug("visit_Call called (%s)", ast.dump(node))
|
||||
|
||||
qualname = b_utils.get_call_name(node, self.import_aliases)
|
||||
@ -132,7 +137,7 @@ class BanditNodeVisitor(ast.NodeVisitor):
|
||||
self.context['name'] = name
|
||||
|
||||
self.update_scores(self.tester.run_tests(self.context, 'Call'))
|
||||
super(BanditNodeVisitor, self).generic_visit(node)
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_Import(self, node):
|
||||
'''Visitor for AST Import nodes
|
||||
@ -142,15 +147,16 @@ class BanditNodeVisitor(ast.NodeVisitor):
|
||||
:param node: The node that is being inspected
|
||||
:return: -
|
||||
'''
|
||||
|
||||
if self.debug:
|
||||
self.logger.debug("visit_Import called (%s)", ast.dump(node))
|
||||
|
||||
for nodename in node.names:
|
||||
if nodename.asname:
|
||||
self.context['import_aliases'][nodename.asname] = nodename.name
|
||||
self.context['imports'].add(nodename.name)
|
||||
self.context['module'] = nodename.name
|
||||
self.update_scores(self.tester.run_tests(self.context, 'Import'))
|
||||
super(BanditNodeVisitor, self).generic_visit(node)
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_ImportFrom(self, node):
|
||||
'''Visitor for AST Import nodes
|
||||
@ -160,7 +166,7 @@ class BanditNodeVisitor(ast.NodeVisitor):
|
||||
:param node: The node that is being inspected
|
||||
:return: -
|
||||
'''
|
||||
|
||||
if self.debug:
|
||||
self.logger.debug("visit_ImportFrom called (%s)", ast.dump(node))
|
||||
|
||||
module = node.module
|
||||
@ -186,7 +192,7 @@ class BanditNodeVisitor(ast.NodeVisitor):
|
||||
self.context['module'] = module
|
||||
self.context['name'] = nodename.name
|
||||
self.update_scores(self.tester.run_tests(self.context, 'ImportFrom'))
|
||||
super(BanditNodeVisitor, self).generic_visit(node)
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_Str(self, node):
|
||||
'''Visitor for AST String nodes
|
||||
@ -197,26 +203,32 @@ class BanditNodeVisitor(ast.NodeVisitor):
|
||||
:return: -
|
||||
'''
|
||||
self.context['str'] = node.s
|
||||
|
||||
if self.debug:
|
||||
self.logger.debug("visit_Str called (%s)", ast.dump(node))
|
||||
|
||||
if not isinstance(node.parent, ast.Expr): # docstring
|
||||
self.context['linerange'] = b_utils.linerange_fix(node.parent)
|
||||
self.update_scores(self.tester.run_tests(self.context, 'Str'))
|
||||
super(BanditNodeVisitor, self).generic_visit(node)
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_Exec(self, node):
|
||||
self.context['str'] = 'exec'
|
||||
|
||||
if self.debug:
|
||||
self.logger.debug("visit_Exec called (%s)", ast.dump(node))
|
||||
|
||||
self.update_scores(self.tester.run_tests(self.context, 'Exec'))
|
||||
super(BanditNodeVisitor, self).generic_visit(node)
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_Assert(self, node):
|
||||
self.context['str'] = 'assert'
|
||||
|
||||
if self.debug:
|
||||
self.logger.debug("visit_Assert called (%s)", ast.dump(node))
|
||||
|
||||
self.update_scores(self.tester.run_tests(self.context, 'Assert'))
|
||||
super(BanditNodeVisitor, self).generic_visit(node)
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit(self, node):
|
||||
'''Generic visitor
|
||||
@ -226,14 +238,17 @@ class BanditNodeVisitor(ast.NodeVisitor):
|
||||
:return: -
|
||||
'''
|
||||
self.context = copy.copy(self.context_template)
|
||||
|
||||
if self.debug:
|
||||
self.logger.debug(ast.dump(node))
|
||||
|
||||
self.metaast.add_node(node, '', self.depth)
|
||||
if hasattr(node, 'lineno'):
|
||||
self.context['lineno'] = node.lineno
|
||||
if ("# nosec" in self.lines[node.lineno - 1] or
|
||||
"#nosec" in self.lines[node.lineno - 1]):
|
||||
self.logger.debug("skipped, nosec")
|
||||
return self.scores # skip this node and all sub-nodes
|
||||
return
|
||||
|
||||
self.context['node'] = node
|
||||
self.context['linerange'] = b_utils.linerange_fix(node)
|
||||
@ -243,10 +258,32 @@ class BanditNodeVisitor(ast.NodeVisitor):
|
||||
self.logger.debug("entering: %s %s [%s]", hex(id(node)), type(node),
|
||||
self.depth)
|
||||
self.depth += 1
|
||||
super(BanditNodeVisitor, self).visit(node)
|
||||
|
||||
method = 'visit_' + node.__class__.__name__
|
||||
visitor = getattr(self, method, self.generic_visit)
|
||||
visitor(node)
|
||||
|
||||
self.depth -= 1
|
||||
self.logger.debug("%s\texiting : %s", self.depth, hex(id(node)))
|
||||
return self.scores
|
||||
|
||||
def generic_visit(self, node):
|
||||
"""Drive the visitor."""
|
||||
for _, value in ast.iter_fields(node):
|
||||
if isinstance(value, list):
|
||||
max_idx = len(value) - 1
|
||||
for idx, item in enumerate(value):
|
||||
if isinstance(item, ast.AST):
|
||||
if idx < max_idx:
|
||||
setattr(item, 'sibling', value[idx + 1])
|
||||
else:
|
||||
setattr(item, 'sibling', None)
|
||||
setattr(item, 'parent', node)
|
||||
self.visit(node=item)
|
||||
|
||||
elif isinstance(value, ast.AST):
|
||||
setattr(value, 'sibling', None)
|
||||
setattr(value, 'parent', node)
|
||||
self.visit(node=value)
|
||||
|
||||
def update_scores(self, scores):
|
||||
'''Score updater
|
||||
@ -272,6 +309,5 @@ class BanditNodeVisitor(ast.NodeVisitor):
|
||||
fdata.seek(0)
|
||||
self.lines = fdata.readlines()
|
||||
f_ast = ast.parse("".join(self.lines))
|
||||
b_utils.embelish_ast(f_ast)
|
||||
self.visit(f_ast)
|
||||
self.generic_visit(f_ast)
|
||||
return self.scores
|
||||
|
@ -263,24 +263,6 @@ def safe_str(obj):
|
||||
return unicode(obj).encode('unicode_escape')
|
||||
|
||||
|
||||
def embelish_ast(node):
|
||||
"""Add a parent and sibling info to every node."""
|
||||
for _, value in ast.iter_fields(node):
|
||||
if isinstance(value, list):
|
||||
last = None
|
||||
for item in value:
|
||||
if isinstance(item, ast.AST):
|
||||
if last is not None:
|
||||
setattr(last, 'sibling', item)
|
||||
last = item
|
||||
setattr(item, 'parent', node)
|
||||
embelish_ast(item)
|
||||
|
||||
elif isinstance(value, ast.AST):
|
||||
setattr(value, 'parent', node)
|
||||
embelish_ast(value)
|
||||
|
||||
|
||||
def linerange(node):
|
||||
"""Get line number range from a node."""
|
||||
strip = {"body": None, "orelse": None,
|
||||
|
Loading…
x
Reference in New Issue
Block a user