Hi Josh,
Sorry for the confusion regarding the runDataAugmentation()
function, we have an open issue to clarify this here: Clarify runDataAugmentation() params · Issue #2072 · SCIInstitute/ShapeWorks · GitHub
There are two ways to call runDataAugmentation()
.
The first is:
DataAugmentationUtils.runDataAugmentation(out_dir, img_list,
world_point_list, num_samples,
num_dim, percent_variability,
sampler_type, mixture_num)
This generates image/particle pairs in the world coordinate system and assumes the images in img_list are groomed/aligned so they are in the world cooridnate system.
The second is:
DataAugmentationUtils.runDataAugmentation(out_dir, img_list,
local_point_list, num_samples,
num_dim, percent_variability,
sampler_type, mixture_num,
world_point_list)
This generates image/particle pairs in the local coordinate system and assumes the images in img_list are the original/unaligned images. The world_point_list needs to be provided in this case so that PCA is done in the world coordinate system. New samples are generated by sampling the world PCA subspace, then mapping it to local points using the transform from world to local of the closest real example. In the future, we could add noise to this transform as an additional form of augmentation, but right now, this is not included.
I believe for your case, you can just use the first way of calling the method where the third parameter is the world particles.
If you don’t require augmented images, instead of calling runDataAugmentation
you could try the following code (will be much faster):
from DataAugmentationUtils import Utils
from DataAugmentationUtils import Embedder
from DataAugmentationUtils import Sampler
point_matrix = Utils.create_data_matrix(world_point_list)
PointEmbedder = Embedder.PCA_Embbeder(world_point_matrix, num_dim, percent_variability)
embedded_matrix = PointEmbedder.getEmbeddedMatrix()
PointSampler = Sampler.Gaussian_Sampler() # or Sampler.Mixture_Sampler() or Sampler.KDE_Sampler()
PointSampler.fit(embedded_matrix)
for index in range(1, num_samples+1):
sampled_embedding, base_index = PointSampler.sample()
gen_points = PointEmbedder.project(sampled_embedding)
out_path = out_dir + 'sample_' + Utils.pad_index(index) + ".particles"
np.savetxt(out_path, gen_points)
Let me know if this works or if you have further questions!
Jadie