Tuesday, November 25, 2014

Trident-ML: Text Classification using KLD

This post shows some very basic example of how to use the Kullback-Leibler Distance text classification algorithm in Trident-ML to process data from Storm Spout.

Firstly create a Maven project (e.g. with groupId="com.memeanalytics" artifactId="trident-text-classifier-kld"). The complete source codes of the project can be downloaded from the link:

https://dl.dropboxusercontent.com/u/113201788/storm/trident-text-classifier-kld.tar.gz

For the start we need to configure the pom.xml file in the project.

Configure pom.xml:

Firstly we need to add the clojars repository to the repositories section:

<repositories>
<repository>
<id>clojars</id>
<url>http://clojars.org/repo</url>
</repository>
</repositories>

Next we need to add the storm dependency to the dependencies section (for storm):

<dependency>
  <groupId>storm</groupId>
  <artifactId>storm</artifactId>
  <version>0.9.0.1</version>
  <scope>provided</scope>
</dependency>

Next we need to add the strident-ml dependency to the dependencies section (for text classification):

<dependency>
  <groupId>com.github.pmerienne</groupId>
  <artifactId>trident-ml</artifactId>
  <version>0.0.4</version>
</dependency>

Next we need to add the exec-maven-plugin to the build/plugins section (for execute the Maven project):

<plugin>
<groupId>org.codehaus.mojo</groupId>
<artifactId>exec-maven-plugin</artifactId>
<version>1.2.1</version>
<executions>
<execution>
<goals>
<goal>exec</goal>
</goals>
</execution>
</executions>
<configuration>
<includeProjectDependencies>true</includeProjectDependencies>
<includePluginDependencies>false</includePluginDependencies>
<executable>java</executable>
<classpathScope>compile</classpathScope>
<mainClass>com.memeanalytics.trident_text_classifier_kld.App</mainClass>
</configuration>
</plugin>

Next we need to add the maven-assembly-plugin to the build/plugins section (for packacging the Maven project to jar for submitting to Storm cluster):

<plugin>
<artifactId>maven-assembly-plugin</artifactId>
<version>2.2.1</version>
<configuration>
<descriptorRefs>
<descriptorRef>jar-with-dependencies</descriptorRef>
</descriptorRefs>
<archive>
<manifest>
<mainClass></mainClass>
</manifest>
</archive>
</configuration>
<executions>
<execution>
<id>make-assembly</id>
<phase>package</phase>
<goals>
<goal>single</goal>
</goals>
</execution>
</executions>
</plugin>

Implement Spout for training data 

Once the pom.xml update is completed, we can move to implement the ReuterNewsSpout which is the Storm spout that emits batches of training data to the Trident topology:

package com.memeanalytics.trident_text_classifier_kld;

import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import backtype.storm.task.TopologyContext;
import backtype.storm.tuple.Fields;
import backtype.storm.tuple.Values;
import storm.trident.operation.TridentCollector;
import storm.trident.spout.IBatchSpout;

public class ReuterNewsSpout implements IBatchSpout {

 private static final long serialVersionUID = 1L;
 private List<List<Object>> trainingData=new ArrayList<List<Object>>();
 private static Map<Integer, List<Object>> testingData=new HashMap<Integer, List<Object>>();
 
 private int batchSize=10;
 private int batchIndex=0;
 
 public ReuterNewsSpout()
 {
  try{
   loadReuterNews();
  }catch(FileNotFoundException ex)
  {
   ex.printStackTrace();
  }catch(IOException ex)
  {
   ex.printStackTrace();
  }
 }
 
 public static List<List<Object>> getTestingData()
 {
  List<List<Object>> result=new ArrayList<List<Object>>();
  for(Integer topic_index : testingData.keySet())
  {
   result.add(testingData.get(topic_index));
  }
  
  return result;
 }
 
 private void loadReuterNews() throws FileNotFoundException, IOException
 {
  Map<String, Integer> topics=new HashMap<String, Integer>();
  String filePath="src/test/resources/reuters.csv";
  FileInputStream inputStream=new FileInputStream(filePath);
  BufferedReader reader= new BufferedReader(new InputStreamReader(inputStream));
  String line;
  while((line = reader.readLine())!=null)
  {
   String topic = line.split(",")[0];
   if(!topics.containsKey(topic))
   {
    topics.put(topic, topics.size());
   }
   Integer topic_index=topics.get(topic);
   
   int index = line.indexOf(" - ");
   if(index==-1) continue;
   
   String text=line.substring(index, line.length()-1);
   
   if(testingData.containsKey(topic_index))
   {
    List<Object> values=new ArrayList<Object>();
    values.add(topic_index);
    values.add(text);
    trainingData.add(values);
   }
   else 
   {
    testingData.put(topic_index, new Values(topic_index, text));
   }
  }
  reader.close();
 }
 public void open(Map conf, TopologyContext context) {
  // TODO Auto-generated method stub
  
 }

 public void emitBatch(long batchId, TridentCollector collector) {
  // TODO Auto-generated method stub
  
  int maxBatchIndex = (trainingData.size() / batchSize);
  
  if(trainingData.size() > batchSize && batchIndex < maxBatchIndex)
  {
   for(int i=batchIndex * batchSize; i < trainingData.size() && i < (batchIndex+1) * batchSize; ++i)
   {
    collector.emit(trainingData.get(i));
   }
   
   
   batchIndex++;
   
   //System.out.println("Progress: "+batchIndex +" / "+maxBatchIndex);
  }
 }

 public void ack(long batchId) {
  // TODO Auto-generated method stub
  
 }

 public void close() {
  // TODO Auto-generated method stub
  
 }

