OLD | NEW |
(Empty) | |
| 1 import sys, os, re, inspect |
| 2 import imp |
| 3 |
| 4 try: |
| 5 import hashlib |
| 6 except ImportError: |
| 7 import md5 as hashlib |
| 8 |
| 9 from distutils.core import Distribution, Extension |
| 10 from distutils.command.build_ext import build_ext |
| 11 |
| 12 import Cython |
| 13 from Cython.Compiler.Main import Context, CompilationOptions, default_options |
| 14 |
| 15 from Cython.Compiler.ParseTreeTransforms import CythonTransform, SkipDeclaration
s, AnalyseDeclarationsTransform |
| 16 from Cython.Compiler.TreeFragment import parse_from_strings |
| 17 from Cython.Build.Dependencies import strip_string_literals, cythonize, cached_f
unction |
| 18 from Cython.Compiler import Pipeline |
| 19 from Cython.Utils import get_cython_cache_dir |
| 20 import cython as cython_module |
| 21 |
| 22 # A utility function to convert user-supplied ASCII strings to unicode. |
| 23 if sys.version_info[0] < 3: |
| 24 def to_unicode(s): |
| 25 if not isinstance(s, unicode): |
| 26 return s.decode('ascii') |
| 27 else: |
| 28 return s |
| 29 else: |
| 30 to_unicode = lambda x: x |
| 31 |
| 32 |
| 33 class AllSymbols(CythonTransform, SkipDeclarations): |
| 34 def __init__(self): |
| 35 CythonTransform.__init__(self, None) |
| 36 self.names = set() |
| 37 def visit_NameNode(self, node): |
| 38 self.names.add(node.name) |
| 39 |
| 40 @cached_function |
| 41 def unbound_symbols(code, context=None): |
| 42 code = to_unicode(code) |
| 43 if context is None: |
| 44 context = Context([], default_options) |
| 45 from Cython.Compiler.ParseTreeTransforms import AnalyseDeclarationsTransform |
| 46 tree = parse_from_strings('(tree fragment)', code) |
| 47 for phase in Pipeline.create_pipeline(context, 'pyx'): |
| 48 if phase is None: |
| 49 continue |
| 50 tree = phase(tree) |
| 51 if isinstance(phase, AnalyseDeclarationsTransform): |
| 52 break |
| 53 symbol_collector = AllSymbols() |
| 54 symbol_collector(tree) |
| 55 unbound = [] |
| 56 try: |
| 57 import builtins |
| 58 except ImportError: |
| 59 import __builtin__ as builtins |
| 60 for name in symbol_collector.names: |
| 61 if not tree.scope.lookup(name) and not hasattr(builtins, name): |
| 62 unbound.append(name) |
| 63 return unbound |
| 64 |
| 65 def unsafe_type(arg, context=None): |
| 66 py_type = type(arg) |
| 67 if py_type is int: |
| 68 return 'long' |
| 69 else: |
| 70 return safe_type(arg, context) |
| 71 |
| 72 def safe_type(arg, context=None): |
| 73 py_type = type(arg) |
| 74 if py_type in [list, tuple, dict, str]: |
| 75 return py_type.__name__ |
| 76 elif py_type is complex: |
| 77 return 'double complex' |
| 78 elif py_type is float: |
| 79 return 'double' |
| 80 elif py_type is bool: |
| 81 return 'bint' |
| 82 elif 'numpy' in sys.modules and isinstance(arg, sys.modules['numpy'].ndarray
): |
| 83 return 'numpy.ndarray[numpy.%s_t, ndim=%s]' % (arg.dtype.name, arg.ndim) |
| 84 else: |
| 85 for base_type in py_type.mro(): |
| 86 if base_type.__module__ in ('__builtin__', 'builtins'): |
| 87 return 'object' |
| 88 module = context.find_module(base_type.__module__, need_pxd=False) |
| 89 if module: |
| 90 entry = module.lookup(base_type.__name__) |
| 91 if entry.is_type: |
| 92 return '%s.%s' % (base_type.__module__, base_type.__name__) |
| 93 return 'object' |
| 94 |
| 95 def _get_build_extension(): |
| 96 dist = Distribution() |
| 97 # Ensure the build respects distutils configuration by parsing |
| 98 # the configuration files |
| 99 config_files = dist.find_config_files() |
| 100 dist.parse_config_files(config_files) |
| 101 build_extension = build_ext(dist) |
| 102 build_extension.finalize_options() |
| 103 return build_extension |
| 104 |
| 105 @cached_function |
| 106 def _create_context(cython_include_dirs): |
| 107 return Context(list(cython_include_dirs), default_options) |
| 108 |
| 109 def cython_inline(code, |
| 110 get_type=unsafe_type, |
| 111 lib_dir=os.path.join(get_cython_cache_dir(), 'inline'), |
| 112 cython_include_dirs=['.'], |
| 113 force=False, |
| 114 quiet=False, |
| 115 locals=None, |
| 116 globals=None, |
| 117 **kwds): |
| 118 if get_type is None: |
| 119 get_type = lambda x: 'object' |
| 120 code = to_unicode(code) |
| 121 orig_code = code |
| 122 code, literals = strip_string_literals(code) |
| 123 code = strip_common_indent(code) |
| 124 ctx = _create_context(tuple(cython_include_dirs)) |
| 125 if locals is None: |
| 126 locals = inspect.currentframe().f_back.f_back.f_locals |
| 127 if globals is None: |
| 128 globals = inspect.currentframe().f_back.f_back.f_globals |
| 129 try: |
| 130 for symbol in unbound_symbols(code): |
| 131 if symbol in kwds: |
| 132 continue |
| 133 elif symbol in locals: |
| 134 kwds[symbol] = locals[symbol] |
| 135 elif symbol in globals: |
| 136 kwds[symbol] = globals[symbol] |
| 137 else: |
| 138 print("Couldn't find ", symbol) |
| 139 except AssertionError: |
| 140 if not quiet: |
| 141 # Parsing from strings not fully supported (e.g. cimports). |
| 142 print("Could not parse code as a string (to extract unbound symbols)
.") |
| 143 cimports = [] |
| 144 for name, arg in kwds.items(): |
| 145 if arg is cython_module: |
| 146 cimports.append('\ncimport cython as %s' % name) |
| 147 del kwds[name] |
| 148 arg_names = kwds.keys() |
| 149 arg_names.sort() |
| 150 arg_sigs = tuple([(get_type(kwds[arg], ctx), arg) for arg in arg_names]) |
| 151 key = orig_code, arg_sigs, sys.version_info, sys.executable, Cython.__versio
n__ |
| 152 module_name = "_cython_inline_" + hashlib.md5(str(key).encode('utf-8')).hexd
igest() |
| 153 |
| 154 if module_name in sys.modules: |
| 155 module = sys.modules[module_name] |
| 156 |
| 157 else: |
| 158 build_extension = None |
| 159 if cython_inline.so_ext is None: |
| 160 # Figure out and cache current extension suffix |
| 161 build_extension = _get_build_extension() |
| 162 cython_inline.so_ext = build_extension.get_ext_filename('') |
| 163 |
| 164 module_path = os.path.join(lib_dir, module_name + cython_inline.so_ext) |
| 165 |
| 166 if not os.path.exists(lib_dir): |
| 167 os.makedirs(lib_dir) |
| 168 if force or not os.path.isfile(module_path): |
| 169 cflags = [] |
| 170 c_include_dirs = [] |
| 171 qualified = re.compile(r'([.\w]+)[.]') |
| 172 for type, _ in arg_sigs: |
| 173 m = qualified.match(type) |
| 174 if m: |
| 175 cimports.append('\ncimport %s' % m.groups()[0]) |
| 176 # one special case |
| 177 if m.groups()[0] == 'numpy': |
| 178 import numpy |
| 179 c_include_dirs.append(numpy.get_include()) |
| 180 # cflags.append('-Wno-unused') |
| 181 module_body, func_body = extract_func_code(code) |
| 182 params = ', '.join(['%s %s' % a for a in arg_sigs]) |
| 183 module_code = """ |
| 184 %(module_body)s |
| 185 %(cimports)s |
| 186 def __invoke(%(params)s): |
| 187 %(func_body)s |
| 188 """ % {'cimports': '\n'.join(cimports), 'module_body': module_body,
'params': params, 'func_body': func_body } |
| 189 for key, value in literals.items(): |
| 190 module_code = module_code.replace(key, value) |
| 191 pyx_file = os.path.join(lib_dir, module_name + '.pyx') |
| 192 fh = open(pyx_file, 'w') |
| 193 try: |
| 194 fh.write(module_code) |
| 195 finally: |
| 196 fh.close() |
| 197 extension = Extension( |
| 198 name = module_name, |
| 199 sources = [pyx_file], |
| 200 include_dirs = c_include_dirs, |
| 201 extra_compile_args = cflags) |
| 202 if build_extension is None: |
| 203 build_extension = _get_build_extension() |
| 204 build_extension.extensions = cythonize([extension], include_path=cyt
hon_include_dirs, quiet=quiet) |
| 205 build_extension.build_temp = os.path.dirname(pyx_file) |
| 206 build_extension.build_lib = lib_dir |
| 207 build_extension.run() |
| 208 |
| 209 module = imp.load_dynamic(module_name, module_path) |
| 210 |
| 211 arg_list = [kwds[arg] for arg in arg_names] |
| 212 return module.__invoke(*arg_list) |
| 213 |
| 214 # Cached suffix used by cython_inline above. None should get |
| 215 # overridden with actual value upon the first cython_inline invocation |
| 216 cython_inline.so_ext = None |
| 217 |
| 218 non_space = re.compile('[^ ]') |
| 219 def strip_common_indent(code): |
| 220 min_indent = None |
| 221 lines = code.split('\n') |
| 222 for line in lines: |
| 223 match = non_space.search(line) |
| 224 if not match: |
| 225 continue # blank |
| 226 indent = match.start() |
| 227 if line[indent] == '#': |
| 228 continue # comment |
| 229 elif min_indent is None or min_indent > indent: |
| 230 min_indent = indent |
| 231 for ix, line in enumerate(lines): |
| 232 match = non_space.search(line) |
| 233 if not match or line[indent] == '#': |
| 234 continue |
| 235 else: |
| 236 lines[ix] = line[min_indent:] |
| 237 return '\n'.join(lines) |
| 238 |
| 239 module_statement = re.compile(r'^((cdef +(extern|class))|cimport|(from .+ cimpor
t)|(from .+ import +[*]))') |
| 240 def extract_func_code(code): |
| 241 module = [] |
| 242 function = [] |
| 243 current = function |
| 244 code = code.replace('\t', ' ') |
| 245 lines = code.split('\n') |
| 246 for line in lines: |
| 247 if not line.startswith(' '): |
| 248 if module_statement.match(line): |
| 249 current = module |
| 250 else: |
| 251 current = function |
| 252 current.append(line) |
| 253 return '\n'.join(module), ' ' + '\n '.join(function) |
| 254 |
| 255 |
| 256 |
| 257 try: |
| 258 from inspect import getcallargs |
| 259 except ImportError: |
| 260 def getcallargs(func, *arg_values, **kwd_values): |
| 261 all = {} |
| 262 args, varargs, kwds, defaults = inspect.getargspec(func) |
| 263 if varargs is not None: |
| 264 all[varargs] = arg_values[len(args):] |
| 265 for name, value in zip(args, arg_values): |
| 266 all[name] = value |
| 267 for name, value in kwd_values.items(): |
| 268 if name in args: |
| 269 if name in all: |
| 270 raise TypeError("Duplicate argument %s" % name) |
| 271 all[name] = kwd_values.pop(name) |
| 272 if kwds is not None: |
| 273 all[kwds] = kwd_values |
| 274 elif kwd_values: |
| 275 raise TypeError("Unexpected keyword arguments: %s" % kwd_values.keys
()) |
| 276 if defaults is None: |
| 277 defaults = () |
| 278 first_default = len(args) - len(defaults) |
| 279 for ix, name in enumerate(args): |
| 280 if name not in all: |
| 281 if ix >= first_default: |
| 282 all[name] = defaults[ix - first_default] |
| 283 else: |
| 284 raise TypeError("Missing argument: %s" % name) |
| 285 return all |
| 286 |
| 287 def get_body(source): |
| 288 ix = source.index(':') |
| 289 if source[:5] == 'lambda': |
| 290 return "return %s" % source[ix+1:] |
| 291 else: |
| 292 return source[ix+1:] |
| 293 |
| 294 # Lots to be done here... It would be especially cool if compiled functions |
| 295 # could invoke each other quickly. |
| 296 class RuntimeCompiledFunction(object): |
| 297 |
| 298 def __init__(self, f): |
| 299 self._f = f |
| 300 self._body = get_body(inspect.getsource(f)) |
| 301 |
| 302 def __call__(self, *args, **kwds): |
| 303 all = getcallargs(self._f, *args, **kwds) |
| 304 return cython_inline(self._body, locals=self._f.func_globals, globals=se
lf._f.func_globals, **all) |
OLD | NEW |