setup.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. #!/usr/bin/env python
  2. # Copyright (c) OpenMMLab. All rights reserved.
  3. import os
  4. import os.path as osp
  5. import platform
  6. import shutil
  7. import sys
  8. import warnings
  9. from setuptools import find_packages, setup
  10. import torch
  11. from torch.utils.cpp_extension import (BuildExtension, CppExtension,
  12. CUDAExtension)
  13. def readme():
  14. with open('README.md', encoding='utf-8') as f:
  15. content = f.read()
  16. return content
  17. version_file = 'mmdet/version.py'
  18. def get_version():
  19. with open(version_file, 'r') as f:
  20. exec(compile(f.read(), version_file, 'exec'))
  21. return locals()['__version__']
  22. def make_cuda_ext(name, module, sources, sources_cuda=[]):
  23. define_macros = []
  24. extra_compile_args = {'cxx': []}
  25. if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1':
  26. define_macros += [('WITH_CUDA', None)]
  27. extension = CUDAExtension
  28. extra_compile_args['nvcc'] = [
  29. '-D__CUDA_NO_HALF_OPERATORS__',
  30. '-D__CUDA_NO_HALF_CONVERSIONS__',
  31. '-D__CUDA_NO_HALF2_OPERATORS__',
  32. ]
  33. sources += sources_cuda
  34. else:
  35. print(f'Compiling {name} without CUDA')
  36. extension = CppExtension
  37. return extension(
  38. name=f'{module}.{name}',
  39. sources=[os.path.join(*module.split('.'), p) for p in sources],
  40. define_macros=define_macros,
  41. extra_compile_args=extra_compile_args)
  42. def parse_requirements(fname='requirements.txt', with_version=True):
  43. """Parse the package dependencies listed in a requirements file but strips
  44. specific versioning information.
  45. Args:
  46. fname (str): path to requirements file
  47. with_version (bool, default=False): if True include version specs
  48. Returns:
  49. List[str]: list of requirements items
  50. CommandLine:
  51. python -c "import setup; print(setup.parse_requirements())"
  52. """
  53. import re
  54. import sys
  55. from os.path import exists
  56. require_fpath = fname
  57. def parse_line(line):
  58. """Parse information from a line in a requirements text file."""
  59. if line.startswith('-r '):
  60. # Allow specifying requirements in other files
  61. target = line.split(' ')[1]
  62. for info in parse_require_file(target):
  63. yield info
  64. else:
  65. info = {'line': line}
  66. if line.startswith('-e '):
  67. info['package'] = line.split('#egg=')[1]
  68. elif '@git+' in line:
  69. info['package'] = line
  70. else:
  71. # Remove versioning from the package
  72. pat = '(' + '|'.join(['>=', '==', '>']) + ')'
  73. parts = re.split(pat, line, maxsplit=1)
  74. parts = [p.strip() for p in parts]
  75. info['package'] = parts[0]
  76. if len(parts) > 1:
  77. op, rest = parts[1:]
  78. if ';' in rest:
  79. # Handle platform specific dependencies
  80. # http://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-platform-specific-dependencies
  81. version, platform_deps = map(str.strip,
  82. rest.split(';'))
  83. info['platform_deps'] = platform_deps
  84. else:
  85. version = rest # NOQA
  86. info['version'] = (op, version)
  87. yield info
  88. def parse_require_file(fpath):
  89. with open(fpath, 'r') as f:
  90. for line in f.readlines():
  91. line = line.strip()
  92. if line and not line.startswith('#'):
  93. for info in parse_line(line):
  94. yield info
  95. def gen_packages_items():
  96. if exists(require_fpath):
  97. for info in parse_require_file(require_fpath):
  98. parts = [info['package']]
  99. if with_version and 'version' in info:
  100. parts.extend(info['version'])
  101. if not sys.version.startswith('3.4'):
  102. # apparently package_deps are broken in 3.4
  103. platform_deps = info.get('platform_deps')
  104. if platform_deps is not None:
  105. parts.append(';' + platform_deps)
  106. item = ''.join(parts)
  107. yield item
  108. packages = list(gen_packages_items())
  109. return packages
  110. def add_mim_extension():
  111. """Add extra files that are required to support MIM into the package.
  112. These files will be added by creating a symlink to the originals if the
  113. package is installed in `editable` mode (e.g. pip install -e .), or by
  114. copying from the originals otherwise.
  115. """
  116. # parse installment mode
  117. if 'develop' in sys.argv:
  118. # installed by `pip install -e .`
  119. if platform.system() == 'Windows':
  120. # set `copy` mode here since symlink fails on Windows.
  121. mode = 'copy'
  122. else:
  123. mode = 'symlink'
  124. elif 'sdist' in sys.argv or 'bdist_wheel' in sys.argv:
  125. # installed by `pip install .`
  126. # or create source distribution by `python setup.py sdist`
  127. mode = 'copy'
  128. else:
  129. return
  130. filenames = ['tools', 'configs', 'demo', 'model-index.yml']
  131. repo_path = osp.dirname(__file__)
  132. mim_path = osp.join(repo_path, 'mmdet', '.mim')
  133. os.makedirs(mim_path, exist_ok=True)
  134. for filename in filenames:
  135. if osp.exists(filename):
  136. src_path = osp.join(repo_path, filename)
  137. tar_path = osp.join(mim_path, filename)
  138. if osp.isfile(tar_path) or osp.islink(tar_path):
  139. os.remove(tar_path)
  140. elif osp.isdir(tar_path):
  141. shutil.rmtree(tar_path)
  142. if mode == 'symlink':
  143. src_relpath = osp.relpath(src_path, osp.dirname(tar_path))
  144. os.symlink(src_relpath, tar_path)
  145. elif mode == 'copy':
  146. if osp.isfile(src_path):
  147. shutil.copyfile(src_path, tar_path)
  148. elif osp.isdir(src_path):
  149. shutil.copytree(src_path, tar_path)
  150. else:
  151. warnings.warn(f'Cannot copy file {src_path}.')
  152. else:
  153. raise ValueError(f'Invalid mode {mode}')
  154. if __name__ == '__main__':
  155. add_mim_extension()
  156. setup(
  157. name='mmdet',
  158. version=get_version(),
  159. description='OpenMMLab Detection Toolbox and Benchmark',
  160. long_description=readme(),
  161. long_description_content_type='text/markdown',
  162. author='MMDetection Contributors',
  163. author_email='openmmlab@gmail.com',
  164. keywords='computer vision, object detection',
  165. url='https://github.com/open-mmlab/mmdetection',
  166. packages=find_packages(exclude=('configs', 'tools', 'demo')),
  167. include_package_data=True,
  168. classifiers=[
  169. 'Development Status :: 5 - Production/Stable',
  170. 'License :: OSI Approved :: Apache Software License',
  171. 'Operating System :: OS Independent',
  172. 'Programming Language :: Python :: 3',
  173. 'Programming Language :: Python :: 3.7',
  174. 'Programming Language :: Python :: 3.8',
  175. 'Programming Language :: Python :: 3.9',
  176. ],
  177. license='Apache License 2.0',
  178. install_requires=parse_requirements('requirements/runtime.txt'),
  179. extras_require={
  180. 'all': parse_requirements('requirements.txt'),
  181. 'tests': parse_requirements('requirements/tests.txt'),
  182. 'build': parse_requirements('requirements/build.txt'),
  183. 'optional': parse_requirements('requirements/optional.txt'),
  184. 'mim': parse_requirements('requirements/mminstall.txt'),
  185. },
  186. ext_modules=[],
  187. cmdclass={'build_ext': BuildExtension},
  188. zip_safe=False)