Merge "Removing statement buffer"
This commit is contained in:
commit
8ff7894cf2
@ -23,128 +23,6 @@ from bandit.core import utils as b_utils
|
||||
from bandit.core.utils import InvalidModulePath
|
||||
|
||||
|
||||
if hasattr(ast, 'TryExcept'):
|
||||
ast_Try = (ast.TryExcept, ast.TryFinally)
|
||||
else: # Python 3.3+
|
||||
ast_Try = ast.Try
|
||||
|
||||
|
||||
class StatementBuffer():
|
||||
'''Buffer for code statements
|
||||
|
||||
Creates a buffer to store a code file as individual statements
|
||||
for AST processing
|
||||
'''
|
||||
def __init__(self):
|
||||
self._buffer = []
|
||||
self.skip_lines = []
|
||||
|
||||
def load_buffer(self, fdata):
|
||||
'''Buffer initialization
|
||||
|
||||
Read the file as lines, so we can store the length of the file
|
||||
so we don't lose multi-line statements at the bottom of the target
|
||||
file
|
||||
:param fdata: The code to be parsed into the buffer
|
||||
'''
|
||||
self._buffer = []
|
||||
self.skip_lines = []
|
||||
lines = fdata.readlines()
|
||||
self.file_len = len(lines)
|
||||
|
||||
for lineno in range(self.file_len):
|
||||
found = False
|
||||
for flag in constants.SKIP_FLAGS:
|
||||
if "#" + flag in lines[lineno].replace(" ", "").lower():
|
||||
found = True
|
||||
if found:
|
||||
self.skip_lines.append(lineno + 1)
|
||||
|
||||
f_ast = ast.parse("".join(lines))
|
||||
# We need to expand body blocks within compound statements
|
||||
# into our statement buffer so each gets processed in
|
||||
# isolation
|
||||
tmp_buf = f_ast.body
|
||||
while len(tmp_buf):
|
||||
# For each statement, if it is one of the special statement
|
||||
# types which contain a body, we first update the tmp_buf
|
||||
# adding the internal body statements to the beginning of
|
||||
# the temporary buffer, then clear the body of the special
|
||||
# statement before adding it to the primary buffer
|
||||
stmt = tmp_buf.pop(0)
|
||||
if (isinstance(stmt, ast.ClassDef)
|
||||
or isinstance(stmt, ast.FunctionDef)
|
||||
or isinstance(stmt, ast.With)
|
||||
or isinstance(stmt, ast.Module)
|
||||
or isinstance(stmt, ast.Interactive)):
|
||||
stmt.body.extend(tmp_buf)
|
||||
tmp_buf = stmt.body
|
||||
stmt.body = []
|
||||
elif (isinstance(stmt, ast.For)
|
||||
or isinstance(stmt, ast.While)
|
||||
or isinstance(stmt, ast.If)):
|
||||
stmt.body.extend(stmt.orelse)
|
||||
stmt.body.extend(tmp_buf)
|
||||
tmp_buf = stmt.body
|
||||
stmt.body = []
|
||||
stmt.orelse = []
|
||||
elif isinstance(stmt, ast_Try):
|
||||
for handler in getattr(stmt, 'handlers', []):
|
||||
stmt.body.extend(handler.body)
|
||||
stmt.body.extend(getattr(stmt, 'orelse', []))
|
||||
stmt.body.extend(tmp_buf)
|
||||
tmp_buf = stmt.body
|
||||
stmt.body = []
|
||||
stmt.orelse = []
|
||||
stmt.handlers = []
|
||||
stmt.finalbody = []
|
||||
|
||||
# once we are sure it's either a single statement or that
|
||||
# any content in a compound statement body has been removed
|
||||
# we can add it to our primary buffer. The compound body
|
||||
# must be removed so the ast isn't walked multiple times
|
||||
# and isn't included in line-by-line output
|
||||
self._buffer.append(stmt)
|
||||
|
||||
def get_next(self, pop=True):
|
||||
'''Statment Retrieval
|
||||
|
||||
Grab the next statement in the buffer for detailed processing
|
||||
:param pop: shift next statement off array (default) or just lookahead
|
||||
:return statement: the next statement to be processed, or None
|
||||
'''
|
||||
if len(self._buffer):
|
||||
statement = {}
|
||||
if pop:
|
||||
# shift the next statement off the array
|
||||
statement['node'] = self._buffer.pop(0)
|
||||
else:
|
||||
# get the next statement without shift
|
||||
statement['node'] = self._buffer[0]
|
||||
statement['linerange'] = self.linenumber_range(statement['node'])
|
||||
return statement
|
||||
return None
|
||||
|
||||
def linenumber_range(self, node):
|
||||
'''Get set of line numbers for statement
|
||||
|
||||
Walks the given statement node, and creates a set
|
||||
of line numbers covered by the code
|
||||
:param node: The statment line numbers are required for
|
||||
:return lines: A set of line numbers
|
||||
'''
|
||||
lines = set()
|
||||
for n in ast.walk(node):
|
||||
if hasattr(n, 'lineno'):
|
||||
lines.add(n.lineno)
|
||||
# we'll return a range here, because in some cases ast.walk skips over
|
||||
# important parts, such as the middle lines in a multi-line string
|
||||
return range(min(lines), max(lines) + 1)
|
||||
|
||||
def get_skip_lines(self):
|
||||
return self.skip_lines
|
||||
|
||||
|
||||
class BanditNodeVisitor(ast.NodeVisitor):
|
||||
|
||||
imports = set()
|
||||
@ -192,8 +70,7 @@ class BanditNodeVisitor(ast.NodeVisitor):
|
||||
self.fname)
|
||||
self.namespace = ""
|
||||
self.logger.debug('Module qualified name: %s', self.namespace)
|
||||
self.stmt_buffer = StatementBuffer()
|
||||
self.statement = {}
|
||||
self.lines = []
|
||||
|
||||
def visit_ClassDef(self, node):
|
||||
'''Visitor for AST ClassDef node
|
||||
@ -322,20 +199,8 @@ class BanditNodeVisitor(ast.NodeVisitor):
|
||||
self.context['str'] = node.s
|
||||
self.logger.debug("visit_Str called (%s)", ast.dump(node))
|
||||
|
||||
# This check is to make sure we aren't running tests against
|
||||
# docstrings (any statement that is just a string, nothing else)
|
||||
node_object = self.context['statement']['node']
|
||||
|
||||
# docstrings can be represented as standalone ast.Str
|
||||
is_str = isinstance(node_object, ast.Str)
|
||||
# or ast.Expr with a value of type ast.Str
|
||||
if (isinstance(node_object, ast.Expr) and
|
||||
isinstance(node_object.value, ast.Str)):
|
||||
is_standalone_expr = True
|
||||
else:
|
||||
is_standalone_expr = False
|
||||
# if we don't have either one of those, run the test
|
||||
if not (is_str or is_standalone_expr):
|
||||
if not isinstance(node.parent, ast.Expr): # docstring
|
||||
self.context['linerange'] = b_utils.linerange_fix(node.parent)
|
||||
self.update_scores(self.tester.run_tests(self.context, 'Str'))
|
||||
super(BanditNodeVisitor, self).generic_visit(node)
|
||||
|
||||
@ -360,31 +225,19 @@ class BanditNodeVisitor(ast.NodeVisitor):
|
||||
:param node: The node that is being inspected
|
||||
:return: -
|
||||
'''
|
||||
self.context = copy.copy(self.context_template)
|
||||
self.logger.debug(ast.dump(node))
|
||||
self.metaast.add_node(node, '', self.depth)
|
||||
|
||||
self.context = copy.copy(self.context_template)
|
||||
self.context['statement'] = self.statement
|
||||
self.context['node'] = node
|
||||
self.context['filename'] = self.fname
|
||||
if hasattr(node, 'lineno'):
|
||||
self.context['lineno'] = node.lineno
|
||||
if ("# nosec" in self.lines[node.lineno - 1] or
|
||||
"#nosec" in self.lines[node.lineno - 1]):
|
||||
self.logger.debug("skipped, nosec")
|
||||
return self.scores # skip this node and all sub-nodes
|
||||
|
||||
# deal with multiline strings lineno behavior (Python issue #16806)
|
||||
current_lineno = self.context['lineno']
|
||||
next_statement = self.stmt_buffer.get_next(pop=False)
|
||||
if next_statement is not None:
|
||||
next_lineno = min(next_statement['linerange'])
|
||||
else:
|
||||
next_lineno = self.stmt_buffer.file_len
|
||||
|
||||
if next_lineno - current_lineno > 1:
|
||||
self.context['statement']['linerange'] = range(
|
||||
min(self.context['statement']['linerange']),
|
||||
next_lineno
|
||||
)
|
||||
|
||||
self.context['skip_lines'] = self.stmt_buffer.get_skip_lines()
|
||||
self.context['node'] = node
|
||||
self.context['linerange'] = b_utils.linerange_fix(node)
|
||||
self.context['filename'] = self.fname
|
||||
|
||||
self.seen += 1
|
||||
self.logger.debug("entering: %s %s [%s]", hex(id(node)), type(node),
|
||||
@ -412,18 +265,13 @@ class BanditNodeVisitor(ast.NodeVisitor):
|
||||
def process(self, fdata):
|
||||
'''Main process loop
|
||||
|
||||
Iniitalizes the statement buffer, iterates over each statement
|
||||
in the buffer testing each AST in turn
|
||||
Build and process the AST
|
||||
:param fdata: the open filehandle for the code to be processed
|
||||
:return score: the aggregated score for the current file
|
||||
'''
|
||||
self.stmt_buffer.load_buffer(fdata)
|
||||
self.statement = self.stmt_buffer.get_next()
|
||||
while self.statement is not None:
|
||||
self.logger.debug('New statement loaded')
|
||||
self.logger.debug('s_node: %s', ast.dump(self.statement['node']))
|
||||
self.logger.debug('s_lineno: %s', self.statement['linerange'])
|
||||
|
||||
self.visit(self.statement['node'])
|
||||
self.statement = self.stmt_buffer.get_next()
|
||||
fdata.seek(0)
|
||||
self.lines = fdata.readlines()
|
||||
f_ast = ast.parse("".join(self.lines))
|
||||
b_utils.embelish_ast(f_ast)
|
||||
self.visit(f_ast)
|
||||
return self.scores
|
||||
|
@ -61,7 +61,7 @@ class BanditResultStore():
|
||||
'''
|
||||
filename = context['filename']
|
||||
lineno = context['lineno']
|
||||
linerange = context['statement']['linerange']
|
||||
linerange = context['linerange']
|
||||
(issue_severity, issue_confidence, issue_text) = issue
|
||||
|
||||
if self.agg_type == 'vuln':
|
||||
|
@ -50,50 +50,49 @@ class BanditTester():
|
||||
'CONFIDENCE': [0] * len(constants.RANKING)
|
||||
}
|
||||
|
||||
if not raw_context['lineno'] in raw_context['skip_lines']:
|
||||
tests = self.testset.get_tests(checktype)
|
||||
for name, test in six.iteritems(tests):
|
||||
# execute test with the an instance of the context class
|
||||
temp_context = copy.copy(raw_context)
|
||||
context = b_context.Context(temp_context)
|
||||
try:
|
||||
if hasattr(test, '_takes_config'):
|
||||
# TODO(??): Possibly allow override from profile
|
||||
test_config = self.config.get_option(
|
||||
test._takes_config)
|
||||
result = test(context, test_config)
|
||||
else:
|
||||
result = test(context)
|
||||
tests = self.testset.get_tests(checktype)
|
||||
for name, test in six.iteritems(tests):
|
||||
# execute test with the an instance of the context class
|
||||
temp_context = copy.copy(raw_context)
|
||||
context = b_context.Context(temp_context)
|
||||
try:
|
||||
if hasattr(test, '_takes_config'):
|
||||
# TODO(??): Possibly allow override from profile
|
||||
test_config = self.config.get_option(
|
||||
test._takes_config)
|
||||
result = test(context, test_config)
|
||||
else:
|
||||
result = test(context)
|
||||
|
||||
# the test call returns a 2- or 3-tuple
|
||||
# - (issue_severity, issue_text) or
|
||||
# - (issue_severity, issue_confidence, issue_text)
|
||||
# the test call returns a 2- or 3-tuple
|
||||
# - (issue_severity, issue_text) or
|
||||
# - (issue_severity, issue_confidence, issue_text)
|
||||
|
||||
# add default confidence level, if not returned by test
|
||||
if (result is not None and len(result) == 2):
|
||||
result = (
|
||||
result[0],
|
||||
constants.CONFIDENCE_DEFAULT,
|
||||
result[1]
|
||||
)
|
||||
# add default confidence level, if not returned by test
|
||||
if (result is not None and len(result) == 2):
|
||||
result = (
|
||||
result[0],
|
||||
constants.CONFIDENCE_DEFAULT,
|
||||
result[1]
|
||||
)
|
||||
|
||||
# if we have a result, record it and update scores
|
||||
if result is not None:
|
||||
self.results.add(temp_context, name, result)
|
||||
self.logger.debug(
|
||||
"Issue identified by %s: %s", name, result
|
||||
)
|
||||
sev = constants.RANKING.index(result[0])
|
||||
val = constants.RANKING_VALUES[result[0]]
|
||||
scores['SEVERITY'][sev] += val
|
||||
con = constants.RANKING.index(result[1])
|
||||
val = constants.RANKING_VALUES[result[1]]
|
||||
scores['CONFIDENCE'][con] += val
|
||||
# if we have a result, record it and update scores
|
||||
if result is not None:
|
||||
self.results.add(temp_context, name, result)
|
||||
self.logger.debug(
|
||||
"Issue identified by %s: %s", name, result
|
||||
)
|
||||
sev = constants.RANKING.index(result[0])
|
||||
val = constants.RANKING_VALUES[result[0]]
|
||||
scores['SEVERITY'][sev] += val
|
||||
con = constants.RANKING.index(result[1])
|
||||
val = constants.RANKING_VALUES[result[1]]
|
||||
scores['CONFIDENCE'][con] += val
|
||||
|
||||
except Exception as e:
|
||||
self.report_error(name, context, e)
|
||||
if self.debug:
|
||||
raise
|
||||
except Exception as e:
|
||||
self.report_error(name, context, e)
|
||||
if self.debug:
|
||||
raise
|
||||
self.logger.debug("Returning scores: %s", scores)
|
||||
return scores
|
||||
|
||||
|
@ -261,3 +261,57 @@ def safe_str(obj):
|
||||
except UnicodeEncodeError:
|
||||
# obj is unicode
|
||||
return unicode(obj).encode('unicode_escape')
|
||||
|
||||
|
||||
def embelish_ast(node):
|
||||
"""Add a parent and sibling info to every node."""
|
||||
for _, value in ast.iter_fields(node):
|
||||
if isinstance(value, list):
|
||||
last = None
|
||||
for item in value:
|
||||
if isinstance(item, ast.AST):
|
||||
if last is not None:
|
||||
setattr(last, 'sibling', item)
|
||||
last = item
|
||||
setattr(item, 'parent', node)
|
||||
embelish_ast(item)
|
||||
|
||||
elif isinstance(value, ast.AST):
|
||||
setattr(value, 'parent', node)
|
||||
embelish_ast(value)
|
||||
|
||||
|
||||
def linerange(node):
|
||||
"""Get line number range from a node."""
|
||||
strip = {"body": None, "orelse": None,
|
||||
"handlers": None, "finalbody": None}
|
||||
fields = dir(node)
|
||||
for key in strip.keys():
|
||||
if key in fields:
|
||||
strip[key] = getattr(node, key)
|
||||
setattr(node, key, [])
|
||||
|
||||
lines = set()
|
||||
for n in ast.walk(node):
|
||||
if hasattr(n, 'lineno'):
|
||||
lines.add(n.lineno)
|
||||
|
||||
for key in strip.keys():
|
||||
if strip[key] is not None:
|
||||
setattr(node, key, strip[key])
|
||||
|
||||
if len(lines):
|
||||
return range(min(lines), max(lines) + 1)
|
||||
return [0, 1]
|
||||
|
||||
|
||||
def linerange_fix(node):
|
||||
"""Try and work around a known Python bug with multi-line strings."""
|
||||
# deal with multiline strings lineno behavior (Python issue #16806)
|
||||
lines = linerange(node)
|
||||
if hasattr(node, 'sibling') and hasattr(node.sibling, 'lineno'):
|
||||
start = min(lines)
|
||||
delta = node.sibling.lineno - start
|
||||
if delta > 1:
|
||||
return range(start, node.sibling.lineno)
|
||||
return lines
|
||||
|
@ -47,7 +47,7 @@ def _ast_binop_stringify(data):
|
||||
|
||||
@checks('Str')
|
||||
def hardcoded_sql_expressions(context):
|
||||
statement = context.statement['node']
|
||||
statement = context.node.parent
|
||||
if isinstance(statement, ast.Assign):
|
||||
test_str = _ast_build_string(statement.value).lower()
|
||||
|
||||
|
@ -394,7 +394,7 @@ class FunctionalTests(unittest.TestCase):
|
||||
self.assertEqual(range(2, 7), issues[0]['line_range'])
|
||||
self.assertIn('/tmp', issues[0]['code'])
|
||||
self.assertEqual(18, issues[1]['line_number'])
|
||||
self.assertEqual(range(16, 21), issues[1]['line_range'])
|
||||
self.assertEqual(range(16, 19), issues[1]['line_range'])
|
||||
self.assertIn('/tmp', issues[1]['code'])
|
||||
self.assertEqual(23, issues[2]['line_number'])
|
||||
self.assertEqual(range(22, 31), issues[2]['line_range'])
|
||||
|
@ -1,85 +0,0 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
#
|
||||
# Copyright 2014 Hewlett-Packard Development Company, L.P.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import os
|
||||
import ast
|
||||
|
||||
import unittest
|
||||
from bandit.core import node_visitor
|
||||
|
||||
|
||||
class StatementBufferTests(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(StatementBufferTests, self).setUp()
|
||||
self.test_file = open("./examples/jinja2_templating.py")
|
||||
self.buf = node_visitor.StatementBuffer()
|
||||
self.buf.load_buffer(self.test_file)
|
||||
|
||||
def tearDown(self):
|
||||
pass
|
||||
|
||||
def test_load_buffer(self):
|
||||
# Check buffer contains 10 statements
|
||||
self.assertEqual(10, len(self.buf._buffer))
|
||||
|
||||
def test_get_next(self):
|
||||
# Check get_next returns an AST statement
|
||||
stmt = self.buf.get_next()
|
||||
self.assertTrue(isinstance(stmt['node'], ast.AST))
|
||||
# Check get_next returned the first statement
|
||||
self.assertEqual(1, stmt['linerange'][0])
|
||||
# Check buffer has been reduced by one
|
||||
self.assertEqual(9, len(self.buf._buffer))
|
||||
|
||||
def test_get_next_lookahead(self):
|
||||
# Check get_next(pop=False) returns an AST statement
|
||||
stmt = self.buf.get_next(pop=False)
|
||||
self.assertTrue(isinstance(stmt['node'], ast.AST))
|
||||
# Check get_next(pop=False) returned the first statement
|
||||
self.assertEqual(1, stmt['linerange'][0])
|
||||
# Check buffer remains the same length
|
||||
self.assertEqual(10, len(self.buf._buffer))
|
||||
|
||||
def test_get_next_count(self):
|
||||
# Check get_next returns exactly 10 statements
|
||||
count = 0
|
||||
stmt = self.buf.get_next()
|
||||
while stmt is not None:
|
||||
count = count + 1
|
||||
stmt = self.buf.get_next()
|
||||
|
||||
self.assertEqual(10, count)
|
||||
|
||||
def test_get_next_empty(self):
|
||||
# Check get_next on an empty buffer returns None
|
||||
# self.test_file has already been read, so is empty file handle
|
||||
self.buf.load_buffer(self.test_file)
|
||||
stmt = self.buf.get_next()
|
||||
self.assertEqual(None, stmt)
|
||||
|
||||
def test_linenumber_range(self):
|
||||
# Check linenumber_range returns corrent number of lines
|
||||
count = 9
|
||||
while count > 0:
|
||||
stmt = self.buf.get_next()
|
||||
count = count - 1
|
||||
|
||||
# line 9 should be three lines long
|
||||
self.assertEqual(3, len(stmt['linerange']))
|
||||
|
||||
# the range should be the correct line numbers
|
||||
self.assertEqual([11, 12, 13], list(stmt['linerange']))
|
@ -221,3 +221,16 @@ class UtilTests(unittest.TestCase):
|
||||
self.assertEqual('', name)
|
||||
# TODO(ljfisher) At best we might be able to get:
|
||||
# self.assertEqual(name, 'a.list[0]')
|
||||
|
||||
def test_linerange(self):
|
||||
self.test_file = open("./examples/jinja2_templating.py")
|
||||
self.tree = ast.parse(self.test_file.read())
|
||||
# Check linerange returns corrent number of lines
|
||||
line = self.tree.body[8]
|
||||
lrange = b_utils.linerange(line)
|
||||
|
||||
# line 9 should be three lines long
|
||||
self.assertEqual(3, len(lrange))
|
||||
|
||||
# the range should be the correct line numbers
|
||||
self.assertEqual([11, 12, 13], list(lrange))
|
||||
|
Loading…
x
Reference in New Issue
Block a user