Movatterモバイル変換


[0]ホーム

URL:


Text-summarization

Intro

First, we need to installblurr module forTransformersintegration.

reticulate::py_install('ohmeow-blurr',pip =TRUE)

Dataset

Get dataset from the link:

library(fastai)library(magrittr)library(zeallot)cnndm_df= data.table::fread('https://raw.githubusercontent.com/ohmeow/blurr/master/nbs/cnndm_sample.csv')

Pretrained model

Get pretrained BART model fromtransformers:

transformers=transformers()BartForConditionalGeneration= transformers$BartForConditionalGenerationpretrained_model_name="facebook/bart-large-cnn"c(hf_arch, hf_config, hf_tokenizer, hf_model)%<-%get_hf_objects(pretrained_model_name,model_cls=BartForConditionalGeneration)

Datablock

Create datablock and see batch:

before_batch_tfm=HF_SummarizationBeforeBatchTransform(hf_arch,                                                        hf_tokenizer,max_length=c(256,130))blocks=list(HF_Text2TextBlock(before_batch_tfms=before_batch_tfm,input_return_type=HF_SummarizationInput),noop())dblock=DataBlock(blocks=blocks,get_x=ColReader('article'),get_y=ColReader('highlights'),splitter=RandomSplitter())dls= dblock%>%dataloaders(cnndm_df,bs=2)dls%>%one_batch()
[[1]][[1]]$input_idstensor([[    0,    36, 16256,    43,   480,  2193,     7,    62,     7,   158,           135,     9,    70,   684,  4707,     6,  1625,    16,  4984,    25,            65,     9,     5,   144, 39865, 32273,  3806,    15,     5,  5518,             4,    20,  9544,  3455,     9,  2147,   464,     8,  1050, 29779,         26284,    15,  1632, 11534,    32,     6,   959,     6,  5608,     5,          8066,     9,     5,   247,    18,  4066,  7892,     4,   178,    89,            16,    10,   372,   432,     7,  2217,     4,    96,     5,   315,          3076,  9356,  4928,    36,  4154,  9662,    43,   623, 12978, 23853,          2521,    18,   889,     9, 10721,   625, 32273,   749,  1625,  5546,           365,   212,     4,    20,   889,  3372,    10,   333,     9,   601,           749,    14, 19655,     5,  1647,     9,     5,  3875,    18,  4707,             8,    32,  3891,  1687,  2778, 39865, 32273,     4,  1740,    63,         23491, 28919,    11,     5,  7599,  3939,     7,    63, 10602, 41166,          1634,    11, 12718,  1115,   281,     8,     5,   854,  3964, 21785,         14113,     8,    63, 38905,     8, 21950, 19947,    11,     5,  1926,             6,  1625, 10902,    41,  5784,  4066,  3143,     9, 39456,     8,           856, 25729,     4,   993,   195,  5243,    66,     9,   262,  1360,         38950,  1848,  4707,   303,    11,  1625,   480,     5,   144,    11,           143,   247,   480,    64,   129,    28, 13590,   624,    63,  7562,             4,    85,    16,   184,     7, 39179,  3505,     9, 27772,     6,         28482,  4707,     9,  7723,     6,   112,     6,  6115, 17576,     9,          7723,     8,   973,     6,   151,  1380, 14868,     9,  3451,     4,           221,  2839,  8367,   102,     6,    10,   786,    12,  7699,  1651,            14,  1364,     7,  3720,  8360,     8,  5068,   709,    11,  1625,             6,    34,  3919,   411,  4707,    61,    24,   161,  7648,  2072,             5,  1272,  2713,    30,     5,     2],        [    0,    36, 16256,    43,   480,  6871,    24,  1655, 22049,    50,         41571,  1896,    54, 24509,    13,   244,     5,   363,     5,   601,            12,   180,    12,   279,  1896,    21,   738,  1462,   116,   280,           115,  6723,    15,    61,   985,     5,  3940,  2046,     4,  1868,         22049,    18,     8,  1896,    18,  8826,  2327,   117, 28946,   273,            11,  2559,   461,  4961,    25,     7,  1060, 28604,  2236,    16,          1317, 11347,   148,    10,  7158,   486,    31,    14,   902,   973,             6,  1125,     6,   363,    11, 23058,     6,  1261,    35,  4028,            26,    24,    21,    69,   979,     4,   280, 34920,   480,    19,          5767,  3809,  1243, 21605, 13875,    24,    21,    69,   979,     6,         41571,     6,    54, 16670,    66,     6,   150, 15151,  2459, 22049,            26,    24,    21,    69,   979,     6,  1655,     6,    54,    21,         16600,    71,   145,  4487,    30,     5,  6066,   480,    21,  1353,             7,   273,    18,   461,  7069,     6,     8,  1353,     7,     5,           200,    12,  5743,  1900,   403, 25451,    11,  1353,  1261,     4,         22049,    34,  4407,    45,  2181,     8,  1695,    37,   738,     5,          7044,    11,  1403,    12, 21409,     4,    20,  7158,   486,   702,          2330,    11,   461,    15,   273,     6,    39,  3969,  2026,     6,           124,    62,    49, 19395,    14,    24,    21,  1896,     6,     8,            45,    49,  3653,     6,    54,    21,     5, 29593,   368,     4,          4500,  4945,   628,   273,  1390,     6, 15151,  2459, 22049,    26,            79,    21,   686,  1655,    21,     5,    65, 16600,     4,  2612,           116, 41039, 10105,    37,    18,   127,   979, 39732,   264,  7173,         41039,  1250,     9,     5,  1065, 48149,    77,   553,   549,    79,            56,   655,   137,  1317,    69,  1655, 22049,  7923, 24120,    50,          8930,    66,    13,   244,     4,     2]], device='cuda:0')[[1]]$attention_masktensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')[[1]]$decoder_input_idstensor([[    0,  1625,  4452,     7,    62,     7,   158,   135,     9,    70,           684,  4707,    15,  3875,   479, 50118,   243,    16,   184,     7,         39179,  3505,     9, 27772,     6, 28482,  5103,  4707,     8,   973,             6,   151,  3505,     9,  3451,   479, 50118, 33837,   709,     8,          2147,   464,    16,  9405,    10,   380,  8793,    15,    63, 28809,           479, 50118,   133,  3274,  9587,    16,   223,  1856,    11, 14117,             9,   145,     5,   247,    18,   632,  7648,   479,     2,     1,             1,     1,     1,     1,     1,     1],        [    0,  5178,    35,    83,  1443,  2470,   161,    97,  1283,     6,            45,     5,  7158,   486,     6,    40,  3094,     5,   403,   479,         50118,  5341,    35,    83,  2470,    13,  1896,    18,   284,   161,            37,  4265,     5,  3940,    40,   465, 22049,  2181,   479, 50118,           534, 15356,  2459, 22049,   161,    79,  2215,     5, 28604,  2236,            16,    14,     9,    69,   979,   479, 50118, 30541,     6, 41571,          1896,    18, 18631,  1843,    26,    14,    24,    69,   979,    18,          2236,    15,     5,  7158,   486,   479]], device='cuda:0')[[1]]$labelstensor([[ 1625,  4452,     7,    62,     7,   158,   135,     9,    70,   684,          4707,    15,  3875,   479, 50118,   243,    16,   184,     7, 39179,          3505,     9, 27772,     6, 28482,  5103,  4707,     8,   973,     6,           151,  3505,     9,  3451,   479, 50118, 33837,   709,     8,  2147,           464,    16,  9405,    10,   380,  8793,    15,    63, 28809,   479,         50118,   133,  3274,  9587,    16,   223,  1856,    11, 14117,     9,           145,     5,   247,    18,   632,  7648,   479,     2,  -100,  -100,          -100,  -100,  -100,  -100,  -100,  -100],        [ 5178,    35,    83,  1443,  2470,   161,    97,  1283,     6,    45,             5,  7158,   486,     6,    40,  3094,     5,   403,   479, 50118,          5341,    35,    83,  2470,    13,  1896,    18,   284,   161,    37,          4265,     5,  3940,    40,   465, 22049,  2181,   479, 50118,   534,         15356,  2459, 22049,   161,    79,  2215,     5, 28604,  2236,    16,            14,     9,    69,   979,   479, 50118, 30541,     6, 41571,  1896,            18, 18631,  1843,    26,    14,    24,    69,   979,    18,  2236,            15,     5,  7158,   486,   479,     2]], device='cuda:0')[[2]]tensor([[ 1625,  4452,     7,    62,     7,   158,   135,     9,    70,   684,          4707,    15,  3875,   479, 50118,   243,    16,   184,     7, 39179,          3505,     9, 27772,     6, 28482,  5103,  4707,     8,   973,     6,           151,  3505,     9,  3451,   479, 50118, 33837,   709,     8,  2147,           464,    16,  9405,    10,   380,  8793,    15,    63, 28809,   479,         50118,   133,  3274,  9587,    16,   223,  1856,    11, 14117,     9,           145,     5,   247,    18,   632,  7648,   479,     2,  -100,  -100,          -100,  -100,  -100,  -100,  -100,  -100],        [ 5178,    35,    83,  1443,  2470,   161,    97,  1283,     6,    45,             5,  7158,   486,     6,    40,  3094,     5,   403,   479, 50118,          5341,    35,    83,  2470,    13,  1896,    18,   284,   161,    37,          4265,     5,  3940,    40,   465, 22049,  2181,   479, 50118,   534,         15356,  2459, 22049,   161,    79,  2215,     5, 28604,  2236,    16,            14,     9,    69,   979,   479, 50118, 30541,     6, 41571,  1896,            18, 18631,  1843,    26,    14,    24,    69,   979,    18,  2236,            15,     5,  7158,   486,   479,     2]], device='cuda:0')

