AsyncCache.java
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
package com.azure.cosmos.implementation.caches;
import com.azure.cosmos.implementation.DocumentCollection;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentHashMap;
public class AsyncCache<TKey, TValue> {
private final Logger logger = LoggerFactory.getLogger(AsyncCache.class);
private final ConcurrentHashMap<TKey, AsyncLazy<TValue>> values;
private final IEqualityComparer<TValue> equalityComparer;
public AsyncCache(IEqualityComparer<TValue> equalityComparer, ConcurrentHashMap<TKey, AsyncLazy<TValue>> values) {
this.equalityComparer = equalityComparer;
this.values = values;
}
public AsyncCache(IEqualityComparer<TValue> equalityComparer) {
this(equalityComparer, new ConcurrentHashMap<>());
}
public AsyncCache() {
this((value1, value2) -> {
if (value1 == value2)
return true;
if (value1 == null || value2 == null)
return false;
return value1.equals(value2);
});
}
public void set(TKey key, TValue value) {
logger.debug("set cache[{}]={}", key, value);
values.put(key, new AsyncLazy<>(value));
}
/**
* Gets value corresponding to <code>key</code>
*
* <p>
* If another initialization function is already running, new initialization function will not be started.
* The result will be result of currently running initialization function.
* </p>
*
* <p>
* If previous initialization function is successfully completed - value returned by it will be returned unless
* it is equal to <code>obsoleteValue</code>, in which case new initialization function will be started.
* </p>
* <p>
* If previous initialization function failed - new one will be launched.
* </p>
*
* @param key Key for which to get a value.
* @param obsoleteValue Value which is obsolete and needs to be refreshed.
* @param singleValueInitFunc Initialization function.
* @return Cached value or value returned by initialization function.
*/
public Mono<TValue> getAsync(
TKey key,
TValue obsoleteValue,
Callable<Mono<TValue>> singleValueInitFunc) {
AsyncLazy<TValue> initialLazyValue = values.get(key);
if (initialLazyValue != null) {
logger.debug("cache[{}] exists", key);
return initialLazyValue.single().flux().flatMap(value -> {
if (!equalityComparer.areEqual(value, obsoleteValue)) {
logger.debug("Returning cache[{}] as it is different from obsoleteValue", key);
return Flux.just(value);
}
logger.debug("cache[{}] result value is obsolete ({}), computing new value", key, obsoleteValue);
AsyncLazy<TValue> asyncLazy = new AsyncLazy<>(singleValueInitFunc);
AsyncLazy<TValue> actualValue = values.merge(key, asyncLazy,
(lazyValue1, lazyValue2) -> lazyValue1 == initialLazyValue ? lazyValue2 : lazyValue1);
return actualValue.single().flux();
}, err -> {
logger.debug("cache[{}] resulted in error, computing new value", key, err);
AsyncLazy<TValue> asyncLazy = new AsyncLazy<>(singleValueInitFunc);
AsyncLazy<TValue> resultAsyncLazy = values.merge(key, asyncLazy,
(lazyValue1, lazyValu2) -> lazyValue1 == initialLazyValue ? lazyValu2 : lazyValue1);
return resultAsyncLazy.single().flux();
}, Flux::empty).single();
}
logger.debug("cache[{}] doesn't exist, computing new value", key);
AsyncLazy<TValue> asyncLazy = new AsyncLazy<>(singleValueInitFunc);
AsyncLazy<TValue> resultAsyncLazy = values.merge(key, asyncLazy,
(lazyValue1, lazyValu2) -> lazyValue1 == initialLazyValue ? lazyValu2 : lazyValue1);
return resultAsyncLazy.single();
}
public void remove(TKey key) {
values.remove(key);
}
/**
* Remove value from cache and return it if present
* @param key
* @return Value if present, default value if not present
*/
public Mono<TValue> removeAsync(TKey key) {
AsyncLazy<TValue> lazy = values.remove(key);
return lazy.single();
// TODO: .Net returns default value on failure of single why?
}
public void clear() {
this.values.clear();
}
/**
* Forces refresh of the cached item if it is not being refreshed at the moment.
* @param key
* @param singleValueInitFunc
*/
public void refresh(
TKey key,
Callable<Mono<TValue>> singleValueInitFunc) {
logger.debug("refreshing cache[{}]", key);
AsyncLazy<TValue> initialLazyValue = values.get(key);
if (initialLazyValue != null && (initialLazyValue.isSucceeded() || initialLazyValue.isFaulted())) {
AsyncLazy<TValue> newLazyValue = new AsyncLazy<>(singleValueInitFunc);
// UPDATE the new task in the cache,
values.merge(key, newLazyValue,
(lazyValue1, lazyValu2) -> lazyValue1 == initialLazyValue ? lazyValu2 : lazyValue1);
}
}
public abstract static class SerializableAsyncCache<TKey, TValue> implements Serializable {
private static final long serialVersionUID = 2l;
private static transient Logger logger = LoggerFactory.getLogger(SerializableAsyncCache.class);
protected transient AsyncCache<TKey, TValue> cache;
protected SerializableAsyncCache() {}
public static class SerializableAsyncCollectionCache extends SerializableAsyncCache<String, DocumentCollection> {
private static final long serialVersionUID = 2l;
private SerializableAsyncCollectionCache() {}
@Override
protected void serializeKey(ObjectOutputStream oos, String s) throws IOException {
oos.writeUTF(s);
}
@Override
protected void serializeValue(ObjectOutputStream oos, DocumentCollection documentCollection) throws IOException {
oos.writeObject(DocumentCollection.SerializableDocumentCollection.from(documentCollection));
}
@Override
protected String deserializeKey(ObjectInputStream ois) throws IOException {
return ois.readUTF();
}
@Override
protected DocumentCollection deserializeValue(ObjectInputStream ois) throws IOException,
ClassNotFoundException {
return ((DocumentCollection.SerializableDocumentCollection) ois.readObject()).getWrappedItem();
}
}
@SuppressWarnings("unchecked")
public static <TKey, TValue> SerializableAsyncCache<TKey, TValue> from(AsyncCache<TKey,
TValue> cache, Class<TKey> keyClass, Class<TValue> valueClass) {
if (keyClass == String.class && valueClass == DocumentCollection.class) {
SerializableAsyncCollectionCache sacc = new SerializableAsyncCollectionCache();
sacc.cache = (AsyncCache<String, DocumentCollection>) cache;
return (SerializableAsyncCache<TKey, TValue>) sacc;
} else {
throw new RuntimeException("not supported");
}
}
protected abstract void serializeKey(ObjectOutputStream oos, TKey key) throws IOException;
protected abstract void serializeValue(ObjectOutputStream oos, TValue value) throws IOException;
protected abstract TKey deserializeKey(ObjectInputStream oos) throws IOException;
protected abstract TValue deserializeValue(ObjectInputStream oos) throws IOException, ClassNotFoundException;
public AsyncCache<TKey, TValue> toAsyncCache() {
return this.cache;
}
private void writeObject(ObjectOutputStream oos)
throws IOException {
logger.info("Serializing {}", this.getClass());
Map<TKey, TValue> paris = new HashMap<>();
for (Map.Entry<TKey, AsyncLazy<TValue>> entry : cache.values.entrySet()) {
TKey key = entry.getKey();
Optional<TValue> value = entry.getValue().tryGet();
if (value.isPresent()) {
paris.put(key, value.get());
}
}
oos.writeInt(paris.size());
for (Map.Entry<TKey, TValue> entry : paris.entrySet()) {
serializeKey(oos, entry.getKey());
serializeValue(oos, entry.getValue());
}
oos.writeObject(cache.equalityComparer);
}
private void readObject(ObjectInputStream ois) throws ClassNotFoundException, IOException {
logger.info("Deserializing {}", this.getClass());
int size = ois.readInt();
ConcurrentHashMap<TKey, AsyncLazy<TValue>> pairs = new ConcurrentHashMap<>();
for (int i = 0; i < size; i++) {
TKey key = deserializeKey(ois);
TValue value = deserializeValue(ois);
pairs.put(key, new AsyncLazy<>(value));
}
@SuppressWarnings("unchecked")
IEqualityComparer<TValue> equalityComparer = (IEqualityComparer<TValue>) ois.readObject();
this.cache = new AsyncCache<>(equalityComparer, pairs);
}
}
}