Artificial Intelligence Fight IV. – Introducing 8×8 matrix partitions and Advanced Vector Extensions (AVX) intrinsics

I woke up at midnight and thought – what if I tried to further optimise the SSE block multiplier? I started browsing the intrinsics documentation from Intel, and I quickly came to the conclusion that it would not be too hard to implement AVX intrinsics instead of SSE, which allows blocks of 8×8 single precision matrices to be multiplied. About 15 minutes later I am here and already looking at the results.

The AVX inistruction set is also known as Sandy Bridge extensions and was released by both Intel and AMD in 2011. It supports calculations with vectors of 8 float values – perfect for a 8×8 matrix product implementation.

In the C# WinForms test harness application I have changed the graph update frequency from 20 iterations to 1000, because this update causes noticeable slowdown in the overall performance when I am testing with a large pattern set of 30000 patterns.

This pattern set has 13 inputs and 8 outputs. Up until now I’ve been using one hidden layer of 100 neurons for testing, now I decided to change this network. Instead of one I am using two hidden layers with 64 neurons in each – the connections of which will fit nicely into the 8×8 block partitions.

The results are great, I can see about 30% performance improvement compared to the SSE solution. However I am still planning to retest the SSE version with this new network architecture, as it may also provide better results with a larger network.

On a 4GHz core i7 computer with 8 hardware threads the speed reached 318 MCpS (which is 9937 kCpS/GHz/core), and on the dual xeon with 40 hardware threads the speed reached 935 MCpS (which is 7791 kCpS/GHz/core). Getting very close to the GCpS territory aren’t we? 🙂

Until I find a better way to include code in this blog I am adding a picture of the AVX multiplier function (you can find it in my github anyway):

[code language=”cpp”]
struct Mat88f { float value[8][8]; };

