Scippy

SCIP

Solving Constraint Integer Programs

bandit_exp3.c
Go to the documentation of this file.
1 /* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
2 /* */
3 /* This file is part of the program and library */
4 /* SCIP --- Solving Constraint Integer Programs */
5 /* */
6 /* Copyright (C) 2002-2018 Konrad-Zuse-Zentrum */
7 /* fuer Informationstechnik Berlin */
8 /* */
9 /* SCIP is distributed under the terms of the ZIB Academic License. */
10 /* */
11 /* You should have received a copy of the ZIB Academic License */
12 /* along with SCIP; see the file COPYING. If not email to scip@zib.de. */
13 /* */
14 /* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
15 
16 /**@file bandit_exp3.c
17  * @brief methods for Exp.3 bandit selection
18  * @author Gregor Hendel
19  */
20 
21 /*---+----1----+----2----+----3----+----4----+----5----+----6----+----7----+----8----+----9----+----0----+----1----+----2*/
22 
23 #include <assert.h>
24 
25 #include "scip/bandit_exp3.h"
26 
27 #define BANDIT_NAME "exp3"
28 #define NUMTOL 1e-6
29 
30 /*
31  * Data structures
32  */
33 
34 /** implementation specific data of Exp.3 bandit algorithm */
35 struct SCIP_BanditData
36 {
37  SCIP_Real* weights; /**< exponential weight for each arm */
38  SCIP_Real weightsum; /**< the sum of all weights */
39  SCIP_Real gamma; /**< weight between uniform (gamma ~ 1) and weight driven (gamma ~ 0) probability distribution */
40  SCIP_Real beta; /**< gain offset between 0 and 1 at every observation */
41 };
42 
43 /*
44  * Local methods
45  */
46 
47 /*
48  * Callback methods of bandit algorithm
49  */
50 
51 /** callback to free bandit specific data structures */
52 SCIP_DECL_BANDITFREE(SCIPbanditFreeExp3)
53 { /*lint --e{715}*/
54 
55  SCIP_BANDITDATA* banditdata;
56  int nactions;
57  assert(bandit != NULL);
58 
59  banditdata = SCIPbanditGetData(bandit);
60  assert(banditdata != NULL);
61  nactions = SCIPbanditGetNActions(bandit);
62 
63  BMSfreeBlockMemoryArray(blkmem, &banditdata->weights, nactions);
64 
65  BMSfreeBlockMemory(blkmem, &banditdata);
66 
67  SCIPbanditSetData(bandit, NULL);
68 
69  return SCIP_OKAY;
70 }
71 
72 /** selection callback for bandit selector */
73 SCIP_DECL_BANDITSELECT(SCIPbanditSelectExp3)
74 { /*lint --e{715}*/
75 
76  SCIP_BANDITDATA* banditdata;
77  SCIP_RANDNUMGEN* rng;
78  SCIP_Real randnr;
79  SCIP_Real psum;
80  SCIP_Real gammaoverk;
81  SCIP_Real oneminusgamma;
82  SCIP_Real* weights;
83  SCIP_Real weightsum;
84  int i;
85  int nactions;
86 
87  assert(bandit != NULL);
88  assert(selection != NULL);
89 
90  banditdata = SCIPbanditGetData(bandit);
91  assert(banditdata != NULL);
92  rng = SCIPbanditGetRandnumgen(bandit);
93  assert(rng != NULL);
94  nactions = SCIPbanditGetNActions(bandit);
95 
96 
97  /* draw a random number between 0 and 1 */
98  randnr = SCIPrandomGetReal(rng, 0.0, 1.0);
99 
100  /* initialize some local variables to speed up probability computations */
101  oneminusgamma = 1 - banditdata->gamma;
102  gammaoverk = banditdata->gamma / (SCIP_Real)nactions;
103  weightsum = banditdata->weightsum;
104  weights = banditdata->weights;
105  psum = 0.0;
106 
107  /* loop over probability distribution until rand is reached
108  * the loop terminates without looking at the last action,
109  * which is then selected automatically if the target probability
110  * is not reached earlier
111  */
112  for( i = 0; i < nactions - 1; ++i )
113  {
114  SCIP_Real prob;
115 
116  /* compute the probability for arm i as convex kombination of a uniform distribution and a weighted distribution */
117  prob = oneminusgamma * weights[i] / weightsum + gammaoverk;
118  psum += prob;
119 
120  /* break and select element if target probability is reached */
121  if( randnr <= psum )
122  break;
123  }
124 
125  /* select element i, which is the last action in case that the break statement hasn't been reached */
126  *selection = i;
127 
128  return SCIP_OKAY;
129 }
130 
131 /** update callback for bandit algorithm */
132 SCIP_DECL_BANDITUPDATE(SCIPbanditUpdateExp3)
133 { /*lint --e{715}*/
134  SCIP_BANDITDATA* banditdata;
135  SCIP_Real eta;
136  SCIP_Real gainestim;
137  SCIP_Real beta;
138  SCIP_Real weightsum;
139  SCIP_Real newweightsum;
140  SCIP_Real* weights;
141  SCIP_Real oneminusgamma;
142  SCIP_Real gammaoverk;
143  int nactions;
144 
145  assert(bandit != NULL);
146 
147  banditdata = SCIPbanditGetData(bandit);
148  assert(banditdata != NULL);
149  nactions = SCIPbanditGetNActions(bandit);
150 
151  assert(selection >= 0);
152  assert(selection < nactions);
153 
154  /* the learning rate eta */
155  eta = 1.0 / (SCIP_Real)nactions;
156 
157  beta = banditdata->beta;
158  oneminusgamma = 1.0 - banditdata->gamma;
159  gammaoverk = banditdata->gamma * eta;
160  weights = banditdata->weights;
161  weightsum = banditdata->weightsum;
162  newweightsum = weightsum;
163 
164  /* if beta is zero, only the observation for the current arm needs an update */
165  if( EPSZ(beta, NUMTOL) )
166  {
167  SCIP_Real probai;
168  probai = oneminusgamma * weights[selection] / weightsum + gammaoverk;
169 
170  assert(probai > 0.0);
171 
172  gainestim = score / probai;
173  newweightsum -= weights[selection];
174  weights[selection] *= exp(eta * gainestim);
175  newweightsum += weights[selection];
176  }
177  else
178  {
179  int j;
180  newweightsum = 0.0;
181 
182  /* loop over all items and update their weights based on the influence of the beta parameter */
183  for( j = 0; j < nactions; ++j )
184  {
185  SCIP_Real probaj;
186  probaj = oneminusgamma * weights[j] / weightsum + gammaoverk;
187 
188  assert(probaj > 0.0);
189 
190  /* consider the score only for the chosen arm i, use constant beta offset otherwise */
191  if( j == selection )
192  gainestim = (score + beta) / probaj;
193  else
194  gainestim = beta / probaj;
195 
196  weights[j] *= exp(eta * gainestim);
197  newweightsum += weights[j];
198  }
199  }
200 
201  banditdata->weightsum = newweightsum;
202 
203  return SCIP_OKAY;
204 }
205 
206 /** reset callback for bandit algorithm */
207 SCIP_DECL_BANDITRESET(SCIPbanditResetExp3)
208 { /*lint --e{715}*/
209  SCIP_BANDITDATA* banditdata;
210  SCIP_Real* weights;
211  int nactions;
212  int i;
213 
214  assert(bandit != NULL);
215 
216  banditdata = SCIPbanditGetData(bandit);
217  assert(banditdata != NULL);
218  nactions = SCIPbanditGetNActions(bandit);
219  weights = banditdata->weights;
220 
221  assert(nactions > 0);
222 
223  banditdata->weightsum = (1.0 + NUMTOL) * (SCIP_Real)nactions;
224 
225  /* in case of priorities, weights are normalized to sum up to nactions */
226  if( priorities != NULL )
227  {
228  SCIP_Real normalization;
229  SCIP_Real priosum;
230  priosum = 0.0;
231 
232  /* compute sum of priorities */
233  for( i = 0; i < nactions; ++i )
234  {
235  assert(priorities[i] >= 0);
236  priosum += priorities[i];
237  }
238 
239  /* if there are positive priorities, normalize the weights */
240  if( priosum > 0.0 )
241  {
242  normalization = nactions / priosum;
243  for( i = 0; i < nactions; ++i )
244  weights[i] = (priorities[i] * normalization) + NUMTOL;
245  }
246  else
247  {
248  /* use uniform distribution in case of all priorities being 0.0 */
249  for( i = 0; i < nactions; ++i )
250  weights[i] = 1.0 + NUMTOL;
251  }
252  }
253  else
254  {
255  /* use uniform distribution in case of unspecified priorities */
256  for( i = 0; i < nactions; ++i )
257  weights[i] = 1.0 + NUMTOL;
258  }
259 
260  return SCIP_OKAY;
261 }
262 
263 
264 /*
265  * bandit algorithm specific interface methods
266  */
267 
268 /** direct bandit creation method for the core where no SCIP pointer is available */
270  BMS_BLKMEM* blkmem, /**< block memory data structure */
271  BMS_BUFMEM* bufmem, /**< buffer memory */
272  SCIP_BANDITVTABLE* vtable, /**< virtual function table for callback functions of Exp.3 */
273  SCIP_BANDIT** exp3, /**< pointer to store bandit algorithm */
274  SCIP_Real* priorities, /**< nonnegative priorities for each action, or NULL if not needed */
275  SCIP_Real gammaparam, /**< weight between uniform (gamma ~ 1) and weight driven (gamma ~ 0) probability distribution */
276  SCIP_Real beta, /**< gain offset between 0 and 1 at every observation */
277  int nactions, /**< the positive number of actions for this bandit algorithm */
278  unsigned int initseed /**< initial random seed */
279  )
280 {
281  SCIP_BANDITDATA* banditdata;
282 
283  SCIP_ALLOC( BMSallocBlockMemory(blkmem, &banditdata) );
284  assert(banditdata != NULL);
285 
286  banditdata->gamma = gammaparam;
287  banditdata->beta = beta;
288  assert(gammaparam >= 0 && gammaparam <= 1);
289  assert(beta >= 0 && beta <= 1);
290 
291  SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->weights, nactions) );
292 
293  SCIP_CALL( SCIPbanditCreate(exp3, vtable, blkmem, bufmem, priorities, nactions, initseed, banditdata) );
294 
295  return SCIP_OKAY;
296 }
297 
298 /** creates and resets an Exp.3 bandit algorithm using \p scip pointer */
300  SCIP* scip, /**< SCIP data structure */
301  SCIP_BANDIT** exp3, /**< pointer to store bandit algorithm */
302  SCIP_Real* priorities, /**< nonnegative priorities for each action, or NULL if not needed */
303  SCIP_Real gammaparam, /**< weight between uniform (gamma ~ 1) and weight driven (gamma ~ 0) probability distribution */
304  SCIP_Real beta, /**< gain offset between 0 and 1 at every observation */
305  int nactions, /**< the positive number of actions for this bandit algorithm */
306  unsigned int initseed /**< initial seed for random number generation */
307  )
308 {
309  SCIP_BANDITVTABLE* vtable;
310 
311  vtable = SCIPfindBanditvtable(scip, BANDIT_NAME);
312  if( vtable == NULL )
313  {
314  SCIPerrorMessage("Could not find virtual function table for %s bandit algorithm\n", BANDIT_NAME);
315  return SCIP_INVALIDDATA;
316  }
317 
318  SCIP_CALL( SCIPbanditCreateExp3(SCIPblkmem(scip), SCIPbuffer(scip), vtable, exp3,
319  priorities, gammaparam, beta, nactions, SCIPinitializeRandomSeed(scip, (int)(initseed % INT_MAX))) );
320 
321  return SCIP_OKAY;
322 }
323 
324 /** set gamma parameter of Exp.3 bandit algorithm to increase weight of uniform distribution */
326  SCIP_BANDIT* exp3, /**< bandit algorithm */
327  SCIP_Real gammaparam /**< weight between uniform (gamma ~ 1) and weight driven (gamma ~ 0) probability distribution */
328  )
329 {
330  SCIP_BANDITDATA* banditdata = SCIPbanditGetData(exp3);
331 
332  assert(gammaparam >= 0 && gammaparam <= 1);
333 
334  banditdata->gamma = gammaparam;
335 }
336 
337 /** set beta parameter of Exp.3 bandit algorithm to increase gain offset for actions that were not played */
339  SCIP_BANDIT* exp3, /**< bandit algorithm */
340  SCIP_Real beta /**< gain offset between 0 and 1 at every observation */
341  )
342 {
343  SCIP_BANDITDATA* banditdata = SCIPbanditGetData(exp3);
344 
345  assert(beta >= 0 && beta <= 1);
346 
347  banditdata->beta = beta;
348 }
349 
350 /** returns probability to play an action */
352  SCIP_BANDIT* exp3, /**< bandit algorithm */
353  int action /**< index of the requested action */
354  )
355 {
356  SCIP_BANDITDATA* banditdata = SCIPbanditGetData(exp3);
357 
358  assert(banditdata->weightsum > 0.0);
359  assert(SCIPbanditGetNActions(exp3) > 0);
360 
361  return (1.0 - banditdata->gamma) * banditdata->weights[action] / banditdata->weightsum + banditdata->gamma / (SCIP_Real)SCIPbanditGetNActions(exp3);
362 }
363 
364 /* include virtual function table for Exp.3 bandit algorithms */
366  SCIP* scip /**< SCIP data structure */
367  )
368 {
369  SCIP_BANDITVTABLE* vtable;
370 
372  SCIPbanditFreeExp3, SCIPbanditSelectExp3, SCIPbanditUpdateExp3, SCIPbanditResetExp3) );
373  assert(vtable != NULL);
374 
375  return SCIP_OKAY;
376 }
SCIP_RETCODE SCIPcreateBanditExp3(SCIP *scip, SCIP_BANDIT **exp3, SCIP_Real *priorities, SCIP_Real gammaparam, SCIP_Real beta, int nactions, unsigned int initseed)
Definition: bandit_exp3.c:299
SCIP_DECL_BANDITUPDATE(SCIPbanditUpdateExp3)
Definition: bandit_exp3.c:132
void SCIPsetBetaExp3(SCIP_BANDIT *exp3, SCIP_Real beta)
Definition: bandit_exp3.c:338
enum SCIP_Retcode SCIP_RETCODE
Definition: type_retcode.h:53
SCIPInterval exp(const SCIPInterval &x)
SCIP_BANDITVTABLE * SCIPfindBanditvtable(SCIP *scip, const char *name)
Definition: scip_bandit.c:64
#define BANDIT_NAME
Definition: bandit_exp3.c:27
SCIP_RETCODE SCIPbanditCreateExp3(BMS_BLKMEM *blkmem, BMS_BUFMEM *bufmem, SCIP_BANDITVTABLE *vtable, SCIP_BANDIT **exp3, SCIP_Real *priorities, SCIP_Real gammaparam, SCIP_Real beta, int nactions, unsigned int initseed)
Definition: bandit_exp3.c:269
SCIP_BANDITDATA * SCIPbanditGetData(SCIP_BANDIT *bandit)
Definition: bandit.c:180
SCIP_DECL_BANDITSELECT(SCIPbanditSelectExp3)
Definition: bandit_exp3.c:73
BMS_BUFMEM * SCIPbuffer(SCIP *scip)
Definition: scip.c:46746
#define SCIPerrorMessage
Definition: pub_message.h:45
BMS_BLKMEM * SCIPblkmem(SCIP *scip)
Definition: scip.c:46731
SCIP_DECL_BANDITRESET(SCIPbanditResetExp3)
Definition: bandit_exp3.c:207
#define SCIP_CALL(x)
Definition: def.h:350
SCIP_DECL_BANDITFREE(SCIPbanditFreeExp3)
Definition: bandit_exp3.c:52
void SCIPbanditSetData(SCIP_BANDIT *bandit, SCIP_BANDITDATA *banditdata)
Definition: bandit.c:190
#define BMSfreeBlockMemory(mem, ptr)
Definition: memory.h:446
SCIP_RETCODE SCIPincludeBanditvtableExp3(SCIP *scip)
Definition: bandit_exp3.c:365
SCIP_Real SCIPgetProbabilityExp3(SCIP_BANDIT *exp3, int action)
Definition: bandit_exp3.c:351
#define BMSallocBlockMemoryArray(mem, ptr, num)
Definition: memory.h:435
#define BMSfreeBlockMemoryArray(mem, ptr, num)
Definition: memory.h:448
unsigned int SCIPinitializeRandomSeed(SCIP *scip, int initialseedvalue)
Definition: scip.c:25905
SCIP_Real SCIPrandomGetReal(SCIP_RANDNUMGEN *randnumgen, SCIP_Real minrandval, SCIP_Real maxrandval)
Definition: misc.c:9388
struct SCIP_BanditData SCIP_BANDITDATA
Definition: type_bandit.h:47
#define NUMTOL
Definition: bandit_exp3.c:28
internal methods for Exp.3 bandit algorithm
void SCIPsetGammaExp3(SCIP_BANDIT *exp3, SCIP_Real gammaparam)
Definition: bandit_exp3.c:325
SCIP_RETCODE SCIPincludeBanditvtable(SCIP *scip, SCIP_BANDITVTABLE **banditvtable, const char *name, SCIP_DECL_BANDITFREE((*banditfree)), SCIP_DECL_BANDITSELECT((*banditselect)), SCIP_DECL_BANDITUPDATE((*banditupdate)), SCIP_DECL_BANDITRESET((*banditreset)))
Definition: scip_bandit.c:32
#define SCIP_Real
Definition: def.h:149
int SCIPbanditGetNActions(SCIP_BANDIT *bandit)
Definition: bandit.c:265
SCIP_RANDNUMGEN * SCIPbanditGetRandnumgen(SCIP_BANDIT *bandit)
Definition: bandit.c:255
#define BMSallocBlockMemory(mem, ptr)
Definition: memory.h:433
struct BMS_BlkMem BMS_BLKMEM
Definition: memory.h:419
#define SCIP_ALLOC(x)
Definition: def.h:361
#define EPSZ(x, eps)
Definition: def.h:179
SCIP_RETCODE SCIPbanditCreate(SCIP_BANDIT **bandit, SCIP_BANDITVTABLE *banditvtable, BMS_BLKMEM *blkmem, BMS_BUFMEM *bufmem, SCIP_Real *priorities, int nactions, unsigned int initseed, SCIP_BANDITDATA *banditdata)
Definition: bandit.c:32