diff --git a/src/main/scala/dregex/Universe.scala b/src/main/scala/dregex/Universe.scala index 5c2734b..5929071 100644 --- a/src/main/scala/dregex/Universe.scala +++ b/src/main/scala/dregex/Universe.scala @@ -1,8 +1,7 @@ package dregex -import dregex.impl.RegexTree -import dregex.impl.CharInterval -import dregex.impl.Normalization +import dregex.impl.{CharInterval, Normalization, RegexTree} +import scala.jdk.CollectionConverters._ /** * The purpose of this class is to enforce that set operation between regular expressions are only done when it is @@ -22,8 +21,8 @@ class Universe(parsedTrees: Seq[RegexTree.Node], val normalization: Normalizatio import RegexTree._ - private[dregex] val alphabet: Map[AbstractRange, Seq[CharInterval]] = { - CharInterval.calculateNonOverlapping(parsedTrees.flatMap(t => collect(t))) + private[dregex] val alphabet: java.util.Map[AbstractRange, java.util.List[CharInterval]] = { + CharInterval.calculateNonOverlapping(parsedTrees.flatMap(t => collect(t)).asJava) } /** diff --git a/src/main/scala/dregex/extra/DotFormatter.scala b/src/main/scala/dregex/extra/DotFormatter.scala index 54e1613..3e6ec44 100644 --- a/src/main/scala/dregex/extra/DotFormatter.scala +++ b/src/main/scala/dregex/extra/DotFormatter.scala @@ -21,7 +21,7 @@ object DotFormatter { } val transitions = for (transition <- nfa.transitions) yield { val weight = - if (transition.char == Epsilon) + if (transition.char == new Epsilon()) 1 else 2 diff --git a/src/main/scala/dregex/impl/AtomPart.java b/src/main/scala/dregex/impl/AtomPart.java new file mode 100644 index 0000000..d3f9945 --- /dev/null +++ b/src/main/scala/dregex/impl/AtomPart.java @@ -0,0 +1,7 @@ +package dregex.impl; + +/** + * A single or null char, i.e., including epsilon values + */ +public interface AtomPart { +} diff --git a/src/main/scala/dregex/impl/CharInterval.java b/src/main/scala/dregex/impl/CharInterval.java new file mode 100644 index 0000000..19193c4 --- /dev/null +++ b/src/main/scala/dregex/impl/CharInterval.java @@ -0,0 +1,84 @@ +package dregex.impl; + +import scala.math.Ordered; + +import java.util.*; + +public final class CharInterval implements AtomPart, Ordered { + + public final UnicodeChar from; + public final UnicodeChar to; + + public CharInterval(UnicodeChar from, UnicodeChar to) { + if (from == null) { + throw new NullPointerException("from is null"); + } + if (to == null) { + throw new NullPointerException("to is null"); + } + + if (from.compare(to) > 0) { + throw new IllegalArgumentException("from value cannot be larger than to"); + } + this.from = from; + this.to = to; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + CharInterval that = (CharInterval) o; + return Objects.equals(from, that.from) && Objects.equals(to, that.to); + } + + @Override + public int hashCode() { + return Objects.hash(from, to); + } + + @Override + public int compare(CharInterval that) { + return this.from.compare(that.from); + } + + public String toString() { + if (from.equals(to)) { + return from.toString(); + } else { + return String.format("[%s-%s]", from, to); + } + } + + public static Map> calculateNonOverlapping(List ranges) { + Set startSet = new HashSet<>(); + Set endSet = new HashSet<>(); + for (var range : ranges) { + startSet.add(range.from()); + if (range.from().compare(UnicodeChar.min()) > 0) { + endSet.add(range.from().$minus(1)); + } + endSet.add(range.to()); + if (range.to().compare(UnicodeChar.max()) < 0) { + startSet.add(range.to().$plus(1)); + } + } + Map> ret = new HashMap<>(); + for (var range : ranges) { + var startCopySet = new java.util.TreeSet<>(startSet); + var endCopySet = new java.util.TreeSet<>(endSet); + var startSubSet = startCopySet.subSet(range.from(), true, range.to(), true); + var endSubSet = endCopySet.subSet(range.from(), true, range.to(), true); + assert startSubSet.size() == endSubSet.size(); + List res = new ArrayList<>(startSubSet.size()); + do { + var start = startSubSet.pollFirst(); + var end = endSubSet.pollFirst(); + res.add(new CharInterval(start, end)); + } while (!startSubSet.isEmpty()); + ret.put(range, res); + } + return ret; + } + +} diff --git a/src/main/scala/dregex/impl/Compiler.scala b/src/main/scala/dregex/impl/Compiler.scala index e3d990d..f57a1a1 100644 --- a/src/main/scala/dregex/impl/Compiler.scala +++ b/src/main/scala/dregex/impl/Compiler.scala @@ -2,14 +2,16 @@ package dregex.impl import dregex.InvalidRegexException +import java.util.stream.Collectors import scala.collection.mutable.Buffer +import scala.jdk.CollectionConverters._ /** * Take a regex AST and produce a NFA. * Except when noted the Thompson-McNaughton-Yamada algorithm is used. * Reference: http://stackoverflow.com/questions/11819185/steps-to-creating-an-nfa-from-a-regular-expression */ -class Compiler(intervalMapping: Map[RegexTree.AbstractRange, Seq[CharInterval]]) { +class Compiler(intervalMapping: java.util.Map[RegexTree.AbstractRange, java.util.List[CharInterval]]) { import RegexTree._ @@ -30,8 +32,8 @@ class Compiler(intervalMapping: Map[RegexTree.AbstractRange, Seq[CharInterval]]) // base case case range: AbstractRange => - val intervals = intervalMapping(range) - intervals.map(interval => NfaTransition(from, to, interval)) + val intervals = intervalMapping.get(range) + intervals.stream().map(interval => NfaTransition(from, to, interval)).collect(Collectors.toList()).asScala.toSeq // recurse @@ -141,7 +143,7 @@ class Compiler(intervalMapping: Map[RegexTree.AbstractRange, Seq[CharInterval]]) private def processJuxtNoLookaround(juxt: Juxt, from: SimpleState, to: SimpleState): Seq[NfaTransition] = { juxt match { case Juxt(Seq()) => - Seq(NfaTransition(from, to, Epsilon)) + Seq(NfaTransition(from, to, new Epsilon())) case Juxt(Seq(head)) => fromTreeImpl(head, from, to) @@ -178,7 +180,7 @@ class Compiler(intervalMapping: Map[RegexTree.AbstractRange, Seq[CharInterval]]) fromTreeImpl(value, from, to) case Rep(0, Some(0), value) => - Seq(NfaTransition(from, to, Epsilon)) + Seq(NfaTransition(from, to, new Epsilon())) // infinite repetitions @@ -190,18 +192,18 @@ class Compiler(intervalMapping: Map[RegexTree.AbstractRange, Seq[CharInterval]]) val int1 = new SimpleState val int2 = new SimpleState fromTreeImpl(value, int1, int2) :+ - NfaTransition(from, int1, Epsilon) :+ - NfaTransition(int2, to, Epsilon) :+ - NfaTransition(int2, int1, Epsilon) + NfaTransition(from, int1, new Epsilon()) :+ + NfaTransition(int2, to, new Epsilon()) :+ + NfaTransition(int2, int1, new Epsilon()) case Rep(0, None, value) => val int1 = new SimpleState val int2 = new SimpleState fromTreeImpl(value, int1, int2) :+ - NfaTransition(from, int1, Epsilon) :+ - NfaTransition(int2, to, Epsilon) :+ - NfaTransition(from, to, Epsilon) :+ - NfaTransition(int2, int1, Epsilon) + NfaTransition(from, int1, new Epsilon()) :+ + NfaTransition(int2, to, new Epsilon()) :+ + NfaTransition(from, to, new Epsilon()) :+ + NfaTransition(int2, int1, new Epsilon()) // finite repetitions @@ -219,11 +221,11 @@ class Compiler(intervalMapping: Map[RegexTree.AbstractRange, Seq[CharInterval]]) for (i <- 1 until m - 1) { val int = new SimpleState transitions ++= fromTreeImpl(value, prev, int) - transitions += NfaTransition(prev, to, Epsilon) + transitions += NfaTransition(prev, to, new Epsilon()) prev = int } transitions ++= fromTreeImpl(value, prev, to) - transitions += NfaTransition(prev, to, Epsilon) + transitions += NfaTransition(prev, to, new Epsilon()) transitions.to(Seq) case Rep(0, Some(m), value) if m > 0 => @@ -233,11 +235,11 @@ class Compiler(intervalMapping: Map[RegexTree.AbstractRange, Seq[CharInterval]]) for (i <- 0 until m - 1) { val int = new SimpleState transitions ++= fromTreeImpl(value, prev, int) - transitions += NfaTransition(prev, to, Epsilon) + transitions += NfaTransition(prev, to, new Epsilon()) prev = int } transitions ++= fromTreeImpl(value, prev, to) - transitions += NfaTransition(prev, to, Epsilon) + transitions += NfaTransition(prev, to, new Epsilon()) transitions.to(Seq) } @@ -254,16 +256,16 @@ class Compiler(intervalMapping: Map[RegexTree.AbstractRange, Seq[CharInterval]]) val result = DfaAlgorithms.toNfa(operation(leftDfa, rightDfa)) result.transitions ++ - result.accepting.to(Seq).map(acc => NfaTransition(acc, to, Epsilon)) :+ - NfaTransition(from, result.initial, Epsilon) + result.accepting.to(Seq).map(acc => NfaTransition(acc, to, new Epsilon())) :+ + NfaTransition(from, result.initial, new Epsilon()) } def processCaptureGroup(value: Node, from: SimpleState, to: SimpleState): Seq[NfaTransition] = { val int1 = new SimpleState val int2 = new SimpleState fromTreeImpl(value, int1, int2) :+ - NfaTransition(from, int1, Epsilon) :+ - NfaTransition(int2, to, Epsilon) + NfaTransition(from, int1, new Epsilon()) :+ + NfaTransition(int2, to, new Epsilon()) } } diff --git a/src/main/scala/dregex/impl/DfaAlgorithms.scala b/src/main/scala/dregex/impl/DfaAlgorithms.scala index 9db4b87..be8746d 100644 --- a/src/main/scala/dregex/impl/DfaAlgorithms.scala +++ b/src/main/scala/dregex/impl/DfaAlgorithms.scala @@ -185,7 +185,7 @@ object DfaAlgorithms { @tailrec def followEpsilonImpl(current: Set[State]): MultiState = { val immediate = for (state <- current) yield { - transitionMap.getOrElse(state, Map()).getOrElse(Epsilon, Set()) + transitionMap.getOrElse(state, Map()).getOrElse(new Epsilon(), Set()) } val expanded = immediate.fold(current)(_ union _) if (expanded == current) @@ -229,7 +229,7 @@ object DfaAlgorithms { def reverse[A <: State](dfa: Dfa[A]): Nfa = { val initial: State = new SimpleState - val first = dfa.accepting.to(Seq).map(s => NfaTransition(initial, s, Epsilon)) + val first = dfa.accepting.to(Seq).map(s => NfaTransition(initial, s, new Epsilon())) val rest = for { (from, fn) <- dfa.defTransitions (char, to) <- fn @@ -282,7 +282,7 @@ object DfaAlgorithms { val char = UnicodeChar(codePoint) val currentTrans = dfa.defTransitions.getOrElse(current, SortedMap[CharInterval, A]()) // O(log transitions) search in the range tree - val newState = Util.floorEntry(currentTrans, CharInterval(from = char, to = char)).flatMap { + val newState = Util.floorEntry(currentTrans, new CharInterval(char, char)).flatMap { case (interval, state) => if (interval.to >= char) { Some(state) diff --git a/src/main/scala/dregex/impl/Epsilon.java b/src/main/scala/dregex/impl/Epsilon.java new file mode 100644 index 0000000..5c9eadd --- /dev/null +++ b/src/main/scala/dregex/impl/Epsilon.java @@ -0,0 +1,18 @@ +package dregex.impl; + +public final class Epsilon implements AtomPart { + + public String toString() { + return "ε"; + } + + @Override + public boolean equals(Object other) { + return getClass().equals(other.getClass()); + } + + @Override + public int hashCode() { + return getClass().hashCode(); + } +} diff --git a/src/main/scala/dregex/impl/atoms.scala b/src/main/scala/dregex/impl/atoms.scala deleted file mode 100644 index 82cfdbc..0000000 --- a/src/main/scala/dregex/impl/atoms.scala +++ /dev/null @@ -1,65 +0,0 @@ -package dregex.impl - -import dregex.impl.RegexTree.AbstractRange -import scala.collection.mutable.ArrayBuffer -import scala.jdk.CollectionConverters._ - -/** - * A single or null char, i.e., including epsilon values - */ -sealed trait AtomPart - -case class CharInterval(from: UnicodeChar, to: UnicodeChar) extends AtomPart with Ordered[CharInterval] { - - if (from > to) - throw new IllegalArgumentException("from value cannot be larger than to") - - def compare(that: CharInterval): Int = this.from compare that.from - - override def toString = { - if (from == to) { - from.toString - } else { - s"[$from-$to]" - } - } - -} - -case object Epsilon extends AtomPart { - override def toString = "ε" -} - -object CharInterval { - - def calculateNonOverlapping(ranges: Seq[AbstractRange]): Map[AbstractRange, Seq[CharInterval]] = { - val startSet = collection.mutable.Set[UnicodeChar]() - val endSet = collection.mutable.Set[UnicodeChar]() - for (range <- ranges) { - startSet.add(range.from) - if (range.from > UnicodeChar.min) { - endSet.add(range.from - 1) - } - endSet.add(range.to) - if (range.to < UnicodeChar.max) { - startSet.add(range.to + 1) - } - } - val pairs = for (range <- ranges) yield { - val startCopySet = new java.util.TreeSet[UnicodeChar](startSet.asJava) - val endCopySet = new java.util.TreeSet[UnicodeChar](endSet.asJava) - val startSubSet = startCopySet.subSet(range.from, true, range.to, true) - val endSubSet = endCopySet.subSet(range.from, true, range.to, true) - assert(startSubSet.size == endSubSet.size) - val res = new ArrayBuffer[CharInterval](initialSize = startSubSet.size) - do { - val start = startSubSet.pollFirst() - val end = endSubSet.pollFirst() - res += CharInterval(from = start, to = end) - } while (!startSubSet.isEmpty()) - range -> res.to(Seq) - } - pairs.toMap - } - -} diff --git a/src/test/scala/dregex/CharIntervalTest.scala b/src/test/scala/dregex/CharIntervalTest.scala index 64a183a..3d1b9a5 100644 --- a/src/test/scala/dregex/CharIntervalTest.scala +++ b/src/test/scala/dregex/CharIntervalTest.scala @@ -3,7 +3,9 @@ package dregex import dregex.impl.UnicodeChar import dregex.impl.CharInterval import dregex.impl.RegexTree.CharRange +import dregex.impl.RegexTree.AbstractRange import org.scalatest.funsuite.AnyFunSuite +import scala.jdk.CollectionConverters._ class CharIntervalTest extends AnyFunSuite { @@ -17,34 +19,53 @@ class CharIntervalTest extends AnyFunSuite { implicit def pairToInterval(pair: (Int, Int)): CharInterval = { pair match { - case (from, to) => CharInterval(UnicodeChar(from), UnicodeChar(to)) + case (from, to) => new CharInterval(UnicodeChar(from), UnicodeChar(to)) } } test("non-overlapping") { - val ranges = Seq[CharRange]((10, 20), (21, 30), (0, 100), (9, 9), (10, 11), (9, 10), (10, 12), (17, 25)) - val nonOverlapping = CharInterval.calculateNonOverlapping(ranges) - val expected = Map[CharRange, Seq[CharInterval]]( - CharRange(10, 20) -> Seq((10, 10), (11, 11), (12, 12), (13, 16), (17, 20)), - CharRange(21, 30) -> Seq((21, 25), (26, 30)), - CharRange(0, 100) -> Seq( - (0, 8), - (9, 9), - (10, 10), - (11, 11), - (12, 12), - (13, 16), - (17, 20), - (21, 25), - (26, 30), - (31, 100)), - CharRange(9, 9) -> Seq((9, 9)), - CharRange(10, 11) -> Seq((10, 10), (11, 11)), - CharRange(9, 10) -> Seq((9, 9), (10, 10)), - CharRange(10, 12) -> Seq((10, 10), (11, 11), (12, 12)), - CharRange(17, 25) -> Seq((17, 20), (21, 25)) + val ranges = Seq[AbstractRange]((10, 20), (21, 30), (0, 100), (9, 9), (10, 11), (9, 10), (10, 12), (17, 25)) + val nonOverlapping = CharInterval.calculateNonOverlapping(ranges.asJava) + val expected = java.util.Map.of( + CharRange(10, 20), java.util.List.of( + new CharInterval(10, 10), + new CharInterval(11, 11), + new CharInterval(12, 12), + new CharInterval(13, 16), + new CharInterval(17, 20)), + CharRange(21, 30), java.util.List.of( + new CharInterval(21, 25), + new CharInterval(26, 30)), + CharRange(0, 100), java.util.List.of( + new CharInterval(0, 8), + new CharInterval(9, 9), + new CharInterval(10, 10), + new CharInterval(11, 11), + new CharInterval(12, 12), + new CharInterval(13, 16), + new CharInterval(17, 20), + new CharInterval(21, 25), + new CharInterval(26, 30), + new CharInterval(31, 100)), + CharRange(9, 9), java.util.List.of( + new CharInterval(9, 9)), + CharRange(10, 11), java.util.List.of( + new CharInterval(10, 10), + new CharInterval(11, 11)), + CharRange(9, 10), java.util.List.of( + new CharInterval(9, 9), + new CharInterval(10, 10)), + CharRange(10, 12), java.util.List.of( + new CharInterval(10, 10), + new CharInterval(11, 11), + new CharInterval(12, 12)), + CharRange(17, 25), java.util.List.of( + new CharInterval(17, 20), + new CharInterval(21, 25)) ) - assertResult(expected)(nonOverlapping) + println(expected) + println(nonOverlapping) + assert(expected == nonOverlapping) } } diff --git a/src/test/scala/dregex/UnicodeTest.scala b/src/test/scala/dregex/UnicodeTest.scala index 9746b5a..f791a00 100644 --- a/src/test/scala/dregex/UnicodeTest.scala +++ b/src/test/scala/dregex/UnicodeTest.scala @@ -5,7 +5,7 @@ import dregex.impl.{PredefinedCharSets, UnicodeChar} import org.scalatest.funsuite.AnyFunSuite import org.slf4j.LoggerFactory -import java.lang.Character.UnicodeScript +import java.lang.Character.{UnicodeBlock, UnicodeScript} import scala.util.control.Breaks._ class UnicodeTest extends AnyFunSuite { @@ -112,7 +112,8 @@ class UnicodeTest extends AnyFunSuite { val javaRegex = java.util.regex.Pattern.compile(regexString) for (codePoint <- UnicodeChar.min.codePoint to UnicodeChar.max.codePoint) { val codePointAsString = new String(Array(codePoint), 0, 1) - assert(regex.matches(codePointAsString) == javaRegex.matcher(codePointAsString).matches()) + assert(regex.matches(codePointAsString) == javaRegex.matcher(codePointAsString).matches(), + s"- block: $block; java bloc: ${UnicodeBlock.of(codePoint)}; code point: ${String.format("0x%04X", Int.box(codePoint))}") } } else { logger.debug("skipping Unicode block {} as it's not present in the current Java version", block)