1 | /* |
2 | * sort_ofp.cuh |
3 | * |
4 | * Created on: Aug 23, 2019 |
5 | * Author: i-bird |
6 | */ |
7 | |
8 | #ifndef SORT_OFP_CUH_ |
9 | #define SORT_OFP_CUH_ |
10 | |
11 | |
12 | #ifdef __NVCC__ |
13 | |
14 | #include "util/cuda_launch.hpp" |
15 | |
16 | #if CUDART_VERSION >= 11000 |
17 | #ifndef CUDA_ON_CPU |
18 | // Here we have for sure CUDA >= 11 |
19 | #include "cub/cub.cuh" |
20 | #ifndef SORT_WITH_CUB |
21 | #define SORT_WITH_CUB |
22 | #endif |
23 | #endif |
24 | #else |
25 | // Here we have old CUDA |
26 | #include "cub_old/cub.cuh" |
27 | #include "util/cuda/moderngpu/kernel_mergesort.hxx" |
28 | #endif |
29 | |
30 | #include "util/cuda/ofp_context.hxx" |
31 | |
32 | template<typename key_t, typename val_t> |
33 | struct key_val_ref; |
34 | |
35 | template<typename key_t, typename val_t> |
36 | struct key_val |
37 | { |
38 | key_t key; |
39 | val_t val; |
40 | |
41 | key_val(const key_t & k, const val_t & v) |
42 | :key(k),val(v) |
43 | {} |
44 | |
45 | key_val(const key_val_ref<key_t,val_t> & tmp) |
46 | { |
47 | this->operator=(tmp); |
48 | } |
49 | |
50 | bool operator<(const key_val & tmp) const |
51 | { |
52 | return key < tmp.key; |
53 | } |
54 | |
55 | bool operator>(const key_val & tmp) const |
56 | { |
57 | return key > tmp.key; |
58 | } |
59 | |
60 | key_val & operator=(const key_val_ref<key_t,val_t> & tmp) |
61 | { |
62 | key = tmp.key; |
63 | val = tmp.val; |
64 | |
65 | return *this; |
66 | } |
67 | }; |
68 | |
69 | |
70 | template<typename key_t, typename val_t> |
71 | struct key_val_ref |
72 | { |
73 | key_t & key; |
74 | val_t & val; |
75 | |
76 | key_val_ref(key_t & k, val_t & v) |
77 | :key(k),val(v) |
78 | {} |
79 | |
80 | key_val_ref(key_val_ref<key_t,val_t> && tmp) |
81 | :key(tmp.key),val(tmp.val) |
82 | {} |
83 | |
84 | key_val_ref & operator=(const key_val<key_t,val_t> & tmp) |
85 | { |
86 | key = tmp.key; |
87 | val = tmp.val; |
88 | |
89 | return *this; |
90 | } |
91 | |
92 | key_val_ref & operator=(const key_val_ref<key_t,val_t> & tmp) |
93 | { |
94 | key = tmp.key; |
95 | val = tmp.val; |
96 | |
97 | return *this; |
98 | } |
99 | |
100 | bool operator<(const key_val_ref<key_t,val_t> & tmp) |
101 | { |
102 | return key < tmp.key; |
103 | } |
104 | |
105 | bool operator>(const key_val_ref<key_t,val_t> & tmp) |
106 | { |
107 | return key > tmp.key; |
108 | } |
109 | |
110 | bool operator<(const key_val<key_t,val_t> & tmp) |
111 | { |
112 | return key < tmp.key; |
113 | } |
114 | |
115 | bool operator>(const key_val<key_t,val_t> & tmp) |
116 | { |
117 | return key > tmp.key; |
118 | } |
119 | }; |
120 | |
121 | |
122 | template<typename key_t, typename val_t> |
123 | struct key_val_it |
124 | { |
125 | key_t * key; |
126 | val_t * val; |
127 | |
128 | bool operator==(const key_val_it & tmp) |
129 | { |
130 | return (key == tmp.key && val == tmp.val); |
131 | } |
132 | |
133 | key_val_ref<key_t,val_t> operator*() |
134 | { |
135 | return key_val_ref<key_t,val_t>(*key,*val); |
136 | } |
137 | |
138 | key_val_ref<key_t,val_t> operator[](int i) |
139 | { |
140 | return key_val_ref<key_t,val_t>(*key,*val); |
141 | } |
142 | |
143 | key_val_it operator+(size_t count) const |
144 | { |
145 | key_val_it tmp(key+count,val+count); |
146 | |
147 | return tmp; |
148 | } |
149 | |
150 | |
151 | size_t operator-(key_val_it & tmp) const |
152 | { |
153 | return key - tmp.key; |
154 | } |
155 | |
156 | key_val_it operator-(size_t count) const |
157 | { |
158 | key_val_it tmp(key-count,val-count); |
159 | |
160 | return tmp; |
161 | } |
162 | |
163 | key_val_it & operator++() |
164 | { |
165 | ++key; |
166 | ++val; |
167 | |
168 | return *this; |
169 | } |
170 | |
171 | key_val_it & operator--() |
172 | { |
173 | --key; |
174 | --val; |
175 | |
176 | return *this; |
177 | } |
178 | |
179 | bool operator!=(const key_val_it & tmp) const |
180 | { |
181 | return key != tmp.key && val != tmp.val; |
182 | } |
183 | |
184 | bool operator<(const key_val_it & tmp) const |
185 | { |
186 | return key < tmp.key; |
187 | } |
188 | |
189 | key_val_it<key_t,val_t> & operator=(key_val_it<key_t,val_t> & tmp) |
190 | { |
191 | key = tmp.key; |
192 | val = tmp.val; |
193 | |
194 | return *this; |
195 | } |
196 | |
197 | key_val_it() {} |
198 | |
199 | key_val_it(const key_val_it<key_t,val_t> & tmp) |
200 | :key(tmp.key),val(tmp.val) |
201 | {} |
202 | |
203 | key_val_it(key_t * key, val_t * val) |
204 | :key(key),val(val) |
205 | {} |
206 | }; |
207 | |
208 | template<typename key_t, typename val_t> |
209 | void swap(key_val_ref<key_t,val_t> a, key_val_ref<key_t,val_t> b) |
210 | { |
211 | key_t kt = a.key; |
212 | a.key = b.key; |
213 | b.key = kt; |
214 | |
215 | val_t vt = a.val; |
216 | a.val = b.val; |
217 | b.val = vt; |
218 | } |
219 | |
220 | namespace std |
221 | { |
222 | template<typename key_t, typename val_t> |
223 | struct iterator_traits<key_val_it<key_t,val_t>> |
224 | { |
225 | typedef size_t difference_type; //almost always ptrdiff_t |
226 | typedef key_val<key_t,val_t> value_type; //almost always T |
227 | typedef key_val<key_t,val_t> & reference; //almost always T& or const T& |
228 | typedef key_val<key_t,val_t> & pointer; //almost always T* or const T* |
229 | typedef std::random_access_iterator_tag iterator_category; //usually std::forward_iterator_tag or similar |
230 | }; |
231 | } |
232 | |
233 | |
234 | namespace openfpm |
235 | { |
236 | template<typename key_t, typename val_t, |
237 | typename comp_t> |
238 | void sort(key_t* keys_input, val_t* vals_input, int count, |
239 | comp_t comp, mgpu::ofp_context_t& context) |
240 | { |
241 | #ifdef CUDA_ON_CPU |
242 | |
243 | key_val_it<key_t,val_t> kv(keys_input,vals_input); |
244 | |
245 | std::sort(kv,kv+count,comp); |
246 | |
247 | #else |
248 | |
249 | #ifdef SORT_WITH_CUB |
250 | |
251 | void *d_temp_storage = NULL; |
252 | size_t temp_storage_bytes = 0; |
253 | |
254 | auto & temporal2 = context.getTemporalCUB2(); |
255 | temporal2.resize(sizeof(key_t)*count); |
256 | |
257 | auto & temporal3 = context.getTemporalCUB3(); |
258 | temporal3.resize(sizeof(val_t)*count); |
259 | |
260 | if (std::is_same<mgpu::template less_t<key_t>,comp_t>::value == true) |
261 | { |
262 | cub::DeviceRadixSort::SortPairs(d_temp_storage, |
263 | temp_storage_bytes, |
264 | keys_input, |
265 | (key_t *)temporal2.template getDeviceBuffer<0>(), |
266 | vals_input, |
267 | (val_t *)temporal3.template getDeviceBuffer<0>(), |
268 | count); |
269 | |
270 | auto & temporal = context.getTemporalCUB(); |
271 | temporal.resize(temp_storage_bytes); |
272 | |
273 | d_temp_storage = temporal.template getDeviceBuffer<0>(); |
274 | |
275 | // Run |
276 | cub::DeviceRadixSort::SortPairs(d_temp_storage, |
277 | temp_storage_bytes, |
278 | keys_input, |
279 | (key_t *)temporal2.template getDeviceBuffer<0>(), |
280 | vals_input, |
281 | (val_t *)temporal3.template getDeviceBuffer<0>(), |
282 | count); |
283 | } |
284 | else if (std::is_same<mgpu::template greater_t<key_t>,comp_t>::value == true) |
285 | { |
286 | cub::DeviceRadixSort::SortPairsDescending(d_temp_storage, |
287 | temp_storage_bytes, |
288 | keys_input, |
289 | (key_t *)temporal2.template getDeviceBuffer<0>(), |
290 | vals_input, |
291 | (val_t *)temporal3.template getDeviceBuffer<0>(), |
292 | count); |
293 | |
294 | auto & temporal = context.getTemporalCUB(); |
295 | temporal.resize(temp_storage_bytes); |
296 | |
297 | d_temp_storage = temporal.template getDeviceBuffer<0>(); |
298 | |
299 | // Run |
300 | cub::DeviceRadixSort::SortPairsDescending(d_temp_storage, |
301 | temp_storage_bytes, |
302 | keys_input, |
303 | (key_t *)temporal2.template getDeviceBuffer<0>(), |
304 | vals_input, |
305 | (val_t *)temporal3.template getDeviceBuffer<0>(), |
306 | count); |
307 | } |
308 | |
309 | cudaMemcpy(keys_input,temporal2.getDeviceBuffer<0>(),sizeof(key_t)*count,cudaMemcpyDeviceToDevice); |
310 | cudaMemcpy(vals_input,temporal3.getDeviceBuffer<0>(),sizeof(val_t)*count,cudaMemcpyDeviceToDevice); |
311 | |
312 | #else |
313 | mgpu::mergesort(keys_input,vals_input,count,comp,context); |
314 | #endif |
315 | |
316 | #endif |
317 | } |
318 | } |
319 | |
320 | #endif |
321 | |
322 | |
323 | #endif /* SORT_OFP_CUH_ */ |
324 | |