Skip to content

Commit

Permalink
Generate the id for the Cache with a unique context if possible
Browse files Browse the repository at this point in the history
  • Loading branch information
andreadimaio committed Jun 7, 2024
1 parent 631e904 commit a438abf
Show file tree
Hide file tree
Showing 10 changed files with 102 additions and 76 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
import io.quarkiverse.langchain4j.runtime.cache.AiCache;
import io.quarkiverse.langchain4j.runtime.cache.AiCacheProvider;
import io.quarkiverse.langchain4j.runtime.cache.AiCacheStore;
import io.quarkiverse.langchain4j.runtime.cache.BasicFixedAiCache;
import io.quarkiverse.langchain4j.runtime.cache.InMemoryAiCacheStore;
import io.quarkiverse.langchain4j.runtime.cache.MessageWindowAiCache;

/**
* Used to create LangChain4j's {@link AiServices} in a declarative manner that the application can then use simply by using the
Expand Down Expand Up @@ -101,7 +101,7 @@
* Configures the way to obtain the {@link AiCacheProvider}.
* <p>
* Be default, Quarkus configures a {@link AiCacheProvider} bean that uses a {@link InMemoryAiCacheStore} bean as the
* backing store. The default type for the actual {@link AiCache} is {@link MessageWindowAiCache} and it is configured with
* backing store. The default type for the actual {@link AiCache} is {@link BasicFixedAiCache} and it is configured with
* the value of the {@code quarkus.langchain4j.cache.max-size} configuration property (which default to
* 1) as a way of limiting the number of messages in each cache.
* <p>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import io.quarkiverse.langchain4j.runtime.cache.AiCache;
import io.quarkiverse.langchain4j.runtime.cache.AiCacheProvider;
import io.quarkiverse.langchain4j.runtime.cache.AiCacheStore;
import io.quarkiverse.langchain4j.runtime.cache.MessageWindowAiCache;
import io.quarkiverse.langchain4j.runtime.cache.BasicFixedAiCache;
import io.quarkiverse.langchain4j.runtime.cache.config.AiCacheConfig;
import io.quarkus.arc.SyntheticCreationalContext;
import io.quarkus.runtime.annotations.Recorder;
Expand Down Expand Up @@ -41,7 +41,7 @@ public AiCacheProvider apply(SyntheticCreationalContext<AiCacheProvider> context
return new AiCacheProvider() {
@Override
public AiCache get(Object memoryId) {
return MessageWindowAiCache.Builder
return BasicFixedAiCache.Builder
.create(memoryId)
.ttl(ttl)
.maxSize(maxSize)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,8 @@ public T apply(SyntheticCreationalContext<T> creationalContext) {
}
}

aiServiceContext.aiCaches = new ConcurrentHashMap<>();
if (aiServiceContext.aiCaches == null)
aiServiceContext.aiCaches = new ConcurrentHashMap<>();
}

return (T) aiServiceContext;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,9 @@ private static Object doImplement(AiServiceMethodCreateInfo methodCreateInfo, Ob
Object memoryId = memoryId(methodCreateInfo, methodArgs, context.chatMemoryProvider != null);
AiCache cache = null;

// TODO: REMOVE THIS COMMENT BEFORE MERGING THE PR.
// - Understand how to implement the concept of cache for the stream responses.
// - What do we have to do when we have the tools?

if (methodCreateInfo.isRequiresCache()) {
Object cacheId = cacheId(methodCreateInfo, methodArgs);
cache = context.aiCacheProvider.get(cacheId);
Object cacheId = cacheId(methodCreateInfo);
cache = context.cache(cacheId);
}
if (context.retrievalAugmentor != null) { // TODO extract method/class
List<ChatMessage> chatMemory = context.hasChatMemory()
Expand Down Expand Up @@ -200,6 +196,7 @@ public void accept(Response<AiMessage> message) {

if (cacheResponse.isPresent()) {
log.debug("Return cached response");
System.out.println("RISPOSTA CON CACHE");
response = Response.from(cacheResponse.get());
} else {
response = executeLLMCall(context, messages, moderationFuture, toolSpecifications);
Expand Down Expand Up @@ -396,17 +393,15 @@ private static Object memoryId(AiServiceMethodCreateInfo createInfo, Object[] me
return "default";
}

private static Object cacheId(AiServiceMethodCreateInfo createInfo, Object[] methodArgs) {
private static Object cacheId(AiServiceMethodCreateInfo createInfo) {
for (DefaultMemoryIdProvider provider : DEFAULT_MEMORY_ID_PROVIDERS) {
Object memoryId = provider.getMemoryId();
if (memoryId != null) {
String perServiceSuffix = "#" + createInfo.getInterfaceName() + "." + createInfo.getMethodName();
return memoryId + perServiceSuffix;
}
}

// fallback to the default since there is nothing else we can really use here
return "default";
return "#" + createInfo.getInterfaceName() + "." + createInfo.getMethodName();
}

// TODO: share these methods with LangChain4j
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ public boolean hasCache() {
return aiCaches != null;
}

public AiCache cache(Object memoryId) {
return aiCaches.computeIfAbsent(memoryId, ignored -> aiCacheProvider.get(memoryId));
public AiCache cache(Object cacheId) {
return aiCaches.computeIfAbsent(cacheId, ignored -> aiCacheProvider.get(cacheId));
}

/**
Expand Down Expand Up @@ -58,7 +58,7 @@ private void clearAiCache() {
if (aiCaches != null) {
aiCaches.forEach(new BiConsumer<>() {
@Override
public void accept(Object memoryId, AiCache aiCache) {
public void accept(Object cacheId, AiCache aiCache) {
aiCache.clear();
}
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

import java.time.Duration;
import java.util.Date;
import java.util.LinkedList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.locks.ReentrantLock;
import java.util.stream.Collectors;

import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.SystemMessage;
Expand All @@ -16,9 +16,9 @@
import io.quarkiverse.langchain4j.runtime.cache.AiCacheStore.CacheRecord;

/**
* This {@link AiCache} implementation operates as a sliding window of messages.
* This {@link AiCache} default implementation.
*/
public class MessageWindowAiCache implements AiCache {
public class BasicFixedAiCache implements AiCache {

private final Object id;
private final Integer maxMessages;
Expand All @@ -30,7 +30,7 @@ public class MessageWindowAiCache implements AiCache {
private final EmbeddingModel embeddingModel;
private final ReentrantLock lock;

public MessageWindowAiCache(Builder builder) {
public BasicFixedAiCache(Builder builder) {
this.id = builder.id;
this.maxMessages = builder.maxSize;
this.store = builder.store;
Expand Down Expand Up @@ -64,25 +64,17 @@ public void add(SystemMessage systemMessage, UserMessage userMessage, AiMessage

lock.lock();

List<CacheRecord> elements = store.getAll(id);
if (elements.size() == maxMessages) {
elements.remove(0);
}

List<CacheRecord> items = new LinkedList<>();
for (int i = 0; i < elements.size(); i++) {

var expiredTime = Date.from(elements.get(i).creation().plus(ttl));
var currentTime = new Date();

if (currentTime.after(expiredTime))
continue;
List<CacheRecord> elements = store.getAll(id)
.stream()
.filter(this::checkTTL)
.collect(Collectors.toList());

items.add(elements.get(i));
if (elements.size() == maxMessages) {
return;
}

items.add(CacheRecord.of(embeddingModel.embed(query).content(), aiResponse));
store.updateCache(id, items);
elements.add(CacheRecord.of(embeddingModel.embed(query).content(), aiResponse));
store.updateCache(id, elements);

} finally {
lock.unlock();
Expand All @@ -101,30 +93,49 @@ public Optional<AiMessage> search(SystemMessage systemMessage, UserMessage userM
else
query = "%s%s%s".formatted(queryPrefix, systemMessage.text(), userMessage.text());

var elements = store.getAll(id);
double maxScore = 0;
AiMessage result = null;
try {

for (var cacheRecord : elements) {
lock.lock();

if (ttl != null) {
var expiredTime = Date.from(cacheRecord.creation().plus(ttl));
var currentTime = new Date();
double maxScore = 0;
AiMessage result = null;
List<CacheRecord> records = store.getAll(id)
.stream()
.filter(this::checkTTL)
.collect(Collectors.toList());

if (currentTime.after(expiredTime))
continue;
}
for (var record : records) {

var relevanceScore = CosineSimilarity.between(embeddingModel.embed(query).content(), cacheRecord.embedded());
var score = (float) CosineSimilarity.fromRelevanceScore(relevanceScore);
var relevanceScore = CosineSimilarity.between(embeddingModel.embed(query).content(), record.embedded());
var score = (float) CosineSimilarity.fromRelevanceScore(relevanceScore);

if (score >= threshold.doubleValue() && score >= maxScore) {
maxScore = score;
result = cacheRecord.response();
if (score >= threshold.doubleValue() && score >= maxScore) {
maxScore = score;
result = record.response();
}
}

store.updateCache(id, records);
return Optional.ofNullable(result);

} finally {
lock.unlock();
}
}

private boolean checkTTL(CacheRecord record) {

if (ttl == null)
return true;

var expiredTime = Date.from(record.creation().plus(ttl));
var currentTime = new Date();

if (currentTime.after(expiredTime)) {
return false;
}

return Optional.ofNullable(result);
return true;
}

@Override
Expand Down Expand Up @@ -187,7 +198,7 @@ public Builder embeddingModel(EmbeddingModel embeddingModel) {
}

public AiCache build() {
return new MessageWindowAiCache(this);
return new BasicFixedAiCache(this);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -76,16 +76,22 @@ public Response<List<Embedding>> embedAll(List<TextSegment> textSegments) {
@Test
void cache_test() {

String cacheId = "default";
String chatCacheId = "#" + LLMService.class.getName() + ".chat";
String chat2CacheId = "#" + LLMService.class.getName() + ".chat2";

assertEquals(0, aiCacheStore.getAll(chatCacheId).size());
assertEquals(0, aiCacheStore.getAll(chat2CacheId).size());

assertEquals(0, aiCacheStore.getAll(cacheId).size());
service.chat("chat");
assertEquals(1, aiCacheStore.getAll(cacheId).size());
assertEquals(1, aiCacheStore.getAll(chatCacheId).size());
assertEquals("result", aiCacheStore.getAll(chatCacheId).get(0).response().text());
assertEquals(es, aiCacheStore.getAll(chatCacheId).get(0).embedded().vector());
assertEquals(0, aiCacheStore.getAll(chat2CacheId).size());

service.chat2("chat2");
assertEquals(1, aiCacheStore.getAll(cacheId).size());
assertEquals("result", aiCacheStore.getAll(cacheId).get(0).response().text());
assertEquals(es, aiCacheStore.getAll(cacheId).get(0).embedded().vector());
assertEquals(1, aiCacheStore.getAll(chat2CacheId).size());
assertEquals("result", aiCacheStore.getAll(chat2CacheId).get(0).response().text());
assertEquals(es, aiCacheStore.getAll(chat2CacheId).get(0).embedded().vector());
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
import io.quarkiverse.langchain4j.CacheResult;
import io.quarkiverse.langchain4j.RegisterAiService;
import io.quarkiverse.langchain4j.runtime.cache.AiCacheStore;
import io.quarkus.arc.Arc;
import io.quarkus.arc.ArcContainer;
import io.quarkus.arc.ManagedContext;
import io.quarkus.test.QuarkusUnitTest;

public class CacheConfigTest {
Expand Down Expand Up @@ -85,7 +88,7 @@ else if (textSegments.get(0).text().equals("TESTFOURTH"))
@Order(1)
void cache_ttl_test() throws InterruptedException {

String cacheId = "default";
String cacheId = "#" + LLMService.class.getName() + ".chat";
aiCacheStore.deleteCache(cacheId);

service.chat("FIRST");
Expand All @@ -107,7 +110,7 @@ void cache_ttl_test() throws InterruptedException {
@Order(2)
void cache_max_size_test() {

String cacheId = "default";
String cacheId = "#" + LLMService.class.getName() + ".chat";
aiCacheStore.deleteCache(cacheId);

service.chat("FIRST");
Expand All @@ -119,12 +122,18 @@ void cache_max_size_test() {
service.chat("THIRD");
service.chat("FOURTH");
assertEquals(3, aiCacheStore.getAll(cacheId).size());
assertEquals("cache: TESTSECOND", aiCacheStore.getAll(cacheId).get(0).response().text());
assertEquals(second, aiCacheStore.getAll(cacheId).get(0).embedded().vector());
assertEquals("cache: TESTTHIRD", aiCacheStore.getAll(cacheId).get(1).response().text());
assertEquals(third, aiCacheStore.getAll(cacheId).get(1).embedded().vector());
assertEquals("cache: TESTFOURTH", aiCacheStore.getAll(cacheId).get(2).response().text());
assertEquals(fourth, aiCacheStore.getAll(cacheId).get(2).embedded().vector());
assertEquals("cache: TESTFIRST", aiCacheStore.getAll(cacheId).get(0).response().text());
assertEquals(first, aiCacheStore.getAll(cacheId).get(0).embedded().vector());
assertEquals("cache: TESTSECOND", aiCacheStore.getAll(cacheId).get(1).response().text());
assertEquals(second, aiCacheStore.getAll(cacheId).get(1).embedded().vector());
assertEquals("cache: TESTTHIRD", aiCacheStore.getAll(cacheId).get(2).response().text());
assertEquals(third, aiCacheStore.getAll(cacheId).get(2).embedded().vector());
}

private String getContext(String methodName) {
ArcContainer container = Arc.container();
ManagedContext requestContext = container.requestContext();
return requestContext.getState() + "#" + LLMService.class.getName() + "." + methodName;
}

static float[] first = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,16 +79,20 @@ public Response<List<Embedding>> embedAll(List<TextSegment> textSegments) {
@Test
void cache_test() {

String cacheId = "default";
String chatCacheId = "#" + LLMService.class.getName() + ".chat";
String chatNoCacheCacheId = "#" + LLMService.class.getName() + ".chatNoCache";

assertEquals(0, aiCacheStore.getAll(cacheId).size());
assertEquals(0, aiCacheStore.getAll(chatCacheId).size());
assertEquals(0, aiCacheStore.getAll(chatNoCacheCacheId).size());
service.chatNoCache("noCache");
assertEquals(0, aiCacheStore.getAll(cacheId).size());
assertEquals(0, aiCacheStore.getAll(chatCacheId).size());
assertEquals(0, aiCacheStore.getAll(chatNoCacheCacheId).size());

service.chat("cache");
assertEquals(1, aiCacheStore.getAll(cacheId).size());
assertEquals("result", aiCacheStore.getAll(cacheId).get(0).response().text());
assertEquals(es, aiCacheStore.getAll(cacheId).get(0).embedded().vector());
assertEquals(1, aiCacheStore.getAll(chatCacheId).size());
assertEquals("result", aiCacheStore.getAll(chatCacheId).get(0).response().text());
assertEquals(es, aiCacheStore.getAll(chatCacheId).get(0).embedded().vector());
assertEquals(0, aiCacheStore.getAll(chatNoCacheCacheId).size());
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ public Response<List<Embedding>> embedAll(List<TextSegment> textSegments) {
@Order(1)
void cache_prefix_test() throws InterruptedException {

String cacheId = "default";
String cacheId = "#" + LLMService.class.getName() + ".chat";
aiCacheStore.deleteCache(cacheId);

service.chat("firstMessage");
Expand Down

0 comments on commit a438abf

Please sign in to comment.