Compare commits

..

No commits in common. "1a1262b0b104274b5348d1cc68a706014fb2adbb" and "72e3de945a9ca45fa19b8a4761f113ea09e04258" have entirely different histories.

3 changed files with 16 additions and 35 deletions

View File

@ -11,7 +11,6 @@ class Hyperparameters:
learning_rate=0.004,
l2_rate=0.001,
log_directory="logs",
max_length=80,
):
self.data_file = data_file
self.label_file = label_file
@ -23,4 +22,3 @@ class Hyperparameters:
self.learning_rate = learning_rate
self.l2_rate = l2_rate
self.log_directory = log_directory
self.max_length = max_length

View File

@ -17,8 +17,7 @@ def build_model(hyperparams) -> Model:
"""
model = Sequential(
[
Input(shape=(None, hyperparams.max_length, len(BASES))),
Masking(mask_value=-1),
Input(shape=(None, len(BASES))),
Conv1D(
filters=16,
kernel_size=5,

View File

@ -22,6 +22,19 @@ def align_sequences(sequence, label) -> Tuple[str, str]:
return aligned_seq, aligned_label
def generate_example(sequence, label) -> bytes:
"""
Create a binary-string for each sequence containing the sequence and the bases' counts
"""
aligned_seq, aligned_label = align_sequences(sequence, label)
schema = {
"sequence": Feature(int64_list=Int64List(value=encode_sequence(aligned_seq))),
"label": Feature(int64_list=Int64List(value=encode_sequence(aligned_label))),
}
example = Example(features=Features(feature=schema))
return example.SerializeToString()
def encode_sequence(sequence) -> List[int]:
"""
Encode the DNA sequence using the indices of the BASES constant
@ -30,30 +43,6 @@ def encode_sequence(sequence) -> List[int]:
return encoded_sequence
def prepare_sequences(sequence, label):
"""
Align and encode the sequences to obtain a fixed length output in order to perform batching
"""
encoded_sequences = []
aligned_seq, aligned_label = align_sequences(sequence, label)
for item in [aligned_seq, aligned_label]:
encoded_sequences.append(encode_sequence(item))
return encoded_sequences[0], encoded_sequences[1]
def generate_example(sequence, label) -> bytes:
"""
Create a binary-string for each sequence containing the sequence and the bases' counts
"""
processed_seq, processed_label = prepare_sequences(sequence, label)
schema = {
"sequence": Feature(int64_list=Int64List(value=processed_seq)),
"label": Feature(int64_list=Int64List(value=processed_label)),
}
example = Example(features=Features(feature=schema))
return example.SerializeToString()
def read_fastq(hyperparams) -> List[bytes]:
"""
Parses a data and a label FASTQ files and generates a List of serialized Examples
@ -115,13 +104,8 @@ def read_dataset(filepath, hyperparams) -> TFRecordDataset:
data_input = TFRecordDataset(filenames=filepath)
dataset = data_input.map(map_func=process_input, num_parallel_calls=AUTOTUNE)
shuffled_dataset = dataset.shuffle(buffer_size=10000, seed=42)
batched_dataset = shuffled_dataset.padded_batch(
batch_size=hyperparams.batch_size,
padded_shapes=(
[hyperparams.max_length, len(BASES)],
[hyperparams.max_length, len(BASES)],
),
padding_values=-1.0,
batched_dataset = shuffled_dataset.batch(batch_size=hyperparams.batch_size).repeat(
count=hyperparams.epochs
)
return batched_dataset