Skip to content

API Reference

Totally typed, plausibly practical, and remarkably random utilities—for me, and maybe for you too.

__all__ = ['bisect_predicate', 'get_first', 'is_sorted'] module-attribute

bisect_predicate(seq, predicate, lo=0, hi=None)

Find the first index where predicate flips from True to False using binary search.

In other words: Find first index where predicate(seq[i]) becomes False.

The sequence must be partitioned such that all elements where predicate(item) is True appear before elements where predicate(item) is False. This is a generalized version of bisect_right that works with arbitrary predicates rather than comparison operators.

Parameters:

  • seq (Sequence[T]) –

    Partitioned sequence to search. Must have all True-predicate elements before any False-predicate elements.

  • predicate (Callable[[T], bool]) –

    Function that returns True for elements that should be considered "left" of the insertion point. Typically a condition like lambda x: x < target.

  • lo (int, default: 0 ) –

    Lower bound index to start search (inclusive).

  • hi (int | None, default: None ) –

    Upper bound index to end search (exclusive). Defaults to len(seq).

Returns:

  • int

    Insertion point index where predicate first fails. This will be:

  • int
    • 0 if predicate fails for all elements
  • int
    • len(seq) if predicate holds for all elements
  • int
    • First index where predicate(seq[i]) == False otherwise

Examples:

Find first non-positive number (predicate=lambda x: x <= 0):

>>> bisect_predicate([-5, -3, 0, 2, 5], lambda x: x <= 0)
3

All elements satisfy predicate:

>>> bisect_predicate([True, True, True], lambda b: b)
3

Edge case - empty sequence:

>>> bisect_predicate([], lambda x: True)
0

Custom search range:

>>> bisect_predicate([1, 3, 5, 7, 9], lambda x: x < 6, lo=1, hi=4)
3
Note

Similar to bisect.bisect_right but works with arbitrary predicates. Requires the array to be properly partitioned - undefined behavior otherwise.

Source code in src/kajihs_utils/core.py
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
def bisect_predicate[T](
    seq: Sequence[T],
    predicate: Callable[[T], bool],
    lo: int = 0,
    hi: int | None = None,
) -> int:
    """
    Find the first index where predicate flips from True to False using binary search.

    In other words: Find first index where predicate(seq[i]) becomes False.

    The sequence must be partitioned such that all elements where predicate(item) is True
    appear before elements where predicate(item) is False. This is a generalized version
    of bisect_right that works with arbitrary predicates rather than comparison operators.

    Args:
        seq: Partitioned sequence to search. Must have all True-predicate elements
            before any False-predicate elements.
        predicate: Function that returns True for elements that should be considered
            "left" of the insertion point. Typically a condition like lambda x: x < target.
        lo: Lower bound index to start search (inclusive).
        hi: Upper bound index to end search (exclusive). Defaults to len(seq).

    Returns:
        Insertion point index where predicate first fails. This will be:
        - 0 if predicate fails for all elements
        - len(seq) if predicate holds for all elements
        - First index where predicate(seq[i]) == False otherwise

    Examples:
        Find first non-positive number (predicate=lambda x: x <= 0):
        >>> bisect_predicate([-5, -3, 0, 2, 5], lambda x: x <= 0)
        3

        All elements satisfy predicate:
        >>> bisect_predicate([True, True, True], lambda b: b)
        3

        Edge case - empty sequence:
        >>> bisect_predicate([], lambda x: True)
        0

        Custom search range:
        >>> bisect_predicate([1, 3, 5, 7, 9], lambda x: x < 6, lo=1, hi=4)
        3

    Note:
        Similar to bisect.bisect_right but works with arbitrary predicates.
        Requires the array to be properly partitioned - undefined behavior otherwise.
    """
    hi = hi or len(seq)

    while lo < hi:
        mid = (lo + hi) // 2
        if predicate(seq[mid]):
            # print(f"Can complete {mid}")
            lo = mid + 1
        else:
            # print(f"Can't complete {mid}")
            hi = mid
    return lo

get_first(d, /, keys, default=None, *, no_default=False)

get_first(d: Mapping[K, V], /, keys: Iterable[K], default: D = None) -> V | D
get_first(d: Mapping[K, V], /, keys: Iterable[K], default: Any = None, *, no_default: Literal[True] = True) -> V
get_first(d: Mapping[K, V], /, keys: Iterable[K], default: D = None, *, no_default: bool = True) -> V | D

Return the value of the first key that exists in the mapping.

