We saw in Part 1 the basic structure of a decision tree. In Part 2 we created a class to handle the samples and labels of a data set. And in Part 3 we saw how to compute the leaves’ values to fit a data set. In this part, we are going to combine the previous results to build a decision tree predictor.

Building a decision tree predictor from a data set

Below is the code from the previous parts, with a slight difference: a combineLabels method has been added to the LabeledData class.

/** The building blocks of a binary decision tree being leaves and nodes, 
  * we create two Scala classes Leaf and Node that extend a trait DecisionTree.
  */
trait DecisionTree[-A, +B] {
  def predict(sample: A): B
}

/** A leaf is a simple element that provides a prediction value (or decision). */
case class Leaf[A, B](decision: B) extends DecisionTree[A, B] {
  def predict(sample: A): B = decision
}

/** A node stores a test function as well as to which node or leaf to go next depending on
  * the result of the test.
  * We define as:
  *   - *right* the leaf or node to go to when the test is true,
  *   - *left* the leaf or node to go to when the test is false.
  */
case class Node[A, B](test: A => Boolean, left: DecisionTree[A, B], right: DecisionTree[A, B])
    extends DecisionTree[A, B] {

  def predict(sample: A): B = test(sample) match {
    case true => right.predict(sample)
    case false => left.predict(sample)
  }
}


case class LabelCombiner[B](combine: Vector[B] => B) {

  /* Combine two elements rather than the elements of a vector */   
  def combine(left: B, right: B): B = combine(Vector(left, right))
}


class LabeledData[A, B] private (
    private val referenceSamples: Vector[A],
    private val referenceLabels: Vector[B],
    private val indices: Vector[Int]) {

  def size: Int = indices.length

  def isEmpty: Boolean = size == 0

  def subset(indices: Vector[Int]): LabeledData[A, B] = 
    new LabeledData(referenceSamples, referenceLabels, indices)

  def emptySet: LabeledData[A, B] = subset(Vector())

  def groupBy[C](f: A => C): Map[C, LabeledData[A, B]] = 
    ((indices groupBy {idx => f(referenceSamples(idx))}) 
      mapValues subset)

  def partition(f: A => Boolean): (LabeledData[A, B], LabeledData[A, B]) = {
    val groups = groupBy(f)
    (groups(true), groups(false))
  }

  def union(that: LabeledData[A, B]): LabeledData[A, B] = {
    require(
      referenceSamples == that.referenceSamples && 
        referenceLabels == that.referenceLabels,
      "Union is only allowed for subsets of the same superset.")
    subset((indices ++ that.indices).distinct)  
  }

  def countDistinctLabels: Int = 
    (indices map referenceLabels).distinct.length

  def mapSamples[C](f: A => C): Vector[C] = 
    indices map {idx => f(referenceSamples(idx))}
    
  def combineLabels(labelCombiner: LabelCombiner[B]): B = 
    labelCombiner.combine(indices map referenceLabels)
}

    
object LabeledData {

  def apply[A, B](samples: Vector[A], labels: Vector[B]): 
      LabeledData[A, B] = {
    require(
      samples.length == labels.length, 
      "Samples and labels should have the same number of elements.")
    new LabeledData[A, B](samples, labels, samples.indices.toVector)
  }
}
$line25.$read$$iw$$iw$LabeledData$@413aee27

We create an Id class to identify the position and depth of a node or leaf:

/** We associate to each leaf or node a unique id *number*,
  * and also store the depth of the leaf or node.
  * 
  * For each node, the next left and right nodes or leaves ids 
  * by considering that each level of the tree is full. The id  
  * numbers represent all the possible branching.
  * By starting with a root node id number of 0, we get the 
  * following numbering for the first levels of the tree:
  *                  0
  *         1                2
  *      3     4          5     6
  *     7 8   9 10      11 12 13 14
  */
case class Id(number: Long, depth: Int) {
  require(number >= 0, "Id number should be nonnegative.")

  def nextIdLeft: Id = Id(number * 2 + 1, depth + 1)

  def nextIdRight: Id = Id(number * 2 + 2, depth + 1)
}
defined class Id

To build a decision tree, we check at each node how to split the data set. We program the specific logic of splitting in the testBuilder function.

case class DecisionTreeBuilder[A, B](dataSet: LabeledData[A, B])
    (combiner: LabelCombiner[B]) {

  def build(testBuilder: (LabeledData[A, B], Id) => Either[String, A => Boolean]):
      DecisionTree[A, B] = {
    val rootId: Id = Id(0, 0)
    buildStep(dataSet, rootId, testBuilder)
  }

  private def buildStep(
    subSet: LabeledData[A, B], 
    id: Id, 
    testBuilder: (LabeledData[A, B], Id) => Either[String, A => Boolean]):
  DecisionTree[A, B] = {
      
    testBuilder(subSet, id) match {
      case Left(_) => Leaf[A, B](subSet.combineLabels(combiner))
      case Right(test) =>
        val Seq(leftSubSet, rightSubSet) = nextSubsets(subSet, test)
        if (leftSubSet.isEmpty || rightSubSet.isEmpty) {
          println("Warning: one of right or left indices is empty.")
          Leaf[A, B](subSet.combineLabels(combiner))
        }
        else Node(
          test,
          buildStep(leftSubSet, id.nextIdLeft, testBuilder),
          buildStep(rightSubSet, id.nextIdRight, testBuilder)
        )
    }
  }

  private def nextSubsets(subSet: LabeledData[A, B], test: A => Boolean):
      Seq[LabeledData[A, B]] = {
    val groups = subSet.groupBy(test) withDefaultValue subSet.emptySet
    Seq(groups(false), groups(true))
  }
}
defined class DecisionTreeBuilder

Below is a short example constructing a tree with only one node and two leaves:

val samples = Vector[Int](1, 2, 3, 4)
val labels = Vector[Double](5.0, -1.0, 0.0, 0.5)
val dataSet = LabeledData(samples, labels)

val combiner: LabelCombiner[Double] = 
  LabelCombiner(v => v.sum / v.length)
  
val treeBuilder = DecisionTreeBuilder(dataSet)(combiner)

def testBuilder(subset: LabeledData[Int, Double], id: Id): Either[String, Int => Boolean] = 
  if (id.depth > 0) Left("Maximal depth reached.") else Right((x: Int) => x < 1.5)

val tree = treeBuilder.build(testBuilder)

// Let's check the prediction of two different points:
println(tree.predict(0) == 5.0)
println(tree.predict(5) == -0.5 / 3)
true
true





null

In this Part, we saw how to build a decision tree predictor and we created an elementary tree with only two leaves. In Part 5, we will create a predictor from a very classic machine learning data set, the Iris data set.