Movatterモバイル変換


[0]ホーム

URL:


エムスリーテックブログ

エムスリー(m3)のエンジニア・開発メンバーによる技術ブログです

機械学習モデルのA/BテストをしやすくするGo言語のAPI設計

こちらはエムスリー Advent Calendar 2023 11日目の記事です。

DALL-Eでサムネ作るの楽勝だぜとなりそうでならない

Overview

A/Bテストをしまくっている、機械学習エンジニアの農見(@rookzeno)です。皆さんA/Bテストをしてますでしょうか。エムスリーでは色々な施策の効果を見るために沢山のA/Bテストをしています。そのためA/Bテストを簡易にできるような設計を作ることも大事なことです。

AI・機械学習チームには、Goで書かれた機械学習関連の機能を各サービスに提供するAPIサーバがあり、こちらのYAMLファイルを設定するだけでA/Bテストが出来るようにしました。

rules:-name: modelArandom_seed:42threshold:50ctrl:weight:0test:weight:1-name: modelBweight:1

このYAMLファイルをどのようにGoのAPIで使ってるかを今回は解説します。

はじめに : 全体の構成について

このYAMLファイルをGoのAPIでどう扱っているかという話の前に、AIチームのMLプロダクトの全体構成について説明します。

AIチームでは「バッチで学習・推論して結果をDBに保存しておき、APIはDBの参照のみ行う(リクエスト時に推論をしない)」という構成をよく採用しています。今回はこの構成であることを前提としたコードになりますが、リクエスト時に推論する場合でも同じやり方はできると思います。

AIチームあるある構成(現代)

AI・機械学習チーム流MLOpsの歴史 - エムスリーテックブログより

0. YAMLファイルの解説

rules:-name: modelArandomseed:42threshold:50ctrl:weight:0test:weight:1-name: modelBweight:1

まずこのYAMLファイルが何を示しているのかを説明します。ctrl50%ではmodelA × 0 + modelB × 1、test50%ではmodelA × 1 + modelB × 1のアンサンブルを行うという設定です。randomseedという設定があると思いますが、これはユーザーを分ける関数のseedになります。ctrlとtestには有意差がないように分ける必要があるので、適切なrandomseedを毎回選ぶ必要があります。エムスリーでは毎回そのA/Bテストにとって最適な任意の関数とrandomseedを選んでA/Bテストをしています。

このテーブルの例で具体的に説明します。

id modelA modelBctrltest
1 1001010110
2 0151515

modelAはid1に100点、id2に0点をつけています。modelBはid1に10点、id2に15点つけてます。この時ctrlではmodelBのみなのでid1に10点、id2に15点となりid2>id1なのでid2,id1という順番で表示します。一方でtestではmodelAとmodelBの足し算なのでid1に110点、id2に15点となりid1,id2という順番で出すことになります。このように新たなモデルを追加するとレコメンド結果が変わりその効果を見るのがA/Bテストです。

1. YAMLファイルを読み込む

ここからGoのコードでどのように処理してるかを見ていきます。まずはYAMLファイルを読み込むところからです。

YAMLファイルをGoで読み込みには以下のように書けばいいです。

import ("context""fmt""io""log/slog""os""gopkg.in/yaml.v3")type Configstruct {    Rules      []ruleConfig`yaml:"rules"`}type ruleConfigstruct {    WeightValue`yaml:",inline"`    Namestring`yaml:"name"`    RandomSeed   *int`yaml:"randomseed"`    Threshold    *int`yaml:"threshold"`    Ctrl         WeightValue`yaml:"ctrl"`    Test         WeightValue`yaml:"test"`}type WeightValuestruct {    Weight       *float64`yaml:"weight"`}func ReadYaml(ctx context.Context, rio.Reader) (Config,error) {var config Config    err := yaml.NewDecoder(r).Decode(&config)if err !=nil {return config, fmt.Errorf("failed to decode yaml: %v", err)    }return config,nil}func main() {    logger := slog.New(slog.NewJSONHandler(os.Stdout,nil))    configFile, err := os.Open("config.yaml")if err !=nil {        logger.Error("Cannot open config file", err)    }    config, err := ReadYaml(context.Background(), configFile)    configFile.Close()if err !=nil {        logger.Error("Cannot create config.Config", err)    }    fmt.Println(config)}

