Add support for a callback on recursion for memoized functions

This commit is contained in:
Ilya Chernikov
2020-12-16 19:37:00 +01:00
parent eef06cded3
commit 02c617468f
4 changed files with 70 additions and 13 deletions
@@ -163,7 +163,8 @@ public class StorageManagerTest extends TestCase {
fail();
}
catch (AssertionError e) {
assertTrue(e.getMessage().startsWith("Recursion detected on input: !!!"));
String message = e.getMessage();
assertTrue("Expected message starting with \"Recursion detected\", got: " + message, message.startsWith("Recursion detected on input: !!!"));
}
}
@@ -213,7 +214,7 @@ public class StorageManagerTest extends TestCase {
new C().rec.invoke();
fail();
}
catch (IllegalStateException e) {
catch (AssertionError e) {
// OK
}
}
@@ -232,7 +233,7 @@ public class StorageManagerTest extends TestCase {
new C().rec.invoke();
fail();
}
catch (IllegalStateException e) {
catch (AssertionError e) {
// OK
}
}
@@ -284,7 +285,7 @@ public class StorageManagerTest extends TestCase {
new C().rec.invoke();
fail();
}
catch (IllegalStateException e) {
catch (AssertionError e) {
// OK
}
}
@@ -19,6 +19,7 @@ package org.jetbrains.kotlin.storage;
import kotlin.Unit;
import kotlin.jvm.functions.Function0;
import kotlin.jvm.functions.Function1;
import kotlin.jvm.functions.Function2;
import kotlin.text.StringsKt;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
@@ -56,7 +57,7 @@ public class LockBasedStorageManager implements StorageManager {
public static final StorageManager NO_LOCKS = new LockBasedStorageManager("NO_LOCKS", ExceptionHandlingStrategy.THROW, EmptySimpleLock.INSTANCE) {
@NotNull
@Override
protected <T> RecursionDetectedResult<T> recursionDetectedDefault() {
protected <K, V> RecursionDetectedResult<V> recursionDetectedDefault(@NotNull String source, K input) {
return RecursionDetectedResult.fallThrough();
}
};
@@ -123,6 +124,15 @@ public class LockBasedStorageManager implements StorageManager {
return createMemoizedFunction(compute, LockBasedStorageManager.<K>createConcurrentHashMap());
}
@NotNull
@Override
public <K, V> MemoizedFunctionToNotNull<K, V> createMemoizedFunction(
@NotNull Function1<? super K, ? extends V> compute,
@NotNull Function2<? super K, ? super Boolean, ? extends V> onRecursiveCall
) {
return createMemoizedFunction(compute, onRecursiveCall, LockBasedStorageManager.<K>createConcurrentHashMap());
}
@NotNull
@Override
public <K, V> MemoizedFunctionToNotNull<K, V> createMemoizedFunction(
@@ -132,6 +142,22 @@ public class LockBasedStorageManager implements StorageManager {
return new MapBasedMemoizedFunctionToNotNull<K, V>(this, map, compute);
}
@NotNull
@Override
public <K, V> MemoizedFunctionToNotNull<K, V> createMemoizedFunction(
@NotNull Function1<? super K, ? extends V> compute,
@NotNull final Function2<? super K, ? super Boolean, ? extends V> onRecursiveCall,
@NotNull ConcurrentMap<K, Object> map
) {
return new MapBasedMemoizedFunctionToNotNull<K, V>(this, map, compute) {
@NotNull
@Override
protected RecursionDetectedResult<V> recursionDetected(K input, boolean firstTime) {
return RecursionDetectedResult.value(onRecursiveCall.invoke(input, firstTime));
}
};
}
@NotNull
@Override
public <K, V> MemoizedFunctionToNullable<K, V> createMemoizedFunctionWithNullableValues(@NotNull Function1<? super K, ? extends V> compute) {
@@ -278,8 +304,15 @@ public class LockBasedStorageManager implements StorageManager {
}
@NotNull
protected <T> RecursionDetectedResult<T> recursionDetectedDefault() {
throw sanitizeStackTrace(new IllegalStateException("Recursive call in a lazy value under " + this));
protected <K, V> RecursionDetectedResult<V> recursionDetectedDefault(@NotNull String source, K input) {
throw sanitizeStackTrace(
new AssertionError("Recursion detected " + source +
(input == null
? ""
: "on input: " + input
) + " under " + this
)
);
}
private static class RecursionDetectedResult<T> {
@@ -406,7 +439,7 @@ public class LockBasedStorageManager implements StorageManager {
*/
@NotNull
protected RecursionDetectedResult<T> recursionDetected(boolean firstTime) {
return storageManager.recursionDetectedDefault();
return storageManager.recursionDetectedDefault("in a lazy value", null);
}
protected void postCompute(T value) {
@@ -521,9 +554,22 @@ public class LockBasedStorageManager implements StorageManager {
storageManager.lock.lock();
try {
value = cache.get(input);
if (value == NotValue.COMPUTING) {
throw recursionDetected(input);
value = NotValue.RECURSION_WAS_DETECTED;
RecursionDetectedResult<V> result = recursionDetected(input, /*firstTime = */ true);
if (!result.isFallThrough()) {
return result.getValue();
}
}
if (value == NotValue.RECURSION_WAS_DETECTED) {
RecursionDetectedResult<V> result = recursionDetected(input, /*firstTime = */ false);
if (!result.isFallThrough()) {
return result.getValue();
}
}
if (value != null) return WrappedValues.unescapeExceptionOrNull(value);
AssertionError error = null;
@@ -567,10 +613,8 @@ public class LockBasedStorageManager implements StorageManager {
}
@NotNull
private AssertionError recursionDetected(K input) {
return sanitizeStackTrace(
new AssertionError("Recursion detected on input: " + input + " under " + storageManager)
);
protected RecursionDetectedResult<V> recursionDetected(K input, boolean firstTime) {
return storageManager.recursionDetectedDefault("", input);
}
@NotNull
@@ -26,6 +26,10 @@ abstract class ObservableStorageManager(private val delegate: StorageManager) :
return delegate.createMemoizedFunction(compute.observable)
}
override fun <K, V : Any> createMemoizedFunction(compute: (K) -> V, onRecursiveCall: (K, Boolean) -> V): MemoizedFunctionToNotNull<K, V> {
return delegate.createMemoizedFunction(compute.observable, onRecursiveCall)
}
override fun <K, V: Any> createMemoizedFunctionWithNullableValues(compute: (K) -> V?): MemoizedFunctionToNullable<K, V> {
return delegate.createMemoizedFunctionWithNullableValues(compute.observable)
}
@@ -34,6 +38,10 @@ abstract class ObservableStorageManager(private val delegate: StorageManager) :
return delegate.createMemoizedFunction(compute.observable, map)
}
override fun <K, V : Any> createMemoizedFunction(compute: (K) -> V, onRecursiveCall: (K, Boolean) -> V, map: ConcurrentMap<K, Any>): MemoizedFunctionToNotNull<K, V> {
return delegate.createMemoizedFunction(compute.observable, onRecursiveCall, map)
}
override fun <K, V: Any> createMemoizedFunctionWithNullableValues(compute: (K) -> V, map: ConcurrentMap<K, Any>): MemoizedFunctionToNullable<K, V> {
return delegate.createMemoizedFunctionWithNullableValues(compute.observable, map)
}
@@ -29,6 +29,8 @@ interface StorageManager {
*/
fun <K, V : Any> createMemoizedFunction(compute: (K) -> V): MemoizedFunctionToNotNull<K, V>
fun <K, V : Any> createMemoizedFunction(compute: (K) -> V, onRecursiveCall: (K, Boolean) -> V): MemoizedFunctionToNotNull<K, V>
fun <K, V : Any> createMemoizedFunctionWithNullableValues(compute: (K) -> V?): MemoizedFunctionToNullable<K, V>
fun <K, V : Any> createCacheWithNullableValues(): CacheWithNullableValues<K, V>
@@ -36,6 +38,8 @@ interface StorageManager {
fun <K, V : Any> createMemoizedFunction(compute: (K) -> V, map: ConcurrentMap<K, Any>): MemoizedFunctionToNotNull<K, V>
fun <K, V : Any> createMemoizedFunction(compute: (K) -> V, onRecursiveCall: (K, Boolean) -> V, map: ConcurrentMap<K, Any>): MemoizedFunctionToNotNull<K, V>
fun <K, V : Any> createMemoizedFunctionWithNullableValues(compute: (K) -> V, map: ConcurrentMap<K, Any>): MemoizedFunctionToNullable<K, V>
fun <T : Any> createLazyValue(computable: () -> T): NotNullLazyValue<T>