diff --git a/jvm/src/test/scala/scala/xml/XMLTest.scala b/jvm/src/test/scala/scala/xml/XMLTest.scala
index 43dd182e2..5008bd6e6 100644
--- a/jvm/src/test/scala/scala/xml/XMLTest.scala
+++ b/jvm/src/test/scala/scala/xml/XMLTest.scala
@@ -550,6 +550,24 @@ class XMLTestJVM {
XML.loadString(broken)
}
+ @UnitTest
+ def issueSI9047AttributeFromSingleChildElementWorks: Unit = {
+ val x =
+
+ val b = x \ "a" \ "@b"
+
+ assertEquals(List("1"), b map (_.text))
+ }
+
+ @UnitTest
+ def issueSI9047AttributeMultipleChildElementsWorks: Unit = {
+ val x =
+
+ val b = x \ "a" \ "@b"
+
+ assertEquals(List("1", "2"), b map (_.text))
+ }
+
@UnitTest
def nodeSeqNs: Unit = {
val x = {
diff --git a/shared/src/main/scala/scala/xml/NodeSeq.scala b/shared/src/main/scala/scala/xml/NodeSeq.scala
index c498279e9..df43f2553 100644
--- a/shared/src/main/scala/scala/xml/NodeSeq.scala
+++ b/shared/src/main/scala/scala/xml/NodeSeq.scala
@@ -95,8 +95,7 @@ abstract class NodeSeq extends AbstractSeq[Node] with immutable.Seq[Node] with S
def \(that: String): NodeSeq = {
def fail = throw new IllegalArgumentException(that)
def atResult = {
- lazy val y = this(0)
- val attr =
+ this flatMap (y => (
if (that.length == 1) fail
else if (that(1) == '{') {
val i = that indexOf '}'
@@ -105,10 +104,9 @@ abstract class NodeSeq extends AbstractSeq[Node] with immutable.Seq[Node] with S
if (uri == "" || key == "") fail
else y.attribute(uri, key)
} else y.attribute(that drop 1)
-
- attr match {
- case Some(x) => Group(x)
- case _ => NodeSeq.Empty
+ ).getOrElse(Nil)) match {
+ case NodeSeq.Empty => NodeSeq.Empty
+ case x => Group(x)
}
}
@@ -118,7 +116,7 @@ abstract class NodeSeq extends AbstractSeq[Node] with immutable.Seq[Node] with S
that match {
case "" => fail
case "_" => makeSeq(!_.isAtom)
- case _ if (that(0) == '@' && this.length == 1) => atResult
+ case _ if that(0) == '@' => atResult
case _ => makeSeq(_.label == that)
}
}
diff --git a/shared/src/test/scala/scala/xml/AttributeTest.scala b/shared/src/test/scala/scala/xml/AttributeTest.scala
index 8943a0a00..b7a0d3115 100644
--- a/shared/src/test/scala/scala/xml/AttributeTest.scala
+++ b/shared/src/test/scala/scala/xml/AttributeTest.scala
@@ -147,9 +147,9 @@ class AttributeTest {
val b = xml \ "b"
assertEquals(2, b.length)
assertEquals(NodeSeq.fromSeq(Seq(, )), b)
- val barFail = b \ "@bar"
+ val barAttributesDirect = b \ "@bar"
val barList = b.map(_ \ "@bar")
- assertEquals(NodeSeq.Empty, barFail)
+ assertEquals(Group(Seq(Text("1"), Text("2"))), barAttributesDirect)
assertEquals(List(Group(Seq(Text("1"))), Group(Seq(Text("2")))), barList)
}