fml  0.1-0
Fused Matrix Library
crossprod.hh
1 // This file is part of fml which is released under the Boost Software
2 // License, Version 1.0. See accompanying file LICENSE or copy at
3 // https://www.boost.org/LICENSE_1_0.txt
4 
5 #ifndef FML_GPU_LINALG_CROSSPROD_H
6 #define FML_GPU_LINALG_CROSSPROD_H
7 #pragma once
8 
9 
10 #include <stdexcept>
11 
12 #include "../../_internals/linalgutils.hh"
13 
14 #include "../arch/arch.hh"
15 
16 #include "../gpumat.hh"
17 
18 #include "internals/err.hh"
19 #include "matmult.hh"
20 
21 
22 namespace fml
23 {
24 namespace linalg
25 {
43  template <typename REAL>
44  void crossprod(const REAL alpha, const gpumat<REAL> &x, gpumat<REAL> &ret)
45  {
46  err::check_card(x, ret);
47 
48  const len_t m = x.nrows();
49  const len_t n = x.ncols();
50 
51  if (n != ret.nrows() || n != ret.ncols())
52  ret.resize(n, n);
53 
54  matmult(true, false, alpha, x, x, ret);
55  }
56 
58  template <typename REAL>
59  gpumat<REAL> crossprod(const REAL alpha, const gpumat<REAL> &x)
60  {
61  const len_t n = x.ncols();
62  gpumat<REAL> ret(x.get_card(), n, n);
63 
64  crossprod(alpha, x, ret);
65 
66  return ret;
67  }
68 
69 
70 
88  template <typename REAL>
89  void tcrossprod(const REAL alpha, const gpumat<REAL> &x, gpumat<REAL> &ret)
90  {
91  err::check_card(x, ret);
92 
93  const len_t m = x.nrows();
94  const len_t n = x.ncols();
95 
96  if (m != ret.nrows() || m != ret.ncols())
97  ret.resize(m, m);
98 
99  matmult(false, true, alpha, x, x, ret);
100  }
101 
103  template <typename REAL>
104  gpumat<REAL> tcrossprod(const REAL alpha, const gpumat<REAL> &x)
105  {
106  const len_t m = x.nrows();
107  gpumat<REAL> ret(x.get_card(), m, m);
108 
109  tcrossprod(alpha, x, ret);
110 
111  return ret;
112  }
113 }
114 }
115 
116 
117 #endif
fml::linalg::crossprod
void crossprod(const REAL alpha, const cpumat< REAL > &x, cpumat< REAL > &ret)
Computes lower triangle of alpha*x^T*x.
Definition: crossprod.hh:37
fml::unimat::nrows
len_t nrows() const
Number of rows.
Definition: unimat.hh:36
fml::unimat::ncols
len_t ncols() const
Number of columns.
Definition: unimat.hh:38
fml
Core namespace.
Definition: dimops.hh:10
fml::gpumat::resize
void resize(len_t nrows, len_t ncols)
Resize the internal object storage.
Definition: gpumat.hh:256
fml::gpumat
Matrix class for data held on a single GPU.
Definition: gpumat.hh:35
fml::linalg::tcrossprod
void tcrossprod(const REAL alpha, const cpumat< REAL > &x, cpumat< REAL > &ret)
Computes lower triangle of alpha*x*x^T.
Definition: crossprod.hh:81
fml::linalg::matmult
void matmult(const bool transx, const bool transy, const REAL alpha, const cpumat< REAL > &x, const cpumat< REAL > &y, cpumat< REAL > &ret)
Computes ret = alpha*op(x)*op(y) where op(A) is A or A^T.
Definition: matmult.hh:43