@@ -3242,6 +3242,174 @@ def test_to_dataframe_w_bqstorage_snapshot(self):
32423242with pytest .raises (ValueError ):
32433243row_iterator .to_dataframe (bqstorage_client )
32443244
3245+ @unittest .skipIf (pandas is None ,"Requires `pandas`" )
3246+ @unittest .skipIf (
3247+ bigquery_storage_v1beta1 is None ,"Requires `google-cloud-bigquery-storage`"
3248+ )
3249+ @unittest .skipIf (pyarrow is None ,"Requires `pyarrow`" )
3250+ def test_to_dataframe_concat_categorical_dtype_w_pyarrow (self ):
3251+ from google .cloud .bigquery import schema
3252+ from google .cloud .bigquery import table as mut
3253+ from google .cloud .bigquery_storage_v1beta1 import reader
3254+
3255+ arrow_fields = [
3256+ # Not alphabetical to test column order.
3257+ pyarrow .field ("col_str" ,pyarrow .utf8 ()),
3258+ # The backend returns strings, and without other info, pyarrow contains
3259+ # string data in categorical columns, too (and not maybe the Dictionary
3260+ # type that corresponds to pandas.Categorical).
3261+ pyarrow .field ("col_category" ,pyarrow .utf8 ()),
3262+ ]
3263+ arrow_schema = pyarrow .schema (arrow_fields )
3264+
3265+ # create a mock BQ storage client
3266+ bqstorage_client = mock .create_autospec (
3267+ bigquery_storage_v1beta1 .BigQueryStorageClient
3268+ )
3269+ bqstorage_client .transport = mock .create_autospec (
3270+ big_query_storage_grpc_transport .BigQueryStorageGrpcTransport
3271+ )
3272+ session = bigquery_storage_v1beta1 .types .ReadSession (
3273+ streams = [{"name" :"/projects/proj/dataset/dset/tables/tbl/streams/1234" }],
3274+ arrow_schema = {"serialized_schema" :arrow_schema .serialize ().to_pybytes ()},
3275+ )
3276+ bqstorage_client .create_read_session .return_value = session
3277+
3278+ mock_rowstream = mock .create_autospec (reader .ReadRowsStream )
3279+ bqstorage_client .read_rows .return_value = mock_rowstream
3280+
3281+ # prepare the iterator over mocked rows
3282+ mock_rows = mock .create_autospec (reader .ReadRowsIterable )
3283+ mock_rowstream .rows .return_value = mock_rows
3284+ page_items = [
3285+ [
3286+ pyarrow .array (["foo" ,"bar" ,"baz" ]),# col_str
3287+ pyarrow .array (["low" ,"medium" ,"low" ]),# col_category
3288+ ],
3289+ [
3290+ pyarrow .array (["foo_page2" ,"bar_page2" ,"baz_page2" ]),# col_str
3291+ pyarrow .array (["medium" ,"high" ,"low" ]),# col_category
3292+ ],
3293+ ]
3294+
3295+ mock_pages = []
3296+
3297+ for record_list in page_items :
3298+ page_record_batch = pyarrow .RecordBatch .from_arrays (
3299+ record_list ,schema = arrow_schema
3300+ )
3301+ mock_page = mock .create_autospec (reader .ReadRowsPage )
3302+ mock_page .to_arrow .return_value = page_record_batch
3303+ mock_pages .append (mock_page )
3304+
3305+ type(mock_rows ).pages = mock .PropertyMock (return_value = mock_pages )
3306+
3307+ schema = [
3308+ schema .SchemaField ("col_str" ,"IGNORED" ),
3309+ schema .SchemaField ("col_category" ,"IGNORED" ),
3310+ ]
3311+
3312+ row_iterator = mut .RowIterator (
3313+ _mock_client (),
3314+ None ,# api_request: ignored
3315+ None ,# path: ignored
3316+ schema ,
3317+ table = mut .TableReference .from_string ("proj.dset.tbl" ),
3318+ selected_fields = schema ,
3319+ )
3320+
3321+ # run the method under test
3322+ got = row_iterator .to_dataframe (
3323+ bqstorage_client = bqstorage_client ,
3324+ dtypes = {
3325+ "col_category" :pandas .core .dtypes .dtypes .CategoricalDtype (
3326+ categories = ["low" ,"medium" ,"high" ],ordered = False ,
3327+ ),
3328+ },
3329+ )
3330+
3331+ # Are the columns in the expected order?
3332+ column_names = ["col_str" ,"col_category" ]
3333+ self .assertEqual (list (got ),column_names )
3334+
3335+ # Have expected number of rows?
3336+ total_pages = len (mock_pages )# we have a single stream, thus these two equal
3337+ total_rows = len (page_items [0 ][0 ])* total_pages
3338+ self .assertEqual (len (got .index ),total_rows )
3339+
3340+ # Are column types correct?
3341+ expected_dtypes = [
3342+ pandas .core .dtypes .dtypes .np .dtype ("O" ),# the default for string data
3343+ pandas .core .dtypes .dtypes .CategoricalDtype (
3344+ categories = ["low" ,"medium" ,"high" ],ordered = False ,
3345+ ),
3346+ ]
3347+ self .assertEqual (list (got .dtypes ),expected_dtypes )
3348+
3349+ # And the data in the categorical column?
3350+ self .assertEqual (
3351+ list (got ["col_category" ]),
3352+ ["low" ,"medium" ,"low" ,"medium" ,"high" ,"low" ],
3353+ )
3354+
3355+ # Don't close the client if it was passed in.
3356+ bqstorage_client .transport .channel .close .assert_not_called ()
3357+
3358+ @unittest .skipIf (pandas is None ,"Requires `pandas`" )
3359+ def test_to_dataframe_concat_categorical_dtype_wo_pyarrow (self ):
3360+ from google .cloud .bigquery .schema import SchemaField
3361+
3362+ schema = [
3363+ SchemaField ("col_str" ,"STRING" ),
3364+ SchemaField ("col_category" ,"STRING" ),
3365+ ]
3366+ row_data = [
3367+ [u"foo" ,u"low" ],
3368+ [u"bar" ,u"medium" ],
3369+ [u"baz" ,u"low" ],
3370+ [u"foo_page2" ,u"medium" ],
3371+ [u"bar_page2" ,u"high" ],
3372+ [u"baz_page2" ,u"low" ],
3373+ ]
3374+ path = "/foo"
3375+
3376+ rows = [{"f" : [{"v" :field }for field in row ]}for row in row_data [:3 ]]
3377+ rows_page2 = [{"f" : [{"v" :field }for field in row ]}for row in row_data [3 :]]
3378+ api_request = mock .Mock (
3379+ side_effect = [{"rows" :rows ,"pageToken" :"NEXTPAGE" }, {"rows" :rows_page2 }]
3380+ )
3381+
3382+ row_iterator = self ._make_one (_mock_client (),api_request ,path ,schema )
3383+
3384+ with mock .patch ("google.cloud.bigquery.table.pyarrow" ,None ):
3385+ got = row_iterator .to_dataframe (
3386+ dtypes = {
3387+ "col_category" :pandas .core .dtypes .dtypes .CategoricalDtype (
3388+ categories = ["low" ,"medium" ,"high" ],ordered = False ,
3389+ ),
3390+ },
3391+ )
3392+
3393+ self .assertIsInstance (got ,pandas .DataFrame )
3394+ self .assertEqual (len (got ),6 )# verify the number of rows
3395+ expected_columns = [field .name for field in schema ]
3396+ self .assertEqual (list (got ),expected_columns )# verify the column names
3397+
3398+ # Are column types correct?
3399+ expected_dtypes = [
3400+ pandas .core .dtypes .dtypes .np .dtype ("O" ),# the default for string data
3401+ pandas .core .dtypes .dtypes .CategoricalDtype (
3402+ categories = ["low" ,"medium" ,"high" ],ordered = False ,
3403+ ),
3404+ ]
3405+ self .assertEqual (list (got .dtypes ),expected_dtypes )
3406+
3407+ # And the data in the categorical column?
3408+ self .assertEqual (
3409+ list (got ["col_category" ]),
3410+ ["low" ,"medium" ,"low" ,"medium" ,"high" ,"low" ],
3411+ )
3412+
32453413
32463414class TestPartitionRange (unittest .TestCase ):
32473415def _get_target_class (self ):