Scippy

SCIP

Solving Constraint Integer Programs

bandit_ucb.c
Go to the documentation of this file.
1 /* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
2 /* */
3 /* This file is part of the program and library */
4 /* SCIP --- Solving Constraint Integer Programs */
5 /* */
6 /* Copyright (C) 2002-2018 Konrad-Zuse-Zentrum */
7 /* fuer Informationstechnik Berlin */
8 /* */
9 /* SCIP is distributed under the terms of the ZIB Academic License. */
10 /* */
11 /* You should have received a copy of the ZIB Academic License */
12 /* along with SCIP; see the file COPYING. If not email to scip@zib.de. */
13 /* */
14 /* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
15 
16 /**@file bandit_ucb.c
17  * @brief methods for UCB bandit selection
18  * @author Gregor Hendel
19  */
20 
21 /*---+----1----+----2----+----3----+----4----+----5----+----6----+----7----+----8----+----9----+----0----+----1----+----2*/
22 
23 #include <assert.h>
24 
25 #include "scip/bandit_ucb.h"
26 #include "blockmemshell/memory.h"
27 
28 #define BANDIT_NAME "ucb"
29 #define NUMEPS 1e-6
30 
31 /*
32  * Data structures
33  */
34 
35 /** implementation specific data of UCB bandit algorithm */
36 struct SCIP_BanditData
37 {
38  int nselections; /**< counter for the number of selections */
39  int* counter; /**< array of counters how often every action has been chosen */
40  int* startperm; /**< indices for starting permutation */
41  SCIP_Real* meanscores; /**< array of average scores for the actions */
42  SCIP_Real alpha; /**< parameter to increase confidence width */
43 };
44 
45 
46 /*
47  * Local methods
48  */
49 
50 /** data reset method */
51 static
53  BMS_BUFMEM* bufmem, /**< buffer memory */
54  SCIP_BANDIT* ucb, /**< ucb bandit algorithm */
55  SCIP_BANDITDATA* banditdata, /**< UCB bandit data structure */
56  SCIP_Real* priorities, /**< priorities for start permutation, or NULL */
57  int nactions /**< number of actions */
58  )
59 {
60  int i;
61  SCIP_RANDNUMGEN* rng;
62 
63  assert(bufmem != NULL);
64  assert(ucb != NULL);
65  assert(nactions > 0);
66 
67  /* clear counters and scores */
68  BMSclearMemoryArray(banditdata->counter, nactions);
69  BMSclearMemoryArray(banditdata->meanscores, nactions);
70  banditdata->nselections = 0;
71 
72  rng = SCIPbanditGetRandnumgen(ucb);
73  assert(rng != NULL);
74 
75  /* initialize start permutation as identity */
76  for( i = 0; i < nactions; ++i )
77  banditdata->startperm[i] = i;
78 
79  /* prepare the start permutation in decreasing order of priority */
80  if( priorities != NULL )
81  {
82  SCIP_Real* prioritycopy;
83 
84  SCIP_ALLOC( BMSduplicateBufferMemoryArray(bufmem, &prioritycopy, priorities, nactions) );
85 
86  /* randomly wiggle priorities a little bit to make them unique */
87  for( i = 0; i < nactions; ++i )
88  prioritycopy[i] += SCIPrandomGetReal(rng, -NUMEPS, NUMEPS);
89 
90  SCIPsortDownRealInt(prioritycopy, banditdata->startperm, nactions);
91 
92  BMSfreeBufferMemoryArray(bufmem, &prioritycopy);
93  }
94  else
95  {
96  /* use a random start permutation */
97  SCIPrandomPermuteIntArray(rng, banditdata->startperm, 0, nactions);
98  }
99 
100  return SCIP_OKAY;
101 }
102 
103 
104 /*
105  * Callback methods of bandit algorithm
106  */
107 
108 /** callback to free bandit specific data structures */
109 SCIP_DECL_BANDITFREE(SCIPbanditFreeUcb)
110 { /*lint --e{715}*/
111 
112  SCIP_BANDITDATA* banditdata;
113  int nactions;
114  assert(bandit != NULL);
115 
116  banditdata = SCIPbanditGetData(bandit);
117  assert(banditdata != NULL);
118  nactions = SCIPbanditGetNActions(bandit);
119 
120  BMSfreeBlockMemoryArray(blkmem, &banditdata->counter, nactions);
121  BMSfreeBlockMemoryArray(blkmem, &banditdata->startperm, nactions);
122  BMSfreeBlockMemoryArray(blkmem, &banditdata->meanscores, nactions);
123  BMSfreeBlockMemory(blkmem, &banditdata);
124 
125  SCIPbanditSetData(bandit, NULL);
126 
127  return SCIP_OKAY;
128 }
129 
130 /** selection callback for bandit selector */
131 SCIP_DECL_BANDITSELECT(SCIPbanditSelectUcb)
132 { /*lint --e{715}*/
133 
134  SCIP_BANDITDATA* banditdata;
135  int nactions;
136  int* counter;
137 
138  assert(bandit != NULL);
139  assert(selection != NULL);
140 
141  banditdata = SCIPbanditGetData(bandit);
142  assert(banditdata != NULL);
143  nactions = SCIPbanditGetNActions(bandit);
144 
145  counter = banditdata->counter;
146  /* select the next uninitialized action from the start permutation */
147  if( banditdata->nselections < nactions )
148  {
149  *selection = banditdata->startperm[banditdata->nselections];
150  assert(counter[*selection] == 0);
151  }
152  else
153  {
154  /* select the action with the highest upper confidence bound */
155  SCIP_Real* meanscores;
156  SCIP_Real widthfactor;
157  SCIP_Real maxucb;
158  int i;
160  meanscores = banditdata->meanscores;
161 
162  assert(rng != NULL);
163  assert(meanscores != NULL);
164 
165  /* compute the confidence width factor that is common for all actions */
166  /* cppcheck-suppress unpreciseMathCall */
167  widthfactor = banditdata->alpha * LOG1P((SCIP_Real)banditdata->nselections);
168  widthfactor = sqrt(widthfactor);
169  maxucb = -1.0;
170 
171  /* loop over the actions and determine the maximum upper confidence bound.
172  * The upper confidence bound of an action is the sum of its mean score
173  * plus a confidence term that decreases with increasing number of observations of
174  * this action.
175  */
176  for( i = 0; i < nactions; ++i )
177  {
178  SCIP_Real uppercb;
179  SCIP_Real rootcount;
180  assert(counter[i] > 0);
181 
182  /* compute the upper confidence bound for action i */
183  uppercb = meanscores[i];
184  rootcount = sqrt((SCIP_Real)counter[i]);
185  uppercb += widthfactor / rootcount;
186  assert(uppercb > 0);
187 
188  /* update maximum, breaking ties uniformly at random */
189  if( EPSGT(uppercb, maxucb, NUMEPS) || (EPSEQ(uppercb, maxucb, NUMEPS) && SCIPrandomGetReal(rng, 0.0, 1.0) >= 0.5) )
190  {
191  maxucb = uppercb;
192  *selection = i;
193  }
194  }
195  }
196 
197  assert(*selection >= 0);
198  assert(*selection < nactions);
199 
200  return SCIP_OKAY;
201 }
202 
203 /** update callback for bandit algorithm */
204 SCIP_DECL_BANDITUPDATE(SCIPbanditUpdateUcb)
205 { /*lint --e{715}*/
206  SCIP_BANDITDATA* banditdata;
207  SCIP_Real delta;
208 
209  assert(bandit != NULL);
210 
211  banditdata = SCIPbanditGetData(bandit);
212  assert(banditdata != NULL);
213  assert(selection >= 0);
214  assert(selection < SCIPbanditGetNActions(bandit));
215 
216  /* increase the mean by the incremental formula: A_n = A_n-1 + 1/n (a_n - A_n-1) */
217  delta = score - banditdata->meanscores[selection];
218  ++banditdata->counter[selection];
219  banditdata->meanscores[selection] += delta / (SCIP_Real)banditdata->counter[selection];
220 
221  banditdata->nselections++;
222 
223  return SCIP_OKAY;
224 }
225 
226 /** reset callback for bandit algorithm */
227 SCIP_DECL_BANDITRESET(SCIPbanditResetUcb)
228 { /*lint --e{715}*/
229  SCIP_BANDITDATA* banditdata;
230  int nactions;
231 
232  assert(bufmem != NULL);
233  assert(bandit != NULL);
234 
235  banditdata = SCIPbanditGetData(bandit);
236  assert(banditdata != NULL);
237  nactions = SCIPbanditGetNActions(bandit);
238 
239  /* call the data reset for the given priorities */
240  SCIP_CALL( dataReset(bufmem, bandit, banditdata, priorities, nactions) );
241 
242  return SCIP_OKAY;
243 }
244 
245 /*
246  * bandit algorithm specific interface methods
247  */
248 
249 /** returns the upper confidence bound of a selected action */
251  SCIP_BANDIT* ucb, /**< UCB bandit algorithm */
252  int action /**< index of the queried action */
253  )
254 {
255  SCIP_Real uppercb;
256  SCIP_BANDITDATA* banditdata;
257  int nactions;
258 
259  assert(ucb != NULL);
260  banditdata = SCIPbanditGetData(ucb);
261  nactions = SCIPbanditGetNActions(ucb);
262  assert(action < nactions);
263 
264  /* since only scores between 0 and 1 are allowed, 1.0 is a sure upper confidence bound */
265  if( banditdata->nselections < nactions )
266  return 1.0;
267 
268  /* the bandit algorithm must have picked every action once */
269  assert(banditdata->counter[action] > 0);
270  uppercb = banditdata->meanscores[action];
271 
272  /* cppcheck-suppress unpreciseMathCall */
273  uppercb += sqrt(banditdata->alpha * LOG1P((SCIP_Real)banditdata->nselections) / (SCIP_Real)banditdata->counter[action]);
274 
275  return uppercb;
276 }
277 
278 /** return start permutation of the UCB bandit algorithm */
280  SCIP_BANDIT* ucb /**< UCB bandit algorithm */
281  )
282 {
283  SCIP_BANDITDATA* banditdata = SCIPbanditGetData(ucb);
284 
285  assert(banditdata != NULL);
286 
287  return banditdata->startperm;
288 }
289 
290 /** internal method to create and reset UCB bandit algorithm */
292  BMS_BLKMEM* blkmem, /**< block memory */
293  BMS_BUFMEM* bufmem, /**< buffer memory */
294  SCIP_BANDITVTABLE* vtable, /**< virtual function table for UCB bandit algorithm */
295  SCIP_BANDIT** ucb, /**< pointer to store bandit algorithm */
296  SCIP_Real* priorities, /**< nonnegative priorities for each action, or NULL if not needed */
297  SCIP_Real alpha, /**< parameter to increase confidence width */
298  int nactions, /**< the positive number of actions for this bandit algorithm */
299  unsigned int initseed /**< initial random seed */
300  )
301 {
302  SCIP_BANDITDATA* banditdata;
303 
304  if( alpha < 0.0 )
305  {
306  SCIPerrorMessage("UCB requires nonnegative alpha parameter, have %f\n", alpha);
307  return SCIP_INVALIDDATA;
308  }
309 
310  SCIP_ALLOC( BMSallocBlockMemory(blkmem, &banditdata) );
311  assert(banditdata != NULL);
312 
313  SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->counter, nactions) );
314  SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->startperm, nactions) );
315  SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->meanscores, nactions) );
316 
317  banditdata->alpha = alpha;
318 
319  SCIP_CALL( SCIPbanditCreate(ucb, vtable, blkmem, bufmem, priorities, nactions, initseed, banditdata) );
320 
321  return SCIP_OKAY;
322 }
323 
324 /** create and reset UCB bandit algorithm */
326  SCIP* scip, /**< SCIP data structure */
327  SCIP_BANDIT** ucb, /**< pointer to store bandit algorithm */
328  SCIP_Real* priorities, /**< nonnegative priorities for each action, or NULL if not needed */
329  SCIP_Real alpha, /**< parameter to increase confidence width */
330  int nactions, /**< the positive number of actions for this bandit algorithm */
331  unsigned int initseed /**< initial random number seed */
332  )
333 {
334  SCIP_BANDITVTABLE* vtable;
335 
336  vtable = SCIPfindBanditvtable(scip, BANDIT_NAME);
337  if( vtable == NULL )
338  {
339  SCIPerrorMessage("Could not find virtual function table for %s bandit algorithm\n", BANDIT_NAME);
340  return SCIP_INVALIDDATA;
341  }
342 
343  SCIP_CALL( SCIPbanditCreateUcb(SCIPblkmem(scip), SCIPbuffer(scip), vtable, ucb,
344  priorities, alpha, nactions, SCIPinitializeRandomSeed(scip, (int)(initseed % INT_MAX))) );
345 
346  return SCIP_OKAY;
347 }
348 
349 /** include virtual function table for UCB bandit algorithms */
351  SCIP* scip /**< SCIP data structure */
352  )
353 {
354  SCIP_BANDITVTABLE* vtable;
355 
357  SCIPbanditFreeUcb, SCIPbanditSelectUcb, SCIPbanditUpdateUcb, SCIPbanditResetUcb) );
358  assert(vtable != NULL);
359 
360  return SCIP_OKAY;
361 }
SCIP_RETCODE SCIPcreateBanditUcb(SCIP *scip, SCIP_BANDIT **ucb, SCIP_Real *priorities, SCIP_Real alpha, int nactions, unsigned int initseed)
Definition: bandit_ucb.c:325
#define EPSEQ(x, y, eps)
Definition: def.h:174
enum SCIP_Retcode SCIP_RETCODE
Definition: type_retcode.h:53
SCIP_BANDITVTABLE * SCIPfindBanditvtable(SCIP *scip, const char *name)
Definition: scip_bandit.c:64
void SCIPsortDownRealInt(SCIP_Real *realarray, int *intarray, int len)
#define BMSduplicateBufferMemoryArray(mem, ptr, source, num)
Definition: memory.h:708
SCIP_BANDITDATA * SCIPbanditGetData(SCIP_BANDIT *bandit)
Definition: bandit.c:180
BMS_BUFMEM * SCIPbuffer(SCIP *scip)
Definition: scip.c:46746
#define SCIPerrorMessage
Definition: pub_message.h:45
internal methods for UCB bandit algorithm
SCIPInterval sqrt(const SCIPInterval &x)
SCIP_DECL_BANDITUPDATE(SCIPbanditUpdateUcb)
Definition: bandit_ucb.c:204
BMS_BLKMEM * SCIPblkmem(SCIP *scip)
Definition: scip.c:46731
SCIP_DECL_BANDITRESET(SCIPbanditResetUcb)
Definition: bandit_ucb.c:227
#define SCIP_CALL(x)
Definition: def.h:350
void SCIPbanditSetData(SCIP_BANDIT *bandit, SCIP_BANDITDATA *banditdata)
Definition: bandit.c:190
#define NUMEPS
Definition: bandit_ucb.c:29
#define BMSfreeBlockMemory(mem, ptr)
Definition: memory.h:446
SCIP_RETCODE SCIPincludeBanditvtableUcb(SCIP *scip)
Definition: bandit_ucb.c:350
#define BMSallocBlockMemoryArray(mem, ptr, num)
Definition: memory.h:435
SCIP_DECL_BANDITSELECT(SCIPbanditSelectUcb)
Definition: bandit_ucb.c:131
#define BMSfreeBlockMemoryArray(mem, ptr, num)
Definition: memory.h:448
void SCIPrandomPermuteIntArray(SCIP_RANDNUMGEN *randnumgen, int *array, int begin, int end)
Definition: misc.c:9407
int * SCIPgetStartPermutationUcb(SCIP_BANDIT *ucb)
Definition: bandit_ucb.c:279
unsigned int SCIPinitializeRandomSeed(SCIP *scip, int initialseedvalue)
Definition: scip.c:25905
#define BANDIT_NAME
Definition: bandit_ucb.c:28
SCIP_Real SCIPrandomGetReal(SCIP_RANDNUMGEN *randnumgen, SCIP_Real minrandval, SCIP_Real maxrandval)
Definition: misc.c:9388
struct SCIP_BanditData SCIP_BANDITDATA
Definition: type_bandit.h:47
#define EPSGT(x, y, eps)
Definition: def.h:177
SCIP_RETCODE SCIPincludeBanditvtable(SCIP *scip, SCIP_BANDITVTABLE **banditvtable, const char *name, SCIP_DECL_BANDITFREE((*banditfree)), SCIP_DECL_BANDITSELECT((*banditselect)), SCIP_DECL_BANDITUPDATE((*banditupdate)), SCIP_DECL_BANDITRESET((*banditreset)))
Definition: scip_bandit.c:32
SCIP_RETCODE SCIPbanditCreateUcb(BMS_BLKMEM *blkmem, BMS_BUFMEM *bufmem, SCIP_BANDITVTABLE *vtable, SCIP_BANDIT **ucb, SCIP_Real *priorities, SCIP_Real alpha, int nactions, unsigned int initseed)
Definition: bandit_ucb.c:291
#define SCIP_Real
Definition: def.h:149
int SCIPbanditGetNActions(SCIP_BANDIT *bandit)
Definition: bandit.c:265
static SCIP_RETCODE dataReset(BMS_BUFMEM *bufmem, SCIP_BANDIT *ucb, SCIP_BANDITDATA *banditdata, SCIP_Real *priorities, int nactions)
Definition: bandit_ucb.c:52
SCIP_RANDNUMGEN * SCIPbanditGetRandnumgen(SCIP_BANDIT *bandit)
Definition: bandit.c:255
#define BMSallocBlockMemory(mem, ptr)
Definition: memory.h:433
#define BMSclearMemoryArray(ptr, num)
Definition: memory.h:112
struct BMS_BlkMem BMS_BLKMEM
Definition: memory.h:419
#define SCIP_ALLOC(x)
Definition: def.h:361
SCIP_Real SCIPgetConfidenceBoundUcb(SCIP_BANDIT *ucb, int action)
Definition: bandit_ucb.c:250
#define BMSfreeBufferMemoryArray(mem, ptr)
Definition: memory.h:713
SCIP_RETCODE SCIPbanditCreate(SCIP_BANDIT **bandit, SCIP_BANDITVTABLE *banditvtable, BMS_BLKMEM *blkmem, BMS_BUFMEM *bufmem, SCIP_Real *priorities, int nactions, unsigned int initseed, SCIP_BANDITDATA *banditdata)
Definition: bandit.c:32
SCIP_DECL_BANDITFREE(SCIPbanditFreeUcb)
Definition: bandit_ucb.c:109
memory allocation routines