utils.py 3.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from collections import defaultdict
  3. from copy import deepcopy
  4. from typing import Any, Callable, Dict, Optional, Tuple
  5. class OutputSaveObjectWrapper:
  6. """A wrapper class that saves the output of function calls on an object."""
  7. def __init__(self, obj: Any) -> None:
  8. self.obj = obj
  9. self.log = defaultdict(list)
  10. def __getattr__(self, attr: str) -> Any:
  11. """Overrides the default behavior when an attribute is accessed.
  12. - If the attribute is callable, hooks the attribute and saves the
  13. returned value of the function call to the log.
  14. - If the attribute is not callable, saves the attribute's value to the
  15. log and returns the value.
  16. """
  17. orig_attr = getattr(self.obj, attr)
  18. if not callable(orig_attr):
  19. self.log[attr].append(orig_attr)
  20. return orig_attr
  21. def hooked(*args: Tuple, **kwargs: Dict) -> Any:
  22. """The hooked function that logs the return value of the original
  23. function."""
  24. result = orig_attr(*args, **kwargs)
  25. self.log[attr].append(result)
  26. return result
  27. return hooked
  28. def clear(self):
  29. """Clears the log of function call outputs."""
  30. self.log.clear()
  31. def __deepcopy__(self, memo):
  32. """Only copy the object when applying deepcopy."""
  33. other = type(self)(deepcopy(self.obj))
  34. memo[id(self)] = other
  35. return other
  36. class OutputSaveFunctionWrapper:
  37. """A class that wraps a function and saves its outputs.
  38. This class can be used to decorate a function to save its outputs. It wraps
  39. the function with a `__call__` method that calls the original function and
  40. saves the results in a log attribute.
  41. Args:
  42. func (Callable): A function to wrap.
  43. spec (Optional[Dict]): A dictionary of global variables to use as the
  44. namespace for the wrapper. If `None`, the global namespace of the
  45. original function is used.
  46. """
  47. def __init__(self, func: Callable, spec: Optional[Dict]) -> None:
  48. """Initializes the OutputSaveFunctionWrapper instance."""
  49. assert callable(func)
  50. self.log = []
  51. self.func = func
  52. self.func_name = func.__name__
  53. if isinstance(spec, dict):
  54. self.spec = spec
  55. elif hasattr(func, '__globals__'):
  56. self.spec = func.__globals__
  57. else:
  58. raise ValueError
  59. def __call__(self, *args, **kwargs) -> Any:
  60. """Calls the wrapped function with the given arguments and saves the
  61. results in the `log` attribute."""
  62. results = self.func(*args, **kwargs)
  63. self.log.append(results)
  64. return results
  65. def __enter__(self) -> None:
  66. """Enters the context and sets the wrapped function to be a global
  67. variable in the specified namespace."""
  68. self.spec[self.func_name] = self
  69. return self.log
  70. def __exit__(self, exc_type, exc_val, exc_tb) -> None:
  71. """Exits the context and resets the wrapped function to its original
  72. value in the specified namespace."""
  73. self.spec[self.func_name] = self.func