こちらはエムスリー Advent Calendar 2023 11日目の記事です。

A/Bテストをしまくっている、機械学習エンジニアの農見(@rookzeno)です。皆さんA/Bテストをしてますでしょうか。エムスリーでは色々な施策の効果を見るために沢山のA/Bテストをしています。そのためA/Bテストを簡易にできるような設計を作ることも大事なことです。
AI・機械学習チームには、Goで書かれた機械学習関連の機能を各サービスに提供するAPIサーバがあり、こちらのYAMLファイルを設定するだけでA/Bテストが出来るようにしました。
rules:-name: modelArandom_seed:42threshold:50ctrl:weight:0test:weight:1-name: modelBweight:1
このYAMLファイルをどのようにGoのAPIで使ってるかを今回は解説します。
このYAMLファイルをGoのAPIでどう扱っているかという話の前に、AIチームのMLプロダクトの全体構成について説明します。
AIチームでは「バッチで学習・推論して結果をDBに保存しておき、APIはDBの参照のみ行う(リクエスト時に推論をしない)」という構成をよく採用しています。今回はこの構成であることを前提としたコードになりますが、リクエスト時に推論する場合でも同じやり方はできると思います。

AI・機械学習チーム流MLOpsの歴史 - エムスリーテックブログより
rules:-name: modelArandomseed:42threshold:50ctrl:weight:0test:weight:1-name: modelBweight:1
まずこのYAMLファイルが何を示しているのかを説明します。ctrl50%ではmodelA × 0 + modelB × 1、test50%ではmodelA × 1 + modelB × 1のアンサンブルを行うという設定です。randomseedという設定があると思いますが、これはユーザーを分ける関数のseedになります。ctrlとtestには有意差がないように分ける必要があるので、適切なrandomseedを毎回選ぶ必要があります。エムスリーでは毎回そのA/Bテストにとって最適な任意の関数とrandomseedを選んでA/Bテストをしています。
このテーブルの例で具体的に説明します。
| id | modelA | modelB | ctrl | test |
|---|---|---|---|---|
| 1 | 100 | 10 | 10 | 110 |
| 2 | 0 | 15 | 15 | 15 |
modelAはid1に100点、id2に0点をつけています。modelBはid1に10点、id2に15点つけてます。この時ctrlではmodelBのみなのでid1に10点、id2に15点となりid2>id1なのでid2,id1という順番で表示します。一方でtestではmodelAとmodelBの足し算なのでid1に110点、id2に15点となりid1,id2という順番で出すことになります。このように新たなモデルを追加するとレコメンド結果が変わりその効果を見るのがA/Bテストです。
ここからGoのコードでどのように処理してるかを見ていきます。まずはYAMLファイルを読み込むところからです。
YAMLファイルをGoで読み込みには以下のように書けばいいです。
import ("context""fmt""io""log/slog""os""gopkg.in/yaml.v3")type Configstruct { Rules []ruleConfig`yaml:"rules"`}type ruleConfigstruct { WeightValue`yaml:",inline"` Namestring`yaml:"name"` RandomSeed *int`yaml:"randomseed"` Threshold *int`yaml:"threshold"` Ctrl WeightValue`yaml:"ctrl"` Test WeightValue`yaml:"test"`}type WeightValuestruct { Weight *float64`yaml:"weight"`}func ReadYaml(ctx context.Context, rio.Reader) (Config,error) {var config Config err := yaml.NewDecoder(r).Decode(&config)if err !=nil {return config, fmt.Errorf("failed to decode yaml: %v", err) }return config,nil}func main() { logger := slog.New(slog.NewJSONHandler(os.Stdout,nil)) configFile, err := os.Open("config.yaml")if err !=nil { logger.Error("Cannot open config file", err) } config, err := ReadYaml(context.Background(), configFile) configFile.Close()if err !=nil { logger.Error("Cannot create config.Config", err) } fmt.Println(config)}
ReadYamlという関数でYAMLファイルを読み込んでます。Configという構造体を作ってタグをつけると、yaml.NewDecoder(r).Decode(&config)で自動的にパースして構造体に入れてくれるので便利です。
これでYAMLファイルをconfigという構造体に入れることができました。
このconfigを使ってDBに入ったデータを取ってきます。
import ("database/sql""fmt""strings""github.com/jmoiron/sqlx")type Contentstruct { IDstring`db:"id"` Score sql.NullFloat64`db:"score"`}type Rulestruct { Namestring Weightfloat64}type DBstruct { pool *sqlx.DB}// テスト判定 簡易化のためuserIDにrandomSeed値を掛ける方法でやってますが、好きな方法でやってくださいfunc isTest(userIDint, randomSeedint, thresholdint)bool {return userID*randomSeed%100 < threshold}// user_idが属するruleのみを取得するfunc (c Config) GetRules(userIDint) []Rule { result :=make([]Rule,0,len(c.Rules))for _, r :=range c.Rules { r, ok := r.GetRule(userID)if !ok {continue } result =append(result, r) }return result}func (r ruleConfig) GetRule(userIDint) (Rule,bool) { rule := Rule{ Name: r.Name, }if r.Threshold ==nil { rule.Weight = *r.Weightreturn rule,true } target := r.Ctrlif isTest(userID, *r.RandomSeed , *r.Threshold) { target = r.Test }if target.Weight ==nil {return Rule{},false } rule.Weight = *target.Weightreturn rule,true}func (d *DB) LoadScores(userIDint, config Config) ([]Content,error) { rules := config.GetRules(userID) sqls :=make([]string,0,len(rules)) args :=make([]any,0)// rulesに書かれているscoreをUNION ALLで全て出すfor _, c :=range rules { sqls =append(sqls, fmt.Sprintf(`SELECT id, score * ? as score FROM %s_score WHERE user_id = ?`, c.Name)) args =append(args, c.Weight, userID) } sql := strings.Join(sqls," UNION ALL ")// scoreを足し算する sql = fmt.Sprintf("SELECT id, SUM(score) AS score FROM (%s) GROUP BY id", sql)var contents []Content err := d.pool.Select(&contents, sql, args...)if err !=nil {returnnil, err }return contents,nil}
こちらは大きく分けて2段階に分かれています。最初がconfig構造体からuserに対するRuleのスライス(rules)を取得する部分。次がSQLにする部分です。
configにはtestやctrl等書いてありますが、ユーザー単位に落とすときにはその情報は必要ないです。なのでGetRulesでユーザーがtestかctrlどっちになるかを見て、NameとWeightのみをもつrulesにします。
rulesが出来たら後はSQLにするだけです。Goで見ると少し複雑ですが、SQLで書くとこんな感じです。
SELECT id,SUM(score)AS scoreFROM (SELECT id, score * weightFROM modelA_scoreWHERE user_id = ?UNIONALLSELECT id, score * weightFROM modelB_scoreWHERE user_id = ?)GROUPBY id
rulesに入っているモデルのscoreを全部出してgroupbyでsumしてるだけです。
これでYAMLファイルでA/Bテストができました。めでたしめでたし。
良い所
悪い所
この方法の悪い所としてはバッチ側で1つのテーブルを作成してABする場合よりもDBの負荷が上がってしまうことですが、API側にロジックを持つことで、バッチ側の複雑性を下げることが出来ます。更に、今なんのモデルを試しているかをYAMLファイルを見るだけでわかるので良いかなと思ってます。
AI・機械学習チームでは、A/Bテストしやすい環境を整える事も大事にしています。環境を整えるのが好きな人はもちろん、A/Bテストするための高精度なモデルを作る人も募集しています!
引用をストックしました
引用するにはまずログインしてください
引用をストックできませんでした。再度お試しください
限定公開記事のため引用できません。