void MultiplyAndAddMM8x8F(Mat88f &result, const Mat88f &MatrixA, const Mat88f &MatrixB)
{
__m256 rightRow0 = _mm256_loadu_ps((const float *)MatrixB.value[0]);
__m256 rightRow1 = _mm256_loadu_ps((const float *)MatrixB.value[1]);
__m256 rightRow2 = _mm256_loadu_ps((const float *)MatrixB.value[2]);
__m256 rightRow3 = _mm256_loadu_ps((const float *)MatrixB.value[3]);
__m256 rightRow4 = _mm256_loadu_ps((const float *)MatrixB.value[4]);
__m256 rightRow5 = _mm256_loadu_ps((const float *)MatrixB.value[5]);
__m256 rightRow6 = _mm256_loadu_ps((const float *)MatrixB.value[6]);
__m256 rightRow7 = _mm256_loadu_ps((const float *)MatrixB.value[7]);

__m256 resultRow0 = _mm256_mul_ps(rightRow0, _mm256_set1_ps(MatrixA.value[0][0]));
resultRow0 = _mm256_add_ps(resultRow0, _mm256_mul_ps(rightRow1, _mm256_set1_ps(MatrixA.value[0][1])));
resultRow0 = _mm256_add_ps(resultRow0, _mm256_mul_ps(rightRow2, _mm256_set1_ps(MatrixA.value[0][2])));
resultRow0 = _mm256_add_ps(resultRow0, _mm256_mul_ps(rightRow3, _mm256_set1_ps(MatrixA.value[0][3])));
resultRow0 = _mm256_add_ps(resultRow0, _mm256_mul_ps(rightRow4, _mm256_set1_ps(MatrixA.value[0][4])));
resultRow0 = _mm256_add_ps(resultRow0, _mm256_mul_ps(rightRow5, _mm256_set1_ps(MatrixA.value[0][5])));
resultRow0 = _mm256_add_ps(resultRow0, _mm256_mul_ps(rightRow6, _mm256_set1_ps(MatrixA.value[0][6])));
resultRow0 = _mm256_add_ps(resultRow0, _mm256_mul_ps(rightRow7, _mm256_set1_ps(MatrixA.value[0][7])));

__m256 resultRow1 = _mm256_mul_ps(rightRow0, _mm256_set1_ps(MatrixA.value[1][0]));
resultRow1 = _mm256_add_ps(resultRow1, _mm256_mul_ps(rightRow1, _mm256_set1_ps(MatrixA.value[1][1])));
resultRow1 = _mm256_add_ps(resultRow1, _mm256_mul_ps(rightRow2, _mm256_set1_ps(MatrixA.value[1][2])));
resultRow1 = _mm256_add_ps(resultRow1, _mm256_mul_ps(rightRow3, _mm256_set1_ps(MatrixA.value[1][3])));
resultRow1 = _mm256_add_ps(resultRow1, _mm256_mul_ps(rightRow4, _mm256_set1_ps(MatrixA.value[1][4])));
resultRow1 = _mm256_add_ps(resultRow1, _mm256_mul_ps(rightRow5, _mm256_set1_ps(MatrixA.value[1][5])));
resultRow1 = _mm256_add_ps(resultRow1, _mm256_mul_ps(rightRow6, _mm256_set1_ps(MatrixA.value[1][6])));
resultRow1 = _mm256_add_ps(resultRow1, _mm256_mul_ps(rightRow7, _mm256_set1_ps(MatrixA.value[1][7])));

__m256 resultRow2 = _mm256_mul_ps(rightRow0, _mm256_set1_ps(MatrixA.value[2][0]));
resultRow2 = _mm256_add_ps(resultRow2, _mm256_mul_ps(rightRow1, _mm256_set1_ps(MatrixA.value[2][1])));
resultRow2 = _mm256_add_ps(resultRow2, _mm256_mul_ps(rightRow2, _mm256_set1_ps(MatrixA.value[2][2])));
resultRow2 = _mm256_add_ps(resultRow2, _mm256_mul_ps(rightRow3, _mm256_set1_ps(MatrixA.value[2][3])));
resultRow2 = _mm256_add_ps(resultRow2, _mm256_mul_ps(rightRow4, _mm256_set1_ps(MatrixA.value[2][4])));
resultRow2 = _mm256_add_ps(resultRow2, _mm256_mul_ps(rightRow5, _mm256_set1_ps(MatrixA.value[2][5])));
resultRow2 = _mm256_add_ps(resultRow2, _mm256_mul_ps(rightRow6, _mm256_set1_ps(MatrixA.value[2][6])));
resultRow2 = _mm256_add_ps(resultRow2, _mm256_mul_ps(rightRow7, _mm256_set1_ps(MatrixA.value[2][7])));

__m256 resultRow3 = _mm256_mul_ps(rightRow0, _mm256_set1_ps(MatrixA.value[3][0]));
resultRow3 = _mm256_add_ps(resultRow3, _mm256_mul_ps(rightRow1, _mm256_set1_ps(MatrixA.value[3][1])));
resultRow3 = _mm256_add_ps(resultRow3, _mm256_mul_ps(rightRow2, _mm256_set1_ps(MatrixA.value[3][2])));
resultRow3 = _mm256_add_ps(resultRow3, _mm256_mul_ps(rightRow3, _mm256_set1_ps(MatrixA.value[3][3])));
resultRow3 = _mm256_add_ps(resultRow3, _mm256_mul_ps(rightRow4, _mm256_set1_ps(MatrixA.value[3][4])));
resultRow3 = _mm256_add_ps(resultRow3, _mm256_mul_ps(rightRow5, _mm256_set1_ps(MatrixA.value[3][5])));
resultRow3 = _mm256_add_ps(resultRow3, _mm256_mul_ps(rightRow6, _mm256_set1_ps(MatrixA.value[3][6])));
resultRow3 = _mm256_add_ps(resultRow3, _mm256_mul_ps(rightRow7, _mm256_set1_ps(MatrixA.value[3][7])));

__m256 resultRow4 = _mm256_mul_ps(rightRow0, _mm256_set1_ps(MatrixA.value[4][0]));
resultRow4 = _mm256_add_ps(resultRow4, _mm256_mul_ps(rightRow1, _mm256_set1_ps(MatrixA.value[4][1])));
resultRow4 = _mm256_add_ps(resultRow4, _mm256_mul_ps(rightRow2, _mm256_set1_ps(MatrixA.value[4][2])));
resultRow4 = _mm256_add_ps(resultRow4, _mm256_mul_ps(rightRow3, _mm256_set1_ps(MatrixA.value[4][3])));
resultRow4 = _mm256_add_ps(resultRow4, _mm256_mul_ps(rightRow4, _mm256_set1_ps(MatrixA.value[4][4])));
resultRow4 = _mm256_add_ps(resultRow4, _mm256_mul_ps(rightRow5, _mm256_set1_ps(MatrixA.value[4][5])));
resultRow4 = _mm256_add_ps(resultRow4, _mm256_mul_ps(rightRow6, _mm256_set1_ps(MatrixA.value[4][6])));
resultRow4 = _mm256_add_ps(resultRow4, _mm256_mul_ps(rightRow7, _mm256_set1_ps(MatrixA.value[4][7])));

__m256 resultRow5 = _mm256_mul_ps(rightRow0, _mm256_set1_ps(MatrixA.value[5][0]));
resultRow5 = _mm256_add_ps(resultRow5, _mm256_mul_ps(rightRow1, _mm256_set1_ps(MatrixA.value[5][1])));
resultRow5 = _mm256_add_ps(resultRow5, _mm256_mul_ps(rightRow2, _mm256_set1_ps(MatrixA.value[5][2])));
resultRow5 = _mm256_add_ps(resultRow5, _mm256_mul_ps(rightRow3, _mm256_set1_ps(MatrixA.value[5][3])));
resultRow5 = _mm256_add_ps(resultRow5, _mm256_mul_ps(rightRow4, _mm256_set1_ps(MatrixA.value[5][4])));
resultRow5 = _mm256_add_ps(resultRow5, _mm256_mul_ps(rightRow5, _mm256_set1_ps(MatrixA.value[5][5])));
resultRow5 = _mm256_add_ps(resultRow5, _mm256_mul_ps(rightRow6, _mm256_set1_ps(MatrixA.value[5][6])));
resultRow5 = _mm256_add_ps(resultRow5, _mm256_mul_ps(rightRow7, _mm256_set1_ps(MatrixA.value[5][7])));

__m256 resultRow6 = _mm256_mul_ps(rightRow0, _mm256_set1_ps(MatrixA.value[6][0]));
resultRow6 = _mm256_add_ps(resultRow6, _mm256_mul_ps(rightRow1, _mm256_set1_ps(MatrixA.value[6][1])));
resultRow6 = _mm256_add_ps(resultRow6, _mm256_mul_ps(rightRow2, _mm256_set1_ps(MatrixA.value[6][2])));
resultRow6 = _mm256_add_ps(resultRow6, _mm256_mul_ps(rightRow3, _mm256_set1_ps(MatrixA.value[6][3])));
resultRow6 = _mm256_add_ps(resultRow6, _mm256_mul_ps(rightRow4, _mm256_set1_ps(MatrixA.value[6][4])));
resultRow6 = _mm256_add_ps(resultRow6, _mm256_mul_ps(rightRow5, _mm256_set1_ps(MatrixA.value[6][5])));
resultRow6 = _mm256_add_ps(resultRow6, _mm256_mul_ps(rightRow6, _mm256_set1_ps(MatrixA.value[6][6])));
resultRow6 = _mm256_add_ps(resultRow6, _mm256_mul_ps(rightRow7, _mm256_set1_ps(MatrixA.value[6][7])));

__m256 resultRow7 = _mm256_mul_ps(rightRow0, _mm256_set1_ps(MatrixA.value[7][0]));
resultRow7 = _mm256_add_ps(resultRow7, _mm256_mul_ps(rightRow1, _mm256_set1_ps(MatrixA.value[7][1])));
resultRow7 = _mm256_add_ps(resultRow7, _mm256_mul_ps(rightRow2, _mm256_set1_ps(MatrixA.value[7][2])));
resultRow7 = _mm256_add_ps(resultRow7, _mm256_mul_ps(rightRow3, _mm256_set1_ps(MatrixA.value[7][3])));
resultRow7 = _mm256_add_ps(resultRow7, _mm256_mul_ps(rightRow4, _mm256_set1_ps(MatrixA.value[7][4])));
resultRow7 = _mm256_add_ps(resultRow7, _mm256_mul_ps(rightRow5, _mm256_set1_ps(MatrixA.value[7][5])));
resultRow7 = _mm256_add_ps(resultRow7, _mm256_mul_ps(rightRow6, _mm256_set1_ps(MatrixA.value[7][6])));
resultRow7 = _mm256_add_ps(resultRow7, _mm256_mul_ps(rightRow7, _mm256_set1_ps(MatrixA.value[7][7])));

__m256 outRow0 = _mm256_add_ps(_mm256_loadu_ps((const float *)result.value[0]), resultRow0);
__m256 outRow1 = _mm256_add_ps(_mm256_loadu_ps((const float *)result.value[1]), resultRow1);
__m256 outRow2 = _mm256_add_ps(_mm256_loadu_ps((const float *)result.value[2]), resultRow2);
__m256 outRow3 = _mm256_add_ps(_mm256_loadu_ps((const float *)result.value[3]), resultRow3);
__m256 outRow4 = _mm256_add_ps(_mm256_loadu_ps((const float *)result.value[4]), resultRow4);
__m256 outRow5 = _mm256_add_ps(_mm256_loadu_ps((const float *)result.value[5]), resultRow5);
__m256 outRow6 = _mm256_add_ps(_mm256_loadu_ps((const float *)result.value[6]), resultRow6);
__m256 outRow7 = _mm256_add_ps(_mm256_loadu_ps((const float *)result.value[7]), resultRow7);

_mm256_storeu_ps((float *)result.value[0], outRow0);
_mm256_storeu_ps((float *)result.value[1], outRow1);
_mm256_storeu_ps((float *)result.value[2], outRow2);
_mm256_storeu_ps((float *)result.value[3], outRow3);
_mm256_storeu_ps((float *)result.value[4], outRow4);
_mm256_storeu_ps((float *)result.value[5], outRow5);
_mm256_storeu_ps((float *)result.value[6], outRow6);
_mm256_storeu_ps((float *)result.value[7], outRow7);
}
[/code]

Leave a Reply

Your email address will not be published. Required fields are marked *