NumCpp  1.0
A C++ implementation of the Python Numpy library
softmax.hpp
Go to the documentation of this file.
1 #pragma once
30 
31 #include "NumCpp/NdArray.hpp"
32 #include "NumCpp/Core/Types.hpp"
35 #include "NumCpp/Functions/exp.hpp"
36 
37 namespace nc
38 {
39  namespace special
40  {
41  //============================================================================
42  // Method Description:
52  template<typename dtype>
54  {
56 
57  switch (inAxis)
58  {
59  case Axis::NONE:
60  {
61  auto returnArray = exp(inArray).template astype<double>();
62  returnArray /= static_cast<double>(returnArray.sum().item());
63  return returnArray;
64  }
65  case Axis::COL:
66  {
67  auto returnArray = exp(inArray).template astype<double>();
68  auto expSums = returnArray.sum(inAxis);
69 
70  for (uint32 row = 0; row < returnArray.shape().rows; ++row)
71  {
72  const double rowExpSum = static_cast<double>(expSums[row]);
73  stl_algorithms::for_each(returnArray.begin(row), returnArray.end(row),
74  [rowExpSum](double& value) { value /= rowExpSum; });
75  }
76 
77  return returnArray;
78  }
79  case Axis::ROW:
80  {
81  auto returnArray = exp(inArray.transpose()).template astype<double>();
82  auto expSums = returnArray.sum(Axis::COL);
83 
84  for (uint32 row = 0; row < returnArray.shape().rows; ++row)
85  {
86  const auto rowExpSum = static_cast<double>(expSums[row]);
87  stl_algorithms::for_each(returnArray.begin(row), returnArray.end(row),
88  [rowExpSum](double& value) { value /= rowExpSum; });
89  }
90 
91  return returnArray.transpose();
92  }
93  default:
94  {
95  // this isn't actually possible, just putting this here to get rid
96  // of the compiler warning.
97  return NdArray<double>(0);
98  }
99  }
100  }
101  }
102 }
StaticAsserts.hpp
nc::Axis::NONE
@ NONE
STATIC_ASSERT_ARITHMETIC
#define STATIC_ASSERT_ARITHMETIC(dtype)
Definition: StaticAsserts.hpp:38
nc::Axis::ROW
@ ROW
nc::stl_algorithms::for_each
void for_each(InputIt first, InputIt last, UnaryFunction f) noexcept
Definition: StlAlgorithms.hpp:214
nc::NdArray::transpose
NdArray< dtype > transpose() const
Definition: NdArrayCore.hpp:4591
nc::NdArray< double >
nc::special::softmax
NdArray< double > softmax(const NdArray< dtype > &inArray, Axis inAxis=Axis::NONE)
Definition: softmax.hpp:53
nc::uint32
std::uint32_t uint32
Definition: Types.hpp:41
NdArray.hpp
nc::Axis
Axis
Enum To describe an axis.
Definition: Types.hpp:47
nc::exp
auto exp(dtype inValue) noexcept
Definition: exp.hpp:52
nc
Definition: Coordinate.hpp:45
exp.hpp
StlAlgorithms.hpp
Types.hpp
nc::Axis::COL
@ COL