Parameters:

  • d (Mapping[K, V]) –

    The dictionary to search in.

  • keys (Iterable[K]) –

    The sequence of keys to look for.

  • default (D, default: None ) –

    The value to return if no key is found.

  • no_default (bool, default: False ) –

    If True, raises a KeyError if no key is found.

Returns:

  • V | D

    The value associated with the first found key, or the default value if not found.

Raises:

  • KeyError

    If no_default is True and none of the keys are found.

Examples:

>>> d = {"a": 1, "b": 2, "c": 3}
>>> get_first(d, ["x", "a", "b"])
1
>>> get_first(d, ["x", "y"], default=0)
0
>>> get_first(d, ["x", "y"], no_default=True)  # Raises: KeyError
Traceback (most recent call last):
...
KeyError: "None of the keys ['x', 'y'] were found in the mapping."
Source code in src/kajihs_utils/core.py
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
def get_first[K, V, D](
    d: Mapping[K, V], /, keys: Iterable[K], default: D = None, *, no_default: bool = False
) -> V | D:
    """
    Return the value of the first key that exists in the mapping.

    Args:
        d: The dictionary to search in.
        keys: The sequence of keys to look for.
        default: The value to return if no key is found.
        no_default: If `True`, raises a `KeyError` if no key is found.

    Returns:
        The value associated with the first found key, or the default value if not found.

    Raises:
        KeyError: If `no_default` is `True` and none of the keys are found.

    Examples:
        >>> d = {"a": 1, "b": 2, "c": 3}
        >>> get_first(d, ["x", "a", "b"])
        1
        >>> get_first(d, ["x", "y"], default=0)
        0
        >>> get_first(d, ["x", "y"], no_default=True)  # Raises: KeyError
        Traceback (most recent call last):
        ...
        KeyError: "None of the keys ['x', 'y'] were found in the mapping."
    """
    for key in keys:
        if key in d:
            return d[key]

    if no_default:
        msg = f"None of the keys {list(keys)} were found in the mapping."
        raise KeyError(msg)

    return default

is_sorted(values, /, *, reverse=False)

Determine if the given iterable is sorted in ascending or descending order.

Parameters:

  • values (Iterable[SupportsDunderLT[Any]]) –

    An iterable of comparable items supporting the < operator.

  • reverse (bool, default: False ) –

    If False (default), checks for non-decreasing order; if True, checks for non-increasing order.

Returns:

  • bool

    True if the sequence is sorted according to the given order, False otherwise.

Examples:

>>> is_sorted([1, 2, 2, 3])
True
>>> is_sorted([3, 2, 1], reverse=True)
True
>>> is_sorted([2, 1, 3])
False
>>> is_sorted([])
True
>>> is_sorted([42])
True
>>> # Works with generators as well
>>> is_sorted(x * x for x in [1, 2, 3, 4])
True
>>> # Equal elements are considered sorted
>>> is_sorted([1, 1, 1])
True
Source code in src/kajihs_utils/core.py
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
def is_sorted(values: Iterable[SupportsDunderLT[Any]], /, *, reverse: bool = False) -> bool:
    """
    Determine if the given iterable is sorted in ascending or descending order.

    Args:
        values: An iterable of comparable items supporting the < operator.
        reverse: If False (default), checks for non-decreasing order; if True,
            checks for non-increasing order.

    Returns:
        True if the sequence is sorted according to the given order, False otherwise.

    Examples:
        >>> is_sorted([1, 2, 2, 3])
        True
        >>> is_sorted([3, 2, 1], reverse=True)
        True
        >>> is_sorted([2, 1, 3])
        False
        >>> is_sorted([])
        True
        >>> is_sorted([42])
        True
        >>> # Works with generators as well
        >>> is_sorted(x * x for x in [1, 2, 3, 4])
        True
        >>> # Equal elements are considered sorted
        >>> is_sorted([1, 1, 1])
        True
    """
    op = operator.le if not reverse else operator.ge
    return all(starmap(op, pairwise(values)))

__about__

kajihs-utils metadata.

__author_emails__ = ['itskajih@gmail.com'] module-attribute

__authors__ = ['Julian Paquerot'] module-attribute

__changelog_url__ = f'{__repo_url__}/blob/main/CHANGELOG.md' module-attribute

__documentation_url__ = 'https://Kajiih.github.io/kajihs-utils/' module-attribute

__homepage_url__ = 'https://github.com/Kajiih/kajihs-utils' module-attribute

__issues_url__ = f'{__repo_url__}/issues' module-attribute

__module_name__ = 'kajihs_utils' module-attribute

__repo_url__ = 'https://github.com/Kajiih/kajihs-utils' module-attribute

__version__ = '0.10.3' module-attribute

