presage  0.9.2~beta
ARPAPredictor.cpp
Go to the documentation of this file.
1 
2 /******************************************************
3  * Presage, an extensible predictive text entry system
4  * ---------------------------------------------------
5  *
6  * Copyright (C) 2008 Matteo Vescovi <matteo.vescovi@yahoo.co.uk>
7 
8  This program is free software; you can redistribute it and/or modify
9  it under the terms of the GNU General Public License as published by
10  the Free Software Foundation; either version 2 of the License, or
11  (at your option) any later version.
12 
13  This program is distributed in the hope that it will be useful,
14  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16  GNU General Public License for more details.
17 
18  You should have received a copy of the GNU General Public License along
19  with this program; if not, write to the Free Software Foundation, Inc.,
20  51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21  *
22  **********(*)*/
23 
24 
25 #include "ARPAPredictor.h"
26 
27 
28 #include <sstream>
29 #include <algorithm>
30 #include <cmath>
31 
32 
33 #define OOV "<UNK>"
34 
35 
36 
38  : Predictor(config,
39  ct,
40  name,
41  "ARPAPredictor, a predictor relying on an ARPA language model",
42  "ARPAPredictor, long description."
43  ),
44  dispatcher (this)
45 {
46  LOGGER = PREDICTORS + name + ".LOGGER";
47  ARPAFILENAME = PREDICTORS + name + ".ARPAFILENAME";
48  VOCABFILENAME = PREDICTORS + name + ".VOCABFILENAME";
49  TIMEOUT = PREDICTORS + name + ".TIMEOUT";
50 
51  // build notification dispatch map
56 
59 }
60 
61 void ARPAPredictor::set_vocab_filename (const std::string& value)
62 {
63  logger << INFO << "VOCABFILENAME: " << value << endl;
64  vocabFilename = value;
65 }
66 
67 void ARPAPredictor::set_arpa_filename (const std::string& value)
68 {
69  logger << INFO << "ARPAFILENAME: " << value << endl;
70  arpaFilename = value;
71 }
72 
73 void ARPAPredictor::set_timeout (const std::string& value)
74 {
75  logger << INFO << "TIMEOUT: " << value << endl;
76  timeout = atoi(value.c_str());
77 }
78 
80 {
81  std::ifstream vocabFile;
82  vocabFile.open(vocabFilename.c_str());
83  if(!vocabFile)
84  logger << ERROR << "Error opening vocabulary file: " << vocabFilename << endl;
85 
86  assert(vocabFile);
87  std::string row;
88  int code = 0;
89  while(std::getline(vocabFile,row))
90  {
91  if(row[0]=='#')
92  continue;
93 
94  vocabCode[row]=code;
95  vocabDecode[code]=row;
96 
97  logger << DEBUG << "["<<row<<"] -> "<< code<<endl;
98 
99  code++;
100  }
101 
102  logger << DEBUG << "Loaded "<<code<<" words from vocabulary" <<endl;
103 
104 }
105 
107 {
108  std::ifstream arpaFile;
109  arpaFile.open(arpaFilename.c_str());
110 
111  if(!arpaFile)
112  logger << ERROR << "Error opening ARPA model file: " << arpaFilename << endl;
113 
114  assert(arpaFile);
115  std::string row;
116 
117  int currOrder = 0;
118 
119  unigramCount = 0;
120  bigramCount = 0;
121  trigramCount = 0;
122 
123  int lineNum =0;
124  bool startData = false;
125 
126  while(std::getline(arpaFile,row))
127  {
128  lineNum++;
129  if(row.empty())
130  continue;
131 
132  if(row == "\\end\\")
133  break;
134 
135  if(row == "\\data\\")
136  {
137  startData = true;
138  continue;
139  }
140 
141 
142  if( startData == true && currOrder == 0)
143  {
144  if( row.find("ngram 1")==0 )
145  {
146  unigramTot = atoi(row.substr(8).c_str());
147  logger << DEBUG << "tot unigram = "<<unigramTot<<endl;
148  continue;
149  }
150 
151  if( row.find("ngram 2")==0)
152  {
153  bigramTot = atoi(row.substr(8).c_str());
154  logger << DEBUG << "tot bigram = "<<bigramTot<<endl;
155  continue;
156  }
157 
158  if( row.find("ngram 3")==0)
159  {
160  trigramTot = atoi(row.substr(8).c_str());
161  logger << DEBUG << "tot trigram = "<<trigramTot<<endl;
162  continue;
163  }
164  }
165 
166  if( row == "\\1-grams:" && startData)
167  {
168  currOrder = 1;
169  std::cerr << std::endl << "ARPA loading unigrams:" << std::endl;
170  unigramProg = new ProgressBar<char>(std::cerr);
171  continue;
172  }
173 
174  if( row == "\\2-grams:" && startData)
175  {
176  currOrder = 2;
177  std::cerr << std::endl << std::endl << "ARPA loading bigrams:" << std::endl;
178  bigramProg = new ProgressBar<char>(std::cerr);
179  continue;
180  }
181 
182  if( row == "\\3-grams:" && startData)
183  {
184  currOrder = 3;
185  std::cerr << std::endl << std::endl << "ARPA loading trigrams:" << std::endl;
186  trigramProg = new ProgressBar<char>(std::cerr);
187  continue;
188  }
189 
190  if(currOrder == 0)
191  continue;
192 
193  switch(currOrder)
194  {
195  case 1: addUnigram(row);
196  break;
197 
198  case 2: addBigram(row);
199  break;
200 
201  case 3: addTrigram(row);
202  break;
203  }
204 
205  }
206 
207  std::cerr << std::endl << std::endl;
208 
209  logger << DEBUG << "loaded unigrams: "<< unigramCount << endl;
210  logger << DEBUG << "loaded bigrams: " << bigramCount << endl;
211  logger << DEBUG << "loaded trigrams: "<< trigramCount << endl;
212 }
213 
214 void ARPAPredictor::addUnigram(std::string row)
215 {
216  std::stringstream str(row);
217  float logProb = 0;
218  float logAlfa = 0;
219  std::string wd1Str;
220 
221  str >> logProb;
222  str >> wd1Str;
223  str >> logAlfa;
224 
225 
226  if(wd1Str != OOV )
227  {
228  int wd1 = vocabCode[wd1Str];
229 
230  unigramMap[wd1]= ARPAData(logProb,logAlfa);
231 
232  logger << DEBUG << "adding unigram ["<<wd1Str<< "] -> "<<logProb<<" "<<logAlfa<<endl;
233  }
234 
235 
236  unigramCount++;
237 
238  unigramProg->update((float)unigramCount/(float)unigramTot);
239 }
240 
241 void ARPAPredictor::addBigram(std::string row)
242 {
243  std::stringstream str(row);
244  float logProb = 0;
245  float logAlfa = 0;
246  std::string wd1Str;
247  std::string wd2Str;
248 
249  str >> logProb;
250  str >> wd1Str;
251  str >> wd2Str;
252  str >> logAlfa;
253 
254  if(wd1Str != OOV && wd2Str != OOV)
255  {
256  int wd1 = vocabCode[wd1Str];
257  int wd2 = vocabCode[wd2Str];
258 
259  bigramMap[BigramKey(wd1,wd2)]=ARPAData(logProb,logAlfa);
260 
261  logger << DEBUG << "adding bigram ["<<wd1Str<< "] ["<<wd2Str<< "] -> "<<logProb<<" "<<logAlfa<<endl;
262  }
263 
264  bigramCount++;
265  bigramProg->update((float)bigramCount/(float)bigramTot);
266 }
267 
268 void ARPAPredictor::addTrigram(std::string row)
269 {
270  std::stringstream str(row);
271  float logProb = 0;
272 
273  std::string wd1Str;
274  std::string wd2Str;
275  std::string wd3Str;
276 
277  str >> logProb;
278  str >> wd1Str;
279  str >> wd2Str;
280  str >> wd3Str;
281 
282  if(wd1Str != OOV && wd2Str != OOV && wd3Str != OOV)
283  {
284  int wd1 = vocabCode[wd1Str];
285  int wd2 = vocabCode[wd2Str];
286  int wd3 = vocabCode[wd3Str];
287 
288  trigramMap[TrigramKey(wd1,wd2,wd3)]=logProb;
289  logger << DEBUG << "adding trigram ["<<wd1Str<< "] ["<<wd2Str<< "] ["<<wd3Str<< "] -> "<<logProb <<endl;
290 
291  }
292 
293  trigramCount++;
294  trigramProg->update((float)trigramCount/(float)trigramTot);
295 }
296 
297 
299 {
300  delete unigramProg;
301  delete bigramProg;
302  delete trigramProg;
303 }
304 
305 bool ARPAPredictor::matchesPrefixAndFilter(std::string word, std::string prefix, const char** filter ) const
306 {
307  if(filter == 0)
308  return word.find(prefix)==0;
309 
310  for(int j = 0; filter[j] != 0; j++)
311  {
312  std::string pattern = prefix+std::string(filter[j]);
313  if(word.find(pattern)==0)
314  return true;
315  }
316 
317  return false;
318 }
319 
320 Prediction ARPAPredictor::predict(const size_t max_partial_prediction_size, const char** filter) const
321 {
322  logger << DEBUG << "predict()" << endl;
323  Prediction prediction;
324 
325  int cardinality = 3;
326  std::vector<std::string> tokens(cardinality);
327 
328  std::string prefix = Utility::strtolower(contextTracker->getToken(0));
329  std::string wd2Str = Utility::strtolower(contextTracker->getToken(1));
330  std::string wd1Str = Utility::strtolower(contextTracker->getToken(2));
331 
332  std::multimap< float, std::string, cmp > result;
333 
334  logger << DEBUG << "["<<wd1Str<<"]"<<" ["<<wd2Str<<"] "<<"["<<prefix<<"]"<<endl;
335 
336  //search for the past tokens in the vocabulary
337  std::map<std::string,int>::const_iterator wd1It,wd2It;
338  wd1It = vocabCode.find(wd1Str);
339  wd2It = vocabCode.find(wd2Str);
340 
346  //we have two valid past tokens available
347  if(wd1It!=vocabCode.end() && wd2It!=vocabCode.end())
348  {
349  //iterate over all vocab words
350  for(std::map<int,std::string>::const_iterator it = vocabDecode.begin(); it!=vocabDecode.end(); ++it) //cppcheck: Prefer prefix ++/-- operators for non-primitive types.
351  {
352  //if wd3 matches prefix and filter -> compute its backoff probability and add to the result set
353  if(matchesPrefixAndFilter(it->second,prefix,filter))
354  {
355  std::pair<const float,std::string> p (computeTrigramBackoff(wd1It->second,wd2It->second,it->first),
356  it->second);
357  result.insert(p);
358  }
359  }
360  }
361 
362  //we have one valid past token available
363  else if(wd2It!=vocabCode.end())
364  {
365  //iterate over all vocab words
366  for(std::map<int,std::string>::const_iterator it = vocabDecode.begin(); it!=vocabDecode.end(); ++it)
367  {
368  //if wd3 matches prefix and filter -> compute its backoff probability and add to the result set
369  if(matchesPrefixAndFilter(it->second,prefix,filter))
370  {
371  std::pair<const float,std::string> p(computeBigramBackoff(wd2It->second,it->first),
372  it->second);
373  result.insert(p);
374  }
375  }
376  }
377 
378  //we have no valid past token available
379  else
380  {
381  //iterate over all vocab words
382  for(std::map<int,std::string>::const_iterator it = vocabDecode.begin(); it!=vocabDecode.end(); ++it)
383  {
384  //if wd3 matches prefix and filter -> compute its backoff probability and add to the result set
385  if(matchesPrefixAndFilter(it->second,prefix,filter))
386  {
387  std::pair<const float,std::string> p (unigramMap.find(it->first)->second.logProb,
388  it->second);
389  result.insert(p);
390  }
391  }
392  }
393 
394 
395  size_t numSuggestions = 0;
396  for(std::multimap< float, std::string, cmp >::const_iterator it = result.begin();
397  it != result.end() && numSuggestions < max_partial_prediction_size;
398  ++it)
399  {
400  prediction.addSuggestion(Suggestion(it->second,exp(it->first)));
401  numSuggestions++;
402  }
403 
404  return prediction;
405 }
409 float ARPAPredictor::computeTrigramBackoff(int wd1,int wd2,int wd3) const
410 {
411  logger << DEBUG << "computing P( ["<<vocabDecode.find(wd3)->second<< "] | ["<<vocabDecode.find(wd1)->second<<"] ["<<vocabDecode.find(wd2)->second<<"] )"<<endl;
412 
413  //trigram exist
414  std::map<TrigramKey,float>::const_iterator trigramIt =trigramMap.find(TrigramKey(wd1,wd2,wd3));
415  if(trigramIt!=trigramMap.end())
416  {
417  logger << DEBUG << "trigram ["<<vocabDecode.find(wd1)->second<< "] ["<<vocabDecode.find(wd2)->second<< "] ["<<vocabDecode.find(wd3)->second<< "] exists" <<endl;
418  logger << DEBUG << "returning "<<trigramIt->second <<endl;
419  return trigramIt->second;
420  }
421 
422  //bigram exist
423  std::map<BigramKey,ARPAData>::const_iterator bigramIt =bigramMap.find(BigramKey(wd1,wd2));
424  if(bigramIt!=bigramMap.end())
425  {
426  logger << DEBUG << "bigram ["<<vocabDecode.find(wd1)->second<< "] ["<<vocabDecode.find(wd2)->second<< "] exists" <<endl;
427  float prob = bigramIt->second.logAlfa + computeBigramBackoff(wd2,wd3);
428  logger << DEBUG << "returning "<<prob<<endl;
429  return prob;
430  }
431 
432  //else
433  logger << DEBUG << "no bigram w1,w2 exist" <<endl;
434  float prob = computeBigramBackoff(wd2,wd3);
435  logger << DEBUG << "returning "<<prob<<endl;
436  return prob;
437 
438 }
439 
443 float ARPAPredictor::computeBigramBackoff(int wd1, int wd2) const
444 {
445  //bigram exist
446  std::map<BigramKey,ARPAData>::const_iterator bigramIt =bigramMap.find(BigramKey(wd1,wd2));
447  if(bigramIt!=bigramMap.end())
448  return bigramIt->second.logProb;
449 
450  //else
451  return unigramMap.find(wd1)->second.logAlfa +unigramMap.find(wd2)->second.logProb;
452 
453 }
454 
455 void ARPAPredictor::learn(const std::vector<std::string>& change)
456 {
457  logger << DEBUG << "learn() method called" << endl;
458  logger << DEBUG << "learn() method exited" << endl;
459 }
460 
462 {
463  logger << DEBUG << "About to invoke dispatcher: " << var->get_name () << " - " << var->get_value() << endl;
464  dispatcher.dispatch (var);
465 }
void loadVocabulary()
Logger< char > logger
Definition: predictor.h:87
std::map< TrigramKey, float > trigramMap
ProgressBar< char > * unigramProg
Dispatcher< ARPAPredictor > dispatcher
void dispatch(const Observable *var)
Definition: dispatcher.h:73
ARPAPredictor(Configuration *, ContextTracker *, const char *)
bool matchesPrefixAndFilter(std::string, std::string, const char **) const
void update(const double percentage)
Definition: progress.h:54
virtual Prediction predict(const size_t size, const char **filter) const
Generate prediction.
std::string ARPAFILENAME
#define OOV
virtual void update(const Observable *variable)
virtual void learn(const std::vector< std::string > &change)
ProgressBar< char > * trigramProg
virtual void set_logger(const std::string &level)
Definition: predictor.cpp:88
std::string LOGGER
void set_vocab_filename(const std::string &value)
std::string config
Definition: presageDemo.cpp:70
const std::string PREDICTORS
Definition: predictor.h:81
void addBigram(std::string)
void set_arpa_filename(const std::string &value)
void addTrigram(std::string)
const std::string name
Definition: predictor.h:77
void set_timeout(const std::string &value)
ProgressBar< char > * bigramProg
virtual std::string get_name() const =0
std::map< std::string, int > vocabCode
float computeBigramBackoff(int, int) const
void map(Observable *var, const mbr_func_ptr_t &ptr)
Definition: dispatcher.h:62
static char * strtolower(char *)
Definition: utility.cpp:42
std::map< int, std::string > vocabDecode
ContextTracker * contextTracker
Definition: predictor.h:83
std::string TIMEOUT
void addSuggestion(Suggestion)
Definition: prediction.cpp:90
Tracks user interaction and context.
std::string VOCABFILENAME
virtual std::string get_value() const =0
void addUnigram(std::string)
std::string vocabFilename
std::map< int, ARPAData > unigramMap
float computeTrigramBackoff(int, int, int) const
std::map< BigramKey, ARPAData > bigramMap
void createARPATable()
std::string arpaFilename
const Logger< _charT, _Traits > & endl(const Logger< _charT, _Traits > &lgr)
Definition: logger.h:278
std::string getToken(const int) const