diff --git a/pgml-extension/src/orm/model.rs b/pgml-extension/src/orm/model.rs index 5c2f75230..0750c534e 100644 --- a/pgml-extension/src/orm/model.rs +++ b/pgml-extension/src/orm/model.rs @@ -954,6 +954,12 @@ impl Model { .unwrap() .map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| k.to_string()) } + pgrx_pg_sys::NUMERICOID => { + let element: Result, TryFromDatumError> = tuple.get_by_index(index); + element + .unwrap() + .map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| k.to_string()) + } _ => error!( "Unsupported type for categorical column: {:?}. oid: {:?}", column.name, attribute.atttypid @@ -992,6 +998,10 @@ impl Model { let element: Result, TryFromDatumError> = tuple.get_by_index(index); features.push(element.unwrap().map_or(f32::NAN, |v| v as f32)); } + pgrx_pg_sys::NUMERICOID => { + let element: Result, TryFromDatumError> = tuple.get_by_index(index); + features.push(element.unwrap().map_or(f32::NAN, |v| v.try_into().unwrap())); + } // TODO handle NULL to NaN for arrays pgrx_pg_sys::BOOLARRAYOID => { let element: Result>, TryFromDatumError> = @@ -1035,6 +1045,13 @@ impl Model { features.push(*j as f32); } } + pgrx_pg_sys::NUMERICARRAYOID => { + let element: Result>, TryFromDatumError> = + tuple.get_by_index(index); + for j in element.as_ref().unwrap().as_ref().unwrap() { + features.push(j.clone().try_into().unwrap()); + } + } _ => error!( "Unsupported type for quantitative column: {:?}. oid: {:?}", column.name, attribute.atttypid diff --git a/pgml-extension/src/orm/snapshot.rs b/pgml-extension/src/orm/snapshot.rs index 6a5973148..1bc27911c 100644 --- a/pgml-extension/src/orm/snapshot.rs +++ b/pgml-extension/src/orm/snapshot.rs @@ -990,6 +990,7 @@ impl Snapshot { "int8" => row[column.position].value::().unwrap().map(|v| v.to_string()), "float4" => row[column.position].value::().unwrap().map(|v| v.to_string()), "float8" => row[column.position].value::().unwrap().map(|v| v.to_string()), + "numeric" => row[column.position].value::().unwrap().map(|v| v.to_string()), "bpchar" | "text" | "varchar" => { row[column.position].value::().unwrap().map(|v| v.to_string()) } @@ -1078,6 +1079,14 @@ impl Snapshot { vector.push(j as f32) } } + "numeric[]" => { + let vec = row[column.position].value::>().unwrap().unwrap(); + check_column_size(column, vec.len()); + + for j in vec { + vector.push(j.rescale::<6,0>().unwrap().try_into().unwrap()) + } + } _ => error!( "Unhandled type for quantitative array column: {} {:?}", column.name, column.pg_type @@ -1092,6 +1101,7 @@ impl Snapshot { "int8" => row[column.position].value::().unwrap().map(|v| v as f32), "float4" => row[column.position].value::().unwrap(), "float8" => row[column.position].value::().unwrap().map(|v| v as f32), + "numeric" => row[column.position].value::().unwrap().map(|v| v.rescale::<6,0>().unwrap().try_into().unwrap()), _ => error!( "Unhandled type for quantitative scalar column: {} {:?}", column.name, column.pg_type