arithmetic

Utils for arithmetic.

close_factors and almost_factors taken from: https://code.visualstudio.com/api/language-extensions/semantic-highlight-guide

almost_factors(n, /, ratio=0.5)

Find a pair of factors that are close enough.

Parameters:

  • n (int) –

    The number to almost-factorize.

  • ratio (float, default: 0.5 ) –

    The threshold ratio between both factors.

Returns:

  • int

    A tuple containing the first two numbers factoring to n or more such

  • int

    that factor 1 is at most 1/ratio times larger than factor 2.

Example

almost_factors(10, ratio=0.5) (4, 3)

Source code in src/kajihs_utils/arithmetic.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def almost_factors(n: int, /, ratio: float = 0.5) -> tuple[int, int]:
    """
    Find a pair of factors that are close enough.

    Args:
        n: The number to almost-factorize.
        ratio: The threshold ratio between both factors.

    Returns:
        A tuple containing the first two numbers factoring to n or more such
        that factor 1 is at most 1/ratio times larger than factor 2.

    Example:
        >>> almost_factors(10, ratio=0.5)
        (4, 3)
    """
    while True:
        factor1, factor2 = closest_factors(n)
        if ratio * factor1 <= factor2:
            break
        n += 1
    return factor1, factor2

closest_factors(n)

Find the closest pair of factors.

Parameters:

  • n (int) –

    The number to find factors for.

Returns:

  • tuple[int, int]

    A tuple containing the two closest factors of n, the larger first.

Example

closest_factors(99) (11, 9)

Source code in src/kajihs_utils/arithmetic.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def closest_factors(n: int, /) -> tuple[int, int]:
    """
    Find the closest pair of factors.

    Args:
        n: The number to find factors for.

    Returns:
        A tuple containing the two closest factors of n, the larger first.

    Example:
        >>> closest_factors(99)
        (11, 9)
    """
    factor1 = 0
    factor2 = n
    while factor1 + 1 <= factor2:
        factor1 += 1
        if n % factor1 == 0:
            factor2 = n // factor1

    return factor1, factor2

core

General utils without dependencies.

bisect_predicate(seq, predicate, lo=0, hi=None)

Find the first index where predicate flips from True to False using binary search.

In other words: Find first index where predicate(seq[i]) becomes False.

The sequence must be partitioned such that all elements where predicate(item) is True appear before elements where predicate(item) is False. This is a generalized version of bisect_right that works with arbitrary predicates rather than comparison operators.

Parameters:

  • seq (Sequence[T]) –

    Partitioned sequence to search. Must have all True-predicate elements before any False-predicate elements.

  • predicate (Callable[[T], bool]) –

    Function that returns True for elements that should be considered "left" of the insertion point. Typically a condition like lambda x: x < target.

  • lo (int, default: 0 ) –

    Lower bound index to start search (inclusive).

  • hi (int | None, default: None ) –

    Upper bound index to end search (exclusive). Defaults to len(seq).

Returns:

  • int

    Insertion point index where predicate first fails. This will be:

  • int
    • 0 if predicate fails for all elements
  • int
    • len(seq) if predicate holds for all elements
  • int
    • First index where predicate(seq[i]) == False otherwise

Examples:

Find first non-positive number (predicate=lambda x: x <= 0):

>>> bisect_predicate([-5, -3, 0, 2, 5], lambda x: x <= 0)
3

All elements satisfy predicate:

>>> bisect_predicate([True, True, True], lambda b: b)
3

Edge case - empty sequence:

>>> bisect_predicate([], lambda x: True)
0

Custom search range:

>>> bisect_predicate([1, 3, 5, 7, 9], lambda x: x < 6, lo=1, hi=4)
3
Note

Similar to bisect.bisect_right but works with arbitrary predicates. Requires the array to be properly partitioned - undefined behavior otherwise.

