| Index: third_party/cython/src/Cython/Build/Inline.py
|
| diff --git a/third_party/cython/src/Cython/Build/Inline.py b/third_party/cython/src/Cython/Build/Inline.py
|
| new file mode 100644
|
| index 0000000000000000000000000000000000000000..fcbb6c1282a3e9768ac21e25c692d51213498300
|
| --- /dev/null
|
| +++ b/third_party/cython/src/Cython/Build/Inline.py
|
| @@ -0,0 +1,304 @@
|
| +import sys, os, re, inspect
|
| +import imp
|
| +
|
| +try:
|
| + import hashlib
|
| +except ImportError:
|
| + import md5 as hashlib
|
| +
|
| +from distutils.core import Distribution, Extension
|
| +from distutils.command.build_ext import build_ext
|
| +
|
| +import Cython
|
| +from Cython.Compiler.Main import Context, CompilationOptions, default_options
|
| +
|
| +from Cython.Compiler.ParseTreeTransforms import CythonTransform, SkipDeclarations, AnalyseDeclarationsTransform
|
| +from Cython.Compiler.TreeFragment import parse_from_strings
|
| +from Cython.Build.Dependencies import strip_string_literals, cythonize, cached_function
|
| +from Cython.Compiler import Pipeline
|
| +from Cython.Utils import get_cython_cache_dir
|
| +import cython as cython_module
|
| +
|
| +# A utility function to convert user-supplied ASCII strings to unicode.
|
| +if sys.version_info[0] < 3:
|
| + def to_unicode(s):
|
| + if not isinstance(s, unicode):
|
| + return s.decode('ascii')
|
| + else:
|
| + return s
|
| +else:
|
| + to_unicode = lambda x: x
|
| +
|
| +
|
| +class AllSymbols(CythonTransform, SkipDeclarations):
|
| + def __init__(self):
|
| + CythonTransform.__init__(self, None)
|
| + self.names = set()
|
| + def visit_NameNode(self, node):
|
| + self.names.add(node.name)
|
| +
|
| +@cached_function
|
| +def unbound_symbols(code, context=None):
|
| + code = to_unicode(code)
|
| + if context is None:
|
| + context = Context([], default_options)
|
| + from Cython.Compiler.ParseTreeTransforms import AnalyseDeclarationsTransform
|
| + tree = parse_from_strings('(tree fragment)', code)
|
| + for phase in Pipeline.create_pipeline(context, 'pyx'):
|
| + if phase is None:
|
| + continue
|
| + tree = phase(tree)
|
| + if isinstance(phase, AnalyseDeclarationsTransform):
|
| + break
|
| + symbol_collector = AllSymbols()
|
| + symbol_collector(tree)
|
| + unbound = []
|
| + try:
|
| + import builtins
|
| + except ImportError:
|
| + import __builtin__ as builtins
|
| + for name in symbol_collector.names:
|
| + if not tree.scope.lookup(name) and not hasattr(builtins, name):
|
| + unbound.append(name)
|
| + return unbound
|
| +
|
| +def unsafe_type(arg, context=None):
|
| + py_type = type(arg)
|
| + if py_type is int:
|
| + return 'long'
|
| + else:
|
| + return safe_type(arg, context)
|
| +
|
| +def safe_type(arg, context=None):
|
| + py_type = type(arg)
|
| + if py_type in [list, tuple, dict, str]:
|
| + return py_type.__name__
|
| + elif py_type is complex:
|
| + return 'double complex'
|
| + elif py_type is float:
|
| + return 'double'
|
| + elif py_type is bool:
|
| + return 'bint'
|
| + elif 'numpy' in sys.modules and isinstance(arg, sys.modules['numpy'].ndarray):
|
| + return 'numpy.ndarray[numpy.%s_t, ndim=%s]' % (arg.dtype.name, arg.ndim)
|
| + else:
|
| + for base_type in py_type.mro():
|
| + if base_type.__module__ in ('__builtin__', 'builtins'):
|
| + return 'object'
|
| + module = context.find_module(base_type.__module__, need_pxd=False)
|
| + if module:
|
| + entry = module.lookup(base_type.__name__)
|
| + if entry.is_type:
|
| + return '%s.%s' % (base_type.__module__, base_type.__name__)
|
| + return 'object'
|
| +
|
| +def _get_build_extension():
|
| + dist = Distribution()
|
| + # Ensure the build respects distutils configuration by parsing
|
| + # the configuration files
|
| + config_files = dist.find_config_files()
|
| + dist.parse_config_files(config_files)
|
| + build_extension = build_ext(dist)
|
| + build_extension.finalize_options()
|
| + return build_extension
|
| +
|
| +@cached_function
|
| +def _create_context(cython_include_dirs):
|
| + return Context(list(cython_include_dirs), default_options)
|
| +
|
| +def cython_inline(code,
|
| + get_type=unsafe_type,
|
| + lib_dir=os.path.join(get_cython_cache_dir(), 'inline'),
|
| + cython_include_dirs=['.'],
|
| + force=False,
|
| + quiet=False,
|
| + locals=None,
|
| + globals=None,
|
| + **kwds):
|
| + if get_type is None:
|
| + get_type = lambda x: 'object'
|
| + code = to_unicode(code)
|
| + orig_code = code
|
| + code, literals = strip_string_literals(code)
|
| + code = strip_common_indent(code)
|
| + ctx = _create_context(tuple(cython_include_dirs))
|
| + if locals is None:
|
| + locals = inspect.currentframe().f_back.f_back.f_locals
|
| + if globals is None:
|
| + globals = inspect.currentframe().f_back.f_back.f_globals
|
| + try:
|
| + for symbol in unbound_symbols(code):
|
| + if symbol in kwds:
|
| + continue
|
| + elif symbol in locals:
|
| + kwds[symbol] = locals[symbol]
|
| + elif symbol in globals:
|
| + kwds[symbol] = globals[symbol]
|
| + else:
|
| + print("Couldn't find ", symbol)
|
| + except AssertionError:
|
| + if not quiet:
|
| + # Parsing from strings not fully supported (e.g. cimports).
|
| + print("Could not parse code as a string (to extract unbound symbols).")
|
| + cimports = []
|
| + for name, arg in kwds.items():
|
| + if arg is cython_module:
|
| + cimports.append('\ncimport cython as %s' % name)
|
| + del kwds[name]
|
| + arg_names = kwds.keys()
|
| + arg_names.sort()
|
| + arg_sigs = tuple([(get_type(kwds[arg], ctx), arg) for arg in arg_names])
|
| + key = orig_code, arg_sigs, sys.version_info, sys.executable, Cython.__version__
|
| + module_name = "_cython_inline_" + hashlib.md5(str(key).encode('utf-8')).hexdigest()
|
| +
|
| + if module_name in sys.modules:
|
| + module = sys.modules[module_name]
|
| +
|
| + else:
|
| + build_extension = None
|
| + if cython_inline.so_ext is None:
|
| + # Figure out and cache current extension suffix
|
| + build_extension = _get_build_extension()
|
| + cython_inline.so_ext = build_extension.get_ext_filename('')
|
| +
|
| + module_path = os.path.join(lib_dir, module_name + cython_inline.so_ext)
|
| +
|
| + if not os.path.exists(lib_dir):
|
| + os.makedirs(lib_dir)
|
| + if force or not os.path.isfile(module_path):
|
| + cflags = []
|
| + c_include_dirs = []
|
| + qualified = re.compile(r'([.\w]+)[.]')
|
| + for type, _ in arg_sigs:
|
| + m = qualified.match(type)
|
| + if m:
|
| + cimports.append('\ncimport %s' % m.groups()[0])
|
| + # one special case
|
| + if m.groups()[0] == 'numpy':
|
| + import numpy
|
| + c_include_dirs.append(numpy.get_include())
|
| + # cflags.append('-Wno-unused')
|
| + module_body, func_body = extract_func_code(code)
|
| + params = ', '.join(['%s %s' % a for a in arg_sigs])
|
| + module_code = """
|
| +%(module_body)s
|
| +%(cimports)s
|
| +def __invoke(%(params)s):
|
| +%(func_body)s
|
| + """ % {'cimports': '\n'.join(cimports), 'module_body': module_body, 'params': params, 'func_body': func_body }
|
| + for key, value in literals.items():
|
| + module_code = module_code.replace(key, value)
|
| + pyx_file = os.path.join(lib_dir, module_name + '.pyx')
|
| + fh = open(pyx_file, 'w')
|
| + try:
|
| + fh.write(module_code)
|
| + finally:
|
| + fh.close()
|
| + extension = Extension(
|
| + name = module_name,
|
| + sources = [pyx_file],
|
| + include_dirs = c_include_dirs,
|
| + extra_compile_args = cflags)
|
| + if build_extension is None:
|
| + build_extension = _get_build_extension()
|
| + build_extension.extensions = cythonize([extension], include_path=cython_include_dirs, quiet=quiet)
|
| + build_extension.build_temp = os.path.dirname(pyx_file)
|
| + build_extension.build_lib = lib_dir
|
| + build_extension.run()
|
| +
|
| + module = imp.load_dynamic(module_name, module_path)
|
| +
|
| + arg_list = [kwds[arg] for arg in arg_names]
|
| + return module.__invoke(*arg_list)
|
| +
|
| +# Cached suffix used by cython_inline above. None should get
|
| +# overridden with actual value upon the first cython_inline invocation
|
| +cython_inline.so_ext = None
|
| +
|
| +non_space = re.compile('[^ ]')
|
| +def strip_common_indent(code):
|
| + min_indent = None
|
| + lines = code.split('\n')
|
| + for line in lines:
|
| + match = non_space.search(line)
|
| + if not match:
|
| + continue # blank
|
| + indent = match.start()
|
| + if line[indent] == '#':
|
| + continue # comment
|
| + elif min_indent is None or min_indent > indent:
|
| + min_indent = indent
|
| + for ix, line in enumerate(lines):
|
| + match = non_space.search(line)
|
| + if not match or line[indent] == '#':
|
| + continue
|
| + else:
|
| + lines[ix] = line[min_indent:]
|
| + return '\n'.join(lines)
|
| +
|
| +module_statement = re.compile(r'^((cdef +(extern|class))|cimport|(from .+ cimport)|(from .+ import +[*]))')
|
| +def extract_func_code(code):
|
| + module = []
|
| + function = []
|
| + current = function
|
| + code = code.replace('\t', ' ')
|
| + lines = code.split('\n')
|
| + for line in lines:
|
| + if not line.startswith(' '):
|
| + if module_statement.match(line):
|
| + current = module
|
| + else:
|
| + current = function
|
| + current.append(line)
|
| + return '\n'.join(module), ' ' + '\n '.join(function)
|
| +
|
| +
|
| +
|
| +try:
|
| + from inspect import getcallargs
|
| +except ImportError:
|
| + def getcallargs(func, *arg_values, **kwd_values):
|
| + all = {}
|
| + args, varargs, kwds, defaults = inspect.getargspec(func)
|
| + if varargs is not None:
|
| + all[varargs] = arg_values[len(args):]
|
| + for name, value in zip(args, arg_values):
|
| + all[name] = value
|
| + for name, value in kwd_values.items():
|
| + if name in args:
|
| + if name in all:
|
| + raise TypeError("Duplicate argument %s" % name)
|
| + all[name] = kwd_values.pop(name)
|
| + if kwds is not None:
|
| + all[kwds] = kwd_values
|
| + elif kwd_values:
|
| + raise TypeError("Unexpected keyword arguments: %s" % kwd_values.keys())
|
| + if defaults is None:
|
| + defaults = ()
|
| + first_default = len(args) - len(defaults)
|
| + for ix, name in enumerate(args):
|
| + if name not in all:
|
| + if ix >= first_default:
|
| + all[name] = defaults[ix - first_default]
|
| + else:
|
| + raise TypeError("Missing argument: %s" % name)
|
| + return all
|
| +
|
| +def get_body(source):
|
| + ix = source.index(':')
|
| + if source[:5] == 'lambda':
|
| + return "return %s" % source[ix+1:]
|
| + else:
|
| + return source[ix+1:]
|
| +
|
| +# Lots to be done here... It would be especially cool if compiled functions
|
| +# could invoke each other quickly.
|
| +class RuntimeCompiledFunction(object):
|
| +
|
| + def __init__(self, f):
|
| + self._f = f
|
| + self._body = get_body(inspect.getsource(f))
|
| +
|
| + def __call__(self, *args, **kwds):
|
| + all = getcallargs(self._f, *args, **kwds)
|
| + return cython_inline(self._body, locals=self._f.func_globals, globals=self._f.func_globals, **all)
|
|
|