Этот вопрос НЕ о том, как использовать ThreadLocal. Мой вопрос касается побочного эффекта продолжения ForkJoinPool для ForkJoinTask.compute(), который нарушает контракт ThreadLocal.
В ForkJoinTask.compute()
я вытаскиваю произвольный статический ThreadLocal.
Значение представляет собой некоторый произвольный объект с состоянием, но не с состоянием после окончания вызова compute()
. Другими словами, я подготавливаю локальный объект/состояние потока, использую его, а затем удаляю.
В принципе, вы бы поместили это состояние в ForkJoinTasK, но просто предположите, что это локальное значение потока находится в сторонней библиотеке, которую я не могу изменить. Следовательно, статический threadlocal, поскольку это ресурс, который будут совместно использовать все экземпляры задач.
Я предвидел, протестировал и доказал, что простой ThreadLocal, конечно же, инициализируется только один раз. Это означает, что из-за продолжения потока после вызова ForkJoinTask.join()
мой метод compute()
может быть вызван снова еще до того, как он завершится. Это показывает состояние объекта, используемого при предыдущем вызове вычислений, на много кадров стека выше.
Как вы решаете эту проблему нежелательного воздействия?
Единственный способ, который я сейчас вижу, — это обеспечить новые потоки для каждого вызова compute()
, но это нарушает продолжение пула F/J и может опасно увеличить количество потоков.
Разве в ядре JRE не нужно что-то делать для резервного копирования TL, который изменился с момента первого ForkJoinTask, и восстановить всю карту threadlocal, как если бы каждый task.compute был первым, запущенным в потоке?
Спасибо.
package jdk8tests;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinWorkerThread;
import java.util.concurrent.RecursiveTask;
import java.util.concurrent.atomic.AtomicInteger;
public class TestForkJoin3 {
static AtomicInteger nextId = new AtomicInteger();
static long T0 = System.currentTimeMillis();
static int NTHREADS = 5;
static final ThreadLocal<StringBuilder> myTL = ThreadLocal.withInitial( () -> new StringBuilder());
static void log(Object msg) {
System.out.format("%09.3f %-10s %s%n", new Double(0.001*(System.currentTimeMillis()-T0)), Thread.currentThread().getName(), " : "+msg);
}
public static void main(String[] args) throws Exception {
ForkJoinPool p = new ForkJoinPool(
NTHREADS,
pool -> {
int id = nextId.incrementAndGet(); //count new threads
log("new FJ thread "+ id);
ForkJoinWorkerThread t = new ForkJoinWorkerThread(pool) {/**/};
t.setName("My FJThread "+id);
return t;
},
Thread.getDefaultUncaughtExceptionHandler(),
false
);
LowercasingTask t = new LowercasingTask("ROOT", 3);
p.invoke(t);
int nt = nextId.get();
log("number of threads was "+nt);
if(nt > NTHREADS)
log(">>>>>>> more threads than prescribed <<<<<<<<");
}
//=====================
static class LowercasingTask extends RecursiveTask<String> {
String name;
int level;
public LowercasingTask(String name, int level) {
this.name = name;
this.level = level;
}
@Override
protected String compute() {
StringBuilder sbtl = myTL.get();
String initialValue = sbtl.toString();
if(!initialValue.equals(""))
log("!!!!!! BROKEN ON START!!!!!!! value = "+ initialValue);
sbtl.append(":START");
if(level>0) {
log(name+": compute level "+level);
try {Thread.sleep(10);} catch (InterruptedException e) {e.printStackTrace();}
List<LowercasingTask> tasks = new ArrayList<>();
for(int i=1; i<=9; i++) {
LowercasingTask lt = new LowercasingTask(name+"."+i, level-1);
tasks.add(lt);
lt.fork();
}
for(int i=0; i<tasks.size(); i++) { //this can lead to compensation threads due to l1.join() method running lifo task lN
//for(int i=tasks.size()-1; i>=0; i--) { //this usually has the lN.join() method running task lN, without compensation threads.
tasks.get(i).join();
}
log(name+": returning from joins");
}
sbtl.append(":END");
String val = sbtl.toString();
if(!val.equals(":START:END"))
log("!!!!!! BROKEN AT END !!!!!!! value = "+val);
sbtl.setLength(0);
return "done";
}
}
}