Reputation: 1
I am working on the classification of Imagenet DataSet on AlexNet architecture. I am working on distributed systems for data streams. I am using DeepLearning4j library. I have a problem with loading Imagenet data from a path on our HPC. So my current, normally loading data method is:
FileSplit fileSplit= new FileSplit(new File("/scratch/imagenet/ILSVRC2012/train"), NativeImageLoader.ALLOWED_FORMATS);
int imageHeightWidth = 224; //224x224 pixel input
int imageChannels = 3; //RGB
PathLabelGenerator labelMaker = new ParentPathLabelGenerator();
ImageRecordReader rr = new ImageRecordReader(imageHeightWidth, imageHeightWidth, imageChannels, labelMaker);
System.out.println("initialization");
rr.initialize(fileSplit);
System.out.println("iterator");
DataSetIterator iter = new RecordReaderDataSetIterator.Builder(rr, minibatch)
.classification(1, 1000)
.preProcessor(new ImagePreProcessingScaler()) //For normalization of image values 0-255 to 0-1
.build();
System.out.println("data list creator");
List<DataSet> dataList = new ArrayList<>();
while (iter.hasNext()){
dataList.add(iter.next());
}
And this is my try to load the dataset via spark. labels list contain all the labels of Imagenet Dataset but I didn't copy them all here:
JavaSparkContext sc = SparkContext.initSparkContext(useSparkLocal);
//load data just one time
System.out.println("load data");
List<String> labelsList = Arrays.asList("kit fox, Vulpes macrotis " , "English setter " , "Australian terrier ");
String folder= "/scratch/imagenet/ILSVRC2012/train/*";
File f = new File(folder);
String path = f.getPath();
path=folder+"/*";
JavaPairRDD<String, PortableDataStream> origData = sc.binaryFiles(path);
int imageHeightWidth = 224; //224x224 pixel input
int imageChannels = 3; //RGB
PathLabelGenerator labelMaker = new ParentPathLabelGenerator();
ImageRecordReader rr = new ImageRecordReader(imageHeightWidth, imageHeightWidth, imageChannels, labelMaker);
System.out.println("initialization");
rr.setLabels(labelsList);
RecordReaderFunction rrf = new org.datavec.spark.functions.RecordReaderFunction(rr);
JavaRDD<List<Writable>> rdd = origData.map(rrf);
JavaRDD<DataSet> data = rdd.map(new DataVecDataSetFunction(1, 1000, false));
List<DataSet> collected = data.collect();
By the way, in the train directory there is 1000 folders (n01440764, n01755581, n02012849, n02097658 ...) in which we find the images. I need this parallelization since the load of the data itself took around 26h and it's not efficient. So could you help me with correcting me my try method?
Upvotes: 0
Views: 105
Reputation: 3205
For spark I would recommend pre vectorizing all of the data and just loading the ndarrays themselves directly. We cover this approach in our examples: https://github.com/eclipse/deeplearning4j-examples/blob/master/dl4j-distributed-training-examples/
I would recommend this approach and just loading the pre created datasets using a map call after that where ideally you setup the batches relative to your number of workers available. Datasets have a save(..) load(..) you can use.
In order to implement this consider using: SparkDataUtils.createFileBatchesSpark(JavaRDD filePaths, final String rootOutputDir, final int batchSize, @NonNull final org.apache.hadoop.conf.Configuration hadoopConfig)
This takes in filepaths, an output directory on HDFS, a pre configured batch size and a hadoop configuration for accessing your cluster.
Here is a snippet from the relevant java doc to get you started on some of the concepts:
{@code
* JavaSparkContext sc = ...
* SparkDl4jMultiLayer net = ...
* String baseFileBatchDir = ...
* JavaRDD<String> paths = org.deeplearning4j.spark.util.SparkUtils.listPaths(sc, baseFileBatchDir);
*
* //Image record reader:
* PathLabelGenerator labelMaker = new ParentPathLabelGenerator();
* ImageRecordReader rr = new ImageRecordReader(32, 32, 1, labelMaker);
* rr.setLabels(<labels here>);
*
* //Create DataSetLoader:
* int batchSize = 32;
* int numClasses = 1000;
* DataSetLoader loader = RecordReaderFileBatchLoader(rr, batchSize, 1, numClasses);
*
* //Fit the network
* net.fitPaths(paths, loader);
Upvotes: 0