NumCpp  2.1.0
A C++ implementation of the Python Numpy library
where.hpp
Go to the documentation of this file.
1 #pragma once
30 
32 #include "NumCpp/Core/Shape.hpp"
33 #include "NumCpp/NdArray.hpp"
34 
35 #include <string>
36 
37 namespace nc
38 {
39  //============================================================================
40  // Method Description:
52  template<typename dtype>
53  NdArray<dtype> where(const NdArray<bool>& inMask, const NdArray<dtype>& inA, const NdArray<dtype>& inB)
54  {
55  const auto shapeMask = inMask.shape();
56  const auto shapeA = inA.shape();
57  if (shapeA != inB.shape())
58  {
59  THROW_INVALID_ARGUMENT_ERROR("input inA and inB must be the same shapes.");
60  }
61 
62  if (shapeMask != shapeA)
63  {
64  THROW_INVALID_ARGUMENT_ERROR("input inMask must be the same shape as the input arrays.");
65  }
66 
67  auto outArray = NdArray<dtype>(shapeMask);
68 
69  uint32 idx = 0;
70  for (auto maskValue : inMask)
71  {
72  if (maskValue)
73  {
74  outArray[idx] = inA[idx];
75  }
76  else
77  {
78  outArray[idx] = inB[idx];
79  }
80  ++idx;
81  }
82 
83  return outArray;
84  }
85 } // namespace nc
nc::NdArray::shape
Shape shape() const noexcept
Definition: NdArrayCore.hpp:4312
Error.hpp
nc::where
NdArray< dtype > where(const NdArray< bool > &inMask, const NdArray< dtype > &inA, const NdArray< dtype > &inB)
Definition: where.hpp:53
nc::NdArray< dtype >
nc::uint32
std::uint32_t uint32
Definition: Types.hpp:41
NdArray.hpp
Shape.hpp
nc
Definition: Coordinate.hpp:45
THROW_INVALID_ARGUMENT_ERROR
#define THROW_INVALID_ARGUMENT_ERROR(msg)
Definition: Error.hpp:37