@@ -53,6 +53,8 @@ type Field struct {
5353Name string
5454Type pyType
5555Comment string
56+ // EmbedFields contains the embedded fields that require scanning.
57+ EmbedFields []Field
5658}
5759
5860type Struct struct {
@@ -105,14 +107,42 @@ func (v QueryValue) RowNode(rowVar string) *pyast.Node {
105107call := & pyast.Call {
106108Func :v .Annotation (),
107109}
108- for i ,f := range v .Struct .Fields {
109- call .Keywords = append (call .Keywords ,& pyast.Keyword {
110- Arg :f .Name ,
111- Value :subscriptNode (
110+ rowIndex := 0 // We need to keep track of the index in the row variable.
111+ for _ ,f := range v .Struct .Fields {
112+
113+ var valueNode * pyast.Node
114+ // Check if we are using sqlc.embed, if so we need to create a new object.
115+ if len (f .EmbedFields )> 0 {
116+ // We keep this separate so we can easily add all arguments.
117+ embed_call := & pyast.Call {Func :f .Type .Annotation ()}
118+
119+ // Now add all field Initializers for the embedded model that index into the original row.
120+ for i ,embedField := range f .EmbedFields {
121+ embed_call .Keywords = append (embed_call .Keywords ,& pyast.Keyword {
122+ Arg :embedField .Name ,
123+ Value :subscriptNode (
124+ rowVar ,
125+ constantInt (rowIndex + i ),
126+ ),
127+ })
128+ }
129+
130+ valueNode = & pyast.Node {
131+ Node :& pyast.Node_Call {
132+ Call :embed_call ,
133+ },
134+ }
135+
136+ rowIndex += len (f .EmbedFields )
137+ }else {
138+ valueNode = subscriptNode (
112139rowVar ,
113- constantInt (i ),
114- ),
115- })
140+ constantInt (rowIndex ),
141+ )
142+ rowIndex ++
143+ }
144+
145+ call .Keywords = append (call .Keywords ,& pyast.Keyword {Arg :f .Name ,Value :valueNode })
116146}
117147return & pyast.Node {
118148Node :& pyast.Node_Call {
@@ -336,6 +366,47 @@ func paramName(p *plugin.Parameter) string {
336366type pyColumn struct {
337367id int32
338368* plugin.Column
369+ embed * pyEmbed
370+ }
371+
372+ type pyEmbed struct {
373+ modelType string
374+ modelName string
375+ fields []Field
376+ }
377+
378+ // Taken from https://github.com/sqlc-dev/sqlc/blob/8c59fbb9938a0bad3d9971fc2c10ea1f83cc1d0b/internal/codegen/golang/result.go#L123-L126
379+ // look through all the structs and attempt to find a matching one to embed
380+ // We need the name of the struct and its field names.
381+ func newGoEmbed (embed * plugin.Identifier ,structs []Struct ,defaultSchema string )* pyEmbed {
382+ if embed == nil {
383+ return nil
384+ }
385+
386+ for _ ,s := range structs {
387+ embedSchema := defaultSchema
388+ if embed .Schema != "" {
389+ embedSchema = embed .Schema
390+ }
391+
392+ // compare the other attributes
393+ if embed .Catalog != s .Table .Catalog || embed .Name != s .Table .Name || embedSchema != s .Table .Schema {
394+ continue
395+ }
396+
397+ fields := make ([]Field ,len (s .Fields ))
398+ for i ,f := range s .Fields {
399+ fields [i ]= f
400+ }
401+
402+ return & pyEmbed {
403+ modelType :s .Name ,
404+ modelName :s .Name ,
405+ fields :fields ,
406+ }
407+ }
408+
409+ return nil
339410}
340411
341412func columnsToStruct (req * plugin.CodeGenRequest ,name string ,columns []pyColumn )* Struct {
@@ -359,10 +430,22 @@ func columnsToStruct(req *plugin.CodeGenRequest, name string, columns []pyColumn
359430if suffix > 0 {
360431fieldName = fmt .Sprintf ("%s_%d" ,fieldName ,suffix )
361432}
362- gs .Fields = append (gs .Fields ,Field {
433+
434+ f := Field {
363435Name :fieldName ,
364436Type :makePyType (req ,c .Column ),
365- })
437+ }
438+
439+ if c .embed != nil {
440+ f .Type = pyType {
441+ InnerType :"models." + modelName (c .embed .modelType ,req .Settings ),
442+ IsArray :false ,
443+ IsNull :false ,
444+ }
445+ f .EmbedFields = c .embed .fields
446+ }
447+
448+ gs .Fields = append (gs .Fields ,f )
366449seen [colName ]++
367450}
368451return & gs
@@ -476,6 +559,7 @@ func buildQueries(conf Config, req *plugin.CodeGenRequest, structs []Struct) ([]
476559columns = append (columns ,pyColumn {
477560id :int32 (i ),
478561Column :c ,
562+ embed :newGoEmbed (c .EmbedTable ,structs ,req .Catalog .DefaultSchema ),
479563})
480564}
481565gs = columnsToStruct (req ,query .Name + "Row" ,columns )