Source code in src/kajihs_utils/core.py
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
def bisect_predicate[T](
    seq: Sequence[T],
    predicate: Callable[[T], bool],
    lo: int = 0,
    hi: int | None = None,
) -> int:
    """
    Find the first index where predicate flips from True to False using binary search.

    In other words: Find first index where predicate(seq[i]) becomes False.

    The sequence must be partitioned such that all elements where predicate(item) is True
    appear before elements where predicate(item) is False. This is a generalized version
    of bisect_right that works with arbitrary predicates rather than comparison operators.

    Args:
        seq: Partitioned sequence to search. Must have all True-predicate elements
            before any False-predicate elements.
        predicate: Function that returns True for elements that should be considered
            "left" of the insertion point. Typically a condition like lambda x: x < target.
        lo: Lower bound index to start search (inclusive).
        hi: Upper bound index to end search (exclusive). Defaults to len(seq).

    Returns:
        Insertion point index where predicate first fails. This will be:
        - 0 if predicate fails for all elements
        - len(seq) if predicate holds for all elements
        - First index where predicate(seq[i]) == False otherwise

    Examples:
        Find first non-positive number (predicate=lambda x: x <= 0):
        >>> bisect_predicate([-5, -3, 0, 2, 5], lambda x: x <= 0)
        3

        All elements satisfy predicate:
        >>> bisect_predicate([True, True, True], lambda b: b)
        3

        Edge case - empty sequence:
        >>> bisect_predicate([], lambda x: True)
        0

        Custom search range:
        >>> bisect_predicate([1, 3, 5, 7, 9], lambda x: x < 6, lo=1, hi=4)
        3

    Note:
        Similar to bisect.bisect_right but works with arbitrary predicates.
        Requires the array to be properly partitioned - undefined behavior otherwise.
    """
    hi = hi or len(seq)

    while lo < hi:
        mid = (lo + hi) // 2
        if predicate(seq[mid]):
            # print(f"Can complete {mid}")
            lo = mid + 1
        else:
            # print(f"Can't complete {mid}")
            hi = mid
    return lo

get_first(d, /, keys, default=None, *, no_default=False)

get_first(d: Mapping[K, V], /, keys: Iterable[K], default: D = None) -> V | D
get_first(d: Mapping[K, V], /, keys: Iterable[K], default: Any = None, *, no_default: Literal[True] = True) -> V
get_first(d: Mapping[K, V], /, keys: Iterable[K], default: D = None, *, no_default: bool = True) -> V | D

Return the value of the first key that exists in the mapping.

Parameters:

  • d (Mapping[K, V]) –

    The dictionary to search in.

  • keys (Iterable[K]) –

    The sequence of keys to look for.

  • default (D, default: None ) –

    The value to return if no key is found.

  • no_default (bool, default: False ) –

    If True, raises a KeyError if no key is found.

Returns:

  • V | D

    The value associated with the first found key, or the default value if not found.

Raises:

  • KeyError

    If no_default is True and none of the keys are found.

Examples:

>>> d = {"a": 1, "b": 2, "c": 3}
>>> get_first(d, ["x", "a", "b"])
1
>>> get_first(d, ["x", "y"], default=0)
0
>>> get_first(d, ["x", "y"], no_default=True)  # Raises: KeyError
Traceback (most recent call last):
...
KeyError: "None of the keys ['x', 'y'] were found in the mapping."
Source code in src/kajihs_utils/core.py
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
def get_first[K, V, D](
    d: Mapping[K, V], /, keys: Iterable[K], default: D = None, *, no_default: bool = False
) -> V | D:
    """
    Return the value of the first key that exists in the mapping.

    Args:
        d: The dictionary to search in.
        keys: The sequence of keys to look for.
        default: The value to return if no key is found.
        no_default: If `True`, raises a `KeyError` if no key is found.

    Returns:
        The value associated with the first found key, or the default value if not found.

    Raises:
        KeyError: If `no_default` is `True` and none of the keys are found.

    Examples:
        >>> d = {"a": 1, "b": 2, "c": 3}
        >>> get_first(d, ["x", "a", "b"])
        1
        >>> get_first(d, ["x", "y"], default=0)
        0
        >>> get_first(d, ["x", "y"], no_default=True)  # Raises: KeyError
        Traceback (most recent call last):
        ...
        KeyError: "None of the keys ['x', 'y'] were found in the mapping."
    """
    for key in keys:
        if key in d:
            return d[key]

    if no_default:
        msg = f"None of the keys {list(keys)} were found in the mapping."
        raise KeyError(msg)

    return default

is_sorted(values, /, *, reverse=False)

Determine if the given iterable is sorted in ascending or descending order.

Parameters:

  • values (Iterable[SupportsDunderLT[Any]]) –

    An iterable of comparable items supporting the < operator.

  • reverse (bool, default: False ) –

    If False (default), checks for non-decreasing order; if True, checks for non-increasing order.

Returns:

  • bool

    True if the sequence is sorted according to the given order, False otherwise.

Examples:

