diff --git a/hypy_utils/tqdm_utils.py b/hypy_utils/tqdm_utils.py index 1095fcc..76cdf6c 100644 --- a/hypy_utils/tqdm_utils.py +++ b/hypy_utils/tqdm_utils.py @@ -12,17 +12,19 @@ from tqdm.contrib.concurrent import process_map, thread_map def smap(fn: Callable, lst: Iterable, *args, **kwargs) -> list: - return [fn(i) for i in tqdm.tqdm(lst, position=0, leave=True)] + return [fn(i) for i in tqdm.tqdm(lst, position=0, leave=True, *args, **kwargs)] def pmap(fn: Callable, lst: Iterable, *args, **kwargs) -> list: tqdm_args = dict(position=0, leave=True, chunksize=1, tqdm_class=tqdm.tqdm, max_workers=os.cpu_count()) - return process_map(fn, lst, *args, **{**tqdm_args, **kwargs}) + tqdm_args.update(kwargs) + return process_map(fn, lst, *args, **tqdm_args) def tmap(fn: Callable, lst: Iterable, *args, **kwargs) -> list: tqdm_args = dict(position=0, leave=True, chunksize=1, tqdm_class=tqdm.tqdm, max_workers=os.cpu_count()) - return thread_map(fn, lst, *args, **{**tqdm_args, **kwargs}) + tqdm_args.update(kwargs) + return thread_map(fn, lst, *args, **tqdm_args) def tq(it: Iterable, desc: str, *args, **kwargs) -> tqdm: