presage 0.9.1
smoothedNgramPredictor.cpp
Go to the documentation of this file.
1
2/******************************************************
3 * Presage, an extensible predictive text entry system
4 * ---------------------------------------------------
5 *
6 * Copyright (C) 2008 Matteo Vescovi <matteo.vescovi@yahoo.co.uk>
7
8 This program is free software; you can redistribute it and/or modify
9 it under the terms of the GNU General Public License as published by
10 the Free Software Foundation; either version 2 of the License, or
11 (at your option) any later version.
12
13 This program is distributed in the hope that it will be useful,
14 but WITHOUT ANY WARRANTY; without even the implied warranty of
15 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 GNU General Public License for more details.
17
18 You should have received a copy of the GNU General Public License along
19 with this program; if not, write to the Free Software Foundation, Inc.,
20 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21 *
22 **********(*)*/
23
24
26
27#include <sstream>
28#include <algorithm>
29
30
33 ct,
34 name,
35 "SmoothedNgramPredictor, a linear interpolating n-gram predictor",
36 "SmoothedNgramPredictor, long description." ),
37 db (0),
38 cardinality (0),
39 learn_mode_set (false),
40 dispatcher (this)
41{
42 LOGGER = PREDICTORS + name + ".LOGGER";
43 DBFILENAME = PREDICTORS + name + ".DBFILENAME";
44 DELTAS = PREDICTORS + name + ".DELTAS";
45 LEARN = PREDICTORS + name + ".LEARN";
46 DATABASE_LOGGER = PREDICTORS + name + ".DatabaseConnector.LOGGER";
47
48 // build notification dispatch map
54}
55
56
57
59{
60 delete db;
61}
62
63
64void SmoothedNgramPredictor::set_dbfilename (const std::string& filename)
65{
66 dbfilename = filename;
67 logger << INFO << "DBFILENAME: " << dbfilename << endl;
68
70}
71
72
74{
75 dbloglevel = value;
76}
77
78
79void SmoothedNgramPredictor::set_deltas (const std::string& value)
80{
81 std::stringstream ss_deltas(value);
82 cardinality = 0;
83 std::string delta;
84 while (ss_deltas >> delta) {
85 logger << DEBUG << "Pushing delta: " << delta << endl;
86 deltas.push_back (Utility::toDouble (delta));
88 }
89 logger << INFO << "DELTAS: " << value << endl;
90 logger << INFO << "CARDINALITY: " << cardinality << endl;
91
93}
94
95
96void SmoothedNgramPredictor::set_learn (const std::string& value)
97{
98 learn_mode = Utility::isYes (value);
99 logger << INFO << "LEARN: " << value << endl;
100
101 learn_mode_set = true;
102
104}
105
106
108{
109 // we can only init the sqlite database connector once we know the
110 // following:
111 // - what database file we need to open
112 // - what cardinality we expect the database file to be
113 // - whether we need to open the database in read only or
114 // read/write mode (learning requires read/write access)
115 //
116 if (! dbfilename.empty()
117 && cardinality > 0
118 && learn_mode_set ) {
119
120 delete db;
121
122 if (dbloglevel.empty ()) {
123 // open database connector
126 learn_mode);
127 } else {
128 // open database connector with logger lever
132 dbloglevel);
133 }
134 }
135}
136
137
138// convenience function to convert ngram to string
139//
140static std::string ngram_to_string(const Ngram& ngram)
141{
142 const char separator[] = "|";
143 std::string result = separator;
144
145 for (Ngram::const_iterator it = ngram.begin();
146 it != ngram.end();
147 it++)
148 {
149 result += *it + separator;
150 }
151
152 return result;
153}
154
155
171unsigned int SmoothedNgramPredictor::count(const std::vector<std::string>& tokens, int offset, int ngram_size) const
172{
173 unsigned int result = 0;
174
175 assert(offset <= 0); // TODO: handle this better
176 assert(ngram_size >= 0);
177
178 if (ngram_size > 0) {
179 Ngram ngram(ngram_size);
180 copy(tokens.end() - ngram_size + offset , tokens.end() + offset, ngram.begin());
181 result = db->getNgramCount(ngram);
182 logger << DEBUG << "count ngram: " << ngram_to_string (ngram) << " : " << result << endl;
183 } else {
184 result = db->getUnigramCountsSum();
185 logger << DEBUG << "unigram counts sum: " << result << endl;
186 }
187
188 return result;
189}
190
191Prediction SmoothedNgramPredictor::predict(const size_t max_partial_prediction_size, const char** filter) const
192{
193 logger << DEBUG << "predict()" << endl;
194
195 // Result prediction
196 Prediction prediction;
197
198 // Cache all the needed tokens.
199 // tokens[k] corresponds to w_{i-k} in the generalized smoothed
200 // n-gram probability formula
201 //
202 std::vector<std::string> tokens(cardinality);
203 for (int i = 0; i < cardinality; i++) {
204 tokens[cardinality - 1 - i] = contextTracker->getToken(i);
205 logger << DEBUG << "Cached tokens[" << cardinality - 1 - i << "] = " << tokens[cardinality - 1 - i] << endl;
206 }
207
208 // Generate list of prefix completition candidates.
209 //
210 // The prefix completion candidates used to be obtained from the
211 // _1_gram table because in a well-constructed ngram database the
212 // _1_gram table (which contains all known tokens). However, this
213 // introduced a skew, since the unigram counts will take
214 // precedence over the higher-order counts.
215 //
216 // The current solution retrieves candidates from the highest
217 // n-gram table, falling back on lower order n-gram tables if
218 // initial completion set is smaller than required.
219 //
220 std::vector<std::string> prefixCompletionCandidates;
221 for (size_t k = cardinality; (k > 0 && prefixCompletionCandidates.size() < max_partial_prediction_size); k--) {
222 logger << DEBUG << "Building partial prefix completion table of cardinality: " << k << endl;
223 // create n-gram used to retrieve initial prefix completion table
224 Ngram prefix_ngram(k);
225 copy(tokens.end() - k, tokens.end(), prefix_ngram.begin());
226
227 if (logger.shouldLog()) {
228 logger << DEBUG << "prefix_ngram: ";
229 for (size_t r = 0; r < prefix_ngram.size(); r++) {
230 logger << DEBUG << prefix_ngram[r] << ' ';
231 }
232 logger << DEBUG << endl;
233 }
234
235 // obtain initial prefix completion candidates
237
238 NgramTable partial;
239
240 if (filter == 0) {
241 partial = db->getNgramLikeTable(prefix_ngram,max_partial_prediction_size - prefixCompletionCandidates.size());
242 } else {
243 partial = db->getNgramLikeTableFiltered(prefix_ngram,filter, max_partial_prediction_size - prefixCompletionCandidates.size());
244 }
245
247
248 if (logger.shouldLog()) {
249 logger << DEBUG << "partial prefixCompletionCandidates" << endl
250 << DEBUG << "----------------------------------" << endl;
251 for (size_t j = 0; j < partial.size(); j++) {
252 for (size_t k = 0; k < partial[j].size(); k++) {
253 logger << DEBUG << partial[j][k] << " ";
254 }
255 logger << endl;
256 }
257 }
258
259 logger << DEBUG << "Partial prefix completion table contains " << partial.size() << " potential completions." << endl;
260
261 // append newly discovered potential completions to prefix
262 // completion candidates array to fill it up to
263 // max_partial_prediction_size
264 //
265 std::vector<Ngram>::const_iterator it = partial.begin();
266 while (it != partial.end() && prefixCompletionCandidates.size() < max_partial_prediction_size) {
267 // only add new candidates, iterator it points to Ngram,
268 // it->end() - 2 points to the token candidate
269 //
270 std::string candidate = *(it->end() - 2);
271 if (find(prefixCompletionCandidates.begin(),
272 prefixCompletionCandidates.end(),
273 candidate) == prefixCompletionCandidates.end()) {
274 prefixCompletionCandidates.push_back(candidate);
275 }
276 it++;
277 }
278 }
279
280 if (logger.shouldLog()) {
281 logger << DEBUG << "prefixCompletionCandidates" << endl
282 << DEBUG << "--------------------------" << endl;
283 for (size_t j = 0; j < prefixCompletionCandidates.size(); j++) {
284 logger << DEBUG << prefixCompletionCandidates[j] << endl;
285 }
286 }
287
288 // compute smoothed probabilities for all candidates
289 //
291 // getUnigramCountsSum is an expensive SQL query
292 // caching it here saves much time later inside the loop
293 int unigrams_counts_sum = db->getUnigramCountsSum();
294 for (size_t j = 0; (j < prefixCompletionCandidates.size() && j < max_partial_prediction_size); j++) {
295 // store w_i candidate at end of tokens
296 tokens[cardinality - 1] = prefixCompletionCandidates[j];
297
298 logger << DEBUG << "------------------" << endl;
299 logger << DEBUG << "w_i: " << tokens[cardinality - 1] << endl;
300
301 double probability = 0;
302 for (int k = 0; k < cardinality; k++) {
303 double numerator = count(tokens, 0, k+1);
304 // reuse cached unigrams_counts_sum to speed things up
305 double denominator = (k == 0 ? unigrams_counts_sum : count(tokens, -1, k));
306 double frequency = ((denominator > 0) ? (numerator / denominator) : 0);
307 probability += deltas[k] * frequency;
308
309 logger << DEBUG << "numerator: " << numerator << endl;
310 logger << DEBUG << "denominator: " << denominator << endl;
311 logger << DEBUG << "frequency: " << frequency << endl;
312 logger << DEBUG << "delta: " << deltas[k] << endl;
313
314 // for some sanity checks
315 assert(numerator <= denominator);
316 assert(frequency <= 1);
317 }
318
319 logger << DEBUG << "____________" << endl;
320 logger << DEBUG << "probability: " << probability << endl;
321
322 if (probability > 0) {
323 prediction.addSuggestion(Suggestion(tokens[cardinality - 1], probability));
324 }
325 }
327
328 logger << DEBUG << "Prediction:" << endl;
329 logger << DEBUG << "-----------" << endl;
330 logger << DEBUG << prediction << endl;
331
332 return prediction;
333}
334
335void SmoothedNgramPredictor::learn(const std::vector<std::string>& change)
336{
337 logger << INFO << "learn(\"" << ngram_to_string(change) << "\")" << endl;
338
339 if (learn_mode) {
340 // learning is turned on
341
342 std::map<std::list<std::string>, int> ngramMap;
343
344 // build up ngram map for all cardinalities
345 // i.e. learn all ngrams and counts in memory
346 for (size_t curr_cardinality = 1;
347 curr_cardinality < cardinality + 1;
348 curr_cardinality++)
349 {
350 int change_idx = 0;
351 int change_size = change.size();
352
353 std::list<std::string> ngram_list;
354
355 // take care of first N-1 tokens
356 for (int i = 0;
357 (i < curr_cardinality - 1 && change_idx < change_size);
358 i++)
359 {
360 ngram_list.push_back(change[change_idx]);
361 change_idx++;
362 }
363
364 while (change_idx < change_size)
365 {
366 ngram_list.push_back(change[change_idx++]);
367 ngramMap[ngram_list] = ngramMap[ngram_list] + 1;
368 ngram_list.pop_front();
369 }
370 }
371
372 // use (past stream - change) to learn token at the boundary
373 // change, i.e.
374 //
375
376 // if change is "bar foobar", then "bar" will only occur in a
377 // 1-gram, since there are no token before it. By dipping in
378 // the past stream, we additional context to learn a 2-gram by
379 // getting extra tokens (assuming past stream ends with token
380 // "foo":
381 //
382 // <"foo", "bar"> will be learnt
383 //
384 // We do this till we build up to n equal to cardinality.
385 //
386 // First check that change is not empty (nothing to learn) and
387 // that change and past stream match by sampling first and
388 // last token in change and comparing them with corresponding
389 // tokens from past stream
390 //
391 if (change.size() > 0 &&
392 change.back() == contextTracker->getToken(1) &&
393 change.front() == contextTracker->getToken(change.size()))
394 {
395 // create ngram list with first (oldest) token from change
396 std::list<std::string> ngram_list(change.begin(), change.begin() + 1);
397
398 // prepend token to ngram list by grabbing extra tokens
399 // from past stream (if there are any) till we have built
400 // up to n==cardinality ngrams, and commit them to
401 // ngramMap
402 //
403 for (int tk_idx = 1;
404 ngram_list.size() < cardinality;
405 tk_idx++)
406 {
407 // getExtraTokenToLearn returns tokens from
408 // past stream that come before and are not in
409 // change vector
410 //
411 std::string extra_token = contextTracker->getExtraTokenToLearn(tk_idx, change);
412 logger << DEBUG << "Adding extra token: " << extra_token << endl;
413
414 if (extra_token.empty())
415 {
416 break;
417 }
418 ngram_list.push_front(extra_token);
419
420 ngramMap[ngram_list] = ngramMap[ngram_list] + 1;
421 }
422 }
423
424 // then write out to language model database
425 try
426 {
428
429 std::map<std::list<std::string>, int>::const_iterator it;
430 for (it = ngramMap.begin(); it != ngramMap.end(); it++)
431 {
432 // convert ngram from list to vector based Ngram
433 Ngram ngram((it->first).begin(), (it->first).end());
434
435 // update the counts
436 int count = db->getNgramCount(ngram);
437 if (count > 0)
438 {
439 // ngram already in database, update count
440 db->updateNgram(ngram, count + it->second);
442 }
443 else
444 {
445 // ngram not in database, insert it
446 db->insertNgram(ngram, it->second);
447 }
448 }
449
451 logger << INFO << "Committed learning update to database" << endl;
452 }
454 {
456 logger << ERROR << "Rolling back learning update : " << ex.what() << endl;
457 throw;
458 }
459 }
460
461 logger << DEBUG << "end learn()" << endl;
462}
463
465{
466 // no need to begin a new transaction, as we'll be called from
467 // within an existing transaction from learn()
468
469 // BEWARE: if the previous sentence is not true, then performance
470 // WILL suffer!
471
472 size_t size = ngram.size();
473 for (size_t i = 0; i < size; i++) {
474 if (count(ngram, -i, size - i) > count(ngram, -(i + 1), size - (i + 1))) {
475 logger << INFO << "consistency adjustment needed!" << endl;
476
477 int offset = -(i + 1);
478 int sub_ngram_size = size - (i + 1);
479
480 logger << DEBUG << "i: " << i << " | offset: " << offset << " | sub_ngram_size: " << sub_ngram_size << endl;
481
482 Ngram sub_ngram(sub_ngram_size); // need to init to right size for sub_ngram
483 copy(ngram.end() - sub_ngram_size + offset, ngram.end() + offset, sub_ngram.begin());
484
485 if (logger.shouldLog()) {
486 logger << "ngram to be count adjusted is: ";
487 for (size_t i = 0; i < sub_ngram.size(); i++) {
488 logger << sub_ngram[i] << ' ';
489 }
490 logger << endl;
491 }
492
493 db->incrementNgramCount(sub_ngram);
494 logger << DEBUG << "consistency adjusted" << endl;
495 }
496 }
497}
498
500{
501 logger << DEBUG << "About to invoke dispatcher: " << var->get_name () << " - " << var->get_value() << endl;
502 dispatcher.dispatch (var);
503}
Tracks user interaction and context.
std::string getExtraTokenToLearn(const int index, const std::vector< std::string > &change) const
std::string getToken(const int) const
virtual void endTransaction() const
virtual void beginTransaction() const
virtual void rollbackTransaction() const
NgramTable getNgramLikeTable(const Ngram ngram, int limit=-1) const
NgramTable getNgramLikeTableFiltered(const Ngram ngram, const char **filter, int limit=-1) const
int incrementNgramCount(const Ngram ngram) const
void insertNgram(const Ngram ngram, const int count) const
int getUnigramCountsSum() const
int getNgramCount(const Ngram ngram) const
void updateNgram(const Ngram ngram, const int count) const
void dispatch(const Observable *var)
Definition: dispatcher.h:73
void map(Observable *var, const mbr_func_ptr_t &ptr)
Definition: dispatcher.h:62
bool shouldLog() const
Definition: logger.h:149
Definition: ngram.h:33
virtual std::string get_name() const =0
virtual std::string get_value() const =0
void addSuggestion(Suggestion)
Definition: prediction.cpp:90
ContextTracker * contextTracker
Definition: predictor.h:83
const std::string PREDICTORS
Definition: predictor.h:81
virtual void set_logger(const std::string &level)
Definition: predictor.cpp:88
Logger< char > logger
Definition: predictor.h:87
const std::string name
Definition: predictor.h:77
virtual const char * what() const
void check_learn_consistency(const Ngram &name) const
Dispatcher< SmoothedNgramPredictor > dispatcher
std::vector< double > deltas
void set_database_logger_level(const std::string &level)
virtual void learn(const std::vector< std::string > &change)
unsigned int count(const std::vector< std::string > &tokens, int offset, int ngram_size) const
Builds the required n-gram and returns its count.
virtual void update(const Observable *variable)
void set_dbfilename(const std::string &filename)
void set_learn(const std::string &learn_mode)
SmoothedNgramPredictor(Configuration *, ContextTracker *, const char *)
virtual Prediction predict(const size_t size, const char **filter) const
Generate prediction.
void set_deltas(const std::string &deltas)
static double toDouble(const std::string)
Definition: utility.cpp:258
static bool isYes(const char *)
Definition: utility.cpp:185
std::vector< Ngram > NgramTable
const Logger< _charT, _Traits > & endl(const Logger< _charT, _Traits > &lgr)
Definition: logger.h:278
std::string config
Definition: presageDemo.cpp:70
static std::string ngram_to_string(const Ngram &ngram)