class MultiModalCache:
@classmethod
def get_leaf_size(
cls,
leaf: object,
*,
debug: bool = False,
) -> int:
if isinstance(leaf, MultiModalFieldElem):
return cls.get_item_size(leaf.data) # type: ignore
# These are not subclasses of dict
if isinstance(leaf, MultiModalKwargsItems):
return cls.get_item_size(leaf.data) # type: ignore
if isinstance(leaf, MultiModalKwargsItem):
return cls.get_item_size(leaf.data) # type: ignore
if isinstance(leaf, MultiModalKwargs):
return cls.get_item_size(leaf.data) # type: ignore
# sys.getsizeof doesn't work for tensors
if isinstance(leaf, torch.Tensor):
return leaf.nbytes
if isinstance(leaf, MultiModalCacheItemMetadata):
return leaf.size
return sys.getsizeof(leaf)
@classmethod
def get_item_size(
cls,
value: MultiModalCacheValue,
*,
debug: bool = False,
) -> int:
size = json_reduce_leaves(
lambda a, b: a + b,
json_map_leaves(lambda x: cls.get_leaf_size(x, debug=debug),
value),
)
if debug:
logger.debug("Calculated size of %s to be %.2f GiB", type(value),
size / GiB_bytes)
return size
@classmethod
def get_lru_cache(
cls,
capacity_gb: float,
value_type: type[_V],
*,
debug: bool = False,
) -> LRUCache[str, _V]:
return LRUCache(
GiB_bytes * capacity_gb,
getsizeof=lambda x: cls.get_item_size(x, debug=debug),
)