Convert sequence and label to VarLenFeature

This commit is contained in:
coolneng 2021-06-14 19:33:42 +02:00
parent c6d0d5959d
commit 19ed847d12
Signed by: coolneng
GPG Key ID: 9893DA236405AF57
1 changed files with 30 additions and 22 deletions

View File

@ -3,27 +3,31 @@ from typing import List, Tuple
from Bio.motifs import create from Bio.motifs import create
from Bio.SeqIO import parse from Bio.SeqIO import parse
from numpy.random import random from numpy.random import random
from tensorflow import float32, int64 from tensorflow import Tensor, int64, stack
from tensorflow.sparse import to_dense
from tensorflow.data import TFRecordDataset from tensorflow.data import TFRecordDataset
from tensorflow.io import FixedLenFeature, TFRecordWriter, parse_single_example from tensorflow.io import (
from tensorflow.train import Example, Feature, Features, FloatList, Int64List FixedLenFeature,
TFRecordWriter,
VarLenFeature,
parse_single_example,
)
from tensorflow.train import Example, Feature, Features, Int64List
from constants import * from constants import *
def generate_example(sequence, label, weight_matrix) -> bytes: def generate_example(sequence, label, base_counts) -> bytes:
""" """
Create a binary-string for each sequence containing the sequence and the bases' frequency Create a binary-string for each sequence containing the sequence and the bases' counts
""" """
schema = { schema = {
"sequence": Feature( "A_counts": Feature(int64_list=Int64List(value=[sum(base_counts["A"])])),
int64_list=Int64List(value=list(encode_sequence(sequence))) "C_counts": Feature(int64_list=Int64List(value=[sum(base_counts["C"])])),
), "G_counts": Feature(int64_list=Int64List(value=[sum(base_counts["G"])])),
"label": Feature(int64_list=Int64List(value=list(encode_sequence(label)))), "T_counts": Feature(int64_list=Int64List(value=[sum(base_counts["T"])])),
"A_counts": Feature(float_list=FloatList(value=weight_matrix["A"])), "sequence": Feature(int64_list=Int64List(value=encode_sequence(sequence))),
"C_counts": Feature(float_list=FloatList(value=weight_matrix["C"])), "label": Feature(int64_list=Int64List(value=encode_sequence(label))),
"G_counts": Feature(float_list=FloatList(value=weight_matrix["G"])),
"T_counts": Feature(float_list=FloatList(value=weight_matrix["T"])),
} }
example = Example(features=Features(feature=schema)) example = Example(features=Features(feature=schema))
return example.SerializeToString() return example.SerializeToString()
@ -48,7 +52,7 @@ def read_fastq(data_file, label_file) -> List[bytes]:
example = generate_example( example = generate_example(
sequence=str(element.seq), sequence=str(element.seq),
label=str(label.seq), label=str(label.seq),
weight_matrix=motifs.pwm, base_counts=motifs.counts,
) )
examples.append(example) examples.append(example)
return examples return examples
@ -73,19 +77,23 @@ def create_dataset(
test.write(element) test.write(element)
def process_input(byte_string) -> Example: def process_input(byte_string) -> Tuple[Tensor, Tensor]:
""" """
Parse a byte-string into an Example object Parse a byte-string into an Example object
""" """
schema = { schema = {
"sequence": FixedLenFeature(shape=[], dtype=int64), "A_counts": FixedLenFeature(shape=[1], dtype=int64),
"label": FixedLenFeature(shape=[], dtype=int64), "C_counts": FixedLenFeature(shape=[1], dtype=int64),
"A_counts": FixedLenFeature(shape=[], dtype=float32), "G_counts": FixedLenFeature(shape=[1], dtype=int64),
"C_counts": FixedLenFeature(shape=[], dtype=float32), "T_counts": FixedLenFeature(shape=[1], dtype=int64),
"G_counts": FixedLenFeature(shape=[], dtype=float32), "sequence": VarLenFeature(dtype=int64),
"T_counts": FixedLenFeature(shape=[], dtype=float32), "label": VarLenFeature(dtype=int64),
} }
return parse_single_example(byte_string, features=schema) parsed_features = parse_single_example(byte_string, features=schema)
parsed_features["sequence"] = to_dense(parsed_features["sequence"])
parsed_features["label"] = to_dense(parsed_features["label"])
features = list(parsed_features.values())[:-1]
return stack(features, axis=-1), parsed_features["label"]
def read_dataset(filepath) -> TFRecordDataset: def read_dataset(filepath) -> TFRecordDataset: