-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathcreate_spark_session.py
137 lines (111 loc) · 4.55 KB
/
create_spark_session.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import logging
import os
import shutil
from typing import Any
from pyspark.sql.session import SparkSession
from spark_pipeline_framework.utilities.spark_data_frame_helpers import (
spark_list_catalog_table_names,
)
# make sure env variables are set correctly
if "SPARK_HOME" not in os.environ:
os.environ["SPARK_HOME"] = "/usr/local/opt/spark"
def quiet_py4j() -> None:
"""turn down spark logging for the carriers context"""
logger = logging.getLogger("py4j")
logger.setLevel(logging.ERROR)
def clean_spark_dir() -> None:
"""
:return:
"""
try:
if os.path.exists("./derby.log"):
os.remove("./derby.log")
if os.path.exists("./metastore_db"):
shutil.rmtree("./metastore_db")
if os.path.exists("./spark-warehouse"):
shutil.rmtree("./spark-warehouse")
except OSError as e:
print(f"Error cleaning spark directories: {e.strerror}")
def clean_spark_session(session: SparkSession) -> None:
"""
:param session:
:return:
"""
table_names = spark_list_catalog_table_names(session)
for table_name in table_names:
print(f"clear_tables() is dropping table/view: {table_name}")
# Drop the table if it exists
if session.catalog.tableExists(f"default.{table_name}"):
# noinspection SqlNoDataSourceInspection
session.sql(f"DROP TABLE default.{table_name}")
# Drop the view if it exists in the default database
if session.catalog.tableExists(f"default.{table_name}"):
session.catalog.dropTempView(f"default.{table_name}")
# Drop the view if it exists in the global context
if session.catalog.tableExists(f"{table_name}"):
session.catalog.dropTempView(f"{table_name}")
session.catalog.clearCache()
def clean_close(session: SparkSession) -> None:
"""
:param session:
:return:
"""
clean_spark_session(session)
clean_spark_dir()
session.stop()
def create_spark_session(request: Any) -> SparkSession:
logging.getLogger("org.apache.spark.deploy.SparkSubmit").setLevel(logging.ERROR)
logging.getLogger("org.apache.ivy").setLevel(logging.ERROR)
logging.getLogger("org.apache.hadoop.hive.metastore.ObjectStore").setLevel(
logging.ERROR
)
logging.getLogger("org.apache.hadoop.hive.conf.HiveConf").setLevel(logging.ERROR)
logging.getLogger("org.apache.hadoop.util.NativeCodeLoader").setLevel(logging.ERROR)
# make sure env variables are set correctly
if "SPARK_HOME" not in os.environ:
os.environ["SPARK_HOME"] = "/usr/local/opt/spark"
clean_spark_dir()
master = "local[2]"
# These jar files are already contained in the imranq2/helix.spark image
# jars = [
# "mysql:mysql-connector-java:8.0.33",
# "org.apache.spark:spark-sql-kafka-0-10_2.12:3.5.1",
# "io.delta:delta-spark_2.12:3.2.0",
# "io.delta:delta-storage:3.2.0",
# "com.johnsnowlabs.nlp:spark-nlp_2.12:5.3.3",
# "org.apache.spark:spark-hadoop-cloud_2.12:3.5.1",
# "com.amazonaws:aws-java-sdk-bundle:1.12.262",
# "com.databricks:spark-xml_2.12:0.18.0",
# ]
session = (
SparkSession.builder.appName("pytest-pyspark-local-testing")
.master(master)
.config("spark.ui.showConsoleProgress", "false")
.config("spark.executor.instances", "2")
.config("spark.executor.cores", "1")
.config("spark.executor.memory", "2g")
.config("spark.sql.shuffle.partitions", "2")
.config("spark.default.parallelism", "4")
.config("spark.sql.broadcastTimeout", "2400")
.config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension")
.config(
"spark.sql.catalog.spark_catalog",
"org.apache.spark.sql.delta.catalog.DeltaCatalog",
)
# .config("spark.jars.packages", ",".join(jars))
.config("spark.sql.execution.arrow.pyspark.enabled", "true")
.config("spark.sql.execution.arrow.pyspark.fallback.enabled", "false")
.config("spark.sql.execution.arrow.maxRecordsPerBatch", "2048")
.enableHiveSupport()
.getOrCreate()
)
if os.environ.get("LOGLEVEL") == "DEBUG":
configurations = session.sparkContext.getConf().getAll()
for item in configurations:
print(item)
# Verify that Arrow is enabled
# arrow_enabled = session.conf.get("spark.sql.execution.arrow.pyspark.enabled")
# print(f"Arrow Enabled: {arrow_enabled}")
request.addfinalizer(lambda: clean_close(session))
quiet_py4j()
return session