 public Map getComponentConfiguration() {
  // TODO Auto-generated method stub
  return null;
 }

 public Fields getOutputFields() {
  // TODO Auto-generated method stub
  return new Fields("label", "text");
 }

}


As can be seen above, the ReuterNewsSpout is derived from IBatchSpout, and emits a batch of 10 tuples at one time, each tuple is a new article containing the fields ("label", "text"). The "label" field is integer value (represents the topic of the news article), while "text" field is a string which is text of the news article. the training records are obtained in such a way that the correct prediction learned from the text classification should be predicting the topic of a news article given the text of the news article.

KLD Text Classification in Trident topology using Trident-ML implementation

Once we have the training data spout, we can build a Trident topology which uses the training data to create a class label for each of the data record using KLD classifier algorithm in Trident-ML. This is implemented in the main class shown below:

package com.memeanalytics.trident_text_classifier_kld;

import java.util.List;

import com.github.pmerienne.trident.ml.nlp.ClassifyTextQuery;
import com.github.pmerienne.trident.ml.nlp.KLDClassifier;
import com.github.pmerienne.trident.ml.nlp.TextClassifierUpdater;
import com.github.pmerienne.trident.ml.preprocessing.TextInstanceCreator;

import storm.trident.TridentState;
import storm.trident.TridentTopology;
import storm.trident.testing.MemoryMapState;
import backtype.storm.Config;
import backtype.storm.LocalCluster;
import backtype.storm.LocalDRPC;
import backtype.storm.generated.AlreadyAliveException;
import backtype.storm.generated.InvalidTopologyException;
import backtype.storm.generated.StormTopology;
import backtype.storm.tuple.Fields;


public class App 
{
    public static void main( String[] args ) throws AlreadyAliveException, InvalidTopologyException
    {
        LocalDRPC drpc=new LocalDRPC();
        
        LocalCluster cluster=new LocalCluster();
        
        Config config=new Config();
        
        cluster.submitTopology("KLDDemo", config, buildTopology(drpc));
        
        try{
         Thread.sleep(20000);
        }catch(InterruptedException ex)
        {
         ex.printStackTrace();
        }
        
        List<List<Object>> testingData = ReuterNewsSpout.getTestingData();
        
        for(int i=0; i < testingData.size(); ++i)
        {
         List<Object> testingDataRecord=testingData.get(i);
         String drpc_args="";
         for(Object val : testingDataRecord){
          if(drpc_args.equals(""))
          {
           drpc_args+=val;
          }
          else
          {
           drpc_args+=(","+val);
          }
         }
         System.out.println(drpc.execute("predict", drpc_args));
        }
        
        cluster.killTopology("KLDDemo");
        cluster.shutdown();
        
        drpc.shutdown();
    }
    
    private static StormTopology buildTopology(LocalDRPC drpc)
    {
     ReuterNewsSpout spout=new ReuterNewsSpout();
     
     TridentTopology topology=new TridentTopology();
     
     TridentState classifierModel = topology.newStream("training", spout).each(new Fields("label", "text"), new TextInstanceCreator<Integer>(), new Fields("instance")).partitionPersist(new MemoryMapState.Factory(), new Fields("instance"), new TextClassifierUpdater("newsClassifier", new KLDClassifier(9)));
     
     topology.newDRPCStream("predict", drpc).each(new Fields("args"), new DRPCArgsToInstance(), new Fields("instance")).stateQuery(classifierModel, new Fields("instance"), new ClassifyTextQuery("newsClassifier"), new Fields("prediction"));
     return topology.build();
    }
}

package com.memeanalytics.trident_text_classifier_kld;

import java.util.ArrayList;
import java.util.List;

import backtype.storm.tuple.Values;

import com.github.pmerienne.trident.ml.core.TextInstance;
import com.github.pmerienne.trident.ml.preprocessing.EnglishTokenizer;
import com.github.pmerienne.trident.ml.preprocessing.TextTokenizer;

import storm.trident.operation.BaseFunction;
import storm.trident.operation.TridentCollector;
import storm.trident.tuple.TridentTuple;

public class DRPCArgsToInstance extends BaseFunction{

 private static final long serialVersionUID = 1L;

 public void execute(TridentTuple tuple, TridentCollector collector) {
  // TODO Auto-generated method stub
  String drpc_args=tuple.getString(0);
  String[] args=drpc_args.split(",");
  Integer label=Integer.parseInt(args[0]);
  
  String text=args[1];
  
  TextTokenizer textAnalyzer=new EnglishTokenizer();
  List<String> tokens=textAnalyzer.tokenize(text);
  
  
  TextInstance<Integer> instance=new TextInstance<Integer>(label, tokens);
  
  collector.emit(new Values(instance));
 }

}

As can be seen above, the Trident topology has a TextInstanceCreator<Integer> trident operation which convert raw ("label", "text") tuple into an TextInstance<Integer> object which can be consumed by TextClassifierUpdater. The TextClassifierUpdater object from Trident-ML updates the underlying classifierModel via KLDClassifier training algorithm.

The DRPCStream allows user to pass in a new testing instance to the classifierModel which will then return a "predict" field, that contains the predicted label of the testing instance. The DRPCArgsToInstance is a BaseFunction operation which converts the arguments passed into the LocalDRPC.execute() into an TextInstance<Integer> (Note you can set the label to null in DRPCArgsToInstance.execute() method as the label will be predicted instead) which can be passed into the ClassifyTextQuery which then uses KLD and classifierModel to determine the predicted label.

Once the coding is completed, we can run the project by navigating to the project root folder and run the following commands:

> .mvn compile exec:java

No comments:

Post a Comment