Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,10 @@ def pyarrow_to_schema(schema: pa.Schema, name_mapping: Optional[NameMapping] = N
return visit_pyarrow(schema, visitor)


def _pyarrow_to_schema_without_ids(schema: pa.Schema) -> Schema:
return visit_pyarrow(schema, _ConvertToIcebergWithoutIDs())


@singledispatch
def visit_pyarrow(obj: Union[pa.DataType, pa.Schema], visitor: PyArrowSchemaVisitor[T]) -> T:
"""Apply a pyarrow schema visitor to any point within a schema.
Expand Down Expand Up @@ -1725,6 +1729,7 @@ def write_file(table: Table, tasks: Iterator[WriteTask]) -> Iterator[DataFile]:

collected_metrics: List[pq.FileMetaData] = []
fo = table.io.new_output(file_path)

with fo.create(overwrite=True) as fos:
with pq.ParquetWriter(fos, schema=file_schema, version="1.0", metadata_collector=collected_metrics) as writer:
writer.write_table(task.df)
Expand Down
11 changes: 11 additions & 0 deletions pyiceberg/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from dataclasses import dataclass
from functools import cached_property, partial, singledispatch
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Expand Down Expand Up @@ -62,6 +63,11 @@
UUIDType,
)

if TYPE_CHECKING:
from pyiceberg.table.name_mapping import (
NameMapping,
)

T = TypeVar("T")
P = TypeVar("P")

Expand Down Expand Up @@ -221,6 +227,11 @@ def find_type(self, name_or_id: Union[str, int], case_sensitive: bool = True) ->
def highest_field_id(self) -> int:
return max(self._lazy_id_to_name.keys(), default=0)

def name_mapping(self) -> NameMapping:
Comment thread
HonahX marked this conversation as resolved.
from pyiceberg.table.name_mapping import create_mapping_from_schema

return create_mapping_from_schema(self)

def find_column_name(self, column_id: int) -> Optional[str]:
"""Find a column name given a column ID.

Expand Down
44 changes: 42 additions & 2 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@
from pyiceberg.table.name_mapping import (
SCHEMA_NAME_MAPPING_DEFAULT,
NameMapping,
create_mapping_from_schema,
parse_mapping_from_json,
)
from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef
Expand Down Expand Up @@ -133,6 +132,43 @@
_JAVA_LONG_MAX = 9223372036854775807


def _check_schema(table_schema: Schema, other_schema: "pa.Schema") -> None:
from pyiceberg.io.pyarrow import _pyarrow_to_schema_without_ids, pyarrow_to_schema

name_mapping = table_schema.name_mapping()
Comment thread
Fokko marked this conversation as resolved.
Outdated
try:
task_schema = pyarrow_to_schema(other_schema, name_mapping=name_mapping)
except ValueError as e:
names = itertools.chain(*[field.names for field in name_mapping])
other_schema = _pyarrow_to_schema_without_ids(other_schema)
other_names = itertools.chain(*[field.names for field in other_schema.name_mapping()])
additional_names = set(other_names) - set(names)
raise ValueError(
f"PyArrow table contains more columns: {', '.join(sorted(additional_names))}. Update the schema first (hint, use union_by_name)."
) from e

if table_schema.as_struct() != task_schema.as_struct():
from rich.console import Console
from rich.table import Table as RichTable

console = Console(record=True)

rich_table = RichTable(show_header=True, header_style="bold")
rich_table.add_column("")
rich_table.add_column("Table field")
rich_table.add_column("Dataframe field")

for lhs in table_schema.fields:
try:
rhs = task_schema.find_field(lhs.field_id)
rich_table.add_row("✅" if lhs == rhs else "❌", str(lhs), str(rhs))
Comment thread
HonahX marked this conversation as resolved.
except ValueError:
rich_table.add_row("❌", str(lhs), "Missing")

console.print(rich_table)
raise ValueError(f"Mismatch in fields:\n{console.export_text()}")


class Transaction:
_table: Table
_updates: Tuple[TableUpdate, ...]
Expand Down Expand Up @@ -923,7 +959,7 @@ def name_mapping(self) -> NameMapping:
if name_mapping_json := self.properties.get(SCHEMA_NAME_MAPPING_DEFAULT):
return parse_mapping_from_json(name_mapping_json)
else:
return create_mapping_from_schema(self.schema())
return self.schema().name_mapping()
Comment thread
Fokko marked this conversation as resolved.
Outdated

def append(self, df: pa.Table) -> None:
"""
Expand All @@ -943,6 +979,8 @@ def append(self, df: pa.Table) -> None:
if len(self.spec().fields) > 0:
raise ValueError("Cannot write to partitioned tables")

_check_schema(self.schema(), other_schema=df.schema)

data_files = _dataframe_to_data_files(self, df=df)
merge = _MergingSnapshotProducer(operation=Operation.APPEND, table=self)
for data_file in data_files:
Expand Down Expand Up @@ -973,6 +1011,8 @@ def overwrite(self, df: pa.Table, overwrite_filter: BooleanExpression = ALWAYS_T
if len(self.spec().fields) > 0:
raise ValueError("Cannot write to partitioned tables")

_check_schema(self.schema(), other_schema=df.schema)

data_files = _dataframe_to_data_files(self, df=df)
merge = _MergingSnapshotProducer(
operation=Operation.OVERWRITE if self.current_snapshot() is not None else Operation.APPEND,
Expand Down
6 changes: 5 additions & 1 deletion pyiceberg/table/name_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from abc import ABC, abstractmethod
from collections import ChainMap
from functools import cached_property, singledispatch
from typing import Any, Dict, Generic, List, TypeVar, Union
from typing import Any, Dict, Generic, Iterator, List, TypeVar, Union

from pydantic import Field, conlist, field_validator, model_serializer

Expand Down Expand Up @@ -87,6 +87,10 @@ def __len__(self) -> int:
"""Return the number of mappings."""
return len(self.root)

def __iter__(self) -> Iterator[MappedField]:
"""Iterate over the mapped fields."""
return iter(self.root)

def __str__(self) -> str:
"""Convert the name-mapping into a nicely formatted string."""
if len(self.root) == 0:
Expand Down
78 changes: 78 additions & 0 deletions tests/table/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from copy import copy
from typing import Dict

import pyarrow as pa
import pytest
from sortedcontainers import SortedList

Expand Down Expand Up @@ -58,6 +59,7 @@
Table,
UpdateSchema,
_apply_table_update,
_check_schema,
_generate_snapshot_id,
_match_deletes_to_data_file,
_TableMetadataUpdateContext,
Expand Down Expand Up @@ -982,3 +984,79 @@ def test_correct_schema() -> None:
_ = t.scan(snapshot_id=-1).projection()

assert "Snapshot not found: -1" in str(exc_info.value)


def test_schema_mismatch_type(table_schema_simple: Schema) -> None:
other_schema = pa.schema((
pa.field("foo", pa.string(), nullable=True),
pa.field("bar", pa.decimal128(18, 6), nullable=False),
pa.field("baz", pa.bool_(), nullable=True),
))

expected = r"""Mismatch in fields:
┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ ┃ Table field ┃ Dataframe field ┃
┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ ✅ │ 1: foo: optional string │ 1: foo: optional string │
│ ❌ │ 2: bar: required int │ 2: bar: required decimal\(18, 6\) │
│ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │
└────┴──────────────────────────┴─────────────────────────────────┘
"""

with pytest.raises(ValueError, match=expected):
_check_schema(table_schema_simple, other_schema)


def test_schema_mismatch_nullability(table_schema_simple: Schema) -> None:
other_schema = pa.schema((
pa.field("foo", pa.string(), nullable=True),
pa.field("bar", pa.int32(), nullable=True),
pa.field("baz", pa.bool_(), nullable=True),
))

expected = """Mismatch in fields:
┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ ┃ Table field ┃ Dataframe field ┃
┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ ✅ │ 1: foo: optional string │ 1: foo: optional string │
│ ❌ │ 2: bar: required int │ 2: bar: optional int │
│ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │
└────┴──────────────────────────┴──────────────────────────┘
"""

with pytest.raises(ValueError, match=expected):
_check_schema(table_schema_simple, other_schema)


def test_schema_mismatch_missing_field(table_schema_simple: Schema) -> None:
other_schema = pa.schema((
pa.field("foo", pa.string(), nullable=True),
pa.field("baz", pa.bool_(), nullable=True),
))

expected = """Mismatch in fields:
┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ ┃ Table field ┃ Dataframe field ┃
┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ ✅ │ 1: foo: optional string │ 1: foo: optional string │
│ ❌ │ 2: bar: required int │ Missing │
│ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │
└────┴──────────────────────────┴──────────────────────────┘
"""

with pytest.raises(ValueError, match=expected):
_check_schema(table_schema_simple, other_schema)


def test_schema_mismatch_additional_field(table_schema_simple: Schema) -> None:
other_schema = pa.schema((
pa.field("foo", pa.string(), nullable=True),
pa.field("bar", pa.int32(), nullable=True),
pa.field("baz", pa.bool_(), nullable=True),
pa.field("new_field", pa.date32(), nullable=True),
))

expected = r"PyArrow table contains more columns: new_field. Update the schema first \(hint, use union_by_name\)."

with pytest.raises(ValueError, match=expected):
_check_schema(table_schema_simple, other_schema)