1/*
2 * sort_ofp.cuh
3 *
4 * Created on: Aug 23, 2019
5 * Author: i-bird
6 */
7
8#ifndef SORT_OFP_CUH_
9#define SORT_OFP_CUH_
10
11
12#ifdef __NVCC__
13
14#include "util/cuda_launch.hpp"
15
16#if CUDART_VERSION >= 11000
17 #ifndef CUDA_ON_CPU
18 // Here we have for sure CUDA >= 11
19 #include "cub/cub.cuh"
20 #ifndef SORT_WITH_CUB
21 #define SORT_WITH_CUB
22 #endif
23 #endif
24#else
25 // Here we have old CUDA
26 #include "cub_old/cub.cuh"
27 #include "util/cuda/moderngpu/kernel_mergesort.hxx"
28#endif
29
30#include "util/cuda/ofp_context.hxx"
31
32template<typename key_t, typename val_t>
33struct key_val_ref;
34
35template<typename key_t, typename val_t>
36struct key_val
37{
38 key_t key;
39 val_t val;
40
41 key_val(const key_t & k, const val_t & v)
42 :key(k),val(v)
43 {}
44
45 key_val(const key_val_ref<key_t,val_t> & tmp)
46 {
47 this->operator=(tmp);
48 }
49
50 bool operator<(const key_val & tmp) const
51 {
52 return key < tmp.key;
53 }
54
55 bool operator>(const key_val & tmp) const
56 {
57 return key > tmp.key;
58 }
59
60 key_val & operator=(const key_val_ref<key_t,val_t> & tmp)
61 {
62 key = tmp.key;
63 val = tmp.val;
64
65 return *this;
66 }
67};
68
69
70template<typename key_t, typename val_t>
71struct key_val_ref
72{
73 key_t & key;
74 val_t & val;
75
76 key_val_ref(key_t & k, val_t & v)
77 :key(k),val(v)
78 {}
79
80 key_val_ref(key_val_ref<key_t,val_t> && tmp)
81 :key(tmp.key),val(tmp.val)
82 {}
83
84 key_val_ref & operator=(const key_val<key_t,val_t> & tmp)
85 {
86 key = tmp.key;
87 val = tmp.val;
88
89 return *this;
90 }
91
92 key_val_ref & operator=(const key_val_ref<key_t,val_t> & tmp)
93 {
94 key = tmp.key;
95 val = tmp.val;
96
97 return *this;
98 }
99
100 bool operator<(const key_val_ref<key_t,val_t> & tmp)
101 {
102 return key < tmp.key;
103 }
104
105 bool operator>(const key_val_ref<key_t,val_t> & tmp)
106 {
107 return key > tmp.key;
108 }
109
110 bool operator<(const key_val<key_t,val_t> & tmp)
111 {
112 return key < tmp.key;
113 }
114
115 bool operator>(const key_val<key_t,val_t> & tmp)
116 {
117 return key > tmp.key;
118 }
119};
120
121
122template<typename key_t, typename val_t>
123struct key_val_it
124{
125 key_t * key;
126 val_t * val;
127
128 bool operator==(const key_val_it & tmp)
129 {
130 return (key == tmp.key && val == tmp.val);
131 }
132
133 key_val_ref<key_t,val_t> operator*()
134 {
135 return key_val_ref<key_t,val_t>(*key,*val);
136 }
137
138 key_val_ref<key_t,val_t> operator[](int i)
139 {
140 return key_val_ref<key_t,val_t>(*key,*val);
141 }
142
143 key_val_it operator+(size_t count) const
144 {
145 key_val_it tmp(key+count,val+count);
146
147 return tmp;
148 }
149
150
151 size_t operator-(key_val_it & tmp) const
152 {
153 return key - tmp.key;
154 }
155
156 key_val_it operator-(size_t count) const
157 {
158 key_val_it tmp(key-count,val-count);
159
160 return tmp;
161 }
162
163 key_val_it & operator++()
164 {
165 ++key;
166 ++val;
167
168 return *this;
169 }
170
171 key_val_it & operator--()
172 {
173 --key;
174 --val;
175
176 return *this;
177 }
178
179 bool operator!=(const key_val_it & tmp) const
180 {
181 return key != tmp.key && val != tmp.val;
182 }
183
184 bool operator<(const key_val_it & tmp) const
185 {
186 return key < tmp.key;
187 }
188
189 key_val_it<key_t,val_t> & operator=(key_val_it<key_t,val_t> & tmp)
190 {
191 key = tmp.key;
192 val = tmp.val;
193
194 return *this;
195 }
196
197 key_val_it() {}
198
199 key_val_it(const key_val_it<key_t,val_t> & tmp)
200 :key(tmp.key),val(tmp.val)
201 {}
202
203 key_val_it(key_t * key, val_t * val)
204 :key(key),val(val)
205 {}
206};
207
208template<typename key_t, typename val_t>
209void swap(key_val_ref<key_t,val_t> a, key_val_ref<key_t,val_t> b)
210{
211 key_t kt = a.key;
212 a.key = b.key;
213 b.key = kt;
214
215 val_t vt = a.val;
216 a.val = b.val;
217 b.val = vt;
218}
219
220namespace std
221{
222 template<typename key_t, typename val_t>
223 struct iterator_traits<key_val_it<key_t,val_t>>
224 {
225 typedef size_t difference_type; //almost always ptrdiff_t
226 typedef key_val<key_t,val_t> value_type; //almost always T
227 typedef key_val<key_t,val_t> & reference; //almost always T& or const T&
228 typedef key_val<key_t,val_t> & pointer; //almost always T* or const T*
229 typedef std::random_access_iterator_tag iterator_category; //usually std::forward_iterator_tag or similar
230 };
231}
232
233
234namespace openfpm
235{
236 template<typename key_t, typename val_t,
237 typename comp_t>
238 void sort(key_t* keys_input, val_t* vals_input, int count,
239 comp_t comp, mgpu::ofp_context_t& context)
240 {
241#ifdef CUDA_ON_CPU
242
243 key_val_it<key_t,val_t> kv(keys_input,vals_input);
244
245 std::sort(kv,kv+count,comp);
246
247#else
248
249 #ifdef SORT_WITH_CUB
250
251 void *d_temp_storage = NULL;
252 size_t temp_storage_bytes = 0;
253
254 auto & temporal2 = context.getTemporalCUB2();
255 temporal2.resize(sizeof(key_t)*count);
256
257 auto & temporal3 = context.getTemporalCUB3();
258 temporal3.resize(sizeof(val_t)*count);
259
260 if (std::is_same<mgpu::template less_t<key_t>,comp_t>::value == true)
261 {
262 cub::DeviceRadixSort::SortPairs(d_temp_storage,
263 temp_storage_bytes,
264 keys_input,
265 (key_t *)temporal2.template getDeviceBuffer<0>(),
266 vals_input,
267 (val_t *)temporal3.template getDeviceBuffer<0>(),
268 count);
269
270 auto & temporal = context.getTemporalCUB();
271 temporal.resize(temp_storage_bytes);
272
273 d_temp_storage = temporal.template getDeviceBuffer<0>();
274
275 // Run
276 cub::DeviceRadixSort::SortPairs(d_temp_storage,
277 temp_storage_bytes,
278 keys_input,
279 (key_t *)temporal2.template getDeviceBuffer<0>(),
280 vals_input,
281 (val_t *)temporal3.template getDeviceBuffer<0>(),
282 count);
283 }
284 else if (std::is_same<mgpu::template greater_t<key_t>,comp_t>::value == true)
285 {
286 cub::DeviceRadixSort::SortPairsDescending(d_temp_storage,
287 temp_storage_bytes,
288 keys_input,
289 (key_t *)temporal2.template getDeviceBuffer<0>(),
290 vals_input,
291 (val_t *)temporal3.template getDeviceBuffer<0>(),
292 count);
293
294 auto & temporal = context.getTemporalCUB();
295 temporal.resize(temp_storage_bytes);
296
297 d_temp_storage = temporal.template getDeviceBuffer<0>();
298
299 // Run
300 cub::DeviceRadixSort::SortPairsDescending(d_temp_storage,
301 temp_storage_bytes,
302 keys_input,
303 (key_t *)temporal2.template getDeviceBuffer<0>(),
304 vals_input,
305 (val_t *)temporal3.template getDeviceBuffer<0>(),
306 count);
307 }
308
309 cudaMemcpy(keys_input,temporal2.getDeviceBuffer<0>(),sizeof(key_t)*count,cudaMemcpyDeviceToDevice);
310 cudaMemcpy(vals_input,temporal3.getDeviceBuffer<0>(),sizeof(val_t)*count,cudaMemcpyDeviceToDevice);
311
312 #else
313 mgpu::mergesort(keys_input,vals_input,count,comp,context);
314 #endif
315
316#endif
317 }
318}
319
320#endif
321
322
323#endif /* SORT_OFP_CUH_ */
324