Skip to content

Commit 5162cb2

Browse files
committed
Ensure Scala 2 libs get "_2" repo names in Scala 3
Without this change, it was possible for Scala 3 core library versions to become overwritten with Scala 2 versions specified as dependencies of other jars. Specifically, all the `@io_bazel_rules_scala_scala_*_2` deps explicitly added to the `scala_3_{1,2,3,4,5}.bzl` files in bazel-contrib#1631 for Scalafmt get stripped of the `_2` suffix before this change. Also computed `is_scala_3` in `create_file` and passed it through where needed. At some point it might be worth refactoring the script into a proper object instead.
1 parent 772344b commit 5162cb2

File tree

1 file changed

+19
-18
lines changed

1 file changed

+19
-18
lines changed

scripts/create_repository.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,9 @@ class ResolvedArtifact:
4747
checksum: str
4848
direct_dependencies: List[MavenCoordinates]
4949

50-
def select_root_artifacts(scala_version) -> List[str]:
51-
scala_major = ".".join(scala_version.split(".")[:2])
52-
scalatest_major = "3" if scala_major >= "3.0" else scala_major
53-
scalafmt_major = "2.13" if scala_major >= "3.0" else scala_major
50+
def select_root_artifacts(scala_version, scala_major, is_scala_3) -> List[str]:
51+
scalatest_major = "3" if is_scala_3 else scala_major
52+
scalafmt_major = "2.13" if is_scala_3 else scala_major
5453
kind_projector_version = "0.13.2" if scala_major < "2.12" else "0.13.3"
5554
scalafmt_version = "2.7.5" if scala_major == "2.11" else SCALAFMT_VERSION
5655

@@ -141,13 +140,14 @@ def get_json_dependencies(artifact) -> List[MavenCoordinates]:
141140
]),
142141
]
143142

144-
def get_label(coordinates) -> str:
143+
def get_label(coordinates, is_scala_3) -> str:
145144
group = coordinates.group
146145
group_label = group.replace('.', '_').replace('-', '_')
147146
artifact_label = coordinates.artifact.split('_')[0].replace('-', '_')
148147

149148
if group in COORDINATE_GROUPS[0] or group.startswith("org.scala-lang."):
150-
return f'io_bazel_rules_scala_{artifact_label}'
149+
s = '_2' if is_scala_3 and coordinates.version.startswith("2.") else ''
150+
return f'io_bazel_rules_scala_{artifact_label}' + s
151151
if group in COORDINATE_GROUPS[1]:
152152
return f'io_bazel_rules_scala_{group_label}_{artifact_label}'
153153
if group in COORDINATE_GROUPS[2]:
@@ -195,25 +195,22 @@ def resolve_artifacts_with_checksums_and_direct_dependencies(
195195
proc.stdout.splitlines(), current_artifacts,
196196
)
197197

198-
def to_rules_scala_compatible_dict(artifacts) -> Dict[str, Dict]:
198+
def to_rules_scala_compatible_dict(artifacts, is_scala_3) -> Dict[str, Dict]:
199199
result = {}
200200

201201
for a in artifacts:
202+
coordinates = a.coordinates
202203
label = (
203-
get_label(a.coordinates)
204+
get_label(coordinates, is_scala_3)
204205
.replace('scala3_', 'scala_')
205206
.replace('scala_tasty_core', 'scala_scala_tasty_core')
206207
)
207-
deps = sorted([
208-
f'@{get_label(dep)}_2'
209-
if "scala3-library_3" in a.coordinates.artifact
210-
else f'@{get_label(dep)}'
211-
for dep in a.direct_dependencies
212-
])
213208
result[label] = {
214-
"artifact": f"{a.coordinates.coordinate}",
209+
"artifact": f"{coordinates.coordinate}",
215210
"sha256": f"{a.checksum}",
216-
"deps": deps,
211+
"deps": sorted([
212+
f'@{get_label(d, is_scala_3)}' for d in a.direct_dependencies
213+
]),
217214
}
218215

219216
return result
@@ -248,7 +245,9 @@ def create_file(version):
248245
with file.open('r', encoding='utf-8') as data:
249246
read_data = data.read()
250247

251-
root_artifacts = select_root_artifacts(version)
248+
scala_major = ".".join(version.split(".")[:2])
249+
is_scala_3 = scala_major.startswith("3.")
250+
root_artifacts = select_root_artifacts(version, scala_major, is_scala_3)
252251
replaced_data = read_data[read_data.find('{'):]
253252

254253
original_artifacts = ast.literal_eval(replaced_data)
@@ -259,7 +258,9 @@ def create_file(version):
259258
{a["artifact"] for a in original_artifacts.values()},
260259
)
261260
)
262-
generated_artifacts = to_rules_scala_compatible_dict(transitive_artifacts)
261+
generated_artifacts = to_rules_scala_compatible_dict(
262+
transitive_artifacts, is_scala_3
263+
)
263264

264265
for label, metadata in original_artifacts.items():
265266
generated_metadata = generated_artifacts.get(label, None)

0 commit comments

Comments
 (0)