OLD | NEW |
1 # Copyright (c) 2012 The Chromium Authors. All rights reserved. | 1 # Copyright (c) 2012 The Chromium Authors. All rights reserved. |
2 # Use of this source code is governed by a BSD-style license that can be | 2 # Use of this source code is governed by a BSD-style license that can be |
3 # found in the LICENSE file. | 3 # found in the LICENSE file. |
4 import inspect | 4 import inspect |
5 import logging | 5 import logging |
6 import os | 6 import os |
7 import traceback | 7 import traceback |
8 | 8 |
9 def Discover(start_dir, suffix, clazz, import_error_should_raise=False): | 9 def Discover(start_dir, top_level_dir, suffix, clazz, |
| 10 import_error_should_raise=False): |
10 """Discover all classes in |start_dir| which subclass |clazz|. | 11 """Discover all classes in |start_dir| which subclass |clazz|. |
11 | 12 |
12 Args: | 13 Args: |
13 start_dir: The directory to recursively search. | 14 start_dir: The directory to recursively search. |
14 suffix: file name suffix for files to import, without the '.py' ending. | 15 suffix: file name suffix for files to import, without the '.py' ending. |
15 clazz: The base class to search for. | 16 clazz: The base class to search for. |
16 import_error_should_raise: If false, then import errors are logged but do | 17 import_error_should_raise: If false, then import errors are logged but do |
17 not stop discovery. | 18 not stop discovery. |
18 | 19 |
19 Returns: | 20 Returns: |
20 dict of {module_name: class}. | 21 dict of {module_name: class}. |
21 """ | 22 """ |
22 top_level_dir = os.path.join(start_dir, '..') | |
23 classes = {} | 23 classes = {} |
24 for dirpath, _, filenames in os.walk(start_dir): | 24 for dirpath, _, filenames in os.walk(start_dir): |
25 for filename in filenames: | 25 for filename in filenames: |
26 if not filename.endswith(suffix + '.py'): | 26 if not filename.endswith(suffix + '.py'): |
27 continue | 27 continue |
28 name, _ = os.path.splitext(filename) | 28 name, _ = os.path.splitext(filename) |
29 relpath = os.path.relpath(dirpath, top_level_dir) | 29 relpath = os.path.relpath(dirpath, top_level_dir) |
30 fqn = relpath.replace('/', '.') + '.' + name | 30 fqn = relpath.replace('/', '.') + '.' + name |
31 try: | 31 try: |
32 module = __import__(fqn, fromlist=[True]) | 32 module = __import__(fqn, fromlist=[True]) |
33 except Exception: | 33 except Exception: |
34 if import_error_should_raise: | 34 if import_error_should_raise: |
35 raise | 35 raise |
36 logging.error('While importing [%s]\n' % fqn) | 36 logging.error('While importing [%s]\n' % fqn) |
37 traceback.print_exc() | 37 traceback.print_exc() |
38 continue | 38 continue |
39 for name, obj in inspect.getmembers(module): | 39 for name, obj in inspect.getmembers(module): |
40 if inspect.isclass(obj): | 40 if inspect.isclass(obj): |
41 if clazz in inspect.getmro(obj): | 41 if clazz in inspect.getmro(obj): |
42 name = module.__name__.split('.')[-1] | 42 name = module.__name__.split('.')[-1] |
43 classes[name] = obj | 43 classes[name] = obj |
44 return classes | 44 return classes |
OLD | NEW |