| 1 | /* |
| 2 | * sort_ofp.cuh |
| 3 | * |
| 4 | * Created on: Aug 23, 2019 |
| 5 | * Author: i-bird |
| 6 | */ |
| 7 | |
| 8 | #ifndef SORT_OFP_CUH_ |
| 9 | #define SORT_OFP_CUH_ |
| 10 | |
| 11 | |
| 12 | #ifdef __NVCC__ |
| 13 | |
| 14 | #include "util/cuda_launch.hpp" |
| 15 | |
| 16 | #if CUDART_VERSION >= 11000 |
| 17 | #ifndef CUDA_ON_CPU |
| 18 | // Here we have for sure CUDA >= 11 |
| 19 | #include "cub/cub.cuh" |
| 20 | #ifndef SORT_WITH_CUB |
| 21 | #define SORT_WITH_CUB |
| 22 | #endif |
| 23 | #endif |
| 24 | #else |
| 25 | // Here we have old CUDA |
| 26 | #include "cub_old/cub.cuh" |
| 27 | #include "util/cuda/moderngpu/kernel_mergesort.hxx" |
| 28 | #endif |
| 29 | |
| 30 | #include "util/cuda/ofp_context.hxx" |
| 31 | |
| 32 | template<typename key_t, typename val_t> |
| 33 | struct key_val_ref; |
| 34 | |
| 35 | template<typename key_t, typename val_t> |
| 36 | struct key_val |
| 37 | { |
| 38 | key_t key; |
| 39 | val_t val; |
| 40 | |
| 41 | key_val(const key_t & k, const val_t & v) |
| 42 | :key(k),val(v) |
| 43 | {} |
| 44 | |
| 45 | key_val(const key_val_ref<key_t,val_t> & tmp) |
| 46 | { |
| 47 | this->operator=(tmp); |
| 48 | } |
| 49 | |
| 50 | bool operator<(const key_val & tmp) const |
| 51 | { |
| 52 | return key < tmp.key; |
| 53 | } |
| 54 | |
| 55 | bool operator>(const key_val & tmp) const |
| 56 | { |
| 57 | return key > tmp.key; |
| 58 | } |
| 59 | |
| 60 | key_val & operator=(const key_val_ref<key_t,val_t> & tmp) |
| 61 | { |
| 62 | key = tmp.key; |
| 63 | val = tmp.val; |
| 64 | |
| 65 | return *this; |
| 66 | } |
| 67 | }; |
| 68 | |
| 69 | |
| 70 | template<typename key_t, typename val_t> |
| 71 | struct key_val_ref |
| 72 | { |
| 73 | key_t & key; |
| 74 | val_t & val; |
| 75 | |
| 76 | key_val_ref(key_t & k, val_t & v) |
| 77 | :key(k),val(v) |
| 78 | {} |
| 79 | |
| 80 | key_val_ref(key_val_ref<key_t,val_t> && tmp) |
| 81 | :key(tmp.key),val(tmp.val) |
| 82 | {} |
| 83 | |
| 84 | key_val_ref & operator=(const key_val<key_t,val_t> & tmp) |
| 85 | { |
| 86 | key = tmp.key; |
| 87 | val = tmp.val; |
| 88 | |
| 89 | return *this; |
| 90 | } |
| 91 | |
| 92 | key_val_ref & operator=(const key_val_ref<key_t,val_t> & tmp) |
| 93 | { |
| 94 | key = tmp.key; |
| 95 | val = tmp.val; |
| 96 | |
| 97 | return *this; |
| 98 | } |
| 99 | |
| 100 | bool operator<(const key_val_ref<key_t,val_t> & tmp) |
| 101 | { |
| 102 | return key < tmp.key; |
| 103 | } |
| 104 | |
| 105 | bool operator>(const key_val_ref<key_t,val_t> & tmp) |
| 106 | { |
| 107 | return key > tmp.key; |
| 108 | } |
| 109 | |
| 110 | bool operator<(const key_val<key_t,val_t> & tmp) |
| 111 | { |
| 112 | return key < tmp.key; |
| 113 | } |
| 114 | |
| 115 | bool operator>(const key_val<key_t,val_t> & tmp) |
| 116 | { |
| 117 | return key > tmp.key; |
| 118 | } |
| 119 | }; |
| 120 | |
| 121 | |
| 122 | template<typename key_t, typename val_t> |
| 123 | struct key_val_it |
| 124 | { |
| 125 | key_t * key; |
| 126 | val_t * val; |
| 127 | |
| 128 | bool operator==(const key_val_it & tmp) |
| 129 | { |
| 130 | return (key == tmp.key && val == tmp.val); |
| 131 | } |
| 132 | |
| 133 | key_val_ref<key_t,val_t> operator*() |
| 134 | { |
| 135 | return key_val_ref<key_t,val_t>(*key,*val); |
| 136 | } |
| 137 | |
| 138 | key_val_ref<key_t,val_t> operator[](int i) |
| 139 | { |
| 140 | return key_val_ref<key_t,val_t>(*key,*val); |
| 141 | } |
| 142 | |
| 143 | key_val_it operator+(size_t count) const |
| 144 | { |
| 145 | key_val_it tmp(key+count,val+count); |
| 146 | |
| 147 | return tmp; |
| 148 | } |
| 149 | |
| 150 | |
| 151 | size_t operator-(key_val_it & tmp) const |
| 152 | { |
| 153 | return key - tmp.key; |
| 154 | } |
| 155 | |
| 156 | key_val_it operator-(size_t count) const |
| 157 | { |
| 158 | key_val_it tmp(key-count,val-count); |
| 159 | |
| 160 | return tmp; |
| 161 | } |
| 162 | |
| 163 | key_val_it & operator++() |
| 164 | { |
| 165 | ++key; |
| 166 | ++val; |
| 167 | |
| 168 | return *this; |
| 169 | } |
| 170 | |
| 171 | key_val_it & operator--() |
| 172 | { |
| 173 | --key; |
| 174 | --val; |
| 175 | |
| 176 | return *this; |
| 177 | } |
| 178 | |
| 179 | bool operator!=(const key_val_it & tmp) const |
| 180 | { |
| 181 | return key != tmp.key && val != tmp.val; |
| 182 | } |
| 183 | |
| 184 | bool operator<(const key_val_it & tmp) const |
| 185 | { |
| 186 | return key < tmp.key; |
| 187 | } |
| 188 | |
| 189 | key_val_it<key_t,val_t> & operator=(key_val_it<key_t,val_t> & tmp) |
| 190 | { |
| 191 | key = tmp.key; |
| 192 | val = tmp.val; |
| 193 | |
| 194 | return *this; |
| 195 | } |
| 196 | |
| 197 | key_val_it() {} |
| 198 | |
| 199 | key_val_it(const key_val_it<key_t,val_t> & tmp) |
| 200 | :key(tmp.key),val(tmp.val) |
| 201 | {} |
| 202 | |
| 203 | key_val_it(key_t * key, val_t * val) |
| 204 | :key(key),val(val) |
| 205 | {} |
| 206 | }; |
| 207 | |
| 208 | template<typename key_t, typename val_t> |
| 209 | void swap(key_val_ref<key_t,val_t> a, key_val_ref<key_t,val_t> b) |
| 210 | { |
| 211 | key_t kt = a.key; |
| 212 | a.key = b.key; |
| 213 | b.key = kt; |
| 214 | |
| 215 | val_t vt = a.val; |
| 216 | a.val = b.val; |
| 217 | b.val = vt; |
| 218 | } |
| 219 | |
| 220 | namespace std |
| 221 | { |
| 222 | template<typename key_t, typename val_t> |
| 223 | struct iterator_traits<key_val_it<key_t,val_t>> |
| 224 | { |
| 225 | typedef size_t difference_type; //almost always ptrdiff_t |
| 226 | typedef key_val<key_t,val_t> value_type; //almost always T |
| 227 | typedef key_val<key_t,val_t> & reference; //almost always T& or const T& |
| 228 | typedef key_val<key_t,val_t> & pointer; //almost always T* or const T* |
| 229 | typedef std::random_access_iterator_tag iterator_category; //usually std::forward_iterator_tag or similar |
| 230 | }; |
| 231 | } |
| 232 | |
| 233 | |
| 234 | namespace openfpm |
| 235 | { |
| 236 | template<typename key_t, typename val_t, |
| 237 | typename comp_t> |
| 238 | void sort(key_t* keys_input, val_t* vals_input, int count, |
| 239 | comp_t comp, mgpu::ofp_context_t& context) |
| 240 | { |
| 241 | #ifdef CUDA_ON_CPU |
| 242 | |
| 243 | key_val_it<key_t,val_t> kv(keys_input,vals_input); |
| 244 | |
| 245 | std::sort(kv,kv+count,comp); |
| 246 | |
| 247 | #else |
| 248 | |
| 249 | #ifdef SORT_WITH_CUB |
| 250 | |
| 251 | void *d_temp_storage = NULL; |
| 252 | size_t temp_storage_bytes = 0; |
| 253 | |
| 254 | auto & temporal2 = context.getTemporalCUB2(); |
| 255 | temporal2.resize(sizeof(key_t)*count); |
| 256 | |
| 257 | auto & temporal3 = context.getTemporalCUB3(); |
| 258 | temporal3.resize(sizeof(val_t)*count); |
| 259 | |
| 260 | if (std::is_same<mgpu::template less_t<key_t>,comp_t>::value == true) |
| 261 | { |
| 262 | cub::DeviceRadixSort::SortPairs(d_temp_storage, |
| 263 | temp_storage_bytes, |
| 264 | keys_input, |
| 265 | (key_t *)temporal2.template getDeviceBuffer<0>(), |
| 266 | vals_input, |
| 267 | (val_t *)temporal3.template getDeviceBuffer<0>(), |
| 268 | count); |
| 269 | |
| 270 | auto & temporal = context.getTemporalCUB(); |
| 271 | temporal.resize(temp_storage_bytes); |
| 272 | |
| 273 | d_temp_storage = temporal.template getDeviceBuffer<0>(); |
| 274 | |
| 275 | // Run |
| 276 | cub::DeviceRadixSort::SortPairs(d_temp_storage, |
| 277 | temp_storage_bytes, |
| 278 | keys_input, |
| 279 | (key_t *)temporal2.template getDeviceBuffer<0>(), |
| 280 | vals_input, |
| 281 | (val_t *)temporal3.template getDeviceBuffer<0>(), |
| 282 | count); |
| 283 | } |
| 284 | else if (std::is_same<mgpu::template greater_t<key_t>,comp_t>::value == true) |
| 285 | { |
| 286 | cub::DeviceRadixSort::SortPairsDescending(d_temp_storage, |
| 287 | temp_storage_bytes, |
| 288 | keys_input, |
| 289 | (key_t *)temporal2.template getDeviceBuffer<0>(), |
| 290 | vals_input, |
| 291 | (val_t *)temporal3.template getDeviceBuffer<0>(), |
| 292 | count); |
| 293 | |
| 294 | auto & temporal = context.getTemporalCUB(); |
| 295 | temporal.resize(temp_storage_bytes); |
| 296 | |
| 297 | d_temp_storage = temporal.template getDeviceBuffer<0>(); |
| 298 | |
| 299 | // Run |
| 300 | cub::DeviceRadixSort::SortPairsDescending(d_temp_storage, |
| 301 | temp_storage_bytes, |
| 302 | keys_input, |
| 303 | (key_t *)temporal2.template getDeviceBuffer<0>(), |
| 304 | vals_input, |
| 305 | (val_t *)temporal3.template getDeviceBuffer<0>(), |
| 306 | count); |
| 307 | } |
| 308 | |
| 309 | cudaMemcpy(keys_input,temporal2.getDeviceBuffer<0>(),sizeof(key_t)*count,cudaMemcpyDeviceToDevice); |
| 310 | cudaMemcpy(vals_input,temporal3.getDeviceBuffer<0>(),sizeof(val_t)*count,cudaMemcpyDeviceToDevice); |
| 311 | |
| 312 | #else |
| 313 | mgpu::mergesort(keys_input,vals_input,count,comp,context); |
| 314 | #endif |
| 315 | |
| 316 | #endif |
| 317 | } |
| 318 | } |
| 319 | |
| 320 | #endif |
| 321 | |
| 322 | |
| 323 | #endif /* SORT_OFP_CUH_ */ |
| 324 | |