diff --git a/pgml-sdks/pgml/src/builtins.rs b/pgml-sdks/pgml/src/builtins.rs index db023b951..652bf0b8c 100644 --- a/pgml-sdks/pgml/src/builtins.rs +++ b/pgml-sdks/pgml/src/builtins.rs @@ -5,7 +5,7 @@ use tracing::instrument; /// Provides access to builtin database methods #[derive(alias, Debug, Clone)] pub struct Builtins { - pub database_url: Option, + database_url: Option, } use crate::{get_or_initialize_pool, query_runner::QueryRunner, types::Json}; diff --git a/pgml-sdks/pgml/src/collection.rs b/pgml-sdks/pgml/src/collection.rs index a343920b1..69dd8574a 100644 --- a/pgml-sdks/pgml/src/collection.rs +++ b/pgml-sdks/pgml/src/collection.rs @@ -102,12 +102,10 @@ pub(crate) struct CollectionDatabaseData { /// A collection of documents #[derive(alias, Debug, Clone)] pub struct Collection { - pub name: String, - pub database_url: Option, - pub pipelines_table_name: String, - pub documents_table_name: String, - pub chunks_table_name: String, - pub documents_tsvectors_table_name: String, + pub(crate) name: String, + pub(crate) database_url: Option, + pub(crate) pipelines_table_name: String, + pub(crate) documents_table_name: String, pub(crate) database_data: Option, } @@ -137,16 +135,21 @@ impl Collection { /// Creates a new [Collection] /// /// # Arguments - /// /// * `name` - The name of the collection. /// * `database_url` - An optional database_url. If passed, this url will be used instead of - /// the `DATABASE_URL` environment variable. + /// the `PGML_DATABASE_URL` environment variable. /// - /// # Example + /// # Errors + /// * If the `name` is not composed of alphanumeric characters, whitespace, or '-' and '_' /// + /// # Example /// ``` /// use pgml::Collection; - /// let collection = Collection::new("my_collection", None); + /// use anyhow::Result; + /// async fn doc() -> Result<()> { + /// let mut collection = Collection::new("my_collection", None)?; + /// Ok(()) + /// } /// ``` pub fn new(name: &str, database_url: Option) -> anyhow::Result { if !name @@ -157,19 +160,12 @@ impl Collection { "Name must only consist of letters, numebers, white space, and '-' or '_'" ) } - let ( - pipelines_table_name, - documents_table_name, - chunks_table_name, - documents_tsvectors_table_name, - ) = Self::generate_table_names(name); + let (pipelines_table_name, documents_table_name) = Self::generate_table_names(name); Ok(Self { name: name.to_string(), database_url, pipelines_table_name, documents_table_name, - chunks_table_name, - documents_tsvectors_table_name, database_data: None, }) } @@ -261,6 +257,26 @@ impl Collection { } /// Adds a new [Pipeline] to the [Collection] + /// + /// # Arguments + /// * `pipeline` - The [Pipeline] to add to the [Collection] + /// + /// # Errors + /// * If the [Pipeline] does not have schema + /// + /// # Example + /// ``` + /// use pgml::Collection; + /// use pgml::Pipeline; + /// use anyhow::Result; + /// use serde_json::json; + /// async fn doc() -> Result<()> { + /// let mut collection = Collection::new("my_collection", None)?; + /// let mut pipeline = Pipeline::new("my_pipeline", Some(json!({}).into()))?; + /// collection.add_pipeline(&mut pipeline).await?; + /// Ok(()) + /// } + /// ``` #[instrument(skip(self))] pub async fn add_pipeline(&mut self, pipeline: &mut Pipeline) -> anyhow::Result<()> { // The flow for this function: @@ -305,6 +321,23 @@ impl Collection { } /// Removes a [Pipeline] from the [Collection] + /// + /// # Arguments + /// * `pipeline` - The [Pipeline] to remove from the [Collection] + /// + /// # Example + /// ``` + /// use pgml::Collection; + /// use pgml::Pipeline; + /// use anyhow::Result; + /// use serde_json::json; + /// async fn doc() -> Result<()> { + /// let mut collection = Collection::new("my_collection", None)?; + /// let mut pipeline = Pipeline::new("my_pipeline", None)?; + /// collection.remove_pipeline(&mut pipeline).await?; + /// Ok(()) + /// } + /// ``` #[instrument(skip(self))] pub async fn remove_pipeline(&mut self, pipeline: &Pipeline) -> anyhow::Result<()> { // The flow for this function: @@ -334,6 +367,26 @@ impl Collection { } /// Enables a [Pipeline] on the [Collection] + /// + /// # Arguments + /// * `pipeline` - The [Pipeline] to enable + /// + /// # Errors + /// * If the pipeline has not already been added to the [Collection] + /// + /// # Example + /// ``` + /// use pgml::Collection; + /// use pgml::Pipeline; + /// use anyhow::Result; + /// use serde_json::json; + /// async fn doc() -> Result<()> { + /// let mut collection = Collection::new("my_collection", None)?; + /// let mut pipeline = Pipeline::new("my_pipeline", None)?; + /// collection.enable_pipeline(&mut pipeline).await?; + /// Ok(()) + /// } + /// ``` #[instrument(skip(self))] pub async fn enable_pipeline(&mut self, pipeline: &mut Pipeline) -> anyhow::Result<()> { // The flow for this function: @@ -356,6 +409,26 @@ impl Collection { } /// Disables a [Pipeline] on the [Collection] + /// + /// # Arguments + /// * `pipeline` - The [Pipeline] to remove + /// + /// # Errors + /// * If the pipeline has not already been added to the [Collection] + /// + /// # Example + /// ``` + /// use pgml::Collection; + /// use pgml::Pipeline; + /// use anyhow::Result; + /// use serde_json::json; + /// async fn doc() -> Result<()> { + /// let mut collection = Collection::new("my_collection", None)?; + /// let mut pipeline = Pipeline::new("my_pipeline", None)?; + /// collection.disable_pipeline(&pipeline).await?; + /// Ok(()) + /// } + /// ``` #[instrument(skip(self))] pub async fn disable_pipeline(&self, pipeline: &Pipeline) -> anyhow::Result<()> { // The flow for this function: @@ -390,7 +463,23 @@ impl Collection { Ok(()) } - /// Upserts documents into the database + /// Upserts documents into [Collection] + /// + /// # Arguments + /// * `documents` - A vector of [Json] documents to upsert + /// * `args` - A [Json] object containing arguments for the upsert + /// + /// # Example + /// ``` + /// use pgml::Collection; + /// use anyhow::Result; + /// use serde_json::json; + /// async fn doc() -> Result<()> { + /// let mut collection = Collection::new("my_collection", None)?; + /// collection.upsert_documents(vec![json!({"id": "1", "name": "one"}).into()], None).await?; + /// Ok(()) + /// } + /// ``` #[instrument(skip(self, documents))] pub async fn upsert_documents( &mut self, @@ -558,6 +647,31 @@ impl Collection { } /// Gets the documents on a [Collection] + /// + /// # Arguments + /// + /// * `args` - A JSON object containing the following keys: + /// * `limit` - The maximum number of documents to return. Defaults to 1000. + /// * `order_by` - A JSON array of objects that specify the order of the documents to return. + /// Each object must have a `field` key with the name of the field to order by, and a `direction` + /// key with the value `asc` or `desc`. + /// * `last_row_id` - The id of the last document returned + /// * `offset` - The number of documents to skip before returning results. + /// * `filter` - A JSON object specifying the filter to apply to the documents. + /// + /// # Example + /// + /// ``` + /// use pgml::Collection; + /// use serde_json::json; + /// use anyhow::Result; + /// async fn run() -> anyhow::Result<()> { + /// let collection = Collection::new("my_collection", None)?; + /// let documents = collection.get_documents(Some(json!({ + /// "limit": 2, + /// }).into())); + /// Ok(()) + /// } #[instrument(skip(self))] pub async fn get_documents(&self, args: Option) -> anyhow::Result> { let pool = get_or_initialize_pool(&self.database_url).await?; @@ -617,6 +731,26 @@ impl Collection { } /// Deletes documents in a [Collection] + /// + /// # Arguments + /// + /// * `filter` - A JSON object specifying the filter to apply to the documents. + /// + /// # Example + /// ``` + /// use pgml::Collection; + /// use serde_json::json; + /// use anyhow::Result; + /// async fn run() -> anyhow::Result<()> { + /// let collection = Collection::new("my_collection", None)?; + /// collection.delete_documents(json!({ + /// "id": { + /// "$eq": 1 + /// } + /// }).into()); + /// Ok(()) + /// } + /// ``` #[instrument(skip(self))] pub async fn delete_documents(&self, filter: Json) -> anyhow::Result<()> { let pool = get_or_initialize_pool(&self.database_url).await?; @@ -633,6 +767,34 @@ impl Collection { } #[instrument(skip(self))] + /// Performs search over the documents in a [Collection] + /// + /// # Arguments + /// + /// * `query` - A JSON object specifying the query to perform. + /// * `pipeline` - The [Pipeline] to use for the search. + /// + /// # Example + /// ``` + /// use pgml::Collection; + /// use pgml::Pipeline; + /// use serde_json::json; + /// use anyhow::Result; + /// async fn run() -> anyhow::Result<()> { + /// let mut collection = Collection::new("my_collection", None)?; + /// let mut pipeline = Pipeline::new("my_pipeline", None)?; + /// let results = collection.search(json!({ + /// "query": { + /// "semantic_search": { + /// "title": { + /// "query": "This is a an example query string", + /// }, + /// } + /// } + /// }).into(), &mut pipeline).await?; + /// Ok(()) + /// } + /// ``` pub async fn search(&mut self, query: Json, pipeline: &mut Pipeline) -> anyhow::Result { let pool = get_or_initialize_pool(&self.database_url).await?; let (built_query, values) = build_search_query(self, query.clone(), pipeline).await?; @@ -676,6 +838,7 @@ impl Collection { } #[instrument(skip(self))] + /// Same as search but the [Collection] is not mutable. This will not work with [Pipeline]s that use remote embeddings pub async fn search_local(&self, query: Json, pipeline: &Pipeline) -> anyhow::Result { let pool = get_or_initialize_pool(&self.database_url).await?; let (built_query, values) = build_search_query(self, query.clone(), pipeline).await?; @@ -689,6 +852,29 @@ impl Collection { Ok(results) } + /// Adds a search event to the database + /// + /// # Arguments + /// + /// * `search_id` - The id of the search + /// * `search_result` - The index of the search result + /// * `event` - The event to add + /// * `pipeline` - The [Pipeline] used for the search + /// + /// # Example + /// ``` + /// use pgml::Collection; + /// use pgml::Pipeline; + /// use serde_json::json; + /// use anyhow::Result; + /// async fn run() -> anyhow::Result<()> { + /// let mut collection = Collection::new("my_collection", None)?; + /// let mut pipeline = Pipeline::new("my_pipeline", None)?; + /// collection.add_search_event(1, 1, json!({ + /// "event": "click", + /// }).into(), &mut pipeline).await?; + /// Ok(()) + /// } #[instrument(skip(self))] pub async fn add_search_event( &self, @@ -723,6 +909,31 @@ impl Collection { } /// Performs vector search on the [Collection] + /// + /// # Arguments + /// * `query` - The query to search for + /// * `pipeline` - The [Pipeline] to use for the search + /// + /// # Example + /// ``` + /// use pgml::Collection; + /// use pgml::Pipeline; + /// use serde_json::json; + /// use anyhow::Result; + /// async fn run() -> anyhow::Result<()> { + /// let mut collection = Collection::new("my_collection", None)?; + /// let mut pipeline = Pipeline::new("my_pipeline", None)?; + /// let results = collection.vector_search(json!({ + /// "query": { + /// "fields": { + /// "title": { + /// "query": "This is an example query string" + /// } + /// } + /// } + /// }).into(), &mut pipeline).await?; + /// Ok(()) + /// } #[instrument(skip(self))] #[allow(clippy::type_complexity)] pub async fn vector_search( @@ -784,6 +995,20 @@ impl Collection { } } + /// Archives a [Collection] + /// This will free up the name to be reused. It does not delete it. + /// + /// # Example + /// ``` + /// use pgml::Collection; + /// use pgml::Pipeline; + /// use serde_json::json; + /// use anyhow::Result; + /// async fn run() -> anyhow::Result<()> { + /// let mut collection = Collection::new("my_collection", None)?; + /// collection.archive().await?; + /// Ok(()) + /// } #[instrument(skip(self))] pub async fn archive(&mut self) -> anyhow::Result<()> { let pool = get_or_initialize_pool(&self.database_url).await?; @@ -822,12 +1047,26 @@ impl Collection { Ok(()) } + /// A legacy query builder. + #[deprecated(since = "1.0.0", note = "please use `vector_search` instead")] #[instrument(skip(self))] pub fn query(&self) -> QueryBuilder { QueryBuilder::new(self.clone()) } /// Gets all pipelines for the [Collection] + /// + /// # Example + /// ``` + /// use pgml::Collection; + /// use pgml::Pipeline; + /// use serde_json::json; + /// use anyhow::Result; + /// async fn run() -> anyhow::Result<()> { + /// let mut collection = Collection::new("my_collection", None)?; + /// let pipelines = collection.get_pipelines().await?; + /// Ok(()) + /// } #[instrument(skip(self))] pub async fn get_pipelines(&mut self) -> anyhow::Result> { self.verify_in_database(false).await?; @@ -842,6 +1081,21 @@ impl Collection { } /// Gets a [Pipeline] by name + /// + /// # Arguments + /// * `name` - The name of the [Pipeline] + /// + /// # Example + /// ``` + /// use pgml::Collection; + /// use pgml::Pipeline; + /// use serde_json::json; + /// use anyhow::Result; + /// async fn run() -> anyhow::Result<()> { + /// let mut collection = Collection::new("my_collection", None)?; + /// let pipeline = collection.get_pipeline("my_pipeline").await?; + /// Ok(()) + /// } #[instrument(skip(self))] pub async fn get_pipeline(&mut self, name: &str) -> anyhow::Result { self.verify_in_database(false).await?; @@ -857,6 +1111,18 @@ impl Collection { } /// Check if the [Collection] exists in the database + /// + /// # Example + /// ``` + /// use pgml::Collection; + /// use pgml::Pipeline; + /// use serde_json::json; + /// use anyhow::Result; + /// async fn run() -> anyhow::Result<()> { + /// let mut collection = Collection::new("my_collection", None)?; + /// let exists = collection.exists().await?; + /// Ok(()) + /// } #[instrument(skip(self))] pub async fn exists(&self) -> anyhow::Result { let pool = get_or_initialize_pool(&self.database_url).await?; @@ -869,6 +1135,29 @@ impl Collection { Ok(collection.is_some()) } + /// Upsert all files in a directory that match the file_types + /// + /// # Arguments + /// * `path` - The path to the directory to upsert + /// * `args` - A [Json](serde_json::Value) object with the following keys: + /// * `file_types` - An array of file extensions to match. E.G. ['md', 'txt'] + /// * `file_batch_size` - The number of files to upsert at a time. Defaults to 10. + /// * `follow_links` - Whether to follow symlinks. Defaults to false. + /// * `ignore_paths` - An array of regexes to ignore. E.G. ['.*ignore.*'] + /// + /// # Example + /// ``` + /// use pgml::Collection; + /// use pgml::Pipeline; + /// use serde_json::json; + /// use anyhow::Result; + /// async fn run() -> anyhow::Result<()> { + /// let mut collection = Collection::new("my_collection", None)?; + /// collection.upsert_directory("/path/to/my/files", json!({ + /// "file_types": ["md", "txt"] + /// }).into()).await?; + /// Ok(()) + /// } #[instrument(skip(self))] pub async fn upsert_directory(&mut self, path: &str, args: Json) -> anyhow::Result<()> { self.verify_in_database(false).await?; @@ -944,6 +1233,22 @@ impl Collection { Ok(()) } + /// Gets the sync status of a [Pipeline] + /// + /// # Arguments + /// * `pipeline` - The [Pipeline] to get the sync status of + /// + /// # Example + /// ``` + /// use pgml::Collection; + /// use pgml::Pipeline; + /// use anyhow::Result; + /// async fn run() -> anyhow::Result<()> { + /// let mut collection = Collection::new("my_collection", None)?; + /// let mut pipeline = Pipeline::new("my_pipeline", None)?; + /// let status = collection.get_pipeline_status(&mut pipeline).await?; + /// Ok(()) + /// } #[instrument(skip(self))] pub async fn get_pipeline_status(&mut self, pipeline: &mut Pipeline) -> anyhow::Result { self.verify_in_database(false).await?; @@ -952,6 +1257,20 @@ impl Collection { pipeline.get_status(project_info, &pool).await } + #[instrument(skip(self))] + /// Generates a PlantUML ER Diagram for a [Collection] and [Pipeline] tables + /// + /// # Example + /// ``` + /// use pgml::Collection; + /// use pgml::Pipeline; + /// use anyhow::Result; + /// async fn run() -> anyhow::Result<()> { + /// let mut collection = Collection::new("my_collection", None)?; + /// let mut pipeline = Pipeline::new("my_pipeline", None)?; + /// let er_diagram = collection.generate_er_diagram(&mut pipeline).await?; + /// Ok(()) + /// } #[instrument(skip(self))] pub async fn generate_er_diagram(&mut self, pipeline: &mut Pipeline) -> anyhow::Result { self.verify_in_database(false).await?; @@ -1074,6 +1393,21 @@ entity "{schema}.{key}_tsvectors" as {nice_name_key}_tsvectors {{ Ok(uml_entites) } + /// Upserts a file into a [Collection] + /// + /// # Arguments + /// * `path` - The path to the file to upsert + /// + /// # Example + /// ``` + /// use pgml::Collection; + /// use anyhow::Result; + /// async fn run() -> anyhow::Result<()> { + /// let mut collection = Collection::new("my_collection", None)?; + /// collection.upsert_file("my_file.txt").await?; + /// Ok(()) + /// } + #[instrument(skip(self))] pub async fn upsert_file(&mut self, path: &str) -> anyhow::Result<()> { self.verify_in_database(false).await?; let path = Path::new(path); @@ -1085,16 +1419,11 @@ entity "{schema}.{key}_tsvectors" as {nice_name_key}_tsvectors {{ self.upsert_documents(vec![document.into()], None).await } - fn generate_table_names(name: &str) -> (String, String, String, String) { - [ - ".pipelines", - ".documents", - ".chunks", - ".documents_tsvectors", - ] - .into_iter() - .map(|s| format!("{}{}", name, s)) - .collect_tuple() - .unwrap() + fn generate_table_names(name: &str) -> (String, String) { + [".pipelines", ".documents"] + .into_iter() + .map(|s| format!("{}{}", name, s)) + .collect_tuple() + .unwrap() } } diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index 50665ed93..34b02ce53 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -20,7 +20,7 @@ mod filter_builder; mod languages; pub mod migrations; mod model; -pub mod models; +mod models; mod open_source_ai; mod order_by_builder; mod pipeline; diff --git a/pgml-sdks/pgml/src/model.rs b/pgml-sdks/pgml/src/model.rs index ff320c0de..432654298 100644 --- a/pgml-sdks/pgml/src/model.rs +++ b/pgml-sdks/pgml/src/model.rs @@ -54,10 +54,10 @@ pub(crate) struct ModelDatabaseData { /// A model used for embedding, inference, etc... #[derive(alias, Debug, Clone)] pub struct Model { - pub name: String, - pub runtime: ModelRuntime, - pub parameters: Json, - pub(crate) database_data: Option, + pub(crate) name: String, + pub(crate) runtime: ModelRuntime, + pub(crate) parameters: Json, + database_data: Option, } impl Default for Model { @@ -69,19 +69,6 @@ impl Default for Model { #[alias_methods(new, transform)] impl Model { /// Creates a new [Model] - /// - /// # Arguments - /// - /// * `name` - The name of the model. - /// * `source` - The source of the model. Defaults to `pgml`, but can be set to providers like `openai`. - /// * `parameters` - The parameters to the model. Defaults to None - /// - /// # Example - /// - /// ``` - /// use pgml::Model; - /// let model = Model::new(Some("intfloat/e5-small".to_string()), None, None, None); - /// ``` pub fn new(name: Option, source: Option, parameters: Option) -> Self { let name = name.unwrap_or("intfloat/e5-small".to_string()); let parameters = parameters.unwrap_or(Json(serde_json::json!({}))); diff --git a/pgml-sdks/pgml/src/open_source_ai.rs b/pgml-sdks/pgml/src/open_source_ai.rs index d4c02215e..e21397a31 100644 --- a/pgml-sdks/pgml/src/open_source_ai.rs +++ b/pgml-sdks/pgml/src/open_source_ai.rs @@ -13,6 +13,7 @@ use crate::{ #[cfg(feature = "python")] use crate::types::{GeneralJsonAsyncIteratorPython, GeneralJsonIteratorPython, JsonPython}; +/// A drop in replacement for OpenAI #[derive(alias, Debug, Clone)] pub struct OpenSourceAI { database_url: Option, @@ -169,6 +170,20 @@ impl Iterator for AsyncToSyncJsonIterator { chat_completions_create_stream_async )] impl OpenSourceAI { + /// Creates a new [OpenSourceAI] + /// + /// # Arguments + /// + /// * `database_url`: The database url to use. If `None`, `PGML_DATABASE_URL` environment variable will be used. + /// + /// # Example + /// ``` + /// use pgml::OpenSourceAI; + /// async fn run() -> anyhow::Result<()> { + /// let ai = OpenSourceAI::new(None); + /// Ok(()) + /// } + /// ``` pub fn new(database_url: Option) -> Self { Self { database_url } } @@ -216,6 +231,7 @@ mistralai/Mistral-7B-v0.1 } } + /// Returns an async iterator of completions #[allow(clippy::too_many_arguments)] pub async fn chat_completions_create_stream_async( &self, @@ -278,6 +294,7 @@ mistralai/Mistral-7B-v0.1 Ok(GeneralJsonAsyncIterator(Box::pin(iter))) } + /// Returns an iterator of completions #[allow(clippy::too_many_arguments)] pub fn chat_completions_create_stream( &self, @@ -302,6 +319,7 @@ mistralai/Mistral-7B-v0.1 )))) } + /// An async function that returns completions #[allow(clippy::too_many_arguments)] pub async fn chat_completions_create_async( &self, @@ -371,6 +389,7 @@ mistralai/Mistral-7B-v0.1 .into()) } + /// A function that returns completions #[allow(clippy::too_many_arguments)] pub fn chat_completions_create( &self, diff --git a/pgml-sdks/pgml/src/pipeline.rs b/pgml-sdks/pgml/src/pipeline.rs index 6dada5159..02b059db3 100644 --- a/pgml-sdks/pgml/src/pipeline.rs +++ b/pgml-sdks/pgml/src/pipeline.rs @@ -175,11 +175,12 @@ pub struct PipelineDatabaseData { pub created_at: DateTime, } +/// A pipeline that describes transformations to documents #[derive(alias, Debug, Clone)] pub struct Pipeline { - pub name: String, - pub schema: Option, - pub parsed_schema: Option, + pub(crate) name: String, + pub(crate) schema: Option, + pub(crate) parsed_schema: Option, database_data: Option, } @@ -203,6 +204,11 @@ fn json_to_schema(schema: &Json) -> anyhow::Result { #[alias_methods(new)] impl Pipeline { + /// Creates a [Pipeline] + /// + /// # Arguments + /// * `name` - The name of the pipeline + /// * `schema` - The schema of the pipeline. This is a JSON object where the keys are the field names and the values are the field actions. pub fn new(name: &str, schema: Option) -> anyhow::Result { let parsed_schema = schema.as_ref().map(json_to_schema).transpose()?; Ok(Self { @@ -215,7 +221,7 @@ impl Pipeline { /// Gets the status of the [Pipeline] #[instrument(skip(self))] - pub async fn get_status( + pub(crate) async fn get_status( &mut self, project_info: &ProjectInfo, pool: &Pool, diff --git a/pgml-sdks/pgml/src/splitter.rs b/pgml-sdks/pgml/src/splitter.rs index 96b1ed9da..a0847c879 100644 --- a/pgml-sdks/pgml/src/splitter.rs +++ b/pgml-sdks/pgml/src/splitter.rs @@ -21,8 +21,8 @@ pub(crate) struct SplitterDatabaseData { /// A text splitter #[derive(alias, Debug, Clone)] pub struct Splitter { - pub name: String, - pub parameters: Json, + pub(crate) name: String, + pub(crate) parameters: Json, pub(crate) database_data: Option, } diff --git a/pgml-sdks/pgml/src/transformer_pipeline.rs b/pgml-sdks/pgml/src/transformer_pipeline.rs index d20089463..43154615b 100644 --- a/pgml-sdks/pgml/src/transformer_pipeline.rs +++ b/pgml-sdks/pgml/src/transformer_pipeline.rs @@ -117,6 +117,13 @@ impl Stream for TransformerStream { #[alias_methods(new, transform, transform_stream)] impl TransformerPipeline { + /// Creates a new [TransformerPipeline] + /// + /// # Arguments + /// * `task` - The task to run + /// * `model` - The model to use + /// * `args` - The arguments to pass to the task + /// * `database_url` - The database url to use. If None, the `PGML_DATABASE_URL` environment variable will be used pub fn new( task: &str, model: Option, @@ -141,6 +148,11 @@ impl TransformerPipeline { } } + /// Calls transform + /// + /// # Arguments + /// * `inputs` - The inputs to the task + /// * `args` - The arguments to pass to the task #[instrument(skip(self))] pub async fn transform(&self, inputs: Vec, args: Option) -> anyhow::Result { let pool = get_or_initialize_pool(&self.database_url).await?; @@ -176,6 +188,9 @@ impl TransformerPipeline { Ok(Json(results)) } + /// Calls transform + /// The same as transformer but it returns an iterator + /// The `batch_size` argument can be used to control the number of results returned in each batch #[instrument(skip(self))] pub async fn transform_stream( &self, diff --git a/pgml-sdks/pgml/src/types.rs b/pgml-sdks/pgml/src/types.rs index 1a51e4f20..34d93be5c 100644 --- a/pgml-sdks/pgml/src/types.rs +++ b/pgml-sdks/pgml/src/types.rs @@ -6,8 +6,7 @@ use sea_query::Iden; use serde::{Deserialize, Serialize}; use std::ops::{Deref, DerefMut}; -/// A wrapper around serde_json::Value -// #[derive(sqlx::Type, sqlx::FromRow, Debug)] +/// A wrapper around `serde_json::Value` #[derive(alias_manual, sqlx::Type, Debug, Clone, Deserialize, PartialEq, Eq)] #[sqlx(transparent)] pub struct Json(pub serde_json::Value); @@ -80,7 +79,7 @@ impl TryToNumeric for serde_json::Value { } } -/// A wrapper around sqlx::types::PrimitiveDateTime +/// A wrapper around `sqlx::types::PrimitiveDateTime` #[derive(sqlx::Type, Debug, Clone)] #[sqlx(transparent)] pub struct DateTime(pub sqlx::types::time::PrimitiveDateTime); @@ -124,6 +123,7 @@ impl IntoTableNameAndSchema for String { } } +/// A wrapper around `std::pin::Pin> + Send>>` #[derive(alias_manual)] pub struct GeneralJsonAsyncIterator( pub std::pin::Pin> + Send>>, @@ -140,6 +140,7 @@ impl Stream for GeneralJsonAsyncIterator { } } +/// A wrapper around `Box> + Send>` #[derive(alias_manual)] pub struct GeneralJsonIterator(pub Box> + Send>);