fml  0.1-0
Fused Matrix Library
crossprod.hh
1 // This file is part of fml which is released under the Boost Software
2 // License, Version 1.0. See accompanying file LICENSE or copy at
3 // https://www.boost.org/LICENSE_1_0.txt
4 
5 #ifndef FML_CPU_LINALG_CROSSPROD_H
6 #define FML_CPU_LINALG_CROSSPROD_H
7 #pragma once
8 
9 
10 #include "../cpumat.hh"
11 
12 #include "internals/blas.hh"
13 
14 
15 namespace fml
16 {
17 namespace linalg
18 {
36  template <typename REAL>
37  void crossprod(const REAL alpha, const cpumat<REAL> &x, cpumat<REAL> &ret)
38  {
39  const len_t m = x.nrows();
40  const len_t n = x.ncols();
41 
42  if (n != ret.nrows() || n != ret.ncols())
43  ret.resize(n, n);
44 
45  ret.fill_zero();
46  fml::blas::syrk('L', 'T', n, m, alpha, x.data_ptr(), m, (REAL)0.0, ret.data_ptr(), n);
47  }
48 
50  template <typename REAL>
51  cpumat<REAL> crossprod(const REAL alpha, const cpumat<REAL> &x)
52  {
53  const len_t n = x.ncols();
54  cpumat<REAL> ret(n, n);
55 
56  crossprod(alpha, x, ret);
57 
58  return ret;
59  }
60 
61 
62 
80  template <typename REAL>
81  void tcrossprod(const REAL alpha, const cpumat<REAL> &x, cpumat<REAL> &ret)
82  {
83  const len_t m = x.nrows();
84  const len_t n = x.ncols();
85 
86  if (m != ret.nrows() || m != ret.ncols())
87  ret.resize(m, m);
88 
89  ret.fill_zero();
90  fml::blas::syrk('L', 'N', m, n, alpha, x.data_ptr(), m, (REAL)0.0, ret.data_ptr(), m);
91  }
92 
93  template <typename REAL>
94  cpumat<REAL> tcrossprod(const REAL alpha, const cpumat<REAL> &x)
95  {
96  const len_t m = x.nrows();
97  cpumat<REAL> ret(m, m);
98 
99  tcrossprod(alpha, x, ret);
100 
101  return ret;
102  }
103 }
104 }
105 
106 
107 #endif
fml::cpumat
Matrix class for data held on a single CPU.
Definition: cpumat.hh:36
fml::linalg::crossprod
void crossprod(const REAL alpha, const cpumat< REAL > &x, cpumat< REAL > &ret)
Computes lower triangle of alpha*x^T*x.
Definition: crossprod.hh:37
fml::cpumat::fill_zero
void fill_zero()
Set all values to zero.
Definition: cpumat.hh:362
fml::unimat::nrows
len_t nrows() const
Number of rows.
Definition: unimat.hh:36
fml::cpumat::resize
void resize(len_t nrows, len_t ncols)
Resize the internal object storage.
Definition: cpumat.hh:233
fml::unimat::ncols
len_t ncols() const
Number of columns.
Definition: unimat.hh:38
fml::unimat::data_ptr
REAL * data_ptr()
Pointer to the internal array.
Definition: unimat.hh:40
fml
Core namespace.
Definition: dimops.hh:10
fml::linalg::tcrossprod
void tcrossprod(const REAL alpha, const cpumat< REAL > &x, cpumat< REAL > &ret)
Computes lower triangle of alpha*x*x^T.
Definition: crossprod.hh:81