|
3 | 3 | // Copyright(c) 2022 Intel Corporation. All rights reserved. |
4 | 4 | // |
5 | 5 | // Author: Seppo Ingalsuo <seppo.ingalsuo@linux.intel.com> |
| 6 | +// Shriram Shastry <malladi.sastry@linux.intel.com> |
| 7 | +// |
6 | 8 |
|
7 | 9 | #include <sof/math/matrix.h> |
8 | 10 | #include <errno.h> |
9 | 11 | #include <stdint.h> |
10 | 12 |
|
11 | | -int mat_multiply(struct mat_matrix_16b *a, struct mat_matrix_16b *b, struct mat_matrix_16b *c) |
| 13 | +/* Performs matrix multiplication of two fixed-point 16-bit integer matrices, |
| 14 | + * storing the result in a third matrix. It accounts for fractional bits for |
| 15 | + * fixed-point arithmetic, adjusting the result accordingly. |
| 16 | + * |
| 17 | + * Arguments: |
| 18 | + * a: pointer to the first input matrix |
| 19 | + * b: pointer to the second input matrix |
| 20 | + * c: pointer to the output matrix to store result |
| 21 | + * |
| 22 | + * Return: |
| 23 | + * 0 on successful multiplication. |
| 24 | + * -EINVAL if input dimensions do not allow for multiplication. |
| 25 | + * -ERANGE if the shift operation might cause integer overflow. |
| 26 | + */ |
| 27 | +int mat_multiply(struct mat_matrix_16b *a, struct mat_matrix_16b *b, |
| 28 | + struct mat_matrix_16b *c) |
12 | 29 | { |
13 | | - int64_t s; |
14 | | - int16_t *x; |
15 | | - int16_t *y; |
16 | | - int16_t *z = c->data; |
17 | | - int i, j, k; |
18 | | - int y_inc = b->columns; |
19 | | - const int shift_minus_one = a->fractions + b->fractions - c->fractions - 1; |
| 30 | + int32_t acc; /* Accumulator for dot product calculation */ |
| 31 | + int16_t *x, *y, *z = c->data; /* Pointers for matrices a, b, and c */ |
| 32 | + int i, j, k; /* Loop counters */ |
| 33 | + int y_inc = b->columns; /* Column increment for matrix b elements */ |
| 34 | + /* Calculate shift amount for adjusting fractional bits in the result */ |
| 35 | + const int shift = a->fractions + b->fractions - c->fractions - 1; |
20 | 36 |
|
| 37 | + /* Validate matrix dimensions are compatible for multiplication */ |
21 | 38 | if (a->columns != b->rows || a->rows != c->rows || b->columns != c->columns) |
22 | 39 | return -EINVAL; |
23 | 40 |
|
24 | | - /* If all data is Q0 */ |
25 | | - if (shift_minus_one == -1) { |
26 | | - for (i = 0; i < a->rows; i++) { |
27 | | - for (j = 0; j < b->columns; j++) { |
28 | | - s = 0; |
29 | | - x = a->data + a->columns * i; |
30 | | - y = b->data + j; |
31 | | - for (k = 0; k < b->rows; k++) { |
32 | | - s += (int32_t)(*x) * (*y); |
33 | | - x++; |
34 | | - y += y_inc; |
35 | | - } |
36 | | - *z = (int16_t)s; /* For Q16.0 */ |
37 | | - z++; |
38 | | - } |
39 | | - } |
40 | | - |
41 | | - return 0; |
42 | | - } |
| 41 | + /* Check shift to ensure no integer overflow occurs during shifting */ |
| 42 | + if (shift < -1 || shift > 31) |
| 43 | + return -ERANGE; |
43 | 44 |
|
| 45 | + /* Matrix multiplication loop */ |
44 | 46 | for (i = 0; i < a->rows; i++) { |
45 | 47 | for (j = 0; j < b->columns; j++) { |
46 | | - s = 0; |
47 | | - x = a->data + a->columns * i; |
48 | | - y = b->data + j; |
| 48 | + acc = 0; /* Initialize accumulator for each element */ |
| 49 | + x = a->data + a->columns * i; /* Set x at the start of ith row of a */ |
| 50 | + y = b->data + j; /* Set y at the top of jth column of b */ |
| 51 | + /* Dot product loop */ |
49 | 52 | for (k = 0; k < b->rows; k++) { |
50 | | - s += (int32_t)(*x) * (*y); |
51 | | - x++; |
52 | | - y += y_inc; |
| 53 | + acc += (int32_t)(*x++) * (*y); /* Multiply & accumulate */ |
| 54 | + y += y_inc; /* Move to next row in the current column of b */ |
53 | 55 | } |
54 | | - *z = (int16_t)(((s >> shift_minus_one) + 1) >> 1); /*Shift to Qx.y */ |
55 | | - z++; |
| 56 | + /* Assign computed value to c matrix, adjusting for fractional bits */ |
| 57 | + if (shift == -1) |
| 58 | + *z = (int16_t)acc; |
| 59 | + else |
| 60 | + *z = (int16_t)(((acc >> shift) + 1) >> 1); |
| 61 | + z++; /* Move to the next element in the output matrix */ |
56 | 62 | } |
57 | 63 | } |
58 | 64 | return 0; |
59 | 65 | } |
60 | 66 |
|
| 67 | +/* Description: Performs element-wise multiplication of two matrices with 16-bit integer elements |
| 68 | + * and stores the result in a third matrix. Checks that all matrices have the same |
| 69 | + * dimensions and adjusts for fractional bits appropriately. This operation handles |
| 70 | + * the manipulation of fixed-point precision based on the fractional bits present in |
| 71 | + * the matrices. |
| 72 | + * |
| 73 | + * Arguments: |
| 74 | + * a - pointer to the first input matrix |
| 75 | + * b - pointer to the second input matrix |
| 76 | + * c - pointer to the output matrix where the result will be stored |
| 77 | + * |
| 78 | + * Returns: |
| 79 | + * 0 on successful multiplication, |
| 80 | + * -EINVAL if input pointers are NULL or matrix dimensions do not match. |
| 81 | + */ |
61 | 82 | int mat_multiply_elementwise(struct mat_matrix_16b *a, struct mat_matrix_16b *b, |
62 | 83 | struct mat_matrix_16b *c) |
63 | | -{ int64_t p; |
| 84 | +{ |
64 | 85 | int16_t *x = a->data; |
65 | 86 | int16_t *y = b->data; |
66 | 87 | int16_t *z = c->data; |
67 | | - int i; |
68 | | - const int shift_minus_one = a->fractions + b->fractions - c->fractions - 1; |
| 88 | + int32_t prod; |
69 | 89 |
|
70 | | - if (a->columns != b->columns || b->columns != c->columns || |
71 | | - a->rows != b->rows || b->rows != c->rows) { |
| 90 | + /* Validate matrix dimensions and non-null pointers */ |
| 91 | + if (!a || !b || !c || a->columns != b->columns || a->rows != b->rows) |
72 | 92 | return -EINVAL; |
73 | | - } |
74 | 93 |
|
75 | | - /* If all data is Q0 */ |
76 | | - if (shift_minus_one == -1) { |
77 | | - for (i = 0; i < a->rows * a->columns; i++) { |
| 94 | + /* Compute the total number of elements in the matrices */ |
| 95 | + const int total_elements = a->rows * a->columns; |
| 96 | + /* Compute the required bit shift based on the fractional part of each matrix */ |
| 97 | + const int shift = a->fractions + b->fractions - c->fractions - 1; |
| 98 | + |
| 99 | + /* Perform multiplication with or without adjusting the fractional bits */ |
| 100 | + if (shift == -1) { |
| 101 | + /* Direct multiplication when no adjustment for fractional bits is needed */ |
| 102 | + for (int i = 0; i < total_elements; i++, x++, y++, z++) |
78 | 103 | *z = *x * *y; |
79 | | - x++; |
80 | | - y++; |
81 | | - z++; |
| 104 | + } else { |
| 105 | + /* Multiplication with rounding to account for the fractional bits */ |
| 106 | + for (int i = 0; i < total_elements; i++, x++, y++, z++) { |
| 107 | + /* Multiply elements as int32_t */ |
| 108 | + prod = (int32_t)(*x) * (*y); |
| 109 | + /* Adjust and round the result */ |
| 110 | + *z = (int16_t)(((prod >> shift) + 1) >> 1); |
82 | 111 | } |
83 | | - |
84 | | - return 0; |
85 | | - } |
86 | | - |
87 | | - for (i = 0; i < a->rows * a->columns; i++) { |
88 | | - p = (int32_t)(*x) * *y; |
89 | | - *z = (int16_t)(((p >> shift_minus_one) + 1) >> 1); /*Shift to Qx.y */ |
90 | | - x++; |
91 | | - y++; |
92 | | - z++; |
93 | 112 | } |
94 | 113 |
|
95 | 114 | return 0; |
|
0 commit comments