diff --git a/BUILD.md b/BUILD.md index 7c45218d..bc0c61ce 100644 --- a/BUILD.md +++ b/BUILD.md @@ -36,11 +36,11 @@ To build the Semantic Kernel for Java, you will need: 1. Clone this repository - git clone -b java-v1 https://github.com/microsoft/semantic-kernel/ + git clone https://github.com/microsoft/semantic-kernel-java 2. Build the project with the Maven Wrapper - cd semantic-kernel/java + cd semantic-kernel ./mvnw install 3. (Optional) To run a FULL build including static analysis and end-to-end tests that might require a valid OpenAI key, @@ -104,8 +104,7 @@ Also ensure that: - All new code is covered by unit tests - All new code is covered by integration tests -Once your proposal is ready, submit a pull request to the `java-v1` branch. The pull request will be reviewed by the -project maintainers. +Once your proposal is ready, submit a pull request. The pull request will be reviewed by the project maintainers. Make sure your pull request has an objective title and a clear description explaining the problem and solution. diff --git a/CHANGELOG.md b/CHANGELOG.md index 3ad56c82..8942e497 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,23 @@ +# 1.2.2 + +- Fix bug in `FunctionInvocation` not using per-invocation type conversion when calling `withResultType`. +- Fix bug in Global Hooks not being invoked under certain circumstances. +- Add fluent returns to `ChatHistory` `addXMessage` methods. +- Add user agent opt-out for OpenAI requests by setting the property `semantic-kernel.useragent-disable` to `true`. +- Add several convenience `invokePromptAsync` methods to `Kernel`. +- Allow Handlebars templates to call Javabean getters to extract data from invocation arguments. +- Improve thread safety of `ChatHistory`. + +#### Experimental Changes + +- Add JDBC vector store + +#### Non-API Changes + +- Add custom type Conversion example, `CustomTypes_Example` +- Dependency updates and pom cleanup +- Documentation updates + # 1.2.0 - Add ability to use image_url as content for a OpenAi chat completion diff --git a/README.md b/README.md index 715b979a..b748c348 100644 --- a/README.md +++ b/README.md @@ -5,12 +5,12 @@ # Semantic Kernel for Java -Welcome to the Semantic Kernel for Java. For detailed documentation, visit [Microsoft Learn](https://learn.microsoft.com/en-us/semantic-kernel/overview/?tabs=Java). +Welcome to the Semantic Kernel for Java. For detailed documentation, visit [Microsoft Learn](https://learn.microsoft.com/en-us/semantic-kernel/overview/?tabs=Java&pivots=programming-language-java). [Semantic Kernel](https://learn.microsoft.com/en-us/semantic-kernel/overview/) is an SDK that integrates Large Language Models (LLMs) like [OpenAI](https://platform.openai.com/docs/introduction), [Azure OpenAI](https://azure.microsoft.com/en-us/products/ai-services/openai-service), and [Hugging Face](https://huggingface.co/) -with conventional programming languages like C#, Python, and Java. Semantic Kernel achieves this by allowing you to define [plugins](https://learn.microsoft.com/en-us/semantic-kernel/ai-orchestration/plugins) that can be chained together in just a [few lines of code](https://learn.microsoft.com/en-us/semantic-kernel/ai-orchestration/chaining-functions?tabs=Java#using-the-runasync-method-to-simplify-your-code). +with conventional programming languages like C#, Python, and Java. Semantic Kernel achieves this by allowing you to define [plugins](https://learn.microsoft.com/en-us/semantic-kernel/ai-orchestration/plugins??tabs=Java&pivots=programming-language-java) that can be chained together in just a [few lines of code](https://learn.microsoft.com/en-us/semantic-kernel/ai-orchestration/chaining-functions?tabs=Java&pivots=programming-language-java#using-the-runasync-method-to-simplify-your-code). -What makes Semantic Kernel _special_, however, is its ability to _automatically_ orchestrate plugins with AI. With Semantic Kernel [planners](https://learn.microsoft.com/en-us/semantic-kernel/ai-orchestration/planner), you can ask an LLM to generate a plan that achieves a user's unique goal. Afterwards, Semantic Kernel will execute the plan for the user. +What makes Semantic Kernel _special_, however, is its ability to _automatically_ orchestrate plugins with AI. With Semantic Kernel [planners](https://learn.microsoft.com/en-us/semantic-kernel/ai-orchestration/planner?tabs=Java&pivots=programming-language-java), you can ask an LLM to generate a plan that achieves a user's unique goal. Afterwards, Semantic Kernel will execute the plan for the user. For C#, Python and other language support, see [microsoft/semantic-kernel](https://github.com/microsoft/semantic-kernel). @@ -23,20 +23,20 @@ For C#, Python and other language support, see [microsoft/semantic-kernel](https The quickest way to get started with the basics is to get an API key from either OpenAI or Azure OpenAI and to run one of the Java console applications/scripts below. 1. Clone the repository: `git clone https://github.com/microsoft/semantic-kernel-java.git` -2. Follow the instructions [Start learning how to use Semantic Kernel](https://learn.microsoft.com/en-us/semantic-kernel/get-started/quick-start-guide?tabs=Java). +2. Follow the instructions [Start learning how to use Semantic Kernel](https://learn.microsoft.com/en-us/semantic-kernel/get-started/quick-start-guide?tabs=Java&pivots=programming-language-java). ## Documentation: Learning how to use Semantic Kernel The fastest way to learn how to use Semantic Kernel is with our walkthroughs on our Learn site. -1. 📖 [Overview of the kernel](https://learn.microsoft.com/en-us/semantic-kernel/ai-orchestration/?tabs=Java) -1. 🔌 [Understanding AI plugins](https://learn.microsoft.com/en-us/semantic-kernel/ai-orchestration/plugins?tabs=Java) -1. 👄 [Creating semantic functions](https://learn.microsoft.com/en-us/semantic-kernel/ai-orchestration/semantic-functions?tabs=Java) -1. 💽 [Creating native functions](https://learn.microsoft.com/en-us/semantic-kernel/ai-orchestration/native-functions?tabs=Java) -1. ⛓️ [Chaining functions together](https://learn.microsoft.com/en-us/semantic-kernel/ai-orchestration/chaining-functions?tabs=Java) -1. 🤖 [Auto create plans with planner](https://learn.microsoft.com/en-us/semantic-kernel/ai-orchestration/planner?tabs=Java) -1. 💡 [Create and run a ChatGPT plugin](https://learn.microsoft.com/en-us/semantic-kernel/ai-orchestration/chatgpt-plugins?tabs=Java) +1. 📖 [Overview of the kernel](https://learn.microsoft.com/en-us/semantic-kernel/ai-orchestration/?tabs=Java&pivots=programming-language-java) +1. 🔌 [Understanding AI plugins](https://learn.microsoft.com/en-us/semantic-kernel/ai-orchestration/plugins?tabs=Java&pivots=programming-language-java) +1. 👄 [Creating semantic functions](https://learn.microsoft.com/en-us/semantic-kernel/ai-orchestration/semantic-functions?tabs=Java&pivots=programming-language-java) +1. 💽 [Creating native functions](https://learn.microsoft.com/en-us/semantic-kernel/ai-orchestration/native-functions?tabs=Java&pivots=programming-language-java) +1. ⛓️ [Chaining functions together](https://learn.microsoft.com/en-us/semantic-kernel/ai-orchestration/chaining-functions?tabs=Java&pivots=programming-language-java) +1. 🤖 [Auto create plans with planner](https://learn.microsoft.com/en-us/semantic-kernel/ai-orchestration/planner?tabs=Java&pivots=programming-language-java) +1. 💡 [Create and run a ChatGPT plugin](https://learn.microsoft.com/en-us/semantic-kernel/ai-orchestration/chatgpt-plugins?tabs=Java&pivots=programming-language-java) ## Join the community @@ -50,8 +50,8 @@ in a different direction, but also to consider the impact on the larger ecosyste To learn more and get started: -- Read the [documentation](https://learn.microsoft.com/en-us/semantic-kernel/overview/?tabs=Java) -- Learn how to [contribute](https://learn.microsoft.com/en-us/semantic-kernel/get-started/contributing) to the project +- Read the [documentation](https://learn.microsoft.com/en-us/semantic-kernel/overview/?tabs=Java&pivots=programming-language-java) +- Learn how to [contribute](https://learn.microsoft.com/en-us/semantic-kernel/get-started/contributing?tabs=Java&pivots=programming-language-java) to the project - Join the [Discord community](https://aka.ms/SKDiscord) - Attend [regular office hours and SK community events](COMMUNITY.md) - Follow the team on our [blog](https://aka.ms/sk/blog) diff --git a/aiservices/google/pom.xml b/aiservices/google/pom.xml index ef878373..f9ff65a4 100644 --- a/aiservices/google/pom.xml +++ b/aiservices/google/pom.xml @@ -4,7 +4,7 @@ com.microsoft.semantic-kernel semantickernel-parent - 1.2.0 + 1.2.2 ../../pom.xml diff --git a/aiservices/huggingface/pom.xml b/aiservices/huggingface/pom.xml index f53983c9..ccfbb689 100644 --- a/aiservices/huggingface/pom.xml +++ b/aiservices/huggingface/pom.xml @@ -6,7 +6,7 @@ com.microsoft.semantic-kernel semantickernel-parent - 1.2.0 + 1.2.2 ../../pom.xml diff --git a/aiservices/openai/pom.xml b/aiservices/openai/pom.xml index a7b12c09..d6a3762e 100644 --- a/aiservices/openai/pom.xml +++ b/aiservices/openai/pom.xml @@ -6,7 +6,7 @@ com.microsoft.semantic-kernel semantickernel-parent - 1.2.0 + 1.2.2 ../../pom.xml diff --git a/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/OpenAiService.java b/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/OpenAiService.java index 4da7b67b..4a7a196b 100644 --- a/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/OpenAiService.java +++ b/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/OpenAiService.java @@ -1,23 +1,22 @@ // Copyright (c) Microsoft. All rights reserved. package com.microsoft.semantickernel.aiservices.openai; -import com.azure.ai.openai.OpenAIAsyncClient; import com.microsoft.semantickernel.services.AIService; import javax.annotation.Nullable; /** * Provides OpenAI service. */ -public abstract class OpenAiService implements AIService { +public abstract class OpenAiService implements AIService { - private final OpenAIAsyncClient client; + private final Client client; @Nullable private final String serviceId; private final String modelId; private final String deploymentName; protected OpenAiService( - OpenAIAsyncClient client, + Client client, @Nullable String serviceId, String modelId, String deploymentName) { @@ -39,7 +38,7 @@ public String getServiceId() { return serviceId; } - protected OpenAIAsyncClient getClient() { + protected Client getClient() { return client; } diff --git a/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/audio/OpenAiAudioToTextService.java b/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/audio/OpenAiAudioToTextService.java index 07fdc76f..631f2cac 100644 --- a/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/audio/OpenAiAudioToTextService.java +++ b/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/audio/OpenAiAudioToTextService.java @@ -18,7 +18,8 @@ /** * Provides OpenAi implementation of audio to text service. */ -public class OpenAiAudioToTextService extends OpenAiService implements AudioToTextService { +public class OpenAiAudioToTextService extends OpenAiService + implements AudioToTextService { private static final Logger LOGGER = LoggerFactory.getLogger(OpenAiAudioToTextService.class); diff --git a/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/audio/OpenAiTextToAudioService.java b/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/audio/OpenAiTextToAudioService.java index b4f4dafd..c698fab3 100644 --- a/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/audio/OpenAiTextToAudioService.java +++ b/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/audio/OpenAiTextToAudioService.java @@ -17,7 +17,8 @@ /** * Provides OpenAi implementation of text to audio service. */ -public class OpenAiTextToAudioService extends OpenAiService implements TextToAudioService { +public class OpenAiTextToAudioService extends OpenAiService + implements TextToAudioService { private static final Logger LOGGER = LoggerFactory.getLogger(OpenAiTextToAudioService.class); diff --git a/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIChatCompletion.java b/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIChatCompletion.java index 6bdb4f1c..5442e51e 100644 --- a/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIChatCompletion.java +++ b/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIChatCompletion.java @@ -79,7 +79,8 @@ /** * OpenAI chat completion service. */ -public class OpenAIChatCompletion extends OpenAiService implements ChatCompletionService { +public class OpenAIChatCompletion extends OpenAiService + implements ChatCompletionService { private static final Logger LOGGER = LoggerFactory.getLogger(OpenAIChatCompletion.class); @@ -183,7 +184,7 @@ private static class ChatMessages { private final List newMessages; private final List allMessages; - private final List newChatMessageContent; + private final List> newChatMessageContent; public ChatMessages(List allMessages) { this.allMessages = Collections.unmodifiableList(allMessages); @@ -194,7 +195,7 @@ public ChatMessages(List allMessages) { private ChatMessages( List allMessages, List newMessages, - List newChatMessageContent) { + List> newChatMessageContent) { this.allMessages = Collections.unmodifiableList(allMessages); this.newMessages = Collections.unmodifiableList(newMessages); this.newChatMessageContent = Collections.unmodifiableList(newChatMessageContent); @@ -218,8 +219,8 @@ public ChatMessages add(ChatRequestMessage requestMessage) { } @CheckReturnValue - public ChatMessages addChatMessage(List chatMessageContent) { - ArrayList tmpChatMessageContent = new ArrayList<>( + public ChatMessages addChatMessage(List> chatMessageContent) { + ArrayList> tmpChatMessageContent = new ArrayList<>( newChatMessageContent); tmpChatMessageContent.addAll(chatMessageContent); @@ -311,6 +312,7 @@ private Mono internalChatMessageContentsAsync( ChatCompletionsOptions options = executeHook( invocationContext, + kernel, new PreChatCompletionEvent( getCompletionsOptions( this, @@ -349,25 +351,23 @@ private Mono internalChatMessageContentsAsync( .collect(Collectors.toList()); // execute post chat completion hook - executeHook(invocationContext, new PostChatCompletionEvent(completions)); + executeHook(invocationContext, kernel, new PostChatCompletionEvent(completions)); // Just return the result: // If we don't want to attempt to invoke any functions // Or if we are auto-invoking, but we somehow end up with other than 1 choice even though only 1 was requested if (autoInvokeAttempts == 0 || responseMessages.size() != 1) { - return getChatMessageContentsAsync(completions) - .flatMap(m -> { - return Mono.just(messages.addChatMessage(m)); - }); + List> chatMessageContents = getChatMessageContentsAsync( + completions); + return Mono.just(messages.addChatMessage(chatMessageContents)); } // Or if there are no tool calls to be done ChatResponseMessage response = responseMessages.get(0); List toolCalls = response.getToolCalls(); if (toolCalls == null || toolCalls.isEmpty()) { - return getChatMessageContentsAsync(completions) - .flatMap(m -> { - return Mono.just(messages.addChatMessage(m)); - }); + List> chatMessageContents = getChatMessageContentsAsync( + completions); + return Mono.just(messages.addChatMessage(chatMessageContents)); } ChatRequestAssistantMessage requestMessage = new ChatRequestAssistantMessage( @@ -517,11 +517,12 @@ private Mono> invokeFunctionTool( pluginName, openAIFunctionToolCall.getFunctionName()); - PreToolCallEvent hookResult = executeHook(invocationContext, new PreToolCallEvent( - openAIFunctionToolCall.getFunctionName(), - openAIFunctionToolCall.getArguments(), - function, - contextVariableTypes)); + PreToolCallEvent hookResult = executeHook(invocationContext, kernel, + new PreToolCallEvent( + openAIFunctionToolCall.getFunctionName(), + openAIFunctionToolCall.getArguments(), + function, + contextVariableTypes)); function = hookResult.getFunction(); KernelFunctionArguments arguments = hookResult.getArguments(); @@ -537,12 +538,21 @@ private Mono> invokeFunctionTool( private static T executeHook( @Nullable InvocationContext invocationContext, + @Nullable Kernel kernel, T event) { - KernelHooks kernelHooks = invocationContext != null - && invocationContext.getKernelHooks() != null - ? invocationContext.getKernelHooks() - : new KernelHooks(); - + KernelHooks kernelHooks = null; + if (kernel == null) { + if (invocationContext != null) { + kernelHooks = invocationContext.getKernelHooks(); + } + } else { + kernelHooks = KernelHooks.merge( + kernel.getGlobalKernelHooks(), + invocationContext != null ? invocationContext.getKernelHooks() : null); + } + if (kernelHooks == null) { + return event; + } return kernelHooks.executeHooks(event); } @@ -580,7 +590,7 @@ private OpenAIFunctionToolCall extractOpenAIFunctionToolCall( arguments); } - private Mono> getChatMessageContentsAsync( + private List> getChatMessageContentsAsync( ChatCompletions completions) { FunctionResultMetadata completionMetadata = FunctionResultMetadata.build( completions.getId(), @@ -594,22 +604,27 @@ private Mono> getChatMessageContentsAsync( .filter(Objects::nonNull) .collect(Collectors.toList()); - return Flux.fromIterable(responseMessages) - .flatMap(response -> { + List> chatMessageContent = responseMessages + .stream() + .map(response -> { try { - return Mono.just(new OpenAIChatMessageContent( + return new OpenAIChatMessageContent<>( AuthorRole.ASSISTANT, response.getContent(), this.getModelId(), null, null, completionMetadata, - formOpenAiToolCalls(response))); - } catch (Exception e) { - return Mono.error(e); + formOpenAiToolCalls(response)); + } catch (SKCheckedException e) { + LOGGER.warn("Failed to form chat message content", e); + return null; } }) - .collectList(); + .filter(Objects::nonNull) + .collect(Collectors.toList()); + + return chatMessageContent; } private List> toOpenAIChatMessageContent( @@ -919,7 +934,7 @@ private static boolean hasToolCallBeenExecuted(List chatRequ } private static List getChatRequestMessages( - List messages) { + List> messages) { if (messages == null || messages.isEmpty()) { return new ArrayList<>(); } @@ -1044,7 +1059,8 @@ static ChatRequestMessage getChatRequestMessage( /** * Builder for creating a new instance of {@link OpenAIChatCompletion}. */ - public static class Builder extends OpenAiServiceBuilder { + public static class Builder + extends OpenAiServiceBuilder { @Override public OpenAIChatCompletion build() { diff --git a/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIChatMessageContent.java b/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIChatMessageContent.java index f2cbf858..89f45014 100644 --- a/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIChatMessageContent.java +++ b/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIChatMessageContent.java @@ -36,7 +36,7 @@ public OpenAIChatMessageContent( @Nullable String modelId, @Nullable T innerContent, @Nullable Charset encoding, - @Nullable FunctionResultMetadata metadata, + @Nullable FunctionResultMetadata metadata, @Nullable List toolCall) { super(authorRole, content, modelId, innerContent, encoding, metadata); diff --git a/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/implementation/OpenAIRequestSettings.java b/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/implementation/OpenAIRequestSettings.java index 8da85ed9..35d0d3f3 100644 --- a/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/implementation/OpenAIRequestSettings.java +++ b/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/implementation/OpenAIRequestSettings.java @@ -20,16 +20,26 @@ public final class OpenAIRequestSettings { private static final String SEMANTIC_KERNEL_VERSION_PROPERTY_NAME = "semantic-kernel.version"; private static final String SEMANTIC_KERNEL_VERSION_PROPERTIES_FILE = "semantic-kernel-version.properties"; - private static final String useragent; + private static final String useragent; private static final String header; + public static final String SEMANTIC_KERNEL_DISABLE_USERAGENT_PROPERTY = "semantic-kernel.useragent-disable"; + + private static final boolean disabled; + static { + disabled = isDisabled(); String version = loadVersion(); useragent = "semantic-kernel-java/" + version; header = "java/" + version; } + private static boolean isDisabled() { + return Boolean.parseBoolean( + System.getProperty(SEMANTIC_KERNEL_DISABLE_USERAGENT_PROPERTY, "false")); + } + private static String loadVersion() { String version = "unknown"; @@ -58,9 +68,14 @@ private static String loadVersion() { * @return The request options */ public static RequestOptions getRequestOptions() { - return new RequestOptions() + RequestOptions requestOptions = new RequestOptions(); + + if (disabled) { + return requestOptions; + } + + return requestOptions .setHeader(HttpHeaderName.fromString("Semantic-Kernel-Version"), header) - .setContext( - new Context(UserAgentPolicy.APPEND_USER_AGENT_CONTEXT_KEY, useragent)); + .setContext(new Context(UserAgentPolicy.APPEND_USER_AGENT_CONTEXT_KEY, useragent)); } } diff --git a/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/textcompletion/OpenAITextGenerationService.java b/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/textcompletion/OpenAITextGenerationService.java index 57e3dd1f..13783229 100644 --- a/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/textcompletion/OpenAITextGenerationService.java +++ b/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/textcompletion/OpenAITextGenerationService.java @@ -30,7 +30,8 @@ /** * An OpenAI implementation of a {@link TextGenerationService}. */ -public class OpenAITextGenerationService extends OpenAiService implements TextGenerationService { +public class OpenAITextGenerationService extends OpenAiService + implements TextGenerationService { private static final Logger LOGGER = LoggerFactory.getLogger(OpenAITextGenerationService.class); diff --git a/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/textembedding/OpenAITextEmbeddingGenerationService.java b/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/textembedding/OpenAITextEmbeddingGenerationService.java index a46540c0..cd2c7aa8 100644 --- a/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/textembedding/OpenAITextEmbeddingGenerationService.java +++ b/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/textembedding/OpenAITextEmbeddingGenerationService.java @@ -23,7 +23,7 @@ * An OpenAI implementation of a {@link TextEmbeddingGenerationService}. * */ -public class OpenAITextEmbeddingGenerationService extends OpenAiService +public class OpenAITextEmbeddingGenerationService extends OpenAiService implements TextEmbeddingGenerationService { private static final Logger LOGGER = LoggerFactory .getLogger(OpenAITextEmbeddingGenerationService.class); @@ -87,7 +87,7 @@ protected Mono> internalGenerateTextEmbeddingsAsync(List * A builder for creating a {@link OpenAITextEmbeddingGenerationService}. */ public static class Builder extends - OpenAiServiceBuilder { + OpenAiServiceBuilder { private int dimensions = DEFAULT_DIMENSIONS; /** diff --git a/api-test/integration-tests/pom.xml b/api-test/integration-tests/pom.xml index 48f5b608..b3ec6563 100644 --- a/api-test/integration-tests/pom.xml +++ b/api-test/integration-tests/pom.xml @@ -6,7 +6,7 @@ com.microsoft.semantic-kernel api-test - 1.2.0 + 1.2.2 ../pom.xml @@ -65,14 +65,19 @@ org.xerial sqlite-jdbc - 3.44.1.0 + 3.46.0.0 com.mysql mysql-connector-j - 8.2.0 + 9.0.0 test + + org.postgresql + postgresql + 42.7.3 + org.testcontainers @@ -98,7 +103,6 @@ org.wiremock wiremock - 3.3.1 test diff --git a/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/Hotel.java b/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/Hotel.java index e43842fc..ad10ad64 100644 --- a/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/Hotel.java +++ b/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/Hotel.java @@ -18,7 +18,7 @@ public class Hotel { @VectorStoreRecordVectorAttribute(dimensions = 3) private final List descriptionEmbedding; @VectorStoreRecordDataAttribute - private final double rating; + private double rating; public Hotel() { this(null, null, 0, null, null, 0.0); @@ -56,4 +56,8 @@ public List getDescriptionEmbedding() { public double getRating() { return rating; } + + public void setRating(double rating) { + this.rating = rating; + } } diff --git a/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/jdbc/JDBCVectorStoreRecordCollectionTest.java b/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/jdbc/JDBCVectorStoreRecordCollectionTest.java new file mode 100644 index 00000000..8bee5a76 --- /dev/null +++ b/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/jdbc/JDBCVectorStoreRecordCollectionTest.java @@ -0,0 +1,309 @@ +package com.microsoft.semantickernel.tests.connectors.memory.jdbc; + +import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreQueryProvider; +import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreRecordCollection; +import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreRecordCollectionOptions; +import com.microsoft.semantickernel.connectors.data.mysql.MySQLVectorStoreQueryProvider; +import com.microsoft.semantickernel.connectors.data.postgres.PostgreSQLVectorStoreQueryProvider; +import com.microsoft.semantickernel.data.recordoptions.GetRecordOptions; +import com.microsoft.semantickernel.tests.connectors.memory.Hotel; +import com.mysql.cj.jdbc.MysqlDataSource; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; +import org.postgresql.ds.PGSimpleDataSource; +import org.testcontainers.containers.MySQLContainer; +import org.testcontainers.containers.PostgreSQLContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.utility.DockerImageName; + +import javax.annotation.Nonnull; +import javax.sql.DataSource; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; + + +@Testcontainers +public class JDBCVectorStoreRecordCollectionTest { + + @Container + private static final MySQLContainer MYSQL_CONTAINER = new MySQLContainer<>("mysql:5.7.34"); + + private static final DockerImageName PGVECTOR = DockerImageName.parse("pgvector/pgvector:pg16").asCompatibleSubstituteFor("postgres"); + @Container + private static final PostgreSQLContainer POSTGRESQL_CONTAINER = new PostgreSQLContainer<>(PGVECTOR); + + public enum QueryProvider { + MySQL, + PostgreSQL + } + + private JDBCVectorStoreRecordCollection buildRecordCollection(QueryProvider provider, @Nonnull String collectionName) { + JDBCVectorStoreQueryProvider queryProvider; + DataSource dataSource; + + switch (provider) { + case MySQL: + MysqlDataSource mysqlDataSource = new MysqlDataSource(); + mysqlDataSource.setUrl(MYSQL_CONTAINER.getJdbcUrl()); + mysqlDataSource.setUser(MYSQL_CONTAINER.getUsername()); + mysqlDataSource.setPassword(MYSQL_CONTAINER.getPassword()); + dataSource = mysqlDataSource; + queryProvider = MySQLVectorStoreQueryProvider.builder() + .withDataSource(dataSource) + .build(); + break; + case PostgreSQL: + PGSimpleDataSource pgSimpleDataSource = new PGSimpleDataSource(); + pgSimpleDataSource.setUrl(POSTGRESQL_CONTAINER.getJdbcUrl()); + pgSimpleDataSource.setUser(POSTGRESQL_CONTAINER.getUsername()); + pgSimpleDataSource.setPassword(POSTGRESQL_CONTAINER.getPassword()); + dataSource = pgSimpleDataSource; + queryProvider = PostgreSQLVectorStoreQueryProvider.builder() + .withDataSource(dataSource) + .build(); + break; + default: + throw new IllegalArgumentException("Unknown query provider: " + provider); + } + + + JDBCVectorStoreRecordCollection recordCollection = new JDBCVectorStoreRecordCollection<>( + dataSource, + collectionName, + JDBCVectorStoreRecordCollectionOptions.builder() + .withRecordClass(Hotel.class) + .withQueryProvider(queryProvider) + .build()); + + recordCollection.prepareAsync().block(); + recordCollection.createCollectionIfNotExistsAsync().block(); + return recordCollection; + } + + @ParameterizedTest + @EnumSource(QueryProvider.class) + public void buildRecordCollection(QueryProvider provider) { + assertNotNull(buildRecordCollection(provider, "buildTest")); + } + + private List getHotels() { + return List.of( + new Hotel("id_1", "Hotel 1", 1, "Hotel 1 description", Arrays.asList(1.0f, 2.0f, 3.0f), 4.0), + new Hotel("id_2", "Hotel 2", 2, "Hotel 2 description", Arrays.asList(1.0f, 2.0f, 3.0f), 3.0), + new Hotel("id_3", "Hotel 3", 3, "Hotel 3 description", Arrays.asList(1.0f, 2.0f, 3.0f), 5.0), + new Hotel("id_4", "Hotel 4", 4, "Hotel 4 description", Arrays.asList(1.0f, 2.0f, 3.0f), 4.0), + new Hotel("id_5", "Hotel 5", 5, "Hotel 5 description", Arrays.asList(1.0f, 2.0f, 3.0f), 5.0) + ); + } + + @ParameterizedTest + @EnumSource(QueryProvider.class) + public void upsertAndGetRecordAsync(QueryProvider provider) { + String collectionName = "upsertAndGetRecordAsync"; + JDBCVectorStoreRecordCollection recordCollection = buildRecordCollection(provider, collectionName); + + List hotels = getHotels(); + for (Hotel hotel : hotels) { + recordCollection.upsertAsync(hotel, null).block(); + } + + // Upsert the first time + for (Hotel hotel : hotels) { + Hotel retrievedHotel = recordCollection.getAsync(hotel.getId(), null).block(); + assertNotNull(retrievedHotel); + assertEquals(hotel.getId(), retrievedHotel.getId()); + assertEquals(hotel.getRating(), retrievedHotel.getRating()); + + // Update the rating + hotel.setRating(1.0); + } + + // Upsert the second time with updated rating + for (Hotel hotel : hotels) { + recordCollection.upsertAsync(hotel, null).block(); + } + + for (Hotel hotel : hotels) { + Hotel retrievedHotel = recordCollection.getAsync(hotel.getId(), null).block(); + assertNotNull(retrievedHotel); + assertEquals(hotel.getId(), retrievedHotel.getId()); + assertEquals(1.0, retrievedHotel.getRating()); + } + } + + @ParameterizedTest + @EnumSource(QueryProvider.class) + public void getBatchAsync(QueryProvider provider) { + String collectionName = "getBatchAsync"; + JDBCVectorStoreRecordCollection recordCollection = buildRecordCollection(provider, collectionName); + + List hotels = getHotels(); + for (Hotel hotel : hotels) { + recordCollection.upsertAsync(hotel, null).block(); + } + + List keys = new ArrayList<>(); + for (Hotel hotel : hotels) { + keys.add(hotel.getId()); + } + + List retrievedHotels = recordCollection.getBatchAsync(keys, null).block(); + assertNotNull(retrievedHotels); + assertEquals(hotels.size(), retrievedHotels.size()); + } + + @ParameterizedTest + @EnumSource(QueryProvider.class) + public void upsertBatchAndGetBatchAsync(QueryProvider provider) { + String collectionName = "upsertBatchAndGetBatchAsync"; + JDBCVectorStoreRecordCollection recordCollection = buildRecordCollection(provider, collectionName); + + List hotels = getHotels(); + recordCollection.upsertBatchAsync(hotels, null).block(); + + List keys = new ArrayList<>(); + for (Hotel hotel : hotels) { + keys.add(hotel.getId()); + } + + List retrievedHotels = recordCollection.getBatchAsync(keys, null).block(); + assertNotNull(retrievedHotels); + assertEquals(hotels.size(), retrievedHotels.size()); + } + + @ParameterizedTest + @EnumSource(QueryProvider.class) + public void insertAndReplaceAsync(QueryProvider provider) { + String collectionName = "insertAndReplaceAsync"; + JDBCVectorStoreRecordCollection recordCollection = buildRecordCollection(provider, collectionName); + + List hotels = getHotels(); + recordCollection.upsertBatchAsync(hotels, null).block(); + recordCollection.upsertBatchAsync(hotels, null).block(); + recordCollection.upsertBatchAsync(hotels, null).block(); + + List keys = new ArrayList<>(); + for (Hotel hotel : hotels) { + keys.add(hotel.getId()); + } + + List retrievedHotels = recordCollection.getBatchAsync(keys, null).block(); + assertNotNull(retrievedHotels); + assertEquals(hotels.size(), retrievedHotels.size()); + } + + @ParameterizedTest + @EnumSource(QueryProvider.class) + public void deleteRecordAsync(QueryProvider provider) { + String collectionName = "deleteRecordAsync"; + JDBCVectorStoreRecordCollection recordCollection = buildRecordCollection(provider, collectionName); + + List hotels = getHotels(); + recordCollection.upsertBatchAsync(hotels, null).block(); + + for (Hotel hotel : hotels) { + recordCollection.deleteAsync(hotel.getId(), null).block(); + Hotel retrievedHotel = recordCollection.getAsync(hotel.getId(), null).block(); + assertNull(retrievedHotel); + } + } + + @ParameterizedTest + @EnumSource(QueryProvider.class) + public void deleteBatchAsync(QueryProvider provider) { + String collectionName = "deleteBatchAsync"; + JDBCVectorStoreRecordCollection recordCollection = buildRecordCollection(provider, collectionName); + + List hotels = getHotels(); + recordCollection.upsertBatchAsync(hotels, null).block(); + + List keys = new ArrayList<>(); + for (Hotel hotel : hotels) { + keys.add(hotel.getId()); + } + + recordCollection.deleteBatchAsync(keys, null).block(); + + for (String key : keys) { + Hotel retrievedHotel = recordCollection.getAsync(key, null).block(); + assertNull(retrievedHotel); + } + } + + @ParameterizedTest + @EnumSource(QueryProvider.class) + public void getWithNoVectors(QueryProvider provider) { + String collectionName = "getWithNoVectors"; + JDBCVectorStoreRecordCollection recordCollection = buildRecordCollection(provider, collectionName); + + List hotels = getHotels(); + recordCollection.upsertBatchAsync(hotels, null).block(); + + GetRecordOptions options = GetRecordOptions.builder() + .includeVectors(false) + .build(); + + for (Hotel hotel : hotels) { + Hotel retrievedHotel = recordCollection.getAsync(hotel.getId(), options).block(); + assertNotNull(retrievedHotel); + assertEquals(hotel.getId(), retrievedHotel.getId()); + assertNull(retrievedHotel.getDescriptionEmbedding()); + } + + options = GetRecordOptions.builder() + .includeVectors(true) + .build(); + + for (Hotel hotel : hotels) { + Hotel retrievedHotel = recordCollection.getAsync(hotel.getId(), options).block(); + assertNotNull(retrievedHotel); + assertEquals(hotel.getId(), retrievedHotel.getId()); + assertNotNull(retrievedHotel.getDescriptionEmbedding()); + } + } + + @ParameterizedTest + @EnumSource(QueryProvider.class) + public void getBatchWithNoVectors(QueryProvider provider) { + String collectionName = "getBatchWithNoVectors"; + JDBCVectorStoreRecordCollection recordCollection = buildRecordCollection(provider, collectionName); + + List hotels = getHotels(); + recordCollection.upsertBatchAsync(hotels, null).block(); + + GetRecordOptions options = GetRecordOptions.builder() + .includeVectors(false) + .build(); + + List keys = new ArrayList<>(); + for (Hotel hotel : hotels) { + keys.add(hotel.getId()); + } + + List retrievedHotels = recordCollection.getBatchAsync(keys, options).block(); + assertNotNull(retrievedHotels); + assertEquals(hotels.size(), retrievedHotels.size()); + + for (Hotel hotel : retrievedHotels) { + assertNull(hotel.getDescriptionEmbedding()); + } + + options = GetRecordOptions.builder() + .includeVectors(true) + .build(); + + retrievedHotels = recordCollection.getBatchAsync(keys, options).block(); + assertNotNull(retrievedHotels); + assertEquals(hotels.size(), retrievedHotels.size()); + + for (Hotel hotel : retrievedHotels) { + assertNotNull(hotel.getDescriptionEmbedding()); + } + } +} diff --git a/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/jdbc/JDBCVectorStoreTest.java b/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/jdbc/JDBCVectorStoreTest.java new file mode 100644 index 00000000..8c2fbfd0 --- /dev/null +++ b/api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/jdbc/JDBCVectorStoreTest.java @@ -0,0 +1,102 @@ +package com.microsoft.semantickernel.tests.connectors.memory.jdbc; + +import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStore; +import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreOptions; +import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreQueryProvider; +import com.microsoft.semantickernel.connectors.data.mysql.MySQLVectorStoreQueryProvider; +import com.microsoft.semantickernel.connectors.data.postgres.PostgreSQLVectorStoreQueryProvider; +import com.microsoft.semantickernel.tests.connectors.memory.Hotel; +import com.mysql.cj.jdbc.MysqlDataSource; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; +import org.postgresql.ds.PGSimpleDataSource; +import org.testcontainers.containers.MySQLContainer; +import org.testcontainers.containers.PostgreSQLContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.utility.DockerImageName; + +import javax.annotation.Nonnull; +import javax.sql.DataSource; +import java.util.Arrays; +import java.util.List; + +import com.microsoft.semantickernel.tests.connectors.memory.jdbc.JDBCVectorStoreRecordCollectionTest.QueryProvider; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +@Testcontainers +public class JDBCVectorStoreTest { + @Container + private static final MySQLContainer MYSQL_CONTAINER = new MySQLContainer<>("mysql:5.7.34"); + + private static final DockerImageName PGVECTOR = DockerImageName.parse("pgvector/pgvector:pg16").asCompatibleSubstituteFor("postgres"); + @Container + private static final PostgreSQLContainer POSTGRESQL_CONTAINER = new PostgreSQLContainer<>(PGVECTOR); + + private JDBCVectorStore buildVectorStore(QueryProvider provider) { + JDBCVectorStoreQueryProvider queryProvider; + DataSource dataSource; + + switch (provider) { + case MySQL: + MysqlDataSource mysqlDataSource = new MysqlDataSource(); + mysqlDataSource.setUrl(MYSQL_CONTAINER.getJdbcUrl()); + mysqlDataSource.setUser(MYSQL_CONTAINER.getUsername()); + mysqlDataSource.setPassword(MYSQL_CONTAINER.getPassword()); + dataSource = mysqlDataSource; + queryProvider = MySQLVectorStoreQueryProvider.builder() + .withDataSource(dataSource) + .build(); + break; + case PostgreSQL: + PGSimpleDataSource pgSimpleDataSource = new PGSimpleDataSource(); + pgSimpleDataSource.setUrl(POSTGRESQL_CONTAINER.getJdbcUrl()); + pgSimpleDataSource.setUser(POSTGRESQL_CONTAINER.getUsername()); + pgSimpleDataSource.setPassword(POSTGRESQL_CONTAINER.getPassword()); + dataSource = pgSimpleDataSource; + queryProvider = PostgreSQLVectorStoreQueryProvider.builder() + .withDataSource(dataSource) + .build(); + break; + default: + throw new IllegalArgumentException("Unknown query provider: " + provider); + } + + + JDBCVectorStore vectorStore = JDBCVectorStore.builder() + .withDataSource(dataSource) + .withOptions( + JDBCVectorStoreOptions.builder() + .withQueryProvider(queryProvider) + .build() + ) + .build(); + + vectorStore.prepareAsync().block(); + return vectorStore; + } + + + @ParameterizedTest + @EnumSource(QueryProvider.class) + public void getCollectionNamesAsync(QueryProvider provider) { + JDBCVectorStore vectorStore = buildVectorStore(provider); + + vectorStore.getCollectionNamesAsync().block(); + + List collectionNames = Arrays.asList("collection1", "collection2", "collection3"); + + for (String collectionName : collectionNames) { + vectorStore.getCollection(collectionName, Hotel.class, null).createCollectionAsync().block(); + } + + List retrievedCollectionNames = vectorStore.getCollectionNamesAsync().block(); + assertNotNull(retrievedCollectionNames); + assertEquals(collectionNames.size(), retrievedCollectionNames.size()); + for (String collectionName : collectionNames) { + assertTrue(retrievedCollectionNames.contains(collectionName)); + } + } +} diff --git a/api-test/pom.xml b/api-test/pom.xml index 7b49dd77..47d10b04 100644 --- a/api-test/pom.xml +++ b/api-test/pom.xml @@ -6,7 +6,7 @@ com.microsoft.semantic-kernel semantickernel-parent - 1.2.0 + 1.2.2 ../pom.xml diff --git a/pom.xml b/pom.xml index 313a7869..84e20d0c 100644 --- a/pom.xml +++ b/pom.xml @@ -6,7 +6,7 @@ com.microsoft.semantic-kernel semantickernel-parent - 1.2.0 + 1.2.2 pom https://www.github.com/microsoft/semantic-kernel @@ -147,6 +147,19 @@ spotbugs-annotations ${spotbugs.version} + + + org.wiremock + wiremock + 3.9.1 + test + + + org.mockito + mockito-junit-jupiter + 5.12.0 + test + @@ -812,6 +825,6 @@ https://github.com/microsoft/semantic-kernel scm:git:https://github.com/microsoft/semantic-kernel.git scm:git:https://github.com/microsoft/semantic-kernel.git - java-1.2.0 + java-1.2.2 diff --git a/samples/pom.xml b/samples/pom.xml index 61a1dada..a62aa7c4 100644 --- a/samples/pom.xml +++ b/samples/pom.xml @@ -4,7 +4,7 @@ com.microsoft.semantic-kernel semantickernel-parent - 1.2.0 + 1.2.2 ../pom.xml diff --git a/samples/semantickernel-concepts/pom.xml b/samples/semantickernel-concepts/pom.xml index a79a2371..cd4b855c 100644 --- a/samples/semantickernel-concepts/pom.xml +++ b/samples/semantickernel-concepts/pom.xml @@ -4,7 +4,7 @@ com.microsoft.semantic-kernel semantickernel-samples-parent - 1.2.0 + 1.2.2 ../pom.xml diff --git a/samples/semantickernel-concepts/semantickernel-syntax-examples/pom.xml b/samples/semantickernel-concepts/semantickernel-syntax-examples/pom.xml index 9cedbbf0..27c3420d 100644 --- a/samples/semantickernel-concepts/semantickernel-syntax-examples/pom.xml +++ b/samples/semantickernel-concepts/semantickernel-syntax-examples/pom.xml @@ -4,7 +4,7 @@ com.microsoft.semantic-kernel semantickernel-concepts - 1.2.0 + 1.2.2 ../pom.xml @@ -85,9 +85,15 @@ com.google.cloud google-cloud-vertexai - 1.1.0 + 1.6.0 compile + + + com.mysql + mysql-connector-j + 9.0.0 + diff --git a/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/demos/lights/App.java b/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/demos/lights/App.java index 08fb2697..08a8f3b5 100644 --- a/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/demos/lights/App.java +++ b/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/demos/lights/App.java @@ -10,6 +10,7 @@ import com.microsoft.semantickernel.aiservices.openai.chatcompletion.OpenAIChatCompletion; import com.microsoft.semantickernel.contextvariables.ContextVariableTypeConverter; import com.microsoft.semantickernel.contextvariables.ContextVariableTypes; +import com.microsoft.semantickernel.hooks.KernelHooks; import com.microsoft.semantickernel.orchestration.InvocationContext; import com.microsoft.semantickernel.orchestration.InvocationContext.Builder; import com.microsoft.semantickernel.orchestration.InvocationReturnMode; @@ -73,6 +74,27 @@ public static void main(String[] args) throws Exception { .toPromptString(new Gson()::toJson) .build()); + KernelHooks hook = new KernelHooks(); + + hook.addPreToolCallHook((context) -> { + System.out.println("Pre-tool call hook"); + return context; + }); + + hook.addPreChatCompletionHook( + (context) -> { + System.out.println("Pre-chat completion hook"); + return context; + }); + + hook.addPostChatCompletionHook( + (context) -> { + System.out.println("Post-chat completion hook"); + return context; + }); + + kernel.getGlobalKernelHooks().addHooks(hook); + // Enable planning InvocationContext invocationContext = new Builder() .withReturnMode(InvocationReturnMode.LAST_MESSAGE_ONLY) diff --git a/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/syntaxexamples/java/CustomTypes_Example.java b/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/syntaxexamples/java/CustomTypes_Example.java new file mode 100644 index 00000000..e809474d --- /dev/null +++ b/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/syntaxexamples/java/CustomTypes_Example.java @@ -0,0 +1,180 @@ +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.semantickernel.samples.syntaxexamples.java; + +import com.azure.ai.openai.OpenAIAsyncClient; +import com.azure.ai.openai.OpenAIClientBuilder; +import com.azure.core.credential.AzureKeyCredential; +import com.azure.core.credential.KeyCredential; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.microsoft.semantickernel.Kernel; +import com.microsoft.semantickernel.aiservices.openai.chatcompletion.OpenAIChatCompletion; +import com.microsoft.semantickernel.contextvariables.ContextVariableTypeConverter; +import com.microsoft.semantickernel.contextvariables.ContextVariableTypes; +import com.microsoft.semantickernel.contextvariables.converters.ContextVariableJacksonConverter; +import com.microsoft.semantickernel.exceptions.ConfigurationException; +import com.microsoft.semantickernel.semanticfunctions.KernelFunctionArguments; +import com.microsoft.semantickernel.services.chatcompletion.ChatCompletionService; +import java.io.IOException; +import java.util.Arrays; +import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collectors; + +public class CustomTypes_Example { + + private static final String CLIENT_KEY = System.getenv("CLIENT_KEY"); + private static final String AZURE_CLIENT_KEY = System.getenv("AZURE_CLIENT_KEY"); + + // Only required if AZURE_CLIENT_KEY is set + private static final String CLIENT_ENDPOINT = System.getenv("CLIENT_ENDPOINT"); + private static final String MODEL_ID = System.getenv() + .getOrDefault("MODEL_ID", "gpt-35-turbo-2"); + + public static void main(String[] args) throws ConfigurationException, IOException { + + OpenAIAsyncClient client; + + if (AZURE_CLIENT_KEY != null) { + client = new OpenAIClientBuilder() + .credential(new AzureKeyCredential(AZURE_CLIENT_KEY)) + .endpoint(CLIENT_ENDPOINT) + .buildAsyncClient(); + } else { + client = new OpenAIClientBuilder() + .credential(new KeyCredential(CLIENT_KEY)) + .buildAsyncClient(); + } + + ChatCompletionService chatCompletionService = OpenAIChatCompletion.builder() + .withOpenAIAsyncClient(client) + .withModelId(MODEL_ID) + .build(); + + exampleBuildingCustomConverter(chatCompletionService); + exampleUsingJackson(chatCompletionService); + exampleUsingGlobalTypes(chatCompletionService); + } + + public record Pet(String name, int age, String species) { + + @JsonCreator + public Pet( + @JsonProperty("name") String name, + @JsonProperty("age") int age, + @JsonProperty("species") String species) { + this.name = name; + this.age = age; + this.species = species; + } + + @Override + public String toString() { + return name + " " + species + " " + age; + } + } + + private static void exampleBuildingCustomConverter( + ChatCompletionService chatCompletionService) { + Pet sandy = new Pet("Sandy", 3, "Dog"); + + Kernel kernel = Kernel.builder() + .withAIService(ChatCompletionService.class, chatCompletionService) + .build(); + + // Format: + // name: Sandy + // age: 3 + // species: Dog + + // Custom serializer + Function petToString = pet -> "name: " + pet.name() + "\n" + + "age: " + pet.age() + "\n" + + "species: " + pet.species() + "\n"; + + // Custom deserializer + Function stringToPet = prompt -> { + Map properties = Arrays.stream(prompt.split("\n")) + .collect(Collectors.toMap( + line -> line.split(":")[0].trim(), + line -> line.split(":")[1].trim())); + + return new Pet( + properties.get("name"), + Integer.parseInt(properties.get("age")), + properties.get("species")); + }; + + // create custom converter + ContextVariableTypeConverter typeConverter = ContextVariableTypeConverter.builder( + Pet.class) + .toPromptString(petToString) + .fromPromptString(stringToPet) + .build(); + + Pet updated = kernel.invokePromptAsync( + "Change Sandy's name to Daisy:\n{{$Sandy}}", + KernelFunctionArguments.builder() + .withVariable("Sandy", sandy, typeConverter) + .build()) + .withTypeConverter(typeConverter) + .withResultType(Pet.class) + .block() + .getResult(); + + System.out.println("Sandy's updated record: " + updated); + } + + public static void exampleUsingJackson(ChatCompletionService chatCompletionService) { + Pet sandy = new Pet("Sandy", 3, "Dog"); + + Kernel kernel = Kernel.builder() + .withAIService(ChatCompletionService.class, chatCompletionService) + .build(); + + // Create a converter that defaults to using jackson for serialization + ContextVariableTypeConverter typeConverter = ContextVariableJacksonConverter.create( + Pet.class); + + // Invoke the prompt with the custom converter + Pet updated = kernel.invokePromptAsync( + "Increase Sandy's age by a year:\n{{$Sandy}}", + KernelFunctionArguments.builder() + .withVariable("Sandy", sandy, typeConverter) + .build()) + .withTypeConverter(typeConverter) + .withResultType(Pet.class) + .block() + .getResult(); + + System.out.println("Sandy's updated record: " + updated); + } + + public static void exampleUsingGlobalTypes(ChatCompletionService chatCompletionService) { + Pet sandy = new Pet("Sandy", 3, "Dog"); + + Kernel kernel = Kernel.builder() + .withAIService(ChatCompletionService.class, chatCompletionService) + .build(); + + // Create a converter that defaults to using jackson for serialization + ContextVariableTypeConverter typeConverter = ContextVariableJacksonConverter.create( + Pet.class); + + // Add converter to global types + ContextVariableTypes.addGlobalConverter(typeConverter); + + // No need to explicitly tell the invocation how to convert the type + Pet updated = kernel.invokePromptAsync( + "Sandy's is actually a cat correct this:\n{{$Sandy}}", + KernelFunctionArguments.builder() + .withVariable("Sandy", sandy) + .build()) + .withResultType(Pet.class) + .block() + .getResult(); + + System.out.println("Sandy's updated record: " + updated); + } + +} diff --git a/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/syntaxexamples/memory/AzureAISearch_DataStorage.java b/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/syntaxexamples/memory/AzureAISearch_DataStorage.java index 92316385..03cc93e3 100644 --- a/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/syntaxexamples/memory/AzureAISearch_DataStorage.java +++ b/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/syntaxexamples/memory/AzureAISearch_DataStorage.java @@ -13,13 +13,10 @@ import com.microsoft.semantickernel.aiservices.openai.textembedding.OpenAITextEmbeddingGenerationService; import com.microsoft.semantickernel.connectors.data.azureaisearch.AzureAISearchVectorStore; import com.microsoft.semantickernel.connectors.data.azureaisearch.AzureAISearchVectorStoreOptions; -import com.microsoft.semantickernel.connectors.data.azureaisearch.AzureAISearchVectorStoreRecordCollection; +import com.microsoft.semantickernel.data.VectorStoreRecordCollection; import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordDataAttribute; import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordKeyAttribute; import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordVectorAttribute; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.Base64; @@ -27,8 +24,11 @@ import java.util.List; import java.util.Map; import java.util.stream.Collectors; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; public class AzureAISearch_DataStorage { + private static final String CLIENT_KEY = System.getenv("CLIENT_KEY"); private static final String AZURE_CLIENT_KEY = System.getenv("AZURE_CLIENT_KEY"); @@ -45,6 +45,7 @@ public class AzureAISearch_DataStorage { private static final int EMBEDDING_DIMENSIONS = 1536; static class GitHubFile { + @VectorStoreRecordKeyAttribute() private final String id; @VectorStoreRecordDataAttribute(hasEmbedding = true, embeddingFieldName = "embedding") @@ -120,7 +121,9 @@ public static void dataStorageWithAzureAISearch( .build(); String collectionName = "skgithubfiles"; - var collection = azureAISearchVectorStore.getCollection(collectionName, GitHubFile.class, + var collection = azureAISearchVectorStore.getCollection( + collectionName, + GitHubFile.class, null); // Create collection if it does not exist and store data @@ -140,7 +143,7 @@ public static void dataStorageWithAzureAISearch( } private static Mono> storeData( - AzureAISearchVectorStoreRecordCollection recordStore, + VectorStoreRecordCollection recordStore, OpenAITextEmbeddingGenerationService embeddingGeneration, Map data) { diff --git a/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/syntaxexamples/memory/InMemory_DataStorage.java b/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/syntaxexamples/memory/InMemory_DataStorage.java index b189081b..c74d16f8 100644 --- a/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/syntaxexamples/memory/InMemory_DataStorage.java +++ b/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/syntaxexamples/memory/InMemory_DataStorage.java @@ -5,32 +5,22 @@ import com.azure.ai.openai.OpenAIClientBuilder; import com.azure.core.credential.AzureKeyCredential; import com.azure.core.credential.KeyCredential; -import com.azure.core.util.ClientOptions; -import com.azure.core.util.MetricsOptions; -import com.azure.core.util.TracingOptions; -import com.azure.search.documents.indexes.SearchIndexAsyncClient; -import com.azure.search.documents.indexes.SearchIndexClientBuilder; import com.microsoft.semantickernel.aiservices.openai.textembedding.OpenAITextEmbeddingGenerationService; -import com.microsoft.semantickernel.connectors.data.azureaisearch.AzureAISearchVectorStore; -import com.microsoft.semantickernel.connectors.data.azureaisearch.AzureAISearchVectorStoreOptions; -import com.microsoft.semantickernel.connectors.data.azureaisearch.AzureAISearchVectorStoreRecordCollection; +import com.microsoft.semantickernel.data.VectorStoreRecordCollection; import com.microsoft.semantickernel.data.VolatileVectorStore; -import com.microsoft.semantickernel.data.VolatileVectorStoreRecordCollection; import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordDataAttribute; import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordKeyAttribute; import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordVectorAttribute; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - -import java.nio.charset.StandardCharsets; import java.util.Arrays; -import java.util.Base64; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.stream.Collectors; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; public class InMemory_DataStorage { + private static final String CLIENT_KEY = System.getenv("CLIENT_KEY"); private static final String AZURE_CLIENT_KEY = System.getenv("AZURE_CLIENT_KEY"); @@ -43,6 +33,7 @@ public class InMemory_DataStorage { private static final int EMBEDDING_DIMENSIONS = 1536; static class GitHubFile { + @VectorStoreRecordKeyAttribute() private final String id; @VectorStoreRecordDataAttribute(hasEmbedding = true, embeddingFieldName = "embedding") @@ -72,8 +63,7 @@ public String getDescription() { } static String encodeId(String realId) { - byte[] bytes = Base64.getUrlEncoder().encode(realId.getBytes(StandardCharsets.UTF_8)); - return new String(bytes, StandardCharsets.UTF_8); + return AzureAISearch_DataStorage.GitHubFile.encodeId(realId); } } @@ -105,7 +95,8 @@ public static void main(String[] args) { inMemoryDataStorage(embeddingGeneration); } - public static void inMemoryDataStorage(OpenAITextEmbeddingGenerationService embeddingGeneration) { + public static void inMemoryDataStorage( + OpenAITextEmbeddingGenerationService embeddingGeneration) { // Create a new Volatile vector store var volatileVectorStore = new VolatileVectorStore(); @@ -125,7 +116,7 @@ public static void inMemoryDataStorage(OpenAITextEmbeddingGenerationService embe } private static Mono> storeData( - VolatileVectorStoreRecordCollection recordCollection, + VectorStoreRecordCollection recordCollection, OpenAITextEmbeddingGenerationService embeddingGeneration, Map data) { diff --git a/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/syntaxexamples/memory/JDBC_DataStorage.java b/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/syntaxexamples/memory/JDBC_DataStorage.java new file mode 100644 index 00000000..ed1d8bd5 --- /dev/null +++ b/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/syntaxexamples/memory/JDBC_DataStorage.java @@ -0,0 +1,188 @@ +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.semantickernel.samples.syntaxexamples.memory; + +import com.azure.ai.openai.OpenAIAsyncClient; +import com.azure.ai.openai.OpenAIClientBuilder; +import com.azure.core.credential.AzureKeyCredential; +import com.azure.core.credential.KeyCredential; +import com.microsoft.semantickernel.aiservices.openai.textembedding.OpenAITextEmbeddingGenerationService; +import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStore; +import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreOptions; +import com.microsoft.semantickernel.connectors.data.mysql.MySQLVectorStoreQueryProvider; +import com.microsoft.semantickernel.data.VectorStoreRecordCollection; +import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordDataAttribute; +import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordKeyAttribute; +import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordVectorAttribute; +import com.mysql.cj.jdbc.MysqlDataSource; +import java.nio.charset.StandardCharsets; +import java.sql.SQLException; +import java.util.Arrays; +import java.util.Base64; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import javax.sql.DataSource; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +public class JDBC_DataStorage { + + private static final String CLIENT_KEY = System.getenv("CLIENT_KEY"); + private static final String AZURE_CLIENT_KEY = System.getenv("AZURE_CLIENT_KEY"); + + // Only required if AZURE_CLIENT_KEY is set + private static final String CLIENT_ENDPOINT = System.getenv("CLIENT_ENDPOINT"); + private static final String MODEL_ID = System.getenv() + .getOrDefault("EMBEDDING_MODEL_ID", "text-embedding-3-large"); + private static final int EMBEDDING_DIMENSIONS = 1536; + + // Run a MySQL server with: + // docker run -d --name mysql-container -e MYSQL_ROOT_PASSWORD=root -e MYSQL_DATABASE=sk -p 3306:3306 mysql:latest + + static class GitHubFile { + + @VectorStoreRecordKeyAttribute() + private final String id; + @VectorStoreRecordDataAttribute(hasEmbedding = true, embeddingFieldName = "embedding") + private final String description; + @VectorStoreRecordDataAttribute + private final String link; + @VectorStoreRecordVectorAttribute(dimensions = EMBEDDING_DIMENSIONS, indexKind = "Hnsw") + private final List embedding; + + public GitHubFile() { + this(null, null, null, Collections.emptyList()); + } + + public GitHubFile( + String id, + String description, + String link, + List embedding) { + this.id = id; + this.description = description; + this.link = link; + this.embedding = embedding; + } + + public String getId() { + return id; + } + + public String getDescription() { + return description; + } + + static String encodeId(String realId) { + byte[] bytes = Base64.getUrlEncoder().encode(realId.getBytes(StandardCharsets.UTF_8)); + return new String(bytes, StandardCharsets.UTF_8); + } + } + + public static void main(String[] args) throws SQLException { + System.out.println("=============================================================="); + System.out.println("========== JDBC Vector Store Example =============="); + System.out.println("=============================================================="); + + OpenAIAsyncClient client; + + if (AZURE_CLIENT_KEY != null) { + client = new OpenAIClientBuilder() + .credential(new AzureKeyCredential(AZURE_CLIENT_KEY)) + .endpoint(CLIENT_ENDPOINT) + .buildAsyncClient(); + + } else { + client = new OpenAIClientBuilder() + .credential(new KeyCredential(CLIENT_KEY)) + .buildAsyncClient(); + } + + var embeddingGeneration = OpenAITextEmbeddingGenerationService.builder() + .withOpenAIAsyncClient(client) + .withModelId(MODEL_ID) + .withDimensions(EMBEDDING_DIMENSIONS) + .build(); + + var dataSource = new MysqlDataSource(); + dataSource.setUrl("jdbc:mysql://localhost:3306/sk"); + dataSource.setPassword("root"); + dataSource.setUser("root"); + + dataStorageWithMySQL(dataSource, embeddingGeneration); + } + + public static void dataStorageWithMySQL( + DataSource dataSource, + OpenAITextEmbeddingGenerationService embeddingGeneration) { + + // Build a query provider + var queryProvider = MySQLVectorStoreQueryProvider.builder() + .withDataSource(dataSource) + .build(); + + // Create a new vector store + var jdbcVectorStore = JDBCVectorStore.builder() + .withDataSource(dataSource) + .withOptions(JDBCVectorStoreOptions.builder() + .withQueryProvider(queryProvider) + .build()) + .build(); + + String collectionName = "skgithubfiles"; + var collection = jdbcVectorStore.getCollection(collectionName, + GitHubFile.class, + null); + + // Create collection if it does not exist and store data + List ids = collection + .createCollectionIfNotExistsAsync() + .then(storeData(collection, embeddingGeneration, sampleData())) + .block(); + + List data = collection.getBatchAsync(ids, null).block(); + + data.forEach(gitHubFile -> System.out.println("Retrieved: " + gitHubFile.getDescription())); + } + + private static Mono> storeData( + VectorStoreRecordCollection recordStore, + OpenAITextEmbeddingGenerationService embeddingGeneration, + Map data) { + + return Flux.fromIterable(data.entrySet()) + .flatMap(entry -> { + System.out.println("Save '" + entry.getKey() + "' to memory."); + + return embeddingGeneration + .generateEmbeddingsAsync(Collections.singletonList(entry.getValue())) + .flatMap(embeddings -> { + GitHubFile gitHubFile = new GitHubFile( + GitHubFile.encodeId(entry.getKey()), + entry.getValue(), + entry.getKey(), + embeddings.get(0).getVector()); + return recordStore.upsertAsync(gitHubFile, null); + }); + }) + .collectList(); + } + + private static Map sampleData() { + return Arrays.stream(new String[][] { + { "https://github.com/microsoft/semantic-kernel/blob/main/README.md", + "README: Installation, getting started with Semantic Kernel, and how to contribute" }, + { "https://github.com/microsoft/semantic-kernel/blob/main/samples/notebooks/dotnet/02-running-prompts-from-file.ipynb", + "Jupyter notebook describing how to pass prompts from a file to a semantic skill or function" }, + { "https://github.com/microsoft/semantic-kernel/tree/main/samples/skills/ChatSkill/ChatGPT", + "Sample demonstrating how to create a chat skill interfacing with ChatGPT" }, + { "https://github.com/microsoft/semantic-kernel/blob/main/dotnet/src/SemanticKernel/Memory/VolatileMemoryStore.cs", + "C# class that defines a volatile embedding store" }, + { "https://github.com/microsoft/semantic-kernel/blob/main/samples/dotnet/KernelHttpServer/README.md", + "README: How to set up a Semantic Kernel Service API using Azure Function Runtime v4" }, + { "https://github.com/microsoft/semantic-kernel/blob/main/samples/apps/chat-summary-webapp-react/README.md", + "README: README associated with a sample chat summary react-based webapp" }, + }).collect(Collectors.toMap(element -> element[0], element -> element[1])); + } +} diff --git a/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/syntaxexamples/memory/Redis_DataStorage.java b/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/syntaxexamples/memory/Redis_DataStorage.java index cde31a3a..6cd1db3d 100644 --- a/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/syntaxexamples/memory/Redis_DataStorage.java +++ b/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/syntaxexamples/memory/Redis_DataStorage.java @@ -22,6 +22,7 @@ import java.util.List; import java.util.Map; import java.util.stream.Collectors; +import com.microsoft.semantickernel.samples.syntaxexamples.memory.AzureAISearch_DataStorage.GitHubFile; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import redis.clients.jedis.JedisPooled; @@ -77,8 +78,7 @@ public String getDescription() { } static String encodeId(String realId) { - byte[] bytes = Base64.getUrlEncoder().encode(realId.getBytes(StandardCharsets.UTF_8)); - return new String(bytes, StandardCharsets.UTF_8); + return AzureAISearch_DataStorage.GitHubFile.encodeId(realId); } } diff --git a/samples/semantickernel-demos/booking-agent-m365/pom.xml b/samples/semantickernel-demos/booking-agent-m365/pom.xml index c427bb4e..a0b84de8 100644 --- a/samples/semantickernel-demos/booking-agent-m365/pom.xml +++ b/samples/semantickernel-demos/booking-agent-m365/pom.xml @@ -4,7 +4,7 @@ com.microsoft.semantic-kernel semantickernel-demos - 1.2.0 + 1.2.2 ../pom.xml @@ -52,7 +52,7 @@ com.microsoft.graph microsoft-graph - 6.5.1 + 6.13.0 diff --git a/samples/semantickernel-demos/pom.xml b/samples/semantickernel-demos/pom.xml index 4ef911aa..b5b4fe7b 100644 --- a/samples/semantickernel-demos/pom.xml +++ b/samples/semantickernel-demos/pom.xml @@ -4,7 +4,7 @@ com.microsoft.semantic-kernel semantickernel-samples-parent - 1.2.0 + 1.2.2 ../pom.xml diff --git a/samples/semantickernel-demos/semantickernel-spring-starter/pom.xml b/samples/semantickernel-demos/semantickernel-spring-starter/pom.xml index acd56827..e9fab05f 100644 --- a/samples/semantickernel-demos/semantickernel-spring-starter/pom.xml +++ b/samples/semantickernel-demos/semantickernel-spring-starter/pom.xml @@ -4,7 +4,7 @@ com.microsoft.semantic-kernel semantickernel-demos - 1.2.0 + 1.2.2 ../pom.xml @@ -39,29 +39,29 @@ org.springframework.boot spring-boot-test - 3.2.1 + 3.3.2 test org.assertj assertj-core - 3.25.1 + 3.26.3 test org.springframework.boot spring-boot-autoconfigure - 3.2.1 + 3.3.2 org.springframework.boot spring-boot - 3.2.1 + 3.3.1 org.springframework spring-test - 6.1.2 + 6.1.10 test @@ -72,7 +72,7 @@ org.junit.jupiter junit-jupiter-api - 5.10.1 + 5.10.3 test diff --git a/samples/semantickernel-demos/sk-presidio-sample/pom.xml b/samples/semantickernel-demos/sk-presidio-sample/pom.xml index 11dd2ae6..8300dcec 100644 --- a/samples/semantickernel-demos/sk-presidio-sample/pom.xml +++ b/samples/semantickernel-demos/sk-presidio-sample/pom.xml @@ -4,7 +4,7 @@ com.microsoft.semantic-kernel semantickernel-demos - 1.2.0 + 1.2.2 ../pom.xml diff --git a/samples/semantickernel-learn-resources/pom.xml b/samples/semantickernel-learn-resources/pom.xml index 63ee1288..c8fb38ba 100644 --- a/samples/semantickernel-learn-resources/pom.xml +++ b/samples/semantickernel-learn-resources/pom.xml @@ -4,7 +4,7 @@ com.microsoft.semantic-kernel semantickernel-samples-parent - 1.2.0 + 1.2.2 ../pom.xml diff --git a/samples/semantickernel-sample-plugins/pom.xml b/samples/semantickernel-sample-plugins/pom.xml index 31bd11cc..71d41a94 100644 --- a/samples/semantickernel-sample-plugins/pom.xml +++ b/samples/semantickernel-sample-plugins/pom.xml @@ -4,7 +4,7 @@ com.microsoft.semantic-kernel semantickernel-samples-parent - 1.2.0 + 1.2.2 ../pom.xml diff --git a/samples/semantickernel-sample-plugins/semantickernel-openapi-plugin/pom.xml b/samples/semantickernel-sample-plugins/semantickernel-openapi-plugin/pom.xml index 6ac975d2..e1804b6f 100644 --- a/samples/semantickernel-sample-plugins/semantickernel-openapi-plugin/pom.xml +++ b/samples/semantickernel-sample-plugins/semantickernel-openapi-plugin/pom.xml @@ -4,7 +4,7 @@ com.microsoft.semantic-kernel semantickernel-sample-plugins - 1.2.0 + 1.2.2 ../pom.xml @@ -68,21 +68,6 @@ com.microsoft.semantic-kernel semantickernel-aiservices-openai - - org.apache.logging.log4j - log4j-api - test - - - org.apache.logging.log4j - log4j-core - test - - - org.apache.logging.log4j - log4j-slf4j2-impl - test - org.junit.jupiter junit-jupiter-api diff --git a/samples/semantickernel-sample-plugins/semantickernel-presidio-plugin/pom.xml b/samples/semantickernel-sample-plugins/semantickernel-presidio-plugin/pom.xml index cd3c0431..e5ce6b0f 100644 --- a/samples/semantickernel-sample-plugins/semantickernel-presidio-plugin/pom.xml +++ b/samples/semantickernel-sample-plugins/semantickernel-presidio-plugin/pom.xml @@ -4,7 +4,7 @@ com.microsoft.semantic-kernel semantickernel-sample-plugins - 1.2.0 + 1.2.2 ../pom.xml diff --git a/semantickernel-api/pom.xml b/semantickernel-api/pom.xml index 2bb0a050..c6bbae4a 100644 --- a/semantickernel-api/pom.xml +++ b/semantickernel-api/pom.xml @@ -6,7 +6,7 @@ com.microsoft.semantic-kernel semantickernel-parent - 1.2.0 + 1.2.2 ../pom.xml @@ -67,7 +67,6 @@ org.wiremock wiremock - 3.3.1 test @@ -80,7 +79,6 @@ org.mockito mockito-junit-jupiter - 5.11.0 test diff --git a/semantickernel-api/src/main/java/com/microsoft/semantickernel/Kernel.java b/semantickernel-api/src/main/java/com/microsoft/semantickernel/Kernel.java index 24b838bc..4ab1553f 100644 --- a/semantickernel-api/src/main/java/com/microsoft/semantickernel/Kernel.java +++ b/semantickernel-api/src/main/java/com/microsoft/semantickernel/Kernel.java @@ -163,6 +163,24 @@ public FunctionInvocation invokePromptAsync(@Nonnull String prompt) { return invokeAsync(KernelFunction.createFromPrompt(prompt).build()); } + public FunctionInvocation invokePromptAsync(@Nonnull String prompt, + @Nonnull KernelFunctionArguments arguments) { + KernelFunction function = KernelFunction.createFromPrompt(prompt).build(); + + return function.invokeAsync(this) + .withArguments(arguments); + } + + public FunctionInvocation invokePromptAsync(@Nonnull String prompt, + @Nonnull KernelFunctionArguments arguments, @Nonnull InvocationContext invocationContext) { + + KernelFunction function = KernelFunction.createFromPrompt(prompt).build(); + + return function.invokeAsync(this) + .withArguments(arguments) + .withInvocationContext(invocationContext); + } + /** * Invokes a {@code KernelFunction}. * diff --git a/semantickernel-api/src/main/java/com/microsoft/semantickernel/contextvariables/ContextVariableTypeConverter.java b/semantickernel-api/src/main/java/com/microsoft/semantickernel/contextvariables/ContextVariableTypeConverter.java index 0a4bc758..687ef302 100644 --- a/semantickernel-api/src/main/java/com/microsoft/semantickernel/contextvariables/ContextVariableTypeConverter.java +++ b/semantickernel-api/src/main/java/com/microsoft/semantickernel/contextvariables/ContextVariableTypeConverter.java @@ -309,9 +309,7 @@ public static class Builder { @SuppressFBWarnings("CT_CONSTRUCTOR_THROW") public Builder(Class clazz) { this.clazz = clazz; - fromObject = x -> { - throw new UnsupportedOperationException("fromObject not implemented"); - }; + fromObject = x -> ContextVariableTypes.convert(x, clazz); toPromptString = (a, b) -> { throw new UnsupportedOperationException("toPromptString not implemented"); }; diff --git a/semantickernel-api/src/main/java/com/microsoft/semantickernel/contextvariables/converters/ContextVariableJacksonConverter.java b/semantickernel-api/src/main/java/com/microsoft/semantickernel/contextvariables/converters/ContextVariableJacksonConverter.java new file mode 100644 index 00000000..ca4bfc44 --- /dev/null +++ b/semantickernel-api/src/main/java/com/microsoft/semantickernel/contextvariables/converters/ContextVariableJacksonConverter.java @@ -0,0 +1,79 @@ +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.semantickernel.contextvariables.converters; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.microsoft.semantickernel.contextvariables.ContextVariableTypeConverter; +import com.microsoft.semantickernel.contextvariables.ContextVariableTypeConverter.Builder; +import com.microsoft.semantickernel.exceptions.SKException; + +/** + * A utility class for creating {@link ContextVariableTypeConverter} instances that use Jackson for + * serialization and deserialization. + */ +public final class ContextVariableJacksonConverter { + + /** + * Creates a new {@link ContextVariableTypeConverter} that uses Jackson for serialization and + * deserialization. + * + * @param type the type of the context variable + * @param mapper the {@link ObjectMapper} to use for serialization and deserialization + * @param the type of the context variable + * @return a new {@link ContextVariableTypeConverter} + */ + public static ContextVariableTypeConverter create(Class type, ObjectMapper mapper) { + return builder(type, mapper).build(); + } + + /** + * Creates a new {@link ContextVariableTypeConverter} that uses Jackson for serialization and + * deserialization. + * + * @param type the type of the context variable + * @param the type of the context variable + * @return a new {@link ContextVariableTypeConverter} + */ + public static ContextVariableTypeConverter create(Class type) { + return create(type, new ObjectMapper()); + } + + /** + * Creates a new {@link Builder} for a {@link ContextVariableTypeConverter} that uses Jackson + * for serialization and deserialization. + * + * @param type the type of the context variable + * @param the type of the context variable + * @return a new {@link Builder} + */ + public static Builder builder(Class type) { + return builder(type, new ObjectMapper()); + } + + /** + * Creates a new {@link Builder} for a {@link ContextVariableTypeConverter} that uses Jackson + * for serialization and deserialization. + * + * @param type the type of the context variable + * @param mapper the {@link ObjectMapper} to use for serialization and deserialization + * @param the type of the context variable + * @return a new {@link Builder} + */ + public static Builder builder(Class type, ObjectMapper mapper) { + return ContextVariableTypeConverter.builder(type) + .fromPromptString(str -> { + try { + return mapper.readValue(str, type); + } catch (JsonProcessingException e) { + throw new SKException("Failed to deserialize object", e); + } + }) + .toPromptString(obj -> { + try { + return mapper.writerWithDefaultPrettyPrinter().writeValueAsString(obj); + } catch (JsonProcessingException e) { + throw new SKException("Failed to serialize object", e); + } + }); + } +} diff --git a/semantickernel-api/src/main/java/com/microsoft/semantickernel/contextvariables/converters/DateTimeContextVariableTypeConverter.java b/semantickernel-api/src/main/java/com/microsoft/semantickernel/contextvariables/converters/DateTimeContextVariableTypeConverter.java index 6a64ccd2..34fc741f 100644 --- a/semantickernel-api/src/main/java/com/microsoft/semantickernel/contextvariables/converters/DateTimeContextVariableTypeConverter.java +++ b/semantickernel-api/src/main/java/com/microsoft/semantickernel/contextvariables/converters/DateTimeContextVariableTypeConverter.java @@ -33,9 +33,7 @@ public DateTimeContextVariableTypeConverter() { return null; }, Object::toString, - o -> { - return ZonedDateTime.parse(o).toOffsetDateTime(); - }, + o -> ZonedDateTime.parse(o).toOffsetDateTime(), Arrays.asList( new DefaultConverter(OffsetDateTime.class, Instant.class) { @Override diff --git a/semantickernel-api/src/main/java/com/microsoft/semantickernel/hooks/KernelHooks.java b/semantickernel-api/src/main/java/com/microsoft/semantickernel/hooks/KernelHooks.java index f690fe26..959dda5a 100644 --- a/semantickernel-api/src/main/java/com/microsoft/semantickernel/hooks/KernelHooks.java +++ b/semantickernel-api/src/main/java/com/microsoft/semantickernel/hooks/KernelHooks.java @@ -66,7 +66,7 @@ public UnmodifiableKernelHooks unmodifiableClone() { * * @return an unmodifiable map of the hooks */ - private Map> getHooks() { + protected Map> getHooks() { return Collections.unmodifiableMap(hooks); } @@ -224,6 +224,31 @@ public boolean isEmpty() { return hooks.isEmpty(); } + /** + * Builds the list of hooks to be invoked for the given context, by merging the hooks in this + * collection with the hooks in the context. Duplicate hooks in b will override hooks in a. + * + * @param a hooks to merge + * @param b hooks to merge + * @return the list of hooks to be invoked + */ + public static KernelHooks merge(@Nullable KernelHooks a, @Nullable KernelHooks b) { + KernelHooks hooks = a; + if (hooks == null) { + hooks = new KernelHooks(); + } + + if (b == null) { + return hooks; + } else if (hooks.isEmpty()) { + return b; + } else { + HashMap> merged = new HashMap<>(hooks.getHooks()); + merged.putAll(b.getHooks()); + return new KernelHooks(merged); + } + } + /** * A wrapper for KernelHooks that disables mutating methods. */ diff --git a/semantickernel-api/src/main/java/com/microsoft/semantickernel/orchestration/FunctionInvocation.java b/semantickernel-api/src/main/java/com/microsoft/semantickernel/orchestration/FunctionInvocation.java index 276f2c18..4cd6ae2a 100644 --- a/semantickernel-api/src/main/java/com/microsoft/semantickernel/orchestration/FunctionInvocation.java +++ b/semantickernel-api/src/main/java/com/microsoft/semantickernel/orchestration/FunctionInvocation.java @@ -133,7 +133,7 @@ private static BiConsumer, SynchronousSink FunctionInvocation withResultType(ContextVariableType resultTyp * @return A new {@code FunctionInvocation} for fluent chaining. */ public FunctionInvocation withResultType(Class resultType) { - return withResultType(ContextVariableTypes.getGlobalVariableTypeForClass(resultType)); + try { + return withResultType(contextVariableTypes.getVariableTypeForSuperClass(resultType)); + } catch (SKException e) { + return withResultType(ContextVariableTypes.getGlobalVariableTypeForClass(resultType)); + } } /** diff --git a/semantickernel-api/src/main/java/com/microsoft/semantickernel/semanticfunctions/KernelFunctionFromMethod.java b/semantickernel-api/src/main/java/com/microsoft/semantickernel/semanticfunctions/KernelFunctionFromMethod.java index 6d9d1166..9a7b09dc 100644 --- a/semantickernel-api/src/main/java/com/microsoft/semantickernel/semanticfunctions/KernelFunctionFromMethod.java +++ b/semantickernel-api/src/main/java/com/microsoft/semantickernel/semanticfunctions/KernelFunctionFromMethod.java @@ -157,10 +157,9 @@ public static ImplementationFunc getFunction(Method method, Object instan } // kernelHooks must be effectively final for lambda - KernelHooks kernelHooks = context.getKernelHooks() != null - ? context.getKernelHooks() - : kernel.getGlobalKernelHooks(); - assert kernelHooks != null : "getGlobalKernelHooks() should never return null!"; + KernelHooks kernelHooks = KernelHooks.merge( + kernel.getGlobalKernelHooks(), + context.getKernelHooks()); FunctionInvokingEvent updatedState = kernelHooks .executeHooks( diff --git a/semantickernel-api/src/main/java/com/microsoft/semantickernel/semanticfunctions/KernelFunctionFromPrompt.java b/semantickernel-api/src/main/java/com/microsoft/semantickernel/semanticfunctions/KernelFunctionFromPrompt.java index 1d754d65..5d124d84 100644 --- a/semantickernel-api/src/main/java/com/microsoft/semantickernel/semanticfunctions/KernelFunctionFromPrompt.java +++ b/semantickernel-api/src/main/java/com/microsoft/semantickernel/semanticfunctions/KernelFunctionFromPrompt.java @@ -20,6 +20,7 @@ import com.microsoft.semantickernel.services.chatcompletion.ChatCompletionService; import com.microsoft.semantickernel.services.textcompletion.TextGenerationService; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -102,10 +103,9 @@ private Flux> invokeInternalAsync( : InvocationContext.builder().build(); // must be effectively final for lambda - KernelHooks kernelHooks = context.getKernelHooks() != null - ? context.getKernelHooks() - : kernel.getGlobalKernelHooks(); - assert kernelHooks != null : "getGlobalKernelHooks() should never return null"; + KernelHooks kernelHooks = KernelHooks.merge( + kernel.getGlobalKernelHooks(), + context.getKernelHooks()); PromptRenderingEvent preRenderingHookResult = kernelHooks .executeHooks(new PromptRenderingEvent(this, argumentsIn)); @@ -440,6 +440,7 @@ public KernelFunction build() { name, template, templateFormat, + Collections.emptySet(), description, inputVariables, outputVariable, diff --git a/semantickernel-api/src/main/java/com/microsoft/semantickernel/semanticfunctions/PromptTemplateConfig.java b/semantickernel-api/src/main/java/com/microsoft/semantickernel/semanticfunctions/PromptTemplateConfig.java index 02adb289..5f69ad46 100644 --- a/semantickernel-api/src/main/java/com/microsoft/semantickernel/semanticfunctions/PromptTemplateConfig.java +++ b/semantickernel-api/src/main/java/com/microsoft/semantickernel/semanticfunctions/PromptTemplateConfig.java @@ -11,9 +11,11 @@ import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Set; import javax.annotation.Nullable; /** @@ -43,6 +45,7 @@ public class PromptTemplateConfig { @Nullable private final String template; private final String templateFormat; + private final Set promptTemplateOptions; @Nullable private final String description; private final List inputVariables; @@ -61,6 +64,7 @@ protected PromptTemplateConfig(String template) { DEFAULT_CONFIG_NAME, template, SEMANTIC_KERNEL_TEMPLATE_FORMAT, + Collections.emptySet(), "", Collections.emptyList(), new OutputVariable(String.class.getName(), "out"), @@ -70,14 +74,15 @@ protected PromptTemplateConfig(String template) { /** * Constructor for a prompt template config * - * @param schema Schema version - * @param name Name of the template - * @param template Template string - * @param templateFormat Template format - * @param description Description of the template - * @param inputVariables Input variables - * @param outputVariable Output variable - * @param executionSettings Execution settings + * @param schema Schema version + * @param name Name of the template + * @param template Template string + * @param templateFormat Template format + * @param promptTemplateOptions Prompt template options + * @param description Description of the template + * @param inputVariables Input variables + * @param outputVariable Output variable + * @param executionSettings Execution settings */ @JsonCreator public PromptTemplateConfig( @@ -85,6 +90,7 @@ public PromptTemplateConfig( @Nullable @JsonProperty("name") String name, @Nullable @JsonProperty("template") String template, @Nullable @JsonProperty(value = "template_format", defaultValue = SEMANTIC_KERNEL_TEMPLATE_FORMAT) String templateFormat, + @Nullable @JsonProperty(value = "prompt_template_options") Set promptTemplateOptions, @Nullable @JsonProperty("description") String description, @Nullable @JsonProperty("input_variables") List inputVariables, @Nullable @JsonProperty("output_variable") OutputVariable outputVariable, @@ -96,6 +102,10 @@ public PromptTemplateConfig( templateFormat = SEMANTIC_KERNEL_TEMPLATE_FORMAT; } this.templateFormat = templateFormat; + if (promptTemplateOptions == null) { + promptTemplateOptions = new HashSet<>(); + } + this.promptTemplateOptions = promptTemplateOptions; this.description = description; if (inputVariables == null) { this.inputVariables = new ArrayList<>(); @@ -127,6 +137,7 @@ protected PromptTemplateConfig( @Nullable String name, @Nullable String template, @Nullable String templateFormat, + @Nullable Set promptTemplateOptions, @Nullable String description, @Nullable List inputVariables, @Nullable OutputVariable outputVariable, @@ -136,6 +147,7 @@ protected PromptTemplateConfig( name, template, templateFormat, + promptTemplateOptions, description, inputVariables, outputVariable, @@ -152,6 +164,7 @@ public PromptTemplateConfig(PromptTemplateConfig promptTemplate) { promptTemplate.name, promptTemplate.template, promptTemplate.templateFormat, + promptTemplate.promptTemplateOptions, promptTemplate.description, promptTemplate.inputVariables, promptTemplate.outputVariable, @@ -300,6 +313,15 @@ public int getSchema() { return schema; } + /** + * Get the prompt template options of the prompt template config. + * + * @return The prompt template options of the prompt template config. + */ + public Set getPromptTemplateOptions() { + return Collections.unmodifiableSet(promptTemplateOptions); + } + /** * Create a builder for a prompt template config which is a clone of the current object. * @@ -358,6 +380,7 @@ public static class Builder { @Nullable private String template; private String templateFormat = SEMANTIC_KERNEL_TEMPLATE_FORMAT; + private final Set promptTemplateOptions = new HashSet<>(); @Nullable private String description = null; private List inputVariables = new ArrayList<>(); @@ -433,6 +456,11 @@ public Builder withTemplateFormat(String templateFormat) { return this; } + public Builder addPromptTemplateOption(PromptTemplateOption option) { + promptTemplateOptions.add(option); + return this; + } + /** * Set the inputVariables of the prompt template config. * @@ -477,6 +505,7 @@ public PromptTemplateConfig build() { name, template, templateFormat, + promptTemplateOptions, description, inputVariables, outputVariable, diff --git a/semantickernel-api/src/main/java/com/microsoft/semantickernel/semanticfunctions/PromptTemplateOption.java b/semantickernel-api/src/main/java/com/microsoft/semantickernel/semanticfunctions/PromptTemplateOption.java new file mode 100644 index 00000000..5d244613 --- /dev/null +++ b/semantickernel-api/src/main/java/com/microsoft/semantickernel/semanticfunctions/PromptTemplateOption.java @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.semantickernel.semanticfunctions; + +public enum PromptTemplateOption { + /** + * Allow methods on objects provided as arguments to an invocation, to be invoked when rendering + * a template and its return value used. Typically, this would be used to call a getter on an + * object i.e. {@code {{#each users}} {{userName}} {{/each}} } on a handlebars template will + * call the method {@code getUserName()} on each object in {@code users}. + *

