Deterministic Gaussian Sampling
Loading...
Searching...
No Matches
gsl_utils_view_helper.h
1#ifndef GSL_UTILS_VIEW_HELPER_H
2#define GSL_UTILS_VIEW_HELPER_H
3
4#include <type_traits>
5
6#include "gsl_vector_matrix_types.h"
7
8template <typename T, bool IsMatrix>
10 static_assert(std::is_same<T, float>::value || std::is_same<T, double>::value,
11 "Only float and double supported");
12
13 public:
14 using GSLVectorType = typename GSLTemplateTypeAlias<T>::VectorType;
15 using GSLVectorViewType = typename GSLTemplateTypeAlias<T>::VectorViewType;
16 using GSLMatrixType = typename GSLTemplateTypeAlias<T>::MatrixType;
17 using GSLMatrixViewType = typename GSLTemplateTypeAlias<T>::MatrixViewType;
18
19 using GSLType =
20 typename std::conditional<IsMatrix, GSLMatrixType, GSLVectorType>::type;
21
22 using ViewType = typename std::conditional<IsMatrix, GSLMatrixViewType,
23 GSLVectorViewType>::type;
24
25 /**************************************************************************/
26 /********************************* pointer ********************************/
27 /**************************************************************************/
28 template <typename U>
29 GSLViewHelper(const U* ptr, size_t size) {
30 static_assert(!IsMatrix, "Vector constructor used for matrix");
31 static_assert(is_float_or_double<U>(), "Only float/double allowed");
32
33 if (!ptr) {
34 _ptr = nullptr;
35 return;
36 }
37
38 construct_vector_from_ptr(ptr, size);
39 }
40
41 template <typename U>
42 GSLViewHelper(const U* ptr, size_t rows, size_t cols) {
43 static_assert(IsMatrix, "Matrix constructor used for vector");
44 static_assert(is_float_or_double<U>(), "Only float/double allowed");
45
46 if (!ptr) {
47 _ptr = nullptr;
48 return;
49 }
50
51 construct_matrix_from_ptr(ptr, rows, cols);
52 }
53
54 /**************************************************************************/
55 /********************************* vector *********************************/
56 /**************************************************************************/
57 GSLViewHelper(const gsl_vector* v, size_t rows = 0, size_t cols = 0) {
58 if (!v) {
59 _ptr = nullptr;
60 return;
61 }
62
63 if constexpr (!IsMatrix) {
64 // internal storage = vector
66 } else {
67 // internal storage = matrix
69 }
70 }
71
72 GSLViewHelper(const gsl_vector_float* v, size_t rows = 0, size_t cols = 0) {
73 if (!v) {
74 _ptr = nullptr;
75 return;
76 }
77
78 if constexpr (!IsMatrix) {
79 // internal storage = vector
81 } else {
82 // internal storage = matrix
84 }
85 }
86
87 /**************************************************************************/
88 /********************************* matrix *********************************/
89 /**************************************************************************/
90 GSLViewHelper(const gsl_matrix* m) {
91 if (!m) {
92 _ptr = nullptr;
93 return;
94 }
95
96 if constexpr (IsMatrix) {
97 // internal storage = matrix
99 } else {
100 // internal storage = vector
102 }
103 }
104
106 if (!m) {
107 _ptr = nullptr;
108 return;
109 }
110
111 if constexpr (IsMatrix) {
112 // internal storage = matrix
114 } else {
115 // internal storage = vector
117 }
118 }
119
120 /**************************************************************************/
121 /******************************* destructor *******************************/
122 /**************************************************************************/
124 if (!_freeMemory || !_ptr) return;
125
126 if constexpr (IsMatrix)
128 else
130 }
131
132 /**************************************************************************/
133 /********************************* access *********************************/
134 /**************************************************************************/
135 GSLType* get() { return _ptr; }
136 const GSLType* get() const { return _ptr; }
137
138 operator GSLType*() { return _ptr; }
139 operator const GSLType*() const { return _ptr; }
140
141 private:
142 template <typename U>
143 void construct_vector_from_ptr(const U* ptr, size_t size) {
144 if constexpr (std::is_same<U, T>::value) {
146 _ptr = &_view.vector;
147 } else {
149 _freeMemory = true;
150
151 for (size_t i = 0; i < size; ++i) _ptr->data[i] = static_cast<T>(ptr[i]);
152 }
153 }
154
155 template <typename U>
156 void construct_vector_from_vector(
157 const typename GSLTemplateTypeAlias<U>::VectorType* v) {
158 if (!v) {
159 _ptr = nullptr;
160 return;
161 }
162
163 if constexpr (std::is_same<U, T>::value) {
164 _ptr = const_cast<GSLType*>(v);
165 } else {
167 _freeMemory = true;
168
169 for (size_t i = 0; i < v->size; ++i)
170 _ptr->data[i] = static_cast<T>(v->data[i]);
171 }
172 }
173
174 template <typename U>
175 void construct_matrix_from_vector(
176 const typename GSLTemplateTypeAlias<U>::VectorType* v, size_t rows,
177 size_t cols) {
178 if (!v) {
179 _ptr = nullptr;
180 return;
181 }
182
183 if (rows == 0 || cols == 0)
184 throw std::runtime_error(
185 "Matrix construction from vector requires rows and cols");
186
187 if (v->size != rows * cols)
188 throw std::runtime_error("Size mismatch in reshape");
189
190 if constexpr (std::is_same<T, U>::value) {
191 _view =
193
194 _ptr = &_view.matrix;
195 } else {
197 _freeMemory = true;
198
199 for (size_t i = 0; i < v->size; ++i)
200 _ptr->data[i] = static_cast<T>(v->data[i]);
201 }
202 }
203
204 template <typename U>
205 void construct_matrix_from_ptr(const U* ptr, size_t rows, size_t cols) {
206 if constexpr (std::is_same<U, T>::value) {
208 _ptr = &_view.matrix;
209 } else {
211 _freeMemory = true;
212
213 size_t total = rows * cols;
214 for (size_t i = 0; i < total; ++i) _ptr->data[i] = static_cast<T>(ptr[i]);
215 }
216 }
217
218 template <typename U>
219 void construct_matrix_from_matrix(
220 const typename GSLTemplateTypeAlias<U>::MatrixType* m) {
221 if (!m) {
222 _ptr = nullptr;
223 return;
224 }
225
226 if constexpr (std::is_same<U, T>::value) {
227 _ptr = const_cast<GSLType*>(m);
228 } else {
229 _ptr = GSLTemplateTypeAlias<T>::allocate_matrix(m->size1, m->size2);
230 _freeMemory = true;
231
232 size_t total = m->size1 * m->size2;
233 for (size_t i = 0; i < total; ++i)
234 _ptr->data[i] = static_cast<T>(m->data[i]);
235 }
236 }
237
238 template <typename U>
239 void construct_vector_from_matrix(
240 const typename GSLTemplateTypeAlias<U>::MatrixType* m) {
241 if (!m) {
242 _ptr = nullptr;
243 return;
244 }
245
246 if constexpr (std::is_same<T, U>::value) {
248 _ptr = &_view.vector;
249 } else {
250 const size_t total = m->size1 * m->size2;
252 _freeMemory = true;
253
254 for (size_t i = 0; i < total; ++i)
255 _ptr->data[i] = static_cast<T>(m->data[i]);
256 }
257 }
258
259 template <typename U>
260 static constexpr bool is_float_or_double() {
261 return std::is_same<U, float>::value || std::is_same<U, double>::value;
262 }
263
264 private:
265 GSLType* _ptr = nullptr;
266 bool _freeMemory = false;
267
268 ViewType _view{};
269};
270
271template <typename T>
273
274template <typename T>
276
277#endif // GSL_UTILS_VIEW_HELPER_H
Definition gsl_vector_matrix_types.h:13
Definition gsl_utils_view_helper.h:9
Definition dirac_to_dirac_approx_short.h:9