Convert data according to a class schema.
Parameters:
Name |
Type |
Description |
Default |
cls_field
|
Attribute
|
Field to apply conversion function.
|
required
|
cls_field_exists
|
bool
|
Whether field exists in data already.
|
required
|
Returns:
Type |
Description |
Column
|
Constructed PySpark Column expression.
|
Source code in src/tidy_tools/model/convert.py
| def convert_field(cls_field: attrs.Attribute, cls_field_exists: bool) -> Column:
"""
Convert data according to a class schema.
Parameters
----------
cls_field : attrs.Attribute
Field to apply conversion function.
cls_field_exists : bool
Whether field exists in data already.
Returns
-------
Column
Constructed PySpark Column expression.
"""
if not cls_field.default:
if not cls_field_exists:
column = F.lit(None)
else:
column = F.col(cls_field.alias)
if cls_field.default:
if isinstance(cls_field.default, attrs.Factory):
return_type = typing.get_type_hints(cls_field.default.factory).get("return")
assert (
return_type is not None
), "Missing type hint for return value! Redefine function to include type hint `def func() -> pyspark.sql.Column: ...`"
assert return_type is Column, "Factory must return a pyspark.sql.Column!"
column = cls_field.default.factory()
elif not cls_field_exists:
column = F.lit(cls_field.default)
else:
column = F.when(
F.col(cls_field.alias).isNull(), cls_field.default
).otherwise(F.col(cls_field.alias))
else:
column = F.col(cls_field.alias)
if cls_field.name != cls_field.alias:
column = column.alias(cls_field.name)
cls_field_type = get_pyspark_type(cls_field)
match cls_field_type:
case T.DateType():
date_format = cls_field.metadata.get("format")
if date_format:
column = F.to_date(column, format=date_format)
else:
logger.warning(
f"No `format` provided for {cls_field.name}. Please add `field(..., metadata={{'format': ...}})` and rerun."
)
column = column.cast(cls_field_type)
case T.TimestampType():
timestamp_format = cls_field.metadata.get("format")
if timestamp_format:
column = F.to_timestamp(column, format=timestamp_format)
else:
logger.warning(
f"No `format` provided for {cls_field.name}. Please add `field(..., metadata={{'format': ...}})` and rerun."
)
column = column.cast(cls_field_type)
if cls_field.converter:
column = cls_field.converter(column)
return column.alias(cls_field.alias)
|