NumCpp  1.0
A C++ implementation of the Python Numpy library
dot.hpp
Go to the documentation of this file.
1 #pragma once
30 
31 #include "NumCpp/NdArray.hpp"
32 
33 #include <complex>
34 
35 namespace nc
36 {
37  //============================================================================
38  // Method Description:
47  template<typename dtype>
48  NdArray<dtype> dot(const NdArray<dtype>& inArray1, const NdArray<dtype>& inArray2)
49  {
50  return inArray1.dot(inArray2);
51  }
52 
53  //============================================================================
54  // Method Description:
66  template<typename dtype>
67  NdArray<std::complex<dtype>> dot(const NdArray<dtype>& inArray1, const NdArray<std::complex<dtype>>& inArray2)
68  {
70 
71  const auto shape1 = inArray1.shape();
72  const auto shape2 = inArray2.shape();
73 
74  if (shape1 == shape2 && (shape1.rows == 1 || shape1.cols == 1))
75  {
76  const std::complex<dtype> dotProduct = std::inner_product(inArray1.cbegin(), inArray1.cend(),
77  inArray2.cbegin(), std::complex<dtype>{0});
78  NdArray<std::complex<dtype>> returnArray = { dotProduct };
79  return returnArray;
80  }
81  else if (shape1.cols == shape2.rows)
82  {
83  // 2D array, use matrix multiplication
84  NdArray<std::complex<dtype>> returnArray(shape1.rows, shape2.cols);
85  auto array2T = inArray2.transpose();
86 
87  for (uint32 i = 0; i < shape1.rows; ++i)
88  {
89  for (uint32 j = 0; j < shape2.cols; ++j)
90  {
91  returnArray(i, j) = std::inner_product(array2T.cbegin(j), array2T.cend(j),
92  inArray1.cbegin(i), std::complex<dtype>{0});
93  }
94  }
95 
96  return returnArray;
97  }
98  else
99  {
100  std::string errStr = "shapes of [" + utils::num2str(shape1.rows) + ", " + utils::num2str(shape1.cols) + "]";
101  errStr += " and [" + utils::num2str(shape2.rows) + ", " + utils::num2str(shape2.cols) + "]";
102  errStr += " are not consistent.";
104  }
105 
106  return NdArray<std::complex<dtype>>(); // get rid of compiler warning
107  }
108 
109  //============================================================================
110  // Method Description:
122  template<typename dtype>
123  NdArray<std::complex<dtype>> dot(const NdArray<std::complex<dtype>>& inArray1, const NdArray<dtype>& inArray2)
124  {
126 
127  const auto shape1 = inArray1.shape();
128  const auto shape2 = inArray2.shape();
129 
130  if (shape1 == shape2 && (shape1.rows == 1 || shape1.cols == 1))
131  {
132  const std::complex<dtype> dotProduct = std::inner_product(inArray1.cbegin(), inArray1.cend(),
133  inArray2.cbegin(), std::complex<dtype>{0});
134  NdArray<std::complex<dtype>> returnArray = { dotProduct };
135  return returnArray;
136  }
137  else if (shape1.cols == shape2.rows)
138  {
139  // 2D array, use matrix multiplication
140  NdArray<std::complex<dtype>> returnArray(shape1.rows, shape2.cols);
141  auto array2T = inArray2.transpose();
142 
143  for (uint32 i = 0; i < shape1.rows; ++i)
144  {
145  for (uint32 j = 0; j < shape2.cols; ++j)
146  {
147  returnArray(i, j) = std::inner_product(array2T.cbegin(j), array2T.cend(j),
148  inArray1.cbegin(i), std::complex<dtype>{0});
149  }
150  }
151 
152  return returnArray;
153  }
154  else
155  {
156  std::string errStr = "shapes of [" + utils::num2str(shape1.rows) + ", " + utils::num2str(shape1.cols) + "]";
157  errStr += " and [" + utils::num2str(shape2.rows) + ", " + utils::num2str(shape2.cols) + "]";
158  errStr += " are not consistent.";
160  }
161 
162  return NdArray<std::complex<dtype>>(); // get rid of compiler warning
163  }
164 }
nc::NdArray::shape
Shape shape() const noexcept
Definition: NdArrayCore.hpp:4296
STATIC_ASSERT_ARITHMETIC
#define STATIC_ASSERT_ARITHMETIC(dtype)
Definition: StaticAsserts.hpp:38
nc::NdArray::dot
NdArray< dtype > dot(const NdArray< dtype > &inOtherArray) const
Definition: NdArrayCore.hpp:2635
nc::utils::num2str
std::string num2str(dtype inNumber)
Definition: num2str.hpp:47
nc::NdArray::transpose
NdArray< dtype > transpose() const
Definition: NdArrayCore.hpp:4591
nc::dot
NdArray< dtype > dot(const NdArray< dtype > &inArray1, const NdArray< dtype > &inArray2)
Definition: dot.hpp:48
nc::NdArray< dtype >
nc::constants::j
constexpr auto j
Definition: Constants.hpp:46
nc::uint32
std::uint32_t uint32
Definition: Types.hpp:41
NdArray.hpp
nc::NdArray::cend
const_iterator cend() const noexcept
Definition: NdArrayCore.hpp:1491
nc
Definition: Coordinate.hpp:45
THROW_INVALID_ARGUMENT_ERROR
#define THROW_INVALID_ARGUMENT_ERROR(msg)
Definition: Error.hpp:37
nc::NdArray::cbegin
const_iterator cbegin() const noexcept
Definition: NdArrayCore.hpp:1147