>>> is_sorted([1, 2, 2, 3])
True
>>> is_sorted([3, 2, 1], reverse=True)
True
>>> is_sorted([2, 1, 3])
False
>>> is_sorted([])
True
>>> is_sorted([42])
True
>>> # Works with generators as well
>>> is_sorted(x * x for x in [1, 2, 3, 4])
True
>>> # Equal elements are considered sorted
>>> is_sorted([1, 1, 1])
True
Source code in src/kajihs_utils/core.py
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
def is_sorted(values: Iterable[SupportsDunderLT[Any]], /, *, reverse: bool = False) -> bool:
    """
    Determine if the given iterable is sorted in ascending or descending order.

    Args:
        values: An iterable of comparable items supporting the < operator.
        reverse: If False (default), checks for non-decreasing order; if True,
            checks for non-increasing order.

    Returns:
        True if the sequence is sorted according to the given order, False otherwise.

    Examples:
        >>> is_sorted([1, 2, 2, 3])
        True
        >>> is_sorted([3, 2, 1], reverse=True)
        True
        >>> is_sorted([2, 1, 3])
        False
        >>> is_sorted([])
        True
        >>> is_sorted([42])
        True
        >>> # Works with generators as well
        >>> is_sorted(x * x for x in [1, 2, 3, 4])
        True
        >>> # Equal elements are considered sorted
        >>> is_sorted([1, 1, 1])
        True
    """
    op = operator.le if not reverse else operator.ge
    return all(starmap(op, pairwise(values)))

loguru

Utils for logging, specifically using Loguru.

InterceptHandler

Bases: Handler

Intercepts logs from the standard logging module and forwards them to Loguru.

Snippet from https://github.com/Delgan/loguru/tree/master

emit(record)

Forward log records from the standard logging system to Loguru.

Source code in src/kajihs_utils/loguru.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
@override
def emit(self, record: logging.LogRecord) -> None:
    """Forward log records from the standard logging system to Loguru."""
    # Get corresponding Loguru level if it exists.
    level: str | int
    try:
        level = logger.level(record.levelname).name
    except ValueError:
        level = record.levelno

    # Find caller from where originated the logged message.
    frame = inspect.currentframe()
    depth = 0
    while frame and (depth == 0 or frame.f_code.co_filename == logging.__file__):
        frame = frame.f_back
        depth += 1

    logger.opt(depth=depth, exception=record.exc_info).log(level, record.getMessage())

prompt(prompt)

Wrap rich.Prompt.ask to add newline and color.

Source code in src/kajihs_utils/loguru.py
13
14
15
16
17
def prompt(prompt: str, /) -> str:
    """Wrap rich.Prompt.ask to add newline and color."""
    val = Prompt.ask(f"\n[cyan bold]{prompt}[/cyan bold]")
    logger.debug(f'Prompt: "{prompt}" -> "{val}"')
    return val

setup_logging(prefix='app', log_dir='logs')

Set up beautiful loguru logging in files and console.

Redirects logging with Loguru, creates 2 logging files with and without colors and log to console.

Parameters:

  • prefix (str, default: 'app' ) –

    Prefix for the log files without extensions.

  • log_dir (str | Path, default: 'logs' ) –

    Directory path to store log files.

Source code in src/kajihs_utils/loguru.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
def setup_logging(prefix: str = "app", log_dir: str | Path = "logs") -> None:
    """
    Set up beautiful loguru logging in files and console.

    Redirects logging with Loguru, creates 2 logging files with and without
    colors and log to console.

    Args:
        prefix: Prefix for the log files without extensions.
        log_dir: Directory path to store log files.

    """
    # Redirect logging with Loguru
    logging.basicConfig(handlers=[InterceptHandler()], level=logging.WARNING, force=True)

    logger.remove()

    log_dir = Path(log_dir)
    logger.add(
        log_dir / f"{prefix}.log",
        level="DEBUG",
        format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}",
        rotation="1 week",
        compression="zip",
    )
    logger.add(
        log_dir / f"{prefix}.clog",
        level="DEBUG",
        # format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}",
        rotation="1 week",
        compression="zip",
        colorize=True,
    )
    logger.add(
        sys.stdout,
        level="INFO",
        # format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{message}</level>",
        format="<level>{message}</level>",
    )

numpy_utils

Tools for numpy.

IncompatibleShapeError(arr1, arr2, obj)

Bases: ValueError

Shapes of input arrays are incompatible for a given function.

Source code in src/kajihs_utils/numpy_utils.py
22
23
24
25
def __init__(self, arr1: NDArray[Any], arr2: NDArray[Any], obj: Any) -> None:
    super().__init__(
        f"Shapes of inputs arrays {arr1.shape} and {arr2.shape} are incompatible for {obj.__name__}"
    )

Vec2d

Bases: ndarray[Literal[2], dtype[float64]]

A 2D vector subclassing numpy.ndarray with .x and .y properties.

