Source code for arsenal.cache.memoize
import atexit
import shelve
import pickle as pickle
from functools import partial
# TODO:
# * add option to pass a reference to another cache
[docs]class memoize(object):
""" cache a function's return value to avoid recalulation """
def __init__(self, func):
self.func = func
self.cache = {}
try:
self.__name__ = func.__name__
self.__doc__ = func.__doc__
except AttributeError:
pass
def __get__(self, obj, objtype=None):
"define `__get__` in case this function is a method."
#print(' method get', obj, objtype)
if obj is None: return self.func
return partial(self, obj)
def __call__(self, *args):
try:
return self.cache[args]
except KeyError:
value = self.func(*args)
try:
self.cache[args] = value
except TypeError:
# uncachable -- for instance, passing a list as an argument.
raise TypeError('uncachable arguments %r passed to memoized function.' % (args,))
return value
except TypeError:
# uncachable -- for instance, passing a list as an argument.
raise TypeError('uncachable arguments %r passed to memoized function.' % (args,))
def __repr__(self):
return '<memoize(%r)>' % self.func
[docs]class ShelfBasedCache(object):
""" cache a function's return value to avoid recalulation and save cache in a shelve. """
def __init__(self, func, key, None_is_bad=False):
self.func = func
self.filename = '{self.func.__name__}.shelf~'.format(self=self)
self.cache = shelve.open(self.filename, flag='c') #, writeback=True)
self.key = key
self.None_is_bad = None_is_bad
self.__name__ = 'ShelfBasedCache(%s)' % func.__name__
def __call__(self, *args):
p_args = self.key(args)
value = None
recompute = True
if self.cache.has_key(p_args):
recompute = False
value = self.cache[p_args]
if value is None and self.None_is_bad:
recompute = True
if recompute:
self.cache[p_args] = value = self.func(*args)
self.cache.sync()
return value
[docs]def persistent_cache(key=lambda x: x, None_is_bad=False):
def wrap(f):
return ShelfBasedCache(f, key, None_is_bad=None_is_bad)
return wrap
## TODO: automatically make a back-up of any previous pickles just in case the
## save fails. (Saving at-exit can be pretty flaky.)
[docs]class memoize_persistent(object):
"""
cache a function's return value to avoid recalulation and save the
cache (via pickle) at system exit so that it persists.
WARNING: retrieves cache for functions which might not be equivalent
if a revision is made to the code which is used to compute it.
"""
def __init__(self, func, filename=None):
self.func = func
self.filename = filename or '{self.func.__name__}.cache.pkl~'.format(self=self)
self.dirty = False
self.key = 0
self.cache = {}
self.loaded = False
atexit.register(self.save)
[docs] def save(self):
if self.cache and self.dirty:
with open(self.filename, 'w') as f:
pickle.dump((self.cache, self.key), f)
print('[ATEXIT] saved persistent cache for {self.func.__name__} to file "{self.filename}"'.format(self=self))
else:
print("[ATEXIT] found nothing to save in {self.func.__name__}'s cache.".format(self=self))
[docs] def load(self):
self.loaded = True
loaded_key = None
try:
with open(self.filename, 'r') as f:
(cache, loaded_key) = pickle.load(f)
except IOError:
pass
finally:
if self.key == loaded_key:
self.cache = cache
#print 'loaded cache for {self.func.__name__}'.format(self=self)
else:
self.cache = {}
#print 'failed to load cache for {self.func.__name__}'.format(self=self)
def __call__(self, *args):
# wait until you call the function to un-pickle
if not self.loaded:
self.load()
try:
return self.cache[args]
except KeyError:
value = self.func(*args)
try:
self.cache[args] = value
except TypeError:
# uncachable -- for instance, passing a list as an argument.
raise TypeError('uncachable arguments %r passed to memoized function.' % (args,))
else:
self.dirty = True
return value
except TypeError:
# uncachable -- for instance, passing a list as an argument.
raise TypeError('uncachable arguments %r passed to memoized function.' % (args,))
[docs] def get_cached(self, *args):
""" If result is cached return it, otherwise return `None`. """
# wait until you call the function to un-pickle
if not self.loaded:
self.load()
if args in self.cache:
return self.cache[args]
else:
return None
[docs]def test_memoize():
@memoize
def g(x):
return x**2
class foo:
def __init__(self, a):
self.a = a
@memoize
def goo(self, x):
return self.a * x
def __repr__(self):
return f'foo({self.a})'
a = foo(2)
b = foo(3)
print('created')
a_goo = a.goo
print(a.goo)
print('calling')
assert a_goo(5) == 2*5
assert b.goo(5) == 3*5
print('ok')
assert a_goo(5) == 2*5
print('xxx')
print(foo.goo) # triggers method.get
assert foo.goo(a, 4) == 2*4
print('----')
print(a.__class__.__dict__['goo']) # gives the memoize instance
print(a.__class__.__dict__['goo'].cache) # it's hashing (obj, args)
print('----')
assert g(4) == 4**2 # no __get__ call here.
print(g)
print(g.cache)
print('ok')
if __name__ == '__main__':
test_memoize()