Skip to content

Commit

Permalink
feat: Implement postgres vector extension (#137)
Browse files Browse the repository at this point in the history
* change to fork

(cherry picked from commit 0806dba)

* add vector

(cherry picked from commit 6d101f1)

* parse size

(cherry picked from commit 99513fb)

* fmt

* feat: Implement PgVector Type
  • Loading branch information
28Smiles authored Dec 24, 2024
1 parent 8c82cc3 commit cf94ae9
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 0 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ default = ["mysql", "postgres", "sqlite", "discovery", "writer", "probe"]
debug-print = ["log"]
mysql = ["sea-query/backend-mysql"]
postgres = ["sea-query/backend-postgres"]
postgres-vector = ["sea-query/postgres-vector", "sea-query-binder/postgres-vector"]
sqlite = ["sea-query/backend-sqlite"]
def = []
discovery = ["futures", "parser"]
Expand Down
17 changes: 17 additions & 0 deletions src/postgres/def/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,10 @@ pub enum Type {
/// Variable-length multidimensional array
Array(ArrayDef),

#[cfg(feature = "postgres-vector")]
/// The postgres vector type introduced by the vector extension.
Vector(VectorDef),

// TODO:
// /// The structure of a row or record; a list of field names and types
// Composite,
Expand Down Expand Up @@ -268,6 +272,14 @@ pub struct ArrayDef {
pub col_type: Option<RcOrArc<Type>>,
}

#[cfg(feature = "postgres-vector")]
/// Defines an enum for the PostgreSQL module
#[derive(Clone, Debug, PartialEq, Default)]
#[cfg_attr(feature = "with-serde", derive(Serialize, Deserialize))]
pub struct VectorDef {
pub length: Option<u32>,
}

impl Type {
pub fn has_numeric_attr(&self) -> bool {
matches!(self, Type::Numeric(_) | Type::Decimal(_))
Expand Down Expand Up @@ -302,4 +314,9 @@ impl Type {
pub fn has_array_attr(&self) -> bool {
matches!(self, Type::Array(_))
}

#[cfg(feature = "postgres-vector")]
pub fn has_vector_attr(&self) -> bool {
matches!(self, Type::Vector(_))
}
}
25 changes: 25 additions & 0 deletions src/postgres/parser/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ pub fn parse_column_type(result: &ColumnQueryResult, enums: &EnumVariantMap) ->
if ctype.has_array_attr() {
ctype = parse_array_attributes(result.udt_name_regtype.as_deref(), ctype, enums);
}
#[cfg(feature = "postgres-vector")]
if ctype.has_vector_attr() {
ctype = parse_vector_attributes(result.character_maximum_length, ctype);
}

ctype
}
Expand Down Expand Up @@ -240,3 +244,24 @@ pub fn parse_array_attributes(

ctype
}

#[cfg(feature = "postgres-vector")]
pub fn parse_vector_attributes(
character_maximum_length: Option<i32>,
mut ctype: ColumnType,
) -> ColumnType {
match ctype {
Type::Vector(ref mut attr) => {
attr.length = match character_maximum_length {
None => None,
Some(num) => match u32::try_from(num) {
Ok(num) => Some(num),
Err(_) => None,
},
};
}
_ => panic!("parse_vector_attributes(_) received a type that does not have StringAttr"),
};

ctype
}
5 changes: 5 additions & 0 deletions src/postgres/writer/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,11 @@ impl ColumnInfo {
Type::TsTzRange => ColumnType::Custom(Alias::new("tstzrange").into_iden()),
Type::DateRange => ColumnType::Custom(Alias::new("daterange").into_iden()),
Type::PgLsn => ColumnType::Custom(Alias::new("pg_lsn").into_iden()),
#[cfg(feature = "postgres-vector")]
Type::Vector(vector_attr) => match vector_attr.length {
Some(length) => ColumnType::Vector(Some(length)),
None => ColumnType::Vector(None),
},
Type::Unknown(s) => ColumnType::Custom(Alias::new(s).into_iden()),
Type::Enum(enum_def) => {
let name = Alias::new(&enum_def.typename).into_iden();
Expand Down

0 comments on commit cf94ae9

Please sign in to comment.