Add support for a callback on recursion for memoized functions
This commit is contained in:
@@ -163,7 +163,8 @@ public class StorageManagerTest extends TestCase {
|
||||
fail();
|
||||
}
|
||||
catch (AssertionError e) {
|
||||
assertTrue(e.getMessage().startsWith("Recursion detected on input: !!!"));
|
||||
String message = e.getMessage();
|
||||
assertTrue("Expected message starting with \"Recursion detected\", got: " + message, message.startsWith("Recursion detected on input: !!!"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -213,7 +214,7 @@ public class StorageManagerTest extends TestCase {
|
||||
new C().rec.invoke();
|
||||
fail();
|
||||
}
|
||||
catch (IllegalStateException e) {
|
||||
catch (AssertionError e) {
|
||||
// OK
|
||||
}
|
||||
}
|
||||
@@ -232,7 +233,7 @@ public class StorageManagerTest extends TestCase {
|
||||
new C().rec.invoke();
|
||||
fail();
|
||||
}
|
||||
catch (IllegalStateException e) {
|
||||
catch (AssertionError e) {
|
||||
// OK
|
||||
}
|
||||
}
|
||||
@@ -284,7 +285,7 @@ public class StorageManagerTest extends TestCase {
|
||||
new C().rec.invoke();
|
||||
fail();
|
||||
}
|
||||
catch (IllegalStateException e) {
|
||||
catch (AssertionError e) {
|
||||
// OK
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@ package org.jetbrains.kotlin.storage;
|
||||
import kotlin.Unit;
|
||||
import kotlin.jvm.functions.Function0;
|
||||
import kotlin.jvm.functions.Function1;
|
||||
import kotlin.jvm.functions.Function2;
|
||||
import kotlin.text.StringsKt;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.jetbrains.annotations.Nullable;
|
||||
@@ -56,7 +57,7 @@ public class LockBasedStorageManager implements StorageManager {
|
||||
public static final StorageManager NO_LOCKS = new LockBasedStorageManager("NO_LOCKS", ExceptionHandlingStrategy.THROW, EmptySimpleLock.INSTANCE) {
|
||||
@NotNull
|
||||
@Override
|
||||
protected <T> RecursionDetectedResult<T> recursionDetectedDefault() {
|
||||
protected <K, V> RecursionDetectedResult<V> recursionDetectedDefault(@NotNull String source, K input) {
|
||||
return RecursionDetectedResult.fallThrough();
|
||||
}
|
||||
};
|
||||
@@ -123,6 +124,15 @@ public class LockBasedStorageManager implements StorageManager {
|
||||
return createMemoizedFunction(compute, LockBasedStorageManager.<K>createConcurrentHashMap());
|
||||
}
|
||||
|
||||
@NotNull
|
||||
@Override
|
||||
public <K, V> MemoizedFunctionToNotNull<K, V> createMemoizedFunction(
|
||||
@NotNull Function1<? super K, ? extends V> compute,
|
||||
@NotNull Function2<? super K, ? super Boolean, ? extends V> onRecursiveCall
|
||||
) {
|
||||
return createMemoizedFunction(compute, onRecursiveCall, LockBasedStorageManager.<K>createConcurrentHashMap());
|
||||
}
|
||||
|
||||
@NotNull
|
||||
@Override
|
||||
public <K, V> MemoizedFunctionToNotNull<K, V> createMemoizedFunction(
|
||||
@@ -132,6 +142,22 @@ public class LockBasedStorageManager implements StorageManager {
|
||||
return new MapBasedMemoizedFunctionToNotNull<K, V>(this, map, compute);
|
||||
}
|
||||
|
||||
@NotNull
|
||||
@Override
|
||||
public <K, V> MemoizedFunctionToNotNull<K, V> createMemoizedFunction(
|
||||
@NotNull Function1<? super K, ? extends V> compute,
|
||||
@NotNull final Function2<? super K, ? super Boolean, ? extends V> onRecursiveCall,
|
||||
@NotNull ConcurrentMap<K, Object> map
|
||||
) {
|
||||
return new MapBasedMemoizedFunctionToNotNull<K, V>(this, map, compute) {
|
||||
@NotNull
|
||||
@Override
|
||||
protected RecursionDetectedResult<V> recursionDetected(K input, boolean firstTime) {
|
||||
return RecursionDetectedResult.value(onRecursiveCall.invoke(input, firstTime));
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@NotNull
|
||||
@Override
|
||||
public <K, V> MemoizedFunctionToNullable<K, V> createMemoizedFunctionWithNullableValues(@NotNull Function1<? super K, ? extends V> compute) {
|
||||
@@ -278,8 +304,15 @@ public class LockBasedStorageManager implements StorageManager {
|
||||
}
|
||||
|
||||
@NotNull
|
||||
protected <T> RecursionDetectedResult<T> recursionDetectedDefault() {
|
||||
throw sanitizeStackTrace(new IllegalStateException("Recursive call in a lazy value under " + this));
|
||||
protected <K, V> RecursionDetectedResult<V> recursionDetectedDefault(@NotNull String source, K input) {
|
||||
throw sanitizeStackTrace(
|
||||
new AssertionError("Recursion detected " + source +
|
||||
(input == null
|
||||
? ""
|
||||
: "on input: " + input
|
||||
) + " under " + this
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
private static class RecursionDetectedResult<T> {
|
||||
@@ -406,7 +439,7 @@ public class LockBasedStorageManager implements StorageManager {
|
||||
*/
|
||||
@NotNull
|
||||
protected RecursionDetectedResult<T> recursionDetected(boolean firstTime) {
|
||||
return storageManager.recursionDetectedDefault();
|
||||
return storageManager.recursionDetectedDefault("in a lazy value", null);
|
||||
}
|
||||
|
||||
protected void postCompute(T value) {
|
||||
@@ -521,9 +554,22 @@ public class LockBasedStorageManager implements StorageManager {
|
||||
storageManager.lock.lock();
|
||||
try {
|
||||
value = cache.get(input);
|
||||
|
||||
if (value == NotValue.COMPUTING) {
|
||||
throw recursionDetected(input);
|
||||
value = NotValue.RECURSION_WAS_DETECTED;
|
||||
RecursionDetectedResult<V> result = recursionDetected(input, /*firstTime = */ true);
|
||||
if (!result.isFallThrough()) {
|
||||
return result.getValue();
|
||||
}
|
||||
}
|
||||
|
||||
if (value == NotValue.RECURSION_WAS_DETECTED) {
|
||||
RecursionDetectedResult<V> result = recursionDetected(input, /*firstTime = */ false);
|
||||
if (!result.isFallThrough()) {
|
||||
return result.getValue();
|
||||
}
|
||||
}
|
||||
|
||||
if (value != null) return WrappedValues.unescapeExceptionOrNull(value);
|
||||
|
||||
AssertionError error = null;
|
||||
@@ -567,10 +613,8 @@ public class LockBasedStorageManager implements StorageManager {
|
||||
}
|
||||
|
||||
@NotNull
|
||||
private AssertionError recursionDetected(K input) {
|
||||
return sanitizeStackTrace(
|
||||
new AssertionError("Recursion detected on input: " + input + " under " + storageManager)
|
||||
);
|
||||
protected RecursionDetectedResult<V> recursionDetected(K input, boolean firstTime) {
|
||||
return storageManager.recursionDetectedDefault("", input);
|
||||
}
|
||||
|
||||
@NotNull
|
||||
|
||||
@@ -26,6 +26,10 @@ abstract class ObservableStorageManager(private val delegate: StorageManager) :
|
||||
return delegate.createMemoizedFunction(compute.observable)
|
||||
}
|
||||
|
||||
override fun <K, V : Any> createMemoizedFunction(compute: (K) -> V, onRecursiveCall: (K, Boolean) -> V): MemoizedFunctionToNotNull<K, V> {
|
||||
return delegate.createMemoizedFunction(compute.observable, onRecursiveCall)
|
||||
}
|
||||
|
||||
override fun <K, V: Any> createMemoizedFunctionWithNullableValues(compute: (K) -> V?): MemoizedFunctionToNullable<K, V> {
|
||||
return delegate.createMemoizedFunctionWithNullableValues(compute.observable)
|
||||
}
|
||||
@@ -34,6 +38,10 @@ abstract class ObservableStorageManager(private val delegate: StorageManager) :
|
||||
return delegate.createMemoizedFunction(compute.observable, map)
|
||||
}
|
||||
|
||||
override fun <K, V : Any> createMemoizedFunction(compute: (K) -> V, onRecursiveCall: (K, Boolean) -> V, map: ConcurrentMap<K, Any>): MemoizedFunctionToNotNull<K, V> {
|
||||
return delegate.createMemoizedFunction(compute.observable, onRecursiveCall, map)
|
||||
}
|
||||
|
||||
override fun <K, V: Any> createMemoizedFunctionWithNullableValues(compute: (K) -> V, map: ConcurrentMap<K, Any>): MemoizedFunctionToNullable<K, V> {
|
||||
return delegate.createMemoizedFunctionWithNullableValues(compute.observable, map)
|
||||
}
|
||||
|
||||
@@ -29,6 +29,8 @@ interface StorageManager {
|
||||
*/
|
||||
fun <K, V : Any> createMemoizedFunction(compute: (K) -> V): MemoizedFunctionToNotNull<K, V>
|
||||
|
||||
fun <K, V : Any> createMemoizedFunction(compute: (K) -> V, onRecursiveCall: (K, Boolean) -> V): MemoizedFunctionToNotNull<K, V>
|
||||
|
||||
fun <K, V : Any> createMemoizedFunctionWithNullableValues(compute: (K) -> V?): MemoizedFunctionToNullable<K, V>
|
||||
|
||||
fun <K, V : Any> createCacheWithNullableValues(): CacheWithNullableValues<K, V>
|
||||
@@ -36,6 +38,8 @@ interface StorageManager {
|
||||
|
||||
fun <K, V : Any> createMemoizedFunction(compute: (K) -> V, map: ConcurrentMap<K, Any>): MemoizedFunctionToNotNull<K, V>
|
||||
|
||||
fun <K, V : Any> createMemoizedFunction(compute: (K) -> V, onRecursiveCall: (K, Boolean) -> V, map: ConcurrentMap<K, Any>): MemoizedFunctionToNotNull<K, V>
|
||||
|
||||
fun <K, V : Any> createMemoizedFunctionWithNullableValues(compute: (K) -> V, map: ConcurrentMap<K, Any>): MemoizedFunctionToNullable<K, V>
|
||||
|
||||
fun <T : Any> createLazyValue(computable: () -> T): NotNullLazyValue<T>
|
||||
|
||||
Reference in New Issue
Block a user