Deterministic Gaussian Sampling
Loading...
Searching...
No Matches
gm_to_dirac_short.h
1#ifndef GM_TO_DIRAC_SHORT_H
2#define GM_TO_DIRAC_SHORT_H
3
4#include <gtest/gtest.h>
5
6#include "gm_to_dirac_approx_i.h"
7#include "gm_to_dirac_optimization_params.h"
8
9template <typename T>
11 public:
12 using GSLVectorType = typename gm_to_dirac_approx_i<T>::GSLVectorType;
13 using GSLVectorViewType = typename gm_to_dirac_approx_i<T>::GSLVectorViewType;
14 using GSLMatrixType = typename gm_to_dirac_approx_i<T>::GSLMatrixType;
15
16 // clang-format off
17 bool approximate(const T* covDiag,
18 size_t L,
19 size_t N,
20 T* x,
21 const T* wX = nullptr,
22 GslminimizerResult* result = nullptr,
24 // clang-format on
25
26 // clang-format off
27 void modified_van_mises_distance_sq(const T* covDiag,
28 T* distance,
29 size_t L,
30 size_t N,
31 size_t bMax,
32 T* x,
33 const T* wX) override;
34 // clang-format on
35
36 // clang-format off
38 T* gradient,
39 size_t L,
40 size_t N,
41 size_t bMax,
42 T* x,
43 const T* wX) override;
44 // clang-format on
45
46 // clang-format off
47 bool approximate(const GSLVectorType* covDiag,
48 size_t L,
49 size_t N,
50 GSLVectorType* x,
51 const GSLVectorType* wX = nullptr,
52 GslminimizerResult* result = nullptr,
54 // clang-format on
55
56 // clang-format off
57 void modified_van_mises_distance_sq(const GSLVectorType* covDiag,
58 T* distance,
59 size_t L,
60 size_t N,
61 size_t bMax,
62 GSLVectorType* x,
63 const GSLVectorType* wX) override;
64 // clang-format on
65
66 // clang-format off
67 void modified_van_mises_distance_sq_derivative(const GSLVectorType* covDiag,
68 GSLVectorType* gradient,
69 size_t L,
70 size_t N,
71 size_t bMax,
72 GSLVectorType* x,
73 const GSLVectorType* wX) override;
74 // clang-format on
75
76 // clang-format off
77 bool approximate(const GSLVectorType* covDiag,
78 size_t L,
79 size_t N,
80 GSLMatrixType* x,
81 const GSLVectorType* wX = nullptr,
82 GslminimizerResult* result = nullptr,
84 // clang-format on
85
86 // clang-format off
87 void modified_van_mises_distance_sq(const GSLVectorType* covDiag,
88 T* distance,
89 size_t L,
90 size_t N,
91 size_t bMax,
92 GSLMatrixType* x,
93 const GSLVectorType* wX) override;
94 // clang-format on
95
96 // clang-format off
97 void modified_van_mises_distance_sq_derivative(const GSLVectorType* covDiag,
98 GSLMatrixType* gradient,
99 size_t L,
100 size_t N,
101 size_t bMax,
102 GSLMatrixType* x,
103 const GSLVectorType* wX) override;
104 // clang-format on
105
106 private:
107 static double modified_van_mises_distance_sq(const gsl_vector* x,
108 void* params);
110 void* params,
112 static void combined_distance_metric(const gsl_vector* x, void* params,
113 double* f, gsl_vector* grad);
114
115 static inline double calculateP2(double b, void* params);
116 static inline double calculateGradP2(double b, void* params);
117
118 static double c_b(size_t bMax);
119 static inline void calculateD2(const gsl_vector* x,
121 double* f, gsl_vector* grad);
122
123 static inline void calculateD3(const gsl_vector* x,
125 double* f, gsl_vector* grad);
126
127 static inline void correctMean(gsl_vector* x, const gsl_vector* wX, size_t L,
128 size_t N);
129
130 friend class benchmark_gm_to_dirac_short;
131 FRIEND_TEST(
134 FRIEND_TEST(
137 FRIEND_TEST(
140};
141
142#include "gm_to_dirac_short.tpp"
143
144template <>
146 size_t L, size_t N,
148 const gsl_vector_float* wX,
151
152template <>
153bool gm_to_dirac_short<double>::approximate(const gsl_vector* covDiag, size_t L,
154 size_t N, gsl_vector* x,
155 const gsl_vector* wX,
158
159template <>
161 const gsl_vector_float* covDiag, float* distance, size_t L, size_t N,
162 size_t bMax, gsl_vector_float* x, const gsl_vector_float* wX);
163
164template <>
166 const gsl_vector* covDiag, double* distance, size_t L, size_t N,
167 size_t bMax, gsl_vector* x, const gsl_vector* wX);
168
169template <>
171 const gsl_vector_float* covDiag, gsl_vector_float* gradient, size_t L,
172 size_t N, size_t bMax, gsl_vector_float* x, const gsl_vector_float* wX);
173
174template <>
176 const gsl_vector* covDiag, gsl_vector* gradient, size_t L, size_t N,
177 size_t bMax, gsl_vector* x, const gsl_vector* wX);
178
179extern template class gm_to_dirac_short<double>;
180extern template class gm_to_dirac_short<float>;
181
182#endif // GM_TO_DIRAC_SHORT_H
Definition dirac_to_dirac_approx_short.h:9
interface for the gausian mixture to dirac approximation
Definition gm_to_dirac_approx_i.h:20
Definition gm_to_dirac_short.h:10
void modified_van_mises_distance_sq(const GSLVectorType *covDiag, T *distance, size_t L, size_t N, size_t bMax, GSLMatrixType *x, const GSLVectorType *wX) override
calculate modified van mises distance based on standard normal deviation and x
void modified_van_mises_distance_sq_derivative(const T *covDiag, T *gradient, size_t L, size_t N, size_t bMax, T *x, const T *wX) override
calculate modified van mises distance based on standard normal deviation and x
Definition gm_to_dirac_short.cpp:53
bool approximate(const GSLVectorType *covDiag, size_t L, size_t N, GSLMatrixType *x, const GSLVectorType *wX=nullptr, GslminimizerResult *result=nullptr, const ApproximateOptions &options=ApproximateOptions{}) override
approximate using gsl vectors
void modified_van_mises_distance_sq(const T *covDiag, T *distance, size_t L, size_t N, size_t bMax, T *x, const T *wX) override
calculate modified van mises distance based on standard normal deviation and x
Definition gm_to_dirac_short.cpp:41
bool approximate(const T *covDiag, size_t L, size_t N, T *x, const T *wX=nullptr, GslminimizerResult *result=nullptr, const ApproximateOptions &options=ApproximateOptions{}) override
approximate using raw pointers
Definition gm_to_dirac_short.cpp:26
void modified_van_mises_distance_sq_derivative(const GSLVectorType *covDiag, GSLMatrixType *gradient, size_t L, size_t N, size_t bMax, GSLMatrixType *x, const GSLVectorType *wX) override
calculate modified van mises distance based on standard normal deviation and x
Definition approximate_options.h:6
optimization parameters for the GMToDirac approximation with constant weights
Definition gm_to_dirac_optimization_params.h:163
struct to hold the result of the minimization
Definition gsl_minimizer.h:32