x property writable

X coordinate.

y property writable

Y coordinate.

__new__(x, y)

Source code in src/kajihs_utils/numpy_utils.py
131
132
133
def __new__(cls, x: AnyFloat, y: AnyFloat) -> Self:  # noqa: D102
    obj = np.asarray([x, y], dtype=np.float64).view(cls)
    return obj

__repr__()

Source code in src/kajihs_utils/numpy_utils.py
184
185
186
@override
def __repr__(self) -> str:
    return f"Vec2d({self.x:.2f}, {self.y:.2f})"

angle()

Return the angle (in degrees) between the vector and the positive x-axis.

Source code in src/kajihs_utils/numpy_utils.py
162
163
164
def angle(self) -> float:
    """Return the angle (in degrees) between the vector and the positive x-axis."""
    return np.degrees(np.arctan2(self.y, self.x))

magnitude()

Magnitude or norm of the vector.

Source code in src/kajihs_utils/numpy_utils.py
153
154
155
def magnitude(self) -> floating[Any]:
    """Magnitude or norm of the vector."""
    return np.linalg.norm(self)

normalized()

Return a normalized version of the vector.

Source code in src/kajihs_utils/numpy_utils.py
157
158
159
160
def normalized(self) -> Vec2d:
    """Return a normalized version of the vector."""
    mag = self.magnitude()
    return self if mag == 0 else Vec2d(self.x / mag, self.y / mag)

rotate(degrees_angle, center=(0, 0))

Rotates the vector counterclockwise by a given angle (in degrees) around the point (cx, cy).

Source code in src/kajihs_utils/numpy_utils.py
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
def rotate(self, degrees_angle: float, center: tuple[float, float] = (0, 0)) -> Vec2d:
    """Rotates the vector counterclockwise by a given angle (in degrees) around the point (cx, cy)."""
    cx, cy = center[0], center[1]
    # Step 1: Translate the vector to the origin (subtract the center of rotation)
    translated_x = self.x - cx
    translated_y = self.y - cy

    # Step 2: Rotate the translated vector
    rad = np.radians(degrees_angle)
    rot_matrix = np.array([[np.cos(rad), -np.sin(rad)], [np.sin(rad), np.cos(rad)]])
    rotated_vector = rot_matrix @ np.array([translated_x, translated_y])

    # Step 3: Translate the vector back to its original position
    new_x = rotated_vector[0] + cx
    new_y = rotated_vector[1] + cy

    return Vec2d(new_x, new_y)

find_closest(x, targets, norm_ord=None)

Find the index of the closest element(s) from x for each target in targets.

Given one or multiple targets (vectors vectors or scalars), this function computes the distance to each element in x and returns the indices of the closest matches. If targets is of the same shape as an element of x, the function returns a single integer index. If targets contains multiple elements, it returns an array of indices corresponding to each target.

If the dimensionality of the vectors in x is greater than 2, the vectors will be flattened into 1D before computing distances.

Parameters:

  • x (Iterable[T] | ArrayLike) –

    An iterable or array-like collection of elements (scalars, vectors, or higher-dimensional arrays). For example, x could be an array of shape (N,) (scalars), (N, D) (D-dimensional vectors), (N, H, W) (2D arrays), or higher-dimensional arrays.

  • targets (Iterable[T] | T | ArrayLike) –

    One or multiple target elements for which you want to find the closest match in x. Can be a single scalar/vector/array or an iterable of them. Must be shape-compatible with the elements of x.

  • norm_ord (Norm | None, default: None ) –

    The order of the norm used for distance computation. Uses the same conventions as numpy.linalg.norm.

Returns:

  • ndarray[tuple[int], dtype[int_]] | int_

    An array of indices. If a single target was given, a single index is

  • ndarray[tuple[int], dtype[int_]] | int_

    returned. If multiple targets were given, an array of shape (M,) is

  • ndarray[tuple[int], dtype[int_]] | int_

    returned, where M is the number of target elements. Each value is the

  • ndarray[tuple[int], dtype[int_]] | int_

    index of the closest element in x to the corresponding target.

Raises:

  • IncompatibleShapeError

    If targets cannot be broadcast or reshaped to match the shape structure of the elements in x.

Examples:

