Deterministic Gaussian Sampling
Loading...
Searching...
No Matches
dirac_to_dirac_approx_short_thread.h
1#ifndef DIRAC_TO_DIRAC_SHORT_THREAD_H
2#define DIRAC_TO_DIRAC_SHORT_THREAD_H
3
4#include <gtest/gtest.h>
5
6#include "dirac_to_dirac_approx_i.h"
7
8template <typename T>
10 public:
11 using GSLVectorType = typename dirac_to_dirac_approx_i<T>::GSLVectorType;
12 using GSLVectorViewType =
13 typename dirac_to_dirac_approx_i<T>::GSLVectorViewType;
14 using GSLMatrixType = typename dirac_to_dirac_approx_i<T>::GSLMatrixType;
15 using GSLMatrixViewType =
16 typename dirac_to_dirac_approx_i<T>::GSLMatrixViewType;
17
18 // clang-format off
19 bool approximate(const T* y,
20 size_t M,
21 size_t L,
22 size_t N,
23 T* x,
24 const T* wX = nullptr,
25 const T* wY = nullptr,
26 GslminimizerResult* result = nullptr,
27 const ApproximateOptions& options = ApproximateOptions{}) override;
28 // clang-format on
29
30 // clang-format off
31 void modified_van_mises_distance_sq(T* distance,
32 const T *y,
33 size_t M,
34 size_t L,
35 size_t N,
36 size_t bMax,
37 T *x,
38 const T *wX = nullptr,
39 const T *wY = nullptr) override;
40 // clang-format on
41
42 // clang-format off
44 const T *y,
45 size_t M,
46 size_t L,
47 size_t N,
48 size_t bMax,
49 T *x,
50 const T *wX = nullptr,
51 const T *wY = nullptr) override;
52 // clang-format on
53
54 // clang-format off
55 bool approximate(const GSLVectorType* y,
56 size_t L,
57 size_t N,
58 GSLVectorType* x,
59 const GSLVectorType* wX = nullptr,
60 const GSLVectorType* wY = nullptr,
61 GslminimizerResult* result = nullptr,
62 const ApproximateOptions& options = ApproximateOptions{}) override;
63 // clang-format on
64
65 // clang-format off
67 const GSLVectorType *y,
68 size_t L,
69 size_t N,
70 size_t bMax,
71 GSLVectorType *x,
72 const GSLVectorType *wX = nullptr,
73 const GSLVectorType *wY = nullptr) override;
74 // clang-format on
75
76 // clang-format off
77 void modified_van_mises_distance_sq_derivative(GSLMatrixType* gradient,
78 const GSLVectorType *y,
79 size_t L,
80 size_t N,
81 size_t bMax,
82 GSLVectorType *x,
83 const GSLVectorType *wX = nullptr,
84 const GSLVectorType *wY = nullptr) override;
85 // clang-format on
86
87 // clang-format off
88 bool approximate(GSLMatrixType* y,
89 size_t L,
90 GSLMatrixType* x,
91 const GSLVectorType* wX = nullptr,
92 const GSLVectorType* wY = nullptr,
93 GslminimizerResult* result = nullptr,
94 const ApproximateOptions& options = ApproximateOptions{}) override;
95 // clang-format on
96
97 // clang-format off
98 void modified_van_mises_distance_sq(T* distance,
99 GSLMatrixType *y,
100 size_t L,
101 size_t bMax,
102 GSLMatrixType *x,
103 const GSLVectorType *wX = nullptr,
104 const GSLVectorType *wY = nullptr) override;
105 // clang-format on
106
107 // clang-format off
108 void modified_van_mises_distance_sq_derivative(GSLMatrixType* gradient,
109 GSLMatrixType *y,
110 size_t L,
111 size_t bMax,
112 GSLMatrixType *x,
113 const GSLVectorType *wX = nullptr,
114 const GSLVectorType *wY = nullptr) override;
115 // clang-format on
116
117 private:
118 static double c_b(size_t bMax);
119 static double modified_van_mises_distance_sq(const gsl_vector* x,
120 void* params);
121 static void modified_van_mises_distance_sq_derivative(const gsl_vector* x,
122 void* params,
123 gsl_vector* grad);
124 static void combined_distance_metric(const gsl_vector* x, void* params,
125 double* f, gsl_vector* grad);
126
127 static inline void correctMean(const gsl_vector* meanY, gsl_vector* x,
128 const gsl_vector* wX, size_t L, size_t N);
129
130 FRIEND_TEST(
131 dirac_to_dirac_approx_short_test_modified_van_mises_distance_sq_derivative,
132 parameterized_test_modified_van_mises_distance_sq_derivative);
133 FRIEND_TEST(dirac_to_dirac_approx_short_test_combined,
134 parameterized_test_combined);
135 friend class testero;
136 friend class benchmark_dirac_to_dirac_approx_short_thread;
137};
138
139template <>
141 const gsl_vector_float* y, size_t L, size_t N, gsl_vector_float* x,
142 const GSLVectorType* wX, const GSLVectorType* wY,
143 GslminimizerResult* result, const ApproximateOptions& options);
144
145template <>
147 const gsl_vector* y, size_t L, size_t N, gsl_vector* x,
148 const GSLVectorType* wX, const GSLVectorType* wY,
149 GslminimizerResult* result, const ApproximateOptions& options);
150
151template <>
153 float* distance, const gsl_vector_float* y, size_t L, size_t N, size_t bMax,
154 gsl_vector_float* x, const gsl_vector_float* wX,
155 const gsl_vector_float* wY);
156
157template <>
159 double* distance, const gsl_vector* y, size_t L, size_t N, size_t bMax,
160 gsl_vector* x, const gsl_vector* wX, const gsl_vector* wY);
161
162template <>
164 modified_van_mises_distance_sq_derivative(gsl_matrix_float* gradient,
165 const gsl_vector_float* y,
166 size_t L, size_t N, size_t bMax,
167 gsl_vector_float* x,
168 const gsl_vector_float* wX,
169 const gsl_vector_float* wY);
170
171template <>
173 double>::modified_van_mises_distance_sq_derivative(gsl_matrix* gradient,
174 const gsl_vector* y,
175 size_t L, size_t N,
176 size_t bMax,
177 gsl_vector* x,
178 const gsl_vector* wX,
179 const gsl_vector* wY);
180
183
184#endif // DIRAC_TO_DIRAC_SHORT_THREAD_H
interface for the gausian mixture to dirac approximation
Definition dirac_to_dirac_approx_i.h:20
Definition dirac_to_dirac_approx_short_thread.h:9
void modified_van_mises_distance_sq(T *distance, const T *y, size_t M, size_t L, size_t N, size_t bMax, T *x, const T *wX=nullptr, const T *wY=nullptr) override
calculate modified van mises distance based on x and y
Definition dirac_to_dirac_approx_short_thread.cpp:40
void modified_van_mises_distance_sq_derivative(T *gradient, const T *y, size_t M, size_t L, size_t N, size_t bMax, T *x, const T *wX=nullptr, const T *wY=nullptr) override
calculate modified van mises distance based on x and y
Definition dirac_to_dirac_approx_short_thread.cpp:56
bool approximate(const T *y, size_t M, size_t L, size_t N, T *x, const T *wX=nullptr, const T *wY=nullptr, GslminimizerResult *result=nullptr, const ApproximateOptions &options=ApproximateOptions{}) override
reduce the data points using raw pointers
Definition dirac_to_dirac_approx_short_thread.cpp:25
bool approximate(const GSLVectorType *y, size_t L, size_t N, GSLVectorType *x, const GSLVectorType *wX=nullptr, const GSLVectorType *wY=nullptr, GslminimizerResult *result=nullptr, const ApproximateOptions &options=ApproximateOptions{}) override
reduce the data points using gsl vectors
void modified_van_mises_distance_sq_derivative(GSLMatrixType *gradient, const GSLVectorType *y, size_t L, size_t N, size_t bMax, GSLVectorType *x, const GSLVectorType *wX=nullptr, const GSLVectorType *wY=nullptr) override
calculate modified van mises distance based on x and y
void modified_van_mises_distance_sq(T *distance, const GSLVectorType *y, size_t L, size_t N, size_t bMax, GSLVectorType *x, const GSLVectorType *wX=nullptr, const GSLVectorType *wY=nullptr) override
calculate modified van mises distance based on x and y
Definition approximate_options.h:6
struct to hold the result of the minimization
Definition gsl_minimizer.h:32