Skip to content
Snippets Groups Projects
Commit c1ff9d27 authored by Terézia Slanináková's avatar Terézia Slanináková
Browse files

Adjusted load_tiny

parent dc50ef42
No related branches found
No related tags found
No related merge requests found
Pipeline #169834 passed
...@@ -318,14 +318,15 @@ class CoPhIRDataLoader(Dataloader): ...@@ -318,14 +318,15 @@ class CoPhIRDataLoader(Dataloader):
self._normalize self._normalize
) )
def load_tiny_descriptors(self) -> Tuple[pd.DataFrame, pd.DataFrame]: def load_tiny_descriptors(self, n_objects=10) -> Tuple[pd.DataFrame, pd.DataFrame]:
return self.load_dataset( return self.load_dataset(
self.descriptors, self.descriptors,
self.object_ids, self.object_ids,
self._shuffle, self._shuffle,
self._shuffle_seed, self._shuffle_seed,
self._normalize, self._normalize,
tiny=True tiny=True,
n_objects=n_objects
) )
def load_dataset( def load_dataset(
...@@ -335,7 +336,8 @@ class CoPhIRDataLoader(Dataloader): ...@@ -335,7 +336,8 @@ class CoPhIRDataLoader(Dataloader):
shuffle, shuffle,
shuffle_seed, shuffle_seed,
normalize, normalize,
tiny=False tiny=False,
n_objects=None
) -> Tuple[pd.DataFrame, pd.DataFrame]: ) -> Tuple[pd.DataFrame, pd.DataFrame]:
""" Loads the CoPhIR dataset from the disk into the memory. """ Loads the CoPhIR dataset from the disk into the memory.
The resulting DataFrame is expected to have `self.dataset_size` rows The resulting DataFrame is expected to have `self.dataset_size` rows
...@@ -369,7 +371,7 @@ class CoPhIRDataLoader(Dataloader): ...@@ -369,7 +371,7 @@ class CoPhIRDataLoader(Dataloader):
sep=r'[,|;]', sep=r'[,|;]',
engine='python', engine='python',
dtype=np.int32, dtype=np.int32,
skiprows=999_990, skiprows=1_000_000-n_objects,
usecols=[i for i in range(284) if i != 218 and i != 219] usecols=[i for i in range(284) if i != 218 and i != 219]
) )
else: else:
...@@ -383,7 +385,7 @@ class CoPhIRDataLoader(Dataloader): ...@@ -383,7 +385,7 @@ class CoPhIRDataLoader(Dataloader):
) )
df_orig = df_orig.fillna(0) df_orig = df_orig.fillna(0)
if tiny: if tiny:
df_objects = pd.read_csv(object_ids, skiprows=999_990, header=None, dtype=np.uint32) df_objects = pd.read_csv(object_ids, skiprows=1_000_000-n_objects, header=None, dtype=np.uint32)
else: else:
df_objects = pd.read_csv(object_ids, header=None, dtype=np.uint32) df_objects = pd.read_csv(object_ids, header=None, dtype=np.uint32)
...@@ -443,14 +445,16 @@ class ProfisetDataLoader(Dataloader): ...@@ -443,14 +445,16 @@ class ProfisetDataLoader(Dataloader):
def load_descriptors(self) -> pd.DataFrame: def load_descriptors(self) -> pd.DataFrame:
return self.load_dataset(self.descriptors, self.object_ids, self._shuffle, self._shuffle_seed) return self.load_dataset(self.descriptors, self.object_ids, self._shuffle, self._shuffle_seed)
def load_tiny_descriptors(self) -> Tuple[pd.DataFrame, pd.DataFrame]: def load_tiny_descriptors(self, n_objects=10) -> Tuple[pd.DataFrame, pd.DataFrame]:
return self.load_dataset( return self.load_dataset(
self.descriptors, self.descriptors,
self.object_ids, self.object_ids,
self._shuffle, self._shuffle,
self._shuffle_seed, self._shuffle_seed,
self._normalize, self._normalize,
tiny=True) tiny=True,
n_objects=n_objects
)
def load_dataset( def load_dataset(
self, self,
...@@ -460,7 +464,8 @@ class ProfisetDataLoader(Dataloader): ...@@ -460,7 +464,8 @@ class ProfisetDataLoader(Dataloader):
shuffle_seed, shuffle_seed,
normalize=None, normalize=None,
tiny=False, tiny=False,
num_tiny=None num_tiny=None,
n_objects=None
) -> pd.DataFrame: ) -> pd.DataFrame:
""" Loads the Profiset dataset from the disk into the memory. """ Loads the Profiset dataset from the disk into the memory.
The resulting DataFrame is expected to have `self.dataset_size` rows The resulting DataFrame is expected to have `self.dataset_size` rows
...@@ -477,8 +482,10 @@ class ProfisetDataLoader(Dataloader): ...@@ -477,8 +482,10 @@ class ProfisetDataLoader(Dataloader):
self.LOG.info(f'Loading Profiset/MoCap dataset from {descriptors}.') self.LOG.info(f'Loading Profiset/MoCap dataset from {descriptors}.')
time_start = time.time() time_start = time.time()
if tiny: if tiny:
if n_objects is None:
n_objects = 10
if num_tiny is None: if num_tiny is None:
num_tiny = 999_990 num_tiny = 1_000_000 - n_objects
df = pd.read_csv( df = pd.read_csv(
descriptors, descriptors,
header=None, header=None,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment