diff --git a/hypy_utils/__init__.py b/hypy_utils/__init__.py index b7393dc..d18117b 100644 --- a/hypy_utils/__init__.py +++ b/hypy_utils/__init__.py @@ -1,6 +1,6 @@ from __future__ import annotations -__version__ = "1.0.7" +__version__ = "1.0.8" import dataclasses import hashlib @@ -8,7 +8,7 @@ import json import time from datetime import datetime, date from pathlib import Path -from typing import Union +from typing import Union, Callable def ansi_rgb(r: int, g: int, b: int, foreground: bool = True) -> str: @@ -202,3 +202,16 @@ class Timer: def reset(self): self.start = time.time_ns() + + +def mem(var: str): + print(f'Memory usage for {var}: {eval(f"sys.getsizeof({var})") / 1024:.1f}KB') + + +def run_time(func: Callable, *args, **kwargs): + name = getattr(func, '__name__', 'function') + start = time.time_ns() + iter = kwargs.pop('iter', 10) + _ = [func(*args, **kwargs) for _ in range(iter)] + ms = (time.time_ns() - start) / 1e6 + print(f'RT {name:30} {ms:6.1f} ms') diff --git a/hypy_utils/scientific_utils.py b/hypy_utils/scientific_utils.py new file mode 100644 index 0000000..5e3519d --- /dev/null +++ b/hypy_utils/scientific_utils.py @@ -0,0 +1,72 @@ +""" +Importing this file requires numpy, matplotlib, and numba +""" +from __future__ import annotations + +import time +from dataclasses import dataclass +from typing import Callable + +import numpy as np +from matplotlib import pyplot as plt +from numba import njit + + +@dataclass +class Statistics: + mean: float + median: float + lower_quartile: float + upper_quartile: float + iqr: float + minimum: float + maximum: float + count: int + total: float + + def get_metric_6(self) -> tuple[float, float, float, float, float, float]: + return self.mean, self.median, self.minimum, self.maximum, self.lower_quartile, self.upper_quartile + + +@njit(cache=True) +def _calc_col_stats_helper(col: np.ndarray) -> tuple[float, float, float, float, float, float, float, int, float]: + q1 = np.quantile(col, 0.25) + q3 = np.quantile(col, 0.75) + return ( + float(np.mean(col)), + float(np.median(col)), + float(q1), + float(q3), + float(q3 - q1), + float(np.min(col)), + float(np.max(col)), + len(col), + float(np.sum(col)) + ) + + +def calc_col_stats(col: np.ndarray | list) -> Statistics: + """ + Compute statistics for a data column + + :param col: Input column (tested on 1D array) + :return: Statistics + """ + if isinstance(col, list): + col = np.array(col) + return Statistics(*_calc_col_stats_helper(col)) + + +def plot(**kwargs) -> plt: + """ + Pyplot configurator shorthand + + Example: plt_cfg(xlabel="X", ylabel="Y") is equivalent to plt.xlabel("X"); plt.ylabel("Y") + """ + for k, args in kwargs.items(): + if isinstance(args, dict): + getattr(plt, k)(**args) + else: + getattr(plt, k)(args) + return plt + diff --git a/hypy_utils/serializer.py b/hypy_utils/serializer.py index 736ba28..2f4e6d1 100644 --- a/hypy_utils/serializer.py +++ b/hypy_utils/serializer.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import io import pickle @@ -6,7 +8,11 @@ def pickle_encode(obj: any, protocol=None, fix_imports=True) -> bytes: """ Encode object to pickle bytes - >>> by = pickle_encode({'meow': 565656}) + >>> by = pickle_encode({'function': pickle_encode}) + >>> len(by) + 57 + >>> decoded = pickle_decode(by) + >>> by = decoded['function']({'meow': 565656}) >>> pickle_decode(by) {'meow': 565656} """ diff --git a/hypy_utils/tqdm_utils.py b/hypy_utils/tqdm_utils.py new file mode 100644 index 0000000..3d48ab6 --- /dev/null +++ b/hypy_utils/tqdm_utils.py @@ -0,0 +1,37 @@ +""" +Importing this file requires installing tqdm. +""" +from __future__ import annotations + +from functools import partial +from typing import Callable, Iterable + +import tqdm +from tqdm.contrib.concurrent import process_map, thread_map + + +def smap(fn: Callable, lst: Iterable, *args, **kwargs) -> list: + return [fn(i) for i in tqdm.tqdm(lst, position=0, leave=True)] + + +def pmap(fn: Callable, lst: Iterable, *args, **kwargs) -> list: + tqdm_args = dict(position=0, leave=True, chunksize=1, tqdm_class=tqdm.tqdm, max_workers=os.cpu_count()) + return process_map(fn, lst, *args, **{**tqdm_args, **kwargs}) + + +def tmap(fn: Callable, lst: Iterable, *args, **kwargs) -> list: + tqdm_args = dict(position=0, leave=True, chunksize=1, tqdm_class=tqdm.tqdm, max_workers=os.cpu_count()) + return process_map(fn, lst, *args, **{**tqdm_args, **kwargs}) + + +def tq(it: Iterable, desc: str, *args, **kwargs) -> tqdm: + tqdm_args = dict(position=0, leave=True) + return tqdm.tqdm(it, desc, *args, **{**tqdm_args, **kwargs}) + + +def patch_tqdm(): + tqdm_args = dict(chunksize=1, position=0, leave=True, tqdm_class=tqdm.tqdm, max_workers=os.cpu_count()) + tq: Callable[[Iterable], tqdm.tqdm] = partial(tqdm.tqdm, position=0, leave=True) + pmap = partial(process_map, **tqdm_args) + tmap = partial(thread_map, **tqdm_args) + return tq, pmap, tmap