Saturday, 19 February 2011

Scala collections: Filtering each n-th element

Recently I had to process huge log files. I found it to be a great opportunity to check my Scala language skills. One of the problems I had to figure out was filtering every n-th element from iterable collection. This is supposed to be a trivial task, but well, after thinking for a while I had five competitive implementations ready. Which one should I pick? Perhaps the most concise one. To make the choice easier I've decided to measure how performant they are.
Filtering implementations:
  1. Filter using groups This is my favorite one, the cleanest implementation.
    def filterUsingGroups[A](in: Iterable[A], step: Int, offset: Int): Iterable[A] =
      in.drop(offset).grouped(step).map(_.head).toIterable
    How it works:
    in = (1,2,3,4,5,6,7,8,9,10), step=3, offset=2
    drop(2) => (3,4,5,6,7,8,9,10)
    grouped(3) => ((3,4,5),(6,7,8),(9,10))
    map(_.head) => (3,6,9)
    
  2. Filter using index
    def filterUsingIndex[A](in: Iterable[A], step: Int, offset: Int): Iterable[A] =
      in.zipWithIndex.filter(_._2 % step == offset).unzip._1
    
    How it works:
    in = (1,2,3,4,5,6,7,8,9,10), step=3, offset=2
    zipWithIndex => ((1,0),(2,1),....,(10,9))
    filter(_._2 % 3 == 2) => ((3,2), (6,5), (9,8))
    unzip._1 => (3,6,9)
    
  3. Filter using recursion
    def filterUsingRecursion[A](in: Iterable[A], step: Int, offset: Int): Iterable[A] = {
      if (in.size >= step) {
        val (l, r) = in.splitAt(step)
        List(l.drop(offset).head) ++ filterUsingRecursion(r, step, offset)
      } else List.empty[A]
    }
    
    How it works:
    in = (1,2,3,4,5,6,7,8,9,10), step=3, offset=2
    splitAt(3) => l=(1,2,3); r=(4,5,6,7,8,9,10)
    l.drop(2).head == 3 => List(3) ++ filterUsingRecursion((4,5,6,7,8,9,10), 3, 2)
    
  4. Filter using loop (unbuffered)
    def filterUsingLoop[A](in: Iterable[A], step: Int, offset: Int): Iterable[A] = {
      var out = List.empty[A]
      val it = in.iterator
      for (i <- 0 to in.size - 1) {
        if (i % step == offset) out = out :+ it.next
        else it.next
      }
      out
    }
    
  5. Filter using loop (buffered)
    def filterUsingLoopWithBuffer[A](in: Iterable[A], step: Int, offset: Int): Iterable[A] = {
      val out = collection.mutable.ListBuffer.empty[A]
      val it = in.iterator
      for (i <- 0 to in.size - 1) {
        if (i % step == offset) out += it.next
        else it.next
      }
      out
    }
    
  6. Filter using tail recursion (added later)
    def filterUsingTailRecursion[A](in: Iterable[A], step: Int, offset: Int): Iterable[A] = {
      def filter(acc: ListBuffer[A], in: Iterable[A]): ListBuffer[A] = {
        in.splitAt(step) match {
          case (l, r) if (l.size >= offset) => filter(acc += l.drop(offset).head, r)
          case _ => acc
        }
      }
      filter(ListBuffer.empty[A], in)
    }
    

In my benchmark I've called each method 10 times in a row and measured average execution time (using parameters step=3, offset=2). When benchmarking I call foreach on the result of each execution of the method under test. I do this to materialize returned collection if it would be backed by an iterator.

def benchmark[A](fn: => Iterable[A]): Double = {
  fn // warm-up
  val times = 10
  val start = System.currentTimeMillis
  for (i <- 1 to times) fn.foreach(_ => ())
  val stop = System.currentTimeMillis
  (stop - start).toDouble / times
}
The chart below shows the results I've collected. Conclusions:
  • Buffered loop is the fastest implementation
  • Index is also quite good
  • Recursion is working pretty well with small collections but fails to process huge collections (java.lang.StackOverflowError)
  • Unbuffered loop fails to process huge collections because of intesive data copying
  • Groups give worst performance
  • Tail recursion works very well with huge collections
Sadly the implementation using groups is not the most performant although it's clean and concise. Anyway it's good enough when processing less than 1000 elements. That makes it quite good candidate for initial implementation.

4 komentarze:

Seth Tisue said...

#4 is O(n^2) overall because :+ is O(n) on List. +: is O(1) though, so you could build up the list in reverse order and then reverse it at the end of the loop, reducing the overall runtime to O(n).

Krzysztof Białek said...

Thanks for pointing this out. Using +: gives better results. I've also tried with :: and it's even faster.
Btw, I figured out one more implementation which is very promising although it requires the collection to be a Seq: Range(offset, seq.size, step).map(seq.apply(_))
Benchmarks will be published soon.

Seth Tisue said...

Good idea, I like it.

Note that map(seq.apply(_)) can be shortened to just map(seq).

Laura Dietz said...

Yikes!

If you want to use a set of indices rather than Range(..), don't forget to use a toSeq before the map!

idxSet.toSeq.map(seq).sum