SigmaTransform
The SigmaTransform unifies various known signal processing transforms, like the STFT and the wavelet transform, differing only by a specific diffeomorphism.
 All Classes Namespaces Files Functions Variables Typedefs Friends Macros
SigmaTransformN.h
Go to the documentation of this file.
1 #ifndef SIGMATRANSFORM_H
2 #define SIGMATRANSFORM_H
3 
4 #include <vector>
5 #include <map>
6 #include <complex>
7 #include <thread>
8 #include <array>
9 #include <iostream>
10 #include <fstream>
11 #include <algorithm>
12 #include <mutex>
13 #include <functional>
14 #include <fftw3.h>
15 
16 #define DEBUG
17 
18 #include "SigmaTransform_util.h"
19 
20 namespace SigmaTransform {
21 
22  // shorthands for cleaner code
23  using cmpx = std::complex<double>;
24  using cxVec = std::vector<cmpx>;
25  template<size_t N> using diffFunc = std::function<point<N>(const point<N>&)>;
26  template<size_t N> using winFunc = std::function<cmpx(const point<N>&)>;
27  template<size_t N> using actFunc = std::function<point<N>(const point<N>&,const point<N>&)>;
28  template<size_t N> using mskFunc = std::function<cmpx(point<N>const&,point<N>const&)>;
29 
34  template<size_t N>
36 
37  public:
48  SigmaTransform( diffFunc<N> sigma=NULL, const point<N> &winWidth=point<N>(0), const point<N> &Fs=point<N>(0), const point<N> &size=point<N>(0),
49  const std::vector<point<N>> &steps=std::vector<point<N>>(0), actFunc<N> action=minus<N> , int const& numThreads = 4 )
50  : SigmaTransform( sigma, (winFunc<N>)(NULL), Fs,size,steps,action,numThreads) { m_winWidth = winWidth; }
51 
62  SigmaTransform( diffFunc<N> sigma=NULL, winFunc<N> window=NULL, const point<N> &Fs=point<N>(0), const point<N> &size=point<N>(0),
63  const std::vector<point<N>> &steps=std::vector<point<N>>(0), actFunc<N> action=minus<N> , int const& numThreads = 4 )
64  : m_window(window),m_sigma(sigma?sigma:id<N>),m_action(action?action:minus<N>),m_windows(0),m_coeff(0),m_reconstructed(0),
65  m_size(size),m_fs(Fs) , m_winWidth(0.0) {
66  setSteps( steps );
67  if( !fftw_init_threads() )
68  std::cerr << "thread error\n";
69  setNumThreads( numThreads );
70  }
71 
78  SigmaTransform& setWindow( winFunc<N> window ) { m_window = window; return *this; }
79 
86  SigmaTransform& setSigma( diffFunc<N> sigma ) { m_sigma = sigma?sigma:id<N>; return *this; }
87 
94  SigmaTransform& setAction( actFunc<N> action ) { m_action = action?action:minus<N>; return *this; }
95 
102  SigmaTransform& setFs( const point<N> &Fs ) { m_fs = Fs; return *this; }
103 
110  SigmaTransform& setSize( const point<N> &size ) { m_size = size; return *this; }
111 
118  SigmaTransform& setWinWidth( const double& winWidth ) { m_window = NULL; m_winWidth = winWidth; return *this; }
119 
126  SigmaTransform& setNumThreads( const int &numThreads )
127  { m_numThreads = numThreads ; fftw_plan_with_nthreads(numThreads); return *this; }
128 
137  SigmaTransform& setSteps( const std::vector<point<N>> &steps ) {
138  if( steps.empty() ) {
139  throw std::runtime_error("steps must not be empty.");
140  }
141  // only number of stepps given? Create steps-array
142  if( steps.size() == 1 ) {
143  if( m_fs==0.0 || !m_sigma) {
144  throw std::runtime_error("Sampling rate and diffeomorphism must be set for automatic step-array determination");
145  }
146  // make Fourier-domain
147  std::array<std::vector<double>,N> stps;
148  for( int k = 0 ; k < N ; ++k ) {
149  stps[k] = linspace( -m_fs[k] / 2.0 , m_fs[k] / 2.0 , steps[0][k] );
150  }
151  m_steps = meshgridN( stps );
152  // get boundary points of warped domain
153  point<N> maxi,mini; maxi=mini=m_steps[0];
154  for( auto& step : m_steps ) {
155  step = m_sigma( step );
156  for( int k = 0 ; k < N ; ++k ) {
157  maxi[k] = (step[k]>maxi[k])?step[k]:maxi[k];
158  mini[k] = (step[k]<mini[k])?step[k]:mini[k];
159  }
160  }
161 
162  // make linear spacing in warped domain
163  for( int k = 0 ; k < N ; ++k ) {
164  stps[k] = linspace( mini[k] , maxi[k] , steps[0][k] );
165  }
166  m_steps = meshgridN( stps );
167  } else {
168  m_steps = steps;
169  }
170 
171  return *this;
172  }
173 
178  cxVec& getCoeffs(){ return m_coeff; }
179 
184  cxVec& getWindows(){ return m_windows; }
185 
190  cxVec& getReconstruction(){ return m_reconstructed; }
191 
198  SigmaTransform& operator()( cxVec const& sig ){ return analyze( sig ); }
199 
212  SigmaTransform& analyze( cxVec const& sig , std::function<void(SigmaTransform*)> onFinish = NULL ) {
213  if( !m_steps.size() ) {
214  throw std::runtime_error("steps not set");
215  }
216  if( !m_fs.prod() ) {
217  throw std::runtime_error("Fs not set");
218  }
219  if( !m_size.prod() ) {
220  throw std::runtime_error("size not set");
221  }
222  return (onFinish) ? asyncTransform( sig , onFinish ) : applyTransform( sig );
223  }
224 
234  SigmaTransform& synthesize( std::function<void(SigmaTransform* obj)> onFinish = NULL ){
235  return (onFinish) ? asyncInverseTransform( onFinish ) : applyInverseTransform( );
236  }
237 
249  cxVec& multiplier( cxVec const& sig, cxVec const& mask, std::function<void(SigmaTransform*)> onFinish = NULL ){
250  return (onFinish)?asyncMultiplier( sig, mask, onFinish ) : analyze(sig).applyMask(mask).synthesize().getReconstruction();
251  }
252 
264  cxVec& multiplier( cxVec const& sig, mskFunc<N> maskFunc, std::function<void(SigmaTransform*)> onFinish = NULL ){
265  return (onFinish)?asyncMultiplier( sig, maskFunc, onFinish ) : analyze(sig).applyMask(maskFunc).synthesize().getReconstruction();
266  }
267 
276  SigmaTransform& applyMask( const cxVec &mask ) {
277  // error?
278  if( mask.size() != m_coeff.size() ) {
279  throw std::runtime_error("Size of mask does not match size of coefficients.");
280  }
281  // deallocate threads after leaving scope
282  std::unique_ptr<std::thread[]> _threads( new std::thread[m_numThreads] );
283  // start threads
284  int stepsPerThread = ceil( (double) (m_steps.size()) / m_numThreads );
285  for( int k = 0, stepsLeft = m_steps.size() ; k < m_numThreads ; ++k, stepsLeft -= stepsPerThread ) {
286  _threads[k] = std::move( std::thread( [this,&stepsPerThread,&mask,stepsLeft,k]() {
287  // get iterator from correct offset
288  int offset = k*stepsPerThread*m_size.prod(),
289  numval = ((stepsLeft>stepsPerThread)?stepsPerThread:stepsLeft)*m_size.prod();
290  // multiply nth part of coeffs
291  for( int i = 0 ; i < numval ; ++i ) {
292  m_coeff[offset+i] *= mask[offset+i];
293  }
294  } ) );
295  }
296  // wait for all threads to finish
297  for( int k = 0 ; k < m_numThreads ; ++k ) {
298  _threads[k].join();
299  }
300  // return
301  return *this;
302  }
303 
310  SigmaTransform& applyMask( mskFunc<N> maskFunc ) {
311  // spatial domain
312  auto spatialDom = makeSpatialDomain();
313  // deallocates itself after leaving scope
314  std::unique_ptr<std::thread[]> _threads( new std::thread[m_numThreads] );
315  // start threads
316  int stepsPerThread = ceil( (double) (m_steps.size()) / m_numThreads );
317  for( int k = 0, stepsLeft = m_steps.size() ; k < m_numThreads ; ++k, stepsLeft -= stepsPerThread ) {
318  _threads[k] = std::move( std::thread( [this,&maskFunc,&spatialDom,&stepsPerThread,stepsLeft,k]() {
319  // get offsets and iterator
320  int stepsOffset = stepsPerThread*k,
321  numval = (stepsLeft>stepsPerThread)?stepsPerThread:stepsLeft;
322  auto coeff = m_coeff.begin() + stepsOffset*m_size.prod();
323  // run thru all points
324  for( int i=0;i<numval;++i) {
325  for( auto const& x : spatialDom ) {
326  *coeff++ *= maskFunc( x , m_steps[stepsOffset+i] );
327  }
328  }
329  } ) );
330  }
331  // wait for all threads to finish
332  for( int k = 0 ; k < m_numThreads ; ++k ) {
333  _threads[k].join();
334  }
335  // return reference
336  return *this;
337  }
338 
343  void makeWindows( ) {
344  // reserve space for windows..
345  m_windows.resize( m_size.prod() * m_steps.size() );
346  // make Domain
348  // check if window was given, else calculate good width for a warped gaussian window
349  if( !m_window )
351  // deallocates itself after leaving scope
352  std::unique_ptr<std::thread[]> _threads( new std::thread[m_numThreads] );
353  // start threads
354  int stepsPerThread = ceil( (double) (m_steps.size()) / m_numThreads );
355  for( int k = 0, stepsLeft = m_steps.size() ; k < m_numThreads ; ++k, stepsLeft -= stepsPerThread ) {
356  _threads[k] = std::move( std::thread( [this,&stepsPerThread,k,stepsLeft]() {
357  // get iterator from correct offset
358  auto win = m_windows.begin() + k*stepsPerThread * m_size.prod();
359  // create n-th part of the windows
360  for( int i = 0 ; i < ((stepsLeft>stepsPerThread)?stepsPerThread:stepsLeft) ; ++i ) {
361  for( auto const& x : m_domain ) {
362  *win++ = m_window( m_action( x , m_steps[k*stepsPerThread + i] ) );
363  }
364  }
365  } ) );
366  }
367  // wait for all threads to finish
368  for( int k = 0 ; k < m_numThreads ; ++k ) {
369  _threads[k].join();
370  }
371  }
372 
380  SigmaTransform& asyncTransform( cxVec const& sig, std::function<void(SigmaTransform*)> onFinish ) {
381  m_threads.insert( std::make_pair<std::string,std::thread>( "Transform" , std::move( std::thread( [this,sig,onFinish]() {
382  //std::cout << "Starting 'Transform' asynchronously."<<std::endl;
383  // try to acquire mutex
384  std::unique_lock<std::mutex> lk( m_mtx );
385  //std::cout << "Acquired mutex in 'Transform'"<<std::endl;
386  // transform asynchronously
387  this->analyze( sig );
388  // call callback
389  onFinish( this );
390  } ) ) ) );
391  return *this;
392  }
393 
400  SigmaTransform& asyncInverseTransform( std::function<void(SigmaTransform*)> onFinish ) {
401  m_threads.insert( std::make_pair<std::string,std::thread>( "Inverse" , std::move( std::thread( [this,onFinish]() {
402  //std::cout << "Starting 'Inverse' asynchronously."<<std::endl;
403  // try to acquire mutex
404  std::unique_lock<std::mutex> lk( m_mtx );
405  //std::cout << "Acquired mutex in 'Inverse'"<<std::endl;
406  // transform asynchronously
407  this->synthesize( );
408  // call callback
409  onFinish( this );
410  } ) ) ) );
411  return *this;
412  }
413 
422  SigmaTransform& asyncMultiplier( cxVec const& sig, cxVec const& mask, std::function<void(SigmaTransform*)> onFinish ) {
423  m_threads.insert( std::make_pair<std::string,std::thread>( "Multiplier" , std::move( std::thread( [this,sig,mask,onFinish]() {
424  //std::cout << "Starting 'Multiplier' asynchronously."<<std::endl;
425  // try to acquire mutex
426  std::unique_lock<std::mutex> lk( m_mtx );
427  //std::cout << "Acquired mutex in 'Multiplier'"<<std::endl;
428  // transform asynchronously
429  this->multiplier( sig,mask );
430  // call callback
431  onFinish( this );
432  } ) ) ) );
433  return *this;
434  }
435 
444  SigmaTransform& asyncMultiplier( cxVec const& sig, mskFunc<N> maskFunc, std::function<void(SigmaTransform*)> onFinish ) {
445  m_threads.insert( std::make_pair<std::string,std::thread>( "Multiplier" , std::move( std::thread( [this,sig,maskFunc,onFinish]() {
446  //std::cout << "Starting 'Multiplier' asynchronously."<<std::endl;
447  // try to acquire mutex
448  std::unique_lock<std::mutex> lk( m_mtx );
449  //std::cout << "Acquired mutex in 'Multiplier'"<<std::endl;
450  // transform asynchronously
451  this->multiplier( sig,maskFunc );
452  // call callback
453  onFinish( this );
454  } ) ) ) );
455  return *this;
456  }
457 
463  // get threads to join
464  std::vector<std::string> toErase;
465  for( auto it = m_threads.begin(); it != m_threads.end() ; ++it ) {
466  if( it->second.joinable() ) {
467  //std::cout << "joining " << it->first << ".\n";
468  it->second.join();
469  toErase.push_back( it->first );
470  }
471  }
472  // erase joined threads
473  for( auto const& name : toErase ) {
474  for( auto it = m_threads.begin(); it != m_threads.end() ; ++it ) {
475  if( name.compare( it->first) == 0 ) {
476  //std::cout << "erasing " << it->first << ".\n";
477  m_threads.erase( it );
478  break;
479  }
480  }
481  }
482  return *this;
483  }
484 
492  cxVec fft( cxVec const& in , int const& howmany = 1 ) {
493  // get space
494  cxVec out( in.size() );
495  // fft transform the signal
496  fftN( reinterpret_cast<fftw_complex*> (out.data()) ,
497  reinterpret_cast<fftw_complex*>(const_cast<cmpx*> (in.data())) ,
498  m_size , howmany , FFTW_FORWARD );
499  // return memory
500  return std::move( out );
501  }
502 
510  void fft_inplace( cxVec& inout , int const& howmany = 1 ) {
511  // ifft transform the signal
512  fftN( reinterpret_cast<fftw_complex*> (inout.data()) ,
513  reinterpret_cast<fftw_complex*>(const_cast<cmpx*> (inout.data())) ,
514  m_size , howmany , FFTW_FORWARD );
515  }
516 
524  cxVec ifft( cxVec const& in , int const& howmany = 1 ) {
525  // get space
526  cxVec out( in.size() );
527  // ifft transform the signal
528  fftN( reinterpret_cast<fftw_complex*> (out.data()) ,
529  reinterpret_cast<fftw_complex*>(const_cast<cmpx*> (in.data())) ,
530  m_size , howmany , FFTW_BACKWARD );
531  // return memory
532  return std::move( out );
533  }
534 
542  void ifft_inplace( cxVec& inout , int const& howmany = 1 ) {
543  // ifft transform the signal
544  fftN( reinterpret_cast<fftw_complex*> (inout.data()) ,
545  reinterpret_cast<fftw_complex*>(const_cast<cmpx*> (inout.data())) ,
546  m_size , howmany , FFTW_BACKWARD );
547  }
548 
549  protected:
550 
556  // make (fft-shifted) domain...
557  std::array<std::vector<double>,N> doms;
558  auto itFs = m_fs.begin(),itSz = m_size.begin();
559  for(auto& d : doms) {
560  d = FourierAxis( *itFs++ , *itSz++ );
561  }
562  m_domain = meshgridN( doms );
563  // ...and warp the domain
564  std::for_each(m_domain.begin(),m_domain.end(),[&](point<N>&x){x=m_sigma(x);});
565  }
566 
571  std::vector<point<N>> makeSpatialDomain() {
572  std::array<std::vector<double>,N> doms;
573  auto itFs = m_fs.begin(), itSz = m_size.begin();
574  for(auto& d : doms) {
575  d = linspace( 0 , (*itSz-1) / *itFs , *itSz );
576  itFs++; itSz++;
577  }
578  return std::move( meshgridN( doms ) );
579  }
580 
586  // make adequate standard deviation
587  point<N> maxi,mini,num_steps{1};
588  maxi=mini=m_steps[0];
589  for( auto& step : m_steps ) {
590  for( int k = 0 ; k < N ; ++k ) {
591  if( step[k]>maxi[k] ) {
592  maxi[k] = step[k];
593  num_steps[k]++;
594  }
595  if( step[k]<mini[k] ) {
596  mini[k] = step[k];
597  }
598  }
599  }
600  // winWidth not set? then set it to "8.0"
601  if( m_winWidth==0.0 )
602  m_winWidth = point<N>(8.0);
603 
604  auto width = (maxi-mini) / num_steps * m_winWidth;
605 
606  m_window = [&](const point<N>&x)->cmpx{ return gauss_stddev( x , width ); };
607  }
608 
615  SigmaTransform& applyTransform( const cxVec &in ) {
616  // convenient var
617  double sigsize = m_size.prod();
618  // make windows
619  makeWindows( );
620  // fft transform the signal
621  cxVec Fsig = fft( in );
622  // copy the windows
623  m_coeff = m_windows;
624  // deallocates itself after leaving scope
625  std::unique_ptr<std::thread[]> _threads( new std::thread[m_numThreads] );
626  // start threads
627  int stepsPerThread = ceil( (double) (m_steps.size()) / m_numThreads );
628  for( int k = 0, stepsLeft = m_steps.size() ; k < m_numThreads ; ++k, stepsLeft -= stepsPerThread ) {
629  _threads[k] = std::move( std::thread( [this,&Fsig,&stepsPerThread,stepsLeft,k]() {
630  cxVec::iterator coeff = m_coeff.begin() + (int)( k*stepsPerThread*m_size.prod() );
631  for( int i = 0 ; i < ((stepsLeft>stepsPerThread)?stepsPerThread:stepsLeft) ; ++i ) {
632  for( auto const& val : Fsig ) {
633  *coeff = conj(*coeff++) * val / ((double)Fsig.size());
634  }
635  }
636  } ) );
637  }
638  // wait for all threads to finish
639  for( int k = 0 ; k < m_numThreads ; ++k ) {
640  _threads[k].join();
641  }
642  // transform back
643  ifft_inplace( m_coeff , m_steps.size() );
644  // return
645  return *this;
646  }
647 
653  // reserve vectorspace with zeros
654  m_reconstructed = cxVec( m_coeff.size() / m_steps.size() , 0 );
655 
656  // fft transform the signal
657  cxVec temp = fft( m_coeff , m_steps.size() );
658 
659  // act on signal
660  auto coeff = temp.begin(),
661  window = m_windows.begin();
662  for(int k=0;k<m_steps.size();++k) {
663  for( auto &val : m_reconstructed ) {
664  val += (*coeff++) * (*window++);
665  }
666  }
667 
668  // transform back
670 
671  // return
672  return *this;
673  }
674 
685  void fftN( fftw_complex *out, fftw_complex *in, const point<N> &size, const int &howmany = 1, const int& DIR = FFTW_FORWARD ) {
686  int sz[N];
687  for( int k = 0 ; k < N ; ++k ) sz[k] = (int) size[k];
688  // make, perform and destroy FFTW-plan
689  fftw_plan p = fftw_plan_many_dft( N , sz , howmany , in , NULL , 1 , (int) size.prod() ,
690  out , NULL , 1 , (int) size.prod() ,
691  DIR , FFTW_ESTIMATE );
692  fftw_execute( p );
693  fftw_destroy_plan( p );
694  }
695 
696  // function handles for the transform
697  std::function<cmpx(point<N>const&)> m_window;
698  std::function<point<N>(point<N>const&)> m_sigma;
699  std::function<point<N>(point<N>const&,point<N>const&)> m_action;
700 
701  // holds data
702  cxVec m_windows;
703  cxVec m_coeff;
705 
706  // holds information about data
709  std::vector<point<N>> m_steps;
710  std::vector<point<N>> m_domain;
713 
714  // for asynchronous computations
715  std::map<std::string,std::thread> m_threads;
716  std::mutex m_mtx;
717 
718  }; // class SigmaTransform
719 
720 } // namespace SigmaTransform
721 
722 #endif //SIGMATRANSFORM_H