Chromium Code Reviews
chromiumcodereview-hr@appspot.gserviceaccount.com (chromiumcodereview-hr) | Please choose your nickname with Settings | Help | Chromium Project | Gerrit Changes | Sign out
(43)

Side by Side Diff: third_party/cython/src/Cython/Compiler/TypeInference.py

Issue 385073004: Adding cython v0.20.2 in third-party. (Closed) Base URL: svn://svn.chromium.org/chrome/trunk/src
Patch Set: Reference cython dev list thread. Created 6 years, 5 months ago
Use n/p to move between diff chunks; N/P to move between comments. Draft comments are only viewable by you.
Jump to:
View unified diff | Download patch | Annotate | Revision Log
OLDNEW
(Empty)
1 from Errors import error, message
2 import ExprNodes
3 import Nodes
4 import Builtin
5 import PyrexTypes
6 from Cython import Utils
7 from PyrexTypes import py_object_type, unspecified_type
8 from Visitor import CythonTransform, EnvTransform
9
10
11 class TypedExprNode(ExprNodes.ExprNode):
12 # Used for declaring assignments of a specified type without a known entry.
13 def __init__(self, type):
14 self.type = type
15
16 object_expr = TypedExprNode(py_object_type)
17
18
19 class MarkParallelAssignments(EnvTransform):
20 # Collects assignments inside parallel blocks prange, with parallel.
21 # Perhaps it's better to move it to ControlFlowAnalysis.
22
23 # tells us whether we're in a normal loop
24 in_loop = False
25
26 parallel_errors = False
27
28 def __init__(self, context):
29 # Track the parallel block scopes (with parallel, for i in prange())
30 self.parallel_block_stack = []
31 super(MarkParallelAssignments, self).__init__(context)
32
33 def mark_assignment(self, lhs, rhs, inplace_op=None):
34 if isinstance(lhs, (ExprNodes.NameNode, Nodes.PyArgDeclNode)):
35 if lhs.entry is None:
36 # TODO: This shouldn't happen...
37 return
38
39 if self.parallel_block_stack:
40 parallel_node = self.parallel_block_stack[-1]
41 previous_assignment = parallel_node.assignments.get(lhs.entry)
42
43 # If there was a previous assignment to the variable, keep the
44 # previous assignment position
45 if previous_assignment:
46 pos, previous_inplace_op = previous_assignment
47
48 if (inplace_op and previous_inplace_op and
49 inplace_op != previous_inplace_op):
50 # x += y; x *= y
51 t = (inplace_op, previous_inplace_op)
52 error(lhs.pos,
53 "Reduction operator '%s' is inconsistent "
54 "with previous reduction operator '%s'" % t)
55 else:
56 pos = lhs.pos
57
58 parallel_node.assignments[lhs.entry] = (pos, inplace_op)
59 parallel_node.assigned_nodes.append(lhs)
60
61 elif isinstance(lhs, ExprNodes.SequenceNode):
62 for arg in lhs.args:
63 self.mark_assignment(arg, object_expr)
64 else:
65 # Could use this info to infer cdef class attributes...
66 pass
67
68 def visit_WithTargetAssignmentStatNode(self, node):
69 self.mark_assignment(node.lhs, node.rhs)
70 self.visitchildren(node)
71 return node
72
73 def visit_SingleAssignmentNode(self, node):
74 self.mark_assignment(node.lhs, node.rhs)
75 self.visitchildren(node)
76 return node
77
78 def visit_CascadedAssignmentNode(self, node):
79 for lhs in node.lhs_list:
80 self.mark_assignment(lhs, node.rhs)
81 self.visitchildren(node)
82 return node
83
84 def visit_InPlaceAssignmentNode(self, node):
85 self.mark_assignment(node.lhs, node.create_binop_node(), node.operator)
86 self.visitchildren(node)
87 return node
88
89 def visit_ForInStatNode(self, node):
90 # TODO: Remove redundancy with range optimization...
91 is_special = False
92 sequence = node.iterator.sequence
93 target = node.target
94 if isinstance(sequence, ExprNodes.SimpleCallNode):
95 function = sequence.function
96 if sequence.self is None and function.is_name:
97 entry = self.current_env().lookup(function.name)
98 if not entry or entry.is_builtin:
99 if function.name == 'reversed' and len(sequence.args) == 1:
100 sequence = sequence.args[0]
101 elif function.name == 'enumerate' and len(sequence.args) == 1:
102 if target.is_sequence_constructor and len(target.args) = = 2:
103 iterator = sequence.args[0]
104 if iterator.is_name:
105 iterator_type = iterator.infer_type(self.current _env())
106 if iterator_type.is_builtin_type:
107 # assume that builtin types have a length wi thin Py_ssize_t
108 self.mark_assignment(
109 target.args[0],
110 ExprNodes.IntNode(target.pos, value='PY_ SSIZE_T_MAX',
111 type=PyrexTypes.c_py_s size_t_type))
112 target = target.args[1]
113 sequence = sequence.args[0]
114 if isinstance(sequence, ExprNodes.SimpleCallNode):
115 function = sequence.function
116 if sequence.self is None and function.is_name:
117 entry = self.current_env().lookup(function.name)
118 if not entry or entry.is_builtin:
119 if function.name in ('range', 'xrange'):
120 is_special = True
121 for arg in sequence.args[:2]:
122 self.mark_assignment(target, arg)
123 if len(sequence.args) > 2:
124 self.mark_assignment(
125 target,
126 ExprNodes.binop_node(node.pos,
127 '+',
128 sequence.args[0],
129 sequence.args[2]))
130
131 if not is_special:
132 # A for-loop basically translates to subsequent calls to
133 # __getitem__(), so using an IndexNode here allows us to
134 # naturally infer the base type of pointers, C arrays,
135 # Python strings, etc., while correctly falling back to an
136 # object type when the base type cannot be handled.
137 self.mark_assignment(target, ExprNodes.IndexNode(
138 node.pos,
139 base=sequence,
140 index=ExprNodes.IntNode(target.pos, value='PY_SSIZE_T_MAX',
141 type=PyrexTypes.c_py_ssize_t_type)))
142
143 self.visitchildren(node)
144 return node
145
146 def visit_ForFromStatNode(self, node):
147 self.mark_assignment(node.target, node.bound1)
148 if node.step is not None:
149 self.mark_assignment(node.target,
150 ExprNodes.binop_node(node.pos,
151 '+',
152 node.bound1,
153 node.step))
154 self.visitchildren(node)
155 return node
156
157 def visit_WhileStatNode(self, node):
158 self.visitchildren(node)
159 return node
160
161 def visit_ExceptClauseNode(self, node):
162 if node.target is not None:
163 self.mark_assignment(node.target, object_expr)
164 self.visitchildren(node)
165 return node
166
167 def visit_FromCImportStatNode(self, node):
168 pass # Can't be assigned to...
169
170 def visit_FromImportStatNode(self, node):
171 for name, target in node.items:
172 if name != "*":
173 self.mark_assignment(target, object_expr)
174 self.visitchildren(node)
175 return node
176
177 def visit_DefNode(self, node):
178 # use fake expressions with the right result type
179 if node.star_arg:
180 self.mark_assignment(
181 node.star_arg, TypedExprNode(Builtin.tuple_type))
182 if node.starstar_arg:
183 self.mark_assignment(
184 node.starstar_arg, TypedExprNode(Builtin.dict_type))
185 EnvTransform.visit_FuncDefNode(self, node)
186 return node
187
188 def visit_DelStatNode(self, node):
189 for arg in node.args:
190 self.mark_assignment(arg, arg)
191 self.visitchildren(node)
192 return node
193
194 def visit_ParallelStatNode(self, node):
195 if self.parallel_block_stack:
196 node.parent = self.parallel_block_stack[-1]
197 else:
198 node.parent = None
199
200 nested = False
201 if node.is_prange:
202 if not node.parent:
203 node.is_parallel = True
204 else:
205 node.is_parallel = (node.parent.is_prange or not
206 node.parent.is_parallel)
207 nested = node.parent.is_prange
208 else:
209 node.is_parallel = True
210 # Note: nested with parallel() blocks are handled by
211 # ParallelRangeTransform!
212 # nested = node.parent
213 nested = node.parent and node.parent.is_prange
214
215 self.parallel_block_stack.append(node)
216
217 nested = nested or len(self.parallel_block_stack) > 2
218 if not self.parallel_errors and nested and not node.is_prange:
219 error(node.pos, "Only prange() may be nested")
220 self.parallel_errors = True
221
222 if node.is_prange:
223 child_attrs = node.child_attrs
224 node.child_attrs = ['body', 'target', 'args']
225 self.visitchildren(node)
226 node.child_attrs = child_attrs
227
228 self.parallel_block_stack.pop()
229 if node.else_clause:
230 node.else_clause = self.visit(node.else_clause)
231 else:
232 self.visitchildren(node)
233 self.parallel_block_stack.pop()
234
235 self.parallel_errors = False
236 return node
237
238 def visit_YieldExprNode(self, node):
239 if self.parallel_block_stack:
240 error(node.pos, "Yield not allowed in parallel sections")
241
242 return node
243
244 def visit_ReturnStatNode(self, node):
245 node.in_parallel = bool(self.parallel_block_stack)
246 return node
247
248
249 class MarkOverflowingArithmetic(CythonTransform):
250
251 # It may be possible to integrate this with the above for
252 # performance improvements (though likely not worth it).
253
254 might_overflow = False
255
256 def __call__(self, root):
257 self.env_stack = []
258 self.env = root.scope
259 return super(MarkOverflowingArithmetic, self).__call__(root)
260
261 def visit_safe_node(self, node):
262 self.might_overflow, saved = False, self.might_overflow
263 self.visitchildren(node)
264 self.might_overflow = saved
265 return node
266
267 def visit_neutral_node(self, node):
268 self.visitchildren(node)
269 return node
270
271 def visit_dangerous_node(self, node):
272 self.might_overflow, saved = True, self.might_overflow
273 self.visitchildren(node)
274 self.might_overflow = saved
275 return node
276
277 def visit_FuncDefNode(self, node):
278 self.env_stack.append(self.env)
279 self.env = node.local_scope
280 self.visit_safe_node(node)
281 self.env = self.env_stack.pop()
282 return node
283
284 def visit_NameNode(self, node):
285 if self.might_overflow:
286 entry = node.entry or self.env.lookup(node.name)
287 if entry:
288 entry.might_overflow = True
289 return node
290
291 def visit_BinopNode(self, node):
292 if node.operator in '&|^':
293 return self.visit_neutral_node(node)
294 else:
295 return self.visit_dangerous_node(node)
296
297 visit_UnopNode = visit_neutral_node
298
299 visit_UnaryMinusNode = visit_dangerous_node
300
301 visit_InPlaceAssignmentNode = visit_dangerous_node
302
303 visit_Node = visit_safe_node
304
305 def visit_assignment(self, lhs, rhs):
306 if (isinstance(rhs, ExprNodes.IntNode)
307 and isinstance(lhs, ExprNodes.NameNode)
308 and Utils.long_literal(rhs.value)):
309 entry = lhs.entry or self.env.lookup(lhs.name)
310 if entry:
311 entry.might_overflow = True
312
313 def visit_SingleAssignmentNode(self, node):
314 self.visit_assignment(node.lhs, node.rhs)
315 self.visitchildren(node)
316 return node
317
318 def visit_CascadedAssignmentNode(self, node):
319 for lhs in node.lhs_list:
320 self.visit_assignment(lhs, node.rhs)
321 self.visitchildren(node)
322 return node
323
324 class PyObjectTypeInferer(object):
325 """
326 If it's not declared, it's a PyObject.
327 """
328 def infer_types(self, scope):
329 """
330 Given a dict of entries, map all unspecified types to a specified type.
331 """
332 for name, entry in scope.entries.items():
333 if entry.type is unspecified_type:
334 entry.type = py_object_type
335
336 class SimpleAssignmentTypeInferer(object):
337 """
338 Very basic type inference.
339
340 Note: in order to support cross-closure type inference, this must be
341 applies to nested scopes in top-down order.
342 """
343 def set_entry_type(self, entry, entry_type):
344 entry.type = entry_type
345 for e in entry.all_entries():
346 e.type = entry_type
347
348 def infer_types(self, scope):
349 enabled = scope.directives['infer_types']
350 verbose = scope.directives['infer_types.verbose']
351
352 if enabled == True:
353 spanning_type = aggressive_spanning_type
354 elif enabled is None: # safe mode
355 spanning_type = safe_spanning_type
356 else:
357 for entry in scope.entries.values():
358 if entry.type is unspecified_type:
359 self.set_entry_type(entry, py_object_type)
360 return
361
362 # Set of assignemnts
363 assignments = set([])
364 assmts_resolved = set([])
365 dependencies = {}
366 assmt_to_names = {}
367
368 for name, entry in scope.entries.items():
369 for assmt in entry.cf_assignments:
370 names = assmt.type_dependencies()
371 assmt_to_names[assmt] = names
372 assmts = set()
373 for node in names:
374 assmts.update(node.cf_state)
375 dependencies[assmt] = assmts
376 if entry.type is unspecified_type:
377 assignments.update(entry.cf_assignments)
378 else:
379 assmts_resolved.update(entry.cf_assignments)
380
381 def infer_name_node_type(node):
382 types = [assmt.inferred_type for assmt in node.cf_state]
383 if not types:
384 node_type = py_object_type
385 else:
386 entry = node.entry
387 node_type = spanning_type(
388 types, entry.might_overflow, entry.pos)
389 node.inferred_type = node_type
390
391 def infer_name_node_type_partial(node):
392 types = [assmt.inferred_type for assmt in node.cf_state
393 if assmt.inferred_type is not None]
394 if not types:
395 return
396 entry = node.entry
397 return spanning_type(types, entry.might_overflow, entry.pos)
398
399 def resolve_assignments(assignments):
400 resolved = set()
401 for assmt in assignments:
402 deps = dependencies[assmt]
403 # All assignments are resolved
404 if assmts_resolved.issuperset(deps):
405 for node in assmt_to_names[assmt]:
406 infer_name_node_type(node)
407 # Resolve assmt
408 inferred_type = assmt.infer_type()
409 assmts_resolved.add(assmt)
410 resolved.add(assmt)
411 assignments.difference_update(resolved)
412 return resolved
413
414 def partial_infer(assmt):
415 partial_types = []
416 for node in assmt_to_names[assmt]:
417 partial_type = infer_name_node_type_partial(node)
418 if partial_type is None:
419 return False
420 partial_types.append((node, partial_type))
421 for node, partial_type in partial_types:
422 node.inferred_type = partial_type
423 assmt.infer_type()
424 return True
425
426 partial_assmts = set()
427 def resolve_partial(assignments):
428 # try to handle circular references
429 partials = set()
430 for assmt in assignments:
431 if assmt in partial_assmts:
432 continue
433 if partial_infer(assmt):
434 partials.add(assmt)
435 assmts_resolved.add(assmt)
436 partial_assmts.update(partials)
437 return partials
438
439 # Infer assignments
440 while True:
441 if not resolve_assignments(assignments):
442 if not resolve_partial(assignments):
443 break
444 inferred = set()
445 # First pass
446 for entry in scope.entries.values():
447 if entry.type is not unspecified_type:
448 continue
449 entry_type = py_object_type
450 if assmts_resolved.issuperset(entry.cf_assignments):
451 types = [assmt.inferred_type for assmt in entry.cf_assignments]
452 if types and Utils.all(types):
453 entry_type = spanning_type(
454 types, entry.might_overflow, entry.pos)
455 inferred.add(entry)
456 self.set_entry_type(entry, entry_type)
457
458 def reinfer():
459 dirty = False
460 for entry in inferred:
461 types = [assmt.infer_type()
462 for assmt in entry.cf_assignments]
463 new_type = spanning_type(types, entry.might_overflow, entry.pos)
464 if new_type != entry.type:
465 self.set_entry_type(entry, new_type)
466 dirty = True
467 return dirty
468
469 # types propagation
470 while reinfer():
471 pass
472
473 if verbose:
474 for entry in inferred:
475 message(entry.pos, "inferred '%s' to be of type '%s'" % (
476 entry.name, entry.type))
477
478
479 def find_spanning_type(type1, type2):
480 if type1 is type2:
481 result_type = type1
482 elif type1 is PyrexTypes.c_bint_type or type2 is PyrexTypes.c_bint_type:
483 # type inference can break the coercion back to a Python bool
484 # if it returns an arbitrary int type here
485 return py_object_type
486 else:
487 result_type = PyrexTypes.spanning_type(type1, type2)
488 if result_type in (PyrexTypes.c_double_type, PyrexTypes.c_float_type,
489 Builtin.float_type):
490 # Python's float type is just a C double, so it's safe to
491 # use the C type instead
492 return PyrexTypes.c_double_type
493 return result_type
494
495 def aggressive_spanning_type(types, might_overflow, pos):
496 result_type = reduce(find_spanning_type, types)
497 if result_type.is_reference:
498 result_type = result_type.ref_base_type
499 if result_type.is_const:
500 result_type = result_type.const_base_type
501 if result_type.is_cpp_class:
502 result_type.check_nullary_constructor(pos)
503 return result_type
504
505 def safe_spanning_type(types, might_overflow, pos):
506 result_type = reduce(find_spanning_type, types)
507 if result_type.is_const:
508 result_type = result_type.const_base_type
509 if result_type.is_reference:
510 result_type = result_type.ref_base_type
511 if result_type.is_cpp_class:
512 result_type.check_nullary_constructor(pos)
513 if result_type.is_pyobject:
514 # In theory, any specific Python type is always safe to
515 # infer. However, inferring str can cause some existing code
516 # to break, since we are also now much more strict about
517 # coercion from str to char *. See trac #553.
518 if result_type.name == 'str':
519 return py_object_type
520 else:
521 return result_type
522 elif result_type is PyrexTypes.c_double_type:
523 # Python's float type is just a C double, so it's safe to use
524 # the C type instead
525 return result_type
526 elif result_type is PyrexTypes.c_bint_type:
527 # find_spanning_type() only returns 'bint' for clean boolean
528 # operations without other int types, so this is safe, too
529 return result_type
530 elif result_type.is_ptr:
531 # Any pointer except (signed|unsigned|) char* can't implicitly
532 # become a PyObject, and inferring char* is now accepted, too.
533 return result_type
534 elif result_type.is_cpp_class:
535 # These can't implicitly become Python objects either.
536 return result_type
537 elif result_type.is_struct:
538 # Though we have struct -> object for some structs, this is uncommonly
539 # used, won't arise in pure Python, and there shouldn't be side
540 # effects, so I'm declaring this safe.
541 return result_type
542 # TODO: double complex should be OK as well, but we need
543 # to make sure everything is supported.
544 elif (result_type.is_int or result_type.is_enum) and not might_overflow:
545 return result_type
546 return py_object_type
547
548
549 def get_type_inferer():
550 return SimpleAssignmentTypeInferer()
OLDNEW
« no previous file with comments | « third_party/cython/src/Cython/Compiler/TreePath.py ('k') | third_party/cython/src/Cython/Compiler/TypeSlots.py » ('j') | no next file with comments »

Powered by Google App Engine
This is Rietveld 408576698