From 48ada0558f187aaa6f98649a8093714c622c2d58 Mon Sep 17 00:00:00 2001 From: knQzx <75641500+knQzx@users.noreply.github.com> Date: Sat, 4 Apr 2026 18:08:23 +0200 Subject: [PATCH] use explicit col_types dict instead of auto-casting --- speechbrain/dataio/dataio.py | 16 +++++++++++++++- speechbrain/dataio/dataset.py | 27 ++++++++++++++++++++++++--- tests/unittests/test_data_io.py | 28 ++++++++++++++++++++++++++++ 3 files changed, 67 insertions(+), 4 deletions(-) diff --git a/speechbrain/dataio/dataio.py b/speechbrain/dataio/dataio.py index 0385ade1c3..30222a8b4f 100644 --- a/speechbrain/dataio/dataio.py +++ b/speechbrain/dataio/dataio.py @@ -99,7 +99,7 @@ def _recursive_format(data, replacements): # If not dict, list or str, do nothing -def load_data_csv(csv_path, replacements=None): +def load_data_csv(csv_path, replacements=None, col_types=None): """Loads CSV and formats string values. Uses the SpeechBrain legacy CSV data format, where the CSV must have an @@ -108,6 +108,11 @@ def load_data_csv(csv_path, replacements=None): The rest of the fields are left as they are (legacy _format and _opts fields are not used to load the data in any special way). + When ``col_types`` is provided, the specified columns are explicitly cast + to the given types. This is useful when numeric columns (e.g. duration, + speaker id) need to be stored as a specific Python type rather than as + strings. + Bash-like string replacements with $to_replace are supported. Arguments @@ -117,6 +122,10 @@ def load_data_csv(csv_path, replacements=None): replacements : dict (Optional dict), e.g., {"data_folder": "/home/speechbrain/data"} This is used to recursively format all string values in the data. + col_types : dict, optional + A mapping from column name to a callable type, e.g. + ``{"duration": float, "spk_id": int}``. Only the listed columns + are cast; all others are left as strings. Returns ------- @@ -169,6 +178,11 @@ def load_data_csv(csv_path, replacements=None): # Duration: if "duration" in row: row["duration"] = float(row["duration"]) + # Explicit column type casting: + if col_types: + for col, typ in col_types.items(): + if col in row: + row[col] = typ(row[col]) result[data_id] = row return result diff --git a/speechbrain/dataio/dataset.py b/speechbrain/dataio/dataset.py index 1ec508385d..17eebb6f18 100644 --- a/speechbrain/dataio/dataset.py +++ b/speechbrain/dataio/dataset.py @@ -453,10 +453,31 @@ def from_json( @classmethod def from_csv( - cls, csv_path, replacements={}, dynamic_items=[], output_keys=[] + cls, + csv_path, + replacements={}, + dynamic_items=[], + output_keys=[], + col_types=None, ): - """Load a data prep CSV file and create a Dataset based on it.""" - data = load_data_csv(csv_path, replacements) + """Load a data prep CSV file and create a Dataset based on it. + + Parameters + ---------- + csv_path : str + Path to the CSV file. + replacements : dict + Substitutions for $-variables in the CSV values. + dynamic_items : list + Dynamic items to add to the dataset. + output_keys : list + Keys to include in the output. + col_types : dict, optional + Mapping from column name to type, e.g. + ``{"duration": float, "spk_id": int}``. + Passed through to :func:`load_data_csv`. + """ + data = load_data_csv(csv_path, replacements, col_types=col_types) return cls(data, dynamic_items, output_keys) @classmethod diff --git a/tests/unittests/test_data_io.py b/tests/unittests/test_data_io.py index f2043e38bb..80f0163830 100644 --- a/tests/unittests/test_data_io.py +++ b/tests/unittests/test_data_io.py @@ -3,6 +3,34 @@ import torch +def test_load_data_csv_col_types(tmpdir): + """col_types should cast only the specified columns.""" + from speechbrain.dataio.dataio import load_data_csv + + csv_content = "ID,duration,spk_id,transcript\nutt1,1.45,007,hello world\nutt2,2.0,42,bye\n" + csv_path = os.path.join(tmpdir, "test_col_types.csv") + with open(csv_path, "w", encoding="utf-8") as f: + f.write(csv_content) + + # without col_types, duration is float (legacy), rest are strings + data = load_data_csv(csv_path) + assert isinstance(data["utt1"]["duration"], float) + assert isinstance(data["utt1"]["spk_id"], str) + assert data["utt1"]["spk_id"] == "007" # leading zero preserved + + # with col_types, specified columns get cast + data = load_data_csv(csv_path, col_types={"spk_id": int}) + assert data["utt1"]["spk_id"] == 7 + assert isinstance(data["utt1"]["spk_id"], int) + assert isinstance(data["utt1"]["duration"], float) # still float (legacy) + assert isinstance(data["utt1"]["transcript"], str) + + # col_types can override duration type too + data = load_data_csv(csv_path, col_types={"duration": str}) + # col_types is applied after the legacy duration cast, so it re-casts + assert isinstance(data["utt1"]["duration"], str) + + def test_read_audio_info(tmpdir, device): from speechbrain.dataio.dataio import read_audio_info, write_audio