ReadYamlという関数でYAMLファイルを読み込んでます。Configという構造体を作ってタグをつけると、yaml.NewDecoder(r).Decode(&config)で自動的にパースして構造体に入れてくれるので便利です。

これでYAMLファイルをconfigという構造体に入れることができました。

このconfigを使ってDBに入ったデータを取ってきます。

2. DBからデータを取り出す

import ("database/sql""fmt""strings""github.com/jmoiron/sqlx")type Contentstruct {    IDstring`db:"id"`    Score sql.NullFloat64`db:"score"`}type Rulestruct {    Namestring    Weightfloat64}type DBstruct {    pool *sqlx.DB}// テスト判定 簡易化のためuserIDにrandomSeed値を掛ける方法でやってますが、好きな方法でやってくださいfunc isTest(userIDint, randomSeedint, thresholdint)bool {return userID*randomSeed%100 < threshold}// user_idが属するruleのみを取得するfunc (c Config) GetRules(userIDint) []Rule {    result :=make([]Rule,0,len(c.Rules))for _, r :=range c.Rules {        r, ok := r.GetRule(userID)if !ok {continue        }        result =append(result, r)    }return result}func (r ruleConfig) GetRule(userIDint) (Rule,bool) {    rule := Rule{        Name: r.Name,    }if r.Threshold ==nil {        rule.Weight = *r.Weightreturn rule,true    }    target := r.Ctrlif isTest(userID, *r.RandomSeed , *r.Threshold) {        target = r.Test    }if target.Weight ==nil {return Rule{},false    }    rule.Weight = *target.Weightreturn rule,true}func (d *DB) LoadScores(userIDint, config Config) ([]Content,error) {    rules := config.GetRules(userID)    sqls :=make([]string,0,len(rules))    args :=make([]any,0)// rulesに書かれているscoreをUNION ALLで全て出すfor _, c :=range rules {        sqls =append(sqls, fmt.Sprintf(`SELECT id, score * ? as score FROM %s_score WHERE user_id = ?`, c.Name))        args =append(args, c.Weight, userID)    }    sql := strings.Join(sqls," UNION ALL ")// scoreを足し算する    sql = fmt.Sprintf("SELECT id,  SUM(score) AS score FROM (%s)  GROUP BY id", sql)var contents []Content    err := d.pool.Select(&contents, sql, args...)if err !=nil {returnnil, err    }return contents,nil}

こちらは大きく分けて2段階に分かれています。最初がconfig構造体からuserに対するRuleのスライス(rules)を取得する部分。次がSQLにする部分です。

configにはtestやctrl等書いてありますが、ユーザー単位に落とすときにはその情報は必要ないです。なのでGetRulesでユーザーがtestかctrlどっちになるかを見て、NameとWeightのみをもつrulesにします。

rulesが出来たら後はSQLにするだけです。Goで見ると少し複雑ですが、SQLで書くとこんな感じです。

SELECT id,SUM(score)AS scoreFROM (SELECT id, score * weightFROM modelA_scoreWHERE user_id = ?UNIONALLSELECT id, score * weightFROM modelB_scoreWHERE user_id = ?)GROUPBY id

rulesに入っているモデルのscoreを全部出してgroupbyでsumしてるだけです。

これでYAMLファイルでA/Bテストができました。めでたしめでたし。

3. この方法の良い所と悪い所

  • 良い所

    • YAMLファイルを見るだけでテスト内容がわかる
    • バッチ側が独立しているので好き勝手にモデルを作成して試すことができる
  • 悪い所

    • テーブル数が増えるにつれてDBの負荷が上がる

この方法の悪い所としてはバッチ側で1つのテーブルを作成してABする場合よりもDBの負荷が上がってしまうことですが、API側にロジックを持つことで、バッチ側の複雑性を下げることが出来ます。更に、今なんのモデルを試しているかをYAMLファイルを見るだけでわかるので良いかなと思ってます。

We are hiring!

AI・機械学習チームでは、A/Bテストしやすい環境を整える事も大事にしています。環境を整えるのが好きな人はもちろん、A/Bテストするための高精度なモデルを作る人も募集しています!

jobs.m3.com

検索
注目記事

引用をストックしました

引用するにはまずログインしてください

引用をストックできませんでした。再度お試しください

限定公開記事のため引用できません。

読者です読者をやめる読者になる読者になる

[8]ページ先頭

©2009-2025 Movatter.jp