Deterministic Gaussian Sampling
Loading...
Searching...
No Matches
dirac_to_dirac_approx_short_thread.h
1#ifndef DIRAC_TO_DIRAC_SHORT_THREAD_H
2#define DIRAC_TO_DIRAC_SHORT_THREAD_H
3
4#include <gtest/gtest.h>
5
6#include "dirac_to_dirac_approx_i.h"
7
8template <typename T>
10 public:
11 using GSLVectorType = typename dirac_to_dirac_approx_i<T>::GSLVectorType;
12 using GSLVectorViewType =
13 typename dirac_to_dirac_approx_i<T>::GSLVectorViewType;
14 using GSLMatrixType = typename dirac_to_dirac_approx_i<T>::GSLMatrixType;
15 using GSLMatrixViewType = typename dirac_to_dirac_approx_i<T>::GSLMatrixViewType;
16
17 // clang-format off
18 bool approximate(const T* y,
19 size_t M,
20 size_t L,
21 size_t N,
22 size_t bMax,
23 T* x,
24 const T* wX = nullptr,
25 const T* wY = nullptr,
26 GslminimizerResult* result = nullptr,
27 const ApproximateOptions& options = ApproximateOptions{}) override;
28 // clang-format on
29
30 // clang-format off
31 void modified_van_mises_distance_sq(T* distance,
32 const T *y,
33 size_t M,
34 size_t L,
35 size_t N,
36 size_t bMax,
37 T *x,
38 const T *wX = nullptr,
39 const T *wY = nullptr) override;
40 // clang-format on
41
42 // clang-format off
44 const T *y,
45 size_t M,
46 size_t L,
47 size_t N,
48 size_t bMax,
49 T *x,
50 const T *wX = nullptr,
51 const T *wY = nullptr) override;
52 // clang-format on
53
54 // clang-format off
55 bool approximate(const GSLVectorType* y,
56 size_t L,
57 size_t N,
58 size_t bMax,
59 GSLVectorType* x,
60 const GSLVectorType* wX = nullptr,
61 const GSLVectorType* wY = nullptr,
62 GslminimizerResult* result = nullptr,
63 const ApproximateOptions& options = ApproximateOptions{}) override;
64 // clang-format on
65
66 // clang-format off
68 const GSLVectorType *y,
69 size_t L,
70 size_t N,
71 size_t bMax,
72 GSLVectorType *x,
73 const GSLVectorType *wX = nullptr,
74 const GSLVectorType *wY = nullptr) override;
75 // clang-format on
76
77 // clang-format off
78 void modified_van_mises_distance_sq_derivative(GSLMatrixType* gradient,
79 const GSLVectorType *y,
80 size_t L,
81 size_t N,
82 size_t bMax,
83 GSLVectorType *x,
84 const GSLVectorType *wX = nullptr,
85 const GSLVectorType *wY = nullptr) override;
86 // clang-format on
87
88 // clang-format off
89 bool approximate(GSLMatrixType* y,
90 size_t L,
91 size_t bMax,
92 GSLMatrixType* x,
93 const GSLVectorType* wX = nullptr,
94 const GSLVectorType* wY = nullptr,
95 GslminimizerResult* result = nullptr,
96 const ApproximateOptions& options = ApproximateOptions{}) override;
97 // clang-format on
98
99 // clang-format off
100 void modified_van_mises_distance_sq(T* distance,
101 GSLMatrixType *y,
102 size_t L,
103 size_t bMax,
104 GSLMatrixType *x,
105 const GSLVectorType *wX = nullptr,
106 const GSLVectorType *wY = nullptr) override;
107 // clang-format on
108
109 // clang-format off
110 void modified_van_mises_distance_sq_derivative(GSLMatrixType* gradient,
111 GSLMatrixType *y,
112 size_t L,
113 size_t bMax,
114 GSLMatrixType *x,
115 const GSLVectorType *wX = nullptr,
116 const GSLVectorType *wY = nullptr) override;
117 // clang-format on
118
119 private:
120 static double c_b(size_t bMax);
121 static double modified_van_mises_distance_sq(const gsl_vector* x,
122 void* params);
123 static void modified_van_mises_distance_sq_derivative(const gsl_vector* x,
124 void* params,
125 gsl_vector* grad);
126 static void combined_distance_metric(const gsl_vector* x, void* params,
127 double* f, gsl_vector* grad);
128
129 static inline void correctMean(const gsl_vector* meanY, gsl_vector* x,
130 const gsl_vector* wX, size_t L, size_t N);
131
132 FRIEND_TEST(
133 dirac_to_dirac_approx_short_test_modified_van_mises_distance_sq_derivative,
134 parameterized_test_modified_van_mises_distance_sq_derivative);
135 FRIEND_TEST(dirac_to_dirac_approx_short_test_combined,
136 parameterized_test_combined);
137 friend class testero;
138 friend class benchmark_dirac_to_dirac_approx_short_thread;
139};
140
141template <>
143 const gsl_vector_float* y, size_t L, size_t N, size_t bMax,
144 gsl_vector_float* x, const GSLVectorType* wX, const GSLVectorType* wY,
145 GslminimizerResult* result, const ApproximateOptions& options);
146
147template <>
149 const gsl_vector* y, size_t L, size_t N, size_t bMax, gsl_vector* x,
150 const GSLVectorType* wX, const GSLVectorType* wY,
151 GslminimizerResult* result, const ApproximateOptions& options);
152
153template <>
155 float* distance, const gsl_vector_float* y, size_t L, size_t N, size_t bMax,
156 gsl_vector_float* x, const gsl_vector_float* wX,
157 const gsl_vector_float* wY);
158
159template <>
161 double* distance, const gsl_vector* y, size_t L, size_t N, size_t bMax,
162 gsl_vector* x, const gsl_vector* wX, const gsl_vector* wY);
163
164template <>
166 modified_van_mises_distance_sq_derivative(gsl_matrix_float* gradient,
167 const gsl_vector_float* y,
168 size_t L, size_t N, size_t bMax,
169 gsl_vector_float* x,
170 const gsl_vector_float* wX,
171 const gsl_vector_float* wY);
172
173template <>
175 double>::modified_van_mises_distance_sq_derivative(gsl_matrix* gradient,
176 const gsl_vector* y,
177 size_t L, size_t N,
178 size_t bMax,
179 gsl_vector* x,
180 const gsl_vector* wX,
181 const gsl_vector* wY);
182
185
186#endif // DIRAC_TO_DIRAC_SHORT_THREAD_H
interface for the gausian mixture to dirac approximation
Definition dirac_to_dirac_approx_i.h:20
Definition dirac_to_dirac_approx_short_thread.h:9
void modified_van_mises_distance_sq(T *distance, const T *y, size_t M, size_t L, size_t N, size_t bMax, T *x, const T *wX=nullptr, const T *wY=nullptr) override
calculate modified van mises distance based on x and y
Definition dirac_to_dirac_approx_short_thread.cpp:41
void modified_van_mises_distance_sq_derivative(T *gradient, const T *y, size_t M, size_t L, size_t N, size_t bMax, T *x, const T *wX=nullptr, const T *wY=nullptr) override
calculate modified van mises distance based on x and y
Definition dirac_to_dirac_approx_short_thread.cpp:57
bool approximate(const GSLVectorType *y, size_t L, size_t N, size_t bMax, GSLVectorType *x, const GSLVectorType *wX=nullptr, const GSLVectorType *wY=nullptr, GslminimizerResult *result=nullptr, const ApproximateOptions &options=ApproximateOptions{}) override
reduce the data points using gsl vectors
bool approximate(const T *y, size_t M, size_t L, size_t N, size_t bMax, T *x, const T *wX=nullptr, const T *wY=nullptr, GslminimizerResult *result=nullptr, const ApproximateOptions &options=ApproximateOptions{}) override
reduce the data points using raw pointers
Definition dirac_to_dirac_approx_short_thread.cpp:25
void modified_van_mises_distance_sq_derivative(GSLMatrixType *gradient, const GSLVectorType *y, size_t L, size_t N, size_t bMax, GSLVectorType *x, const GSLVectorType *wX=nullptr, const GSLVectorType *wY=nullptr) override
calculate modified van mises distance based on x and y
void modified_van_mises_distance_sq(T *distance, const GSLVectorType *y, size_t L, size_t N, size_t bMax, GSLVectorType *x, const GSLVectorType *wX=nullptr, const GSLVectorType *wY=nullptr) override
calculate modified van mises distance based on x and y
Definition approximate_options.h:6
struct to hold the result of the minimization
Definition gsl_minimizer.h:32