Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion speechbrain/dataio/dataio.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def _recursive_format(data, replacements):
# If not dict, list or str, do nothing


def load_data_csv(csv_path, replacements=None):
def load_data_csv(csv_path, replacements=None, col_types=None):
"""Loads CSV and formats string values.

Uses the SpeechBrain legacy CSV data format, where the CSV must have an
Expand All @@ -108,6 +108,11 @@ def load_data_csv(csv_path, replacements=None):
The rest of the fields are left as they are (legacy _format and _opts fields
are not used to load the data in any special way).

When ``col_types`` is provided, the specified columns are explicitly cast
to the given types. This is useful when numeric columns (e.g. duration,
speaker id) need to be stored as a specific Python type rather than as
strings.

Bash-like string replacements with $to_replace are supported.

Arguments
Expand All @@ -117,6 +122,10 @@ def load_data_csv(csv_path, replacements=None):
replacements : dict
(Optional dict), e.g., {"data_folder": "/home/speechbrain/data"}
This is used to recursively format all string values in the data.
col_types : dict, optional
A mapping from column name to a callable type, e.g.
``{"duration": float, "spk_id": int}``. Only the listed columns
are cast; all others are left as strings.

Returns
-------
Expand Down Expand Up @@ -169,6 +178,11 @@ def load_data_csv(csv_path, replacements=None):
# Duration:
if "duration" in row:
row["duration"] = float(row["duration"])
# Explicit column type casting:
if col_types:
for col, typ in col_types.items():
if col in row:
row[col] = typ(row[col])
result[data_id] = row
return result

Expand Down
27 changes: 24 additions & 3 deletions speechbrain/dataio/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,10 +453,31 @@ def from_json(

@classmethod
def from_csv(
cls, csv_path, replacements={}, dynamic_items=[], output_keys=[]
cls,
csv_path,
replacements={},
dynamic_items=[],
output_keys=[],
col_types=None,
):
"""Load a data prep CSV file and create a Dataset based on it."""
data = load_data_csv(csv_path, replacements)
"""Load a data prep CSV file and create a Dataset based on it.

Parameters
----------
csv_path : str
Path to the CSV file.
replacements : dict
Substitutions for $-variables in the CSV values.
dynamic_items : list
Dynamic items to add to the dataset.
output_keys : list
Keys to include in the output.
col_types : dict, optional
Mapping from column name to type, e.g.
``{"duration": float, "spk_id": int}``.
Passed through to :func:`load_data_csv`.
"""
data = load_data_csv(csv_path, replacements, col_types=col_types)
return cls(data, dynamic_items, output_keys)

@classmethod
Expand Down
28 changes: 28 additions & 0 deletions tests/unittests/test_data_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,34 @@
import torch


def test_load_data_csv_col_types(tmpdir):
"""col_types should cast only the specified columns."""
from speechbrain.dataio.dataio import load_data_csv

csv_content = "ID,duration,spk_id,transcript\nutt1,1.45,007,hello world\nutt2,2.0,42,bye\n"
csv_path = os.path.join(tmpdir, "test_col_types.csv")
with open(csv_path, "w", encoding="utf-8") as f:
f.write(csv_content)

# without col_types, duration is float (legacy), rest are strings
data = load_data_csv(csv_path)
assert isinstance(data["utt1"]["duration"], float)
assert isinstance(data["utt1"]["spk_id"], str)
assert data["utt1"]["spk_id"] == "007" # leading zero preserved

# with col_types, specified columns get cast
data = load_data_csv(csv_path, col_types={"spk_id": int})
assert data["utt1"]["spk_id"] == 7
assert isinstance(data["utt1"]["spk_id"], int)
assert isinstance(data["utt1"]["duration"], float) # still float (legacy)
assert isinstance(data["utt1"]["transcript"], str)

# col_types can override duration type too
data = load_data_csv(csv_path, col_types={"duration": str})
# col_types is applied after the legacy duration cast, so it re-casts
assert isinstance(data["utt1"]["duration"], str)


def test_read_audio_info(tmpdir, device):
from speechbrain.dataio.dataio import read_audio_info, write_audio

Expand Down