Deterministic Gaussian Sampling
Loading...
Searching...
No Matches
gm_to_dirac_short.h
1#ifndef GM_TO_DIRAC_SHORT_H
2#define GM_TO_DIRAC_SHORT_H
3
4#include <gtest/gtest.h>
5
6#include "gm_to_dirac_approx_i.h"
7#include "gm_to_dirac_optimization_params.h"
8
9template <typename T>
11 public:
12 using GSLVectorType = typename gm_to_dirac_approx_i<T>::GSLVectorType;
13 using GSLVectorViewType = typename gm_to_dirac_approx_i<T>::GSLVectorViewType;
14 using GSLMatrixType = typename gm_to_dirac_approx_i<T>::GSLMatrixType;
15
16 // clang-format off
17 bool approximate(const T* covDiag,
18 size_t L,
19 size_t N,
20 size_t bMax,
21 T* x,
22 const T* wX = nullptr,
23 GslminimizerResult* result = nullptr,
25 // clang-format on
26
27 // clang-format off
28 void modified_van_mises_distance_sq(const T* covDiag,
29 T* distance,
30 size_t L,
31 size_t N,
32 size_t bMax,
33 T* x,
34 const T* wX) override;
35 // clang-format on
36
37 // clang-format off
39 T* gradient,
40 size_t L,
41 size_t N,
42 size_t bMax,
43 T* x,
44 const T* wX) override;
45 // clang-format on
46
47 // clang-format off
48 bool approximate(const GSLVectorType* covDiag,
49 size_t L,
50 size_t N,
51 size_t bMax,
52 GSLVectorType* x,
53 const GSLVectorType* wX = nullptr,
54 GslminimizerResult* result = nullptr,
56 // clang-format on
57
58 // clang-format off
59 void modified_van_mises_distance_sq(const GSLVectorType* covDiag,
60 T* distance,
61 size_t L,
62 size_t N,
63 size_t bMax,
64 GSLVectorType* x,
65 const GSLVectorType* wX) override;
66 // clang-format on
67
68 // clang-format off
69 void modified_van_mises_distance_sq_derivative(const GSLVectorType* covDiag,
70 GSLVectorType* gradient,
71 size_t L,
72 size_t N,
73 size_t bMax,
74 GSLVectorType* x,
75 const GSLVectorType* wX) override;
76 // clang-format on
77
78 // clang-format off
79 bool approximate(const GSLVectorType* covDiag,
80 size_t L,
81 size_t N,
82 size_t bMax,
83 GSLMatrixType* x,
84 const GSLVectorType* wX = nullptr,
85 GslminimizerResult* result = nullptr,
87 // clang-format on
88
89 // clang-format off
90 void modified_van_mises_distance_sq(const GSLVectorType* covDiag,
91 T* distance,
92 size_t L,
93 size_t N,
94 size_t bMax,
95 GSLMatrixType* x,
96 const GSLVectorType* wX) override;
97 // clang-format on
98
99 // clang-format off
100 void modified_van_mises_distance_sq_derivative(const GSLVectorType* covDiag,
101 GSLMatrixType* gradient,
102 size_t L,
103 size_t N,
104 size_t bMax,
105 GSLMatrixType* x,
106 const GSLVectorType* wX) override;
107 // clang-format on
108
109 private:
110 static double modified_van_mises_distance_sq(const gsl_vector* x,
111 void* params);
113 void* params,
115 static void combined_distance_metric(const gsl_vector* x, void* params,
116 double* f, gsl_vector* grad);
117
118 static inline double calculateP2(double b, void* params);
119 static inline double calculateGradP2(double b, void* params);
120
121 static double c_b(size_t bMax);
122 static inline void calculateD2(const gsl_vector* x,
124 double* f, gsl_vector* grad);
125
126 static inline void calculateD3(const gsl_vector* x,
128 double* f, gsl_vector* grad);
129
130 static inline void correctMean(gsl_vector* x, const gsl_vector* wX, size_t L,
131 size_t N);
132
133 friend class benchmark_gm_to_dirac_short;
134 FRIEND_TEST(
137 FRIEND_TEST(
140 FRIEND_TEST(
143};
144
145#include "gm_to_dirac_short.tpp"
146
147template <>
149 size_t L, size_t N, size_t bMax,
151 const gsl_vector_float* wX,
154
155template <>
156bool gm_to_dirac_short<double>::approximate(const gsl_vector* covDiag, size_t L,
157 size_t N, size_t bMax,
158 gsl_vector* x, const gsl_vector* wX,
161
162template <>
164 const gsl_vector_float* covDiag, float* distance, size_t L, size_t N,
165 size_t bMax, gsl_vector_float* x, const gsl_vector_float* wX);
166
167template <>
169 const gsl_vector* covDiag, double* distance, size_t L, size_t N,
170 size_t bMax, gsl_vector* x, const gsl_vector* wX);
171
172template <>
174 const gsl_vector_float* covDiag, gsl_vector_float* gradient, size_t L,
175 size_t N, size_t bMax, gsl_vector_float* x, const gsl_vector_float* wX);
176
177template <>
179 const gsl_vector* covDiag, gsl_vector* gradient, size_t L, size_t N,
180 size_t bMax, gsl_vector* x, const gsl_vector* wX);
181
182extern template class gm_to_dirac_short<double>;
183extern template class gm_to_dirac_short<float>;
184
185#endif // GM_TO_DIRAC_SHORT_H
Definition dirac_to_dirac_approx_short.h:9
interface for the gausian mixture to dirac approximation
Definition gm_to_dirac_approx_i.h:20
Definition gm_to_dirac_short.h:10
void modified_van_mises_distance_sq(const GSLVectorType *covDiag, T *distance, size_t L, size_t N, size_t bMax, GSLMatrixType *x, const GSLVectorType *wX) override
calculate modified van mises distance based on standard normal deviation and x
void modified_van_mises_distance_sq_derivative(const T *covDiag, T *gradient, size_t L, size_t N, size_t bMax, T *x, const T *wX) override
calculate modified van mises distance based on standard normal deviation and x
Definition gm_to_dirac_short.cpp:53
bool approximate(const GSLVectorType *covDiag, size_t L, size_t N, size_t bMax, GSLMatrixType *x, const GSLVectorType *wX=nullptr, GslminimizerResult *result=nullptr, const ApproximateOptions &options=ApproximateOptions{}) override
approximate using gsl vectors
bool approximate(const T *covDiag, size_t L, size_t N, size_t bMax, T *x, const T *wX=nullptr, GslminimizerResult *result=nullptr, const ApproximateOptions &options=ApproximateOptions{}) override
approximate using raw pointers
Definition gm_to_dirac_short.cpp:26
void modified_van_mises_distance_sq(const T *covDiag, T *distance, size_t L, size_t N, size_t bMax, T *x, const T *wX) override
calculate modified van mises distance based on standard normal deviation and x
Definition gm_to_dirac_short.cpp:41
void modified_van_mises_distance_sq_derivative(const GSLVectorType *covDiag, GSLMatrixType *gradient, size_t L, size_t N, size_t bMax, GSLMatrixType *x, const GSLVectorType *wX) override
calculate modified van mises distance based on standard normal deviation and x
Definition approximate_options.h:6
optimization parameters for the GMToDirac approximation with constant weights
Definition gm_to_dirac_optimization_params.h:163
struct to hold the result of the minimization
Definition gsl_minimizer.h:32