1 /// Multi-threaded easy arbitrary function invocation and iterating.
2 module tern.concurrency;
3 
4 import tern.functional : barter;
5 import tern.object : loadLength;
6 import tern.traits;
7 import std.parallelism;
8 import std.concurrency;
9 import std.range : iota;
10 
11 public:
12 /**
13  * Asynchronously invokes `F` and awaits its return value.
14  *
15  * Params:
16  *  F = Function to be invoked.
17  *  args = Arguments to invoke `F` on.
18  *
19  * Returns:
20  *  The return of `F`.
21  */
22 auto await(alias F, ARGS...)(ARGS args)
23     if (!isNoReturn!F)
24 {
25     void function(ARGS args) f = (ARGS args) { auto ret = F(args); send(ownerTid, ret); };
26     spawn(f, args);
27     return receiveOnly!(ReturnType!F);
28 }
29 
30 /// ditto
31 auto await(alias F, ARGS...)(ARGS args)
32     if (isCallable!F && !__traits(compiles, isNoReturn!F))
33 {
34     void function(ARGS args) f = (ARGS args) { auto ret = F(args); send(ownerTid, ret); };
35     spawn(f, args);
36     return receiveOnly!(typeof(F(args)));
37 }
38 
39 /// ditto
40 bool await(alias F, ARGS...)(ARGS args)
41     if (isNoReturn!F)
42 {
43     void function(ARGS args) f = (ARGS args) { F(args); send(ownerTid, true); };
44     spawn(f, args);
45     return receiveOnly!bool;
46 }
47 
48 /**
49  * Asynchronously invokes `F` and ignores any further actions from the spawn.
50  *
51  * Params:
52  *  F = Function to be invoked.
53  *  args = Arguments to invoke `F` on.
54  */
55 void async(alias F, ARGS...)(ARGS args)
56 {
57     void function(ARGS args) f = (ARGS args) { F(args); };
58     spawn(f, args);
59 }
60 
61 /**
62  * Spins up a group of `WORKERS` to run `F` on the given `args`.
63  *
64  * Params:
65  *  WORKERS = The number of workers to delegate the task to.
66  *  F = The function to be invoked.
67  *  args = Arguments to invoke `F` on.
68  */
69 void spinGroup(size_t WORKERS, alias F, ARGS...)(ARGS args)
70 {
71     enum range = iota(0, WORKERS);
72     foreach (worker; parallel(range))
73         F(args, worker);
74 }
75 
76 /**
77  * Spins up a group to iterate across all elements in `range` on.
78  *
79  * Params:
80  *  F = The function to be invoked.
81  *  range = The range to iterate across.
82  */
83 void parallelForeach(alias F, T)(auto ref T range)
84 {
85     immutable size_t chunk = range.loadLength / ((range.loadLength / 4) | 1);
86     spinGroup!(4, (size_t worker) {
87         size_t index = worker * chunk;
88         size_t len = index + chunk;
89         if (len >= range.loadLength)
90         {
91             if (len - chunk >= range.loadLength)
92                 return;
93             size_t rem = range.loadLength % chunk;
94             len -= rem != 0 ? chunk - rem : 0;
95         }
96         
97         foreach (i; index..len)
98             barter!F(i, range[i]);
99     })();
100 }
101 
102 /// ditto
103 void parallelForeachReverse(alias F, T)(auto ref T range)
104 {
105     immutable size_t chunk = range.loadLength / ((range.loadLength / 4) | 1);
106     spinGroup!(4, (size_t worker) {
107         size_t index = worker * chunk;
108         size_t len = index + chunk;
109         if (len >= range.loadLength)
110         {
111             if (len - chunk >= range.loadLength)
112                 return;
113             size_t rem = range.loadLength % chunk;
114             len -= rem != 0 ? chunk - rem : 0;
115         }
116         
117         foreach_reverse (i; index..len)
118             barter!F(i, range[i]);
119     })();
120 }
121 
122 /**
123  * Spins up a group to iterate from `start` to `end` with increments of `step`.
124  *
125  * Params:
126  *  F = The function to be invoked.
127  *  start = The starting value.
128  *  end = The ending value.
129  *  step = The increment.
130  */
131 void parallelFor(alias F)(ptrdiff_t start, ptrdiff_t end, ptrdiff_t step)
132 {
133     ptrdiff_t cycles = end - start < 0 ? -(end - start) : end - start;
134     immutable size_t chunk = cycles / ((cycles / (4 * step)) | 1);
135     spinGroup!(4, (size_t worker) {
136         size_t index = worker * chunk;
137         size_t len = index + chunk;
138         if (len >= cycles)
139         {
140             if (len - chunk >= cycles)
141                 return;
142             size_t rem = cycles % chunk;
143             len -= rem != 0 ? chunk - rem - 1 : 0;
144         }
145         
146         foreach (i; index..len)
147             barter!F(i);
148     })();
149 }
150 
151 /**
152  * Spins up a group to call `F` while `W`.
153  *
154  * Params:
155  *  W = The conditional function.
156  *  F = The function to be invoked.
157  */
158 void parallelWhile(alias W, alias F)()
159 {
160     size_t index;
161     spinGroup!(4, (size_t worker) {
162         if (!barter!W(index, worker))
163             return;
164 
165         barter!F(index++);
166     })();
167 }