Skip to content

feat: add upsert mode to sqlserver.py and corresponding tests #2835

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions awswrangler/sqlserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,11 +436,12 @@ def to_sql(
con: "pyodbc.Connection",
table: str,
schema: str,
mode: Literal["append", "overwrite"] = "append",
mode: Literal["append", "overwrite", "upsert"] = "append",
index: bool = False,
dtype: dict[str, str] | None = None,
varchar_lengths: dict[str, int] | None = None,
use_column_names: bool = False,
upsert_conflict_columns: list[str] | None = None,
chunksize: int = 200,
fast_executemany: bool = False,
) -> None:
Expand All @@ -457,7 +458,12 @@ def to_sql(
schema : str
Schema name
mode : str
Append or overwrite.
Append, overwrite or upsert.

- append: Inserts new records into table.
- overwrite: Drops table and recreates.
- upsert: Perform an upsert which checks for conflicts on columns given by ``upsert_conflict_columns`` and sets the new values on conflicts. Note that column names of the Dataframe will be used for this operation, as if ``use_column_names`` was set to True.

index : bool
True to store the DataFrame index as a column in the table,
otherwise False to ignore it.
Expand All @@ -471,6 +477,8 @@ def to_sql(
If set to True, will use the column names of the DataFrame for generating the INSERT SQL Query.
E.g. If the DataFrame has two columns `col1` and `col3` and `use_column_names` is True, data will only be
inserted into the database columns `col1` and `col3`.
uspert_conflict_columns: List[str], optional
List of columns to be used as conflict columns in the upsert operation.
chunksize: int
Number of rows which are inserted with each SQL query. Defaults to inserting 200 rows per query.
fast_executemany: bool
Expand Down Expand Up @@ -506,6 +514,8 @@ def to_sql(
if df.empty is True:
raise exceptions.EmptyDataFrame("DataFrame cannot be empty.")
_validate_connection(con=con)
if mode == "upsert" and not upsert_conflict_columns:
raise exceptions.InvalidArgumentValue("<upsert_conflict_columns> need to be set when using upsert mode.")
try:
with con.cursor() as cursor:
if fast_executemany:
Expand All @@ -524,15 +534,28 @@ def to_sql(
df.reset_index(level=df.index.names, inplace=True)
column_placeholders: str = ", ".join(["?"] * len(df.columns))
table_identifier = _get_table_identifier(schema, table)
column_names = [identifier(col, sql_mode="mssql") for col in df.columns]
quoted_columns = ", ".join(column_names)
insertion_columns = ""
if use_column_names:
quoted_columns = ", ".join(f"{identifier(col, sql_mode='mssql')}" for col in df.columns)
insertion_columns = f"({quoted_columns})"
placeholder_parameter_pair_generator = _db_utils.generate_placeholder_parameter_pairs(
df=df, column_placeholders=column_placeholders, chunksize=chunksize
)
for placeholders, parameters in placeholder_parameter_pair_generator:
sql: str = f"INSERT INTO {table_identifier} {insertion_columns} VALUES {placeholders}"
if mode == "upsert" and upsert_conflict_columns:
merge_on_columns = [identifier(col, sql_mode="mssql") for col in upsert_conflict_columns]
sql = f"MERGE INTO {table_identifier}\nUSING (VALUES {placeholders}) AS source ({quoted_columns})\n"
sql += f"ON {' AND '.join(f'{table_identifier}.{col}=source.{col}' for col in merge_on_columns)}\n"
sql += (
f"WHEN MATCHED THEN\n UPDATE "
f"SET {', '.join(f'{col}=source.{col}' for col in column_names)}\n"
)
sql += (
f"WHEN NOT MATCHED THEN\n INSERT "
f"({quoted_columns}) VALUES ({', '.join([f'source.{col}' for col in column_names])});"
)
_logger.debug("sql: %s", sql)
cursor.executemany(sql, (parameters,))
con.commit()
Expand Down
71 changes: 71 additions & 0 deletions tests/unit/test_sqlserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,3 +260,74 @@ def test_dfs_are_equal_for_different_chunksizes(sqlserver_table, sqlserver_con,
df["c1"] = df["c1"].astype("string")

assert df.equals(df2)


def test_upsert(sqlserver_table, sqlserver_con):
df = pd.DataFrame({"c0": ["foo", "bar"], "c2": [1, 2]})

with pytest.raises(wr.exceptions.InvalidArgumentValue):
wr.sqlserver.to_sql(
df=df,
con=sqlserver_con,
schema="dbo",
table=sqlserver_table,
mode="upsert",
upsert_conflict_columns=None,
use_column_names=True,
)

wr.sqlserver.to_sql(
df=df,
con=sqlserver_con,
schema="dbo",
table=sqlserver_table,
mode="upsert",
upsert_conflict_columns=["c0"],
)
wr.sqlserver.to_sql(
df=df,
con=sqlserver_con,
schema="dbo",
table=sqlserver_table,
mode="upsert",
upsert_conflict_columns=["c0"],
)
df2 = wr.sqlserver.read_sql_table(con=sqlserver_con, schema="dbo", table=sqlserver_table)
assert bool(len(df2) == 2)

wr.sqlserver.to_sql(
df=df,
con=sqlserver_con,
schema="dbo",
table=sqlserver_table,
mode="upsert",
upsert_conflict_columns=["c0"],
)
df3 = pd.DataFrame({"c0": ["baz", "bar"], "c2": [3, 2]})
wr.sqlserver.to_sql(
df=df3,
con=sqlserver_con,
schema="dbo",
table=sqlserver_table,
mode="upsert",
upsert_conflict_columns=["c0"],
use_column_names=True,
)
df4 = wr.sqlserver.read_sql_table(con=sqlserver_con, schema="dbo", table=sqlserver_table)
assert bool(len(df4) == 3)

df5 = pd.DataFrame({"c0": ["foo", "bar"], "c2": [4, 5]})
wr.sqlserver.to_sql(
df=df5,
con=sqlserver_con,
schema="dbo",
table=sqlserver_table,
mode="upsert",
upsert_conflict_columns=["c0"],
use_column_names=True,
)

df6 = wr.sqlserver.read_sql_table(con=sqlserver_con, schema="dbo", table=sqlserver_table)
assert bool(len(df6) == 3)
assert bool(len(df6.loc[(df6["c0"] == "foo") & (df6["c2"] == 4)]) == 1)
assert bool(len(df6.loc[(df6["c0"] == "bar") & (df6["c2"] == 5)]) == 1)
Loading