Skip to content

Commit

Permalink
Add nested enums to stub generation
Browse files Browse the repository at this point in the history
  • Loading branch information
njroussel committed Sep 21, 2022
1 parent fc1aca9 commit ad72a53
Showing 1 changed file with 41 additions and 18 deletions.
59 changes: 41 additions & 18 deletions resources/generate_stub_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
import re
import sys
import logging
import collections.abc

# ------------------------------------------------------------------------------

top_level_objects = {}
buffer = ''

def w(s):
Expand Down Expand Up @@ -118,11 +120,14 @@ def process_enums(name, e, indent=0):

# ------------------------------------------------------------------------------

def process_class(obj):
def process_class(_, obj, indent=0):
methods = []
py_methods = []
properties = []
enums = []
classes = []

indent_str = ' ' * indent

for k in dir(obj):
# Skip private attributes
Expand All @@ -141,49 +146,58 @@ def process_class(obj):
properties.append((k, v))
elif str(v).endswith(k):
enums.append((k, v))
elif inspect.isclass(v):
if (hasattr(v, '__module__') and (v.__module__.startswith('mitsuba'))):
classes.append((k, v))

base = obj.__bases__[0]
base_module = base.__module__
base_name = base.__name__
has_base = not (base_module == 'builtins' or base_name == 'object' or base_name == 'pybind11_object')
has_base = not (base_module == 'builtins' or
base_name == 'object' or
base_name == 'pybind11_object' or
base_module.startswith('drjit'))
base_name = base_name.replace(f'.{mi.variant()}', '')

if has_base and not base.__module__.startswith('mitsuba'):
w(f'from {base.__module__} import {base_name}')
w(f'{indent_str}from {base.__module__} import {base_name}')

w(f'class {obj.__name__}{"(" + base_name + ")" if has_base else ""}:')
w(f'{indent_str}class {obj.__name__}{"(" + base_name + ")" if has_base else ""}:')
if obj.__doc__ is not None:
doc = obj.__doc__.splitlines()
if len(doc) > 0:
if doc[0].strip() == '':
doc = doc[1:]
if obj.__doc__:
w(f' \"\"\"')
w(f'{indent_str} \"\"\"')
for l in doc:
w(f' {l}')
w(f' \"\"\"')
w(f'{indent_str} {l}')
w(f'{indent_str} \"\"\"')
w(f'')

process_function('__init__', obj.__init__, indent=4)
process_function('__call__', obj.__call__, indent=4)
process_function('__init__', obj.__init__, indent=indent + 4)
process_function('__call__', obj.__call__, indent=indent + 4)

for k, v in classes:
process_class(k, v, indent=indent + 4)

if len(properties) > 0:
for k, v in properties:
process_properties(k, v, indent=4)
process_properties(k, v, indent=indent + 4)
w(f'')

if len(enums) > 0:
for k, v in enums:
process_enums(k, v, indent=4)
process_enums(k, v, indent=indent + 4)
w(f'')

for k, v in methods:
process_function(k, v, indent=4)
process_function(k, v, indent=indent + 4)

for k, v in py_methods:
process_py_function(k, v, indent=4)
process_py_function(k, v, indent=indent + 4)

w(f' ...')
w(f'{indent_str} ...')
w('')

# ------------------------------------------------------------------------------
Expand Down Expand Up @@ -232,7 +246,6 @@ def process_py_function(name, obj, indent=0):
has_doc = obj.__doc__ is not None

signature = str(inspect.signature(obj))
signature = signature.replace('\'', '')
signature = signature.replace('mi.', 'mitsuba.')

# Fix parameters that have enums as default values
Expand Down Expand Up @@ -273,9 +286,10 @@ def process_module(m, top_module=False):
submodules = []
buffer = ''

w('from typing import Any, Callable, Iterable, Iterator, Tuple, List, TypeVar, Union, overload, ModuleType')
w('from typing import Any, Callable, Iterable, Iterator, Tuple, List, TypeVar, Union, overload')
w('import mitsuba')
w('import mitsuba as mi')
w('import drjit as dr')
w('')

# Ignore initialization errors of invalid variants on this system
Expand All @@ -284,16 +298,23 @@ def process_module(m, top_module=False):
except Exception:
pass

if m not in top_level_objects:
top_level_objects[m] = set()

for k in dir(m):
v = getattr(m, k)

# Already seen object
if isinstance(v, collections.abc.Hashable) and v in top_level_objects[m]:
continue

if inspect.isclass(v):
if (hasattr(v, '__module__') and
not (v.__module__.startswith('mitsuba') or v.__module__.startswith('drjit'))):
if v in [bool, int, float]:
process_builtin_type(v, k)
continue
process_class(v)
process_class(k, v)
elif type(v).__name__ in ['method', 'function']:
process_py_function(k, v)
elif type(v).__name__ == 'builtin_function_or_method':
Expand All @@ -320,8 +341,10 @@ def process_module(m, top_module=False):

submodules.append((module_filename, v))

if isinstance(v, collections.abc.Hashable):
top_level_objects[m].add(v)

# Adjust DrJIT type hints manually here
buffer = re.sub(r'from drjit(\.scalar|\.llvm|\.cuda)?(\.ad)? ' , 'from mitsuba ', buffer)
buffer = re.sub(r'drjit\.(scalar|llvm|cuda).(ad\.)?' , 'mitsuba.', buffer)

return buffer, submodules
Expand Down

0 comments on commit ad72a53

Please sign in to comment.