From ab4bcb749ff92ed768cfcab1c8e3c7ab917fd48c Mon Sep 17 00:00:00 2001 From: Adam Gent Date: Tue, 28 Nov 2023 10:35:02 -0500 Subject: [PATCH] Make partial loading threadsafe across platforms --- .../com/samskivert/mustache/Mustache.java | 21 ++++-- .../mustache/PartialThreadSafeTest.java | 69 +++++++++++++++++++ 2 files changed, 86 insertions(+), 4 deletions(-) create mode 100644 src/test/java/com/samskivert/mustache/PartialThreadSafeTest.java diff --git a/src/main/java/com/samskivert/mustache/Mustache.java b/src/main/java/com/samskivert/mustache/Mustache.java index 21d560d..58b7142 100644 --- a/src/main/java/com/samskivert/mustache/Mustache.java +++ b/src/main/java/com/samskivert/mustache/Mustache.java @@ -13,6 +13,8 @@ import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; /** * Provides Mustache templating services. @@ -1057,10 +1059,20 @@ private IncludedTemplateSegment (Compiler compiler, String name, String indent) protected Template getTemplate () { // we compile our template lazily to avoid infinie recursion if a template includes // itself (see issue #13) - if (_template == null) { - _template = _comp.loadTemplate(_name).indent(_indent); + Template t = _template; + if (t == null) { + // We cannot use synchronized or a CAS operation here since loadTemplate might be an IO call + // and virtual threads prefer regular locks. + lock.lock(); + try { + if ((t = _template) == null) { + _template = t = _comp.loadTemplate(_name).indent(_indent); + } + } finally { + lock.unlock(); + } } - return _template; + return t; } protected IncludedTemplateSegment indent(String indent, boolean first, boolean last) { // Indent this partial based on the spacing provided. @@ -1084,7 +1096,8 @@ public String toString() { protected final Compiler _comp; protected final String _name; private final String _indent; - private Template _template; + private final Lock lock = new ReentrantLock(); + private volatile Template _template; protected boolean _standalone = false; } diff --git a/src/test/java/com/samskivert/mustache/PartialThreadSafeTest.java b/src/test/java/com/samskivert/mustache/PartialThreadSafeTest.java new file mode 100644 index 0000000..ce97644 --- /dev/null +++ b/src/test/java/com/samskivert/mustache/PartialThreadSafeTest.java @@ -0,0 +1,69 @@ +package com.samskivert.mustache; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import java.io.IOException; +import java.io.Reader; +import java.io.StringReader; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ConcurrentLinkedDeque; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +import org.junit.Test; + +import com.samskivert.mustache.Mustache.TemplateLoader; + +public class PartialThreadSafeTest { + + @Test + public void testPartialThreadSafe() throws Exception { + long t = System.currentTimeMillis(); + AtomicInteger loadCount = new AtomicInteger(); + TemplateLoader loader = new TemplateLoader() { + + @Override + public Reader getTemplate(String name) throws Exception { + if ("partial".equals(name)) { + loadCount.incrementAndGet(); + TimeUnit.MILLISECONDS.sleep(20); + return new StringReader("Hello"); + } + throw new IOException(name); + } + }; + + Template template = Mustache.compiler().withLoader(loader).compile("{{stuff}}\n\t{{> partial }}"); + ExecutorService executor = Executors.newFixedThreadPool(64); + ConcurrentLinkedDeque q = new ConcurrentLinkedDeque<>(); + + Map m = new HashMap<>(); + m.put("stuff", "Foo"); + for (int i = 100; i > 0; i--) { + int ii = i; + executor.execute(() -> { + try { + TimeUnit.MILLISECONDS.sleep(ii % 10); + template.execute(m); + } catch (Exception e) { + q.add(e); + } + }); + } + executor.shutdown(); + executor.awaitTermination(10_000, TimeUnit.MILLISECONDS); + if (!q.isEmpty()) { + System.out.println(q); + } + assertTrue(q.isEmpty()); + assertEquals(1, loadCount.get()); + System.out.println(loadCount); + System.out.println(System.currentTimeMillis() - t); + + } + +}