mirror of
https://github.com/pjreddie/darknet.git
synced 2023-08-10 21:13:14 +03:00
97 lines
2.4 KiB
C
97 lines
2.4 KiB
C
//#include "mini_blas.h"
|
|
#ifdef __cplusplus
|
|
#define PUT_IN_REGISTER
|
|
#else
|
|
#define PUT_IN_REGISTER register
|
|
#endif
|
|
|
|
void cpu_gemm_nn(int TA, int TB, int M, int N, int K, float ALPHA,
|
|
float *A, int lda,
|
|
float *B, int ldb,
|
|
float BETA,
|
|
float *C, int ldc)
|
|
{
|
|
int i,j,k;
|
|
for(i = 0; i < M; ++i){
|
|
for(k = 0; k < K; ++k){
|
|
PUT_IN_REGISTER float A_PART = ALPHA * A[i * lda + k];
|
|
for(j = 0; j < N; ++j){
|
|
C[i*ldc+j] += A_PART*B[k*ldb+j];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
void cpu_gemm_nt(int TA, int TB, int M, int N, int K, float ALPHA,
|
|
float *A, int lda,
|
|
float *B, int ldb,
|
|
float BETA,
|
|
float *C, int ldc)
|
|
{
|
|
int i,j,k;
|
|
for(i = 0; i < M; ++i){
|
|
for(j = 0; j < N; ++j){
|
|
PUT_IN_REGISTER float sum = 0;
|
|
for(k = 0; k < K; ++k){
|
|
sum += ALPHA*A[i*lda+k]*B[k+j*ldb];
|
|
}
|
|
C[i*ldc+j] += sum;
|
|
}
|
|
}
|
|
}
|
|
|
|
void cpu_gemm_tn(int TA, int TB, int M, int N, int K, float ALPHA,
|
|
float *A, int lda,
|
|
float *B, int ldb,
|
|
float BETA,
|
|
float *C, int ldc)
|
|
{
|
|
int i,j,k;
|
|
for(i = 0; i < M; ++i){
|
|
for(k = 0; k < K; ++k){
|
|
PUT_IN_REGISTER float A_PART = ALPHA * A[k * lda + i];
|
|
for(j = 0; j < N; ++j){
|
|
C[i*ldc+j] += A_PART*B[k*ldb+j];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
void cpu_gemm_tt(int TA, int TB, int M, int N, int K, float ALPHA,
|
|
float *A, int lda,
|
|
float *B, int ldb,
|
|
float BETA,
|
|
float *C, int ldc)
|
|
{
|
|
int i,j,k;
|
|
for(i = 0; i < M; ++i){
|
|
for(j = 0; j < N; ++j){
|
|
for(k = 0; k < K; ++k){
|
|
C[i*ldc+j] += ALPHA*A[i+k*lda]*B[k+j*ldb];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
void cpu_gemm(int TA, int TB, int M, int N, int K, float ALPHA,
|
|
float *A, int lda,
|
|
float *B, int ldb,
|
|
float BETA,
|
|
float *C, int ldc)
|
|
{
|
|
int i, j;
|
|
for(i = 0; i < M; ++i){
|
|
for(j = 0; j < N; ++j){
|
|
C[i*ldc + j] *= BETA;
|
|
}
|
|
}
|
|
if(!TA && !TB)
|
|
cpu_gemm_nn( TA, TB, M, N, K, ALPHA,A,lda, B, ldb,BETA,C,ldc);
|
|
else if(TA && !TB)
|
|
cpu_gemm_tn( TA, TB, M, N, K, ALPHA,A,lda, B, ldb,BETA,C,ldc);
|
|
else if(!TA && TB)
|
|
cpu_gemm_nt( TA, TB, M, N, K, ALPHA,A,lda, B, ldb,BETA,C,ldc);
|
|
else
|
|
cpu_gemm_tt( TA, TB, M, N, K, ALPHA,A,lda, B, ldb,BETA,C,ldc);
|
|
}
|