>>> import numpy as np
>>> x = np.array([0, 10, 20, 30])
>>> int(find_closest(x, 12))
1
>>> # Multiple targets
>>> find_closest(x, [2, 26])
array([0, 3])
>>> # Using vectors
>>> x = np.array([[0, 0], [10, 10], [20, 20]])
>>> int(find_closest(x, [6, 5]))  # Single target vector
1
>>> find_closest(x, [[-1, -1], [15, 12]])  # Multiple target vectors
array([0, 1])
>>> # Higher dimensional arrays
>>> x = np.array([[[0, 0], [0, 0]], [[10, 10], [10, 10]], [[20, 20], [20, 20]]])
>>> int(find_closest(x, [[2, 2], [2, 2]]))
0
>>> find_closest(x, [[[0, 0], [1, 1]], [[19, 19], [19, 19]]])
array([0, 2])
Source code in src/kajihs_utils/numpy_utils.py
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
def find_closest[T](
    x: Iterable[T] | ArrayLike,
    targets: Iterable[T] | T | ArrayLike,
    norm_ord: Norm | None = None,
) -> ndarray[tuple[int], dtype[int_]] | int_:
    """
    Find the index of the closest element(s) from `x` for each target in `targets`.

    Given one or multiple `targets` (vectors vectors or scalars),
    this function computes the distance to each element in `x` and returns the
    indices of the closest matches. If `targets` is of the same shape as an
    element of `x`, the function returns a single integer index. If `targets`
    contains multiple elements, it returns an array of indices corresponding to
    each target.

    If the dimensionality of the vectors in `x` is greater than 2, the vectors
    will be flattened into 1D before computing distances.

    Args:
        x: An iterable or array-like collection of elements (scalars, vectors,
            or higher-dimensional arrays). For example, `x` could be an array of
            shape `(N,)` (scalars), `(N, D)` (D-dimensional vectors),
            `(N, H, W)` (2D arrays), or higher-dimensional arrays.
        targets: One or multiple target elements for which you want to find the
            closest match in `x`. Can be a single scalar/vector/array or an
            iterable of them.
            Must be shape-compatible with the elements of `x`.
        norm_ord: The order of the norm used for distance computation.
            Uses the same conventions as `numpy.linalg.norm`.

    Returns:
        An array of indices. If a single target was given, a single index is
        returned. If multiple targets were given, an array of shape `(M,)` is
        returned, where `M` is the number of target elements. Each value is the
        index of the closest element in `x` to the corresponding target.

    Raises:
        IncompatibleShapeError: If `targets` cannot be broadcast or reshaped to
            match the shape structure of the elements in `x`.

    Examples:
        >>> import numpy as np
        >>> x = np.array([0, 10, 20, 30])
        >>> int(find_closest(x, 12))
        1
        >>> # Multiple targets
        >>> find_closest(x, [2, 26])
        array([0, 3])

        >>> # Using vectors
        >>> x = np.array([[0, 0], [10, 10], [20, 20]])
        >>> int(find_closest(x, [6, 5]))  # Single target vector
        1
        >>> find_closest(x, [[-1, -1], [15, 12]])  # Multiple target vectors
        array([0, 1])

        >>> # Higher dimensional arrays
        >>> x = np.array([[[0, 0], [0, 0]], [[10, 10], [10, 10]], [[20, 20], [20, 20]]])
        >>> int(find_closest(x, [[2, 2], [2, 2]]))
        0
        >>> find_closest(x, [[[0, 0], [1, 1]], [[19, 19], [19, 19]]])
        array([0, 2])
    """
    x = np.array(x)  # (N, vector_shape)
    targets = np.array(targets)
    vector_shape = x.shape[1:]

    # Check that shapes are compatible
    do_unsqueeze = False
    if targets.shape == vector_shape:
        targets = np.atleast_1d(targets)[np.newaxis, :]  # (M, vector_shape)
        do_unsqueeze = True
    elif targets.shape[1:] != vector_shape:
        raise IncompatibleShapeError(x, targets, find_closest)

    nb_vectors = x.shape[0]  # N
    nb_targets = targets.shape[0]  # M

    diffs = x[:, np.newaxis] - targets

    match vector_shape:
        case ():
            distances = np.linalg.norm(diffs[:, np.newaxis], ord=norm_ord, axis=1)
        case (_,):
            distances = np.linalg.norm(diffs, ord=norm_ord, axis=2)
        case (_, _):
            distances = np.linalg.norm(diffs, ord=norm_ord, axis=(2, 3))
        case _:  # Tensors
            # Reshape to 1d vectors
            diffs = diffs.reshape(nb_vectors, nb_targets, -1)
            distances = np.linalg.norm(diffs, ord=norm_ord, axis=2)

    closest_indices = np.argmin(distances, axis=0)
    if do_unsqueeze:
        closest_indices = closest_indices[0]

    return closest_indices

protocols

Useful protocols for structural subtyping.

