| Index: third_party/cython/src/Cython/Compiler/TypeInference.py
|
| diff --git a/third_party/cython/src/Cython/Compiler/TypeInference.py b/third_party/cython/src/Cython/Compiler/TypeInference.py
|
| new file mode 100644
|
| index 0000000000000000000000000000000000000000..f089f330af0fa4ca8ac53f06be2037aebc5120f5
|
| --- /dev/null
|
| +++ b/third_party/cython/src/Cython/Compiler/TypeInference.py
|
| @@ -0,0 +1,550 @@
|
| +from Errors import error, message
|
| +import ExprNodes
|
| +import Nodes
|
| +import Builtin
|
| +import PyrexTypes
|
| +from Cython import Utils
|
| +from PyrexTypes import py_object_type, unspecified_type
|
| +from Visitor import CythonTransform, EnvTransform
|
| +
|
| +
|
| +class TypedExprNode(ExprNodes.ExprNode):
|
| + # Used for declaring assignments of a specified type without a known entry.
|
| + def __init__(self, type):
|
| + self.type = type
|
| +
|
| +object_expr = TypedExprNode(py_object_type)
|
| +
|
| +
|
| +class MarkParallelAssignments(EnvTransform):
|
| + # Collects assignments inside parallel blocks prange, with parallel.
|
| + # Perhaps it's better to move it to ControlFlowAnalysis.
|
| +
|
| + # tells us whether we're in a normal loop
|
| + in_loop = False
|
| +
|
| + parallel_errors = False
|
| +
|
| + def __init__(self, context):
|
| + # Track the parallel block scopes (with parallel, for i in prange())
|
| + self.parallel_block_stack = []
|
| + super(MarkParallelAssignments, self).__init__(context)
|
| +
|
| + def mark_assignment(self, lhs, rhs, inplace_op=None):
|
| + if isinstance(lhs, (ExprNodes.NameNode, Nodes.PyArgDeclNode)):
|
| + if lhs.entry is None:
|
| + # TODO: This shouldn't happen...
|
| + return
|
| +
|
| + if self.parallel_block_stack:
|
| + parallel_node = self.parallel_block_stack[-1]
|
| + previous_assignment = parallel_node.assignments.get(lhs.entry)
|
| +
|
| + # If there was a previous assignment to the variable, keep the
|
| + # previous assignment position
|
| + if previous_assignment:
|
| + pos, previous_inplace_op = previous_assignment
|
| +
|
| + if (inplace_op and previous_inplace_op and
|
| + inplace_op != previous_inplace_op):
|
| + # x += y; x *= y
|
| + t = (inplace_op, previous_inplace_op)
|
| + error(lhs.pos,
|
| + "Reduction operator '%s' is inconsistent "
|
| + "with previous reduction operator '%s'" % t)
|
| + else:
|
| + pos = lhs.pos
|
| +
|
| + parallel_node.assignments[lhs.entry] = (pos, inplace_op)
|
| + parallel_node.assigned_nodes.append(lhs)
|
| +
|
| + elif isinstance(lhs, ExprNodes.SequenceNode):
|
| + for arg in lhs.args:
|
| + self.mark_assignment(arg, object_expr)
|
| + else:
|
| + # Could use this info to infer cdef class attributes...
|
| + pass
|
| +
|
| + def visit_WithTargetAssignmentStatNode(self, node):
|
| + self.mark_assignment(node.lhs, node.rhs)
|
| + self.visitchildren(node)
|
| + return node
|
| +
|
| + def visit_SingleAssignmentNode(self, node):
|
| + self.mark_assignment(node.lhs, node.rhs)
|
| + self.visitchildren(node)
|
| + return node
|
| +
|
| + def visit_CascadedAssignmentNode(self, node):
|
| + for lhs in node.lhs_list:
|
| + self.mark_assignment(lhs, node.rhs)
|
| + self.visitchildren(node)
|
| + return node
|
| +
|
| + def visit_InPlaceAssignmentNode(self, node):
|
| + self.mark_assignment(node.lhs, node.create_binop_node(), node.operator)
|
| + self.visitchildren(node)
|
| + return node
|
| +
|
| + def visit_ForInStatNode(self, node):
|
| + # TODO: Remove redundancy with range optimization...
|
| + is_special = False
|
| + sequence = node.iterator.sequence
|
| + target = node.target
|
| + if isinstance(sequence, ExprNodes.SimpleCallNode):
|
| + function = sequence.function
|
| + if sequence.self is None and function.is_name:
|
| + entry = self.current_env().lookup(function.name)
|
| + if not entry or entry.is_builtin:
|
| + if function.name == 'reversed' and len(sequence.args) == 1:
|
| + sequence = sequence.args[0]
|
| + elif function.name == 'enumerate' and len(sequence.args) == 1:
|
| + if target.is_sequence_constructor and len(target.args) == 2:
|
| + iterator = sequence.args[0]
|
| + if iterator.is_name:
|
| + iterator_type = iterator.infer_type(self.current_env())
|
| + if iterator_type.is_builtin_type:
|
| + # assume that builtin types have a length within Py_ssize_t
|
| + self.mark_assignment(
|
| + target.args[0],
|
| + ExprNodes.IntNode(target.pos, value='PY_SSIZE_T_MAX',
|
| + type=PyrexTypes.c_py_ssize_t_type))
|
| + target = target.args[1]
|
| + sequence = sequence.args[0]
|
| + if isinstance(sequence, ExprNodes.SimpleCallNode):
|
| + function = sequence.function
|
| + if sequence.self is None and function.is_name:
|
| + entry = self.current_env().lookup(function.name)
|
| + if not entry or entry.is_builtin:
|
| + if function.name in ('range', 'xrange'):
|
| + is_special = True
|
| + for arg in sequence.args[:2]:
|
| + self.mark_assignment(target, arg)
|
| + if len(sequence.args) > 2:
|
| + self.mark_assignment(
|
| + target,
|
| + ExprNodes.binop_node(node.pos,
|
| + '+',
|
| + sequence.args[0],
|
| + sequence.args[2]))
|
| +
|
| + if not is_special:
|
| + # A for-loop basically translates to subsequent calls to
|
| + # __getitem__(), so using an IndexNode here allows us to
|
| + # naturally infer the base type of pointers, C arrays,
|
| + # Python strings, etc., while correctly falling back to an
|
| + # object type when the base type cannot be handled.
|
| + self.mark_assignment(target, ExprNodes.IndexNode(
|
| + node.pos,
|
| + base=sequence,
|
| + index=ExprNodes.IntNode(target.pos, value='PY_SSIZE_T_MAX',
|
| + type=PyrexTypes.c_py_ssize_t_type)))
|
| +
|
| + self.visitchildren(node)
|
| + return node
|
| +
|
| + def visit_ForFromStatNode(self, node):
|
| + self.mark_assignment(node.target, node.bound1)
|
| + if node.step is not None:
|
| + self.mark_assignment(node.target,
|
| + ExprNodes.binop_node(node.pos,
|
| + '+',
|
| + node.bound1,
|
| + node.step))
|
| + self.visitchildren(node)
|
| + return node
|
| +
|
| + def visit_WhileStatNode(self, node):
|
| + self.visitchildren(node)
|
| + return node
|
| +
|
| + def visit_ExceptClauseNode(self, node):
|
| + if node.target is not None:
|
| + self.mark_assignment(node.target, object_expr)
|
| + self.visitchildren(node)
|
| + return node
|
| +
|
| + def visit_FromCImportStatNode(self, node):
|
| + pass # Can't be assigned to...
|
| +
|
| + def visit_FromImportStatNode(self, node):
|
| + for name, target in node.items:
|
| + if name != "*":
|
| + self.mark_assignment(target, object_expr)
|
| + self.visitchildren(node)
|
| + return node
|
| +
|
| + def visit_DefNode(self, node):
|
| + # use fake expressions with the right result type
|
| + if node.star_arg:
|
| + self.mark_assignment(
|
| + node.star_arg, TypedExprNode(Builtin.tuple_type))
|
| + if node.starstar_arg:
|
| + self.mark_assignment(
|
| + node.starstar_arg, TypedExprNode(Builtin.dict_type))
|
| + EnvTransform.visit_FuncDefNode(self, node)
|
| + return node
|
| +
|
| + def visit_DelStatNode(self, node):
|
| + for arg in node.args:
|
| + self.mark_assignment(arg, arg)
|
| + self.visitchildren(node)
|
| + return node
|
| +
|
| + def visit_ParallelStatNode(self, node):
|
| + if self.parallel_block_stack:
|
| + node.parent = self.parallel_block_stack[-1]
|
| + else:
|
| + node.parent = None
|
| +
|
| + nested = False
|
| + if node.is_prange:
|
| + if not node.parent:
|
| + node.is_parallel = True
|
| + else:
|
| + node.is_parallel = (node.parent.is_prange or not
|
| + node.parent.is_parallel)
|
| + nested = node.parent.is_prange
|
| + else:
|
| + node.is_parallel = True
|
| + # Note: nested with parallel() blocks are handled by
|
| + # ParallelRangeTransform!
|
| + # nested = node.parent
|
| + nested = node.parent and node.parent.is_prange
|
| +
|
| + self.parallel_block_stack.append(node)
|
| +
|
| + nested = nested or len(self.parallel_block_stack) > 2
|
| + if not self.parallel_errors and nested and not node.is_prange:
|
| + error(node.pos, "Only prange() may be nested")
|
| + self.parallel_errors = True
|
| +
|
| + if node.is_prange:
|
| + child_attrs = node.child_attrs
|
| + node.child_attrs = ['body', 'target', 'args']
|
| + self.visitchildren(node)
|
| + node.child_attrs = child_attrs
|
| +
|
| + self.parallel_block_stack.pop()
|
| + if node.else_clause:
|
| + node.else_clause = self.visit(node.else_clause)
|
| + else:
|
| + self.visitchildren(node)
|
| + self.parallel_block_stack.pop()
|
| +
|
| + self.parallel_errors = False
|
| + return node
|
| +
|
| + def visit_YieldExprNode(self, node):
|
| + if self.parallel_block_stack:
|
| + error(node.pos, "Yield not allowed in parallel sections")
|
| +
|
| + return node
|
| +
|
| + def visit_ReturnStatNode(self, node):
|
| + node.in_parallel = bool(self.parallel_block_stack)
|
| + return node
|
| +
|
| +
|
| +class MarkOverflowingArithmetic(CythonTransform):
|
| +
|
| + # It may be possible to integrate this with the above for
|
| + # performance improvements (though likely not worth it).
|
| +
|
| + might_overflow = False
|
| +
|
| + def __call__(self, root):
|
| + self.env_stack = []
|
| + self.env = root.scope
|
| + return super(MarkOverflowingArithmetic, self).__call__(root)
|
| +
|
| + def visit_safe_node(self, node):
|
| + self.might_overflow, saved = False, self.might_overflow
|
| + self.visitchildren(node)
|
| + self.might_overflow = saved
|
| + return node
|
| +
|
| + def visit_neutral_node(self, node):
|
| + self.visitchildren(node)
|
| + return node
|
| +
|
| + def visit_dangerous_node(self, node):
|
| + self.might_overflow, saved = True, self.might_overflow
|
| + self.visitchildren(node)
|
| + self.might_overflow = saved
|
| + return node
|
| +
|
| + def visit_FuncDefNode(self, node):
|
| + self.env_stack.append(self.env)
|
| + self.env = node.local_scope
|
| + self.visit_safe_node(node)
|
| + self.env = self.env_stack.pop()
|
| + return node
|
| +
|
| + def visit_NameNode(self, node):
|
| + if self.might_overflow:
|
| + entry = node.entry or self.env.lookup(node.name)
|
| + if entry:
|
| + entry.might_overflow = True
|
| + return node
|
| +
|
| + def visit_BinopNode(self, node):
|
| + if node.operator in '&|^':
|
| + return self.visit_neutral_node(node)
|
| + else:
|
| + return self.visit_dangerous_node(node)
|
| +
|
| + visit_UnopNode = visit_neutral_node
|
| +
|
| + visit_UnaryMinusNode = visit_dangerous_node
|
| +
|
| + visit_InPlaceAssignmentNode = visit_dangerous_node
|
| +
|
| + visit_Node = visit_safe_node
|
| +
|
| + def visit_assignment(self, lhs, rhs):
|
| + if (isinstance(rhs, ExprNodes.IntNode)
|
| + and isinstance(lhs, ExprNodes.NameNode)
|
| + and Utils.long_literal(rhs.value)):
|
| + entry = lhs.entry or self.env.lookup(lhs.name)
|
| + if entry:
|
| + entry.might_overflow = True
|
| +
|
| + def visit_SingleAssignmentNode(self, node):
|
| + self.visit_assignment(node.lhs, node.rhs)
|
| + self.visitchildren(node)
|
| + return node
|
| +
|
| + def visit_CascadedAssignmentNode(self, node):
|
| + for lhs in node.lhs_list:
|
| + self.visit_assignment(lhs, node.rhs)
|
| + self.visitchildren(node)
|
| + return node
|
| +
|
| +class PyObjectTypeInferer(object):
|
| + """
|
| + If it's not declared, it's a PyObject.
|
| + """
|
| + def infer_types(self, scope):
|
| + """
|
| + Given a dict of entries, map all unspecified types to a specified type.
|
| + """
|
| + for name, entry in scope.entries.items():
|
| + if entry.type is unspecified_type:
|
| + entry.type = py_object_type
|
| +
|
| +class SimpleAssignmentTypeInferer(object):
|
| + """
|
| + Very basic type inference.
|
| +
|
| + Note: in order to support cross-closure type inference, this must be
|
| + applies to nested scopes in top-down order.
|
| + """
|
| + def set_entry_type(self, entry, entry_type):
|
| + entry.type = entry_type
|
| + for e in entry.all_entries():
|
| + e.type = entry_type
|
| +
|
| + def infer_types(self, scope):
|
| + enabled = scope.directives['infer_types']
|
| + verbose = scope.directives['infer_types.verbose']
|
| +
|
| + if enabled == True:
|
| + spanning_type = aggressive_spanning_type
|
| + elif enabled is None: # safe mode
|
| + spanning_type = safe_spanning_type
|
| + else:
|
| + for entry in scope.entries.values():
|
| + if entry.type is unspecified_type:
|
| + self.set_entry_type(entry, py_object_type)
|
| + return
|
| +
|
| + # Set of assignemnts
|
| + assignments = set([])
|
| + assmts_resolved = set([])
|
| + dependencies = {}
|
| + assmt_to_names = {}
|
| +
|
| + for name, entry in scope.entries.items():
|
| + for assmt in entry.cf_assignments:
|
| + names = assmt.type_dependencies()
|
| + assmt_to_names[assmt] = names
|
| + assmts = set()
|
| + for node in names:
|
| + assmts.update(node.cf_state)
|
| + dependencies[assmt] = assmts
|
| + if entry.type is unspecified_type:
|
| + assignments.update(entry.cf_assignments)
|
| + else:
|
| + assmts_resolved.update(entry.cf_assignments)
|
| +
|
| + def infer_name_node_type(node):
|
| + types = [assmt.inferred_type for assmt in node.cf_state]
|
| + if not types:
|
| + node_type = py_object_type
|
| + else:
|
| + entry = node.entry
|
| + node_type = spanning_type(
|
| + types, entry.might_overflow, entry.pos)
|
| + node.inferred_type = node_type
|
| +
|
| + def infer_name_node_type_partial(node):
|
| + types = [assmt.inferred_type for assmt in node.cf_state
|
| + if assmt.inferred_type is not None]
|
| + if not types:
|
| + return
|
| + entry = node.entry
|
| + return spanning_type(types, entry.might_overflow, entry.pos)
|
| +
|
| + def resolve_assignments(assignments):
|
| + resolved = set()
|
| + for assmt in assignments:
|
| + deps = dependencies[assmt]
|
| + # All assignments are resolved
|
| + if assmts_resolved.issuperset(deps):
|
| + for node in assmt_to_names[assmt]:
|
| + infer_name_node_type(node)
|
| + # Resolve assmt
|
| + inferred_type = assmt.infer_type()
|
| + assmts_resolved.add(assmt)
|
| + resolved.add(assmt)
|
| + assignments.difference_update(resolved)
|
| + return resolved
|
| +
|
| + def partial_infer(assmt):
|
| + partial_types = []
|
| + for node in assmt_to_names[assmt]:
|
| + partial_type = infer_name_node_type_partial(node)
|
| + if partial_type is None:
|
| + return False
|
| + partial_types.append((node, partial_type))
|
| + for node, partial_type in partial_types:
|
| + node.inferred_type = partial_type
|
| + assmt.infer_type()
|
| + return True
|
| +
|
| + partial_assmts = set()
|
| + def resolve_partial(assignments):
|
| + # try to handle circular references
|
| + partials = set()
|
| + for assmt in assignments:
|
| + if assmt in partial_assmts:
|
| + continue
|
| + if partial_infer(assmt):
|
| + partials.add(assmt)
|
| + assmts_resolved.add(assmt)
|
| + partial_assmts.update(partials)
|
| + return partials
|
| +
|
| + # Infer assignments
|
| + while True:
|
| + if not resolve_assignments(assignments):
|
| + if not resolve_partial(assignments):
|
| + break
|
| + inferred = set()
|
| + # First pass
|
| + for entry in scope.entries.values():
|
| + if entry.type is not unspecified_type:
|
| + continue
|
| + entry_type = py_object_type
|
| + if assmts_resolved.issuperset(entry.cf_assignments):
|
| + types = [assmt.inferred_type for assmt in entry.cf_assignments]
|
| + if types and Utils.all(types):
|
| + entry_type = spanning_type(
|
| + types, entry.might_overflow, entry.pos)
|
| + inferred.add(entry)
|
| + self.set_entry_type(entry, entry_type)
|
| +
|
| + def reinfer():
|
| + dirty = False
|
| + for entry in inferred:
|
| + types = [assmt.infer_type()
|
| + for assmt in entry.cf_assignments]
|
| + new_type = spanning_type(types, entry.might_overflow, entry.pos)
|
| + if new_type != entry.type:
|
| + self.set_entry_type(entry, new_type)
|
| + dirty = True
|
| + return dirty
|
| +
|
| + # types propagation
|
| + while reinfer():
|
| + pass
|
| +
|
| + if verbose:
|
| + for entry in inferred:
|
| + message(entry.pos, "inferred '%s' to be of type '%s'" % (
|
| + entry.name, entry.type))
|
| +
|
| +
|
| +def find_spanning_type(type1, type2):
|
| + if type1 is type2:
|
| + result_type = type1
|
| + elif type1 is PyrexTypes.c_bint_type or type2 is PyrexTypes.c_bint_type:
|
| + # type inference can break the coercion back to a Python bool
|
| + # if it returns an arbitrary int type here
|
| + return py_object_type
|
| + else:
|
| + result_type = PyrexTypes.spanning_type(type1, type2)
|
| + if result_type in (PyrexTypes.c_double_type, PyrexTypes.c_float_type,
|
| + Builtin.float_type):
|
| + # Python's float type is just a C double, so it's safe to
|
| + # use the C type instead
|
| + return PyrexTypes.c_double_type
|
| + return result_type
|
| +
|
| +def aggressive_spanning_type(types, might_overflow, pos):
|
| + result_type = reduce(find_spanning_type, types)
|
| + if result_type.is_reference:
|
| + result_type = result_type.ref_base_type
|
| + if result_type.is_const:
|
| + result_type = result_type.const_base_type
|
| + if result_type.is_cpp_class:
|
| + result_type.check_nullary_constructor(pos)
|
| + return result_type
|
| +
|
| +def safe_spanning_type(types, might_overflow, pos):
|
| + result_type = reduce(find_spanning_type, types)
|
| + if result_type.is_const:
|
| + result_type = result_type.const_base_type
|
| + if result_type.is_reference:
|
| + result_type = result_type.ref_base_type
|
| + if result_type.is_cpp_class:
|
| + result_type.check_nullary_constructor(pos)
|
| + if result_type.is_pyobject:
|
| + # In theory, any specific Python type is always safe to
|
| + # infer. However, inferring str can cause some existing code
|
| + # to break, since we are also now much more strict about
|
| + # coercion from str to char *. See trac #553.
|
| + if result_type.name == 'str':
|
| + return py_object_type
|
| + else:
|
| + return result_type
|
| + elif result_type is PyrexTypes.c_double_type:
|
| + # Python's float type is just a C double, so it's safe to use
|
| + # the C type instead
|
| + return result_type
|
| + elif result_type is PyrexTypes.c_bint_type:
|
| + # find_spanning_type() only returns 'bint' for clean boolean
|
| + # operations without other int types, so this is safe, too
|
| + return result_type
|
| + elif result_type.is_ptr:
|
| + # Any pointer except (signed|unsigned|) char* can't implicitly
|
| + # become a PyObject, and inferring char* is now accepted, too.
|
| + return result_type
|
| + elif result_type.is_cpp_class:
|
| + # These can't implicitly become Python objects either.
|
| + return result_type
|
| + elif result_type.is_struct:
|
| + # Though we have struct -> object for some structs, this is uncommonly
|
| + # used, won't arise in pure Python, and there shouldn't be side
|
| + # effects, so I'm declaring this safe.
|
| + return result_type
|
| + # TODO: double complex should be OK as well, but we need
|
| + # to make sure everything is supported.
|
| + elif (result_type.is_int or result_type.is_enum) and not might_overflow:
|
| + return result_type
|
| + return py_object_type
|
| +
|
| +
|
| +def get_type_inferer():
|
| + return SimpleAssignmentTypeInferer()
|
|
|