Pythonでデコレーターを継承先にも適用する2

こんにちは、Pythonエンジニア見習いです。
Pythonでデコレーターを継承先にも適用する
のデコレーターを継承先にも適用する方法に欠陥を発見したのでそれがどのようなものなのか
解説していきたいと思います。さらにその代替方法についても紹介します。

なぜクラスの__new__関数を用いる方法ではダメなのか

まず、Pythonのクラスに書かれた__new__関数はそのクラスのインスタンスが 生成されるたびに
実行されます 。これにより本来であれば各サブクラスごとに一回デコレーターを適用したい場面で
デコレーターがインスタンス生成ごとに適用されてしまいます。わかりやすいように以下に例を示します。
まず、以下に示すようにpythonのクラスや関数が定義されているとします。

def deco(f):
  @functools.wraps(f)
  def inner(*args, **kwargs):
    print("before")
    result = f(*args, **kwargs)
    print("after")
    return result

  return inner


class A(metaclass=abc.ABCMeta):
  @abc.abstractclassmethod
  def foo(self):
    pass

  # この関数がインスタンス生成ごとに実行される
  def __new__(cls, *args, **kwargs):
    # ここでデコレーターを適用している
    cls.foo = deco(cls.foo)
    return super().__new__(cls)


class B(A):
  def foo(self):
    print("B")

これを以下のように実行するとインスタンス生成ごとに
クラスBのメンバ関数fooがデコレートされていることがわかります。

In [2]: B().foo()
before
B
after

In [3]: B().foo()
before
before
B
after
after

どうやったら継承先に(複数回適用せずに)デコレーターを適用できるのか

解決方法は以下のとおりです。

  1. デコレータが適用されていないことを確認してから適用する
  2. メタクラスの __new__ を用いてデコレーターの適用をモジュール読み込み時にしか行わないようにする

1の方法はシンプルですが、
その関数が何でデコレートされているのかを簡単に知る方法は現在ないので
特定のデコレーターが適用されているか否かを識別するのは容易ではありません。
よって2の方法でデコレーターを継承先にも適用するようにしたコードを以下に示します。

import abc
import functools
from abc import ABCMeta
from collections import ChainMap


class HogeException(Exception):
  pass


class FooException(HogeException):
  pass


class BarException(HogeException):
  pass


# デコレーター
def exception_catcher(f, exception, callback):
  @functools.wraps(f)
  def inner(*args, **kwargs):
    print("before")
    try:
      result = f(*args, **kwargs)
    except exception as e:
      result = callback(*args, **kwargs)
    print("after")
    return result
  return inner


def get_bases_dict(bases):
  dicts = []
  for base in bases:
    if hasattr(base, "__bases__") and hasattr(base, "__dict__"):
      dicts.append(base.__dict__)
      dicts.extend(get_bases_dict(base.__bases__))
  return dicts


class MetaExceptionCatcher(ABCMeta):
  # この関数が各クラスが読み込まれるときに実行される
  def __new__(mcls, name, bases, namespace):
    # namespaceには継承元のメンバは含まれていないのでそれらを統合したnsを生成している
    ns = ChainMap(namespace, *get_bases_dict(bases))
    # ここでデコレートしている
    namespace["foo"] = exception_catcher(ns["foo"], FooException, ns["foo_sub"])
    return super().__new__(mcls, name, bases, namespace)


# AのメタクラスにMetaExceptionCatcherを指定することで
# Aを読み込むときにメタクラスの `__new__` が実行される
class A(metaclass=MetaExceptionCatcher):
  @abc.abstractclassmethod
  def foo(self):
    pass

  def foo_sub(self):
    print("foo_sub")


class B(A):
  def foo(self):
    raise FooException


class C(A):
  def foo(self):
    raise BarException


if __name__ == "__main__":
  print("- - -")
  B().foo()
  print("- - -")
  B().foo()
  print("- - -")
  C().foo() 

これを実行すると以下の結果になります。

- - -
before
foo_sub
after
- - -
before
foo_sub
after
- - -
before
Traceback (most recent call last):
  File "meta2.py", line 74, in <module>
    C().foo()
  File "meta2.py", line 24, in inner
    result = f(*args, **kwargs)
  File "meta2.py", line 65, in foo
    raise BarException
__main__.BarException

基本的にやっていることは前回の使用例と同じように例外をキャッチしてコールバック関数を
実行しているだけです。クラスAの継承先であるクラスBでメンバ関数 foo がデコレートされており、二回目の B().foo() 実行時に二重にデコレートされていないことが
わかります。少しコードは複雑になってしまいましたが、 継承先に(複数回適用せずに)デコレーターを適用できました。

結論

継承先に(複数回適用せずに)デコレーターを適用するにはメタクラスの __new__ 関数を用いる方法があります。