diff --git a/jvm/src/test/scala/scala/xml/XMLTest.scala b/jvm/src/test/scala/scala/xml/XMLTest.scala index 43dd182e2..5008bd6e6 100644 --- a/jvm/src/test/scala/scala/xml/XMLTest.scala +++ b/jvm/src/test/scala/scala/xml/XMLTest.scala @@ -550,6 +550,24 @@ class XMLTestJVM { XML.loadString(broken) } + @UnitTest + def issueSI9047AttributeFromSingleChildElementWorks: Unit = { + val x = + + val b = x \ "a" \ "@b" + + assertEquals(List("1"), b map (_.text)) + } + + @UnitTest + def issueSI9047AttributeMultipleChildElementsWorks: Unit = { + val x = + + val b = x \ "a" \ "@b" + + assertEquals(List("1", "2"), b map (_.text)) + } + @UnitTest def nodeSeqNs: Unit = { val x = { diff --git a/shared/src/main/scala/scala/xml/NodeSeq.scala b/shared/src/main/scala/scala/xml/NodeSeq.scala index c498279e9..df43f2553 100644 --- a/shared/src/main/scala/scala/xml/NodeSeq.scala +++ b/shared/src/main/scala/scala/xml/NodeSeq.scala @@ -95,8 +95,7 @@ abstract class NodeSeq extends AbstractSeq[Node] with immutable.Seq[Node] with S def \(that: String): NodeSeq = { def fail = throw new IllegalArgumentException(that) def atResult = { - lazy val y = this(0) - val attr = + this flatMap (y => ( if (that.length == 1) fail else if (that(1) == '{') { val i = that indexOf '}' @@ -105,10 +104,9 @@ abstract class NodeSeq extends AbstractSeq[Node] with immutable.Seq[Node] with S if (uri == "" || key == "") fail else y.attribute(uri, key) } else y.attribute(that drop 1) - - attr match { - case Some(x) => Group(x) - case _ => NodeSeq.Empty + ).getOrElse(Nil)) match { + case NodeSeq.Empty => NodeSeq.Empty + case x => Group(x) } } @@ -118,7 +116,7 @@ abstract class NodeSeq extends AbstractSeq[Node] with immutable.Seq[Node] with S that match { case "" => fail case "_" => makeSeq(!_.isAtom) - case _ if (that(0) == '@' && this.length == 1) => atResult + case _ if that(0) == '@' => atResult case _ => makeSeq(_.label == that) } } diff --git a/shared/src/test/scala/scala/xml/AttributeTest.scala b/shared/src/test/scala/scala/xml/AttributeTest.scala index 8943a0a00..b7a0d3115 100644 --- a/shared/src/test/scala/scala/xml/AttributeTest.scala +++ b/shared/src/test/scala/scala/xml/AttributeTest.scala @@ -147,9 +147,9 @@ class AttributeTest { val b = xml \ "b" assertEquals(2, b.length) assertEquals(NodeSeq.fromSeq(Seq(, )), b) - val barFail = b \ "@bar" + val barAttributesDirect = b \ "@bar" val barList = b.map(_ \ "@bar") - assertEquals(NodeSeq.Empty, barFail) + assertEquals(Group(Seq(Text("1"), Text("2"))), barAttributesDirect) assertEquals(List(Group(Seq(Text("1"))), Group(Seq(Text("2")))), barList) }