Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitd9d1089

Browse files
authored
Merge pull request#80 from Sergio0694/dev
Batch normalization, APIs adjustments
2 parentse9170bd +5936f85 commitd9d1089

File tree

39 files changed

+2130
-308
lines changed

39 files changed

+2130
-308
lines changed

‎NeuralNetwork.NET/APIs/CuDnnNetworkLayers.cs‎

Lines changed: 10 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
1-
usingSystem;
2-
usingSystem.Linq;
3-
usingJetBrains.Annotations;
1+
usingJetBrains.Annotations;
42
usingNeuralNetworkNET.APIs.Delegates;
53
usingNeuralNetworkNET.APIs.Enums;
64
usingNeuralNetworkNET.APIs.Structs;
7-
usingNeuralNetworkNET.Extensions;
5+
usingNeuralNetworkNET.cuDNN;
86
usingNeuralNetworkNET.Networks.Layers.Cuda;
97

108
namespaceNeuralNetworkNET.APIs
@@ -17,22 +15,7 @@ public static class CuDnnNetworkLayers
1715
/// <summary>
1816
/// Gets whether or not the Cuda acceleration is supported on the current system
1917
/// </summary>
20-
publicstaticboolIsCudaSupportAvailable
21-
{
22-
get
23-
{
24-
try
25-
{
26-
// Calling this directly would could a crash in the <Module> loader due to the missing .dll files
27-
returnCuDnnSupportHelper.IsGpuAccelerationSupported();
28-
}
29-
catch(TypeInitializationException)
30-
{
31-
// Missing .dll file
32-
returnfalse;
33-
}
34-
}
35-
}
18+
publicstaticboolIsCudaSupportAvailable=>CuDnnService.IsAvailable;
3619

3720
/// <summary>
3821
/// Creates a new fully connected layer with the specified number of input and output neurons, and the given activation function
@@ -132,41 +115,14 @@ public static LayerFactory Convolutional(
132115
publicstaticLayerFactoryInception(InceptionInfoinfo,BiasInitializationModebiasMode=BiasInitializationMode.Zero)
133116
=> input=>newCuDnnInceptionLayer(input,info,biasMode);
134117

135-
#region Feature helper
136-
137118
/// <summary>
138-
///A private class that is used to createa newstandalone type that contains the actual test method (decoupling is needed to &lt;Module&gt; loading crashes)
119+
///Createsa newbatch normalization layer
139120
/// </summary>
140-
privatestaticclassCuDnnSupportHelper
141-
{
142-
/// <summary>
143-
/// Checks whether or not the Cuda features are currently supported
144-
/// </summary>
145-
publicstaticboolIsGpuAccelerationSupported()
146-
{
147-
try
148-
{
149-
// CUDA test
150-
Alea.Gpugpu=Alea.Gpu.Default;
151-
if(gpu==null)returnfalse;
152-
if(!Alea.cuDNN.Dnn.IsAvailable)returnfalse;// cuDNN
153-
using(Alea.DeviceMemory<float>sample_gpu=gpu.AllocateDevice<float>(1024))
154-
{
155-
Alea.deviceptr<float>ptr=sample_gpu.Ptr;
156-
voidKernel(inti)=>ptr[i]=i;
157-
Alea.Parallel.GpuExtension.For(gpu,0,1024,Kernel);// JIT test
158-
float[]sample=Alea.Gpu.CopyToHost(sample_gpu);
159-
returnEnumerable.Range(0,1024).Select<int,float>(i=>i).ToArray().ContentEquals(sample);
160-
}
161-
}
162-
catch
163-
{
164-
// Missing .dll or other errors
165-
returnfalse;
166-
}
167-
}
168-
}
169-
170-
#endregion
121+
/// <param name="mode">The normalization mode to use for the new layer</param>
122+
/// <param name="activation">The desired activation function to use in the network layer</param>
123+
[PublicAPI]
124+
[Pure,NotNull]
125+
publicstaticLayerFactoryBatchNormalization(NormalizationModemode,ActivationTypeactivation)
126+
=> input=>newCuDnnBatchNormalizationLayer(input,mode,activation);
171127
}
172128
}

‎NeuralNetwork.NET/APIs/DatasetLoader.cs‎

Lines changed: 69 additions & 24 deletions
Large diffs are not rendered by default.