SupportsRichComparisonT = TypeVar('SupportsRichComparisonT', bound=SupportsRichComparison) module-attribute

SupportsAllComparisons

Bases: SupportsDunderLT[T], SupportsDunderGT[T], SupportsDunderLE[T], SupportsDunderGE[T], Protocol

__ge__(other)

Source code in src/kajihs_utils/protocols.py
32
def __ge__(self, other: _T_contra, /) -> bool: ...

__gt__(other)

Source code in src/kajihs_utils/protocols.py
22
def __gt__(self, other: _T_contra, /) -> bool: ...

__le__(other)

Source code in src/kajihs_utils/protocols.py
27
def __le__(self, other: _T_contra, /) -> bool: ...

__lt__(other)

Source code in src/kajihs_utils/protocols.py
17
def __lt__(self, other: _T_contra, /) -> bool: ...

SupportsDunderGE

Bases: Protocol[_T_contra]

__ge__(other)

Source code in src/kajihs_utils/protocols.py
32
def __ge__(self, other: _T_contra, /) -> bool: ...

SupportsDunderGT

Bases: Protocol[_T_contra]

__gt__(other)

Source code in src/kajihs_utils/protocols.py
22
def __gt__(self, other: _T_contra, /) -> bool: ...

SupportsDunderLE

Bases: Protocol[_T_contra]

__le__(other)

Source code in src/kajihs_utils/protocols.py
27
def __le__(self, other: _T_contra, /) -> bool: ...

SupportsDunderLT

Bases: Protocol[_T_contra]

__lt__(other)

Source code in src/kajihs_utils/protocols.py
17
def __lt__(self, other: _T_contra, /) -> bool: ...

pyplot

Utils for matplotlib.pyplot.

auto_subplot(size, /, ratio=9 / 16, *, more_cols=False, transposed=False, return_all_axes=False, **subplot_params)

Automatically creates a subplot grid with an adequate number of rows and columns.

Parameters:

  • size (int) –

    The total number of subplots.

  • ratio (float, default: 9 / 16 ) –

    The threshold aspect ratio between rows and columns.

  • more_cols (bool, default: False ) –

    Whether there should be columns than rows instead of the opposite

  • transposed (bool, default: False ) –

    Whether to transpose the indexes before flattening.

  • return_all_axes (bool, default: False ) –

    Whether to return axis beyond size in the flatten axes.

  • **subplot_params (Any, default: {} ) –

    Additional keyword parameters for subplot.

Returns:

  • tuple[Figure, AxesFlatArray]

    Tuple containing the figure and the flatten axes.

Source code in src/kajihs_utils/pyplot.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def auto_subplot(
    size: int,
    /,
    ratio: float = 9 / 16,
    *,
    more_cols: bool = False,
    transposed: bool = False,
    return_all_axes: bool = False,
    **subplot_params: Any,
) -> tuple[Figure, AxesFlatArray]:
    """
    Automatically creates a subplot grid with an adequate number of rows and columns.

    Args:
        size: The total number of subplots.
        ratio: The threshold aspect ratio between rows and columns.
        more_cols: Whether there should be columns than rows instead of the
            opposite
        transposed: Whether to transpose the indexes before flattening.
        return_all_axes: Whether to return axis beyond size in the flatten axes.
        **subplot_params: Additional keyword parameters for subplot.

    Returns:
        Tuple containing the figure and the flatten axes.
    """
    # Special case for 2
    large, small = (2, 1) if size == 2 else almost_factors(size, ratio)  # noqa: PLR2004
    rows, cols = (small, large) if more_cols else (large, small)

    fig, axes = plt.subplots(rows, cols, **subplot_params)

    # if isinstance(axes, np.ndarray):
    #     axes = axes.flatten()
    if transposed:
        axes = axes.T
    axes: AxesFlatArray = axes.flatten()

    # Hide the remaining axes if there are more axes than subplots
    for i in range(size, len(axes)):
        axes[i].set_axis_off()

    axes = axes if return_all_axes else cast(AxesFlatArray, axes[:size])  # noqa: TC006

    return fig, axes

whenever

Utils for better date and times, specifically using whenever.

dt_to_system_datetime(dt)

Convert into exact time by assuming system timezone if necessary.

Source code in src/kajihs_utils/whenever.py
12
13
14
15
16
17
18
def dt_to_system_datetime(dt: datetime) -> SystemDateTime:
    """Convert into exact time by assuming system timezone if necessary."""
    return (
        PlainDateTime.from_py_datetime(dt).assume_system_tz()
        if dt.tzinfo is None
        else SystemDateTime.from_py_datetime(dt)
    )