+ * WARNING: If this option is used, ensure that your template is trusted, and that objects added + * as arguments to an invocation, do not contain methods that are unsafe to be invoked when + * rendering a template. + */ + ALLOW_CONTEXT_VARIABLE_METHOD_CALLS_UNSAFE +} \ No newline at end of file diff --git a/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/audio/AudioToTextService.java b/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/audio/AudioToTextService.java index 91e14466..871d4cb4 100644 --- a/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/audio/AudioToTextService.java +++ b/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/audio/AudioToTextService.java @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. package com.microsoft.semantickernel.services.audio; +import com.azure.ai.openai.OpenAIAsyncClient; import com.microsoft.semantickernel.implementation.ServiceLoadUtil; import com.microsoft.semantickernel.services.AIService; import com.microsoft.semantickernel.services.openai.OpenAiServiceBuilder; @@ -32,7 +33,8 @@ static Builder builder() { /** * Builder for the AudioToTextService. */ - abstract class Builder extends OpenAiServiceBuilder { + abstract class Builder + extends OpenAiServiceBuilder { } } diff --git a/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/audio/TextToAudioService.java b/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/audio/TextToAudioService.java index 80a42436..ff2cd40a 100644 --- a/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/audio/TextToAudioService.java +++ b/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/audio/TextToAudioService.java @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. package com.microsoft.semantickernel.services.audio; +import com.azure.ai.openai.OpenAIAsyncClient; import com.microsoft.semantickernel.implementation.ServiceLoadUtil; import com.microsoft.semantickernel.services.AIService; import com.microsoft.semantickernel.services.openai.OpenAiServiceBuilder; @@ -36,7 +37,7 @@ static Builder builder() { * Builder for the TextToAudioService. */ abstract class Builder extends - OpenAiServiceBuilder { + OpenAiServiceBuilder { } } diff --git a/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/chatcompletion/ChatHistory.java b/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/chatcompletion/ChatHistory.java index d2f391ff..ea910c01 100644 --- a/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/chatcompletion/ChatHistory.java +++ b/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/chatcompletion/ChatHistory.java @@ -33,7 +33,7 @@ public ChatHistory() { * @param instructions The instructions to add to the chat history */ public ChatHistory(@Nullable String instructions) { - this.chatMessageContents = new ArrayList<>(); + this.chatMessageContents = Collections.synchronizedList(new ArrayList<>()); if (instructions != null) { this.chatMessageContents.add( ChatMessageTextContent.systemMessage(instructions)); @@ -45,8 +45,9 @@ public ChatHistory(@Nullable String instructions) { * * @param chatMessageContents The chat message contents to add to the chat history */ - public ChatHistory(List chatMessageContents) { - this.chatMessageContents = new ArrayList(chatMessageContents); + public ChatHistory(List> chatMessageContents) { + this.chatMessageContents = Collections + .synchronizedList(new ArrayList<>(chatMessageContents)); } /** @@ -55,7 +56,7 @@ public ChatHistory(List chatMessageContents) { * @return List of messages in the chat */ public List> getMessages() { - return Collections.unmodifiableList(chatMessageContents); + return Collections.unmodifiableList(new ArrayList<>(chatMessageContents)); } /** @@ -67,7 +68,8 @@ public Optional> getLastMessage() { if (chatMessageContents.isEmpty()) { return Optional.empty(); } - return Optional.of(chatMessageContents.get(chatMessageContents.size() - 1)); + return Optional + .of(chatMessageContents.get(chatMessageContents.size() - 1)); } /** @@ -113,8 +115,8 @@ public Spliterator> spliterator() { * @param encoding The encoding of the message * @param metadata The metadata of the message */ - public void addMessage(AuthorRole authorRole, String content, Charset encoding, - FunctionResultMetadata metadata) { + public ChatHistory addMessage(AuthorRole authorRole, String content, Charset encoding, + FunctionResultMetadata metadata) { chatMessageContents.add( ChatMessageTextContent.builder() .withAuthorRole(authorRole) @@ -122,6 +124,7 @@ public void addMessage(AuthorRole authorRole, String content, Charset encoding, .withEncoding(encoding) .withMetadata(metadata) .build()); + return this; } /** @@ -130,12 +133,13 @@ public void addMessage(AuthorRole authorRole, String content, Charset encoding, * @param authorRole The role of the author of the message * @param content The content of the message */ - public void addMessage(AuthorRole authorRole, String content) { + public ChatHistory addMessage(AuthorRole authorRole, String content) { chatMessageContents.add( ChatMessageTextContent.builder() .withAuthorRole(authorRole) .withContent(content) .build()); + return this; } /** @@ -143,8 +147,9 @@ public void addMessage(AuthorRole authorRole, String content) { * * @param content The content of the message */ - public void addMessage(ChatMessageContent content) { + public ChatHistory addMessage(ChatMessageContent content) { chatMessageContents.add(content); + return this; } /** @@ -152,8 +157,8 @@ public void addMessage(ChatMessageContent content) { * * @param content The content of the user message */ - public void addUserMessage(String content) { - addMessage(AuthorRole.USER, content); + public ChatHistory addUserMessage(String content) { + return addMessage(AuthorRole.USER, content); } /** @@ -161,8 +166,8 @@ public void addUserMessage(String content) { * * @param content The content of the assistant message */ - public void addAssistantMessage(String content) { - addMessage(AuthorRole.ASSISTANT, content); + public ChatHistory addAssistantMessage(String content) { + return addMessage(AuthorRole.ASSISTANT, content); } /** @@ -170,11 +175,12 @@ public void addAssistantMessage(String content) { * * @param content The content of the system message */ - public void addSystemMessage(String content) { - addMessage(AuthorRole.SYSTEM, content); + public ChatHistory addSystemMessage(String content) { + return addMessage(AuthorRole.SYSTEM, content); } - public void addAll(List> messages) { + public ChatHistory addAll(List> messages) { chatMessageContents.addAll(messages); + return this; } } diff --git a/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/openai/OpenAiServiceBuilder.java b/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/openai/OpenAiServiceBuilder.java index 0197a55a..5386cd83 100644 --- a/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/openai/OpenAiServiceBuilder.java +++ b/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/openai/OpenAiServiceBuilder.java @@ -1,20 +1,25 @@ // Copyright (c) Microsoft. All rights reserved. package com.microsoft.semantickernel.services.openai; -import com.azure.ai.openai.OpenAIAsyncClient; +import com.microsoft.semantickernel.services.AIService; import com.microsoft.semantickernel.builders.SemanticKernelBuilder; import javax.annotation.Nullable; /** * Builder for an OpenAI service. - */ -public abstract class OpenAiServiceBuilder> implements + * @param The client type + * @param The service type + * @param The builder type +*/ +public abstract class OpenAiServiceBuilder> + implements + SemanticKernelBuilder { @Nullable protected String modelId; @Nullable - protected OpenAIAsyncClient client; + protected C client; @Nullable protected String serviceId; @Nullable @@ -51,7 +56,7 @@ public U withDeploymentName(String deploymentName) { * @param client The OpenAI client * @return The builder */ - public U withOpenAIAsyncClient(OpenAIAsyncClient client) { + public U withOpenAIAsyncClient(C client) { this.client = client; return (U) this; } diff --git a/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/textcompletion/TextGenerationService.java b/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/textcompletion/TextGenerationService.java index ab592c86..0ab08f5f 100644 --- a/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/textcompletion/TextGenerationService.java +++ b/semantickernel-api/src/main/java/com/microsoft/semantickernel/services/textcompletion/TextGenerationService.java @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. package com.microsoft.semantickernel.services.textcompletion; +import com.azure.ai.openai.OpenAIAsyncClient; import com.microsoft.semantickernel.Kernel; import com.microsoft.semantickernel.implementation.ServiceLoadUtil; import com.microsoft.semantickernel.orchestration.PromptExecutionSettings; @@ -60,6 +61,7 @@ Flux getStreamingTextContentsAsync( /** * Builder for a TextGenerationService */ - abstract class Builder extends OpenAiServiceBuilder { + abstract class Builder + extends OpenAiServiceBuilder { } } diff --git a/semantickernel-api/src/main/java/com/microsoft/semantickernel/templateengine/handlebars/HandlebarsPromptTemplate.java b/semantickernel-api/src/main/java/com/microsoft/semantickernel/templateengine/handlebars/HandlebarsPromptTemplate.java index 2e7c260e..ec701ce3 100644 --- a/semantickernel-api/src/main/java/com/microsoft/semantickernel/templateengine/handlebars/HandlebarsPromptTemplate.java +++ b/semantickernel-api/src/main/java/com/microsoft/semantickernel/templateengine/handlebars/HandlebarsPromptTemplate.java @@ -9,6 +9,7 @@ import com.github.jknack.handlebars.Helper; import com.github.jknack.handlebars.Options; import com.github.jknack.handlebars.ValueResolver; +import com.github.jknack.handlebars.context.JavaBeanValueResolver; import com.microsoft.semantickernel.Kernel; import com.microsoft.semantickernel.contextvariables.ContextVariable; import com.microsoft.semantickernel.contextvariables.ContextVariableType; @@ -21,6 +22,7 @@ import com.microsoft.semantickernel.semanticfunctions.KernelFunctionArguments; import com.microsoft.semantickernel.semanticfunctions.PromptTemplate; import com.microsoft.semantickernel.semanticfunctions.PromptTemplateConfig; +import com.microsoft.semantickernel.semanticfunctions.PromptTemplateOption; import com.microsoft.semantickernel.services.chatcompletion.ChatMessageContent; import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; import java.io.IOException; @@ -35,7 +37,6 @@ import java.util.stream.Collectors; import javax.annotation.Nonnull; import javax.annotation.Nullable; -import org.apache.commons.text.StringEscapeUtils; import reactor.core.publisher.Mono; /** @@ -168,7 +169,7 @@ public Set> propertySet(Object context) { } } - private static class HandleBarsPromptTemplateHandler { + private class HandleBarsPromptTemplateHandler { private final String template; private final Handlebars handlebars; @@ -181,7 +182,7 @@ public HandleBarsPromptTemplateHandler( this.template = template; this.handlebars = new Handlebars(); this.handlebars - .registerHelper("message", HandleBarsPromptTemplateHandler::handleMessage) + .registerHelper("message", this::handleMessage) .registerHelper("each", handleEach(context)) .with(EscapingStrategy.XML); @@ -190,7 +191,7 @@ public HandleBarsPromptTemplateHandler( // TODO: 1.0 Add more helpers } - private static Helper handleEach(InvocationContext invocationContext) { + private Helper handleEach(InvocationContext invocationContext) { return (context, options) -> { if (context instanceof ContextVariable) { return ((ContextVariable) context) @@ -227,7 +228,7 @@ private static Helper handleEach(InvocationContext invocationContext) { } @Nullable - private static CharSequence handleMessage(Object context, Options options) + private CharSequence handleMessage(Object context, Options options) throws IOException { String role = options.hash("role"); String content = (String) options.fn(context); @@ -258,7 +259,10 @@ public Mono render(KernelFunctionArguments variables) { resolvers.add(new MessageResolver()); resolvers.add(new ContextVariableResolver()); - // resolvers.addAll(ValueResolver.defaultValueResolvers()); + if (promptTemplate.getPromptTemplateOptions() + .contains(PromptTemplateOption.ALLOW_CONTEXT_VARIABLE_METHOD_CALLS_UNSAFE)) { + resolvers.add(JavaBeanValueResolver.INSTANCE); + } Context context = Context .newBuilder(variables) diff --git a/semantickernel-api/src/test/java/com/microsoft/semantickernel/semanticfunctions/PromptTemplateConfigTest.java b/semantickernel-api/src/test/java/com/microsoft/semantickernel/semanticfunctions/PromptTemplateConfigTest.java index d4141dcf..ea34400a 100644 --- a/semantickernel-api/src/test/java/com/microsoft/semantickernel/semanticfunctions/PromptTemplateConfigTest.java +++ b/semantickernel-api/src/test/java/com/microsoft/semantickernel/semanticfunctions/PromptTemplateConfigTest.java @@ -5,6 +5,7 @@ import com.microsoft.semantickernel.orchestration.PromptExecutionSettings; import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; import java.util.List; import org.junit.jupiter.api.Test; @@ -28,6 +29,7 @@ void testInstanceMadeWithBuilderEqualsInstanceMadeWithConstructor() { name, template, "semantic-kernel", + Collections.emptySet(), description, inputVariables, outputVariable, diff --git a/semantickernel-bom/pom.xml b/semantickernel-bom/pom.xml index b7f52070..1322666f 100644 --- a/semantickernel-bom/pom.xml +++ b/semantickernel-bom/pom.xml @@ -5,7 +5,7 @@ com.microsoft.semantic-kernel semantickernel-bom - 1.2.0 + 1.2.2 pom Semantic Kernel Java BOM @@ -256,6 +256,6 @@ https://github.com/microsoft/semantic-kernel scm:git:https://github.com/microsoft/semantic-kernel.git scm:git:https://github.com/microsoft/semantic-kernel.git - java-1.2.0 + java-1.2.2 diff --git a/semantickernel-experimental/pom.xml b/semantickernel-experimental/pom.xml index 327ff1e6..44772dc2 100644 --- a/semantickernel-experimental/pom.xml +++ b/semantickernel-experimental/pom.xml @@ -4,7 +4,7 @@ com.microsoft.semantic-kernel semantickernel-parent - 1.2.0 + 1.2.2 semantickernel-experimental @@ -64,7 +64,6 @@ org.wiremock wiremock - 3.3.1 test @@ -77,7 +76,6 @@ org.mockito mockito-junit-jupiter - 5.11.0 test @@ -109,6 +107,11 @@ + + org.postgresql + postgresql + 42.7.3 + diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/azureaisearch/AzureAISearchVectorStore.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/azureaisearch/AzureAISearchVectorStore.java index 39d13f75..a7b7a7c8 100644 --- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/azureaisearch/AzureAISearchVectorStore.java +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/azureaisearch/AzureAISearchVectorStore.java @@ -4,16 +4,15 @@ import com.azure.search.documents.indexes.SearchIndexAsyncClient; import com.azure.search.documents.indexes.models.SearchIndex; import com.microsoft.semantickernel.data.VectorStore; +import com.microsoft.semantickernel.data.VectorStoreRecordCollection; import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDefinition; import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; -import reactor.core.publisher.Mono; - +import java.util.List; import javax.annotation.Nonnull; import javax.annotation.Nullable; -import java.util.List; +import reactor.core.publisher.Mono; -public class AzureAISearchVectorStore - implements VectorStore> { +public class AzureAISearchVectorStore implements VectorStore { private final SearchIndexAsyncClient client; private final AzureAISearchVectorStoreOptions options; @@ -21,7 +20,7 @@ public class AzureAISearchVectorStore /** * Creates a new instance of {@link AzureAISearchVectorStore}. * - * @param client The Azure AI Search client. + * @param client The Azure AI Search client. * @param options The options for the vector store. */ @SuppressFBWarnings("EI_EXPOSE_REP2") @@ -34,17 +33,29 @@ public AzureAISearchVectorStore(@Nonnull SearchIndexAsyncClient client, /** * Gets a new instance of {@link AzureAISearchVectorStoreRecordCollection} * - * @param collectionName The name of the collection. - * @param recordClass The class type of the record. + * @param collectionName The name of the collection. + * @param recordClass The class type of the record. * @param recordDefinition The record definition. * @return The collection. */ @Override - public AzureAISearchVectorStoreRecordCollection getCollection( + public final VectorStoreRecordCollection getCollection( @Nonnull String collectionName, + @Nonnull Class keyClass, @Nonnull Class recordClass, - VectorStoreRecordDefinition recordDefinition) { + @Nullable VectorStoreRecordDefinition recordDefinition) { + if (!keyClass.equals(String.class)) { + throw new IllegalArgumentException("Azure AI Search only supports string keys"); + } + return (VectorStoreRecordCollection) getCollection( + collectionName, recordClass, recordDefinition); + } + + public AzureAISearchVectorStoreRecordCollection getCollection( + @Nonnull String collectionName, + @Nonnull Class recordClass, + @Nullable VectorStoreRecordDefinition recordDefinition) { if (options.getVectorStoreRecordCollectionFactory() != null) { return options.getVectorStoreRecordCollectionFactory() .createVectorStoreRecordCollection( @@ -56,7 +67,9 @@ public AzureAISearchVectorStoreRecordCollection getCollect .build()); } - return new AzureAISearchVectorStoreRecordCollection<>(client, collectionName, + return new AzureAISearchVectorStoreRecordCollection<>( + client, + collectionName, AzureAISearchVectorStoreRecordCollectionOptions.builder() .withRecordClass(recordClass) .withRecordDefinition(recordDefinition) @@ -86,6 +99,7 @@ public static Builder builder() { * Builder for {@link AzureAISearchVectorStore}. */ public static class Builder { + @Nullable private SearchIndexAsyncClient client; @Nullable @@ -109,7 +123,8 @@ public Builder withClient(@Nonnull SearchIndexAsyncClient client) { * @param options The options for the Azure AI Search vector store. * @return The updated builder instance. */ - public Builder withOptions(@Nonnull AzureAISearchVectorStoreOptions options) { + public Builder withOptions( + @Nonnull AzureAISearchVectorStoreOptions options) { this.options = options; return this; } diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/azureaisearch/AzureAISearchVectorStoreCollectionCreateMapping.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/azureaisearch/AzureAISearchVectorStoreCollectionCreateMapping.java index e077a510..c57e6f9d 100644 --- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/azureaisearch/AzureAISearchVectorStoreCollectionCreateMapping.java +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/azureaisearch/AzureAISearchVectorStoreCollectionCreateMapping.java @@ -13,12 +13,9 @@ import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDataField; import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordKeyField; import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordVectorField; - -import javax.annotation.Nonnull; -import javax.annotation.Nullable; import java.time.OffsetDateTime; import java.util.List; -import java.util.Objects; +import javax.annotation.Nonnull; public class AzureAISearchVectorStoreCollectionCreateMapping { diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/azureaisearch/AzureAISearchVectorStoreOptions.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/azureaisearch/AzureAISearchVectorStoreOptions.java index d7bb0314..db3b7ab6 100644 --- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/azureaisearch/AzureAISearchVectorStoreOptions.java +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/azureaisearch/AzureAISearchVectorStoreOptions.java @@ -1,10 +1,10 @@ // Copyright (c) Microsoft. All rights reserved. package com.microsoft.semantickernel.connectors.data.azureaisearch; -import javax.annotation.Nonnull; import javax.annotation.Nullable; public class AzureAISearchVectorStoreOptions { + @Nullable private final AzureAISearchVectorStoreRecordCollectionFactory vectorStoreRecordCollectionFactory; @@ -49,6 +49,7 @@ public AzureAISearchVectorStoreRecordCollectionFactory getVectorStoreRecordColle * */ public static class Builder { + @Nullable private AzureAISearchVectorStoreRecordCollectionFactory vectorStoreRecordCollectionFactory; diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/azureaisearch/AzureAISearchVectorStoreRecordCollection.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/azureaisearch/AzureAISearchVectorStoreRecordCollection.java index 9576b122..5937d928 100644 --- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/azureaisearch/AzureAISearchVectorStoreRecordCollection.java +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/azureaisearch/AzureAISearchVectorStoreRecordCollection.java @@ -11,22 +11,18 @@ import com.azure.search.documents.indexes.models.VectorSearchProfile; import com.azure.search.documents.models.IndexDocumentsResult; import com.azure.search.documents.models.IndexingResult; -import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordField; -import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordKeyField; -import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordVectorField; -import com.microsoft.semantickernel.exceptions.SKException; import com.microsoft.semantickernel.data.VectorStoreRecordCollection; import com.microsoft.semantickernel.data.VectorStoreRecordMapper; import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDataField; import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDefinition; +import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordField; +import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordKeyField; +import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordVectorField; import com.microsoft.semantickernel.data.recordoptions.DeleteRecordOptions; import com.microsoft.semantickernel.data.recordoptions.GetRecordOptions; import com.microsoft.semantickernel.data.recordoptions.UpsertRecordOptions; +import com.microsoft.semantickernel.exceptions.SKException; import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - -import javax.annotation.Nonnull; import java.time.OffsetDateTime; import java.util.ArrayList; import java.util.Arrays; @@ -38,9 +34,12 @@ import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; +import javax.annotation.Nonnull; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; -public class AzureAISearchVectorStoreRecordCollection - implements VectorStoreRecordCollection { +public class AzureAISearchVectorStoreRecordCollection implements + VectorStoreRecordCollection { private static final HashSet> supportedKeyTypes = new HashSet<>( Collections.singletonList( @@ -90,12 +89,16 @@ public AzureAISearchVectorStoreRecordCollection( : options.getRecordDefinition(); // Validate supported types - VectorStoreRecordDefinition.validateSupportedKeyTypes(this.options.getRecordClass(), - this.recordDefinition, supportedKeyTypes); - VectorStoreRecordDefinition.validateSupportedDataTypes(this.options.getRecordClass(), - this.recordDefinition, supportedDataTypes); - VectorStoreRecordDefinition.validateSupportedVectorTypes(this.options.getRecordClass(), - this.recordDefinition, supportedVectorTypes); + VectorStoreRecordDefinition.validateSupportedTypes( + Collections + .singletonList(recordDefinition.getKeyDeclaredField(this.options.getRecordClass())), + supportedKeyTypes); + VectorStoreRecordDefinition.validateSupportedTypes( + recordDefinition.getDataDeclaredFields(this.options.getRecordClass()), + supportedDataTypes); + VectorStoreRecordDefinition.validateSupportedTypes( + recordDefinition.getVectorDeclaredFields(this.options.getRecordClass()), + supportedVectorTypes); // Add non-vector fields to the list nonVectorFields.add(this.recordDefinition.getKeyField().getName()); @@ -120,7 +123,7 @@ public Mono collectionExistsAsync() { } @Override - public Mono createCollectionAsync() { + public Mono> createCollectionAsync() { List searchFields = new ArrayList<>(); List algorithms = new ArrayList<>(); List profiles = new ArrayList<>(); @@ -147,18 +150,19 @@ public Mono createCollectionAsync() { .setAlgorithms(algorithms) .setProfiles(profiles)); - return client.createIndex(newIndex).then(); + return client.createIndex(newIndex).then(Mono.just(this)); } @Override - public Mono createCollectionIfNotExistsAsync() { + public Mono> createCollectionIfNotExistsAsync() { return collectionExistsAsync().flatMap( exists -> { if (!exists) { return createCollectionAsync(); } return Mono.empty(); - }); + }) + .then(Mono.just(this)); } @Override @@ -187,11 +191,11 @@ public Mono getAsync( } return client.getDocumentWithResponse(key, this.options.getRecordClass(), selectedFields) - .map(response -> { + .flatMap(response -> { if (response.getStatusCode() == 404) { - throw new SKException("Record not found: " + key); + return Mono.error(new SKException("Record not found: " + key)); } - return response.getValue(); + return Mono.just(response.getValue()); }); } diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/azureaisearch/AzureAISearchVectorStoreRecordCollectionFactory.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/azureaisearch/AzureAISearchVectorStoreRecordCollectionFactory.java index c5041284..ec08ba03 100644 --- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/azureaisearch/AzureAISearchVectorStoreRecordCollectionFactory.java +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/azureaisearch/AzureAISearchVectorStoreRecordCollectionFactory.java @@ -5,16 +5,15 @@ /** * Factory for creating Azure AI Search vector store record collections. - * */ public interface AzureAISearchVectorStoreRecordCollectionFactory { /** * Creates a new Azure AI Search vector store record collection. * - * @param client The Azure AI Search client. + * @param client The Azure AI Search client. * @param collectionName The name of the collection. - * @param options The options for the collection. + * @param options The options for the collection. * @return The new Azure AI Search vector store record collection. */ AzureAISearchVectorStoreRecordCollection createVectorStoreRecordCollection( diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/azureaisearch/AzureAISearchVectorStoreRecordCollectionOptions.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/azureaisearch/AzureAISearchVectorStoreRecordCollectionOptions.java index 45fb410c..7275dd43 100644 --- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/azureaisearch/AzureAISearchVectorStoreRecordCollectionOptions.java +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/azureaisearch/AzureAISearchVectorStoreRecordCollectionOptions.java @@ -4,7 +4,6 @@ import com.azure.search.documents.SearchDocument; import com.microsoft.semantickernel.data.VectorStoreRecordMapper; import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDefinition; - import javax.annotation.Nonnull; import javax.annotation.Nullable; @@ -14,6 +13,7 @@ * @param the record type */ public class AzureAISearchVectorStoreRecordCollectionOptions { + private final Class recordClass; @Nullable private final VectorStoreRecordMapper vectorStoreRecordMapper; @@ -75,6 +75,7 @@ private AzureAISearchVectorStoreRecordCollectionOptions( * @param the record type */ public static class Builder { + @Nullable private VectorStoreRecordMapper vectorStoreRecordMapper; @Nullable diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStore.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStore.java new file mode 100644 index 00000000..66ad995b --- /dev/null +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStore.java @@ -0,0 +1,185 @@ +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.semantickernel.connectors.data.jdbc; + +import com.microsoft.semantickernel.connectors.data.redis.RedisVectorStoreRecordCollection; +import com.microsoft.semantickernel.data.VectorStoreRecordCollection; +import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDefinition; +import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import javax.sql.DataSource; +import java.util.List; + +/** + * A JDBC vector store. + */ +public class JDBCVectorStore implements SQLVectorStore { + private final DataSource dataSource; + private final JDBCVectorStoreOptions options; + private final JDBCVectorStoreQueryProvider queryProvider; + + /** + * Creates a new instance of the {@link JDBCVectorStore}. + * If using this constructor, call {@link #prepareAsync()} before using the vector store. + * + * @param dataSource the connection + * @param options the options + */ + @SuppressFBWarnings("EI_EXPOSE_REP2") // DataSource is not exposed + public JDBCVectorStore(@Nonnull DataSource dataSource, + @Nullable JDBCVectorStoreOptions options) { + this.dataSource = dataSource; + this.options = options; + + if (this.options != null && this.options.getQueryProvider() != null) { + this.queryProvider = this.options.getQueryProvider(); + } else { + this.queryProvider = JDBCVectorStoreDefaultQueryProvider.builder() + .withDataSource(dataSource) + .build(); + } + } + + /** + * Creates a new builder for the vector store. + * + * @return the builder + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Gets a collection from the vector store. + * + * @param collectionName The name of the collection. + * @param recordClass The class type of the record. + * @param recordDefinition The record definition. + * @return The collection. + */ + @Override + public VectorStoreRecordCollection getCollection( + @Nonnull String collectionName, @Nonnull Class keyClass, + @Nonnull Class recordClass, + @Nullable VectorStoreRecordDefinition recordDefinition) { + if (keyClass != String.class) { + throw new IllegalArgumentException("Redis only supports string keys"); + } + + return (VectorStoreRecordCollection) getCollection( + collectionName, + recordClass, + recordDefinition); + } + + /** + * Gets a collection from the vector store. + * + * @param collectionName The name of the collection. + * @param recordClass The class type of the record. + * @param recordDefinition The record definition. + * @return The collection. + */ + public JDBCVectorStoreRecordCollection getCollection( + @Nonnull String collectionName, + @Nonnull Class recordClass, + @Nullable VectorStoreRecordDefinition recordDefinition) { + if (this.options != null && this.options.getVectorStoreRecordCollectionFactory() != null) { + return this.options.getVectorStoreRecordCollectionFactory() + .createVectorStoreRecordCollection( + dataSource, + collectionName, + JDBCVectorStoreRecordCollectionOptions.builder() + .withRecordClass(recordClass) + .withRecordDefinition(recordDefinition) + .withQueryProvider(this.queryProvider) + .build()); + } + + return new JDBCVectorStoreRecordCollection<>( + dataSource, + collectionName, + JDBCVectorStoreRecordCollectionOptions.builder() + .withRecordClass(recordClass) + .withRecordDefinition(recordDefinition) + .withQueryProvider(this.queryProvider) + .build()); + } + + /** + * Gets the names of all collections in the vector store. + * + * @return A list of collection names. + */ + @Override + public Mono> getCollectionNamesAsync() { + return Mono.fromCallable(queryProvider::getCollectionNames) + .subscribeOn(Schedulers.boundedElastic()); + } + + /** + * Prepares the vector store. + */ + @Override + public Mono prepareAsync() { + return Mono.fromRunnable(queryProvider::prepareVectorStore) + .subscribeOn(Schedulers.boundedElastic()).then(); + } + + /** + * Builder for creating a {@link JDBCVectorStore}. + */ + public static class Builder { + private DataSource dataSource; + private JDBCVectorStoreOptions options; + + /** + * Sets the data source. + * + * @param dataSource the data source + * @return the builder + */ + @SuppressFBWarnings("EI_EXPOSE_REP2") + public Builder withDataSource(DataSource dataSource) { + this.dataSource = dataSource; + return this; + } + + /** + * Sets the options. + * + * @param options the options + * @return the builder + */ + public Builder withOptions(JDBCVectorStoreOptions options) { + this.options = options; + return this; + } + + /** + * Builds the {@link JDBCVectorStore}. + * + * @return the {@link JDBCVectorStore} + */ + public JDBCVectorStore build() { + return buildAsync().block(); + } + + /** + * Builds the {@link JDBCVectorStore} asynchronously. + * + * @return the {@link Mono} with the {@link JDBCVectorStore} + */ + public Mono buildAsync() { + if (dataSource == null) { + throw new IllegalArgumentException("dataSource is required"); + } + + JDBCVectorStore vectorStore = new JDBCVectorStore(dataSource, options); + return vectorStore.prepareAsync().thenReturn(vectorStore); + } + } +} diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreDefaultQueryProvider.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreDefaultQueryProvider.java new file mode 100644 index 00000000..f1795083 --- /dev/null +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreDefaultQueryProvider.java @@ -0,0 +1,453 @@ +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.semantickernel.connectors.data.jdbc; + +import com.microsoft.semantickernel.data.VectorStoreRecordMapper; +import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDefinition; +import com.microsoft.semantickernel.exceptions.SKException; +import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordField; +import com.microsoft.semantickernel.data.recordoptions.DeleteRecordOptions; +import com.microsoft.semantickernel.data.recordoptions.GetRecordOptions; +import com.microsoft.semantickernel.data.recordoptions.UpsertRecordOptions; +import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; + +import javax.annotation.Nonnull; +import javax.sql.DataSource; +import java.lang.reflect.Field; +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.time.OffsetDateTime; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class JDBCVectorStoreDefaultQueryProvider + implements JDBCVectorStoreQueryProvider { + + private Map, String> supportedKeyTypes; + private Map, String> supportedDataTypes; + private Map, String> supportedVectorTypes; + private final DataSource dataSource; + private final String collectionsTable; + private final String prefixForCollectionTables; + + @SuppressFBWarnings("EI_EXPOSE_REP2") // DataSource is not exposed + protected JDBCVectorStoreDefaultQueryProvider( + @Nonnull DataSource dataSource, + @Nonnull String collectionsTable, + @Nonnull String prefixForCollectionTables) { + this.dataSource = dataSource; + this.collectionsTable = collectionsTable; + this.prefixForCollectionTables = prefixForCollectionTables; + + supportedKeyTypes = new HashMap<>(); + supportedKeyTypes.put(String.class, "VARCHAR(255)"); + + supportedDataTypes = new HashMap<>(); + supportedDataTypes.put(String.class, "TEXT"); + supportedDataTypes.put(Integer.class, "INTEGER"); + supportedDataTypes.put(int.class, "INTEGER"); + supportedDataTypes.put(Long.class, "BIGINT"); + supportedDataTypes.put(long.class, "BIGINT"); + supportedDataTypes.put(Float.class, "REAL"); + supportedDataTypes.put(float.class, "REAL"); + supportedDataTypes.put(Double.class, "DOUBLE"); + supportedDataTypes.put(double.class, "DOUBLE"); + supportedDataTypes.put(Boolean.class, "BOOLEAN"); + supportedDataTypes.put(boolean.class, "BOOLEAN"); + supportedDataTypes.put(OffsetDateTime.class, "TIMESTAMPTZ"); + + supportedVectorTypes = new HashMap<>(); + supportedVectorTypes.put(String.class, "TEXT"); + supportedVectorTypes.put(List.class, "TEXT"); + supportedVectorTypes.put(Collection.class, "TEXT"); + } + + /** + * Creates a new builder. + * @return the builder + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Formats a wildcard string for a query. + * @param wildcards the number of wildcards + * @return the formatted wildcard string + */ + protected String getWildcardString(int wildcards) { + return Stream.generate(() -> "?") + .limit(wildcards) + .collect(Collectors.joining(", ")); + } + + /** + * Formats the query columns from a record definition. + * @param fields the fields to get the columns from + * @return the formatted query columns + */ + protected String getQueryColumnsFromFields(List fields) { + return fields.stream().map(VectorStoreRecordField::getName) + .collect(Collectors.joining(", ")); + } + + /** + * Formats the column names and types for a table. + * @param fields the fields + * @param types the types + * @return the formatted column names and types + */ + protected String getColumnNamesAndTypes(List fields, Map, String> types) { + List columns = fields.stream() + .map(field -> field.getName() + " " + types.get(field.getType())) + .collect(Collectors.toList()); + + return String.join(", ", columns); + } + + protected String getCollectionTableName(String collectionName) { + return validateSQLidentifier(prefixForCollectionTables + collectionName); + } + + /** + * Gets the supported key types and their corresponding SQL types. + * + * @return the supported key types + */ + @Override + public Map, String> getSupportedKeyTypes() { + return new HashMap<>(this.supportedKeyTypes); + } + + /** + * Gets the supported data types and their corresponding SQL types. + * + * @return the supported data types + */ + @Override + public Map, String> getSupportedDataTypes() { + return new HashMap<>(this.supportedDataTypes); + } + + /** + * Gets the supported vector types and their corresponding SQL types. + * + * @return the supported vector types + */ + @Override + public Map, String> getSupportedVectorTypes() { + return new HashMap<>(this.supportedVectorTypes); + } + + /** + * Prepares the vector store. + * Executes any necessary setup steps for the vector store. + * + * @throws SKException if an error occurs while preparing the vector store + */ + @Override + public void prepareVectorStore() { + String createCollectionsTable = "CREATE TABLE IF NOT EXISTS " + + validateSQLidentifier(collectionsTable) + + " (collectionId VARCHAR(255) PRIMARY KEY);"; + + try (Connection connection = dataSource.getConnection(); + PreparedStatement createTable = connection.prepareStatement(createCollectionsTable)) { + createTable.execute(); + } catch (SQLException e) { + throw new SKException("Failed to prepare vector store", e); + } + } + + /** + * Checks if the types of the record class fields are supported. + * + * @param recordClass the record class + * @param recordDefinition the record definition + * @throws IllegalArgumentException if the types are not supported + */ + @Override + public void validateSupportedTypes(Class recordClass, + VectorStoreRecordDefinition recordDefinition) { + VectorStoreRecordDefinition.validateSupportedTypes( + Collections.singletonList(recordDefinition.getKeyDeclaredField(recordClass)), + getSupportedKeyTypes().keySet()); + VectorStoreRecordDefinition.validateSupportedTypes( + recordDefinition.getDataDeclaredFields(recordClass), getSupportedDataTypes().keySet()); + VectorStoreRecordDefinition.validateSupportedTypes( + recordDefinition.getVectorDeclaredFields(recordClass), + getSupportedVectorTypes().keySet()); + } + + /** + * Checks if a collection exists. + * + * @param collectionName the collection name + * @return true if the collection exists, false otherwise + * @throws SKException if an error occurs while checking if the collection exists + */ + @Override + public boolean collectionExists(String collectionName) { + String query = "SELECT 1 FROM " + validateSQLidentifier(collectionsTable) + + " WHERE collectionId = ?"; + + try (Connection connection = dataSource.getConnection(); + PreparedStatement statement = connection.prepareStatement(query)) { + statement.setObject(1, collectionName); + + return statement.executeQuery().next(); + } catch (SQLException e) { + throw new SKException("Failed to check if collection exists", e); + } + } + + /** + * Creates a collection. + * + * @param collectionName the collection name + * @param recordClass the record class + * @param recordDefinition the record definition + * @throws SKException if an error occurs while creating the collection + */ + @Override + @SuppressFBWarnings("SQL_PREPARED_STATEMENT_GENERATED_FROM_NONCONSTANT_STRING") // SQL query is generated dynamically with valid identifiers + public void createCollection(String collectionName, Class recordClass, + VectorStoreRecordDefinition recordDefinition) { + Field keyDeclaredField = recordDefinition.getKeyDeclaredField(recordClass); + List dataDeclaredFields = recordDefinition.getDataDeclaredFields(recordClass); + List vectorDeclaredFields = recordDefinition.getVectorDeclaredFields(recordClass); + + String createStorageTable = "CREATE TABLE IF NOT EXISTS " + + getCollectionTableName(collectionName) + + " (" + keyDeclaredField.getName() + " VARCHAR(255) PRIMARY KEY, " + + getColumnNamesAndTypes(dataDeclaredFields, getSupportedDataTypes()) + ", " + + getColumnNamesAndTypes(vectorDeclaredFields, getSupportedVectorTypes()) + ");"; + + String insertCollectionQuery = "INSERT INTO " + validateSQLidentifier(collectionsTable) + + " (collectionId) VALUES (?)"; + + try (Connection connection = dataSource.getConnection(); + PreparedStatement createTable = connection.prepareStatement(createStorageTable)) { + createTable.execute(); + } catch (SQLException e) { + throw new SKException("Failed to create collection", e); + } + + try (Connection connection = dataSource.getConnection(); + PreparedStatement insert = connection.prepareStatement(insertCollectionQuery)) { + insert.setObject(1, collectionName); + insert.execute(); + } catch (SQLException e) { + throw new SKException("Failed to insert collection", e); + } + } + + /** + * Deletes a collection. + * + * @param collectionName the collection name + * @throws SKException if an error occurs while deleting the collection + */ + @Override + public void deleteCollection(String collectionName) { + String deleteCollectionOperation = "DELETE FROM " + validateSQLidentifier(collectionsTable) + + " WHERE collectionId = ?"; + String dropTableOperation = "DROP TABLE " + getCollectionTableName(collectionName); + + try (Connection connection = dataSource.getConnection(); + PreparedStatement deleteCollection = connection + .prepareStatement(deleteCollectionOperation)) { + deleteCollection.setObject(1, collectionName); + deleteCollection.execute(); + } catch (SQLException e) { + throw new SKException("Failed to delete collection", e); + } + + try (Connection connection = dataSource.getConnection(); + PreparedStatement dropTable = connection.prepareStatement(dropTableOperation)) { + dropTable.execute(); + } catch (SQLException e) { + throw new SKException("Failed to drop table", e); + } + } + + /** + * Gets the collection names. + * + * @return the collection names + * @throws SKException if an error occurs while getting the collection names + */ + @Override + public List getCollectionNames() { + String query = "SELECT collectionId FROM " + validateSQLidentifier(collectionsTable); + + try (Connection connection = dataSource.getConnection(); + PreparedStatement statement = connection.prepareStatement(query)) { + List collectionNames = new ArrayList<>(); + ResultSet resultSet = statement.executeQuery(); + + while (resultSet.next()) { + collectionNames.add(resultSet.getString(1)); + } + + return Collections.unmodifiableList(collectionNames); + } catch (SQLException e) { + throw new SKException("Failed to get collection names", e); + } + } + + /** + * Gets a list of records from the store. + * + * @param collectionName the collection name + * @param keys the keys + * @param recordDefinition the record definition + * @param mapper the mapper + * @param options the options + * @return the records + * @param the record type + * @throws SKException if an error occurs while getting the records + */ + @Override + public List getRecords(String collectionName, List keys, + VectorStoreRecordDefinition recordDefinition, + VectorStoreRecordMapper mapper, + GetRecordOptions options) { + List fields; + if (options == null || options.includeVectors()) { + fields = recordDefinition.getAllFields(); + } else { + fields = recordDefinition.getNonVectorFields(); + } + + String query = "SELECT " + getQueryColumnsFromFields(fields) + + " FROM " + getCollectionTableName(collectionName) + + " WHERE " + recordDefinition.getKeyField().getName() + + " IN (" + getWildcardString(keys.size()) + ")"; + + try (Connection connection = dataSource.getConnection(); + PreparedStatement statement = connection.prepareStatement(query)) { + for (int i = 0; i < keys.size(); ++i) { + statement.setObject(i + 1, keys.get(i)); + } + + List records = new ArrayList<>(); + ResultSet resultSet = statement.executeQuery(); + + while (resultSet.next()) { + records.add(mapper.mapStorageModeltoRecord(resultSet)); + } + + return Collections.unmodifiableList(records); + } catch (SQLException e) { + throw new SKException("Failed to set statement values", e); + } + } + + @Override + public void upsertRecords(String collectionName, List records, + VectorStoreRecordDefinition recordDefinition, UpsertRecordOptions options) { + throw new UnsupportedOperationException( + "Upsert is not supported. Try with a specific query provider."); + } + + /** + * Deletes records. + * + * @param collectionName the collection name + * @param keys the keys + * @param recordDefinition the record definition + * @param options the options + * @throws SKException if an error occurs while deleting the records + */ + @Override + public void deleteRecords(String collectionName, List keys, + VectorStoreRecordDefinition recordDefinition, DeleteRecordOptions options) { + String query = "DELETE FROM " + getCollectionTableName(collectionName) + + " WHERE " + recordDefinition.getKeyField().getName() + + " IN (" + getWildcardString(keys.size()) + ")"; + + try (Connection connection = dataSource.getConnection(); + PreparedStatement statement = connection.prepareStatement(query)) { + for (int i = 0; i < keys.size(); ++i) { + statement.setObject(i + 1, keys.get(i)); + } + + statement.execute(); + } catch (SQLException e) { + throw new SKException("Failed to set statement values", e); + } + } + + /** + * Validates an SQL identifier. + * + * @param identifier the identifier + * @return the identifier if it is valid + * @throws IllegalArgumentException if the identifier is invalid + */ + public static String validateSQLidentifier(String identifier) { + if (identifier.matches("[a-zA-Z_][a-zA-Z0-9_]*")) { + return identifier; + } + throw new IllegalArgumentException("Invalid SQL identifier: " + identifier); + } + + /** + * The builder for {@link JDBCVectorStoreDefaultQueryProvider}. + */ + public static class Builder + implements JDBCVectorStoreQueryProvider.Builder { + private DataSource dataSource; + private String collectionsTable = DEFAULT_COLLECTIONS_TABLE; + private String prefixForCollectionTables = DEFAULT_PREFIX_FOR_COLLECTION_TABLES; + + /** + * Sets the data source. + * @param dataSource the data source + * @return the builder + */ + @SuppressFBWarnings("EI_EXPOSE_REP2") // DataSource is not exposed + public Builder withDataSource(DataSource dataSource) { + this.dataSource = dataSource; + return this; + } + + /** + * Sets the collections table name. + * @param collectionsTable the collections table name + * @return the builder + */ + public Builder withCollectionsTable(String collectionsTable) { + this.collectionsTable = validateSQLidentifier(collectionsTable); + return this; + } + + /** + * Sets the prefix for collection tables. + * @param prefixForCollectionTables the prefix for collection tables + * @return the builder + */ + public Builder withPrefixForCollectionTables(String prefixForCollectionTables) { + this.prefixForCollectionTables = validateSQLidentifier(prefixForCollectionTables); + return this; + } + + @Override + public JDBCVectorStoreDefaultQueryProvider build() { + if (dataSource == null) { + throw new IllegalArgumentException("DataSource is required"); + } + + return new JDBCVectorStoreDefaultQueryProvider(dataSource, collectionsTable, + prefixForCollectionTables); + } + } +} diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreOptions.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreOptions.java new file mode 100644 index 00000000..adb6e13c --- /dev/null +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreOptions.java @@ -0,0 +1,107 @@ +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.semantickernel.connectors.data.jdbc; + +import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; + +import javax.annotation.Nullable; + +public class JDBCVectorStoreOptions { + @Nullable + private final JDBCVectorStoreRecordCollectionFactory vectorStoreRecordCollectionFactory; + @Nullable + private final JDBCVectorStoreQueryProvider queryProvider; + + /** + * Creates a new instance of the JDBC vector store options. + * + * @param vectorStoreRecordCollectionFactory The vector store record collection factory. + */ + @SuppressFBWarnings("EI_EXPOSE_REP2") // DataSource in queryProvider is not exposed + public JDBCVectorStoreOptions( + @Nullable JDBCVectorStoreQueryProvider queryProvider, + @Nullable JDBCVectorStoreRecordCollectionFactory vectorStoreRecordCollectionFactory) { + this.queryProvider = queryProvider; + this.vectorStoreRecordCollectionFactory = vectorStoreRecordCollectionFactory; + } + + /** + * Creates a new instance of the JDBC vector store options. + */ + public JDBCVectorStoreOptions() { + this(null, null); + } + + /** + * Gets the query provider. + * + * @return the query provider + */ + @Nullable + @SuppressFBWarnings("EI_EXPOSE_REP") // DataSource in queryProvider is not exposed + public JDBCVectorStoreQueryProvider getQueryProvider() { + return queryProvider; + } + + /** + * Creates a new builder. + * + * @return the builder + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Gets the vector store record collection factory. + * + * @return the vector store record collection factory + */ + @Nullable + public JDBCVectorStoreRecordCollectionFactory getVectorStoreRecordCollectionFactory() { + return vectorStoreRecordCollectionFactory; + } + + /** + * Builder for JDBC vector store options. + * + */ + public static class Builder { + @Nullable + private JDBCVectorStoreQueryProvider queryProvider; + @Nullable + private JDBCVectorStoreRecordCollectionFactory vectorStoreRecordCollectionFactory; + + /** + * Sets the query provider. + * + * @param queryProvider The query provider. + * @return The updated builder instance. + */ + @SuppressFBWarnings("EI_EXPOSE_REP2") // DataSource in queryProvider is not exposed + public Builder withQueryProvider(JDBCVectorStoreQueryProvider queryProvider) { + this.queryProvider = queryProvider; + return this; + } + + /** + * Sets the vector store record collection factory. + * + * @param vectorStoreRecordCollectionFactory The vector store record collection factory. + * @return The updated builder instance. + */ + public Builder withVectorStoreRecordCollectionFactory( + JDBCVectorStoreRecordCollectionFactory vectorStoreRecordCollectionFactory) { + this.vectorStoreRecordCollectionFactory = vectorStoreRecordCollectionFactory; + return this; + } + + /** + * Builds the JDBC vector store options. + * + * @return The JDBC vector store options. + */ + public JDBCVectorStoreOptions build() { + return new JDBCVectorStoreOptions(queryProvider, vectorStoreRecordCollectionFactory); + } + } +} diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreQueryProvider.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreQueryProvider.java new file mode 100644 index 00000000..6009b885 --- /dev/null +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreQueryProvider.java @@ -0,0 +1,140 @@ +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.semantickernel.connectors.data.jdbc; + +import com.microsoft.semantickernel.builders.SemanticKernelBuilder; +import com.microsoft.semantickernel.data.VectorStoreRecordMapper; +import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDefinition; +import com.microsoft.semantickernel.data.recordoptions.DeleteRecordOptions; +import com.microsoft.semantickernel.data.recordoptions.GetRecordOptions; +import com.microsoft.semantickernel.data.recordoptions.UpsertRecordOptions; + +import java.sql.ResultSet; +import java.util.List; +import java.util.Map; + +/** + * The JDBC vector store query provider. + * Provides the necessary methods to interact with a JDBC vector store and vector store collections. + */ +public interface JDBCVectorStoreQueryProvider { + /** + * The default name for the collections table. + */ + String DEFAULT_COLLECTIONS_TABLE = "SKCollections"; + + /** + * The prefix for collection tables. + */ + String DEFAULT_PREFIX_FOR_COLLECTION_TABLES = "SKCollection_"; + + /** + * Gets the supported key types and their corresponding SQL types. + * + * @return the supported key types + */ + Map, String> getSupportedKeyTypes(); + + /** + * Gets the supported data types and their corresponding SQL types. + * + * @return the supported data types + */ + Map, String> getSupportedDataTypes(); + + /** + * Gets the supported vector types and their corresponding SQL types. + * + * @return the supported vector types + */ + Map, String> getSupportedVectorTypes(); + + /** + * Prepares the vector store. + * Executes any necessary setup steps for the vector store. + */ + void prepareVectorStore(); + + /** + * Checks if the types of the record class fields are supported. + * + * @param recordClass the record class + * @param recordDefinition the record definition + */ + void validateSupportedTypes(Class recordClass, VectorStoreRecordDefinition recordDefinition); + + /** + * Checks if a collection exists. + * + * @param collectionName the collection name + * @return true if the collection exists, false otherwise + */ + boolean collectionExists(String collectionName); + + /** + * Creates a collection. + * + * @param collectionName the collection name + * @param recordClass the record class + * @param recordDefinition the record definition + */ + void createCollection(String collectionName, Class recordClass, + VectorStoreRecordDefinition recordDefinition); + + /** + * Deletes a collection. + * + * @param collectionName the collection name + */ + void deleteCollection(String collectionName); + + /** + * Gets the collection names. + * + * @return the collection names + */ + List getCollectionNames(); + + /** + * Gets records. + * + * @param collectionName the collection name + * @param keys the keys + * @param recordDefinition the record definition + * @param mapper the mapper + * @param options the options + * @return the records + */ + List getRecords(String collectionName, List keys, + VectorStoreRecordDefinition recordDefinition, + VectorStoreRecordMapper mapper, + GetRecordOptions options); + + /** + * Upserts records. + * + * @param collectionName the collection name + * @param records the records + * @param vectorStoreRecordDefinition the record definition + * @param options the options + */ + void upsertRecords(String collectionName, List records, + VectorStoreRecordDefinition vectorStoreRecordDefinition, UpsertRecordOptions options); + + /** + * Deletes records. + * + * @param collectionName the collection name + * @param keys the keys + * @param recordDefinition the record definition + * @param options the options + */ + void deleteRecords(String collectionName, List keys, + VectorStoreRecordDefinition recordDefinition, DeleteRecordOptions options); + + /** + * The builder for the JDBC vector store query provider. + */ + interface Builder extends SemanticKernelBuilder { + + } +} diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollection.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollection.java new file mode 100644 index 00000000..44fc2338 --- /dev/null +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollection.java @@ -0,0 +1,345 @@ +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.semantickernel.connectors.data.jdbc; + +import com.microsoft.semantickernel.builders.SemanticKernelBuilder; +import com.microsoft.semantickernel.connectors.data.mysql.MySQLVectorStoreQueryProvider; +import com.microsoft.semantickernel.connectors.data.postgres.PostgreSQLVectorStoreQueryProvider; +import com.microsoft.semantickernel.connectors.data.postgres.PostgreSQLVectorStoreRecordMapper; +import com.microsoft.semantickernel.data.VectorStoreRecordMapper; +import com.microsoft.semantickernel.data.VectorStoreRecordCollection; +import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDefinition; +import com.microsoft.semantickernel.data.recordoptions.DeleteRecordOptions; +import com.microsoft.semantickernel.data.recordoptions.GetRecordOptions; +import com.microsoft.semantickernel.data.recordoptions.UpsertRecordOptions; +import com.microsoft.semantickernel.exceptions.SKException; +import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; +import java.lang.reflect.Field; +import java.sql.ResultSet; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; +import javax.annotation.Nonnull; +import javax.sql.DataSource; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; + +public class JDBCVectorStoreRecordCollection + implements SQLVectorStoreRecordCollection { + + private final String collectionName; + private final VectorStoreRecordDefinition recordDefinition; + private final VectorStoreRecordMapper vectorStoreRecordMapper; + private final JDBCVectorStoreRecordCollectionOptions options; + private final JDBCVectorStoreQueryProvider queryProvider; + + /** + * Creates a new instance of the {@link JDBCVectorStoreRecordCollection}. + * + * @param dataSource the data source + * @param collectionName the name of the collection + * @param options the options + */ + @SuppressFBWarnings("EI_EXPOSE_REP2") // DataSource is not exposed + public JDBCVectorStoreRecordCollection( + @Nonnull DataSource dataSource, + @Nonnull String collectionName, + @Nonnull JDBCVectorStoreRecordCollectionOptions options) { + this.collectionName = collectionName; + this.options = options; + + // If record definition is not provided, create one from the record class + recordDefinition = options.getRecordDefinition() == null + ? VectorStoreRecordDefinition.fromRecordClass(options.getRecordClass()) + : options.getRecordDefinition(); + + // If the query provider is not provided, set a default one + if (options.getQueryProvider() == null) { + this.queryProvider = JDBCVectorStoreDefaultQueryProvider.builder() + .withDataSource(dataSource) + .build(); + } else { + this.queryProvider = options.getQueryProvider(); + } + + // If mapper is not provided, set a default one + if (options.getVectorStoreRecordMapper() == null) { + // Default mapper for PostgreSQL + if (this.queryProvider instanceof PostgreSQLVectorStoreQueryProvider) { + vectorStoreRecordMapper = PostgreSQLVectorStoreRecordMapper.builder() + .withRecordClass(options.getRecordClass()) + .withVectorStoreRecordDefinition(recordDefinition) + .build(); + // Default mapper for MySQL + } else if (this.queryProvider instanceof MySQLVectorStoreQueryProvider) { + vectorStoreRecordMapper = JDBCVectorStoreRecordMapper.builder() + .withRecordClass(options.getRecordClass()) + .withVectorStoreRecordDefinition(recordDefinition) + .build(); + // Default mapper for other databases + } else { + vectorStoreRecordMapper = JDBCVectorStoreRecordMapper.builder() + .withRecordClass(options.getRecordClass()) + .withVectorStoreRecordDefinition(recordDefinition) + .build(); + } + } else { + vectorStoreRecordMapper = options.getVectorStoreRecordMapper(); + } + + // Check if the types are supported + queryProvider.validateSupportedTypes(options.getRecordClass(), recordDefinition); + } + + /** + * Gets the name of the collection. + * + * @return The name of the collection. + */ + @Override + public String getCollectionName() { + return collectionName; + } + + /** + * Checks if the collection exists in the store. + * + * @return A Mono emitting a boolean indicating if the collection exists. + * @throws SKException if the operation fails + */ + @Override + public Mono collectionExistsAsync() { + return Mono.fromCallable( + () -> queryProvider.collectionExists(this.collectionName)) + .subscribeOn(Schedulers.boundedElastic()); + } + + /** + * Creates the collection in the store. + * + * @return A Mono representing the completion of the creation operation. + * @throws SKException if the operation fails + */ + @Override + public Mono> createCollectionAsync() { + return Mono.fromRunnable( + () -> queryProvider.createCollection(this.collectionName, options.getRecordClass(), + recordDefinition)) + .subscribeOn(Schedulers.boundedElastic()) + .then(Mono.just(this)); + } + + /** + * Creates the collection in the store if it does not exist. + * + * @return A Mono representing the completion of the creation operation. + * @throws SKException if the operation fails + */ + @Override + public Mono> createCollectionIfNotExistsAsync() { + return collectionExistsAsync().map( + exists -> { + if (!exists) { + return createCollectionAsync(); + } + return Mono.empty(); + }) + .flatMap(mono -> mono) + .then(Mono.just(this)); + } + + /** + * Deletes the collection from the store. + * + * @return A Mono representing the completion of the deletion operation. + * @throws SKException if the operation fails + */ + @Override + public Mono deleteCollectionAsync() { + return Mono.fromRunnable( + () -> { + queryProvider.deleteCollection(this.collectionName); + }).subscribeOn(Schedulers.boundedElastic()).then(); + } + + /** + * Gets a record from the store. + * + * @param key The key of the record to get. + * @param options The options for getting the record. + * @return A Mono emitting the record. + * @throws SKException if the operation fails + */ + @Override + public Mono getAsync(String key, GetRecordOptions options) { + return this.getBatchAsync(Collections.singletonList(key), options) + .mapNotNull(records -> { + if (records.isEmpty()) { + return null; + } + return records.get(0); + }); + } + + /** + * Gets a batch of records from the store. + * + * @param keys The keys of the records to get. + * @param options The options for getting the records. + * @return A Mono emitting a collection of records. + * @throws SKException if the operation fails + */ + @Override + public Mono> getBatchAsync(List keys, GetRecordOptions options) { + return Mono.fromCallable( + () -> { + return queryProvider.getRecords(this.collectionName, keys, recordDefinition, + vectorStoreRecordMapper, options); + }).subscribeOn(Schedulers.boundedElastic()); + } + + protected String getKeyFromRecord(Record data) { + try { + Field keyField = data.getClass() + .getDeclaredField(recordDefinition.getKeyField().getName()); + keyField.setAccessible(true); + return (String) keyField.get(data); + } catch (NoSuchFieldException | IllegalAccessException e) { + throw new SKException("Failed to get key from record", e); + } + } + + /** + * Inserts or updates a record in the store. + * + * @param data The record to upsert. + * @param options The options for upserting the record. + * @return A Mono emitting the key of the upserted record. + * @throws SKException if the operation fails + */ + @Override + public Mono upsertAsync(Record data, UpsertRecordOptions options) { + return this.upsertBatchAsync(Collections.singletonList(data), options) + .mapNotNull(keys -> { + if (keys.isEmpty()) { + return null; + } + return keys.get(0); + }); + } + + /** + * Inserts or updates a batch of records in the store. + * + * @param data The records to upsert. + * @param options The options for upserting the records. + * @return A Mono emitting a collection of keys of the upserted records. + * @throws SKException if the operation fails + */ + @Override + public Mono> upsertBatchAsync(List data, UpsertRecordOptions options) { + return Mono.fromCallable( + () -> { + queryProvider.upsertRecords(this.collectionName, data, recordDefinition, options); + return data.stream().map(this::getKeyFromRecord).collect(Collectors.toList()); + }) + .subscribeOn(Schedulers.boundedElastic()); + } + + /** + * Deletes a record from the store. + * + * @param key The key of the record to delete. + * @param options The options for deleting the record. + * @return A Mono representing the completion of the deletion operation. + * @throws SKException if the operation fails + */ + @Override + public Mono deleteAsync(String key, DeleteRecordOptions options) { + return this.deleteBatchAsync(Collections.singletonList(key), options); + } + + /** + * Deletes a batch of records from the store. + * + * @param keys The keys of the records to delete. + * @param options The options for deleting the records. + * @return A Mono representing the completion of the deletion operation. + * @throws SKException if the operation fails + */ + @Override + public Mono deleteBatchAsync(List keys, DeleteRecordOptions options) { + return Mono.fromRunnable( + () -> { + queryProvider.deleteRecords(this.collectionName, keys, recordDefinition, options); + }).subscribeOn(Schedulers.boundedElastic()).then(); + } + + /** + * Prepares the collection for use. + * + * @return A Mono representing the completion of the preparation operation. + * @throws SKException if the operation fails + */ + @Override + public Mono prepareAsync() { + return Mono.fromRunnable(queryProvider::prepareVectorStore) + .subscribeOn(Schedulers.boundedElastic()).then(); + } + + public static class Builder + implements SemanticKernelBuilder> { + + private DataSource dataSource; + private String collectionName; + private JDBCVectorStoreRecordCollectionOptions options; + + /** + * Sets the data source. + * + * @param dataSource the data source + * @return the builder + */ + @SuppressFBWarnings("EI_EXPOSE_REP2") // DataSource is not exposed + public Builder withDataSource(DataSource dataSource) { + this.dataSource = dataSource; + return this; + } + + /** + * Sets the collection name. + * + * @param collectionName the collection name + * @return the builder + */ + public Builder withCollectionName(String collectionName) { + this.collectionName = collectionName; + return this; + } + + /** + * Sets the options. + * + * @param options the options + * @return the builder + */ + public Builder withOptions(JDBCVectorStoreRecordCollectionOptions options) { + this.options = options; + return this; + } + + @Override + public JDBCVectorStoreRecordCollection build() { + if (dataSource == null) { + throw new IllegalArgumentException("dataSource is required"); + } + if (collectionName == null) { + throw new IllegalArgumentException("collectionName is required"); + } + if (options == null) { + throw new IllegalArgumentException("options is required"); + } + + return new JDBCVectorStoreRecordCollection<>(dataSource, collectionName, options); + } + } +} diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollectionFactory.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollectionFactory.java new file mode 100644 index 00000000..6cfcdcad --- /dev/null +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollectionFactory.java @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.semantickernel.connectors.data.jdbc; + +import javax.sql.DataSource; + +/** + * Factory for creating JDBC vector store record collections. + */ +public interface JDBCVectorStoreRecordCollectionFactory { + /** + * Creates a new JDBC vector store record collection. + * + * @param options The options for the collection. + * @return The new JDBC vector store record collection. + */ + JDBCVectorStoreRecordCollection createVectorStoreRecordCollection( + DataSource dataSource, + String collectionName, + JDBCVectorStoreRecordCollectionOptions options); +} diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollectionOptions.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollectionOptions.java new file mode 100644 index 00000000..f6aa871d --- /dev/null +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordCollectionOptions.java @@ -0,0 +1,183 @@ +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.semantickernel.connectors.data.jdbc; + +import com.microsoft.semantickernel.data.VectorStoreRecordMapper; +import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDefinition; +import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; + +import java.sql.ResultSet; + +import static com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreDefaultQueryProvider.validateSQLidentifier; +import static com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreQueryProvider.DEFAULT_COLLECTIONS_TABLE; +import static com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreQueryProvider.DEFAULT_PREFIX_FOR_COLLECTION_TABLES; + +public class JDBCVectorStoreRecordCollectionOptions { + private final Class recordClass; + private final VectorStoreRecordMapper vectorStoreRecordMapper; + private final VectorStoreRecordDefinition recordDefinition; + private final JDBCVectorStoreQueryProvider queryProvider; + private final String collectionsTableName; + private final String prefixForCollectionTables; + + private JDBCVectorStoreRecordCollectionOptions( + Class recordClass, + VectorStoreRecordDefinition recordDefinition, + VectorStoreRecordMapper vectorStoreRecordMapper, + JDBCVectorStoreQueryProvider queryProvider, + String collectionsTableName, + String prefixForCollectionTables) { + this.recordClass = recordClass; + this.recordDefinition = recordDefinition; + this.vectorStoreRecordMapper = vectorStoreRecordMapper; + this.queryProvider = queryProvider; + this.collectionsTableName = collectionsTableName; + this.prefixForCollectionTables = prefixForCollectionTables; + } + + /** + * Creates a new builder. + * @param the record type + * @return the builder + */ + public static Builder builder() { + return new Builder<>(); + } + + /** + * Gets the record class. + * @return the record class + */ + public Class getRecordClass() { + return recordClass; + } + + /** + * Gets the record definition. + * @return the record definition + */ + public VectorStoreRecordDefinition getRecordDefinition() { + return recordDefinition; + } + + /** + * Gets the vector store record mapper. + * @return the vector store record mapper + */ + public VectorStoreRecordMapper getVectorStoreRecordMapper() { + return vectorStoreRecordMapper; + } + + /** + * Gets the collections table. + * @return the collections table + */ + public String getCollectionsTableName() { + return collectionsTableName; + } + + /** + * Gets the prefix for collection tables. + * @return the prefix for collection tables + */ + public String getPrefixForCollectionTables() { + return prefixForCollectionTables; + } + + /** + * Gets the query provider. + * @return the query provider + */ + @SuppressFBWarnings("EI_EXPOSE_REP") // DataSource in queryProvider is not exposed + public JDBCVectorStoreQueryProvider getQueryProvider() { + return queryProvider; + } + + public static class Builder { + private Class recordClass; + private VectorStoreRecordDefinition recordDefinition; + private VectorStoreRecordMapper vectorStoreRecordMapper; + private JDBCVectorStoreQueryProvider queryProvider; + private String collectionsTableName = DEFAULT_COLLECTIONS_TABLE; + private String prefixForCollectionTables = DEFAULT_PREFIX_FOR_COLLECTION_TABLES; + + /** + * Sets the record class. + * @param recordClass the record class + * @return the builder + */ + public Builder withRecordClass(Class recordClass) { + this.recordClass = recordClass; + return this; + } + + /** + * Sets the record definition. + * @param recordDefinition the record definition + * @return the builder + */ + public Builder withRecordDefinition(VectorStoreRecordDefinition recordDefinition) { + this.recordDefinition = recordDefinition; + return this; + } + + /** + * Sets the vector store record mapper. + * @param vectorStoreRecordMapper the vector store record mapper + * @return the builder + */ + public Builder withVectorStoreRecordMapper( + VectorStoreRecordMapper vectorStoreRecordMapper) { + this.vectorStoreRecordMapper = vectorStoreRecordMapper; + return this; + } + + /** + * Sets the query provider. + * @param queryProvider the query provider + * @return the builder + */ + @SuppressFBWarnings("EI_EXPOSE_REP2") // DataSource in queryProvider is not exposed + public Builder withQueryProvider(JDBCVectorStoreQueryProvider queryProvider) { + this.queryProvider = queryProvider; + return this; + } + + /** + * Sets the collections table name. + * @param collectionsTableName the collections table name + * @return the builder + */ + public Builder withCollectionsTableName(String collectionsTableName) { + this.collectionsTableName = validateSQLidentifier(collectionsTableName); + return this; + } + + /** + * Sets the prefix for collection tables. + * @param prefixForCollectionTables the prefix for collection tables + * @return the builder + */ + public Builder withPrefixForCollectionTables(String prefixForCollectionTables) { + this.prefixForCollectionTables = validateSQLidentifier(prefixForCollectionTables); + return this; + } + + /** + * Builds the options. + * @return the options + */ + public JDBCVectorStoreRecordCollectionOptions build() { + if (recordClass == null) { + throw new IllegalArgumentException("recordClass is required"); + } + + return new JDBCVectorStoreRecordCollectionOptions<>( + recordClass, + recordDefinition, + vectorStoreRecordMapper, + queryProvider, + collectionsTableName, + prefixForCollectionTables); + } + } +} diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordMapper.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordMapper.java new file mode 100644 index 00000000..6eff0c7d --- /dev/null +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreRecordMapper.java @@ -0,0 +1,150 @@ +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.semantickernel.connectors.data.jdbc; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.microsoft.semantickernel.builders.SemanticKernelBuilder; +import com.microsoft.semantickernel.exceptions.SKException; +import com.microsoft.semantickernel.data.VectorStoreRecordMapper; +import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDefinition; +import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordField; +import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordVectorField; + +import java.sql.ResultSetMetaData; +import java.util.List; +import java.lang.reflect.Constructor; +import java.lang.reflect.Field; +import java.lang.reflect.InvocationTargetException; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.function.Function; + +public class JDBCVectorStoreRecordMapper + extends VectorStoreRecordMapper { + + /** + * Constructs a new instance of the VectorStoreRecordMapper. + * + * @param storageModelToRecordMapper the function to convert a storage model to a record + */ + protected JDBCVectorStoreRecordMapper(Function storageModelToRecordMapper) { + super(null, storageModelToRecordMapper); + } + + /** + * Creates a new builder. + * + * @param the record type + * @return the builder + */ + public static Builder builder() { + return new Builder<>(); + } + + /** + * Operation not supported. + */ + @Override + public ResultSet mapRecordToStorageModel(Record record) { + throw new UnsupportedOperationException("Not implemented"); + } + + public static class Builder + implements SemanticKernelBuilder> { + private Class recordClass; + private VectorStoreRecordDefinition vectorStoreRecordDefinition; + + /** + * Sets the record class. + * + * @param recordClass the record class + * @return the builder + */ + public Builder withRecordClass(Class recordClass) { + this.recordClass = recordClass; + return this; + } + + /** + * Sets the vector store record definition. + * + * @param vectorStoreRecordDefinition the vector store record definition + * @return the builder + */ + public Builder withVectorStoreRecordDefinition( + VectorStoreRecordDefinition vectorStoreRecordDefinition) { + this.vectorStoreRecordDefinition = vectorStoreRecordDefinition; + return this; + } + + /** + * Builds the {@link JDBCVectorStoreRecordMapper}. + * + * @return the {@link JDBCVectorStoreRecordMapper} + */ + public JDBCVectorStoreRecordMapper build() { + if (recordClass == null) { + throw new IllegalArgumentException("recordClass is required"); + } + if (vectorStoreRecordDefinition == null) { + throw new IllegalArgumentException("vectorStoreRecordDefinition is required"); + } + + return new JDBCVectorStoreRecordMapper<>( + resultSet -> { + try { + Constructor constructor = recordClass.getDeclaredConstructor(); + constructor.setAccessible(true); + Record record = (Record) constructor.newInstance(); + + // Select fields from the record definition. + // Check if vector fields are present in the result set. + List fields; + ResultSetMetaData metaData = resultSet.getMetaData(); + if (metaData.getColumnCount() == vectorStoreRecordDefinition.getAllFields() + .size()) { + fields = vectorStoreRecordDefinition.getAllFields(); + } else { + fields = vectorStoreRecordDefinition.getNonVectorFields(); + } + + for (VectorStoreRecordField field : fields) { + Object value = resultSet.getObject(field.getName()); + Field recordField = recordClass.getDeclaredField(field.getName()); + recordField.setAccessible(true); + + // If the field is a vector field, deserialize the JSON string + if (field instanceof VectorStoreRecordVectorField) { + Class vectorType = recordField.getType(); + + // If the vector type is a string, set the value directly + if (vectorType.equals(String.class)) { + recordField.set(record, value); + } else { + // Deserialize the JSON string to the vector type + recordField.set(record, + new ObjectMapper().readValue((String) value, vectorType)); + } + } else { + recordField.set(record, value); + } + } + + return record; + } catch (NoSuchMethodException e) { + throw new SKException("Default constructor not found.", e); + } catch (InstantiationException | InvocationTargetException e) { + throw new SKException(String.format( + "SK cannot instantiate %s. A custom mapper is required.", + recordClass.getName()), e); + } catch (JsonProcessingException e) { + throw new SKException(String.format( + "SK cannot deserialize %s. A custom mapper is required.", + recordClass.getName()), e); + } catch (SQLException | NoSuchFieldException | IllegalAccessException e) { + throw new RuntimeException(e); + } + }); + } + } +} diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/SQLVectorStore.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/SQLVectorStore.java new file mode 100644 index 00000000..046f9941 --- /dev/null +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/SQLVectorStore.java @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.semantickernel.connectors.data.jdbc; + +import com.microsoft.semantickernel.data.VectorStore; +import reactor.core.publisher.Mono; + +public interface SQLVectorStore + extends VectorStore { + + /** + * Prepares the vector store. + * + * @return A {@link Mono} that completes when the vector store is prepared to be used. + */ + Mono prepareAsync(); +} diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/SQLVectorStoreRecordCollection.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/SQLVectorStoreRecordCollection.java new file mode 100644 index 00000000..ff12c88b --- /dev/null +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/SQLVectorStoreRecordCollection.java @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.semantickernel.connectors.data.jdbc; + +import com.microsoft.semantickernel.data.VectorStoreRecordCollection; +import reactor.core.publisher.Mono; + +public interface SQLVectorStoreRecordCollection + extends VectorStoreRecordCollection { + + /** + * Prepares the vector store record collection. + * + * @return A {@link Mono} that completes when the vector store record collection is prepared to be used. + */ + Mono prepareAsync(); +} diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/mysql/MySQLVectorStoreQueryProvider.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/mysql/MySQLVectorStoreQueryProvider.java new file mode 100644 index 00000000..ff19017c --- /dev/null +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/mysql/MySQLVectorStoreQueryProvider.java @@ -0,0 +1,162 @@ +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.semantickernel.connectors.data.mysql; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreDefaultQueryProvider; +import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreQueryProvider; +import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDefinition; +import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordField; +import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordKeyField; +import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordVectorField; +import com.microsoft.semantickernel.data.recordoptions.UpsertRecordOptions; +import com.microsoft.semantickernel.exceptions.SKException; +import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; + +import javax.sql.DataSource; +import java.lang.reflect.Field; +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.SQLException; +import java.util.List; + +public class MySQLVectorStoreQueryProvider extends + JDBCVectorStoreDefaultQueryProvider implements JDBCVectorStoreQueryProvider { + + private final DataSource dataSource; + + @SuppressFBWarnings("EI_EXPOSE_REP2") + private MySQLVectorStoreQueryProvider(DataSource dataSource, String collectionsTable, + String prefixForCollectionTables) { + super(dataSource, collectionsTable, prefixForCollectionTables); + this.dataSource = dataSource; + } + + /** + * Creates a new builder. + * @return the builder + */ + public static Builder builder() { + return new Builder(); + } + + private void setStatementValues(PreparedStatement statement, Object record, + List fields) { + for (int i = 0; i < fields.size(); ++i) { + VectorStoreRecordField field = fields.get(i); + try { + Field recordField = record.getClass().getDeclaredField(field.getName()); + recordField.setAccessible(true); + Object value = recordField.get(record); + + if (field instanceof VectorStoreRecordKeyField) { + statement.setObject(i + 1, (String) value); + } else if (field instanceof VectorStoreRecordVectorField) { + Class vectorType = record.getClass().getDeclaredField(field.getName()) + .getType(); + + // If the vector field is other than String, serialize it to JSON + if (vectorType.equals(String.class)) { + statement.setObject(i + 1, value); + } else { + // Serialize the vector to JSON + statement.setObject(i + 1, new ObjectMapper().writeValueAsString(value)); + } + } else { + statement.setObject(i + 1, value); + } + } catch (NoSuchFieldException | IllegalAccessException | SQLException e) { + throw new SKException("Failed to set statement values", e); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + } + + /** + * Upserts records into the collection. + * @param collectionName the collection name + * @param records the records to upsert + * @param recordDefinition the record definition + * @param options the upsert options + * @throws SKException if the upsert fails + */ + @Override + @SuppressFBWarnings("SQL_PREPARED_STATEMENT_GENERATED_FROM_NONCONSTANT_STRING") // SQL query is generated dynamically with valid identifiers + public void upsertRecords(String collectionName, List records, + VectorStoreRecordDefinition recordDefinition, UpsertRecordOptions options) { + validateSQLidentifier(getCollectionTableName(collectionName)); + + List fields = recordDefinition.getAllFields(); + + StringBuilder onDuplicateKeyUpdate = new StringBuilder(); + for (int i = 0; i < fields.size(); ++i) { + VectorStoreRecordField field = fields.get(i); + if (i > 0) { + onDuplicateKeyUpdate.append(", "); + } + + onDuplicateKeyUpdate.append(field.getName()).append(" = VALUES(") + .append(field.getName()).append(")"); + } + + String query = "INSERT INTO " + getCollectionTableName(collectionName) + + " (" + getQueryColumnsFromFields(fields) + ")" + + " VALUES (" + getWildcardString(fields.size()) + ")" + + " ON DUPLICATE KEY UPDATE " + onDuplicateKeyUpdate; + + try (Connection connection = dataSource.getConnection(); + PreparedStatement statement = connection.prepareStatement(query)) { + for (Object record : records) { + setStatementValues(statement, record, recordDefinition.getAllFields()); + statement.addBatch(); + } + + statement.executeBatch(); + } catch (SQLException e) { + throw new SKException("Failed to upsert records", e); + } + } + + public static class Builder + extends JDBCVectorStoreDefaultQueryProvider.Builder { + private DataSource dataSource; + private String collectionsTable = DEFAULT_COLLECTIONS_TABLE; + private String prefixForCollectionTables = DEFAULT_PREFIX_FOR_COLLECTION_TABLES; + + @SuppressFBWarnings("EI_EXPOSE_REP2") + public Builder withDataSource(DataSource dataSource) { + this.dataSource = dataSource; + return this; + } + + /** + * Sets the collections table name. + * @param collectionsTable the collections table name + * @return the builder + */ + public Builder withCollectionsTable(String collectionsTable) { + this.collectionsTable = validateSQLidentifier(collectionsTable); + return this; + } + + /** + * Sets the prefix for collection tables. + * @param prefixForCollectionTables the prefix for collection tables + * @return the builder + */ + public Builder withPrefixForCollectionTables(String prefixForCollectionTables) { + this.prefixForCollectionTables = validateSQLidentifier(prefixForCollectionTables); + return this; + } + + public MySQLVectorStoreQueryProvider build() { + if (dataSource == null) { + throw new SKException("DataSource is required"); + } + + return new MySQLVectorStoreQueryProvider(dataSource, collectionsTable, + prefixForCollectionTables); + } + } +} diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/postgres/PostgreSQLVectorStoreQueryProvider.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/postgres/PostgreSQLVectorStoreQueryProvider.java new file mode 100644 index 00000000..d9d5deff --- /dev/null +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/postgres/PostgreSQLVectorStoreQueryProvider.java @@ -0,0 +1,338 @@ +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.semantickernel.connectors.data.postgres; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreDefaultQueryProvider; +import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreQueryProvider; +import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDefinition; +import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordField; +import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordKeyField; +import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordVectorField; +import com.microsoft.semantickernel.data.recordoptions.UpsertRecordOptions; +import com.microsoft.semantickernel.exceptions.SKException; +import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; + +import javax.sql.DataSource; +import java.lang.reflect.Field; +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.time.OffsetDateTime; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class PostgreSQLVectorStoreQueryProvider extends + JDBCVectorStoreDefaultQueryProvider implements JDBCVectorStoreQueryProvider { + + private Map, String> supportedKeyTypes; + private Map, String> supportedDataTypes; + private Map, String> supportedVectorTypes; + + private final DataSource dataSource; + private final String collectionsTable; + private final String prefixForCollectionTables; + + @SuppressFBWarnings("EI_EXPOSE_REP2") + private PostgreSQLVectorStoreQueryProvider(DataSource dataSource, String collectionsTable, + String prefixForCollectionTables) { + super(dataSource, collectionsTable, prefixForCollectionTables); + this.dataSource = dataSource; + this.collectionsTable = collectionsTable; + this.prefixForCollectionTables = prefixForCollectionTables; + + supportedKeyTypes = new HashMap<>(); + supportedKeyTypes.put(String.class, "VARCHAR(255)"); + + supportedDataTypes = new HashMap<>(); + supportedDataTypes.put(String.class, "TEXT"); + supportedDataTypes.put(Integer.class, "INTEGER"); + supportedDataTypes.put(int.class, "INTEGER"); + supportedDataTypes.put(Long.class, "BIGINT"); + supportedDataTypes.put(long.class, "BIGINT"); + supportedDataTypes.put(Float.class, "REAL"); + supportedDataTypes.put(float.class, "REAL"); + supportedDataTypes.put(Double.class, "DOUBLE PRECISION"); + supportedDataTypes.put(double.class, "DOUBLE PRECISION"); + supportedDataTypes.put(Boolean.class, "BOOLEAN"); + supportedDataTypes.put(boolean.class, "BOOLEAN"); + supportedDataTypes.put(OffsetDateTime.class, "TIMESTAMPTZ"); + + supportedVectorTypes = new HashMap<>(); + supportedDataTypes.put(String.class, "TEXT"); + supportedVectorTypes.put(List.class, "VECTOR(%d)"); + supportedVectorTypes.put(Collection.class, "VECTOR(%d)"); + } + + /** + * Gets the supported key types and their corresponding SQL types. + * + * @return the supported key types + */ + @Override + public Map, String> getSupportedKeyTypes() { + return new HashMap<>(this.supportedKeyTypes); + } + + /** + * Gets the supported data types and their corresponding SQL types. + * + * @return the supported data types + */ + @Override + public Map, String> getSupportedDataTypes() { + return new HashMap<>(this.supportedDataTypes); + } + + /** + * Gets the supported vector types and their corresponding SQL types. + * + * @return the supported vector types + */ + @Override + public Map, String> getSupportedVectorTypes() { + return new HashMap<>(this.supportedVectorTypes); + } + + /** + * Creates a new builder. + * @return the builder + */ + public static PostgreSQLVectorStoreQueryProvider.Builder builder() { + return new PostgreSQLVectorStoreQueryProvider.Builder(); + } + + /** + * Prepares the vector store. + * Executes any necessary setup steps for the vector store. + * + * @throws SKException if an error occurs while preparing the vector store + */ + @Override + public void prepareVectorStore() { + super.prepareVectorStore(); + + // Create the vector extension + String pgVector = "CREATE EXTENSION IF NOT EXISTS vector"; + + try (Connection connection = dataSource.getConnection(); + PreparedStatement createPgVector = connection.prepareStatement(pgVector)) { + createPgVector.execute(); + } catch (SQLException e) { + throw new SKException("Failed to prepare vector store", e); + } + } + + private String getColumnNamesAndTypesForVectorFields(List fields, + Class recordClass) { + StringBuilder columnNames = new StringBuilder(); + for (VectorStoreRecordVectorField field : fields) { + try { + Field declaredField = recordClass.getDeclaredField(field.getName()); + if (columnNames.length() > 0) { + columnNames.append(", "); + } + + if (declaredField.getType().equals(String.class)) { + columnNames.append(field.getName()).append(" ") + .append(supportedVectorTypes.get(String.class)); + } else { + // Get the vector type and dimensions + String type = String.format(supportedVectorTypes.get(declaredField.getType()), + field.getDimensions()); + columnNames.append(field.getName()).append(" ").append(type); + } + } catch (NoSuchFieldException e) { + throw new RuntimeException(e); + } + } + + return columnNames.toString(); + } + + /** + * Creates a collection. + * + * @param collectionName the collection name + * @param recordClass the record class + * @param recordDefinition the record definition + * @throws SKException if an error occurs while creating the collection + */ + @Override + @SuppressFBWarnings("SQL_PREPARED_STATEMENT_GENERATED_FROM_NONCONSTANT_STRING") // SQL query is generated dynamically with valid identifiers + public void createCollection(String collectionName, Class recordClass, + VectorStoreRecordDefinition recordDefinition) { + Field keyDeclaredField = recordDefinition.getKeyDeclaredField(recordClass); + List dataDeclaredFields = recordDefinition.getDataDeclaredFields(recordClass); + + String createStorageTable = "CREATE TABLE IF NOT EXISTS " + + getCollectionTableName(collectionName) + + " (" + keyDeclaredField.getName() + " VARCHAR(255) PRIMARY KEY, " + + getColumnNamesAndTypes(dataDeclaredFields, supportedDataTypes) + ", " + + getColumnNamesAndTypesForVectorFields(recordDefinition.getVectorFields(), recordClass) + + ");"; + + String insertCollectionQuery = "INSERT INTO " + validateSQLidentifier(collectionsTable) + + " (collectionId) VALUES (?)"; + + try (Connection connection = dataSource.getConnection(); + PreparedStatement createTable = connection.prepareStatement(createStorageTable)) { + createTable.execute(); + } catch (SQLException e) { + throw new SKException("Failed to create collection", e); + } + + try (Connection connection = dataSource.getConnection(); + PreparedStatement insert = connection.prepareStatement(insertCollectionQuery)) { + insert.setObject(1, collectionName); + insert.execute(); + } catch (SQLException e) { + throw new SKException("Failed to insert collection", e); + } + } + + private void setStatementValues(PreparedStatement statement, Object record, + List fields) { + for (int i = 0; i < fields.size(); ++i) { + VectorStoreRecordField field = fields.get(i); + try { + Field recordField = record.getClass().getDeclaredField(field.getName()); + recordField.setAccessible(true); + Object value = recordField.get(record); + + if (field instanceof VectorStoreRecordKeyField) { + statement.setObject(i + 1, (String) value); + } else if (field instanceof VectorStoreRecordVectorField) { + Class vectorType = record.getClass().getDeclaredField(field.getName()) + .getType(); + + // If the vector field is other than String, serialize it to JSON + if (vectorType.equals(String.class)) { + statement.setObject(i + 1, value); + } else { + // Serialize the vector to JSON + statement.setString(i + 1, new ObjectMapper().writeValueAsString(value)); + } + } else { + statement.setObject(i + 1, value); + } + } catch (NoSuchFieldException | IllegalAccessException | SQLException e) { + throw new SKException("Failed to set statement values", e); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + } + + private String getWildcardStringWithCast(List fields) { + StringBuilder wildcardString = new StringBuilder(); + int wildcards = fields.size(); + for (int i = 0; i < wildcards; ++i) { + if (i > 0) { + wildcardString.append(", "); + } + wildcardString.append("?"); + // Add casting for vector fields + if (fields.get(i) instanceof VectorStoreRecordVectorField) { + wildcardString.append("::vector"); + } + } + return wildcardString.toString(); + } + + /** + * Upserts records into the collection. + * @param collectionName the collection name + * @param records the records to upsert + * @param recordDefinition the record definition + * @param options the upsert options + * @throws SKException if the upsert fails + */ + @Override + @SuppressFBWarnings("SQL_PREPARED_STATEMENT_GENERATED_FROM_NONCONSTANT_STRING") // SQL query is generated dynamically with valid identifiers + public void upsertRecords(String collectionName, List records, + VectorStoreRecordDefinition recordDefinition, UpsertRecordOptions options) { + validateSQLidentifier(getCollectionTableName(collectionName)); + List fields = recordDefinition.getAllFields(); + + StringBuilder onDuplicateKeyUpdate = new StringBuilder(); + for (VectorStoreRecordField field : fields) { + if (field instanceof VectorStoreRecordKeyField) { + continue; + } + if (onDuplicateKeyUpdate.length() > 0) { + onDuplicateKeyUpdate.append(", "); + } + onDuplicateKeyUpdate.append(field.getName()) + .append(" = EXCLUDED.") + .append(field.getName()); + } + + String query = "INSERT INTO " + getCollectionTableName(collectionName) + + " (" + getQueryColumnsFromFields(fields) + ")" + + " VALUES (" + getWildcardStringWithCast(fields) + ")" + + " ON CONFLICT (" + recordDefinition.getKeyField().getName() + ") DO UPDATE SET " + + onDuplicateKeyUpdate; + + try (Connection connection = dataSource.getConnection(); + PreparedStatement statement = connection.prepareStatement(query)) { + for (Object record : records) { + setStatementValues(statement, record, recordDefinition.getAllFields()); + statement.addBatch(); + } + + statement.executeBatch(); + } catch (SQLException e) { + throw new SKException("Failed to upsert records", e); + } + } + + public static class Builder + extends JDBCVectorStoreDefaultQueryProvider.Builder { + private DataSource dataSource; + private String collectionsTable = DEFAULT_COLLECTIONS_TABLE; + private String prefixForCollectionTables = DEFAULT_PREFIX_FOR_COLLECTION_TABLES; + + @SuppressFBWarnings("EI_EXPOSE_REP2") + public PostgreSQLVectorStoreQueryProvider.Builder withDataSource(DataSource dataSource) { + this.dataSource = dataSource; + return this; + } + + /** + * Sets the collections table name. + * @param collectionsTable the collections table name + * @return the builder + */ + public PostgreSQLVectorStoreQueryProvider.Builder withCollectionsTable( + String collectionsTable) { + this.collectionsTable = validateSQLidentifier(collectionsTable); + return this; + } + + /** + * Sets the prefix for collection tables. + * @param prefixForCollectionTables the prefix for collection tables + * @return the builder + */ + public PostgreSQLVectorStoreQueryProvider.Builder withPrefixForCollectionTables( + String prefixForCollectionTables) { + this.prefixForCollectionTables = validateSQLidentifier(prefixForCollectionTables); + return this; + } + + public PostgreSQLVectorStoreQueryProvider build() { + if (dataSource == null) { + throw new SKException("DataSource is required"); + } + + return new PostgreSQLVectorStoreQueryProvider(dataSource, collectionsTable, + prefixForCollectionTables); + } + } +} diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/postgres/PostgreSQLVectorStoreRecordMapper.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/postgres/PostgreSQLVectorStoreRecordMapper.java new file mode 100644 index 00000000..83b821c3 --- /dev/null +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/postgres/PostgreSQLVectorStoreRecordMapper.java @@ -0,0 +1,146 @@ +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.semantickernel.connectors.data.postgres; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.microsoft.semantickernel.builders.SemanticKernelBuilder; +import com.microsoft.semantickernel.data.VectorStoreRecordMapper; +import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDefinition; +import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordField; +import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordVectorField; +import com.microsoft.semantickernel.exceptions.SKException; +import org.postgresql.util.PGobject; + +import java.lang.reflect.Constructor; +import java.lang.reflect.Field; +import java.lang.reflect.InvocationTargetException; +import java.sql.ResultSet; +import java.sql.ResultSetMetaData; +import java.sql.SQLException; +import java.util.List; +import java.util.function.Function; + +public class PostgreSQLVectorStoreRecordMapper + extends VectorStoreRecordMapper { + + /** + * Constructs a new instance of the VectorStoreRecordMapper. + * + * @param storageModelToRecordMapper the function to convert a storage model to a record + */ + protected PostgreSQLVectorStoreRecordMapper( + Function storageModelToRecordMapper) { + super(null, storageModelToRecordMapper); + } + + /** + * Creates a new builder. + * + * @param the record type + * @return the builder + */ + public static Builder builder() { + return new Builder<>(); + } + + public static class Builder + implements SemanticKernelBuilder> { + private Class recordClass; + private VectorStoreRecordDefinition vectorStoreRecordDefinition; + + /** + * Sets the record class. + * + * @param recordClass the record class + * @return the builder + */ + public Builder withRecordClass(Class recordClass) { + this.recordClass = recordClass; + return this; + } + + /** + * Sets the vector store record definition. + * + * @param vectorStoreRecordDefinition the vector store record definition + * @return the builder + */ + public Builder withVectorStoreRecordDefinition( + VectorStoreRecordDefinition vectorStoreRecordDefinition) { + this.vectorStoreRecordDefinition = vectorStoreRecordDefinition; + return this; + } + + /** + * Builds the {@link PostgreSQLVectorStoreRecordMapper}. + * + * @return the {@link PostgreSQLVectorStoreRecordMapper} + */ + public PostgreSQLVectorStoreRecordMapper build() { + if (recordClass == null) { + throw new IllegalArgumentException("recordClass is required"); + } + if (vectorStoreRecordDefinition == null) { + throw new IllegalArgumentException("vectorStoreRecordDefinition is required"); + } + + return new PostgreSQLVectorStoreRecordMapper<>( + resultSet -> { + try { + Constructor constructor = recordClass.getDeclaredConstructor(); + constructor.setAccessible(true); + Record record = (Record) constructor.newInstance(); + + // Select fields from the record definition. + // Check if vector fields are present in the result set. + List fields; + ResultSetMetaData metaData = resultSet.getMetaData(); + if (metaData.getColumnCount() == vectorStoreRecordDefinition.getAllFields() + .size()) { + fields = vectorStoreRecordDefinition.getAllFields(); + } else { + fields = vectorStoreRecordDefinition.getNonVectorFields(); + } + + for (VectorStoreRecordField field : fields) { + Object value = resultSet.getObject(field.getName()); + Field recordField = recordClass.getDeclaredField(field.getName()); + recordField.setAccessible(true); + + // If the field is a vector field, deserialize the JSON string + if (field instanceof VectorStoreRecordVectorField) { + Class vectorType = recordField.getType(); + + // If the vector type is a string, set the value directly + if (vectorType.equals(String.class)) { + recordField.set(record, value); + } else { + // Deserialize the pgvector string to the vector type + PGobject pgObject = (PGobject) value; + recordField.set(record, + new ObjectMapper().readValue(pgObject.getValue(), + vectorType)); + } + } else { + recordField.set(record, value); + } + } + + return record; + } catch (NoSuchMethodException e) { + throw new SKException("Default constructor not found.", e); + } catch (InstantiationException | InvocationTargetException e) { + throw new SKException(String.format( + "SK cannot instantiate %s. A custom mapper is required.", + recordClass.getName()), e); + } catch (JsonProcessingException e) { + throw new SKException(String.format( + "SK cannot deserialize %s. A custom mapper is required.", + recordClass.getName()), e); + } catch (SQLException | NoSuchFieldException | IllegalAccessException e) { + throw new RuntimeException(e); + } + }); + } + } +} diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/redis/RedisVectorStore.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/redis/RedisVectorStore.java index 7e561e43..02f728f0 100644 --- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/redis/RedisVectorStore.java +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/redis/RedisVectorStore.java @@ -3,6 +3,7 @@ import com.microsoft.semantickernel.builders.SemanticKernelBuilder; import com.microsoft.semantickernel.data.VectorStore; +import com.microsoft.semantickernel.data.VectorStoreRecordCollection; import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDefinition; import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; import java.util.ArrayList; @@ -12,7 +13,7 @@ import reactor.core.publisher.Mono; import redis.clients.jedis.JedisPooled; -public class RedisVectorStore implements VectorStore> { +public class RedisVectorStore implements VectorStore { private final JedisPooled client; private final RedisVectorStoreOptions options; @@ -30,6 +31,22 @@ public RedisVectorStore(@Nonnull JedisPooled client, this.options = options; } + @Override + public VectorStoreRecordCollection getCollection( + @Nonnull String collectionName, + @Nonnull Class keyClass, + @Nonnull Class recordClass, + @Nullable VectorStoreRecordDefinition recordDefinition) { + if (keyClass != String.class) { + throw new IllegalArgumentException("Redis only supports string keys"); + } + + return (VectorStoreRecordCollection) getCollection( + collectionName, + recordClass, + recordDefinition); + } + /** * Gets a collection from the vector store. * @@ -38,8 +55,7 @@ public RedisVectorStore(@Nonnull JedisPooled client, * @param recordDefinition The record definition. * @return The collection. */ - @Override - public RedisVectorStoreRecordCollection getCollection( + public RedisVectorStoreRecordCollection getCollection( @Nonnull String collectionName, @Nonnull Class recordClass, @Nullable VectorStoreRecordDefinition recordDefinition) { @@ -74,7 +90,6 @@ public Mono> getCollectionNamesAsync() { /** * Builder for the Redis vector store. - * */ public static Builder builder() { return new Builder(); diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/redis/RedisVectorStoreRecordCollection.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/redis/RedisVectorStoreRecordCollection.java index 8783320b..b0f8858b 100644 --- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/redis/RedisVectorStoreRecordCollection.java +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/redis/RedisVectorStoreRecordCollection.java @@ -12,6 +12,17 @@ import com.microsoft.semantickernel.data.recordoptions.GetRecordOptions; import com.microsoft.semantickernel.data.recordoptions.UpsertRecordOptions; import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; +import java.util.AbstractMap.SimpleEntry; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.stream.Collectors; +import javax.annotation.Nonnull; import org.json.JSONArray; import org.json.JSONObject; import reactor.core.publisher.Mono; @@ -25,18 +36,6 @@ import redis.clients.jedis.search.IndexOptions; import redis.clients.jedis.search.Schema; -import javax.annotation.Nonnull; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Map.Entry; -import java.util.AbstractMap.SimpleEntry; -import java.util.stream.Collectors; - public class RedisVectorStoreRecordCollection implements VectorStoreRecordCollection { @@ -81,10 +80,13 @@ public RedisVectorStoreRecordCollection( } // Validate supported types - VectorStoreRecordDefinition.validateSupportedKeyTypes(options.getRecordClass(), - recordDefinition, supportedKeyTypes); - VectorStoreRecordDefinition.validateSupportedVectorTypes(options.getRecordClass(), - recordDefinition, supportedVectorTypes); + VectorStoreRecordDefinition.validateSupportedTypes( + Collections + .singletonList(recordDefinition.getKeyDeclaredField(this.options.getRecordClass())), + supportedKeyTypes); + VectorStoreRecordDefinition.validateSupportedTypes( + recordDefinition.getVectorDeclaredFields(this.options.getRecordClass()), + supportedVectorTypes); // If mapper is not provided, set a default one if (options.getVectorStoreRecordMapper() == null) { @@ -140,7 +142,7 @@ public Mono collectionExistsAsync() { * @return A Mono representing the completion of the creation operation. */ @Override - public Mono createCollectionAsync() { + public Mono> createCollectionAsync() { return Mono.fromRunnable(() -> { Schema schema = RedisVectorStoreCollectionCreateMapping .mapToSchema(recordDefinition.getAllFields()); @@ -152,17 +154,19 @@ public Mono createCollectionAsync() { collectionName, IndexOptions.defaultOptions().setDefinition(indexDefinition), schema); - }).subscribeOn(Schedulers.boundedElastic()).then(); + }) + .subscribeOn(Schedulers.boundedElastic()) + .then(Mono.just(this)); } @Override - public Mono createCollectionIfNotExistsAsync() { + public Mono> createCollectionIfNotExistsAsync() { return collectionExistsAsync().flatMap(exists -> { if (!exists) { return createCollectionAsync(); } - return Mono.empty(); + return Mono.just(this); }); } @@ -200,7 +204,7 @@ private JsonNode removeRedisPathPrefix(JSONObject object) { /** * Gets a record from the store. * - * @param key The key of the record to get. + * @param key The key of the record to get. * @param options The options for getting the record. * @return A Mono emitting the record. */ @@ -240,7 +244,7 @@ public Mono getAsync(String key, GetRecordOptions options) { /** * Gets a batch of records from the store. * - * @param keys The keys of the records to get. + * @param keys The keys of the records to get. * @param options The options for getting the records. * @return A Mono emitting a list of records. */ @@ -333,7 +337,7 @@ public Mono> upsertBatchAsync(List data, UpsertRecordOption /** * Deletes a record from the store. * - * @param key The key of the record to delete. + * @param key The key of the record to delete. * @param options The options for deleting the record. * @return A Mono representing the completion of the deletion operation. */ diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/redis/RedisVectorStoreRecordCollectionFactory.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/redis/RedisVectorStoreRecordCollectionFactory.java index 45417980..df1e7544 100644 --- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/redis/RedisVectorStoreRecordCollectionFactory.java +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/redis/RedisVectorStoreRecordCollectionFactory.java @@ -1,19 +1,20 @@ // Copyright (c) Microsoft. All rights reserved. package com.microsoft.semantickernel.connectors.data.redis; +import com.azure.search.documents.indexes.SearchIndexAsyncClient; import redis.clients.jedis.JedisPooled; /** * Factory for creating Redis vector store record collections. - * */ public interface RedisVectorStoreRecordCollectionFactory { + /** * Creates a new vector store record collection. * - * @param client The Redis client. + * @param client The Redis client. * @param collectionName The name of the collection. - * @param options The options for the collection. + * @param options The options for the collection. * @return The collection. */ RedisVectorStoreRecordCollection createVectorStoreRecordCollection( diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/redis/RedisVectorStoreRecordMapper.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/redis/RedisVectorStoreRecordMapper.java index a4f5f798..f33a6b1d 100644 --- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/redis/RedisVectorStoreRecordMapper.java +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/redis/RedisVectorStoreRecordMapper.java @@ -1,17 +1,15 @@ // Copyright (c) Microsoft. All rights reserved. package com.microsoft.semantickernel.connectors.data.redis; -import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.node.ObjectNode; import com.microsoft.semantickernel.builders.SemanticKernelBuilder; import com.microsoft.semantickernel.data.VectorStoreRecordMapper; import com.microsoft.semantickernel.exceptions.SKException; - -import javax.annotation.Nullable; import java.util.AbstractMap; import java.util.Map.Entry; import java.util.function.Function; +import javax.annotation.Nullable; public class RedisVectorStoreRecordMapper extends VectorStoreRecordMapper> { diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/VectorStore.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/VectorStore.java index d778829a..a51f044c 100644 --- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/VectorStore.java +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/VectorStore.java @@ -1,31 +1,28 @@ // Copyright (c) Microsoft. All rights reserved. package com.microsoft.semantickernel.data; -import com.microsoft.semantickernel.data.VectorStoreRecordCollection; import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDefinition; -import reactor.core.publisher.Mono; - +import java.util.List; import javax.annotation.Nonnull; import javax.annotation.Nullable; -import java.util.List; +import reactor.core.publisher.Mono; /** * Represents a vector store. - * - * @param The type of the record collection. */ -public interface VectorStore> { +public interface VectorStore { /** * Gets a collection from the vector store. * - * @param collectionName The name of the collection. - * @param recordClass The class type of the record. + * @param collectionName The name of the collection. + * @param recordClass The class type of the record. * @param recordDefinition The record definition. * @return The collection. */ - RecordCollection getCollection( + VectorStoreRecordCollection getCollection( @Nonnull String collectionName, + @Nonnull Class keyClass, @Nonnull Class recordClass, @Nullable VectorStoreRecordDefinition recordDefinition); diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/VectorStoreRecordCollection.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/VectorStoreRecordCollection.java index 867cbf16..0b7319c0 100644 --- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/VectorStoreRecordCollection.java +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/VectorStoreRecordCollection.java @@ -4,9 +4,8 @@ import com.microsoft.semantickernel.data.recordoptions.DeleteRecordOptions; import com.microsoft.semantickernel.data.recordoptions.GetRecordOptions; import com.microsoft.semantickernel.data.recordoptions.UpsertRecordOptions; -import reactor.core.publisher.Mono; - import java.util.List; +import reactor.core.publisher.Mono; public interface VectorStoreRecordCollection { @@ -15,40 +14,40 @@ public interface VectorStoreRecordCollection { * * @return The name of the collection. */ - public String getCollectionName(); + String getCollectionName(); /** * Checks if the collection exists in the store. * * @return A Mono emitting a boolean indicating if the collection exists. */ - public Mono collectionExistsAsync(); + Mono collectionExistsAsync(); /** * Creates the collection in the store. * * @return A Mono representing the completion of the creation operation. */ - public Mono createCollectionAsync(); + Mono> createCollectionAsync(); /** * Creates the collection in the store if it does not exist. * * @return A Mono representing the completion of the creation operation. */ - public Mono createCollectionIfNotExistsAsync(); + Mono> createCollectionIfNotExistsAsync(); /** * Deletes the collection from the store. * * @return A Mono representing the completion of the deletion operation. */ - public Mono deleteCollectionAsync(); + Mono deleteCollectionAsync(); /** * Gets a record from the store. * - * @param key The key of the record to get. + * @param key The key of the record to get. * @param options The options for getting the record. * @return A Mono emitting the record. */ @@ -57,7 +56,7 @@ public interface VectorStoreRecordCollection { /** * Gets a batch of records from the store. * - * @param keys The keys of the records to get. + * @param keys The keys of the records to get. * @param options The options for getting the records. * @return A Mono emitting a list of records. */ @@ -66,7 +65,7 @@ public interface VectorStoreRecordCollection { /** * Inserts or updates a record in the store. * - * @param data The record to upsert. + * @param data The record to upsert. * @param options The options for upserting the record. * @return A Mono emitting the key of the upserted record. */ @@ -75,7 +74,7 @@ public interface VectorStoreRecordCollection { /** * Inserts or updates a batch of records in the store. * - * @param data The records to upsert. + * @param data The records to upsert. * @param options The options for upserting the records. * @return A Mono emitting a list of keys of the upserted records. */ @@ -84,7 +83,7 @@ public interface VectorStoreRecordCollection { /** * Deletes a record from the store. * - * @param key The key of the record to delete. + * @param key The key of the record to delete. * @param options The options for deleting the record. * @return A Mono representing the completion of the deletion operation. */ @@ -93,7 +92,7 @@ public interface VectorStoreRecordCollection { /** * Deletes a batch of records from the store. * - * @param keys The keys of the records to delete. + * @param keys The keys of the records to delete. * @param options The options for deleting the records. * @return A Mono representing the completion of the deletion operation. */ diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/VectorStoreRecordMapper.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/VectorStoreRecordMapper.java index 09420cf8..100c7e6d 100644 --- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/VectorStoreRecordMapper.java +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/VectorStoreRecordMapper.java @@ -3,6 +3,7 @@ import com.microsoft.semantickernel.builders.SemanticKernelBuilder; +import javax.annotation.Nullable; import java.util.function.Function; /** @@ -12,6 +13,7 @@ * @param the storage model type */ public class VectorStoreRecordMapper { + @Nullable private final Function recordToStorageModelMapper; private final Function storageModelToRecordMapper; @@ -22,7 +24,7 @@ public class VectorStoreRecordMapper { * @param storageModelToRecordMapper the function to convert a storage model to a record */ protected VectorStoreRecordMapper( - Function recordToStorageModelMapper, + @Nullable Function recordToStorageModelMapper, Function storageModelToRecordMapper) { this.recordToStorageModelMapper = recordToStorageModelMapper; this.storageModelToRecordMapper = storageModelToRecordMapper; @@ -33,6 +35,7 @@ protected VectorStoreRecordMapper( * * @return the function to convert a record to a storage model */ + @Nullable public Function getRecordToStorageModelMapper() { return recordToStorageModelMapper; } diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/VolatileVectorStore.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/VolatileVectorStore.java index ef074247..25e90ad9 100644 --- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/VolatileVectorStore.java +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/VolatileVectorStore.java @@ -2,15 +2,16 @@ package com.microsoft.semantickernel.data; import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDefinition; -import reactor.core.publisher.Mono; - -import javax.annotation.Nonnull; import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import reactor.core.publisher.Mono; + +public class VolatileVectorStore implements VectorStore { -public class VolatileVectorStore implements VectorStore> { private final Map> collections; public VolatileVectorStore() { @@ -25,9 +26,25 @@ public VolatileVectorStore() { * @return The collection. */ @Override - public VolatileVectorStoreRecordCollection getCollection( - @Nonnull String collectionName, @Nonnull Class recordClass, - VectorStoreRecordDefinition recordDefinition) { + public VectorStoreRecordCollection getCollection( + @Nonnull String collectionName, + @Nonnull Class keyClass, + @Nonnull Class recordClass, + @Nullable VectorStoreRecordDefinition recordDefinition) { + if (keyClass != String.class) { + throw new IllegalArgumentException("Volatile only supports string keys"); + } + + return (VectorStoreRecordCollection) getCollection( + collectionName, + recordClass, + recordDefinition); + } + + public VectorStoreRecordCollection getCollection( + @Nonnull String collectionName, + @Nonnull Class recordClass, + @Nullable VectorStoreRecordDefinition recordDefinition) { return new VolatileVectorStoreRecordCollection<>( collectionName, collections, diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/VolatileVectorStoreRecordCollection.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/VolatileVectorStoreRecordCollection.java index e675d4cd..07d55cdd 100644 --- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/VolatileVectorStoreRecordCollection.java +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/VolatileVectorStoreRecordCollection.java @@ -8,18 +8,17 @@ import com.microsoft.semantickernel.data.recordoptions.GetRecordOptions; import com.microsoft.semantickernel.data.recordoptions.UpsertRecordOptions; import com.microsoft.semantickernel.exceptions.SKException; -import reactor.core.publisher.Mono; - import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; +import reactor.core.publisher.Mono; + +public class VolatileVectorStoreRecordCollection implements + VectorStoreRecordCollection { -public class VolatileVectorStoreRecordCollection - implements VectorStoreRecordCollection { private static final HashSet> supportedKeyTypes = new HashSet<>( Collections.singletonList(String.class)); private Map> collections; @@ -43,8 +42,10 @@ public VolatileVectorStoreRecordCollection(String collectionName, } // Validate the key type - VectorStoreRecordDefinition.validateSupportedKeyTypes(options.getRecordClass(), - recordDefinition, supportedKeyTypes); + VectorStoreRecordDefinition.validateSupportedTypes( + Collections + .singletonList(recordDefinition.getKeyDeclaredField(options.getRecordClass())), + supportedKeyTypes); } VolatileVectorStoreRecordCollection(String collectionName, @@ -80,8 +81,9 @@ public Mono collectionExistsAsync() { * @return A Mono representing the completion of the creation operation. */ @Override - public Mono createCollectionAsync() { - return Mono.fromRunnable(() -> collections.put(collectionName, new ConcurrentHashMap<>())); + public Mono> createCollectionAsync() { + return Mono.fromRunnable(() -> collections.put(collectionName, new ConcurrentHashMap<>())) + .then(Mono.just(this)); } /** @@ -90,9 +92,10 @@ public Mono createCollectionAsync() { * @return A Mono representing the completion of the creation operation. */ @Override - public Mono createCollectionIfNotExistsAsync() { + public Mono> createCollectionIfNotExistsAsync() { return Mono - .fromRunnable(() -> collections.putIfAbsent(collectionName, new ConcurrentHashMap<>())); + .fromRunnable(() -> collections.putIfAbsent(collectionName, new ConcurrentHashMap<>())) + .then(Mono.just(this)); } /** @@ -108,7 +111,7 @@ public Mono deleteCollectionAsync() { /** * Gets a record from the store. * - * @param key The key of the record to get. + * @param key The key of the record to get. * @param options The options for getting the record. * @return A Mono emitting the record. */ @@ -120,7 +123,7 @@ public Mono getAsync(String key, GetRecordOptions options) { /** * Gets a batch of records from the store. * - * @param keys The keys of the records to get. + * @param keys The keys of the records to get. * @param options The options for getting the records. * @return A Mono emitting a list of records. */ @@ -186,7 +189,7 @@ public Mono> upsertBatchAsync(List data, UpsertRecordOption /** * Deletes a record from the store. * - * @param key The key of the record to delete. + * @param key The key of the record to delete. * @param options The options for deleting the record. * @return A Mono representing the completion of the deletion operation. */ diff --git a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/recorddefinition/VectorStoreRecordDefinition.java b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/recorddefinition/VectorStoreRecordDefinition.java index a1914d2c..39e04a3f 100644 --- a/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/recorddefinition/VectorStoreRecordDefinition.java +++ b/semantickernel-experimental/src/main/java/com/microsoft/semantickernel/data/recorddefinition/VectorStoreRecordDefinition.java @@ -5,13 +5,12 @@ import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordKeyAttribute; import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordVectorAttribute; -import javax.annotation.Nonnull; -import javax.annotation.Nullable; import java.lang.reflect.Field; import java.util.ArrayList; import java.util.Collections; import java.util.HashSet; import java.util.List; +import java.util.Set; import java.util.stream.Collectors; /** @@ -50,6 +49,56 @@ public List getAllFields() { return fields; } + public List getNonVectorFields() { + List fields = new ArrayList<>(); + fields.add(keyField); + fields.addAll(dataFields); + return fields; + } + + private enum DeclaredFieldType { + KEY, DATA, VECTOR + } + + private List getDeclaredFields(Class recordClass, List fields, + DeclaredFieldType fieldType) { + List declaredFields = new ArrayList<>(); + for (VectorStoreRecordField field : fields) { + try { + Field declaredField = recordClass.getDeclaredField(field.getName()); + declaredFields.add(declaredField); + } catch (NoSuchFieldException e) { + throw new IllegalArgumentException( + String.format("%s field not found in record class: %s", fieldType, + field.getName())); + } + } + return declaredFields; + } + + public Field getKeyDeclaredField(Class recordClass) { + try { + return recordClass.getDeclaredField(keyField.getName()); + } catch (NoSuchFieldException e) { + throw new IllegalArgumentException( + "Key field not found in record class: " + keyField.getName()); + } + } + + public List getDataDeclaredFields(Class recordClass) { + return getDeclaredFields( + recordClass, + dataFields.stream().map(f -> (VectorStoreRecordField) f).collect(Collectors.toList()), + DeclaredFieldType.DATA); + } + + public List getVectorDeclaredFields(Class recordClass) { + return getDeclaredFields( + recordClass, + vectorFields.stream().map(f -> (VectorStoreRecordField) f).collect(Collectors.toList()), + DeclaredFieldType.VECTOR); + } + private VectorStoreRecordDefinition( VectorStoreRecordKeyField keyField, List dataFields, @@ -148,71 +197,20 @@ public static VectorStoreRecordDefinition fromRecordClass(Class recordClass) return checkFields(keyFields, dataFields, vectorFields); } - private static String getSupportedTypesString(@Nullable HashSet> types) { - if (types == null || types.isEmpty()) { - return ""; - } - return types.stream().map(Class::getName).collect(Collectors.joining(", ")); - } - - public static void validateSupportedKeyTypes(@Nonnull Class recordClass, - @Nonnull VectorStoreRecordDefinition recordDefinition, - @Nonnull HashSet> supportedTypes) { - String supportedTypesString = getSupportedTypesString(supportedTypes); - - try { - Field declaredField = recordClass.getDeclaredField(recordDefinition.keyField.getName()); - + public static void validateSupportedTypes(List declaredFields, + Set> supportedTypes) { + Set> unsupportedTypes = new HashSet<>(); + for (Field declaredField : declaredFields) { if (!supportedTypes.contains(declaredField.getType())) { - throw new IllegalArgumentException( - "Unsupported key field type: " + declaredField.getType().getName() - + ". Supported types are: " + supportedTypesString); + unsupportedTypes.add(declaredField.getType()); } - } catch (NoSuchFieldException e) { - throw new IllegalArgumentException( - "Key field not found in record class: " + recordDefinition.keyField.getName()); } - } - - public static void validateSupportedDataTypes(@Nonnull Class recordClass, - @Nonnull VectorStoreRecordDefinition recordDefinition, - @Nonnull HashSet> supportedTypes) { - String supportedTypesString = getSupportedTypesString(supportedTypes); - - for (VectorStoreRecordDataField field : recordDefinition.dataFields) { - try { - Field declaredField = recordClass.getDeclaredField(field.getName()); - - if (!supportedTypes.contains(declaredField.getType())) { - throw new IllegalArgumentException( - "Unsupported data field type: " + declaredField.getType().getName() - + ". Supported types are: " + supportedTypesString); - } - } catch (NoSuchFieldException e) { - throw new IllegalArgumentException( - "Data field not found in record class: " + field.getName()); - } - } - } - - public static void validateSupportedVectorTypes(@Nonnull Class recordClass, - @Nonnull VectorStoreRecordDefinition recordDefinition, - @Nonnull HashSet> supportedTypes) { - String supportedTypesString = getSupportedTypesString(supportedTypes); - - for (VectorStoreRecordVectorField field : recordDefinition.vectorFields) { - try { - Field declaredField = recordClass.getDeclaredField(field.getName()); - - if (!supportedTypes.contains(declaredField.getType())) { - throw new IllegalArgumentException( - "Unsupported vector field type: " + declaredField.getType().getName() - + ". Supported types are: " + supportedTypesString); - } - } catch (NoSuchFieldException e) { - throw new IllegalArgumentException( - "Vector field not found in record class: " + field.getName()); - } + if (!unsupportedTypes.isEmpty()) { + throw new IllegalArgumentException( + String.format( + "Unsupported field types found in record class: %s. Supported types: %s", + unsupportedTypes.stream().map(Class::getName).collect(Collectors.joining(", ")), + supportedTypes.stream().map(Class::getName).collect(Collectors.joining(", ")))); } } } diff --git a/semantickernel-experimental/src/test/java/com/microsoft/semantickernel/data/VolatileVectorStoreRecordCollectionTest.java b/semantickernel-experimental/src/test/java/com/microsoft/semantickernel/data/VolatileVectorStoreRecordCollectionTest.java index 9a087adb..915b2166 100644 --- a/semantickernel-experimental/src/test/java/com/microsoft/semantickernel/data/VolatileVectorStoreRecordCollectionTest.java +++ b/semantickernel-experimental/src/test/java/com/microsoft/semantickernel/data/VolatileVectorStoreRecordCollectionTest.java @@ -1,20 +1,18 @@ // Copyright (c) Microsoft. All rights reserved. package com.microsoft.semantickernel.data; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Order; -import org.junit.jupiter.api.Test; - -import java.util.Arrays; -import java.util.List; -import java.util.stream.Collectors; - import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertTrue; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + public class VolatileVectorStoreRecordCollectionTest { private static VolatileVectorStoreRecordCollection recordCollection; diff --git a/semantickernel-experimental/src/test/java/com/microsoft/semantickernel/data/VolatileVectorStoreTest.java b/semantickernel-experimental/src/test/java/com/microsoft/semantickernel/data/VolatileVectorStoreTest.java index cfd52757..99c643e5 100644 --- a/semantickernel-experimental/src/test/java/com/microsoft/semantickernel/data/VolatileVectorStoreTest.java +++ b/semantickernel-experimental/src/test/java/com/microsoft/semantickernel/data/VolatileVectorStoreTest.java @@ -1,17 +1,17 @@ // Copyright (c) Microsoft. All rights reserved. package com.microsoft.semantickernel.data; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; - -import java.util.Arrays; -import java.util.List; - import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertTrue; +import java.util.Arrays; +import java.util.List; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + public class VolatileVectorStoreTest { + private static VolatileVectorStore vectorStore; @BeforeAll