Merge "Removing statement buffer"

This commit is contained in:
Jenkins 2015-07-09 18:20:30 +00:00 committed by Gerrit Code Review
commit 8ff7894cf2
8 changed files with 126 additions and 297 deletions

View File

@ -23,128 +23,6 @@ from bandit.core import utils as b_utils
from bandit.core.utils import InvalidModulePath
if hasattr(ast, 'TryExcept'):
ast_Try = (ast.TryExcept, ast.TryFinally)
else: # Python 3.3+
ast_Try = ast.Try
class StatementBuffer():
'''Buffer for code statements
Creates a buffer to store a code file as individual statements
for AST processing
'''
def __init__(self):
self._buffer = []
self.skip_lines = []
def load_buffer(self, fdata):
'''Buffer initialization
Read the file as lines, so we can store the length of the file
so we don't lose multi-line statements at the bottom of the target
file
:param fdata: The code to be parsed into the buffer
'''
self._buffer = []
self.skip_lines = []
lines = fdata.readlines()
self.file_len = len(lines)
for lineno in range(self.file_len):
found = False
for flag in constants.SKIP_FLAGS:
if "#" + flag in lines[lineno].replace(" ", "").lower():
found = True
if found:
self.skip_lines.append(lineno + 1)
f_ast = ast.parse("".join(lines))
# We need to expand body blocks within compound statements
# into our statement buffer so each gets processed in
# isolation
tmp_buf = f_ast.body
while len(tmp_buf):
# For each statement, if it is one of the special statement
# types which contain a body, we first update the tmp_buf
# adding the internal body statements to the beginning of
# the temporary buffer, then clear the body of the special
# statement before adding it to the primary buffer
stmt = tmp_buf.pop(0)
if (isinstance(stmt, ast.ClassDef)
or isinstance(stmt, ast.FunctionDef)
or isinstance(stmt, ast.With)
or isinstance(stmt, ast.Module)
or isinstance(stmt, ast.Interactive)):
stmt.body.extend(tmp_buf)
tmp_buf = stmt.body
stmt.body = []
elif (isinstance(stmt, ast.For)
or isinstance(stmt, ast.While)
or isinstance(stmt, ast.If)):
stmt.body.extend(stmt.orelse)
stmt.body.extend(tmp_buf)
tmp_buf = stmt.body
stmt.body = []
stmt.orelse = []
elif isinstance(stmt, ast_Try):
for handler in getattr(stmt, 'handlers', []):
stmt.body.extend(handler.body)
stmt.body.extend(getattr(stmt, 'orelse', []))
stmt.body.extend(tmp_buf)
tmp_buf = stmt.body
stmt.body = []
stmt.orelse = []
stmt.handlers = []
stmt.finalbody = []
# once we are sure it's either a single statement or that
# any content in a compound statement body has been removed
# we can add it to our primary buffer. The compound body
# must be removed so the ast isn't walked multiple times
# and isn't included in line-by-line output
self._buffer.append(stmt)
def get_next(self, pop=True):
'''Statment Retrieval
Grab the next statement in the buffer for detailed processing
:param pop: shift next statement off array (default) or just lookahead
:return statement: the next statement to be processed, or None
'''
if len(self._buffer):
statement = {}
if pop:
# shift the next statement off the array
statement['node'] = self._buffer.pop(0)
else:
# get the next statement without shift
statement['node'] = self._buffer[0]
statement['linerange'] = self.linenumber_range(statement['node'])
return statement
return None
def linenumber_range(self, node):
'''Get set of line numbers for statement
Walks the given statement node, and creates a set
of line numbers covered by the code
:param node: The statment line numbers are required for
:return lines: A set of line numbers
'''
lines = set()
for n in ast.walk(node):
if hasattr(n, 'lineno'):
lines.add(n.lineno)
# we'll return a range here, because in some cases ast.walk skips over
# important parts, such as the middle lines in a multi-line string
return range(min(lines), max(lines) + 1)
def get_skip_lines(self):
return self.skip_lines
class BanditNodeVisitor(ast.NodeVisitor):
imports = set()
@ -192,8 +70,7 @@ class BanditNodeVisitor(ast.NodeVisitor):
self.fname)
self.namespace = ""
self.logger.debug('Module qualified name: %s', self.namespace)
self.stmt_buffer = StatementBuffer()
self.statement = {}
self.lines = []
def visit_ClassDef(self, node):
'''Visitor for AST ClassDef node
@ -322,20 +199,8 @@ class BanditNodeVisitor(ast.NodeVisitor):
self.context['str'] = node.s
self.logger.debug("visit_Str called (%s)", ast.dump(node))
# This check is to make sure we aren't running tests against
# docstrings (any statement that is just a string, nothing else)
node_object = self.context['statement']['node']
# docstrings can be represented as standalone ast.Str
is_str = isinstance(node_object, ast.Str)
# or ast.Expr with a value of type ast.Str
if (isinstance(node_object, ast.Expr) and
isinstance(node_object.value, ast.Str)):
is_standalone_expr = True
else:
is_standalone_expr = False
# if we don't have either one of those, run the test
if not (is_str or is_standalone_expr):
if not isinstance(node.parent, ast.Expr): # docstring
self.context['linerange'] = b_utils.linerange_fix(node.parent)
self.update_scores(self.tester.run_tests(self.context, 'Str'))
super(BanditNodeVisitor, self).generic_visit(node)
@ -360,31 +225,19 @@ class BanditNodeVisitor(ast.NodeVisitor):
:param node: The node that is being inspected
:return: -
'''
self.context = copy.copy(self.context_template)
self.logger.debug(ast.dump(node))
self.metaast.add_node(node, '', self.depth)
self.context = copy.copy(self.context_template)
self.context['statement'] = self.statement
self.context['node'] = node
self.context['filename'] = self.fname
if hasattr(node, 'lineno'):
self.context['lineno'] = node.lineno
if ("# nosec" in self.lines[node.lineno - 1] or
"#nosec" in self.lines[node.lineno - 1]):
self.logger.debug("skipped, nosec")
return self.scores # skip this node and all sub-nodes
# deal with multiline strings lineno behavior (Python issue #16806)
current_lineno = self.context['lineno']
next_statement = self.stmt_buffer.get_next(pop=False)
if next_statement is not None:
next_lineno = min(next_statement['linerange'])
else:
next_lineno = self.stmt_buffer.file_len
if next_lineno - current_lineno > 1:
self.context['statement']['linerange'] = range(
min(self.context['statement']['linerange']),
next_lineno
)
self.context['skip_lines'] = self.stmt_buffer.get_skip_lines()
self.context['node'] = node
self.context['linerange'] = b_utils.linerange_fix(node)
self.context['filename'] = self.fname
self.seen += 1
self.logger.debug("entering: %s %s [%s]", hex(id(node)), type(node),
@ -412,18 +265,13 @@ class BanditNodeVisitor(ast.NodeVisitor):
def process(self, fdata):
'''Main process loop
Iniitalizes the statement buffer, iterates over each statement
in the buffer testing each AST in turn
Build and process the AST
:param fdata: the open filehandle for the code to be processed
:return score: the aggregated score for the current file
'''
self.stmt_buffer.load_buffer(fdata)
self.statement = self.stmt_buffer.get_next()
while self.statement is not None:
self.logger.debug('New statement loaded')
self.logger.debug('s_node: %s', ast.dump(self.statement['node']))
self.logger.debug('s_lineno: %s', self.statement['linerange'])
self.visit(self.statement['node'])
self.statement = self.stmt_buffer.get_next()
fdata.seek(0)
self.lines = fdata.readlines()
f_ast = ast.parse("".join(self.lines))
b_utils.embelish_ast(f_ast)
self.visit(f_ast)
return self.scores

View File

@ -61,7 +61,7 @@ class BanditResultStore():
'''
filename = context['filename']
lineno = context['lineno']
linerange = context['statement']['linerange']
linerange = context['linerange']
(issue_severity, issue_confidence, issue_text) = issue
if self.agg_type == 'vuln':

View File

@ -50,50 +50,49 @@ class BanditTester():
'CONFIDENCE': [0] * len(constants.RANKING)
}
if not raw_context['lineno'] in raw_context['skip_lines']:
tests = self.testset.get_tests(checktype)
for name, test in six.iteritems(tests):
# execute test with the an instance of the context class
temp_context = copy.copy(raw_context)
context = b_context.Context(temp_context)
try:
if hasattr(test, '_takes_config'):
# TODO(??): Possibly allow override from profile
test_config = self.config.get_option(
test._takes_config)
result = test(context, test_config)
else:
result = test(context)
tests = self.testset.get_tests(checktype)
for name, test in six.iteritems(tests):
# execute test with the an instance of the context class
temp_context = copy.copy(raw_context)
context = b_context.Context(temp_context)
try:
if hasattr(test, '_takes_config'):
# TODO(??): Possibly allow override from profile
test_config = self.config.get_option(
test._takes_config)
result = test(context, test_config)
else:
result = test(context)
# the test call returns a 2- or 3-tuple
# - (issue_severity, issue_text) or
# - (issue_severity, issue_confidence, issue_text)
# the test call returns a 2- or 3-tuple
# - (issue_severity, issue_text) or
# - (issue_severity, issue_confidence, issue_text)
# add default confidence level, if not returned by test
if (result is not None and len(result) == 2):
result = (
result[0],
constants.CONFIDENCE_DEFAULT,
result[1]
)
# add default confidence level, if not returned by test
if (result is not None and len(result) == 2):
result = (
result[0],
constants.CONFIDENCE_DEFAULT,
result[1]
)
# if we have a result, record it and update scores
if result is not None:
self.results.add(temp_context, name, result)
self.logger.debug(
"Issue identified by %s: %s", name, result
)
sev = constants.RANKING.index(result[0])
val = constants.RANKING_VALUES[result[0]]
scores['SEVERITY'][sev] += val
con = constants.RANKING.index(result[1])
val = constants.RANKING_VALUES[result[1]]
scores['CONFIDENCE'][con] += val
# if we have a result, record it and update scores
if result is not None:
self.results.add(temp_context, name, result)
self.logger.debug(
"Issue identified by %s: %s", name, result
)
sev = constants.RANKING.index(result[0])
val = constants.RANKING_VALUES[result[0]]
scores['SEVERITY'][sev] += val
con = constants.RANKING.index(result[1])
val = constants.RANKING_VALUES[result[1]]
scores['CONFIDENCE'][con] += val
except Exception as e:
self.report_error(name, context, e)
if self.debug:
raise
except Exception as e:
self.report_error(name, context, e)
if self.debug:
raise
self.logger.debug("Returning scores: %s", scores)
return scores

View File

@ -261,3 +261,57 @@ def safe_str(obj):
except UnicodeEncodeError:
# obj is unicode
return unicode(obj).encode('unicode_escape')
def embelish_ast(node):
"""Add a parent and sibling info to every node."""
for _, value in ast.iter_fields(node):
if isinstance(value, list):
last = None
for item in value:
if isinstance(item, ast.AST):
if last is not None:
setattr(last, 'sibling', item)
last = item
setattr(item, 'parent', node)
embelish_ast(item)
elif isinstance(value, ast.AST):
setattr(value, 'parent', node)
embelish_ast(value)
def linerange(node):
"""Get line number range from a node."""
strip = {"body": None, "orelse": None,
"handlers": None, "finalbody": None}
fields = dir(node)
for key in strip.keys():
if key in fields:
strip[key] = getattr(node, key)
setattr(node, key, [])
lines = set()
for n in ast.walk(node):
if hasattr(n, 'lineno'):
lines.add(n.lineno)
for key in strip.keys():
if strip[key] is not None:
setattr(node, key, strip[key])
if len(lines):
return range(min(lines), max(lines) + 1)
return [0, 1]
def linerange_fix(node):
"""Try and work around a known Python bug with multi-line strings."""
# deal with multiline strings lineno behavior (Python issue #16806)
lines = linerange(node)
if hasattr(node, 'sibling') and hasattr(node.sibling, 'lineno'):
start = min(lines)
delta = node.sibling.lineno - start
if delta > 1:
return range(start, node.sibling.lineno)
return lines

View File

@ -47,7 +47,7 @@ def _ast_binop_stringify(data):
@checks('Str')
def hardcoded_sql_expressions(context):
statement = context.statement['node']
statement = context.node.parent
if isinstance(statement, ast.Assign):
test_str = _ast_build_string(statement.value).lower()

View File

@ -394,7 +394,7 @@ class FunctionalTests(unittest.TestCase):
self.assertEqual(range(2, 7), issues[0]['line_range'])
self.assertIn('/tmp', issues[0]['code'])
self.assertEqual(18, issues[1]['line_number'])
self.assertEqual(range(16, 21), issues[1]['line_range'])
self.assertEqual(range(16, 19), issues[1]['line_range'])
self.assertIn('/tmp', issues[1]['code'])
self.assertEqual(23, issues[2]['line_number'])
self.assertEqual(range(22, 31), issues[2]['line_range'])

View File

@ -1,85 +0,0 @@
# -*- coding:utf-8 -*-
#
# Copyright 2014 Hewlett-Packard Development Company, L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import os
import ast
import unittest
from bandit.core import node_visitor
class StatementBufferTests(unittest.TestCase):
def setUp(self):
super(StatementBufferTests, self).setUp()
self.test_file = open("./examples/jinja2_templating.py")
self.buf = node_visitor.StatementBuffer()
self.buf.load_buffer(self.test_file)
def tearDown(self):
pass
def test_load_buffer(self):
# Check buffer contains 10 statements
self.assertEqual(10, len(self.buf._buffer))
def test_get_next(self):
# Check get_next returns an AST statement
stmt = self.buf.get_next()
self.assertTrue(isinstance(stmt['node'], ast.AST))
# Check get_next returned the first statement
self.assertEqual(1, stmt['linerange'][0])
# Check buffer has been reduced by one
self.assertEqual(9, len(self.buf._buffer))
def test_get_next_lookahead(self):
# Check get_next(pop=False) returns an AST statement
stmt = self.buf.get_next(pop=False)
self.assertTrue(isinstance(stmt['node'], ast.AST))
# Check get_next(pop=False) returned the first statement
self.assertEqual(1, stmt['linerange'][0])
# Check buffer remains the same length
self.assertEqual(10, len(self.buf._buffer))
def test_get_next_count(self):
# Check get_next returns exactly 10 statements
count = 0
stmt = self.buf.get_next()
while stmt is not None:
count = count + 1
stmt = self.buf.get_next()
self.assertEqual(10, count)
def test_get_next_empty(self):
# Check get_next on an empty buffer returns None
# self.test_file has already been read, so is empty file handle
self.buf.load_buffer(self.test_file)
stmt = self.buf.get_next()
self.assertEqual(None, stmt)
def test_linenumber_range(self):
# Check linenumber_range returns corrent number of lines
count = 9
while count > 0:
stmt = self.buf.get_next()
count = count - 1
# line 9 should be three lines long
self.assertEqual(3, len(stmt['linerange']))
# the range should be the correct line numbers
self.assertEqual([11, 12, 13], list(stmt['linerange']))

View File

@ -221,3 +221,16 @@ class UtilTests(unittest.TestCase):
self.assertEqual('', name)
# TODO(ljfisher) At best we might be able to get:
# self.assertEqual(name, 'a.list[0]')
def test_linerange(self):
self.test_file = open("./examples/jinja2_templating.py")
self.tree = ast.parse(self.test_file.read())
# Check linerange returns corrent number of lines
line = self.tree.body[8]
lrange = b_utils.linerange(line)
# line 9 should be three lines long
self.assertEqual(3, len(lrange))
# the range should be the correct line numbers
self.assertEqual([11, 12, 13], list(lrange))