Train

Define aLearner object and train the model:

text_gen_kwargs= hf_config$task_specific_params['summarization'][[1]]text_gen_kwargs['max_length']= 130L; text_gen_kwargs['min_length']= 30Ltext_gen_kwargsmodel=HF_BaseModelWrapper(hf_model)model_cb=HF_SummarizationModelCallback(text_gen_kwargs=text_gen_kwargs)learn=Learner(dls,                model,opt_func=partial(Adam),loss_func=CrossEntropyLossFlat(),#HF_PreCalculatedLoss()cbs=model_cb,splitter=partial(summarization_splitter,arch=hf_arch))#.to_native_fp16() #.to_fp16()learn$create_opt()learn$freeze()learn%>%fit_one_cycle(1,lr_max=4e-5)

Conclusion

Predict a new text:

test_article=c("About 10 men armed with pistols and small machine guns raided a casino in Switzerlandand made off into France with several hundred thousand Swiss francs in the early hoursof Sunday morning, police said. The men, dressed in black clothes and black ski masks,split into two groups during the raid on the Grand Casino Basel, Chief Inspector PeterGill told CNN. One group tried to break into the casino's vault on the lower levelbut could not get in, but they did rob the cashier of the money that was not secured,he said. The second group of armed robbers entered the upper level where the rouletteand blackjack tables are located and robbed the cashier there, he said. As the thieveswere leaving the casino, a woman driving by and unaware of what was occurring unknowinglyblocked the armed robbers' vehicles. A gunman pulled the woman from her vehicle, beather, and took off for the French border. The other gunmen followed into France, whichis only about 100 meters (yards) from the casino, Gill said. There were about 600 peoplein the casino at the time of the robbery. There were no serious injuries, although oneguest on the Casino floor was kicked in the head by one of the robbers when he moved,the police officer said. Swiss authorities are working closely with French authorities,Gill said. The robbers spoke French and drove vehicles with French lRicense plates.CNN's Andreena Narayan contributed to this report.")
outputs= learn$blurr_summarize(test_article,num_return_sequences=3L)cat(outputs)
 Robbers made off with several hundred thousand Swiss francs in the early hours of Sunday morning, police say .The men, dressed in black clothes and black ski masks, split into two groups during the raid on the Grand Casino Basel .There were no serious injuries, although one guest was kicked in the head by one of the robbers when he moved, police officer says .A woman driving by unknowingly blocked the robbers' vehicles and was beaten to death .Swiss authorities are working closely with French authorities, police chief says . Robbers made off with several hundred thousand Swiss francs in the early hours of Sunday morning, police say .The men, dressed in black clothes and black ski masks, split into two groups during the raid on the Grand Casino Basel .There were no serious injuries, although one guest was kicked in the head by one of the robbers when he moved, police officer says .A woman driving by unknowingly blocked the robbers' vehicles and was beaten to death .Swiss authorities are working closely with French authorities, an officer says. Robbers made off with several hundred thousand Swiss francs in the early hours of Sunday morning, police say .The men, dressed in black clothes and black ski masks, split into two groups during the raid on the Grand Casino Basel .There were no serious injuries, although one guest was kicked in the head by one of the robbers when he moved, police officer says .A woman driving by unknowingly blocked the robbers' vehicles and was beaten to death .

[8]ページ先頭

©2009-2025 Movatter.jp