presage  0.9.2~beta
smoothedNgramPredictor.cpp
Go to the documentation of this file.
1 
2 /******************************************************
3  * Presage, an extensible predictive text entry system
4  * ---------------------------------------------------
5  *
6  * Copyright (C) 2008 Matteo Vescovi <matteo.vescovi@yahoo.co.uk>
7 
8  This program is free software; you can redistribute it and/or modify
9  it under the terms of the GNU General Public License as published by
10  the Free Software Foundation; either version 2 of the License, or
11  (at your option) any later version.
12 
13  This program is distributed in the hope that it will be useful,
14  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16  GNU General Public License for more details.
17 
18  You should have received a copy of the GNU General Public License along
19  with this program; if not, write to the Free Software Foundation, Inc.,
20  51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21  *
22  **********(*)*/
23 
24 
25 #include "smoothedNgramPredictor.h"
26 
27 #include <sstream>
28 #include <algorithm>
29 
30 
32  : Predictor(config,
33  ct,
34  name,
35  "SmoothedNgramPredictor, a linear interpolating n-gram predictor",
36  "SmoothedNgramPredictor, long description." ),
37  db (0),
38  count_threshold (0),
39  cardinality (0),
40  learn_mode_set (false),
41  dispatcher (this)
42 {
43  LOGGER = PREDICTORS + name + ".LOGGER";
44  DBFILENAME = PREDICTORS + name + ".DBFILENAME";
45  DELTAS = PREDICTORS + name + ".DELTAS";
46  COUNT_THRESHOLD = PREDICTORS + name + ".COUNT_THRESHOLD";
47  LEARN = PREDICTORS + name + ".LEARN";
48  DATABASE_LOGGER = PREDICTORS + name + ".DatabaseConnector.LOGGER";
49 
50  // build notification dispatch map
57 }
58 
59 
60 
62 {
63  delete db;
64 }
65 
66 
67 void SmoothedNgramPredictor::set_dbfilename (const std::string& filename)
68 {
69  dbfilename = filename;
70  logger << INFO << "DBFILENAME: " << dbfilename << endl;
71 
73 }
74 
75 
77 {
78  dbloglevel = value;
79 }
80 
81 
82 void SmoothedNgramPredictor::set_deltas (const std::string& value)
83 {
84  std::stringstream ss_deltas(value);
85  cardinality = 0;
86  std::string delta;
87  while (ss_deltas >> delta) {
88  logger << DEBUG << "Pushing delta: " << delta << endl;
89  deltas.push_back (Utility::toDouble (delta));
90  cardinality++;
91  }
92  logger << INFO << "DELTAS: " << value << endl;
93  logger << INFO << "CARDINALITY: " << cardinality << endl;
94 
96 }
97 
98 
99 void SmoothedNgramPredictor::set_count_threshold (const std::string& value)
100 {
102  logger << INFO << "COUNT_THRESHOLD: " << count_threshold << endl;
103 }
104 
105 
106 void SmoothedNgramPredictor::set_learn (const std::string& value)
107 {
108  learn_mode = Utility::isYes (value);
109  logger << INFO << "LEARN: " << value << endl;
110 
111  learn_mode_set = true;
112 
114 }
115 
116 
118 {
119  // we can only init the sqlite database connector once we know the
120  // following:
121  // - what database file we need to open
122  // - what cardinality we expect the database file to be
123  // - whether we need to open the database in read only or
124  // read/write mode (learning requires read/write access)
125  //
126  if (! dbfilename.empty()
127  && cardinality > 0
128  && learn_mode_set ) {
129 
130  delete db;
131 
132  if (dbloglevel.empty ()) {
133  // open database connector
135  cardinality,
136  learn_mode);
137  } else {
138  // open database connector with logger lever
140  cardinality,
141  learn_mode,
142  dbloglevel);
143  }
144  }
145 }
146 
147 
148 // convenience function to convert ngram to string
149 //
150 static std::string ngram_to_string(const Ngram& ngram)
151 {
152  const char separator[] = "|";
153  std::string result = separator;
154 
155  for (Ngram::const_iterator it = ngram.begin();
156  it != ngram.end();
157  it++)
158  {
159  result += *it + separator;
160  }
161 
162  return result;
163 }
164 
165 
181 unsigned int SmoothedNgramPredictor::count(const std::vector<std::string>& tokens, int offset, int ngram_size) const
182 {
183  unsigned int result = 0;
184 
185  assert(offset <= 0); // TODO: handle this better
186  assert(ngram_size >= 0);
187 
188  if (ngram_size > 0) {
189  Ngram ngram(ngram_size);
190  copy(tokens.end() - ngram_size + offset , tokens.end() + offset, ngram.begin());
191  result = db->getNgramCount(ngram);
192  logger << DEBUG << "count ngram: " << ngram_to_string (ngram) << " : " << result << endl;
193  } else {
194  result = db->getUnigramCountsSum();
195  logger << DEBUG << "unigram counts sum: " << result << endl;
196  }
197 
198  return result;
199 }
200 
201 Prediction SmoothedNgramPredictor::predict(const size_t max_partial_prediction_size, const char** filter) const
202 {
203  logger << DEBUG << "predict()" << endl;
204 
205  // Result prediction
206  Prediction prediction;
207 
208  // Cache all the needed tokens.
209  // tokens[k] corresponds to w_{i-k} in the generalized smoothed
210  // n-gram probability formula
211  //
212  std::vector<std::string> tokens(cardinality);
213  for (int i = 0; i < cardinality; i++) {
214  tokens[cardinality - 1 - i] = contextTracker->getToken(i);
215  logger << DEBUG << "Cached tokens[" << cardinality - 1 - i << "] = " << tokens[cardinality - 1 - i] << endl;
216  }
217 
218  // Generate list of prefix completition candidates.
219  //
220  // The prefix completion candidates used to be obtained from the
221  // _1_gram table because in a well-constructed ngram database the
222  // _1_gram table (which contains all known tokens). However, this
223  // introduced a skew, since the unigram counts will take
224  // precedence over the higher-order counts.
225  //
226  // The current solution retrieves candidates from the highest
227  // n-gram table, falling back on lower order n-gram tables if
228  // initial completion set is smaller than required.
229  //
230  std::vector<std::string> prefixCompletionCandidates;
231  for (size_t k = cardinality; (k > 0 && prefixCompletionCandidates.size() < max_partial_prediction_size); k--) {
232  logger << DEBUG << "Building partial prefix completion table of cardinality: " << k << endl;
233  // create n-gram used to retrieve initial prefix completion table
234  Ngram prefix_ngram(k);
235  copy(tokens.end() - k, tokens.end(), prefix_ngram.begin());
236 
237  if (logger.shouldLog()) {
238  logger << DEBUG << "prefix_ngram: ";
239  for (size_t r = 0; r < prefix_ngram.size(); r++) {
240  logger << DEBUG << prefix_ngram[r] << ' ';
241  }
242  logger << DEBUG << endl;
243  }
244 
245  // obtain initial prefix completion candidates
246  db->beginTransaction();
247 
248  NgramTable partial;
249 
250  partial = db->getNgramLikeTable(prefix_ngram,
251  filter,
253  max_partial_prediction_size - prefixCompletionCandidates.size());
254 
255  db->endTransaction();
256 
257  if (logger.shouldLog()) {
258  logger << DEBUG << "partial prefixCompletionCandidates" << endl
259  << DEBUG << "----------------------------------" << endl;
260  for (size_t j = 0; j < partial.size(); j++) {
261  for (size_t k = 0; k < partial[j].size(); k++) {
262  logger << DEBUG << partial[j][k] << " ";
263  }
264  logger << endl;
265  }
266  }
267 
268  logger << DEBUG << "Partial prefix completion table contains " << partial.size() << " potential completions." << endl;
269 
270  // append newly discovered potential completions to prefix
271  // completion candidates array to fill it up to
272  // max_partial_prediction_size
273  //
274  std::vector<Ngram>::const_iterator it = partial.begin();
275  while (it != partial.end() && prefixCompletionCandidates.size() < max_partial_prediction_size) {
276  // only add new candidates, iterator it points to Ngram,
277  // it->end() - 2 points to the token candidate
278  //
279  std::string candidate = *(it->end() - 2);
280  if (find(prefixCompletionCandidates.begin(),
281  prefixCompletionCandidates.end(),
282  candidate) == prefixCompletionCandidates.end()) {
283  prefixCompletionCandidates.push_back(candidate);
284  }
285  it++;
286  }
287  }
288 
289  if (logger.shouldLog()) {
290  logger << DEBUG << "prefixCompletionCandidates" << endl
291  << DEBUG << "--------------------------" << endl;
292  for (size_t j = 0; j < prefixCompletionCandidates.size(); j++) {
293  logger << DEBUG << prefixCompletionCandidates[j] << endl;
294  }
295  }
296 
297  // compute smoothed probabilities for all candidates
298  //
299  db->beginTransaction();
300  // getUnigramCountsSum is an expensive SQL query
301  // caching it here saves much time later inside the loop
302  int unigrams_counts_sum = db->getUnigramCountsSum();
303  for (size_t j = 0; (j < prefixCompletionCandidates.size() && j < max_partial_prediction_size); j++) {
304  // store w_i candidate at end of tokens
305  tokens[cardinality - 1] = prefixCompletionCandidates[j];
306 
307  logger << DEBUG << "------------------" << endl;
308  logger << DEBUG << "w_i: " << tokens[cardinality - 1] << endl;
309 
310  double probability = 0;
311  for (int k = 0; k < cardinality; k++) {
312  double numerator = count(tokens, 0, k+1);
313  // reuse cached unigrams_counts_sum to speed things up
314  double denominator = (k == 0 ? unigrams_counts_sum : count(tokens, -1, k));
315  double frequency = ((denominator > 0) ? (numerator / denominator) : 0);
316  probability += deltas[k] * frequency;
317 
318  logger << DEBUG << "numerator: " << numerator << endl;
319  logger << DEBUG << "denominator: " << denominator << endl;
320  logger << DEBUG << "frequency: " << frequency << endl;
321  logger << DEBUG << "delta: " << deltas[k] << endl;
322 
323  // for some sanity checks
324  assert(numerator <= denominator);
325  assert(frequency <= 1);
326  }
327 
328  logger << DEBUG << "____________" << endl;
329  logger << DEBUG << "probability: " << probability << endl;
330 
331  if (probability > 0) {
332  prediction.addSuggestion(Suggestion(tokens[cardinality - 1], probability));
333  }
334  }
335  db->endTransaction();
336 
337  logger << DEBUG << "Prediction:" << endl;
338  logger << DEBUG << "-----------" << endl;
339  logger << DEBUG << prediction << endl;
340 
341  return prediction;
342 }
343 
344 void SmoothedNgramPredictor::learn(const std::vector<std::string>& change)
345 {
346  logger << INFO << "learn(\"" << ngram_to_string(change) << "\")" << endl;
347 
348  if (learn_mode) {
349  // learning is turned on
350 
351  std::map<std::list<std::string>, int> ngramMap;
352 
353  // build up ngram map for all cardinalities
354  // i.e. learn all ngrams and counts in memory
355  for (size_t curr_cardinality = 1;
356  curr_cardinality < cardinality + 1;
357  curr_cardinality++)
358  {
359  int change_idx = 0;
360  int change_size = change.size();
361 
362  std::list<std::string> ngram_list;
363 
364  // take care of first N-1 tokens
365  for (int i = 0;
366  (i < curr_cardinality - 1 && change_idx < change_size);
367  i++)
368  {
369  ngram_list.push_back(change[change_idx]);
370  change_idx++;
371  }
372 
373  while (change_idx < change_size)
374  {
375  ngram_list.push_back(change[change_idx++]);
376  ngramMap[ngram_list] = ngramMap[ngram_list] + 1;
377  ngram_list.pop_front();
378  }
379  }
380 
381  // use (past stream - change) to learn token at the boundary
382  // change, i.e.
383  //
384 
385  // if change is "bar foobar", then "bar" will only occur in a
386  // 1-gram, since there are no token before it. By dipping in
387  // the past stream, we additional context to learn a 2-gram by
388  // getting extra tokens (assuming past stream ends with token
389  // "foo":
390  //
391  // <"foo", "bar"> will be learnt
392  //
393  // We do this till we build up to n equal to cardinality.
394  //
395  // First check that change is not empty (nothing to learn) and
396  // that change and past stream match by sampling first and
397  // last token in change and comparing them with corresponding
398  // tokens from past stream
399  //
400  if (change.size() > 0 &&
401  change.back() == contextTracker->getToken(1) &&
402  change.front() == contextTracker->getToken(change.size()))
403  {
404  // create ngram list with first (oldest) token from change
405  std::list<std::string> ngram_list(change.begin(), change.begin() + 1);
406 
407  // prepend token to ngram list by grabbing extra tokens
408  // from past stream (if there are any) till we have built
409  // up to n==cardinality ngrams, and commit them to
410  // ngramMap
411  //
412  for (int tk_idx = 1;
413  ngram_list.size() < cardinality;
414  tk_idx++)
415  {
416  // getExtraTokenToLearn returns tokens from
417  // past stream that come before and are not in
418  // change vector
419  //
420  std::string extra_token = contextTracker->getExtraTokenToLearn(tk_idx, change);
421  logger << DEBUG << "Adding extra token: " << extra_token << endl;
422 
423  if (extra_token.empty())
424  {
425  break;
426  }
427  ngram_list.push_front(extra_token);
428 
429  ngramMap[ngram_list] = ngramMap[ngram_list] + 1;
430  }
431  }
432 
433  // then write out to language model database
434  try
435  {
436  db->beginTransaction();
437 
438  std::map<std::list<std::string>, int>::const_iterator it;
439  for (it = ngramMap.begin(); it != ngramMap.end(); it++)
440  {
441  // convert ngram from list to vector based Ngram
442  Ngram ngram((it->first).begin(), (it->first).end());
443 
444  // update the counts
445  int count = db->getNgramCount(ngram);
446  if (count > 0)
447  {
448  // ngram already in database, update count
449  db->updateNgram(ngram, count + it->second);
451  }
452  else
453  {
454  // ngram not in database, insert it
455  db->insertNgram(ngram, it->second);
456  }
457  }
458 
459  db->endTransaction();
460  logger << INFO << "Committed learning update to database" << endl;
461  }
463  {
465  logger << ERROR << "Rolling back learning update : " << ex.what() << endl;
466  throw;
467  }
468  }
469 
470  logger << DEBUG << "end learn()" << endl;
471 }
472 
474 {
475  // no need to begin a new transaction, as we'll be called from
476  // within an existing transaction from learn()
477 
478  // BEWARE: if the previous sentence is not true, then performance
479  // WILL suffer!
480 
481  size_t size = ngram.size();
482  for (size_t i = 0; i < size; i++) {
483  if (count(ngram, -i, size - i) > count(ngram, -(i + 1), size - (i + 1))) {
484  logger << INFO << "consistency adjustment needed!" << endl;
485 
486  int offset = -(i + 1);
487  int sub_ngram_size = size - (i + 1);
488 
489  logger << DEBUG << "i: " << i << " | offset: " << offset << " | sub_ngram_size: " << sub_ngram_size << endl;
490 
491  Ngram sub_ngram(sub_ngram_size); // need to init to right size for sub_ngram
492  copy(ngram.end() - sub_ngram_size + offset, ngram.end() + offset, sub_ngram.begin());
493 
494  if (logger.shouldLog()) {
495  logger << "ngram to be count adjusted is: ";
496  for (size_t i = 0; i < sub_ngram.size(); i++) {
497  logger << sub_ngram[i] << ' ';
498  }
499  logger << endl;
500  }
501 
502  db->incrementNgramCount(sub_ngram);
503  logger << DEBUG << "consistency adjusted" << endl;
504  }
505  }
506 }
507 
509 {
510  logger << DEBUG << "About to invoke dispatcher: " << var->get_name () << " - " << var->get_value() << endl;
511  dispatcher.dispatch (var);
512 }
Logger< char > logger
Definition: predictor.h:87
static int toInt(const std::string)
Definition: utility.cpp:266
void dispatch(const Observable *var)
Definition: dispatcher.h:73
Dispatcher< SmoothedNgramPredictor > dispatcher
int getNgramCount(const Ngram ngram) const
void insertNgram(const Ngram ngram, const int count) const
virtual const char * what() const
void set_count_threshold(const std::string &value)
virtual void learn(const std::vector< std::string > &change)
int getUnigramCountsSum() const
void set_database_logger_level(const std::string &level)
virtual void set_logger(const std::string &level)
Definition: predictor.cpp:88
NgramTable getNgramLikeTable(const Ngram ngram, const char **filter, const int count_threshold, int limit=-1) const
std::string config
Definition: presageDemo.cpp:70
void set_deltas(const std::string &deltas)
const std::string PREDICTORS
Definition: predictor.h:81
static std::string ngram_to_string(const Ngram &ngram)
std::vector< double > deltas
const std::string name
Definition: predictor.h:77
static double toDouble(const std::string)
Definition: utility.cpp:258
unsigned int count(const std::vector< std::string > &tokens, int offset, int ngram_size) const
Builds the required n-gram and returns its count.
virtual void update(const Observable *variable)
std::vector< Ngram > NgramTable
void set_learn(const std::string &learn_mode)
std::string getExtraTokenToLearn(const int index, const std::vector< std::string > &change) const
void check_learn_consistency(const Ngram &name) const
virtual std::string get_name() const =0
SmoothedNgramPredictor(Configuration *, ContextTracker *, const char *)
int incrementNgramCount(const Ngram ngram) const
virtual void beginTransaction() const
void updateNgram(const Ngram ngram, const int count) const
void map(Observable *var, const mbr_func_ptr_t &ptr)
Definition: dispatcher.h:62
virtual void rollbackTransaction() const
ContextTracker * contextTracker
Definition: predictor.h:83
virtual void endTransaction() const
void addSuggestion(Suggestion)
Definition: prediction.cpp:90
static bool isYes(const char *)
Definition: utility.cpp:185
virtual Prediction predict(const size_t size, const char **filter) const
Generate prediction.
Tracks user interaction and context.
bool shouldLog() const
Definition: logger.h:149
virtual std::string get_value() const =0
Definition: ngram.h:33
void set_dbfilename(const std::string &filename)
const Logger< _charT, _Traits > & endl(const Logger< _charT, _Traits > &lgr)
Definition: logger.h:278
std::string getToken(const int) const