Skip to content

ENH: vectorized / batch tensorsolve() #28099

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
34j opened this issue Jan 4, 2025 · 2 comments
Open

ENH: vectorized / batch tensorsolve() #28099

34j opened this issue Jan 4, 2025 · 2 comments

Comments

@34j
Copy link

34j commented Jan 4, 2025

Proposed new feature or change:

It would be nice if tensorsolve() could solve multiple tensor equations simultaneously. i.e.

a = np.random.randn((2, 2, 3, 6))
b = np.random.randn((2, 2, 3))
np.allclose(np.einsum("...ijk,...k->...ij", a, tensorsolve(a, b)), b) # tensorsolve(a, b).shape == (2, 6)

I think there are two options for the design of vectorized tensorsolve(). Let n: int be the number of degrees of freedom for a linear equation problem:

  • Case 1: Assume that a.shape[axis_batch:b.ndim] == b.shape[axis_batch:] and assume that a.shape[:axis_batch] and b.shape[:axis_batch] are broadcastable
    In this case n = prod(a.shape[b.ndim:], then we may find axis_batch where prod(a.shape[axis_batch:b.ndim]) = n (and if not an error should be raised), and reshape a to shape a.shape[:axis_batch] + (n, n), b to shape b.shape[:axis_batch] + (n) (not a.shape to make it broadcastable)
  • Case 2: Do not assume that a.shape[:b.ndim] == b.shape but assume that the batch shapes a.shape[:axis_batch] and b.shape[:axis_batch] are exactly the same (not recommended?)
    In this case n = prod(a.shape)/prod(b.shape) (and if n is not int an error should be raised), and for a, b. there exists axis_sol_{a,b} where prod({a,b}.shape[axis_batch:axis_sol_{a,b}]) == n (and if not an error should be raised), and reshape a to (-1,n,n) and b to (-1,n) a.shape[:axis_batch] + (n, n), b to shape b.shape[:axis_batch] + (n) (not a.shape to make it broadcastable)

Implementation for Case 1:

import numpy as np

def btensorsolve(a, b, axes=None):
    """
    Solve the tensor equation ``a x = b`` for x.

    It is assumed that all indices of `x` are summed over in the product,
    together with the rightmost indices of `a`, as is done in, for example,
    ``tensordot(a, x, axes=x.ndim)``.

    Parameters
    ----------
    a : array_like
        Coefficient tensor, of shape ``b.shape + Q``. `Q`, a tuple, equals
        the shape of that sub-tensor of `a` consisting of the appropriate
        number of its rightmost indices, and must be such that
        **there exists i that**
        ``prod(Q) == prod(b.shape[i:])``
    b : array_like
        Right-hand tensor, which can be of any shape.
    axes : tuple of ints, optional
        Axes in `a` to reorder to the right, before inversion.
        If None (default), no reordering is done.

    Returns
    -------
    x : ndarray, shape Q

    Raises
    ------
    LinAlgError
        If `a` is singular or not 'square' (in the above sense).

    See Also
    --------
    numpy.tensordot, tensorinv, numpy.einsum

    Examples
    --------
    >>> import numpy as np
    >>> rng = np.random.default_rng()
    >>> a = rng.normal(size=(2, 2*3, 4, 2, 3, 4))
    >>> b = rng.normal(size=(2, 2*3, 4))
    >>> x = np.linalg.tensorsolve(a, b)
    >>> x.shape
    (2, 2, 3, 4)
    >>> np.allclose(np.einsum('...ijklm,...klm->...ij', a, x), b)
    True

    """
    # https://github.com/numpy/numpy/blob/
    # e7a123b2d3eca9897843791dd698c1803d9a39c2/numpy/linalg/_linalg.py#L291
    an = a.ndim
    if axes is not None:
        allaxes = list(range(0, an))
        for k in axes:
            allaxes.remove(k)
            allaxes.insert(an, k)
        a = a.transpose(allaxes)

    # find right dimensions
    # a = [2 (dim1) 2 2 3 (dim2) 2 6]
    # b = [2 (dim1) 2 2 3 (dim2)]
    axis_sol_last = b.ndim
    if a.shape[:axis_sol_last] != b.shape:
        raise ValueError(
            f"Shapes of a and b are incompatible: " f"{a.shape} and {b.shape}"
        )

    # the dimention of the linear system
    sol_dim = np.prod(a.shape[axis_sol_last:])
    sol_dim_ = 1
    for axis_batch_last in range(axis_sol_last - 1, -1, -1):
        sol_dim_ *= a.shape[axis_batch_last]
        if sol_dim_ == sol_dim:
            break
    else:
        raise ValueError("Unable to divide batch dimensions and solution dimensions")

    a_ = a.reshape(a.shape[:axis_batch_last] + (sol_dim, sol_dim))
    b_ = b.reshape(b.shape[:axis_batch_last] + (sol_dim, 1))
    x = np.linalg.solve(a_, b_)
    return x.reshape(a.shape[:axis_batch_last] + a.shape[axis_sol_last:])
@34j
Copy link
Author

34j commented Jan 4, 2025

I realized that it is impossible to automatically detect which axis is the "batch" axis if an axis of size 1 is included in the front. for example if a.shape == (1, 2, 2) and b.shape == (1, 2) it is impossible to infer if 0th axis is "batch" axis or not; the result shape could be either (1, 2) or (2,) 😢

@34j
Copy link
Author

34j commented May 10, 2025

Since I could not get any feedback, I temporarily created a package batch-tensorsolve. However, I hope this will one day be officially available.
https://github.com/34j/batch-tensorsolve

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy