Deterministic Gaussian Sampling
Loading...
Searching...
No Matches
dirac_to_dirac_approx_short.h
1#ifndef DIRAC_TO_DIRAC_SHORT_H
2#define DIRAC_TO_DIRAC_SHORT_H
3
4#include <gtest/gtest.h>
5
6#include "dirac_to_dirac_approx_i.h"
7
8template <typename T>
10 public:
11 using GSLVectorType = typename dirac_to_dirac_approx_i<T>::GSLVectorType;
12 using GSLVectorViewType =
13 typename dirac_to_dirac_approx_i<T>::GSLVectorViewType;
14 using GSLMatrixType = typename dirac_to_dirac_approx_i<T>::GSLMatrixType;
15 using GSLMatrixViewType =
16 typename dirac_to_dirac_approx_i<T>::GSLMatrixViewType;
17
18 // clang-format off
19 bool approximate(const T* y,
20 size_t M,
21 size_t L,
22 size_t N,
23 size_t bMax,
24 T* x,
25 const T* wX = nullptr,
26 const T* wY = nullptr,
27 GslminimizerResult* result = nullptr,
28 const ApproximateOptions& options = ApproximateOptions{}) override;
29 // clang-format on
30
31 // clang-format off
32 void modified_van_mises_distance_sq(T* distance,
33 const T *y,
34 size_t M,
35 size_t L,
36 size_t N,
37 size_t bMax,
38 T *x,
39 const T *wX = nullptr,
40 const T *wY = nullptr) override;
41 // clang-format on
42
43 // clang-format off
45 const T *y,
46 size_t M,
47 size_t L,
48 size_t N,
49 size_t bMax,
50 T *x,
51 const T *wX = nullptr,
52 const T *wY = nullptr) override;
53 // clang-format on
54
55 // clang-format off
56 bool approximate(const GSLVectorType* y,
57 size_t L,
58 size_t N,
59 size_t bMax,
60 GSLVectorType* x,
61 const GSLVectorType* wX = nullptr,
62 const GSLVectorType* wY = nullptr,
63 GslminimizerResult* result = nullptr,
64 const ApproximateOptions& options = ApproximateOptions{}) override;
65 // clang-format on
66
67 // clang-format off
69 const GSLVectorType *y,
70 size_t L,
71 size_t N,
72 size_t bMax,
73 GSLVectorType *x,
74 const GSLVectorType *wX = nullptr,
75 const GSLVectorType *wY = nullptr) override;
76 // clang-format on
77
78 // clang-format off
79 void modified_van_mises_distance_sq_derivative(GSLMatrixType* gradient,
80 const GSLVectorType *y,
81 size_t L,
82 size_t N,
83 size_t bMax,
84 GSLVectorType *x,
85 const GSLVectorType *wX = nullptr,
86 const GSLVectorType *wY = nullptr) override;
87 // clang-format on
88
89 // clang-format off
90 bool approximate(GSLMatrixType* y,
91 size_t L,
92 size_t bMax,
93 GSLMatrixType* x,
94 const GSLVectorType* wX = nullptr,
95 const GSLVectorType* wY = nullptr,
96 GslminimizerResult* result = nullptr,
97 const ApproximateOptions& options = ApproximateOptions{}) override;
98 // clang-format on
99
100 // clang-format off
101 void modified_van_mises_distance_sq(T* distance,
102 GSLMatrixType *y,
103 size_t L,
104 size_t bMax,
105 GSLMatrixType *x,
106 const GSLVectorType *wX = nullptr,
107 const GSLVectorType *wY = nullptr) override;
108 // clang-format on
109
110 // clang-format off
111 void modified_van_mises_distance_sq_derivative(GSLMatrixType* gradient,
112 GSLMatrixType *y,
113 size_t L,
114 size_t bMax,
115 GSLMatrixType *x,
116 const GSLVectorType *wX = nullptr,
117 const GSLVectorType *wY = nullptr) override;
118 // clang-format on
119
120 private:
121 static double c_b(size_t bMax);
122 static double modified_van_mises_distance_sq(const gsl_vector* x,
123 void* params);
124 static void modified_van_mises_distance_sq_derivative(const gsl_vector* x,
125 void* params,
126 gsl_vector* grad);
127 static void combined_distance_metric(const gsl_vector* x, void* params,
128 double* f, gsl_vector* grad);
129
130 static inline void correctMean(const GSLVectorType* meanY, GSLVectorType* x,
131 const GSLVectorType* wX, size_t L, size_t N);
132
133 FRIEND_TEST(
134 dirac_to_dirac_approx_short_test_modified_van_mises_distance_sq_derivative,
135 parameterized_test_modified_van_mises_distance_sq_derivative);
136 FRIEND_TEST(
137 dirac_to_dirac_approx_short_test_modified_van_mises_distance_sq_derivative,
138 parameterized_test_modified_van_mises_distance_sq_derivative_wrapper_distance);
139 FRIEND_TEST(
140 dirac_to_dirac_approx_short_test_modified_van_mises_distance_sq_derivative,
141 parameterized_test_modified_van_mises_distance_sq_derivative_wrapper_gradient);
142 FRIEND_TEST(dirac_to_dirac_approx_short_test_combined,
143 parameterized_test_combined);
144 friend class benchmark_dirac_to_dirac_approx_short;
145};
146
147template <>
149 const gsl_vector_float* y, size_t L, size_t N, size_t bMax,
150 gsl_vector_float* x, const gsl_vector_float* wX, const gsl_vector_float* wY,
151 GslminimizerResult* result, const ApproximateOptions& options);
152
153template <>
155 const gsl_vector* y, size_t L, size_t N, size_t bMax, gsl_vector* x,
156 const gsl_vector* wX, const gsl_vector* wY, GslminimizerResult* result,
157 const ApproximateOptions& options);
158
159template <>
161 float* distance, const gsl_vector_float* y, size_t L, size_t N, size_t bMax,
162 gsl_vector_float* x, const gsl_vector_float* wX,
163 const gsl_vector_float* wY);
164
165template <>
167 double* distance, const gsl_vector* y, size_t L, size_t N, size_t bMax,
168 gsl_vector* x, const gsl_vector* wX, const gsl_vector* wY);
169
170template <>
172 modified_van_mises_distance_sq_derivative(gsl_matrix_float* gradient,
173 const gsl_vector_float* y,
174 size_t L, size_t N, size_t bMax,
175 gsl_vector_float* x,
176 const gsl_vector_float* wX,
177 const gsl_vector_float* wY);
178
179template <>
181 double>::modified_van_mises_distance_sq_derivative(gsl_matrix* gradient,
182 const gsl_vector* y,
183 size_t L, size_t N,
184 size_t bMax,
185 gsl_vector* x,
186 const gsl_vector* wX,
187 const gsl_vector* wY);
188
189extern template class dirac_to_dirac_approx_short<double>;
190extern template class dirac_to_dirac_approx_short<float>;
191
192#endif // DIRAC_TO_DIRAC_SHORT_H
interface for the gausian mixture to dirac approximation
Definition dirac_to_dirac_approx_i.h:20
Definition dirac_to_dirac_approx_short.h:9
void modified_van_mises_distance_sq(T *distance, const T *y, size_t M, size_t L, size_t N, size_t bMax, T *x, const T *wX=nullptr, const T *wY=nullptr) override
calculate modified van mises distance based on x and y
Definition dirac_to_dirac_approx_short.cpp:40
void modified_van_mises_distance_sq_derivative(T *gradient, const T *y, size_t M, size_t L, size_t N, size_t bMax, T *x, const T *wX=nullptr, const T *wY=nullptr) override
calculate modified van mises distance based on x and y
Definition dirac_to_dirac_approx_short.cpp:55
bool approximate(const GSLVectorType *y, size_t L, size_t N, size_t bMax, GSLVectorType *x, const GSLVectorType *wX=nullptr, const GSLVectorType *wY=nullptr, GslminimizerResult *result=nullptr, const ApproximateOptions &options=ApproximateOptions{}) override
reduce the data points using gsl vectors
bool approximate(const T *y, size_t M, size_t L, size_t N, size_t bMax, T *x, const T *wX=nullptr, const T *wY=nullptr, GslminimizerResult *result=nullptr, const ApproximateOptions &options=ApproximateOptions{}) override
reduce the data points using raw pointers
Definition dirac_to_dirac_approx_short.cpp:24
void modified_van_mises_distance_sq_derivative(GSLMatrixType *gradient, const GSLVectorType *y, size_t L, size_t N, size_t bMax, GSLVectorType *x, const GSLVectorType *wX=nullptr, const GSLVectorType *wY=nullptr) override
calculate modified van mises distance based on x and y
void modified_van_mises_distance_sq(T *distance, const GSLVectorType *y, size_t L, size_t N, size_t bMax, GSLVectorType *x, const GSLVectorType *wX=nullptr, const GSLVectorType *wY=nullptr) override
calculate modified van mises distance based on x and y
Definition approximate_options.h:6
struct to hold the result of the minimization
Definition gsl_minimizer.h:32