Skip to content

feat: add columns parameters support #2814

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions awswrangler/catalog/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
_get_table_input,
databases,
get_columns_comments,
get_columns_parameters,
get_connection,
get_csv_partitions,
get_databases,
Expand Down Expand Up @@ -83,6 +84,7 @@
"_get_table_input",
"databases",
"get_columns_comments",
"get_columns_parameters",
"get_connection",
"get_csv_partitions",
"get_databases",
Expand Down
42 changes: 41 additions & 1 deletion awswrangler/catalog/_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
_logger: logging.Logger = logging.getLogger(__name__)


def _update_if_necessary(dic: dict[str, str], key: str, value: str | None, mode: str) -> str:
def _update_if_necessary(
dic: dict[str, str | dict[str, str]], key: str, value: str | dict[str, str] | None, mode: str
) -> str:
if value is not None:
if key not in dic or dic[key] != value:
dic[key] = value
Expand All @@ -46,6 +48,7 @@ def _create_table( # noqa: PLR0912,PLR0915
table_exist: bool,
partitions_types: dict[str, str] | None,
columns_comments: dict[str, str] | None,
columns_parameters: dict[str, dict[str, str]] | None,
athena_partition_projection_settings: typing.AthenaPartitionProjectionSettings | None,
catalog_id: str | None,
) -> None:
Expand Down Expand Up @@ -130,6 +133,19 @@ def _create_table( # noqa: PLR0912,PLR0915
if name in columns_comments:
mode = _update_if_necessary(dic=par, key="Comment", value=columns_comments[name], mode=mode)

# Column parameters
columns_parameters = columns_parameters if columns_parameters else {}
columns_parameters = {sanitize_column_name(k): v for k, v in columns_parameters.items()}
if columns_parameters:
for col in table_input["StorageDescriptor"]["Columns"]:
name: str = col["Name"] # type: ignore[no-redef]
if name in columns_parameters:
mode = _update_if_necessary(dic=col, key="Parameters", value=columns_parameters[name], mode=mode)
for par in table_input["PartitionKeys"]:
name = par["Name"]
if name in columns_parameters:
mode = _update_if_necessary(dic=par, key="Parameters", value=columns_parameters[name], mode=mode)

_logger.debug("table_input: %s", table_input)

client_glue = _utils.client(service_name="glue", session=boto3_session)
Expand Down Expand Up @@ -275,6 +291,7 @@ def _create_parquet_table(
description: str | None,
parameters: dict[str, str] | None,
columns_comments: dict[str, str] | None,
columns_parameters: dict[str, dict[str, str]] | None,
mode: str,
catalog_versioning: bool,
athena_partition_projection_settings: typing.AthenaPartitionProjectionSettings | None,
Expand Down Expand Up @@ -311,6 +328,7 @@ def _create_parquet_table(
description=description,
parameters=parameters,
columns_comments=columns_comments,
columns_parameters=columns_parameters,
mode=mode,
catalog_versioning=catalog_versioning,
boto3_session=boto3_session,
Expand All @@ -335,6 +353,7 @@ def _create_orc_table(
description: str | None,
parameters: dict[str, str] | None,
columns_comments: dict[str, str] | None,
columns_parameters: dict[str, dict[str, str]] | None,
mode: str,
catalog_versioning: bool,
athena_partition_projection_settings: typing.AthenaPartitionProjectionSettings | None,
Expand Down Expand Up @@ -371,6 +390,7 @@ def _create_orc_table(
description=description,
parameters=parameters,
columns_comments=columns_comments,
columns_parameters=columns_parameters,
mode=mode,
catalog_versioning=catalog_versioning,
boto3_session=boto3_session,
Expand All @@ -394,6 +414,7 @@ def _create_csv_table(
compression: str | None,
parameters: dict[str, str] | None,
columns_comments: dict[str, str] | None,
columns_parameters: dict[str, dict[str, str]] | None,
mode: str,
catalog_versioning: bool,
schema_evolution: bool,
Expand Down Expand Up @@ -444,6 +465,7 @@ def _create_csv_table(
description=description,
parameters=parameters,
columns_comments=columns_comments,
columns_parameters=columns_parameters,
mode=mode,
catalog_versioning=catalog_versioning,
boto3_session=boto3_session,
Expand All @@ -467,6 +489,7 @@ def _create_json_table(
compression: str | None,
parameters: dict[str, str] | None,
columns_comments: dict[str, str] | None,
columns_parameters: dict[str, dict[str, str]] | None,
mode: str,
catalog_versioning: bool,
schema_evolution: bool,
Expand Down Expand Up @@ -512,6 +535,7 @@ def _create_json_table(
description=description,
parameters=parameters,
columns_comments=columns_comments,
columns_parameters=columns_parameters,
mode=mode,
catalog_versioning=catalog_versioning,
boto3_session=boto3_session,
Expand Down Expand Up @@ -713,6 +737,7 @@ def create_parquet_table(
description: str | None = None,
parameters: dict[str, str] | None = None,
columns_comments: dict[str, str] | None = None,
columns_parameters: dict[str, dict[str, str]] | None = None,
mode: Literal["overwrite", "append"] = "overwrite",
catalog_versioning: bool = False,
athena_partition_projection_settings: typing.AthenaPartitionProjectionSettings | None = None,
Expand Down Expand Up @@ -751,6 +776,8 @@ def create_parquet_table(
Key/value pairs to tag the table.
columns_comments: Dict[str, str], optional
Columns names and the related comments (e.g. {'col0': 'Column 0.', 'col1': 'Column 1.', 'col2': 'Partition.'}).
columns_parameters: Dict[str, Dict[str, str]], optional
Columns names and the related parameters (e.g. {'col0': {'par0': 'Param 0', 'par1': 'Param 1'}}).
mode: str
'overwrite' to recreate any possible existing table or 'append' to keep any possible existing table.
catalog_versioning : bool
Expand Down Expand Up @@ -848,6 +875,7 @@ def create_parquet_table(
description=description,
parameters=parameters,
columns_comments=columns_comments,
columns_parameters=columns_parameters,
mode=mode,
catalog_versioning=catalog_versioning,
athena_partition_projection_settings=athena_partition_projection_settings,
Expand All @@ -870,6 +898,7 @@ def create_orc_table(
description: str | None = None,
parameters: dict[str, str] | None = None,
columns_comments: dict[str, str] | None = None,
columns_parameters: dict[str, dict[str, str]] | None = None,
mode: Literal["overwrite", "append"] = "overwrite",
catalog_versioning: bool = False,
athena_partition_projection_settings: typing.AthenaPartitionProjectionSettings | None = None,
Expand Down Expand Up @@ -908,6 +937,8 @@ def create_orc_table(
Key/value pairs to tag the table.
columns_comments: Dict[str, str], optional
Columns names and the related comments (e.g. {'col0': 'Column 0.', 'col1': 'Column 1.', 'col2': 'Partition.'}).
columns_parameters: Dict[str, Dict[str, str]], optional
Columns names and the related parameters (e.g. {'col0': {'par0': 'Param 0', 'par1': 'Param 1'}}).
mode: str
'overwrite' to recreate any possible existing table or 'append' to keep any possible existing table.
catalog_versioning : bool
Expand Down Expand Up @@ -1005,6 +1036,7 @@ def create_orc_table(
description=description,
parameters=parameters,
columns_comments=columns_comments,
columns_parameters=columns_parameters,
mode=mode,
catalog_versioning=catalog_versioning,
athena_partition_projection_settings=athena_partition_projection_settings,
Expand All @@ -1026,6 +1058,7 @@ def create_csv_table(
description: str | None = None,
parameters: dict[str, str] | None = None,
columns_comments: dict[str, str] | None = None,
columns_parameters: dict[str, dict[str, str]] | None = None,
mode: Literal["overwrite", "append"] = "overwrite",
catalog_versioning: bool = False,
schema_evolution: bool = False,
Expand Down Expand Up @@ -1072,6 +1105,8 @@ def create_csv_table(
Key/value pairs to tag the table.
columns_comments: Dict[str, str], optional
Columns names and the related comments (e.g. {'col0': 'Column 0.', 'col1': 'Column 1.', 'col2': 'Partition.'}).
columns_parameters: Dict[str, Dict[str, str]], optional
Columns names and the related parameters (e.g. {'col0': {'par0': 'Param 0', 'par1': 'Param 1'}}).
mode : str
'overwrite' to recreate any possible existing table or 'append' to keep any possible existing table.
catalog_versioning : bool
Expand Down Expand Up @@ -1188,6 +1223,7 @@ def create_csv_table(
description=description,
parameters=parameters,
columns_comments=columns_comments,
columns_parameters=columns_parameters,
mode=mode,
catalog_versioning=catalog_versioning,
schema_evolution=schema_evolution,
Expand All @@ -1214,6 +1250,7 @@ def create_json_table(
description: str | None = None,
parameters: dict[str, str] | None = None,
columns_comments: dict[str, str] | None = None,
columns_parameters: dict[str, dict[str, str]] | None = None,
mode: Literal["overwrite", "append"] = "overwrite",
catalog_versioning: bool = False,
schema_evolution: bool = False,
Expand Down Expand Up @@ -1253,6 +1290,8 @@ def create_json_table(
Key/value pairs to tag the table.
columns_comments: Dict[str, str], optional
Columns names and the related comments (e.g. {'col0': 'Column 0.', 'col1': 'Column 1.', 'col2': 'Partition.'}).
columns_parameters: Dict[str, Dict[str, str]], optional
Columns names and the related parameters (e.g. {'col0': {'par0': 'Param 0', 'par1': 'Param 1'}}).
mode : str
'overwrite' to recreate any possible existing table or 'append' to keep any possible existing table.
catalog_versioning : bool
Expand Down Expand Up @@ -1361,6 +1400,7 @@ def create_json_table(
description=description,
parameters=parameters,
columns_comments=columns_comments,
columns_parameters=columns_parameters,
mode=mode,
catalog_versioning=catalog_versioning,
schema_evolution=schema_evolution,
Expand Down
45 changes: 44 additions & 1 deletion awswrangler/catalog/_get.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import base64
import itertools
import logging
from typing import TYPE_CHECKING, Any, Dict, Iterator, cast
from typing import TYPE_CHECKING, Any, Dict, Iterator, Mapping, cast

import boto3
import botocore.exceptions
Expand Down Expand Up @@ -887,6 +887,49 @@ def get_columns_comments(
return comments


@apply_configs
def get_columns_parameters(
database: str,
table: str,
catalog_id: str | None = None,
boto3_session: boto3.Session | None = None,
) -> dict[str, Mapping[str, str] | None]:
"""Get all columns parameters.

Parameters
----------
database : str
Database name.
table : str
Table name.
catalog_id : str, optional
The ID of the Data Catalog from which to retrieve Databases.
If none is provided, the AWS account ID is used by default.
boto3_session : boto3.Session(), optional
Boto3 Session. The default boto3 session will be used if boto3_session receive None.

Returns
-------
Dict[str, Optional[Dict[str, str]]]
Columns parameters.

Examples
--------
>>> import awswrangler as wr
>>> pars = wr.catalog.get_columns_parameters(database="...", table="...")

"""
client_glue = _utils.client("glue", session=boto3_session)
response = client_glue.get_table(**_catalog_id(catalog_id=catalog_id, DatabaseName=database, Name=table))
parameters = {}
for c in response["Table"]["StorageDescriptor"]["Columns"]:
parameters[c["Name"]] = c.get("Parameters")
if "PartitionKeys" in response["Table"]:
for p in response["Table"]["PartitionKeys"]:
parameters[p["Name"]] = p.get("Parameters")
return parameters


@apply_configs
def get_table_versions(
database: str, table: str, catalog_id: str | None = None, boto3_session: boto3.Session | None = None
Expand Down
8 changes: 6 additions & 2 deletions awswrangler/s3/_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def _validate_args(
description: str | None,
parameters: dict[str, str] | None,
columns_comments: dict[str, str] | None,
columns_parameters: dict[str, dict[str, str]] | None,
execution_engine: Enum,
) -> None:
if df.empty is True:
Expand All @@ -87,11 +88,11 @@ def _validate_args(
raise exceptions.InvalidArgumentCombination("Please, pass dataset=True to be able to use bucketing_info.")
if mode is not None:
raise exceptions.InvalidArgumentCombination("Please pass dataset=True to be able to use mode.")
if any(arg is not None for arg in (table, description, parameters, columns_comments)):
if any(arg is not None for arg in (table, description, parameters, columns_comments, columns_parameters)):
raise exceptions.InvalidArgumentCombination(
"Please pass dataset=True to be able to use any one of these "
"arguments: database, table, description, parameters, "
"columns_comments."
"columns_comments, columns_parameters."
)
elif (database is None) != (table is None):
raise exceptions.InvalidArgumentCombination(
Expand Down Expand Up @@ -214,6 +215,7 @@ def _create_glue_table(
description: str | None,
parameters: dict[str, str] | None,
columns_comments: dict[str, str] | None,
columns_parameters: dict[str, dict[str, str]] | None,
mode: str,
catalog_versioning: bool,
athena_partition_projection_settings: typing.AthenaPartitionProjectionSettings | None,
Expand Down Expand Up @@ -262,6 +264,7 @@ def write( # noqa: PLR0912,PLR0913
description: str | None,
parameters: dict[str, str] | None,
columns_comments: dict[str, str] | None,
columns_parameters: dict[str, dict[str, str]] | None,
regular_partitions: bool,
table_type: str | None,
dtype: dict[str, str] | None,
Expand Down Expand Up @@ -361,6 +364,7 @@ def write( # noqa: PLR0912,PLR0913
"description": description,
"parameters": parameters,
"columns_comments": columns_comments,
"columns_parameters": columns_parameters,
"boto3_session": boto3_session,
"mode": mode,
"catalog_versioning": catalog_versioning,
Expand Down
5 changes: 5 additions & 0 deletions awswrangler/s3/_write_orc.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ def _create_glue_table(
description: str | None = None,
parameters: dict[str, str] | None = None,
columns_comments: dict[str, str] | None = None,
columns_parameters: dict[str, dict[str, str]] | None = None,
mode: str = "overwrite",
catalog_versioning: bool = False,
athena_partition_projection_settings: AthenaPartitionProjectionSettings | None = None,
Expand All @@ -272,6 +273,7 @@ def _create_glue_table(
description=description,
parameters=parameters,
columns_comments=columns_comments,
columns_parameters=columns_parameters,
mode=mode,
catalog_versioning=catalog_versioning,
athena_partition_projection_settings=athena_partition_projection_settings,
Expand Down Expand Up @@ -629,6 +631,7 @@ def to_orc(
description = glue_table_settings.get("description")
parameters = glue_table_settings.get("parameters")
columns_comments = glue_table_settings.get("columns_comments")
columns_parameters = glue_table_settings.get("columns_parameters")
regular_partitions = glue_table_settings.get("regular_partitions", True)

_validate_args(
Expand All @@ -643,6 +646,7 @@ def to_orc(
description=description,
parameters=parameters,
columns_comments=columns_comments,
columns_parameters=columns_parameters,
execution_engine=engine.get(),
)

Expand Down Expand Up @@ -682,6 +686,7 @@ def to_orc(
description=description,
parameters=parameters,
columns_comments=columns_comments,
columns_parameters=columns_parameters,
table_type=table_type,
regular_partitions=regular_partitions,
dtype=dtype,
Expand Down
Loading
Loading