NumCpp  2.1.0
A C++ implementation of the Python Numpy library
Brent.hpp
Go to the documentation of this file.
1 #pragma once
34 
36 #include "NumCpp/Core/Types.hpp"
38 
39 #include <cmath>
40 #include <functional>
41 #include <utility>
42 
43 namespace nc
44 {
45  namespace roots
46  {
47  //================================================================================
48  // Class Description:
51  class Brent : public Iteration
52  {
53  public:
54  //============================================================================
55  // Method Description:
61  Brent(const double epsilon,
62  std::function<double(double)> f) noexcept :
63  Iteration(epsilon),
64  f_(std::move(f))
65  {}
66 
67  //============================================================================
68  // Method Description:
75  Brent(const double epsilon,
76  const uint32 maxNumIterations,
77  std::function<double(double)> f) noexcept :
78  Iteration(epsilon, maxNumIterations),
79  f_(std::move(f))
80  {}
81 
82  //============================================================================
83  // Method Description:
86  ~Brent() override = default;
87 
88  //============================================================================
89  // Method Description:
96  double solve(double a, double b)
97  {
99 
100  double fa = f_(a);
101  double fb = f_(b);
102 
103  checkAndFixAlgorithmCriteria(a, b, fa, fb);
104 
105  double lastB = a; // b_{k-1}
106  double lastFb = fa;
107  double s = DtypeInfo<double>::max();
108  double fs = DtypeInfo<double>::max();
109  double penultimateB = a; // b_{k-2}
110 
111  bool bisection = true;
112  while (std::fabs(fb) > epsilon_ && std::fabs(fs) > epsilon_ && std::fabs(b - a) > epsilon_)
113  {
114  if (useInverseQuadraticInterpolation(fa, fb, lastFb))
115  {
116  s = calculateInverseQuadraticInterpolation(a, b, lastB, fa, fb, lastFb);
117  }
118  else
119  {
120  s = calculateSecant(a, b, fa, fb);
121  }
122 
123  if (useBisection(bisection, b, lastB, penultimateB, s))
124  {
125  s = calculateBisection(a, b);
126  bisection = true;
127  }
128  else
129  {
130  bisection = false;
131  }
132 
133  fs = f_(s);
134  penultimateB = lastB;
135  lastB = b;
136 
137  if (fa * fs < 0)
138  {
139  b = s;
140  }
141  else {
142  a = s;
143  }
144 
145  fa = f_(a);
146  lastFb = fb;
147  fb = f_(b);
148  checkAndFixAlgorithmCriteria(a, b, fa, fb);
149 
151  }
152 
153  return fb < fs ? b : s;
154  }
155 
156  private:
157  //============================================================================
158  const std::function<double(double)> f_;
159 
160  //============================================================================
161  // Method Description:
168  static double calculateBisection(const double a, const double b) noexcept
169  {
170  return 0.5 * (a + b);
171  }
172 
173  //============================================================================
174  // Method Description:
183  static double calculateSecant(const double a, const double b, const double fa, const double fb) noexcept
184  {
185  //No need to check division by 0, in this case the method returns NAN which is taken care by useSecantMethod method
186  return b - fb * (b - a) / (fb - fa);
187  }
188 
189  //============================================================================
190  // Method Description:
201  static double calculateInverseQuadraticInterpolation(const double a, const double b, const double lastB,
202  const double fa, const double fb, const double lastFb) noexcept
203  {
204  return a * fb * lastFb / ((fa - fb) * (fa - lastFb)) +
205  b * fa * lastFb / ((fb - fa) * (fb - lastFb)) +
206  lastB * fa * fb / ((lastFb - fa) * (lastFb - fb));
207  }
208 
209  //============================================================================
210  // Method Description:
218  static bool useInverseQuadraticInterpolation(const double fa, const double fb, const double lastFb) noexcept
219  {
220  return fa != lastFb && fb != lastFb;
221  }
222 
223  //============================================================================
224  // Method Description:
232  static void checkAndFixAlgorithmCriteria(double &a, double &b, double &fa, double &fb) noexcept
233  {
234  //Algorithm works in range [a,b] if criteria f(a)*f(b) < 0 and f(a) > f(b) is fulfilled
235  if (std::fabs(fa) < std::fabs(fb))
236  {
237  std::swap(a, b);
238  std::swap(fa, fb);
239  }
240  }
241 
242  //============================================================================
243  // Method Description:
253  bool useBisection(const bool bisection, const double b, const double lastB,
254  const double penultimateB, const double s) const noexcept
255  {
256  const double DELTA = epsilon_ + std::numeric_limits<double>::min();
257 
258  return (bisection && std::fabs(s - b) >= 0.5 * std::fabs(b - lastB)) || //Bisection was used in last step but |s-b|>=|b-lastB|/2 <- Interpolation step would be to rough, so still use bisection
259  (!bisection && std::fabs(s - b) >= 0.5 * std::fabs(lastB - penultimateB)) || //Interpolation was used in last step but |s-b|>=|lastB-penultimateB|/2 <- Interpolation step would be to small
260  (bisection && std::fabs(b - lastB) < DELTA) || //If last iteration was using bisection and difference between b and lastB is < delta use bisection for next iteration
261  (!bisection && std::fabs(lastB - penultimateB) < DELTA); //If last iteration was using interpolation but difference between lastB ond penultimateB is < delta use biscetion for next iteration
262  }
263  };
264  } // namespace roots
265 } // namespace nc
nc::roots::Iteration::Iteration
Iteration(double epsilon) noexcept
Definition: Iteration.hpp:56
nc::roots::Brent
Definition: Brent.hpp:51
nc::roots::Iteration
ABC for iteration classes to derive from.
Definition: Iteration.hpp:47
nc::roots::Brent::Brent
Brent(const double epsilon, const uint32 maxNumIterations, std::function< double(double)> f) noexcept
Definition: Brent.hpp:75
nc::roots::Brent::~Brent
~Brent() override=default
nc::roots::Iteration::incrementNumberOfIterations
void incrementNumberOfIterations()
Definition: Iteration.hpp:105
nc::uint32
std::uint32_t uint32
Definition: Types.hpp:41
nc::roots::Iteration::epsilon_
const double epsilon_
Definition: Iteration.hpp:115
nc::roots::Brent::Brent
Brent(const double epsilon, std::function< double(double)> f) noexcept
Definition: Brent.hpp:61
nc::roots::Brent::solve
double solve(double a, double b)
Definition: Brent.hpp:96
Iteration.hpp
nc::roots::Iteration::resetNumberOfIterations
void resetNumberOfIterations() noexcept
Definition: Iteration.hpp:94
nc
Definition: Coordinate.hpp:45
nc::swap
void swap(NdArray< dtype > &inArray1, NdArray< dtype > &inArray2) noexcept
Definition: swap.hpp:43
DtypeInfo.hpp
Types.hpp
nc::random::f
dtype f(dtype inDofN, dtype inDofD)
Definition: f.hpp:58
nc::DtypeInfo::max
static constexpr dtype max() noexcept
Definition: DtypeInfo.hpp:111
nc::min
NdArray< dtype > min(const NdArray< dtype > &inArray, Axis inAxis=Axis::NONE)
Definition: min.hpp:46