## Numpy: make batched version of quaternion multiplication

Question

I transformed the following function

```
def quaternion_multiply(quaternion0, quaternion1):
"""Return multiplication of two quaternions.
>>> q = quaternion_multiply([1, -2, 3, 4], [-5, 6, 7, 8])
>>> numpy.allclose(q, [-44, -14, 48, 28])
True
"""
x0, y0, z0, w0 = quaternion0
x1, y1, z1, w1 = quaternion1
return numpy.array((
x1*w0 + y1*z0 - z1*y0 + w1*x0,
-x1*z0 + y1*w0 + z1*x0 + w1*y0,
x1*y0 - y1*x0 + z1*w0 + w1*z0,
-x1*x0 - y1*y0 - z1*z0 + w1*w0), dtype=numpy.float64)
```

to a batched version

```
def quat_multiply(self, quaternion0, quaternion1):
x0, y0, z0, w0 = np.split(quaternion0, 4, 1)
x1, y1, z1, w1 = np.split(quaternion1, 4, 1)
result = np.array((
x1*w0 + y1*z0 - z1*y0 + w1*x0,
-x1*z0 + y1*w0 + z1*x0 + w1*y0,
x1*y0 - y1*x0 + z1*w0 + w1*z0,
-x1*x0 - y1*y0 - z1*z0 + w1*w0), dtype=np.float64)
return np.transpose(np.squeeze(result))
```

This function handles quaternion1 and quaternion0 with shape (?,4). Now I want that the function can handle an arbitrary number of dimensions, such as (?,?,4). How to do this?

Show source

## Answers ( 3 )

You could make use of

`np.rollaxis`

to bring the last axis to the front, helping us in slicing out the 4 arrays without actually splitting them. We perform the required operations and finallysendback the first axis to the end to keep the output array shape same as the inputs. Thus, we would have a solution for generic n-dimensional ndarrays, like so -Sample run -

You're almost there! You just need to be a little careful about how you're splitting and concatenating your array:

Here, we're using

`axis=-1`

both times to split along the last axis, and then concatenate back along the last axis. Finally, we squeeze out the second-to-last axis, as you correctly noticed. And to show you that it works:Hope that's what you needed! This should work on arbitrary dimensions, and arbitrary number of dimensions.

Note:`np.split`

appears not to work on lists. So you can only pass arrays to your new function, as I've done above. If you want to be able to pass lists, you can instead callinside your function.

Also, your test case appears to be wrong. I think you've swapped the positions of

`quaternion0`

and`quaternion1`

: I've swapped them back above while testing`q0`

and`q1`

.You can get the behavior you are after by simply passing

`axis-=-1`

to`np.split`

to split along the last axis.And since your arrays have that annoying size 1 trailing dimension, rather than stacking along a new dimension, then squeezing that one away, you can simply concatenate them, again along the (last)

`axis=-1`

:Note that, with this approach, not only can you multiply identically shaped quaternion stacks of any number of dimensions:

But you also get the nice broadcasting that allows you to i.e. multiply a stack of quaternions with a single one without having to fiddle with the dimensions:

Or with minimal fiddling do all cross products between two stacks in a single line: