Removing statement buffer
The statement buffer previously mangled the AST to produce chunks that were passed, in isolation, to the node visitor. This caused some tests to become impossible, if the needed data from two chunks had been separated. This change fixes things by allowing the node_visitor to process the AST directly. To make this work it was necessary to embellish the AST with a few extra bits of data, a node parent and sibling. This is a none destructive change that does not limit any of the testing potential of bandit compared to the raw AST (indeed it enhances it). As a beneficial side effect, "nosec" comments can now be used to skip entire chunks of code, not just lines, for example try: ... except: #nosec ... will now skip everything in except block. Change-Id: I47c4963b143868a2bd2f171dbcf39e4f97a929da
This commit is contained in:
parent
055598028a
commit
ca48f32924
@ -23,128 +23,6 @@ from bandit.core import utils as b_utils
|
||||
from bandit.core.utils import InvalidModulePath
|
||||
|
||||
|
||||
if hasattr(ast, 'TryExcept'):
|
||||
ast_Try = (ast.TryExcept, ast.TryFinally)
|
||||
else: # Python 3.3+
|
||||
ast_Try = ast.Try
|
||||
|
||||
|
||||
class StatementBuffer():
|
||||
'''Buffer for code statements
|
||||
|
||||
Creates a buffer to store a code file as individual statements
|
||||
for AST processing
|
||||
'''
|
||||
def __init__(self):
|
||||
self._buffer = []
|
||||
self.skip_lines = []
|
||||
|
||||
def load_buffer(self, fdata):
|
||||
'''Buffer initialization
|
||||
|
||||
Read the file as lines, so we can store the length of the file
|
||||
so we don't lose multi-line statements at the bottom of the target
|
||||
file
|
||||
:param fdata: The code to be parsed into the buffer
|
||||
'''
|
||||
self._buffer = []
|
||||
self.skip_lines = []
|
||||
lines = fdata.readlines()
|
||||
self.file_len = len(lines)
|
||||
|
||||
for lineno in range(self.file_len):
|
||||
found = False
|
||||
for flag in constants.SKIP_FLAGS:
|
||||
if "#" + flag in lines[lineno].replace(" ", "").lower():
|
||||
found = True
|
||||
if found:
|
||||
self.skip_lines.append(lineno + 1)
|
||||
|
||||
f_ast = ast.parse("".join(lines))
|
||||
# We need to expand body blocks within compound statements
|
||||
# into our statement buffer so each gets processed in
|
||||
# isolation
|
||||
tmp_buf = f_ast.body
|
||||
while len(tmp_buf):
|
||||
# For each statement, if it is one of the special statement
|
||||
# types which contain a body, we first update the tmp_buf
|
||||
# adding the internal body statements to the beginning of
|
||||
# the temporary buffer, then clear the body of the special
|
||||
# statement before adding it to the primary buffer
|
||||
stmt = tmp_buf.pop(0)
|
||||
if (isinstance(stmt, ast.ClassDef)
|
||||
or isinstance(stmt, ast.FunctionDef)
|
||||
or isinstance(stmt, ast.With)
|
||||
or isinstance(stmt, ast.Module)
|
||||
or isinstance(stmt, ast.Interactive)):
|
||||
stmt.body.extend(tmp_buf)
|
||||
tmp_buf = stmt.body
|
||||
stmt.body = []
|
||||
elif (isinstance(stmt, ast.For)
|
||||
or isinstance(stmt, ast.While)
|
||||
or isinstance(stmt, ast.If)):
|
||||
stmt.body.extend(stmt.orelse)
|
||||
stmt.body.extend(tmp_buf)
|
||||
tmp_buf = stmt.body
|
||||
stmt.body = []
|
||||
stmt.orelse = []
|
||||
elif isinstance(stmt, ast_Try):
|
||||
for handler in getattr(stmt, 'handlers', []):
|
||||
stmt.body.extend(handler.body)
|
||||
stmt.body.extend(getattr(stmt, 'orelse', []))
|
||||
stmt.body.extend(tmp_buf)
|
||||
tmp_buf = stmt.body
|
||||
stmt.body = []
|
||||
stmt.orelse = []
|
||||
stmt.handlers = []
|
||||
stmt.finalbody = []
|
||||
|
||||
# once we are sure it's either a single statement or that
|
||||
# any content in a compound statement body has been removed
|
||||
# we can add it to our primary buffer. The compound body
|
||||
# must be removed so the ast isn't walked multiple times
|
||||
# and isn't included in line-by-line output
|
||||
self._buffer.append(stmt)
|
||||
|
||||
def get_next(self, pop=True):
|
||||
'''Statment Retrieval
|
||||
|
||||
Grab the next statement in the buffer for detailed processing
|
||||
:param pop: shift next statement off array (default) or just lookahead
|
||||
:return statement: the next statement to be processed, or None
|
||||
'''
|
||||
if len(self._buffer):
|
||||
statement = {}
|
||||
if pop:
|
||||
# shift the next statement off the array
|
||||
statement['node'] = self._buffer.pop(0)
|
||||
else:
|
||||
# get the next statement without shift
|
||||
statement['node'] = self._buffer[0]
|
||||
statement['linerange'] = self.linenumber_range(statement['node'])
|
||||
return statement
|
||||
return None
|
||||
|
||||
def linenumber_range(self, node):
|
||||
'''Get set of line numbers for statement
|
||||
|
||||
Walks the given statement node, and creates a set
|
||||
of line numbers covered by the code
|
||||
:param node: The statment line numbers are required for
|
||||
:return lines: A set of line numbers
|
||||
'''
|
||||
lines = set()
|
||||
for n in ast.walk(node):
|
||||
if hasattr(n, 'lineno'):
|
||||
lines.add(n.lineno)
|
||||
# we'll return a range here, because in some cases ast.walk skips over
|
||||
# important parts, such as the middle lines in a multi-line string
|
||||
return range(min(lines), max(lines) + 1)
|
||||
|
||||
def get_skip_lines(self):
|
||||
return self.skip_lines
|
||||
|
||||
|
||||
class BanditNodeVisitor(ast.NodeVisitor):
|
||||
|
||||
imports = set()
|
||||
@ -192,8 +70,7 @@ class BanditNodeVisitor(ast.NodeVisitor):
|
||||
self.fname)
|
||||
self.namespace = ""
|
||||
self.logger.debug('Module qualified name: %s', self.namespace)
|
||||
self.stmt_buffer = StatementBuffer()
|
||||
self.statement = {}
|
||||
self.lines = []
|
||||
|
||||
def visit_ClassDef(self, node):
|
||||
'''Visitor for AST ClassDef node
|
||||
@ -322,20 +199,8 @@ class BanditNodeVisitor(ast.NodeVisitor):
|
||||
self.context['str'] = node.s
|
||||
self.logger.debug("visit_Str called (%s)", ast.dump(node))
|
||||
|
||||
# This check is to make sure we aren't running tests against
|
||||
# docstrings (any statement that is just a string, nothing else)
|
||||
node_object = self.context['statement']['node']
|
||||
|
||||
# docstrings can be represented as standalone ast.Str
|
||||
is_str = isinstance(node_object, ast.Str)
|
||||
# or ast.Expr with a value of type ast.Str
|
||||
if (isinstance(node_object, ast.Expr) and
|
||||
isinstance(node_object.value, ast.Str)):
|
||||
is_standalone_expr = True
|
||||
else:
|
||||
is_standalone_expr = False
|
||||
# if we don't have either one of those, run the test
|
||||
if not (is_str or is_standalone_expr):
|
||||
if not isinstance(node.parent, ast.Expr): # docstring
|
||||
self.context['linerange'] = b_utils.linerange_fix(node.parent)
|
||||
self.update_scores(self.tester.run_tests(self.context, 'Str'))
|
||||
super(BanditNodeVisitor, self).generic_visit(node)
|
||||
|
||||
@ -360,31 +225,19 @@ class BanditNodeVisitor(ast.NodeVisitor):
|
||||
:param node: The node that is being inspected
|
||||
:return: -
|
||||
'''
|
||||
self.context = copy.copy(self.context_template)
|
||||
self.logger.debug(ast.dump(node))
|
||||
self.metaast.add_node(node, '', self.depth)
|
||||
|
||||
self.context = copy.copy(self.context_template)
|
||||
self.context['statement'] = self.statement
|
||||
self.context['node'] = node
|
||||
self.context['filename'] = self.fname
|
||||
if hasattr(node, 'lineno'):
|
||||
self.context['lineno'] = node.lineno
|
||||
if ("# nosec" in self.lines[node.lineno - 1] or
|
||||
"#nosec" in self.lines[node.lineno - 1]):
|
||||
self.logger.debug("skipped, nosec")
|
||||
return self.scores # skip this node and all sub-nodes
|
||||
|
||||
# deal with multiline strings lineno behavior (Python issue #16806)
|
||||
current_lineno = self.context['lineno']
|
||||
next_statement = self.stmt_buffer.get_next(pop=False)
|
||||
if next_statement is not None:
|
||||
next_lineno = min(next_statement['linerange'])
|
||||
else:
|
||||
next_lineno = self.stmt_buffer.file_len
|
||||
|
||||
if next_lineno - current_lineno > 1:
|
||||
self.context['statement']['linerange'] = range(
|
||||
min(self.context['statement']['linerange']),
|
||||
next_lineno
|
||||
)
|
||||
|
||||
self.context['skip_lines'] = self.stmt_buffer.get_skip_lines()
|
||||
self.context['node'] = node
|
||||
self.context['linerange'] = b_utils.linerange_fix(node)
|
||||
self.context['filename'] = self.fname
|
||||
|
||||
self.seen += 1
|
||||
self.logger.debug("entering: %s %s [%s]", hex(id(node)), type(node),
|
||||
@ -412,18 +265,13 @@ class BanditNodeVisitor(ast.NodeVisitor):
|
||||
def process(self, fdata):
|
||||
'''Main process loop
|
||||
|
||||
Iniitalizes the statement buffer, iterates over each statement
|
||||
in the buffer testing each AST in turn
|
||||
Build and process the AST
|
||||
:param fdata: the open filehandle for the code to be processed
|
||||
:return score: the aggregated score for the current file
|
||||
'''
|
||||
self.stmt_buffer.load_buffer(fdata)
|
||||
self.statement = self.stmt_buffer.get_next()
|
||||
while self.statement is not None:
|
||||
self.logger.debug('New statement loaded')
|
||||
self.logger.debug('s_node: %s', ast.dump(self.statement['node']))
|
||||
self.logger.debug('s_lineno: %s', self.statement['linerange'])
|
||||
|
||||
self.visit(self.statement['node'])
|
||||
self.statement = self.stmt_buffer.get_next()
|
||||
fdata.seek(0)
|
||||
self.lines = fdata.readlines()
|
||||
f_ast = ast.parse("".join(self.lines))
|
||||
b_utils.embelish_ast(f_ast)
|
||||
self.visit(f_ast)
|
||||
return self.scores
|
||||
|
@ -61,7 +61,7 @@ class BanditResultStore():
|
||||
'''
|
||||
filename = context['filename']
|
||||
lineno = context['lineno']
|
||||
linerange = context['statement']['linerange']
|
||||
linerange = context['linerange']
|
||||
(issue_severity, issue_confidence, issue_text) = issue
|
||||
|
||||
if self.agg_type == 'vuln':
|
||||
|
@ -50,50 +50,49 @@ class BanditTester():
|
||||
'CONFIDENCE': [0] * len(constants.RANKING)
|
||||
}
|
||||
|
||||
if not raw_context['lineno'] in raw_context['skip_lines']:
|
||||
tests = self.testset.get_tests(checktype)
|
||||
for name, test in six.iteritems(tests):
|
||||
# execute test with the an instance of the context class
|
||||
temp_context = copy.copy(raw_context)
|
||||
context = b_context.Context(temp_context)
|
||||
try:
|
||||
if hasattr(test, '_takes_config'):
|
||||
# TODO(??): Possibly allow override from profile
|
||||
test_config = self.config.get_option(
|
||||
test._takes_config)
|
||||
result = test(context, test_config)
|
||||
else:
|
||||
result = test(context)
|
||||
tests = self.testset.get_tests(checktype)
|
||||
for name, test in six.iteritems(tests):
|
||||
# execute test with the an instance of the context class
|
||||
temp_context = copy.copy(raw_context)
|
||||
context = b_context.Context(temp_context)
|
||||
try:
|
||||
if hasattr(test, '_takes_config'):
|
||||
# TODO(??): Possibly allow override from profile
|
||||
test_config = self.config.get_option(
|
||||
test._takes_config)
|
||||
result = test(context, test_config)
|
||||
else:
|
||||
result = test(context)
|
||||
|
||||
# the test call returns a 2- or 3-tuple
|
||||
# - (issue_severity, issue_text) or
|
||||
# - (issue_severity, issue_confidence, issue_text)
|
||||
# the test call returns a 2- or 3-tuple
|
||||
# - (issue_severity, issue_text) or
|
||||
# - (issue_severity, issue_confidence, issue_text)
|
||||
|
||||
# add default confidence level, if not returned by test
|
||||
if (result is not None and len(result) == 2):
|
||||
result = (
|
||||
result[0],
|
||||
constants.CONFIDENCE_DEFAULT,
|
||||
result[1]
|
||||
)
|
||||
# add default confidence level, if not returned by test
|
||||
if (result is not None and len(result) == 2):
|
||||
result = (
|
||||
result[0],
|
||||
constants.CONFIDENCE_DEFAULT,
|
||||
result[1]
|
||||
)
|
||||
|
||||
# if we have a result, record it and update scores
|
||||
if result is not None:
|
||||
self.results.add(temp_context, name, result)
|
||||
self.logger.debug(
|
||||
"Issue identified by %s: %s", name, result
|
||||
)
|
||||
sev = constants.RANKING.index(result[0])
|
||||
val = constants.RANKING_VALUES[result[0]]
|
||||
scores['SEVERITY'][sev] += val
|
||||
con = constants.RANKING.index(result[1])
|
||||
val = constants.RANKING_VALUES[result[1]]
|
||||
scores['CONFIDENCE'][con] += val
|
||||
# if we have a result, record it and update scores
|
||||
if result is not None:
|
||||
self.results.add(temp_context, name, result)
|
||||
self.logger.debug(
|
||||
"Issue identified by %s: %s", name, result
|
||||
)
|
||||
sev = constants.RANKING.index(result[0])
|
||||
val = constants.RANKING_VALUES[result[0]]
|
||||
scores['SEVERITY'][sev] += val
|
||||
con = constants.RANKING.index(result[1])
|
||||
val = constants.RANKING_VALUES[result[1]]
|
||||
scores['CONFIDENCE'][con] += val
|
||||
|
||||
except Exception as e:
|
||||
self.report_error(name, context, e)
|
||||
if self.debug:
|
||||
raise
|
||||
except Exception as e:
|
||||
self.report_error(name, context, e)
|
||||
if self.debug:
|
||||
raise
|
||||
self.logger.debug("Returning scores: %s", scores)
|
||||
return scores
|
||||
|
||||
|
@ -261,3 +261,57 @@ def safe_str(obj):
|
||||
except UnicodeEncodeError:
|
||||
# obj is unicode
|
||||
return unicode(obj).encode('unicode_escape')
|
||||
|
||||
|
||||
def embelish_ast(node):
|
||||
"""Add a parent and sibling info to every node."""
|
||||
for _, value in ast.iter_fields(node):
|
||||
if isinstance(value, list):
|
||||
last = None
|
||||
for item in value:
|
||||
if isinstance(item, ast.AST):
|
||||
if last is not None:
|
||||
setattr(last, 'sibling', item)
|
||||
last = item
|
||||
setattr(item, 'parent', node)
|
||||
embelish_ast(item)
|
||||
|
||||
elif isinstance(value, ast.AST):
|
||||
setattr(value, 'parent', node)
|
||||
embelish_ast(value)
|
||||
|
||||
|
||||
def linerange(node):
|
||||
"""Get line number range from a node."""
|
||||
strip = {"body": None, "orelse": None,
|
||||
"handlers": None, "finalbody": None}
|
||||
fields = dir(node)
|
||||
for key in strip.keys():
|
||||
if key in fields:
|
||||
strip[key] = getattr(node, key)
|
||||
setattr(node, key, [])
|
||||
|
||||
lines = set()
|
||||
for n in ast.walk(node):
|
||||
if hasattr(n, 'lineno'):
|
||||
lines.add(n.lineno)
|
||||
|
||||
for key in strip.keys():
|
||||
if strip[key] is not None:
|
||||
setattr(node, key, strip[key])
|
||||
|
||||
if len(lines):
|
||||
return range(min(lines), max(lines) + 1)
|
||||
return [0, 1]
|
||||
|
||||
|
||||
def linerange_fix(node):
|
||||
"""Try and work around a known Python bug with multi-line strings."""
|
||||
# deal with multiline strings lineno behavior (Python issue #16806)
|
||||
lines = linerange(node)
|
||||
if hasattr(node, 'sibling') and hasattr(node.sibling, 'lineno'):
|
||||
start = min(lines)
|
||||
delta = node.sibling.lineno - start
|
||||
if delta > 1:
|
||||
return range(start, node.sibling.lineno)
|
||||
return lines
|
||||
|
@ -47,7 +47,7 @@ def _ast_binop_stringify(data):
|
||||
|
||||
@checks('Str')
|
||||
def hardcoded_sql_expressions(context):
|
||||
statement = context.statement['node']
|
||||
statement = context.node.parent
|
||||
if isinstance(statement, ast.Assign):
|
||||
test_str = _ast_build_string(statement.value).lower()
|
||||
|
||||
|
@ -394,7 +394,7 @@ class FunctionalTests(unittest.TestCase):
|
||||
self.assertEqual(range(2, 7), issues[0]['line_range'])
|
||||
self.assertIn('/tmp', issues[0]['code'])
|
||||
self.assertEqual(18, issues[1]['line_number'])
|
||||
self.assertEqual(range(16, 21), issues[1]['line_range'])
|
||||
self.assertEqual(range(16, 19), issues[1]['line_range'])
|
||||
self.assertIn('/tmp', issues[1]['code'])
|
||||
self.assertEqual(23, issues[2]['line_number'])
|
||||
self.assertEqual(range(22, 31), issues[2]['line_range'])
|
||||
|
@ -1,85 +0,0 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
#
|
||||
# Copyright 2014 Hewlett-Packard Development Company, L.P.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import os
|
||||
import ast
|
||||
|
||||
import unittest
|
||||
from bandit.core import node_visitor
|
||||
|
||||
|
||||
class StatementBufferTests(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(StatementBufferTests, self).setUp()
|
||||
self.test_file = open("./examples/jinja2_templating.py")
|
||||
self.buf = node_visitor.StatementBuffer()
|
||||
self.buf.load_buffer(self.test_file)
|
||||
|
||||
def tearDown(self):
|
||||
pass
|
||||
|
||||
def test_load_buffer(self):
|
||||
# Check buffer contains 10 statements
|
||||
self.assertEqual(10, len(self.buf._buffer))
|
||||
|
||||
def test_get_next(self):
|
||||
# Check get_next returns an AST statement
|
||||
stmt = self.buf.get_next()
|
||||
self.assertTrue(isinstance(stmt['node'], ast.AST))
|
||||
# Check get_next returned the first statement
|
||||
self.assertEqual(1, stmt['linerange'][0])
|
||||
# Check buffer has been reduced by one
|
||||
self.assertEqual(9, len(self.buf._buffer))
|
||||
|
||||
def test_get_next_lookahead(self):
|
||||
# Check get_next(pop=False) returns an AST statement
|
||||
stmt = self.buf.get_next(pop=False)
|
||||
self.assertTrue(isinstance(stmt['node'], ast.AST))
|
||||
# Check get_next(pop=False) returned the first statement
|
||||
self.assertEqual(1, stmt['linerange'][0])
|
||||
# Check buffer remains the same length
|
||||
self.assertEqual(10, len(self.buf._buffer))
|
||||
|
||||
def test_get_next_count(self):
|
||||
# Check get_next returns exactly 10 statements
|
||||
count = 0
|
||||
stmt = self.buf.get_next()
|
||||
while stmt is not None:
|
||||
count = count + 1
|
||||
stmt = self.buf.get_next()
|
||||
|
||||
self.assertEqual(10, count)
|
||||
|
||||
def test_get_next_empty(self):
|
||||
# Check get_next on an empty buffer returns None
|
||||
# self.test_file has already been read, so is empty file handle
|
||||
self.buf.load_buffer(self.test_file)
|
||||
stmt = self.buf.get_next()
|
||||
self.assertEqual(None, stmt)
|
||||
|
||||
def test_linenumber_range(self):
|
||||
# Check linenumber_range returns corrent number of lines
|
||||
count = 9
|
||||
while count > 0:
|
||||
stmt = self.buf.get_next()
|
||||
count = count - 1
|
||||
|
||||
# line 9 should be three lines long
|
||||
self.assertEqual(3, len(stmt['linerange']))
|
||||
|
||||
# the range should be the correct line numbers
|
||||
self.assertEqual([11, 12, 13], list(stmt['linerange']))
|
@ -221,3 +221,16 @@ class UtilTests(unittest.TestCase):
|
||||
self.assertEqual('', name)
|
||||
# TODO(ljfisher) At best we might be able to get:
|
||||
# self.assertEqual(name, 'a.list[0]')
|
||||
|
||||
def test_linerange(self):
|
||||
self.test_file = open("./examples/jinja2_templating.py")
|
||||
self.tree = ast.parse(self.test_file.read())
|
||||
# Check linerange returns corrent number of lines
|
||||
line = self.tree.body[8]
|
||||
lrange = b_utils.linerange(line)
|
||||
|
||||
# line 9 should be three lines long
|
||||
self.assertEqual(3, len(lrange))
|
||||
|
||||
# the range should be the correct line numbers
|
||||
self.assertEqual([11, 12, 13], list(lrange))
|
||||
|
Loading…
x
Reference in New Issue
Block a user