‎NeuralNetwork.NET/APIs/Datasets/Cifar10.cs‎

Lines changed: 70 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
usingNeuralNetworkNET.Extensions;
1010
usingNeuralNetworkNET.Helpers;
1111
usingNeuralNetworkNET.SupervisedLearning.Progress;
12+
usingSixLabors.ImageSharp;
13+
usingSixLabors.ImageSharp.Advanced;
14+
usingSixLabors.ImageSharp.PixelFormats;
1215

1316
namespaceNeuralNetworkNET.APIs.Datasets
1417
{
@@ -25,11 +28,14 @@ public static class Cifar10
2528
// 32*32 RGB images
2629
privateconstintSampleSize=3072;
2730

31+
// A single 32*32 image
32+
privateconstintImageSize=1024;
33+
2834
privateconstStringDatasetURL="https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz";
2935

3036
[NotNull,ItemNotNull]
3137
privatestaticreadonlyIReadOnlyList<String>TrainingBinFilenames=Enumerable.Range(1,5).Select(i=>$"data_batch_{i}.bin").ToArray();
32-
38+
3339
privateconstStringTestBinFilename="test_batch.bin";
3440

3541
#endregion
@@ -38,12 +44,13 @@ public static class Cifar10
3844
/// Downloads the CIFAR-10 training datasets and returns a new <see cref="ITestDataset"/> instance
3945
/// </summary>
4046
/// <param name="size">The desired dataset batch size</param>
47+
/// <param name="callback">The optional progress calback</param>
4148
/// <param name="token">An optional cancellation token for the operation</param>
4249
[PublicAPI]
4350
[Pure,ItemCanBeNull]
44-
publicstaticasyncTask<ITrainingDataset>GetTrainingDatasetAsync(intsize,CancellationTokentoken=default)
51+
publicstaticasyncTask<ITrainingDataset>GetTrainingDatasetAsync(intsize,[CanBeNull]IProgress<HttpProgress>callback=null,CancellationTokentoken=default)
4552
{
46-
IReadOnlyDictionary<String,Func<Stream>>map=awaitDatasetsDownloader.GetArchiveAsync(DatasetURL,token);
53+
IReadOnlyDictionary<String,Func<Stream>>map=awaitDatasetsDownloader.GetArchiveAsync(DatasetURL,callback,token);
4754
if(map==null)returnnull;
4855
IReadOnlyList<(float[],float[])>[]data=newIReadOnlyList<(float[],float[])>[TrainingBinFilenames.Count];
4956
Parallel.For(0,TrainingBinFilenames.Count, i=>data[i]=ParseSamples(map[TrainingBinFilenames[i]],TrainingSamplesInBinFiles)).AssertCompleted();
@@ -54,25 +61,45 @@ public static async Task<ITrainingDataset> GetTrainingDatasetAsync(int size, Can
5461
/// Downloads the CIFAR-10 test datasets and returns a new <see cref="ITestDataset"/> instance
5562
/// </summary>
5663
/// <param name="progress">The optional progress callback to use</param>
64+
/// <param name="callback">The optional progress calback</param>
5765
/// <param name="token">An optional cancellation token for the operation</param>
5866
[PublicAPI]
5967
[Pure,ItemCanBeNull]
60-
publicstaticasyncTask<ITestDataset>GetTestDatasetAsync([CanBeNull]Action<TrainingProgressEventArgs>progress=null,CancellationTokentoken=default)
68+
publicstaticasyncTask<ITestDataset>GetTestDatasetAsync([CanBeNull]Action<TrainingProgressEventArgs>progress=null,[CanBeNull]IProgress<HttpProgress>callback=null,CancellationTokentoken=default)
6169
{
62-
IReadOnlyDictionary<String,Func<Stream>>map=awaitDatasetsDownloader.GetArchiveAsync(DatasetURL,token);
70+
IReadOnlyDictionary<String,Func<Stream>>map=awaitDatasetsDownloader.GetArchiveAsync(DatasetURL,callback,token);
6371
if(map==null)returnnull;
6472
IReadOnlyList<(float[],float[])>data=ParseSamples(map[TestBinFilename],TrainingSamplesInBinFiles);
6573
returnDatasetLoader.Test(data,progress);
6674
}
6775

76+
/// <summary>
77+
/// Downloads and exports the full CIFAR-10 dataset (both training and test samples) to the target directory
78+
/// </summary>
79+
/// <param name="directory">The target directory</param>
80+
/// <param name="token">The cancellation token for the operation</param>
81+
[PublicAPI]
82+
publicstaticasyncTask<bool>ExportDatasetAsync([NotNull]DirectoryInfodirectory,CancellationTokentoken=default)
83+
{
84+
IReadOnlyDictionary<String,Func<Stream>>map=awaitDatasetsDownloader.GetArchiveAsync(DatasetURL,null,token);
85+
if(map==null)returnfalse;
86+
if(!directory.Exists)directory.Create();
87+
ParallelLoopResultresult=Parallel.ForEach(TrainingBinFilenames.Concat(new[]{TestBinFilename}),(name,state)=>
88+
{
89+
ExportSamples(directory,(name,map[name]),TrainingSamplesInBinFiles,token);
90+
if(token.IsCancellationRequested)state.Stop();
91+
});
92+
returnresult.IsCompleted&&!token.IsCancellationRequested;
93+
}
94+
6895
#region Tools
6996

7097
/// <summary>
7198
/// Parses a CIFAR-10 .bin file
7299
/// </summary>
73100
/// <param name="factory">A <see cref="Func{TResult}"/> that returns the <see cref="Stream"/> to read</param>
74101
/// <param name="count">The number of samples to parse</param>
75-
privatestaticunsafeIReadOnlyList<(float[],float[])>ParseSamples(Func<Stream>factory,intcount)
102+
privatestaticunsafeIReadOnlyList<(float[],float[])>ParseSamples([NotNull]Func<Stream>factory,intcount)
76103
{
77104
using(Streamstream=factory())
78105
{
@@ -89,8 +116,12 @@ public static async Task<ITestDataset> GetTestDatasetAsync([CanBeNull] Action<Tr
89116
fixed(float*px=x)
90117
{
91118
stream.Read(temp,0,SampleSize);
92-
for(intj=0;j<SampleSize;j++)
119+
for(intj=0;j<ImageSize;j++)
120+
{
93121
px[j]=ptemp[j]/255f;// Normalized samples
122+
px[j]=ptemp[j+ImageSize]/255f;
123+
px[j]=ptemp[j+2*ImageSize]/255f;
124+
}
94125
}
95126
data[i]=(x,y);
96127
}
@@ -99,6 +130,38 @@ public static async Task<ITestDataset> GetTestDatasetAsync([CanBeNull] Action<Tr
99130
}
100131
}
101132

133+
/// <summary>
134+
/// Exports a CIFAR-10 .bin file
135+
/// </summary>
136+
/// <param name="folder">The target folder to use to save the images</param>
137+
/// <param name="source">The source filename and a <see cref="Func{TResult}"/> that returns the <see cref="Stream"/> to read</param>
138+
/// <param name="count">The number of samples to parse</param>
139+
/// <param name="token">A token for the operation</param>
140+
privatestaticunsafevoidExportSamples([NotNull]DirectoryInfofolder,(StringName,Func<Stream>Factory)source,intcount,CancellationTokentoken)
141+
{
142+
using(Streamstream=source.Factory())
143+
{
144+
byte[]temp=newbyte[SampleSize];
145+
fixed(byte*ptemp=temp)
146+
{
147+
for(inti=0;i<count;i++)
148+
{
149+
if(token.IsCancellationRequested)return;
150+
intlabel=stream.ReadByte();
151+
stream.Read(temp,0,SampleSize);
152+
using(Image<Rgb24>image=newImage<Rgb24>(32,32))
153+
fixed(Rgb24*p0=&image.DangerousGetPinnableReferenceToPixelBuffer())
154+
{
155+
for(intj=0;j<ImageSize;j++)
156+
p0[j]=newRgb24(ptemp[j],ptemp[j+ImageSize],ptemp[j+2*ImageSize]);
157+
using(FileStreamfile=File.OpenWrite(Path.Combine(folder.FullName,$"[{source.Name}][{i}][{label}].bmp")))
158+
image.SaveAsBmp(file);
159+
}
160+
}
161+
}
162+
}
163+
}
164+
102165
#endregion
103166
}
104167
}

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp