OLD | NEW |
(Empty) | |
| 1 from Errors import error, message |
| 2 import ExprNodes |
| 3 import Nodes |
| 4 import Builtin |
| 5 import PyrexTypes |
| 6 from Cython import Utils |
| 7 from PyrexTypes import py_object_type, unspecified_type |
| 8 from Visitor import CythonTransform, EnvTransform |
| 9 |
| 10 |
| 11 class TypedExprNode(ExprNodes.ExprNode): |
| 12 # Used for declaring assignments of a specified type without a known entry. |
| 13 def __init__(self, type): |
| 14 self.type = type |
| 15 |
| 16 object_expr = TypedExprNode(py_object_type) |
| 17 |
| 18 |
| 19 class MarkParallelAssignments(EnvTransform): |
| 20 # Collects assignments inside parallel blocks prange, with parallel. |
| 21 # Perhaps it's better to move it to ControlFlowAnalysis. |
| 22 |
| 23 # tells us whether we're in a normal loop |
| 24 in_loop = False |
| 25 |
| 26 parallel_errors = False |
| 27 |
| 28 def __init__(self, context): |
| 29 # Track the parallel block scopes (with parallel, for i in prange()) |
| 30 self.parallel_block_stack = [] |
| 31 super(MarkParallelAssignments, self).__init__(context) |
| 32 |
| 33 def mark_assignment(self, lhs, rhs, inplace_op=None): |
| 34 if isinstance(lhs, (ExprNodes.NameNode, Nodes.PyArgDeclNode)): |
| 35 if lhs.entry is None: |
| 36 # TODO: This shouldn't happen... |
| 37 return |
| 38 |
| 39 if self.parallel_block_stack: |
| 40 parallel_node = self.parallel_block_stack[-1] |
| 41 previous_assignment = parallel_node.assignments.get(lhs.entry) |
| 42 |
| 43 # If there was a previous assignment to the variable, keep the |
| 44 # previous assignment position |
| 45 if previous_assignment: |
| 46 pos, previous_inplace_op = previous_assignment |
| 47 |
| 48 if (inplace_op and previous_inplace_op and |
| 49 inplace_op != previous_inplace_op): |
| 50 # x += y; x *= y |
| 51 t = (inplace_op, previous_inplace_op) |
| 52 error(lhs.pos, |
| 53 "Reduction operator '%s' is inconsistent " |
| 54 "with previous reduction operator '%s'" % t) |
| 55 else: |
| 56 pos = lhs.pos |
| 57 |
| 58 parallel_node.assignments[lhs.entry] = (pos, inplace_op) |
| 59 parallel_node.assigned_nodes.append(lhs) |
| 60 |
| 61 elif isinstance(lhs, ExprNodes.SequenceNode): |
| 62 for arg in lhs.args: |
| 63 self.mark_assignment(arg, object_expr) |
| 64 else: |
| 65 # Could use this info to infer cdef class attributes... |
| 66 pass |
| 67 |
| 68 def visit_WithTargetAssignmentStatNode(self, node): |
| 69 self.mark_assignment(node.lhs, node.rhs) |
| 70 self.visitchildren(node) |
| 71 return node |
| 72 |
| 73 def visit_SingleAssignmentNode(self, node): |
| 74 self.mark_assignment(node.lhs, node.rhs) |
| 75 self.visitchildren(node) |
| 76 return node |
| 77 |
| 78 def visit_CascadedAssignmentNode(self, node): |
| 79 for lhs in node.lhs_list: |
| 80 self.mark_assignment(lhs, node.rhs) |
| 81 self.visitchildren(node) |
| 82 return node |
| 83 |
| 84 def visit_InPlaceAssignmentNode(self, node): |
| 85 self.mark_assignment(node.lhs, node.create_binop_node(), node.operator) |
| 86 self.visitchildren(node) |
| 87 return node |
| 88 |
| 89 def visit_ForInStatNode(self, node): |
| 90 # TODO: Remove redundancy with range optimization... |
| 91 is_special = False |
| 92 sequence = node.iterator.sequence |
| 93 target = node.target |
| 94 if isinstance(sequence, ExprNodes.SimpleCallNode): |
| 95 function = sequence.function |
| 96 if sequence.self is None and function.is_name: |
| 97 entry = self.current_env().lookup(function.name) |
| 98 if not entry or entry.is_builtin: |
| 99 if function.name == 'reversed' and len(sequence.args) == 1: |
| 100 sequence = sequence.args[0] |
| 101 elif function.name == 'enumerate' and len(sequence.args) ==
1: |
| 102 if target.is_sequence_constructor and len(target.args) =
= 2: |
| 103 iterator = sequence.args[0] |
| 104 if iterator.is_name: |
| 105 iterator_type = iterator.infer_type(self.current
_env()) |
| 106 if iterator_type.is_builtin_type: |
| 107 # assume that builtin types have a length wi
thin Py_ssize_t |
| 108 self.mark_assignment( |
| 109 target.args[0], |
| 110 ExprNodes.IntNode(target.pos, value='PY_
SSIZE_T_MAX', |
| 111 type=PyrexTypes.c_py_s
size_t_type)) |
| 112 target = target.args[1] |
| 113 sequence = sequence.args[0] |
| 114 if isinstance(sequence, ExprNodes.SimpleCallNode): |
| 115 function = sequence.function |
| 116 if sequence.self is None and function.is_name: |
| 117 entry = self.current_env().lookup(function.name) |
| 118 if not entry or entry.is_builtin: |
| 119 if function.name in ('range', 'xrange'): |
| 120 is_special = True |
| 121 for arg in sequence.args[:2]: |
| 122 self.mark_assignment(target, arg) |
| 123 if len(sequence.args) > 2: |
| 124 self.mark_assignment( |
| 125 target, |
| 126 ExprNodes.binop_node(node.pos, |
| 127 '+', |
| 128 sequence.args[0], |
| 129 sequence.args[2])) |
| 130 |
| 131 if not is_special: |
| 132 # A for-loop basically translates to subsequent calls to |
| 133 # __getitem__(), so using an IndexNode here allows us to |
| 134 # naturally infer the base type of pointers, C arrays, |
| 135 # Python strings, etc., while correctly falling back to an |
| 136 # object type when the base type cannot be handled. |
| 137 self.mark_assignment(target, ExprNodes.IndexNode( |
| 138 node.pos, |
| 139 base=sequence, |
| 140 index=ExprNodes.IntNode(target.pos, value='PY_SSIZE_T_MAX', |
| 141 type=PyrexTypes.c_py_ssize_t_type))) |
| 142 |
| 143 self.visitchildren(node) |
| 144 return node |
| 145 |
| 146 def visit_ForFromStatNode(self, node): |
| 147 self.mark_assignment(node.target, node.bound1) |
| 148 if node.step is not None: |
| 149 self.mark_assignment(node.target, |
| 150 ExprNodes.binop_node(node.pos, |
| 151 '+', |
| 152 node.bound1, |
| 153 node.step)) |
| 154 self.visitchildren(node) |
| 155 return node |
| 156 |
| 157 def visit_WhileStatNode(self, node): |
| 158 self.visitchildren(node) |
| 159 return node |
| 160 |
| 161 def visit_ExceptClauseNode(self, node): |
| 162 if node.target is not None: |
| 163 self.mark_assignment(node.target, object_expr) |
| 164 self.visitchildren(node) |
| 165 return node |
| 166 |
| 167 def visit_FromCImportStatNode(self, node): |
| 168 pass # Can't be assigned to... |
| 169 |
| 170 def visit_FromImportStatNode(self, node): |
| 171 for name, target in node.items: |
| 172 if name != "*": |
| 173 self.mark_assignment(target, object_expr) |
| 174 self.visitchildren(node) |
| 175 return node |
| 176 |
| 177 def visit_DefNode(self, node): |
| 178 # use fake expressions with the right result type |
| 179 if node.star_arg: |
| 180 self.mark_assignment( |
| 181 node.star_arg, TypedExprNode(Builtin.tuple_type)) |
| 182 if node.starstar_arg: |
| 183 self.mark_assignment( |
| 184 node.starstar_arg, TypedExprNode(Builtin.dict_type)) |
| 185 EnvTransform.visit_FuncDefNode(self, node) |
| 186 return node |
| 187 |
| 188 def visit_DelStatNode(self, node): |
| 189 for arg in node.args: |
| 190 self.mark_assignment(arg, arg) |
| 191 self.visitchildren(node) |
| 192 return node |
| 193 |
| 194 def visit_ParallelStatNode(self, node): |
| 195 if self.parallel_block_stack: |
| 196 node.parent = self.parallel_block_stack[-1] |
| 197 else: |
| 198 node.parent = None |
| 199 |
| 200 nested = False |
| 201 if node.is_prange: |
| 202 if not node.parent: |
| 203 node.is_parallel = True |
| 204 else: |
| 205 node.is_parallel = (node.parent.is_prange or not |
| 206 node.parent.is_parallel) |
| 207 nested = node.parent.is_prange |
| 208 else: |
| 209 node.is_parallel = True |
| 210 # Note: nested with parallel() blocks are handled by |
| 211 # ParallelRangeTransform! |
| 212 # nested = node.parent |
| 213 nested = node.parent and node.parent.is_prange |
| 214 |
| 215 self.parallel_block_stack.append(node) |
| 216 |
| 217 nested = nested or len(self.parallel_block_stack) > 2 |
| 218 if not self.parallel_errors and nested and not node.is_prange: |
| 219 error(node.pos, "Only prange() may be nested") |
| 220 self.parallel_errors = True |
| 221 |
| 222 if node.is_prange: |
| 223 child_attrs = node.child_attrs |
| 224 node.child_attrs = ['body', 'target', 'args'] |
| 225 self.visitchildren(node) |
| 226 node.child_attrs = child_attrs |
| 227 |
| 228 self.parallel_block_stack.pop() |
| 229 if node.else_clause: |
| 230 node.else_clause = self.visit(node.else_clause) |
| 231 else: |
| 232 self.visitchildren(node) |
| 233 self.parallel_block_stack.pop() |
| 234 |
| 235 self.parallel_errors = False |
| 236 return node |
| 237 |
| 238 def visit_YieldExprNode(self, node): |
| 239 if self.parallel_block_stack: |
| 240 error(node.pos, "Yield not allowed in parallel sections") |
| 241 |
| 242 return node |
| 243 |
| 244 def visit_ReturnStatNode(self, node): |
| 245 node.in_parallel = bool(self.parallel_block_stack) |
| 246 return node |
| 247 |
| 248 |
| 249 class MarkOverflowingArithmetic(CythonTransform): |
| 250 |
| 251 # It may be possible to integrate this with the above for |
| 252 # performance improvements (though likely not worth it). |
| 253 |
| 254 might_overflow = False |
| 255 |
| 256 def __call__(self, root): |
| 257 self.env_stack = [] |
| 258 self.env = root.scope |
| 259 return super(MarkOverflowingArithmetic, self).__call__(root) |
| 260 |
| 261 def visit_safe_node(self, node): |
| 262 self.might_overflow, saved = False, self.might_overflow |
| 263 self.visitchildren(node) |
| 264 self.might_overflow = saved |
| 265 return node |
| 266 |
| 267 def visit_neutral_node(self, node): |
| 268 self.visitchildren(node) |
| 269 return node |
| 270 |
| 271 def visit_dangerous_node(self, node): |
| 272 self.might_overflow, saved = True, self.might_overflow |
| 273 self.visitchildren(node) |
| 274 self.might_overflow = saved |
| 275 return node |
| 276 |
| 277 def visit_FuncDefNode(self, node): |
| 278 self.env_stack.append(self.env) |
| 279 self.env = node.local_scope |
| 280 self.visit_safe_node(node) |
| 281 self.env = self.env_stack.pop() |
| 282 return node |
| 283 |
| 284 def visit_NameNode(self, node): |
| 285 if self.might_overflow: |
| 286 entry = node.entry or self.env.lookup(node.name) |
| 287 if entry: |
| 288 entry.might_overflow = True |
| 289 return node |
| 290 |
| 291 def visit_BinopNode(self, node): |
| 292 if node.operator in '&|^': |
| 293 return self.visit_neutral_node(node) |
| 294 else: |
| 295 return self.visit_dangerous_node(node) |
| 296 |
| 297 visit_UnopNode = visit_neutral_node |
| 298 |
| 299 visit_UnaryMinusNode = visit_dangerous_node |
| 300 |
| 301 visit_InPlaceAssignmentNode = visit_dangerous_node |
| 302 |
| 303 visit_Node = visit_safe_node |
| 304 |
| 305 def visit_assignment(self, lhs, rhs): |
| 306 if (isinstance(rhs, ExprNodes.IntNode) |
| 307 and isinstance(lhs, ExprNodes.NameNode) |
| 308 and Utils.long_literal(rhs.value)): |
| 309 entry = lhs.entry or self.env.lookup(lhs.name) |
| 310 if entry: |
| 311 entry.might_overflow = True |
| 312 |
| 313 def visit_SingleAssignmentNode(self, node): |
| 314 self.visit_assignment(node.lhs, node.rhs) |
| 315 self.visitchildren(node) |
| 316 return node |
| 317 |
| 318 def visit_CascadedAssignmentNode(self, node): |
| 319 for lhs in node.lhs_list: |
| 320 self.visit_assignment(lhs, node.rhs) |
| 321 self.visitchildren(node) |
| 322 return node |
| 323 |
| 324 class PyObjectTypeInferer(object): |
| 325 """ |
| 326 If it's not declared, it's a PyObject. |
| 327 """ |
| 328 def infer_types(self, scope): |
| 329 """ |
| 330 Given a dict of entries, map all unspecified types to a specified type. |
| 331 """ |
| 332 for name, entry in scope.entries.items(): |
| 333 if entry.type is unspecified_type: |
| 334 entry.type = py_object_type |
| 335 |
| 336 class SimpleAssignmentTypeInferer(object): |
| 337 """ |
| 338 Very basic type inference. |
| 339 |
| 340 Note: in order to support cross-closure type inference, this must be |
| 341 applies to nested scopes in top-down order. |
| 342 """ |
| 343 def set_entry_type(self, entry, entry_type): |
| 344 entry.type = entry_type |
| 345 for e in entry.all_entries(): |
| 346 e.type = entry_type |
| 347 |
| 348 def infer_types(self, scope): |
| 349 enabled = scope.directives['infer_types'] |
| 350 verbose = scope.directives['infer_types.verbose'] |
| 351 |
| 352 if enabled == True: |
| 353 spanning_type = aggressive_spanning_type |
| 354 elif enabled is None: # safe mode |
| 355 spanning_type = safe_spanning_type |
| 356 else: |
| 357 for entry in scope.entries.values(): |
| 358 if entry.type is unspecified_type: |
| 359 self.set_entry_type(entry, py_object_type) |
| 360 return |
| 361 |
| 362 # Set of assignemnts |
| 363 assignments = set([]) |
| 364 assmts_resolved = set([]) |
| 365 dependencies = {} |
| 366 assmt_to_names = {} |
| 367 |
| 368 for name, entry in scope.entries.items(): |
| 369 for assmt in entry.cf_assignments: |
| 370 names = assmt.type_dependencies() |
| 371 assmt_to_names[assmt] = names |
| 372 assmts = set() |
| 373 for node in names: |
| 374 assmts.update(node.cf_state) |
| 375 dependencies[assmt] = assmts |
| 376 if entry.type is unspecified_type: |
| 377 assignments.update(entry.cf_assignments) |
| 378 else: |
| 379 assmts_resolved.update(entry.cf_assignments) |
| 380 |
| 381 def infer_name_node_type(node): |
| 382 types = [assmt.inferred_type for assmt in node.cf_state] |
| 383 if not types: |
| 384 node_type = py_object_type |
| 385 else: |
| 386 entry = node.entry |
| 387 node_type = spanning_type( |
| 388 types, entry.might_overflow, entry.pos) |
| 389 node.inferred_type = node_type |
| 390 |
| 391 def infer_name_node_type_partial(node): |
| 392 types = [assmt.inferred_type for assmt in node.cf_state |
| 393 if assmt.inferred_type is not None] |
| 394 if not types: |
| 395 return |
| 396 entry = node.entry |
| 397 return spanning_type(types, entry.might_overflow, entry.pos) |
| 398 |
| 399 def resolve_assignments(assignments): |
| 400 resolved = set() |
| 401 for assmt in assignments: |
| 402 deps = dependencies[assmt] |
| 403 # All assignments are resolved |
| 404 if assmts_resolved.issuperset(deps): |
| 405 for node in assmt_to_names[assmt]: |
| 406 infer_name_node_type(node) |
| 407 # Resolve assmt |
| 408 inferred_type = assmt.infer_type() |
| 409 assmts_resolved.add(assmt) |
| 410 resolved.add(assmt) |
| 411 assignments.difference_update(resolved) |
| 412 return resolved |
| 413 |
| 414 def partial_infer(assmt): |
| 415 partial_types = [] |
| 416 for node in assmt_to_names[assmt]: |
| 417 partial_type = infer_name_node_type_partial(node) |
| 418 if partial_type is None: |
| 419 return False |
| 420 partial_types.append((node, partial_type)) |
| 421 for node, partial_type in partial_types: |
| 422 node.inferred_type = partial_type |
| 423 assmt.infer_type() |
| 424 return True |
| 425 |
| 426 partial_assmts = set() |
| 427 def resolve_partial(assignments): |
| 428 # try to handle circular references |
| 429 partials = set() |
| 430 for assmt in assignments: |
| 431 if assmt in partial_assmts: |
| 432 continue |
| 433 if partial_infer(assmt): |
| 434 partials.add(assmt) |
| 435 assmts_resolved.add(assmt) |
| 436 partial_assmts.update(partials) |
| 437 return partials |
| 438 |
| 439 # Infer assignments |
| 440 while True: |
| 441 if not resolve_assignments(assignments): |
| 442 if not resolve_partial(assignments): |
| 443 break |
| 444 inferred = set() |
| 445 # First pass |
| 446 for entry in scope.entries.values(): |
| 447 if entry.type is not unspecified_type: |
| 448 continue |
| 449 entry_type = py_object_type |
| 450 if assmts_resolved.issuperset(entry.cf_assignments): |
| 451 types = [assmt.inferred_type for assmt in entry.cf_assignments] |
| 452 if types and Utils.all(types): |
| 453 entry_type = spanning_type( |
| 454 types, entry.might_overflow, entry.pos) |
| 455 inferred.add(entry) |
| 456 self.set_entry_type(entry, entry_type) |
| 457 |
| 458 def reinfer(): |
| 459 dirty = False |
| 460 for entry in inferred: |
| 461 types = [assmt.infer_type() |
| 462 for assmt in entry.cf_assignments] |
| 463 new_type = spanning_type(types, entry.might_overflow, entry.pos) |
| 464 if new_type != entry.type: |
| 465 self.set_entry_type(entry, new_type) |
| 466 dirty = True |
| 467 return dirty |
| 468 |
| 469 # types propagation |
| 470 while reinfer(): |
| 471 pass |
| 472 |
| 473 if verbose: |
| 474 for entry in inferred: |
| 475 message(entry.pos, "inferred '%s' to be of type '%s'" % ( |
| 476 entry.name, entry.type)) |
| 477 |
| 478 |
| 479 def find_spanning_type(type1, type2): |
| 480 if type1 is type2: |
| 481 result_type = type1 |
| 482 elif type1 is PyrexTypes.c_bint_type or type2 is PyrexTypes.c_bint_type: |
| 483 # type inference can break the coercion back to a Python bool |
| 484 # if it returns an arbitrary int type here |
| 485 return py_object_type |
| 486 else: |
| 487 result_type = PyrexTypes.spanning_type(type1, type2) |
| 488 if result_type in (PyrexTypes.c_double_type, PyrexTypes.c_float_type, |
| 489 Builtin.float_type): |
| 490 # Python's float type is just a C double, so it's safe to |
| 491 # use the C type instead |
| 492 return PyrexTypes.c_double_type |
| 493 return result_type |
| 494 |
| 495 def aggressive_spanning_type(types, might_overflow, pos): |
| 496 result_type = reduce(find_spanning_type, types) |
| 497 if result_type.is_reference: |
| 498 result_type = result_type.ref_base_type |
| 499 if result_type.is_const: |
| 500 result_type = result_type.const_base_type |
| 501 if result_type.is_cpp_class: |
| 502 result_type.check_nullary_constructor(pos) |
| 503 return result_type |
| 504 |
| 505 def safe_spanning_type(types, might_overflow, pos): |
| 506 result_type = reduce(find_spanning_type, types) |
| 507 if result_type.is_const: |
| 508 result_type = result_type.const_base_type |
| 509 if result_type.is_reference: |
| 510 result_type = result_type.ref_base_type |
| 511 if result_type.is_cpp_class: |
| 512 result_type.check_nullary_constructor(pos) |
| 513 if result_type.is_pyobject: |
| 514 # In theory, any specific Python type is always safe to |
| 515 # infer. However, inferring str can cause some existing code |
| 516 # to break, since we are also now much more strict about |
| 517 # coercion from str to char *. See trac #553. |
| 518 if result_type.name == 'str': |
| 519 return py_object_type |
| 520 else: |
| 521 return result_type |
| 522 elif result_type is PyrexTypes.c_double_type: |
| 523 # Python's float type is just a C double, so it's safe to use |
| 524 # the C type instead |
| 525 return result_type |
| 526 elif result_type is PyrexTypes.c_bint_type: |
| 527 # find_spanning_type() only returns 'bint' for clean boolean |
| 528 # operations without other int types, so this is safe, too |
| 529 return result_type |
| 530 elif result_type.is_ptr: |
| 531 # Any pointer except (signed|unsigned|) char* can't implicitly |
| 532 # become a PyObject, and inferring char* is now accepted, too. |
| 533 return result_type |
| 534 elif result_type.is_cpp_class: |
| 535 # These can't implicitly become Python objects either. |
| 536 return result_type |
| 537 elif result_type.is_struct: |
| 538 # Though we have struct -> object for some structs, this is uncommonly |
| 539 # used, won't arise in pure Python, and there shouldn't be side |
| 540 # effects, so I'm declaring this safe. |
| 541 return result_type |
| 542 # TODO: double complex should be OK as well, but we need |
| 543 # to make sure everything is supported. |
| 544 elif (result_type.is_int or result_type.is_enum) and not might_overflow: |
| 545 return result_type |
| 546 return py_object_type |
| 547 |
| 548 |
| 549 def get_type_inferer(): |
| 550 return SimpleAssignmentTypeInferer() |
OLD | NEW |