Scippy

SCIP

Solving Constraint Integer Programs

bandit_exp3ix.c
Go to the documentation of this file.
1 /* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
2 /* */
3 /* This file is part of the program and library */
4 /* SCIP --- Solving Constraint Integer Programs */
5 /* */
6 /* Copyright (c) 2002-2024 Zuse Institute Berlin (ZIB) */
7 /* */
8 /* Licensed under the Apache License, Version 2.0 (the "License"); */
9 /* you may not use this file except in compliance with the License. */
10 /* You may obtain a copy of the License at */
11 /* */
12 /* http://www.apache.org/licenses/LICENSE-2.0 */
13 /* */
14 /* Unless required by applicable law or agreed to in writing, software */
15 /* distributed under the License is distributed on an "AS IS" BASIS, */
16 /* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */
17 /* See the License for the specific language governing permissions and */
18 /* limitations under the License. */
19 /* */
20 /* You should have received a copy of the Apache-2.0 license */
21 /* along with SCIP; see the file LICENSE. If not visit scipopt.org. */
22 /* */
23 /* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
24 
25 /**@file bandit_exp3ix.c
26  * @ingroup OTHER_CFILES
27  * @brief methods for Exp.3-IX bandit selection
28  * @author Antonia Chmiela
29  */
30 
31 /*---+----1----+----2----+----3----+----4----+----5----+----6----+----7----+----8----+----9----+----0----+----1----+----2*/
32 
33 #include "scip/bandit.h"
34 #include "scip/bandit_exp3ix.h"
35 #include "scip/pub_bandit.h"
36 #include "scip/pub_message.h"
37 #include "scip/pub_misc.h"
38 #include "scip/scip_bandit.h"
39 #include "scip/scip_mem.h"
40 #include "scip/scip_randnumgen.h"
41 
42 #define BANDIT_NAME "exp3ix"
43 
44 /*
45  * Data structures
46  */
47 
48 /** implementation specific data of Exp.3 bandit algorithm */
49 struct SCIP_BanditData
50 {
51  SCIP_Real* weights; /**< exponential weight for each arm */
52  SCIP_Real weightsum; /**< the sum of all weights */
53  int iter; /**< current iteration counter to compute parameters gamma_t and eta_t */
54 };
55 
56 /*
57  * Local methods
58  */
59 
60 /*
61  * Callback methods of bandit algorithm
62  */
63 
64 /** callback to free bandit specific data structures */
65 SCIP_DECL_BANDITFREE(SCIPbanditFreeExp3IX)
66 { /*lint --e{715}*/
67  SCIP_BANDITDATA* banditdata;
68  int nactions;
69  assert(bandit != NULL);
70 
71  banditdata = SCIPbanditGetData(bandit);
72  assert(banditdata != NULL);
73  nactions = SCIPbanditGetNActions(bandit);
74 
75  BMSfreeBlockMemoryArray(blkmem, &banditdata->weights, nactions);
76 
77  BMSfreeBlockMemory(blkmem, &banditdata);
78 
79  SCIPbanditSetData(bandit, NULL);
80 
81  return SCIP_OKAY;
82 }
83 
84 /** selection callback for bandit selector */
85 SCIP_DECL_BANDITSELECT(SCIPbanditSelectExp3IX)
86 { /*lint --e{715}*/
87  SCIP_BANDITDATA* banditdata;
88  SCIP_RANDNUMGEN* rng;
89  SCIP_Real* weights;
90  SCIP_Real weightsum;
91  int i;
92  int nactions;
93  SCIP_Real psum;
94  SCIP_Real randnr;
95 
96  assert(bandit != NULL);
97  assert(selection != NULL);
98 
99  banditdata = SCIPbanditGetData(bandit);
100  assert(banditdata != NULL);
101  rng = SCIPbanditGetRandnumgen(bandit);
102  assert(rng != NULL);
103  nactions = SCIPbanditGetNActions(bandit);
104 
105  /* initialize some local variables to speed up probability computations */
106  weightsum = banditdata->weightsum;
107  weights = banditdata->weights;
108 
109  /* draw a random number between 0 and 1 */
110  randnr = SCIPrandomGetReal(rng, 0.0, 1.0);
111 
112  /* loop over probability distribution until rand is reached
113  * the loop terminates without looking at the last action,
114  * which is then selected automatically if the target probability
115  * is not reached earlier
116  */
117  psum = 0.0;
118  for( i = 0; i < nactions - 1; ++i )
119  {
120  SCIP_Real prob;
121 
122  /* compute the probability for arm i */
123  prob = weights[i] / weightsum;
124  psum += prob;
125 
126  /* break and select element if target probability is reached */
127  if( randnr <= psum )
128  break;
129  }
130 
131  /* select element i, which is the last action in case that the break statement hasn't been reached */
132  *selection = i;
133 
134  return SCIP_OKAY;
135 }
136 
137 /** compute gamma_t */
138 static
140  int nactions, /**< the positive number of actions for this bandit algorithm */
141  int t /**< current iteration */
142  )
143 {
144  return sqrt(log((SCIP_Real)nactions) / (4.0 * (SCIP_Real)t * (SCIP_Real)nactions));
145 }
146 
147 /** update callback for bandit algorithm */
148 SCIP_DECL_BANDITUPDATE(SCIPbanditUpdateExp3IX)
149 { /*lint --e{715}*/
150  SCIP_BANDITDATA* banditdata;
151  SCIP_Real etaparam;
152  SCIP_Real lossestim;
153  SCIP_Real prob;
154  SCIP_Real weightsum;
155  SCIP_Real newweightsum;
156  SCIP_Real* weights;
157  SCIP_Real gammaparam;
158  int nactions;
159 
160  assert(bandit != NULL);
161 
162  banditdata = SCIPbanditGetData(bandit);
163  assert(banditdata != NULL);
164  nactions = SCIPbanditGetNActions(bandit);
165 
166  assert(selection >= 0);
167  assert(selection < nactions);
168 
169  weights = banditdata->weights;
170  weightsum = banditdata->weightsum;
171  newweightsum = weightsum;
172  gammaparam = SCIPcomputeGamma(nactions, banditdata->iter);
173  etaparam = 2.0 * gammaparam;
174 
175  /* probability of selection */
176  prob = weights[selection] / weightsum;
177 
178  /* estimated loss */
179  lossestim = (1.0 - score) / (prob + gammaparam);
180  assert(lossestim >= 0);
181 
182  /* update the observation for the current arm */
183  newweightsum -= weights[selection];
184  weights[selection] *= exp(-etaparam * lossestim);
185  newweightsum += weights[selection];
186 
187  banditdata->weightsum = newweightsum;
188 
189  /* increase iteration counter */
190  banditdata->iter += 1;
191 
192  return SCIP_OKAY;
193 }
194 
195 /** reset callback for bandit algorithm */
196 SCIP_DECL_BANDITRESET(SCIPbanditResetExp3IX)
197 { /*lint --e{715}*/
198  SCIP_BANDITDATA* banditdata;
199  SCIP_Real* weights;
200  int nactions;
201  int i;
202 
203  assert(bandit != NULL);
204 
205  banditdata = SCIPbanditGetData(bandit);
206  assert(banditdata != NULL);
207  nactions = SCIPbanditGetNActions(bandit);
208  weights = banditdata->weights;
209 
210  assert(nactions > 0);
211 
212  /* initialize all weights with 1.0 */
213  for( i = 0; i < nactions; ++i )
214  weights[i] = 1.0;
215 
216  banditdata->weightsum = (SCIP_Real)nactions;
217 
218  /* set iteration counter to 1 */
219  banditdata->iter = 1;
220 
221  return SCIP_OKAY;
222 }
223 
224 
225 /*
226  * bandit algorithm specific interface methods
227  */
228 
229 /** direct bandit creation method for the core where no SCIP pointer is available */
231  BMS_BLKMEM* blkmem, /**< block memory data structure */
232  BMS_BUFMEM* bufmem, /**< buffer memory */
233  SCIP_BANDITVTABLE* vtable, /**< virtual function table for callback functions of Exp.3-IX */
234  SCIP_BANDIT** exp3ix, /**< pointer to store bandit algorithm */
235  SCIP_Real* priorities, /**< nonnegative priorities for each action, or NULL if not needed */
236  int nactions, /**< the positive number of actions for this bandit algorithm */
237  unsigned int initseed /**< initial random seed */
238  )
239 {
240  SCIP_BANDITDATA* banditdata;
241 
242  SCIP_ALLOC( BMSallocBlockMemory(blkmem, &banditdata) );
243  assert(banditdata != NULL);
244 
245  banditdata->iter = 1;
246 
247  SCIP_ALLOC( BMSallocBlockMemoryArray(blkmem, &banditdata->weights, nactions) );
248 
249  SCIP_CALL( SCIPbanditCreate(exp3ix, vtable, blkmem, bufmem, priorities, nactions, initseed, banditdata) );
250 
251  return SCIP_OKAY;
252 }
253 
254 /** creates and resets an Exp.3-IX bandit algorithm using \p scip pointer */
256  SCIP* scip, /**< SCIP data structure */
257  SCIP_BANDIT** exp3ix, /**< pointer to store bandit algorithm */
258  SCIP_Real* priorities, /**< nonnegative priorities for each action, or NULL if not needed */
259  int nactions, /**< the positive number of actions for this bandit algorithm */
260  unsigned int initseed /**< initial seed for random number generation */
261  )
262 {
263  SCIP_BANDITVTABLE* vtable;
264 
265  vtable = SCIPfindBanditvtable(scip, BANDIT_NAME);
266  if( vtable == NULL )
267  {
268  SCIPerrorMessage("Could not find virtual function table for %s bandit algorithm\n", BANDIT_NAME);
269  return SCIP_INVALIDDATA;
270  }
271 
272  SCIP_CALL( SCIPbanditCreateExp3IX(SCIPblkmem(scip), SCIPbuffer(scip), vtable, exp3ix,
273  priorities, nactions, SCIPinitializeRandomSeed(scip, initseed)) );
274 
275  return SCIP_OKAY;
276 }
277 
278 /** returns probability to play an action */
280  SCIP_BANDIT* exp3ix, /**< bandit algorithm */
281  int action /**< index of the requested action */
282  )
283 {
284  SCIP_BANDITDATA* banditdata = SCIPbanditGetData(exp3ix);
285 
286  assert(banditdata->weightsum > 0.0);
287  assert(SCIPbanditGetNActions(exp3ix) > 0);
288 
289  return banditdata->weights[action] / banditdata->weightsum;
290 }
291 
292 /** include virtual function table for Exp.3-IX bandit algorithms */
294  SCIP* scip /**< SCIP data structure */
295  )
296 {
297  SCIP_BANDITVTABLE* vtable;
298 
300  SCIPbanditFreeExp3IX, SCIPbanditSelectExp3IX, SCIPbanditUpdateExp3IX, SCIPbanditResetExp3IX) );
301  assert(vtable != NULL);
302 
303  return SCIP_OKAY;
304 }
#define NULL
Definition: def.h:267
public methods for memory management
internal methods for bandit algorithms
enum SCIP_Retcode SCIP_RETCODE
Definition: type_retcode.h:63
SCIP_BANDITDATA * SCIPbanditGetData(SCIP_BANDIT *bandit)
Definition: bandit.c:190
internal methods for Exp.3-IX bandit algorithm
SCIP_DECL_BANDITRESET(SCIPbanditResetExp3IX)
SCIP_DECL_BANDITUPDATE(SCIPbanditUpdateExp3IX)
BMS_BUFMEM * SCIPbuffer(SCIP *scip)
Definition: scip_mem.c:72
#define SCIPerrorMessage
Definition: pub_message.h:64
static SCIP_Real SCIPcomputeGamma(int nactions, int t)
BMS_BLKMEM * SCIPblkmem(SCIP *scip)
Definition: scip_mem.c:57
SCIP_RETCODE SCIPbanditCreateExp3IX(BMS_BLKMEM *blkmem, BMS_BUFMEM *bufmem, SCIP_BANDITVTABLE *vtable, SCIP_BANDIT **exp3ix, SCIP_Real *priorities, int nactions, unsigned int initseed)
SCIP_BANDITVTABLE * SCIPfindBanditvtable(SCIP *scip, const char *name)
Definition: scip_bandit.c:80
#define SCIP_CALL(x)
Definition: def.h:380
#define BANDIT_NAME
Definition: bandit_exp3ix.c:42
void SCIPbanditSetData(SCIP_BANDIT *bandit, SCIP_BANDITDATA *banditdata)
Definition: bandit.c:200
#define BMSfreeBlockMemory(mem, ptr)
Definition: memory.h:465
public data structures and miscellaneous methods
#define BMSallocBlockMemoryArray(mem, ptr, num)
Definition: memory.h:454
SCIP_RETCODE SCIPcreateBanditExp3IX(SCIP *scip, SCIP_BANDIT **exp3ix, SCIP_Real *priorities, int nactions, unsigned int initseed)
SCIP_RETCODE SCIPincludeBanditvtable(SCIP *scip, SCIP_BANDITVTABLE **banditvtable, const char *name, SCIP_DECL_BANDITFREE((*banditfree)), SCIP_DECL_BANDITSELECT((*banditselect)), SCIP_DECL_BANDITUPDATE((*banditupdate)), SCIP_DECL_BANDITRESET((*banditreset)))
Definition: scip_bandit.c:48
#define BMSfreeBlockMemoryArray(mem, ptr, num)
Definition: memory.h:467
public methods for bandit algorithms
SCIP_Real SCIPrandomGetReal(SCIP_RANDNUMGEN *randnumgen, SCIP_Real minrandval, SCIP_Real maxrandval)
Definition: misc.c:10130
public methods for bandit algorithms
SCIP_DECL_BANDITFREE(SCIPbanditFreeExp3IX)
Definition: bandit_exp3ix.c:65
struct SCIP_BanditData SCIP_BANDITDATA
Definition: type_bandit.h:56
SCIP_Real SCIPgetProbabilityExp3IX(SCIP_BANDIT *exp3ix, int action)
public methods for random numbers
SCIP_RETCODE SCIPincludeBanditvtableExp3IX(SCIP *scip)
public methods for message output
#define SCIP_Real
Definition: def.h:173
int SCIPbanditGetNActions(SCIP_BANDIT *bandit)
Definition: bandit.c:303
SCIP_RANDNUMGEN * SCIPbanditGetRandnumgen(SCIP_BANDIT *bandit)
Definition: bandit.c:293
SCIP_DECL_BANDITSELECT(SCIPbanditSelectExp3IX)
Definition: bandit_exp3ix.c:85
#define BMSallocBlockMemory(mem, ptr)
Definition: memory.h:451
unsigned int SCIPinitializeRandomSeed(SCIP *scip, unsigned int initialseedvalue)
struct BMS_BlkMem BMS_BLKMEM
Definition: memory.h:437
#define SCIP_ALLOC(x)
Definition: def.h:391
SCIP_RETCODE SCIPbanditCreate(SCIP_BANDIT **bandit, SCIP_BANDITVTABLE *banditvtable, BMS_BLKMEM *blkmem, BMS_BUFMEM *bufmem, SCIP_Real *priorities, int nactions, unsigned int initseed, SCIP_BANDITDATA *banditdata)
Definition: bandit.c:42