Commit ba8eecab authored by Lukasz Waskiewicz's avatar Lukasz Waskiewicz
Browse files

refs #1376 optymalizacja procesowania cząsteczek

parent f2cf1146
......@@ -118,11 +118,13 @@ public abstract class ProcessBase implements Serializable {
saveDataFrame(frame, table, mode);
}
public void saveDataFrame(DataFrame frame, SparkTableBase<?> table, SaveMode mode) {
String resultTableName = table.getResultTable();
Path path = new Path(outputDirectory, resultTableName);
protected void saveDataFrame(DataFrame frame, SparkTableBase<?> table, SaveMode mode) {
saveDataFrame(frame, table.getResultTable(), table.getPartitionBy(), mode);
}
protected void saveDataFrame(DataFrame frame, String targetDirectory, String[] partitionBy, SaveMode mode) {
Path path = new Path(outputDirectory, targetDirectory);
DataFrameWriter writer = frame.write().mode(mode);
String[] partitionBy= table.getPartitionBy();
if (partitionBy != null && partitionBy.length > 0) {
writer = writer.partitionBy(partitionBy);
}
......
......@@ -60,9 +60,13 @@ public class TablesRegistrator {
}
public <O> JavaRDD<O> mapTable(final SparkTableBase<?> table, JavaRDD<GenericRowWithSchema> result) {
return mapTable(table.getInputClass(), result);
}
protected <O> JavaRDD<O> mapTable(final Class<?> inputClass, JavaRDD<GenericRowWithSchema> result) {
@SuppressWarnings("unchecked")
JavaRDD<O> mapped = result.map(rowWithSchema -> {
O resultObject = (O) Mapper.INSTANCE.mapObject(rowWithSchema, rowWithSchema.schema(), table.getInputClass());
O resultObject = (O) Mapper.INSTANCE.mapObject(rowWithSchema, rowWithSchema.schema(), inputClass);
return resultObject;
}).setName(result.name() + "_obj");
return mapped;
......
......@@ -6,7 +6,6 @@ import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FSDataOutputStream;
......@@ -16,6 +15,7 @@ import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SaveMode;
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema;
import org.apache.spark.sql.types.StructField;
......@@ -41,6 +41,8 @@ import scala.Tuple4;
public class ProcessParticles extends ProcessBase {
private static final int PARTICLE_TYPES_COUNT = 6;
private static final String PARTITIONED_PARTICLES_TEMP_DIR = "partitionedParticles";
private static final long serialVersionUID = -6069305744873068035L;
public ProcessParticles(Path rootPath, Path outputDirectory, double boxSize, int partitionsCount, SparkConf sparkConf) throws Exception {
......@@ -97,18 +99,19 @@ public class ProcessParticles extends ProcessBase {
"snapshot_id=" + snapshotId);
Broadcast<Long> snapshotIdValue = jsc.broadcast(snapshotId);
JavaRDD<GadgetSnapshotFileMetadata> snapshots = (JavaRDD<GadgetSnapshotFileMetadata>) rdds.get(SparkTable.SNAPSHOT_FILEMETADATA);
snapshots = snapshots.filter(s -> snapshotIdValue.getValue().equals(s.getSnapshotNumber()));
JavaPairRDD<Long, Tuple2<Long, Long[]>> fileCounts = mapToPair(snapshots, s -> {
Long sum = s.getNpart().stream().collect(Collectors.summingLong(l -> l));
return new Tuple2<>(s.getFileId(), new Tuple2<>(sum, s.getNpart().toArray(new Long[6])));
JavaRDD<GadgetSnapshotFileMetadata> snapshotsFileMetadata = (JavaRDD<GadgetSnapshotFileMetadata>) rdds.get(SparkTable.SNAPSHOT_FILEMETADATA);
snapshotsFileMetadata = snapshotsFileMetadata.filter(s -> snapshotIdValue.getValue().equals(s.getSnapshotNumber()));
JavaPairRDD<Long, Long[]> fileCounts = mapToPair(snapshotsFileMetadata, s -> {
return new Tuple2<>(s.getFileId(), s.getNpart().toArray(new Long[PARTICLE_TYPES_COUNT]));
});
long offset = 0;
long[] offsets = new long[PARTICLE_TYPES_COUNT];
Map<Long, FileNPartInfo> fileStartMap = new HashMap<Long, FileNPartInfo>();
for (Tuple2<Long, Tuple2<Long, Long[]>> fileCount : fileCounts.sortByKey().collect()) {
fileStartMap.put(fileCount._1, new FileNPartInfo(offset, fileCount._2._2));
offset += fileCount._2._1;
for (Tuple2<Long, Long[]> fileCount : fileCounts.sortByKey().collect()) {
fileStartMap.put(fileCount._1, new FileNPartInfo(offsets, fileCount._2));
for (int i = 0; i < PARTICLE_TYPES_COUNT; i++) {
offsets[i] = offsets[i] + fileCount._2[i];
}
}
final Broadcast<Map<Long, FileNPartInfo>> fileStartMapBroadcast = jsc.broadcast(fileStartMap);
......@@ -119,10 +122,6 @@ public class ProcessParticles extends ProcessBase {
Map<Long, FileNPartInfo> currentStartMap = fileStartMapBroadcast.getValue();
FileNPartInfo fileNPartInfo = currentStartMap.get(v.getFileId());
Long startValue = fileNPartInfo.getOffset();
if (startValue == null) {
System.out.println("No information for fileid: " + v.getFileId());
}
Long[] types = fileNPartInfo.getTypes();
Long fileOrd = v.getFileOrdinal();
Integer type = 0;
......@@ -134,120 +133,160 @@ public class ProcessParticles extends ProcessBase {
type++;
}
Long startValue = fileNPartInfo.getOffsets()[type];
ProcessedParticle processedParticle = new ProcessedParticle(v, startValue, type);
return processedParticle;
});
if (!fofData.isEmpty()) {
processParticlesWithFofs(fofData, haloData, processedParticles);
fofData = fofData.persist(StorageLevel.MEMORY_AND_DISK());
haloData = haloData.persist(StorageLevel.MEMORY_AND_DISK());
DataFrame processedParticlesFrame = sqlContext.createDataFrame(processedParticles, ProcessedParticle.class);
saveDataFrame(processedParticlesFrame, PARTITIONED_PARTICLES_TEMP_DIR, new String[] { "type" }, SaveMode.Overwrite);
Map<String, Integer> fofFieldIndexes = getIndexes(fofData.first().schema());
Broadcast<Map<String, Integer>> fofIndexes = jsc.broadcast(fofFieldIndexes);
long[] fofCounts = fofData.aggregate(new long[PARTICLE_TYPES_COUNT], (count, fof) -> {
for (int i = 0; i < PARTICLE_TYPES_COUNT; i++) {
count[i] = count[i] + fof.getLong(fofIndexes.getValue().get("npart_type_" + i));
}
return count;
}, (a, b) -> {
for (int i = 0; i < PARTICLE_TYPES_COUNT; i++) {
a[i] = a[i] + b[i];
}
return a;
});
for (int i = 0; i < PARTICLE_TYPES_COUNT; i++) {
if (fofCounts[i] > 0) {
processParticlesWithFofs(fofData, haloData, i);
} else {
processParticlesWithoutFofs(i);
}
}
final String processedParticlesTempDir = outputDirectory.toUri().getPath() + "/" + PARTITIONED_PARTICLES_TEMP_DIR;
haloData.unpersist();
fofData.unpersist();
shell.run(new String[] { "-rm", "-r", processedParticlesTempDir });
} else {
processParticlesWithoutFofs(processedParticles);
}
}
private void processParticlesWithFofs(JavaRDD<GenericRowWithSchema> fofData, JavaRDD<GenericRowWithSchema> haloData,
JavaRDD<ProcessedParticle> processedParticles) throws Exception {
processedParticles = processedParticles.persist(StorageLevel.MEMORY_AND_DISK());
private void processParticlesWithFofs(JavaRDD<GenericRowWithSchema> fofData, JavaRDD<GenericRowWithSchema> haloData, final Integer processedType)
throws Exception {
JavaRDD<GenericRowWithSchema> particles = tablesRegistrator.registerTable(outputDirectory, PARTITIONED_PARTICLES_TEMP_DIR, null,
"type=" + processedType);
JavaRDD<ProcessedParticle> processedParticles = tablesRegistrator.mapTable(ProcessedParticle.class, particles);
int particleParallelism = tablesRegistrator.getParallelism(SparkTable.PARTICLE);
Broadcast<Double> boxSize = jsc.broadcast(this.boxSize);
Broadcast<Integer> partitionsCount = jsc.broadcast(this.partitionsCount);
for (int type = 0; type < 6; type++) {
final int processedType = type;
JavaRDD<ProcessedParticle> filteredParticles = filter(processedParticles, p -> p.getType().equals(processedType));
JavaPairRDD<Long, ProcessedParticle> particlesByOrdinal = keyBy(filteredParticles, p -> p.getOrdinal(), "id");
Map<String, Integer> fofFieldIndexes = getIndexes(fofData.first().schema());
Broadcast<Map<String, Integer>> fofIndexes = jsc.broadcast(fofFieldIndexes);
JavaPairRDD<Long, Long> fofParticleCounts = mapToPair(fofData, row -> {
final GenericRowWithSchema fof = (GenericRowWithSchema) row;
Long id = fof.getLong(fofIndexes.getValue().get("fof_id"));
Long particlesCount = fof.getLong(fofIndexes.getValue().get("npart_type_" + processedType));
return new Tuple2<>(id, particlesCount);
});
JavaPairRDD<Long, Iterable<Tuple2<Long, Long>>> groupedFofs = fofParticleCounts.groupBy(t -> (long) (t._1 % 1e12 / 5e3));
groupedFofs = groupedFofs.persist(StorageLevel.MEMORY_AND_DISK());
JavaPairRDD<Long, Long> groups = groupedFofs.mapValues(fofs -> {
long count = 0;
for (Tuple2<Long, Long> fof : fofs) {
count += fof._2;
}
return count;
});
Map<Long, Long> groupStart = new HashMap<Long, Long>();
long sumCount = 0;
for (Tuple2<Long, Long> groupCount : groups.sortByKey().collect()) {
groupStart.put(groupCount._1, sumCount);
sumCount += groupCount._2;
JavaRDD<ProcessedParticle> filteredParticles = filter(processedParticles, p -> p.getType().equals(processedType));
JavaPairRDD<Long, ProcessedParticle> particlesByOrdinal = keyBy(filteredParticles, p -> p.getOrdinal(), "id");
Map<String, Integer> fofFieldIndexes = getIndexes(fofData.first().schema());
Broadcast<Map<String, Integer>> fofIndexes = jsc.broadcast(fofFieldIndexes);
JavaPairRDD<Long, Long> fofParticleCounts = mapToPair(fofData, fof -> {
Long id = fof.getLong(fofIndexes.getValue().get("fof_id"));
Long particlesCount = fof.getLong(fofIndexes.getValue().get("npart_type_" + processedType));
return new Tuple2<>(id, particlesCount);
});
JavaPairRDD<Long, Iterable<Tuple2<Long, Long>>> groupedFofs = fofParticleCounts.groupBy(t -> (long) (t._1 % 1e12 / 5e3));
groupedFofs = groupedFofs.persist(StorageLevel.MEMORY_AND_DISK());
JavaPairRDD<Long, Long> groups = groupedFofs.mapValues(fofs -> {
long count = 0;
for (Tuple2<Long, Long> fof : fofs) {
count += fof._2;
}
return count;
});
Map<Long, Long> groupStart = new HashMap<Long, Long>();
long sumCount = 0;
for (Tuple2<Long, Long> groupCount : groups.sortByKey().collect()) {
groupStart.put(groupCount._1, sumCount);
sumCount += groupCount._2;
}
Broadcast<Map<Long, Long>> groupStartBroadcast = jsc.broadcast(groupStart);
JavaPairRDD<Long, Tuple2<Long, Long>> fofOffsets = groupedFofs.flatMapToPair(fofGroup -> {
List<Tuple2<Long, Tuple2<Long, Long>>> localFofOffsets = new ArrayList<>();
long fofOffset = groupStartBroadcast.getValue().get(fofGroup._1);
List<Tuple2<Long, Long>> sortedGroups = FluentIterable.from(fofGroup._2).toSortedList((a, b) -> a._1().compareTo(b._1()));
for (Tuple2<Long, Long> fof : sortedGroups) {
localFofOffsets.add(new Tuple2<>(fof._1, new Tuple2<>(fofOffset, fof._2)));
fofOffset += fof._2;
}
return localFofOffsets;
});
Broadcast<Map<Long, Long>> groupStartBroadcast = jsc.broadcast(groupStart);
JavaPairRDD<Long, Tuple2<Long, Long>> fofOffsets = groupedFofs.flatMapToPair(fofGroup -> {
List<Tuple2<Long, Tuple2<Long, Long>>> localFofOffsets = new ArrayList<>();
long fofOffset = groupStartBroadcast.getValue().get(fofGroup._1);
List<Tuple2<Long, Long>> sortedGroups = FluentIterable.from(fofGroup._2).toSortedList((a, b) -> a._1().compareTo(b._1()));
for (Tuple2<Long, Long> fof : sortedGroups) {
localFofOffsets.add(new Tuple2<>(fof._1, new Tuple2<>(fofOffset, fof._2)));
fofOffset += fof._2;
}
return localFofOffsets;
});
Map<String, Integer> haloFieldIndexes = getIndexes(haloData.first().schema());
Broadcast<Map<String, Integer>> haloIndexes = jsc.broadcast(haloFieldIndexes);
JavaPairRDD<Long, Tuple2<Long, Long>> haloParticleCounts = mapToPair(haloData, row -> {
final GenericRowWithSchema halo = (GenericRowWithSchema) row;
Long id = halo.getLong(haloIndexes.getValue().get("subhalo_id"));
Long fofId = halo.getLong(haloIndexes.getValue().get("fof_id"));
Long particlesCount = halo.getLong(haloIndexes.getValue().get("npart_type_" + processedType));
return new Tuple2<>(fofId, new Tuple2<>(id, particlesCount));
});
Map<String, Integer> haloFieldIndexes = getIndexes(haloData.first().schema());
Broadcast<Map<String, Integer>> haloIndexes = jsc.broadcast(haloFieldIndexes);
JavaPairRDD<Long, Tuple2<Long, Long>> haloParticleCounts = mapToPair(haloData, row -> {
final GenericRowWithSchema halo = (GenericRowWithSchema) row;
Long id = halo.getLong(haloIndexes.getValue().get("subhalo_id"));
Long fofId = halo.getLong(haloIndexes.getValue().get("fof_id"));
Long particlesCount = halo.getLong(haloIndexes.getValue().get("npart_type_" + processedType));
return new Tuple2<>(fofId, new Tuple2<>(id, particlesCount));
});
JavaRDD<Tuple4<Long, Long, Long, Long>> fofs = fofOffsets.leftOuterJoin(haloParticleCounts.groupByKey()).flatMap(tuple -> {
Long groupStartNpart = tuple._2._1._1;
List<Tuple4<Long, Long, Long, Long>> result = new ArrayList<>();
long currentOffset = groupStartNpart;
Optional<Iterable<Tuple2<Long, Long>>> fofHalos = tuple._2._2;
long fofOffset = tuple._2._1._2;
Long nextOffset = currentOffset;
if (fofHalos.isPresent()) {
List<Tuple2<Long, Long>> subhalos = FluentIterable.from(fofHalos.get()).toSortedList((a, b) -> a._1().compareTo(b._1()));
for (Tuple2<Long, Long> halo : subhalos) {
nextOffset += halo._2;
fofOffset -= halo._2;
Tuple4<Long, Long, Long, Long> haloCounts = new Tuple4<>(tuple._1, halo._1, currentOffset, nextOffset);
result.add(haloCounts);
currentOffset = nextOffset;
}
}
if (fofOffset > 0) {
nextOffset += fofOffset;
Tuple4<Long, Long, Long, Long> haloCounts = new Tuple4<>(tuple._1, null, currentOffset, nextOffset);
JavaRDD<Tuple4<Long, Long, Long, Long>> fofs = fofOffsets.leftOuterJoin(haloParticleCounts.groupByKey()).flatMap(tuple -> {
Long groupStartNpart = tuple._2._1._1;
List<Tuple4<Long, Long, Long, Long>> result = new ArrayList<>();
long currentOffset = groupStartNpart;
Optional<Iterable<Tuple2<Long, Long>>> fofHalos = tuple._2._2;
long fofOffset = tuple._2._1._2;
Long nextOffset = currentOffset;
if (fofHalos.isPresent()) {
List<Tuple2<Long, Long>> subhalos = FluentIterable.from(fofHalos.get()).toSortedList((a, b) -> a._1().compareTo(b._1()));
for (Tuple2<Long, Long> halo : subhalos) {
nextOffset += halo._2;
fofOffset -= halo._2;
Tuple4<Long, Long, Long, Long> haloCounts = new Tuple4<>(tuple._1, halo._1, currentOffset, nextOffset);
result.add(haloCounts);
currentOffset = nextOffset;
}
return result;
});
}
if (fofOffset > 0) {
nextOffset += fofOffset;
Tuple4<Long, Long, Long, Long> haloCounts = new Tuple4<>(tuple._1, null, currentOffset, nextOffset);
result.add(haloCounts);
}
return result;
});
JavaPairRDD<Long, Tuple2<Long, Long>> ordinalFofs = flatMapToPair(fofs, tuple -> {
List<Tuple2<Long, Tuple2<Long, Long>>> particlesToFofs = new ArrayList<>();
for (long i = tuple._3(); i < tuple._4(); i++) {
particlesToFofs.add(new Tuple2<>(i, new Tuple2<>(tuple._1(), tuple._2())));
}
return particlesToFofs;
});
JavaPairRDD<Long, Tuple2<ProcessedParticle, Optional<Tuple2<Long, Long>>>> particlesWithHalos = leftOuterJoin(particlesByOrdinal, ordinalFofs,
particleParallelism);
JavaRDD<pl.edu.icm.cocos.spark.job.model.output.Particle> result = map(particlesWithHalos, data -> {
ProcessedParticle particle = data._2._1;
Long fofId = data._2._2.transform(x -> x._1).orNull();
Long haloId = data._2._2.isPresent() ? data._2._2.get()._2 : null;
int boxIndex = BoxUtils.calculateBoxIndex(particle.getPos_x(), particle.getPos_y(), particle.getPos_z(), boxSize.getValue(),
partitionsCount.getValue());
return new pl.edu.icm.cocos.spark.job.model.output.Particle(particle, fofId, haloId, boxIndex, particle.getType());
});
saveData(result);
groupedFofs.unpersist();
}
processedParticles.unpersist();
JavaPairRDD<Long, Tuple2<Long, Long>> ordinalFofs = flatMapToPair(fofs, tuple -> {
List<Tuple2<Long, Tuple2<Long, Long>>> particlesToFofs = new ArrayList<>();
for (long i = tuple._3(); i < tuple._4(); i++) {
particlesToFofs.add(new Tuple2<>(i, new Tuple2<>(tuple._1(), tuple._2())));
}
return particlesToFofs;
});
JavaPairRDD<Long, Tuple2<ProcessedParticle, Optional<Tuple2<Long, Long>>>> particlesWithHalos = leftOuterJoin(particlesByOrdinal, ordinalFofs,
particleParallelism);
JavaRDD<pl.edu.icm.cocos.spark.job.model.output.Particle> result = map(particlesWithHalos, data -> {
ProcessedParticle particle = data._2._1;
Long fofId = data._2._2.transform(x -> x._1).orNull();
Long haloId = data._2._2.isPresent() ? data._2._2.get()._2 : null;
int boxIndex = BoxUtils.calculateBoxIndex(particle.getPos_x(), particle.getPos_y(), particle.getPos_z(), boxSize.getValue(),
partitionsCount.getValue());
return new pl.edu.icm.cocos.spark.job.model.output.Particle(particle, fofId, haloId, boxIndex, particle.getType());
});
saveData(result);
groupedFofs.unpersist();
}
private void processParticlesWithoutFofs(final Integer processedType) throws Exception {
JavaRDD<GenericRowWithSchema> particles = tablesRegistrator.registerTable(outputDirectory, PARTITIONED_PARTICLES_TEMP_DIR, null,
"type=" + processedType);
JavaRDD<ProcessedParticle> processedParticles = tablesRegistrator.mapTable(ProcessedParticle.class, particles);
processParticlesWithoutFofs(processedParticles);
}
private void processParticlesWithoutFofs(JavaRDD<ProcessedParticle> processedParticles) throws Exception {
......
......@@ -6,18 +6,18 @@ public class FileNPartInfo implements Serializable {
private static final long serialVersionUID = -5464645343758644703L;
private final Long offset;
private final long[] offsets;
private final Long[] types;
public FileNPartInfo(Long offset, Long[] types) {
public FileNPartInfo(long[] offsets, Long[] types) {
super();
this.offset = offset;
this.offsets = offsets;
this.types = types;
}
public Long getOffset() {
return offset;
public long[] getOffsets() {
return offsets;
}
public Long[] getTypes() {
......
......@@ -30,6 +30,9 @@ public class ProcessedParticle implements Serializable {
private Integer type;
public ProcessedParticle(){
}
public ProcessedParticle(GadgetParticle particle, Long startOffset, Integer type){
this.id = particle.getId();
this.pos_x = particle.getPos_x();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment