Amir
Amir

Reputation: 16597

Partitioned matrix multiplication in tensorflow or pytorch

Assume I have matrices P with the size [4, 4] which partitioned (block) into 4 smaller matrices [2,2]. How can I efficiently multiply this block-matrix into another matrix (not partitioned matrix but smaller)?

Let's Assume our original matric is:

P = [ 1 1 2 2
      1 1 2 2
      3 3 4 4
      3 3 4 4]

Which split into submatrices:

P_1 = [1 1    , P_2 = [2 2  , P_3 = [3 3   P_4 = [4 4
       1 1]            2 2]          3 3]         4 4]

Now our P is:

P = [P_1 P_2
     P_3 p_4]

In the next step, I want to do element-wise multiplication between P and smaller matrices which its size is equal to number of sub-matrices:

P * [ 1 0   =   [P_1  0  = [1 1 0 0 
      0 0 ]      0    0]    1 1 0 0
                            0 0 0 0
                            0 0 0 0]    

Upvotes: 8

Views: 2508

Answers (5)

GZ0
GZ0

Reputation: 4273

Following is a general Tensorflow-based solution that works for input matrices p (large) and m (small) of arbitrary shapes as long as the sizes of p are divisible by the sizes of m on both axes.

def block_mul(p, m):
   p_x, p_y = p.shape
   m_x, m_y = m.shape
   m_4d = tf.reshape(m, (m_x, 1, m_y, 1))
   m_broadcasted = tf.broadcast_to(m_4d, (m_x, p_x // m_x, m_y, p_y // m_y))
   mp = tf.reshape(m_broadcasted, (p_x, p_y))
   return p * mp

Test:

import tensorflow as tf

tf.enable_eager_execution()

p = tf.reshape(tf.constant(range(36)), (6, 6))
m = tf.reshape(tf.constant(range(9)), (3, 3))
print(f"p:\n{p}\n")
print(f"m:\n{m}\n")
print(f"block_mul(p, m):\n{block_mul(p, m)}")

Output (Python 3.7.3, Tensorflow 1.13.1):

p:
[[ 0  1  2  3  4  5]
 [ 6  7  8  9 10 11]
 [12 13 14 15 16 17]
 [18 19 20 21 22 23]
 [24 25 26 27 28 29]
 [30 31 32 33 34 35]]

m:
[[0 1 2]
 [3 4 5]
 [6 7 8]]

block_mul(p, m):
[[  0   0   2   3   8  10]
 [  0   0   8   9  20  22]
 [ 36  39  56  60  80  85]
 [ 54  57  80  84 110 115]
 [144 150 182 189 224 232]
 [180 186 224 231 272 280]]

Another solution that uses implicit broadcasting is the following:

def block_mul2(p, m):
   p_x, p_y = p.shape
   m_x, m_y = m.shape
   p_4d = tf.reshape(p, (m_x, p_x // m_x, m_y, p_y // m_y))
   m_4d = tf.reshape(m, (m_x, 1, m_y, 1))
   return tf.reshape(p_4d * m_4d, (p_x, p_y))

Upvotes: 2

iacolippo
iacolippo

Reputation: 4513

If the matrices are small, you are probably fine with cat or pad. The solution with factorization is very elegant, as the one with a block_mul implementation.

Another solution is turning the 2D block matrix in a 3D volume where each 2D slice is a block (P_1, P_2, P_3, P_4). Then use the power of broadcasting to multiply each 2D slice by a scalar. Finally reshape the output. Reshaping is not immediate but it's doable, port from numpy to pytorch of https://stackoverflow.com/a/16873755/4892874

In Pytorch:

import torch

h = w = 4
x = torch.ones(h, w)
x[:2, 2:] = 2
x[2:, :2] = 3
x[2:, 2:] = 4

# number of blocks along x and y
nrows=2
ncols=2

vol3d = x.reshape(h//nrows, nrows, -1, ncols)
vol3d = vol3d.permute(0, 2, 1, 3).reshape(-1, nrows, ncols)

out = vol3d * torch.Tensor([1, 0, 0, 0])[:, None, None].float()

# reshape to original
n, nrows, ncols = out.shape
out = out.reshape(h//nrows, -1, nrows, ncols)
out = out.permute(0, 2, 1, 3)
out = out.reshape(h, w)

print(out)

tensor([[1., 1., 0., 0.],
        [1., 1., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])

I haven't benchmarked this against the others, but this doesn't consume additional memory like padding would do and it doesn't do slow operations like concatenation. It has also ther advantage of being easy to understand and visualize.

You can generalize it to any situation by playing with h, w, nrows, ncols.

Upvotes: 1

Shai
Shai

Reputation: 114866

You can think of representing your large block matrix in a more efficient way.

For instance, a block matrix

P = [ 1 1 2 2
      1 1 2 2
      3 3 4 4
      3 3 4 4]

Can be represented using

a = [ 1 0    b = [ 1 1 0 0    p = [ 1 2
      1 0          0 0 1 1 ]        3 4 ]
      0 1
      0 1 ]

As

P = a @ p @ b

With (@ representing matrix multiplication). Matrices a and b represents/encode the block structure of P and the small p represents the values of each block.

Now, if you want to multiply (element-wise) p with a small (2x2) matrix q you simply

a @ (p * q) @ b

A simple pytorch example

In [1]: a = torch.tensor([[1., 0], [1., 0], [0., 1], [0, 1]])
In [2]: b = torch.tensor([[1., 1., 0, 0], [0, 0, 1., 1]]) 
In [3]: p=torch.tensor([[1., 2.], [3., 4.]])
In [4]: q = torch.tensor([[1., 0], [0., 0]])

In [5]: a @ p @ b
Out[5]:
tensor([[1., 1., 2., 2.],
        [1., 1., 2., 2.],
        [3., 3., 4., 4.],
        [3., 3., 4., 4.]])
In [6]: a @ (p*q) @ b
Out[6]:
tensor([[1., 1., 0., 0.],
        [1., 1., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])

I leave it to you as an exercise how to efficiently produce the "structure" matrices a and b given the sizes of the blocks.

Upvotes: 4

Amir
Amir

Reputation: 16597

Although the other answer may be the solution, it is not an efficient way. I come up with another one to tackle the problem (but still is not perfect). The following implementation needs too much memory when our inputs are 3 or 4 dimensions. For example, for input size of 20*75*1024*1024, the following calculation needs around 12gb ram.

Here is my implementation:

import tensorflow as tf

tf.enable_eager_execution()


inps = tf.constant([
    [1, 1, 1, 1, 2, 2, 2, 2],
    [1, 1, 1, 1, 2, 2, 2, 2],
    [1, 1, 1, 1, 2, 2, 2, 2],
    [1, 1, 1, 1, 2, 2, 2, 2],
    [3, 3, 3, 3, 4, 4, 4, 4],
    [3, 3, 3, 3, 4, 4, 4, 4],
    [3, 3, 3, 3, 4, 4, 4, 4],
    [3, 3, 3, 3, 4, 4, 4, 4]])

on_cells = tf.constant([[1, 0, 0, 1]])

on_cells = tf.expand_dims(on_cells, axis=-1)

# replicate the value to block-size (4*4)
on_cells = tf.tile(on_cells, [1, 1, 4 * 4])

# reshape to a format for permutation
on_cells = tf.reshape(on_cells, (1, 2, 2, 4, 4))

# permutation
on_cells = tf.transpose(on_cells, [0, 1, 3, 2, 4])

# reshape
on_cells = tf.reshape(on_cells, [1, 8, 8])

# element-wise operation
print(inps * on_cells)

Output:

tf.Tensor(
[[[1 1 1 1 0 0 0 0]
  [1 1 1 1 0 0 0 0]
  [1 1 1 1 0 0 0 0]
  [1 1 1 1 0 0 0 0]
  [0 0 0 0 4 4 4 4]
  [0 0 0 0 4 4 4 4]
  [0 0 0 0 4 4 4 4]
  [0 0 0 0 4 4 4 4]]], shape=(1, 8, 8), dtype=int32)

Upvotes: 0

Anubhav Singh
Anubhav Singh

Reputation: 8709

Don't know about the efficient method, but you can try these:

Method 1:

Using torch.cat()

import torch

def multiply(a, b):
    x1 = a[0:2, 0:2]*b[0,0]
    x2 = a[0:2, 2:]*b[0,1]
    x3 = a[2:, 0:2]*b[1,0]
    x4 = a[2:, 2:]*b[1,1]
    return torch.cat((torch.cat((x1, x2), 1), torch.cat((x3, x4), 1)), 0)

a = torch.tensor([[1, 1, 2, 2],[1, 1, 2, 2],[3, 3, 4, 4,],[3, 3, 4, 4]])
b = torch.tensor([[1, 0],[0, 0]])
print(multiply(a, b))

output:

tensor([[1, 1, 0, 0],
        [1, 1, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0]])

Method 2:

Using torch.nn.functional.pad()

import torch.nn.functional as F
import torch

def multiply(a, b):
    b = F.pad(input=b, pad=(1, 1, 1, 1), mode='constant', value=0)
    b[0,0] = 1
    b[0,1] = 1
    b[1,0] = 1
    return a*b

a = torch.tensor([[1, 1, 2, 2],[1, 1, 2, 2],[3, 3, 4, 4,],[3, 3, 4, 4]])
b = torch.tensor([[1, 0],[0, 0]])
print(multiply(a, b))

output:

tensor([[1, 1, 0, 0],
        [1, 1, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0]])

Upvotes: 1

Related Questions