Oleh Rybalchenko
Oleh Rybalchenko

Reputation: 8059

How to import predefined decision tree and use it for classification

Long story short

As an input I have a file with text representation of simple decision tree:

Region in [ "someregion" ]
    Revenue <= 1020.30
        group in [ "audio" ] => 123.456
        group in [ "disc" ] => 123.456
            volume <= 1 => 734.25
...

The program should import it as a classifier and be able to predict object's value. In other words, for object like following:

{"Region": "someregion", "Revenue": 100, "group": "disc", "volume": 0.5}

the prediction will be 734.25.

What existing decision tree implementations can I use to create a classifier? SciKit trees is almost the thing, but I didn't find a way to build a custom predefined trees instead of fitting on a dataset.

My attempt

For now I implement a simple tree parser:

import re

def parse_condition(row):
    # try with leaf regex
    condition = re.search(
        r'^(?P<field>.*?) (?P<statement>.*?) (?P<value>.*?)(?: => )(?P<target>\d*\.\d*)',
        row
    ) or re.search(
        r'^(?P<field>.*?) (?P<statement>.*?) (?P<value>.*?)',
        row)
    return condition.groupdict()

f = open('tree.txt', 'r')

for row in f.readlines():
    level = len(re.search(r'^(\t*)', row).group(0))
    row = row.strip()
    condition = parse_condition(row)
    el = (level, condition)
    print(el)

which extracts node level, condition and target value.

(0, {'field': 'Region', 'statement': 'in', 'value': ''})
(1, {'field': 'Revenue', 'statement': '<=', 'value': ''})
(2, {'field': 'group', 'statement': 'in', 'value': '[ "audio" ]', 'target': '123.456'})
(2, {'field': 'group', 'statement': 'in', 'value': '[ "disc" ]', 'target': '123.456'})
(3, {'field': 'volume', 'statement': '<=', 'value': '1', 'target': '734.25'})

Although I can develop a custom decision tree and condition parser from scratch, it seems like attempt to reinvent the wheel.

Upvotes: 1

Views: 966

Answers (1)

dani herrera
dani herrera

Reputation: 51705

It exists a format named PMML, Predictive Model Markup Language. You can store decissions trees in this format to avoid to reinvent the wheel.

For example, knime software is able to deal with this format Example for Learning a Decision Tree. A PMML decision tree looks like this example:

<?xml version="1.0" encoding="UTF-8"?>
<PMML version="4.1" xmlns="http://www.dmg.org/PMML-4_1">
  <Header copyright="dani">
    <Application name="KNIME" version="2.7.2"/>
  </Header>
  <DataDictionary numberOfFields="7">
    <DataField name="nom_nivell" optype="categorical" dataType="string">
      <Value value="ESO"/>
      ...
      <Value value="CFGM Infor"/>
    </DataField>
    <DataField name="hora_inici" optype="categorical" dataType="string">
      <Value value="09:15:00"/>
      ...
      <Value value="13:45:00"/>
    </DataField>
    ...
  </DataDictionary>
  <TreeModel modelName="DecisionTree" functionName="classification" splitCharacteristic="multiSplit" missingValueStrategy="lastPrediction" noTrueChildStrategy="returnNullPrediction">
    <MiningSchema>
      <MiningField name="nom_nivell" invalidValueTreatment="asIs"/>
      <MiningField name="hora_inici" invalidValueTreatment="asIs"/>
      <MiningField name="assistenciaMateixaHora1WeekBefore" invalidValueTreatment="asIs"/>
      <MiningField name="assistencia" invalidValueTreatment="asIs" usageType="predicted"/>
    </MiningSchema>
    <Node id="0" score="Present" recordCount="244770.0">
      <True/>
      <ScoreDistribution value="Present" recordCount="211657.0"/>
      <ScoreDistribution value="Absent" recordCount="24925.0"/>
          ...

Graphically it looks like this on Knime:

enter image description here

Then, the easy way to figure up results rom a PMML is using tree traversals. I posted on my githup repo lightpmmlpredictor an utility to do it. The core is a simple while traversing nodes using etree from lxml:

while True:

    try:        

        fill = next( e for e in Node 
                     if etree.QName(e).localname == 'Node' and
                        unicode(values[ e[0].get('field') ]) == e[0].get('value') )

        try:
            Node = fill
            predict = Node.get("score")
            n_tot = Node.get("recordCount")
            n_predict = max(  x.get( 'recordCount' ) 
                              for x in Node 
                              if etree.QName(x).localname == 'ScoreDistribution'   
                                 and x.get('value') == predict )
        except IndexError:
            break

        try:
            pct = float(n_predict) / float(n_tot)
        except:
            pct = 0.5
    except StopIteration:
        break

Be free to contribute to or fork my repo.

Upvotes: 4

Related Questions