Removing statement buffer

The statement buffer previously mangled the AST to produce chunks
that were passed, in isolation, to the node visitor. This caused
some tests to become impossible, if the needed data from two chunks
had been separated. This change fixes things by allowing the
node_visitor to process the AST directly.

To make this work it was necessary to embellish the AST with a few
extra bits of data, a node parent and sibling. This is a none
destructive change that does not limit any of the testing potential
of bandit compared to the raw AST (indeed it enhances it).
As a beneficial side effect, "nosec" comments can now be used to
skip entire chunks of code, not just lines, for example

try:
 ...
except: #nosec
 ...

will now skip everything in except block.

Change-Id: I47c4963b143868a2bd2f171dbcf39e4f97a929da
This commit is contained in:
Tim Kelsey 2015-07-07 21:17:59 +01:00
parent 055598028a
commit ca48f32924
8 changed files with 126 additions and 297 deletions

View File

@ -23,128 +23,6 @@ from bandit.core import utils as b_utils
from bandit.core.utils import InvalidModulePath
if hasattr(ast, 'TryExcept'):
ast_Try = (ast.TryExcept, ast.TryFinally)
else: # Python 3.3+
ast_Try = ast.Try
class StatementBuffer():
'''Buffer for code statements
Creates a buffer to store a code file as individual statements
for AST processing
'''
def __init__(self):
self._buffer = []
self.skip_lines = []
def load_buffer(self, fdata):
'''Buffer initialization
Read the file as lines, so we can store the length of the file
so we don't lose multi-line statements at the bottom of the target
file
:param fdata: The code to be parsed into the buffer
'''
self._buffer = []
self.skip_lines = []
lines = fdata.readlines()
self.file_len = len(lines)
for lineno in range(self.file_len):
found = False
for flag in constants.SKIP_FLAGS:
if "#" + flag in lines[lineno].replace(" ", "").lower():
found = True
if found:
self.skip_lines.append(lineno + 1)
f_ast = ast.parse("".join(lines))
# We need to expand body blocks within compound statements
# into our statement buffer so each gets processed in
# isolation
tmp_buf = f_ast.body
while len(tmp_buf):
# For each statement, if it is one of the special statement
# types which contain a body, we first update the tmp_buf
# adding the internal body statements to the beginning of
# the temporary buffer, then clear the body of the special
# statement before adding it to the primary buffer
stmt = tmp_buf.pop(0)
if (isinstance(stmt, ast.ClassDef)
or isinstance(stmt, ast.FunctionDef)
or isinstance(stmt, ast.With)
or isinstance(stmt, ast.Module)
or isinstance(stmt, ast.Interactive)):
stmt.body.extend(tmp_buf)
tmp_buf = stmt.body
stmt.body = []
elif (isinstance(stmt, ast.For)
or isinstance(stmt, ast.While)
or isinstance(stmt, ast.If)):
stmt.body.extend(stmt.orelse)
stmt.body.extend(tmp_buf)
tmp_buf = stmt.body
stmt.body = []
stmt.orelse = []
elif isinstance(stmt, ast_Try):
for handler in getattr(stmt, 'handlers', []):
stmt.body.extend(handler.body)
stmt.body.extend(getattr(stmt, 'orelse', []))
stmt.body.extend(tmp_buf)
tmp_buf = stmt.body
stmt.body = []
stmt.orelse = []
stmt.handlers = []
stmt.finalbody = []
# once we are sure it's either a single statement or that
# any content in a compound statement body has been removed
# we can add it to our primary buffer. The compound body
# must be removed so the ast isn't walked multiple times
# and isn't included in line-by-line output
self._buffer.append(stmt)
def get_next(self, pop=True):
'''Statment Retrieval
Grab the next statement in the buffer for detailed processing
:param pop: shift next statement off array (default) or just lookahead
:return statement: the next statement to be processed, or None
'''
if len(self._buffer):
statement = {}
if pop:
# shift the next statement off the array
statement['node'] = self._buffer.pop(0)
else:
# get the next statement without shift
statement['node'] = self._buffer[0]
statement['linerange'] = self.linenumber_range(statement['node'])
return statement
return None
def linenumber_range(self, node):
'''Get set of line numbers for statement
Walks the given statement node, and creates a set
of line numbers covered by the code
:param node: The statment line numbers are required for
:return lines: A set of line numbers
'''
lines = set()
for n in ast.walk(node):
if hasattr(n, 'lineno'):
lines.add(n.lineno)
# we'll return a range here, because in some cases ast.walk skips over
# important parts, such as the middle lines in a multi-line string
return range(min(lines), max(lines) + 1)
def get_skip_lines(self):
return self.skip_lines
class BanditNodeVisitor(ast.NodeVisitor):
imports = set()
@ -192,8 +70,7 @@ class BanditNodeVisitor(ast.NodeVisitor):
self.fname)
self.namespace = ""
self.logger.debug('Module qualified name: %s', self.namespace)
self.stmt_buffer = StatementBuffer()
self.statement = {}
self.lines = []
def visit_ClassDef(self, node):
'''Visitor for AST ClassDef node
@ -322,20 +199,8 @@ class BanditNodeVisitor(ast.NodeVisitor):
self.context['str'] = node.s
self.logger.debug("visit_Str called (%s)", ast.dump(node))
# This check is to make sure we aren't running tests against
# docstrings (any statement that is just a string, nothing else)
node_object = self.context['statement']['node']
# docstrings can be represented as standalone ast.Str
is_str = isinstance(node_object, ast.Str)
# or ast.Expr with a value of type ast.Str
if (isinstance(node_object, ast.Expr) and
isinstance(node_object.value, ast.Str)):
is_standalone_expr = True
else:
is_standalone_expr = False
# if we don't have either one of those, run the test
if not (is_str or is_standalone_expr):
if not isinstance(node.parent, ast.Expr): # docstring
self.context['linerange'] = b_utils.linerange_fix(node.parent)
self.update_scores(self.tester.run_tests(self.context, 'Str'))
super(BanditNodeVisitor, self).generic_visit(node)
@ -360,31 +225,19 @@ class BanditNodeVisitor(ast.NodeVisitor):
:param node: The node that is being inspected
:return: -
'''
self.context = copy.copy(self.context_template)
self.logger.debug(ast.dump(node))
self.metaast.add_node(node, '', self.depth)
self.context = copy.copy(self.context_template)
self.context['statement'] = self.statement
self.context['node'] = node
self.context['filename'] = self.fname
if hasattr(node, 'lineno'):
self.context['lineno'] = node.lineno
if ("# nosec" in self.lines[node.lineno - 1] or
"#nosec" in self.lines[node.lineno - 1]):
self.logger.debug("skipped, nosec")
return self.scores # skip this node and all sub-nodes
# deal with multiline strings lineno behavior (Python issue #16806)
current_lineno = self.context['lineno']
next_statement = self.stmt_buffer.get_next(pop=False)
if next_statement is not None:
next_lineno = min(next_statement['linerange'])
else:
next_lineno = self.stmt_buffer.file_len
if next_lineno - current_lineno > 1:
self.context['statement']['linerange'] = range(
min(self.context['statement']['linerange']),
next_lineno
)
self.context['skip_lines'] = self.stmt_buffer.get_skip_lines()
self.context['node'] = node
self.context['linerange'] = b_utils.linerange_fix(node)
self.context['filename'] = self.fname
self.seen += 1
self.logger.debug("entering: %s %s [%s]", hex(id(node)), type(node),
@ -412,18 +265,13 @@ class BanditNodeVisitor(ast.NodeVisitor):
def process(self, fdata):
'''Main process loop
Iniitalizes the statement buffer, iterates over each statement
in the buffer testing each AST in turn
Build and process the AST
:param fdata: the open filehandle for the code to be processed
:return score: the aggregated score for the current file
'''
self.stmt_buffer.load_buffer(fdata)
self.statement = self.stmt_buffer.get_next()
while self.statement is not None:
self.logger.debug('New statement loaded')
self.logger.debug('s_node: %s', ast.dump(self.statement['node']))
self.logger.debug('s_lineno: %s', self.statement['linerange'])
self.visit(self.statement['node'])
self.statement = self.stmt_buffer.get_next()
fdata.seek(0)
self.lines = fdata.readlines()
f_ast = ast.parse("".join(self.lines))
b_utils.embelish_ast(f_ast)
self.visit(f_ast)
return self.scores

View File

@ -61,7 +61,7 @@ class BanditResultStore():
'''
filename = context['filename']
lineno = context['lineno']
linerange = context['statement']['linerange']
linerange = context['linerange']
(issue_severity, issue_confidence, issue_text) = issue
if self.agg_type == 'vuln':

View File

@ -50,50 +50,49 @@ class BanditTester():
'CONFIDENCE': [0] * len(constants.RANKING)
}
if not raw_context['lineno'] in raw_context['skip_lines']:
tests = self.testset.get_tests(checktype)
for name, test in six.iteritems(tests):
# execute test with the an instance of the context class
temp_context = copy.copy(raw_context)
context = b_context.Context(temp_context)
try:
if hasattr(test, '_takes_config'):
# TODO(??): Possibly allow override from profile
test_config = self.config.get_option(
test._takes_config)
result = test(context, test_config)
else:
result = test(context)
tests = self.testset.get_tests(checktype)
for name, test in six.iteritems(tests):
# execute test with the an instance of the context class
temp_context = copy.copy(raw_context)
context = b_context.Context(temp_context)
try:
if hasattr(test, '_takes_config'):
# TODO(??): Possibly allow override from profile
test_config = self.config.get_option(
test._takes_config)
result = test(context, test_config)
else:
result = test(context)
# the test call returns a 2- or 3-tuple
# - (issue_severity, issue_text) or
# - (issue_severity, issue_confidence, issue_text)
# the test call returns a 2- or 3-tuple
# - (issue_severity, issue_text) or
# - (issue_severity, issue_confidence, issue_text)
# add default confidence level, if not returned by test
if (result is not None and len(result) == 2):
result = (
result[0],
constants.CONFIDENCE_DEFAULT,
result[1]
)
# add default confidence level, if not returned by test
if (result is not None and len(result) == 2):
result = (
result[0],
constants.CONFIDENCE_DEFAULT,
result[1]
)
# if we have a result, record it and update scores
if result is not None:
self.results.add(temp_context, name, result)
self.logger.debug(
"Issue identified by %s: %s", name, result
)
sev = constants.RANKING.index(result[0])
val = constants.RANKING_VALUES[result[0]]
scores['SEVERITY'][sev] += val
con = constants.RANKING.index(result[1])
val = constants.RANKING_VALUES[result[1]]
scores['CONFIDENCE'][con] += val
# if we have a result, record it and update scores
if result is not None:
self.results.add(temp_context, name, result)
self.logger.debug(
"Issue identified by %s: %s", name, result
)
sev = constants.RANKING.index(result[0])
val = constants.RANKING_VALUES[result[0]]
scores['SEVERITY'][sev] += val
con = constants.RANKING.index(result[1])
val = constants.RANKING_VALUES[result[1]]
scores['CONFIDENCE'][con] += val
except Exception as e:
self.report_error(name, context, e)
if self.debug:
raise
except Exception as e:
self.report_error(name, context, e)
if self.debug:
raise
self.logger.debug("Returning scores: %s", scores)
return scores

View File

@ -261,3 +261,57 @@ def safe_str(obj):
except UnicodeEncodeError:
# obj is unicode
return unicode(obj).encode('unicode_escape')
def embelish_ast(node):
"""Add a parent and sibling info to every node."""
for _, value in ast.iter_fields(node):
if isinstance(value, list):
last = None
for item in value:
if isinstance(item, ast.AST):
if last is not None:
setattr(last, 'sibling', item)
last = item
setattr(item, 'parent', node)
embelish_ast(item)
elif isinstance(value, ast.AST):
setattr(value, 'parent', node)
embelish_ast(value)
def linerange(node):
"""Get line number range from a node."""
strip = {"body": None, "orelse": None,
"handlers": None, "finalbody": None}
fields = dir(node)
for key in strip.keys():
if key in fields:
strip[key] = getattr(node, key)
setattr(node, key, [])
lines = set()
for n in ast.walk(node):
if hasattr(n, 'lineno'):
lines.add(n.lineno)
for key in strip.keys():
if strip[key] is not None:
setattr(node, key, strip[key])
if len(lines):
return range(min(lines), max(lines) + 1)
return [0, 1]
def linerange_fix(node):
"""Try and work around a known Python bug with multi-line strings."""
# deal with multiline strings lineno behavior (Python issue #16806)
lines = linerange(node)
if hasattr(node, 'sibling') and hasattr(node.sibling, 'lineno'):
start = min(lines)
delta = node.sibling.lineno - start
if delta > 1:
return range(start, node.sibling.lineno)
return lines

View File

@ -47,7 +47,7 @@ def _ast_binop_stringify(data):
@checks('Str')
def hardcoded_sql_expressions(context):
statement = context.statement['node']
statement = context.node.parent
if isinstance(statement, ast.Assign):
test_str = _ast_build_string(statement.value).lower()

View File

@ -394,7 +394,7 @@ class FunctionalTests(unittest.TestCase):
self.assertEqual(range(2, 7), issues[0]['line_range'])
self.assertIn('/tmp', issues[0]['code'])
self.assertEqual(18, issues[1]['line_number'])
self.assertEqual(range(16, 21), issues[1]['line_range'])
self.assertEqual(range(16, 19), issues[1]['line_range'])
self.assertIn('/tmp', issues[1]['code'])
self.assertEqual(23, issues[2]['line_number'])
self.assertEqual(range(22, 31), issues[2]['line_range'])

View File

@ -1,85 +0,0 @@
# -*- coding:utf-8 -*-
#
# Copyright 2014 Hewlett-Packard Development Company, L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import os
import ast
import unittest
from bandit.core import node_visitor
class StatementBufferTests(unittest.TestCase):
def setUp(self):
super(StatementBufferTests, self).setUp()
self.test_file = open("./examples/jinja2_templating.py")
self.buf = node_visitor.StatementBuffer()
self.buf.load_buffer(self.test_file)
def tearDown(self):
pass
def test_load_buffer(self):
# Check buffer contains 10 statements
self.assertEqual(10, len(self.buf._buffer))
def test_get_next(self):
# Check get_next returns an AST statement
stmt = self.buf.get_next()
self.assertTrue(isinstance(stmt['node'], ast.AST))
# Check get_next returned the first statement
self.assertEqual(1, stmt['linerange'][0])
# Check buffer has been reduced by one
self.assertEqual(9, len(self.buf._buffer))
def test_get_next_lookahead(self):
# Check get_next(pop=False) returns an AST statement
stmt = self.buf.get_next(pop=False)
self.assertTrue(isinstance(stmt['node'], ast.AST))
# Check get_next(pop=False) returned the first statement
self.assertEqual(1, stmt['linerange'][0])
# Check buffer remains the same length
self.assertEqual(10, len(self.buf._buffer))
def test_get_next_count(self):
# Check get_next returns exactly 10 statements
count = 0
stmt = self.buf.get_next()
while stmt is not None:
count = count + 1
stmt = self.buf.get_next()
self.assertEqual(10, count)
def test_get_next_empty(self):
# Check get_next on an empty buffer returns None
# self.test_file has already been read, so is empty file handle
self.buf.load_buffer(self.test_file)
stmt = self.buf.get_next()
self.assertEqual(None, stmt)
def test_linenumber_range(self):
# Check linenumber_range returns corrent number of lines
count = 9
while count > 0:
stmt = self.buf.get_next()
count = count - 1
# line 9 should be three lines long
self.assertEqual(3, len(stmt['linerange']))
# the range should be the correct line numbers
self.assertEqual([11, 12, 13], list(stmt['linerange']))

View File

@ -221,3 +221,16 @@ class UtilTests(unittest.TestCase):
self.assertEqual('', name)
# TODO(ljfisher) At best we might be able to get:
# self.assertEqual(name, 'a.list[0]')
def test_linerange(self):
self.test_file = open("./examples/jinja2_templating.py")
self.tree = ast.parse(self.test_file.read())
# Check linerange returns corrent number of lines
line = self.tree.body[8]
lrange = b_utils.linerange(line)
# line 9 should be three lines long
self.assertEqual(3, len(lrange))
# the range should be the correct line numbers
self.assertEqual([11, 12, 13], list(lrange))