Deterministic Gaussian Sampling
Loading...
Searching...
No Matches
dirac_to_dirac_approx_short.h
1#ifndef DIRAC_TO_DIRAC_SHORT_H
2#define DIRAC_TO_DIRAC_SHORT_H
3
4#include <gtest/gtest.h>
5
6#include "dirac_to_dirac_approx_i.h"
7
8template <typename T>
10 public:
11 using GSLVectorType = typename dirac_to_dirac_approx_i<T>::GSLVectorType;
12 using GSLVectorViewType =
13 typename dirac_to_dirac_approx_i<T>::GSLVectorViewType;
14 using GSLMatrixType = typename dirac_to_dirac_approx_i<T>::GSLMatrixType;
15 using GSLMatrixViewType =
16 typename dirac_to_dirac_approx_i<T>::GSLMatrixViewType;
17
18 // clang-format off
19 bool approximate(const T* y,
20 size_t M,
21 size_t L,
22 size_t N,
23 T* x,
24 const T* wX = nullptr,
25 const T* wY = nullptr,
26 GslminimizerResult* result = nullptr,
27 const ApproximateOptions& options = ApproximateOptions{}) override;
28 // clang-format on
29
30 // clang-format off
31 void modified_van_mises_distance_sq(T* distance,
32 const T *y,
33 size_t M,
34 size_t L,
35 size_t N,
36 size_t bMax,
37 T *x,
38 const T *wX = nullptr,
39 const T *wY = nullptr) override;
40 // clang-format on
41
42 // clang-format off
44 const T *y,
45 size_t M,
46 size_t L,
47 size_t N,
48 size_t bMax,
49 T *x,
50 const T *wX = nullptr,
51 const T *wY = nullptr) override;
52 // clang-format on
53
54 // clang-format off
55 bool approximate(const GSLVectorType* y,
56 size_t L,
57 size_t N,
58 GSLVectorType* x,
59 const GSLVectorType* wX = nullptr,
60 const GSLVectorType* wY = nullptr,
61 GslminimizerResult* result = nullptr,
62 const ApproximateOptions& options = ApproximateOptions{}) override;
63 // clang-format on
64
65 // clang-format off
67 const GSLVectorType *y,
68 size_t L,
69 size_t N,
70 size_t bMax,
71 GSLVectorType *x,
72 const GSLVectorType *wX = nullptr,
73 const GSLVectorType *wY = nullptr) override;
74 // clang-format on
75
76 // clang-format off
77 void modified_van_mises_distance_sq_derivative(GSLMatrixType* gradient,
78 const GSLVectorType *y,
79 size_t L,
80 size_t N,
81 size_t bMax,
82 GSLVectorType *x,
83 const GSLVectorType *wX = nullptr,
84 const GSLVectorType *wY = nullptr) override;
85 // clang-format on
86
87 // clang-format off
88 bool approximate(GSLMatrixType* y,
89 size_t L,
90 GSLMatrixType* x,
91 const GSLVectorType* wX = nullptr,
92 const GSLVectorType* wY = nullptr,
93 GslminimizerResult* result = nullptr,
94 const ApproximateOptions& options = ApproximateOptions{}) override;
95 // clang-format on
96
97 // clang-format off
98 void modified_van_mises_distance_sq(T* distance,
99 GSLMatrixType *y,
100 size_t L,
101 size_t bMax,
102 GSLMatrixType *x,
103 const GSLVectorType *wX = nullptr,
104 const GSLVectorType *wY = nullptr) override;
105 // clang-format on
106
107 // clang-format off
108 void modified_van_mises_distance_sq_derivative(GSLMatrixType* gradient,
109 GSLMatrixType *y,
110 size_t L,
111 size_t bMax,
112 GSLMatrixType *x,
113 const GSLVectorType *wX = nullptr,
114 const GSLVectorType *wY = nullptr) override;
115 // clang-format on
116
117 private:
118 static double c_b(size_t bMax);
119 static double modified_van_mises_distance_sq(const gsl_vector* x,
120 void* params);
121 static void modified_van_mises_distance_sq_derivative(const gsl_vector* x,
122 void* params,
123 gsl_vector* grad);
124 static void combined_distance_metric(const gsl_vector* x, void* params,
125 double* f, gsl_vector* grad);
126
127 static inline void correctMean(const GSLVectorType* meanY, GSLVectorType* x,
128 const GSLVectorType* wX, size_t L, size_t N);
129
130 FRIEND_TEST(
131 dirac_to_dirac_approx_short_test_modified_van_mises_distance_sq_derivative,
132 parameterized_test_modified_van_mises_distance_sq_derivative);
133 FRIEND_TEST(
134 dirac_to_dirac_approx_short_test_modified_van_mises_distance_sq_derivative,
135 parameterized_test_modified_van_mises_distance_sq_derivative_wrapper_distance);
136 FRIEND_TEST(
137 dirac_to_dirac_approx_short_test_modified_van_mises_distance_sq_derivative,
138 parameterized_test_modified_van_mises_distance_sq_derivative_wrapper_gradient);
139 FRIEND_TEST(dirac_to_dirac_approx_short_test_combined,
140 parameterized_test_combined);
141 friend class benchmark_dirac_to_dirac_approx_short;
142};
143
144template <>
146 const gsl_vector_float* y, size_t L, size_t N, gsl_vector_float* x,
147 const gsl_vector_float* wX, const gsl_vector_float* wY,
148 GslminimizerResult* result, const ApproximateOptions& options);
149
150template <>
152 const gsl_vector* y, size_t L, size_t N, gsl_vector* x,
153 const gsl_vector* wX, const gsl_vector* wY, GslminimizerResult* result,
154 const ApproximateOptions& options);
155
156template <>
158 float* distance, const gsl_vector_float* y, size_t L, size_t N, size_t bMax,
159 gsl_vector_float* x, const gsl_vector_float* wX,
160 const gsl_vector_float* wY);
161
162template <>
164 double* distance, const gsl_vector* y, size_t L, size_t N, size_t bMax,
165 gsl_vector* x, const gsl_vector* wX, const gsl_vector* wY);
166
167template <>
169 modified_van_mises_distance_sq_derivative(gsl_matrix_float* gradient,
170 const gsl_vector_float* y,
171 size_t L, size_t N, size_t bMax,
172 gsl_vector_float* x,
173 const gsl_vector_float* wX,
174 const gsl_vector_float* wY);
175
176template <>
178 double>::modified_van_mises_distance_sq_derivative(gsl_matrix* gradient,
179 const gsl_vector* y,
180 size_t L, size_t N,
181 size_t bMax,
182 gsl_vector* x,
183 const gsl_vector* wX,
184 const gsl_vector* wY);
185
186extern template class dirac_to_dirac_approx_short<double>;
187extern template class dirac_to_dirac_approx_short<float>;
188
189#endif // DIRAC_TO_DIRAC_SHORT_H
interface for the gausian mixture to dirac approximation
Definition dirac_to_dirac_approx_i.h:20
Definition dirac_to_dirac_approx_short.h:9
bool approximate(const GSLVectorType *y, size_t L, size_t N, GSLVectorType *x, const GSLVectorType *wX=nullptr, const GSLVectorType *wY=nullptr, GslminimizerResult *result=nullptr, const ApproximateOptions &options=ApproximateOptions{}) override
reduce the data points using gsl vectors
void modified_van_mises_distance_sq(T *distance, const T *y, size_t M, size_t L, size_t N, size_t bMax, T *x, const T *wX=nullptr, const T *wY=nullptr) override
calculate modified van mises distance based on x and y
Definition dirac_to_dirac_approx_short.cpp:39
void modified_van_mises_distance_sq_derivative(T *gradient, const T *y, size_t M, size_t L, size_t N, size_t bMax, T *x, const T *wX=nullptr, const T *wY=nullptr) override
calculate modified van mises distance based on x and y
Definition dirac_to_dirac_approx_short.cpp:54
bool approximate(const T *y, size_t M, size_t L, size_t N, T *x, const T *wX=nullptr, const T *wY=nullptr, GslminimizerResult *result=nullptr, const ApproximateOptions &options=ApproximateOptions{}) override
reduce the data points using raw pointers
Definition dirac_to_dirac_approx_short.cpp:24
void modified_van_mises_distance_sq_derivative(GSLMatrixType *gradient, const GSLVectorType *y, size_t L, size_t N, size_t bMax, GSLVectorType *x, const GSLVectorType *wX=nullptr, const GSLVectorType *wY=nullptr) override
calculate modified van mises distance based on x and y
void modified_van_mises_distance_sq(T *distance, const GSLVectorType *y, size_t L, size_t N, size_t bMax, GSLVectorType *x, const GSLVectorType *wX=nullptr, const GSLVectorType *wY=nullptr) override
calculate modified van mises distance based on x and y
Definition approximate_options.h:6
struct to hold the result of the minimization
Definition gsl_minimizer.h:32