-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(core): precomputed bandits #128
Conversation
let bandits = configuration | ||
.bandits | ||
.as_ref() | ||
.map(|bandits| { | ||
configuration | ||
.flags | ||
.compiled | ||
.flags | ||
.iter() | ||
.filter_map(|(flag_key, flag)| { | ||
let flag = flag.as_ref().ok()?; | ||
|
||
// Skip non-string variations as they can't be bandits. | ||
if flag.variation_type != VariationType::String { | ||
return None; | ||
} | ||
|
||
let flag_bandits: HashMap</* variation_key: */ Str, PrecomputedBandit> = | ||
if let Some(ValueWire::String(precomputed_variation_value)) = flags | ||
.get(flag_key) | ||
.map(|assignment| &assignment.variation_value) | ||
{ | ||
// If precomputing flag resolved to a value, we only need to evaluate a | ||
// single bandit. | ||
let bandit_key = &configuration | ||
.flags | ||
.compiled | ||
.flag_to_bandit_associations | ||
.get(flag_key)? | ||
.get(precomputed_variation_value)? | ||
.key; | ||
let bandit_model = bandits.bandits.get(bandit_key)?; | ||
|
||
let bandit_evaluation = bandit_model | ||
.model_data | ||
.evaluate(flag_key, subject_key, subject_attributes, actions) | ||
.ok()?; | ||
|
||
let selected_action = &actions[&bandit_evaluation.action_key]; | ||
let precomputed_bandit = PrecomputedBandit { | ||
bandit_key: bandit_key.clone(), | ||
action: bandit_evaluation.action_key, | ||
action_probability: bandit_evaluation.action_weight, | ||
optimality_gap: bandit_evaluation.optimality_gap, | ||
model_version: bandit_model.model_version.clone(), | ||
action_numeric_attributes: selected_action.numeric.clone(), | ||
action_categorical_attributes: selected_action.categorical.clone(), | ||
}; | ||
|
||
[(precomputed_variation_value.clone(), precomputed_bandit)] | ||
.into_iter() | ||
.collect() | ||
} else { | ||
// If precomputed flag did not resolve to a value, we need to precompute all | ||
// bandits for the flag in case the user supplies a bandit variation as | ||
// default variation. | ||
configuration | ||
.flags | ||
.compiled | ||
.flag_to_bandit_associations | ||
.get(flag_key)? | ||
.iter() | ||
.filter_map(|(variation_value, bandit_variation)| { | ||
let bandit_key = &bandit_variation.key; | ||
let bandit_model = bandits.bandits.get(bandit_key)?; | ||
|
||
let bandit_evaluation = bandit_model | ||
.model_data | ||
.evaluate( | ||
flag_key, | ||
subject_key, | ||
subject_attributes, | ||
actions, | ||
) | ||
.ok()?; | ||
|
||
let selected_action = &actions[&bandit_evaluation.action_key]; | ||
let precomputed_bandit = PrecomputedBandit { | ||
bandit_key: bandit_key.clone(), | ||
action: bandit_evaluation.action_key, | ||
action_probability: bandit_evaluation.action_weight, | ||
optimality_gap: bandit_evaluation.optimality_gap, | ||
model_version: bandit_model.model_version.clone(), | ||
action_numeric_attributes: selected_action.numeric.clone(), | ||
action_categorical_attributes: selected_action | ||
.categorical | ||
.clone(), | ||
}; | ||
|
||
Some((variation_value.clone(), precomputed_bandit)) | ||
}) | ||
.collect() | ||
}; | ||
|
||
Some((flag_key.clone(), flag_bandits)) | ||
}) | ||
.collect() | ||
}) | ||
.unwrap_or_default(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The diff in this file is a bit large because I renamed the file.
This highlighted piece is the main change, though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the note; yep I recognize the existing stuff 😄
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is another important file to review
let salt: Str = { | ||
let bytes = rand::thread_rng().gen::<[u8; 16]>(); | ||
base64::prelude::BASE64_STANDARD_NO_PAD | ||
.encode(&bytes) | ||
.into() | ||
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change is to make it use salt in text mode instead of binary as before
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @typotter
impl std::borrow::Borrow<str> for Str { | ||
fn borrow(&self) -> &str { | ||
self.as_str() | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This allows us using &str
for querying HashMap<Str, _>
.
Quite a bit of changes in this PR is changing String
-> Str
to enable quick cloning.
I start leaning to using Str
whenever we use strings because it's much cheaper to clone and has small string optimization (small strings are stored without heap allocation), so it's a good fit for our use case. These optimizations also make it much faster for converting to/from other languages (e.g., Python). The only downside is that occasional String -> Str
conversion might be a bit more costly but there are few
@@ -21,6 +21,7 @@ magnus = ["dep:magnus", "dep:serde_magnus"] | |||
vendored = ["reqwest/native-tls-vendored"] | |||
|
|||
[dependencies] | |||
base64 = "0.22.1" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
base64
is already a transitive dependency of reqwest
, so it's not expanding our dependency set
65a4c6c
to
1ae1fde
Compare
# Add WASM target | ||
- run: rustup target add wasm32-wasi | ||
# Build WASM target separately | ||
- run: cargo build --verbose --target wasm32-wasi | ||
working-directory: fastly-edge-assignments |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for tidying this up
@@ -76,7 +76,7 @@ impl Evaluator { | |||
flag_key: &str, | |||
subject_key: &Str, | |||
subject_attributes: &ContextAttributes, | |||
actions: &HashMap<String, ContextAttributes>, | |||
actions: &HashMap<Str, ContextAttributes>, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the consistent use of Str
let salt: Str = { | ||
let bytes = rand::thread_rng().gen::<[u8; 16]>(); | ||
base64::prelude::BASE64_STANDARD_NO_PAD | ||
.encode(&bytes) | ||
.into() | ||
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @typotter
// Skip non-string variations as they can't be bandits. | ||
if flag.variation_type != VariationType::String { | ||
return None; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know this is true but this optimization could put us into a tricky spot if the code ships in clients and this fact changes. cc @aarsilv is this ever changing?
// Skip non-string variations as they can't be bandits. | |
if flag.variation_type != VariationType::String { | |
return None; | |
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because we look up bandits by variation key and it requires a string, changing this would require quite a bit of changes anyway. e.g., get_bandit_action
doesn't work with non-string variations, too
So this piece here is less of an optimization but rather to ensure correctness and consistent behavior (e.g., so we don't accidentally interpret JSON as a bandit)
let bandits = configuration | ||
.bandits | ||
.as_ref() | ||
.map(|bandits| { | ||
configuration | ||
.flags | ||
.compiled | ||
.flags | ||
.iter() | ||
.filter_map(|(flag_key, flag)| { | ||
let flag = flag.as_ref().ok()?; | ||
|
||
// Skip non-string variations as they can't be bandits. | ||
if flag.variation_type != VariationType::String { | ||
return None; | ||
} | ||
|
||
let flag_bandits: HashMap</* variation_key: */ Str, PrecomputedBandit> = | ||
if let Some(ValueWire::String(precomputed_variation_value)) = flags | ||
.get(flag_key) | ||
.map(|assignment| &assignment.variation_value) | ||
{ | ||
// If precomputing flag resolved to a value, we only need to evaluate a | ||
// single bandit. | ||
let bandit_key = &configuration | ||
.flags | ||
.compiled | ||
.flag_to_bandit_associations | ||
.get(flag_key)? | ||
.get(precomputed_variation_value)? | ||
.key; | ||
let bandit_model = bandits.bandits.get(bandit_key)?; | ||
|
||
let bandit_evaluation = bandit_model | ||
.model_data | ||
.evaluate(flag_key, subject_key, subject_attributes, actions) | ||
.ok()?; | ||
|
||
let selected_action = &actions[&bandit_evaluation.action_key]; | ||
let precomputed_bandit = PrecomputedBandit { | ||
bandit_key: bandit_key.clone(), | ||
action: bandit_evaluation.action_key, | ||
action_probability: bandit_evaluation.action_weight, | ||
optimality_gap: bandit_evaluation.optimality_gap, | ||
model_version: bandit_model.model_version.clone(), | ||
action_numeric_attributes: selected_action.numeric.clone(), | ||
action_categorical_attributes: selected_action.categorical.clone(), | ||
}; | ||
|
||
[(precomputed_variation_value.clone(), precomputed_bandit)] | ||
.into_iter() | ||
.collect() | ||
} else { | ||
// If precomputed flag did not resolve to a value, we need to precompute all | ||
// bandits for the flag in case the user supplies a bandit variation as | ||
// default variation. | ||
configuration | ||
.flags | ||
.compiled | ||
.flag_to_bandit_associations | ||
.get(flag_key)? | ||
.iter() | ||
.filter_map(|(variation_value, bandit_variation)| { | ||
let bandit_key = &bandit_variation.key; | ||
let bandit_model = bandits.bandits.get(bandit_key)?; | ||
|
||
let bandit_evaluation = bandit_model | ||
.model_data | ||
.evaluate( | ||
flag_key, | ||
subject_key, | ||
subject_attributes, | ||
actions, | ||
) | ||
.ok()?; | ||
|
||
let selected_action = &actions[&bandit_evaluation.action_key]; | ||
let precomputed_bandit = PrecomputedBandit { | ||
bandit_key: bandit_key.clone(), | ||
action: bandit_evaluation.action_key, | ||
action_probability: bandit_evaluation.action_weight, | ||
optimality_gap: bandit_evaluation.optimality_gap, | ||
model_version: bandit_model.model_version.clone(), | ||
action_numeric_attributes: selected_action.numeric.clone(), | ||
action_categorical_attributes: selected_action | ||
.categorical | ||
.clone(), | ||
}; | ||
|
||
Some((variation_value.clone(), precomputed_bandit)) | ||
}) | ||
.collect() | ||
}; | ||
|
||
Some((flag_key.clone(), flag_bandits)) | ||
}) | ||
.collect() | ||
}) | ||
.unwrap_or_default(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the note; yep I recognize the existing stuff 😄
ObfuscatedPrecomputedAssignment::from(v), | ||
) | ||
}) | ||
.collect(), | ||
bandits: config |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you have the bandit obfuscation as part of a unit test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a small test to compute bandits against bandit-flags-v1.json
+bandit-models-v1.json
and I verified the output.
Overall, I'm more reliant on types here rather than tests. Having obfuscated type clearly define where md5s and base64s go makes it trivial to ensure correctness. (If the type requires an md5, there's only one way to get — hash salt+string — regular strings and base64 don't fit.)
Once we stabilize formats, I do want to add a couple of test cases to sdk-test-data though
No description provided.