Compare commits
2 Commits
72e3de945a
...
1a1262b0b1
Author | SHA1 | Date |
---|---|---|
coolneng | 1a1262b0b1 | |
coolneng | 70363a82a0 |
|
@ -11,6 +11,7 @@ class Hyperparameters:
|
||||||
learning_rate=0.004,
|
learning_rate=0.004,
|
||||||
l2_rate=0.001,
|
l2_rate=0.001,
|
||||||
log_directory="logs",
|
log_directory="logs",
|
||||||
|
max_length=80,
|
||||||
):
|
):
|
||||||
self.data_file = data_file
|
self.data_file = data_file
|
||||||
self.label_file = label_file
|
self.label_file = label_file
|
||||||
|
@ -22,3 +23,4 @@ class Hyperparameters:
|
||||||
self.learning_rate = learning_rate
|
self.learning_rate = learning_rate
|
||||||
self.l2_rate = l2_rate
|
self.l2_rate = l2_rate
|
||||||
self.log_directory = log_directory
|
self.log_directory = log_directory
|
||||||
|
self.max_length = max_length
|
||||||
|
|
|
@ -17,7 +17,8 @@ def build_model(hyperparams) -> Model:
|
||||||
"""
|
"""
|
||||||
model = Sequential(
|
model = Sequential(
|
||||||
[
|
[
|
||||||
Input(shape=(None, len(BASES))),
|
Input(shape=(None, hyperparams.max_length, len(BASES))),
|
||||||
|
Masking(mask_value=-1),
|
||||||
Conv1D(
|
Conv1D(
|
||||||
filters=16,
|
filters=16,
|
||||||
kernel_size=5,
|
kernel_size=5,
|
||||||
|
|
|
@ -22,19 +22,6 @@ def align_sequences(sequence, label) -> Tuple[str, str]:
|
||||||
return aligned_seq, aligned_label
|
return aligned_seq, aligned_label
|
||||||
|
|
||||||
|
|
||||||
def generate_example(sequence, label) -> bytes:
|
|
||||||
"""
|
|
||||||
Create a binary-string for each sequence containing the sequence and the bases' counts
|
|
||||||
"""
|
|
||||||
aligned_seq, aligned_label = align_sequences(sequence, label)
|
|
||||||
schema = {
|
|
||||||
"sequence": Feature(int64_list=Int64List(value=encode_sequence(aligned_seq))),
|
|
||||||
"label": Feature(int64_list=Int64List(value=encode_sequence(aligned_label))),
|
|
||||||
}
|
|
||||||
example = Example(features=Features(feature=schema))
|
|
||||||
return example.SerializeToString()
|
|
||||||
|
|
||||||
|
|
||||||
def encode_sequence(sequence) -> List[int]:
|
def encode_sequence(sequence) -> List[int]:
|
||||||
"""
|
"""
|
||||||
Encode the DNA sequence using the indices of the BASES constant
|
Encode the DNA sequence using the indices of the BASES constant
|
||||||
|
@ -43,6 +30,30 @@ def encode_sequence(sequence) -> List[int]:
|
||||||
return encoded_sequence
|
return encoded_sequence
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_sequences(sequence, label):
|
||||||
|
"""
|
||||||
|
Align and encode the sequences to obtain a fixed length output in order to perform batching
|
||||||
|
"""
|
||||||
|
encoded_sequences = []
|
||||||
|
aligned_seq, aligned_label = align_sequences(sequence, label)
|
||||||
|
for item in [aligned_seq, aligned_label]:
|
||||||
|
encoded_sequences.append(encode_sequence(item))
|
||||||
|
return encoded_sequences[0], encoded_sequences[1]
|
||||||
|
|
||||||
|
|
||||||
|
def generate_example(sequence, label) -> bytes:
|
||||||
|
"""
|
||||||
|
Create a binary-string for each sequence containing the sequence and the bases' counts
|
||||||
|
"""
|
||||||
|
processed_seq, processed_label = prepare_sequences(sequence, label)
|
||||||
|
schema = {
|
||||||
|
"sequence": Feature(int64_list=Int64List(value=processed_seq)),
|
||||||
|
"label": Feature(int64_list=Int64List(value=processed_label)),
|
||||||
|
}
|
||||||
|
example = Example(features=Features(feature=schema))
|
||||||
|
return example.SerializeToString()
|
||||||
|
|
||||||
|
|
||||||
def read_fastq(hyperparams) -> List[bytes]:
|
def read_fastq(hyperparams) -> List[bytes]:
|
||||||
"""
|
"""
|
||||||
Parses a data and a label FASTQ files and generates a List of serialized Examples
|
Parses a data and a label FASTQ files and generates a List of serialized Examples
|
||||||
|
@ -104,8 +115,13 @@ def read_dataset(filepath, hyperparams) -> TFRecordDataset:
|
||||||
data_input = TFRecordDataset(filenames=filepath)
|
data_input = TFRecordDataset(filenames=filepath)
|
||||||
dataset = data_input.map(map_func=process_input, num_parallel_calls=AUTOTUNE)
|
dataset = data_input.map(map_func=process_input, num_parallel_calls=AUTOTUNE)
|
||||||
shuffled_dataset = dataset.shuffle(buffer_size=10000, seed=42)
|
shuffled_dataset = dataset.shuffle(buffer_size=10000, seed=42)
|
||||||
batched_dataset = shuffled_dataset.batch(batch_size=hyperparams.batch_size).repeat(
|
batched_dataset = shuffled_dataset.padded_batch(
|
||||||
count=hyperparams.epochs
|
batch_size=hyperparams.batch_size,
|
||||||
|
padded_shapes=(
|
||||||
|
[hyperparams.max_length, len(BASES)],
|
||||||
|
[hyperparams.max_length, len(BASES)],
|
||||||
|
),
|
||||||
|
padding_values=-1.0,
|
||||||
)
|
)
|
||||||
return batched_dataset
|
return batched_dataset
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue