基于LSTM-CLIP的多模态自主疾病诊疗方法,包含。包括电子病历信息预处理模块、transformer文本编码器模块、图像编码器模块、图像特征提取网络模块、LSTM循环神经网络模块、以及基于强化学习的交互模块。其中:电子病历信息预处理模块:用于采集病人文本病历以及影像学病历信息并对他们做预处理,使其转换为可以供神经网络输入的相关数据形式。 编码器模块:分为图像编码器(Images encoder)与文本编码器(Text encoder)模块,图像编码器将输入图像编码成一个包含语义信息的高维向量,同理文本编码器将病人的病历文本信息进行特征提取并将其编码成包含病历语义信息的高维向量。 特征提取网络模块:该模块为额外附加的模块,可以根据实际应用中不同的需求插入到决策网络中。 LSTM循环神经网络模块:该模块是整个决策网络的主干,也是整个系统中与强化学习环境相交互的agent。模块将前面编码器结构所提取的高维特征做时序建模并做出相应的诊疗动作。 基于强化学习的交互模块:该模块基于价值网络的DDQN算法,提供整个系统强化学习环境,模块接收来自病人反馈的评分并根据这些反馈生成相应的奖赏指导智能体agent决策网络做出相应诊疗行为action。

DEMO:baseline主干网络流程图如下:(优点:模型小,硬件较低,需要数据量不大,数据小效果好,容易出结果;缺点:天花板低,大数据时效果比不少深度学习)

类DNS域名解析服务的疾病流程分级解析流程(初筛疾病种类进行分科->对应专科的疾病初筛网络(开出检查项目)->临床决策网络->临床后端评价网络)

初筛网络,决策网络,时序网络,强化学习网络

多个多模态Clip网络作分类器进程集成,深度强化学习state reward的现实交互模式从临床中学习

# by kayak.gw

import torch

#import clip

from PIL import Image

import time

#from text_encoder import C

import numpy as np

import torch

import torch.nn.functional as F

from torch import nn

from collections import OrderedDict

from typing import Tuple, Union

class LayerNorm(nn.LayerNorm):

"""Subclass torch's LayerNorm to handle fp16."""

def forward(self, x: torch.Tensor):

orig_type = x.dtype

ret = super().forward(x.type(torch.float32))

return ret.type(orig_type)

class QuickGELU(nn.Module):

def forward(self, x: torch.Tensor):

return x * torch.sigmoid(1.702 * x)

class ResidualAttentionBlock(nn.Module):

def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):

super().__init__()

self.attn = nn.MultiheadAttention(d_model, n_head)

self.ln_1 = LayerNorm(d_model)

self.mlp = nn.Sequential(OrderedDict([

("c_fc", nn.Linear(d_model, d_model * 4)),

("gelu", QuickGELU()),

("c_proj", nn.Linear(d_model * 4, d_model))

]))

self.ln_2 = LayerNorm(d_model)

self.attn_mask = attn_mask

def attention(self, x: torch.Tensor):

self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None

return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]

def forward(self, x: torch.Tensor):

x = x + self.attention(self.ln_1(x))

x = x + self.mlp(self.ln_2(x))

return x

class Transformer(nn.Module):

def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):

super().__init__()

self.width = width

self.layers = layers

self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])

def forward(self, x: torch.Tensor):

return self.resblocks(x)

class Text_Encoder(nn.Module):

def __init__(self,

embed_dim: int,

# text

context_length: int,

vocab_size: int,

transformer_width: int,

transformer_heads: int,

transformer_layers: int

):

super().__init__()

self.context_length = context_length

self.transformer = Transformer(

width=transformer_width,

layers=transformer_layers,

heads=transformer_heads,

attn_mask=self.build_attention_mask()

)

#512 12 8 tensor(

#print(transformer_width,transformer_layers,transformer_heads,self.build_attention_mask())

self.vocab_size = vocab_size

self.token_embedding = nn.Embedding(vocab_size, transformer_width)

#print(vocab_size, transformer_width,self.context_length)

self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))

self.ln_final = LayerNorm(transformer_width)

self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))

self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

self.initialize_parameters()

def initialize_parameters(self):

nn.init.normal_(self.token_embedding.weight, std=0.02)

nn.init.normal_(self.positional_embedding, std=0.01)

proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)

attn_std = self.transformer.width ** -0.5

fc_std = (2 * self.transformer.width) ** -0.5

for block in self.transformer.resblocks:

nn.init.normal_(block.attn.in_proj_weight, std=attn_std)

nn.init.normal_(block.attn.out_proj.weight, std=proj_std)

nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)

nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)

if self.text_projection is not None:

nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)

def build_attention_mask(self):

# lazily create causal attention mask, with full attention between the vision tokens

# pytorch uses additive attention mask; fill with -inf

mask = torch.empty(self.context_length, self.context_length)

mask.fill_(float("-inf"))

mask.triu_(1) # zero out the lower diagonal

return mask

@property

def dtype(self):

return torch.FloatTensor

def forward(self, text):

x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]

#x1 torch.Size([1, 77, 512])

#x2 torch.Size([1, 77, 512])

#x3 torch.Size([77, 1, 512])

#print('x1',x.shape)

x = x + self.positional_embedding.type(self.dtype)

#print('x2',x.shape)

x = x.permute(1, 0, 2) # NLD -> LND

x = self.transformer(x)

#print('x3',x.shape)

x = x.permute(1, 0, 2) # LND -> NLD

x = self.ln_final(x).type(self.dtype)

# x.shape = [batch_size, n_ctx, transformer.width]

# take features from the eot embedding (eot_token is the highest number in each sequence)

x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection

return x

def save_model():

model=Text_Encoder(1024,77,49408,512,8,12)

torch.save(model.state_dict(), "text_encoder.pth")

def load_model(model_path):

return torch.load(model_path)

def text_encoder(pretrained=False,**kwargs):

"""

Construct

"""

model=Text_Encoder(1024,77,49408,512,8,12)

if pretrained:

print('Loading text model')

model.load_state_dict(torch.load(r'..\model\transformer_block_weight.pth'))

print('Loaded successfully')

return model

'''

# DEMO

save_model()

model=load_model("text_encoder.pth")

for k,v in model.items():

print(v.shape)

#DEMO

from images_encoder import ModifiedResNet

img=torch.rand(1,3,224,224)

model1=ModifiedResNet(heads=32,layers=(3, 4, 6, 3) ,output_dim=1024)

img_f=model1(img)

model=Text_Encoder(1024,77,49408,512,8,12)

model.load_state_dict(torch.load(r'transformer_block_weight.pth'))

a=torch.rand(1,77).type(torch.LongTensor)

x = model(a)

print(x.shape)

out=torch.add(img_f,x)

fc=nn.Linear(1024,10)

print(fc(out))

'''

推荐阅读

评论可见,请评论后查看内容,谢谢!!!评论后请刷新页面。