# scient
**scient**一个用python实现科学计算相关算法的包,包括自然语言、图像、神经网络、优化算法、机器学习、图计算等模块。
**scient**源码和编译安装包可以在`Python package index`获取。
The source code and binary installers for the latest released version are available at the [Python package index].
[https://pypi.org/project/scient](https://pypi.org/project/scient)
可以用`pip`安装`scient`。
You can install `scient` like this:
```
pip install scient
```
也可以用`setup.py`安装。
Or in the `scient` directory, execute:
```
python setup.py install
```
## scient.image
图像相关算法模块,包括边缘检测、图像相似度计算、图像质量评价、图像特征提取等。
### scient.image.friqa
全参考图像质量评价模块,包括峰值信噪比(PSNR),结构相似度(SSIM),直方图相似度(HistSim)。
#### scient.image.friqa.psnr(image1,image2,max_pix=255)
Parameters
----------
image1 : numpy.array 2D or 3D,参考图像
image2 : numpy.array 2D or 3D,待评价图像
max_pix : int, optional default is 255, 像素值的最大值,默认值是255.
Returns
-------
float
Algorithms
-------
PSNR(Peak Signal to Noise Ratio),峰值信噪比,是一种评价图像的客观标准,单位dB。图像在经过压缩之后,会在某种程度与原始图像不同,PSNR值用来衡量经过处理后的图像品质是否令人满意。
$$
PSNR=10 \cdot \log _ {10} ( \frac { MAX _ I ^ 2 } { MSE }) = 20 \cdot \log _ {10} ( \frac { MAX _ I } { MSE })
$$
其中,$MAX _ I$是图像像素值的最大值,一般每个采样点用8位表示,那么$MAX _ I$就是255。
$MSE$是待评价图像与参考图像的均方误差,$MSE$越小,PSNR越大;PSNR越大,待评价图像质量越好。
* PSNR高于40dB说明待评价图像质量极好,非常接近原始图像;
* PSNR在30—40dB说明待评价图像质量是较好,虽然有明显失真但可以接受;
* PSNR在20—30dB说明待评价图像质量差;
* PSNR低于20dB说明待评价图像质量不可接受。
PSNR缺点:基于对应像素点间的误差,即基于误差敏感的图像质量评价。由于并未考虑到人眼的视觉特性(人眼对空间频率较低的对比差异敏感度较高,人眼对亮度对比差异的敏感度较色度高,人眼对一个 区域的感知结果会受到其周围邻近区域的影响等),因而经常出现评价结果与人的主观感觉不一致的情况。
Examples
-------
```
import os
from scient.image import friqa
import numpy
from PIL import Image
ref_image='test/data/I10.BMP'
images=['test/data/I10.BMP','test/data/i10_23_3.bmp','test/data/i10_23_4.bmp','test/data/i10_23_5.bmp','test/data/i10_24_5.bmp']
#读取图像文件
ref_image=Image.open(os.path.join(os.path.dirname(friqa.__file__),'..',ref_image))
images=[Image.open(os.path.join(os.path.dirname(friqa.__file__),'..',i)) for i in images]
#计算psnr
for i in images:
print(friqa.psnr(numpy.array(ref_image),numpy.array(i)))
```
运行结果
```
100
32.436263852012544
31.184291262813648
30.272831107297733
29.3584810257951
```
#### scient.image.friqa.ssim(image1,image2,k1=0.01,k2=0.03,block_size=(8, 8),max_pix=255)
Parameters
----------
image1 : numpy.array 2D
image2 : numpy.array 2D
k1 : float, optional,k1<<1,避免分母为0造成不稳定. The default is 0.01.
k2 : float, optional,k2<<1,避免分母为0造成不稳定. The default is 0.03.
block_size : tuple, optional,将图像分成多个block,采用gaussian加权计算所有block的均值、方差、协方差,进而计算所有block的ssim,最后的ssim取所有block的平均值. The default is (8,8).
max_pix : int, optional default is 255, 像素值的最大值,默认值是255.
Returns
-------
float
Algorithms
-------
SSIM(Structural Similarity),结构相似度,用于衡量两个图像相似程度,或检测图像的失真程度。
SSIM基于样本之间的亮度(luminance,像素平均值)、对比度(contrast,像素标准差)和结构(structure,像素减均值除以标准差)计算。
$$
SSIM(x,y)=f(l(x,y),c(x,y),s(x,y))
$$
$l(x,y)$为亮度对比函数,是关于图像的平均灰度$μ_x,μ_y$的函数;
$$
l(x,y)=\frac { 2μ_x μ_y + C1 } { μ_x^2 μ_y^2 + C1 }
$$
$$
μ_x=\frac { 1 } { N } \sum^{N}_{i=1}{x_i}
$$
$$
C1=(K_1 L)^2
$$
像素值的最大值,默认值是255. K1<<1。
$c(x,y)$为对比度对比函数,是关于图像的标准差$σ_x,σ_y$的函数;
$$
c(x,y)=\frac { 2σ_x σ_y + C2 } { σ_x^2 σ_y^2 + C2 }
$$
$$
σ_x=(\frac { 1 } { N-1 } \sum^{N}_{i=1}{(x_i-μ_x)^2})^{\frac { 1 } { 2 }}
$$
$$
C2=(K_2 L)^2
$$
K2<<1
$s(x,y)$为结构对比函数,是关于图像的标准化$\frac { x-μ_x } { σ_x },\frac { y-μ_y } { σ_y }$的函数;
$$
s(x,y)=\frac { σ_{xy} + C3 } { σ_x σ_y + C3 }
$$
$$
σ_{xy}=\frac { 1 } { N-1 } (\sum^{N}_{i=1}{(x_i-μ_x)(y_i-μ_y)})
$$
$$
SSIM(x,y)=[l(x,y)]^α[c(x,y)]^β[s(x,y)]^γ
$$
α,β,γ取1,令$C_3=\frac { C_2 } { 2 }$,可将SSIM简化为:
$$
SSIM(x,y)=\frac { (2μ_x μ_y + C1)(2σ_{xy} + C2) } { (μ_x^2 μ_y^2 + C1)(σ_x^2 σ_y^2 + C2) }
$$
SSIM取值范围为[0,1],值越大表示图像质量越好。
SSIM具有:对称性,ssim(x,y)==ssim(y,x);
有界性,ssim(x,y)<=1;
最大值唯一性,当且仅当x==y时,ssim(x,y)==1。
SSIM缺点:对于图像出现位移、缩放、旋转(皆属于非结构性的失真)的情况无法有效的判断。
Examples
-------
```
import os
from scient.image import friqa
import numpy
from PIL import Image
ref_image='test/data/I10.BMP'
images=['test/data/I10.BMP','test/data/i10_23_3.bmp','test/data/i10_23_4.bmp','test/data/i10_23_5.bmp','test/data/i10_24_5.bmp']
#读取图像文件
ref_image=Image.open(os.path.join(os.path.dirname(friqa.__file__),'..',ref_image))
images=[Image.open(os.path.join(os.path.dirname(friqa.__file__),'..',i)) for i in images]
#计算ssim
for i in images:
print(friqa.ssim(numpy.array(ref_image.convert("L")),numpy.array(i.convert("L"))))
```
运行结果
```
1.0
0.8568124416229375
0.6810351495300123
0.5575398637742431
0.5072153083460104
```
### scient.image.feature
图像特征提取模块,包括BRISQUE,基于累积概率的锐化因子(CPB),曝光度。
#### scient.image.feature.brisque(image)
Parameters
----------
image : numpy.array 2D
Returns
-------
tuple
('gdd_α','gdd_σ',
'aggd_α1','aggd_η1','aggd_σl1','aggd_σr1',
'aggd_α2','aggd_η2','aggd_σl2','aggd_σr2',
'aggd_α3','aggd_η3','aggd_σl3','aggd_σr3',
'aggd_α4','aggd_η4','aggd_σl4','aggd_σr4')
Algorithms
-------
BRISQUE(Blind/Referenceless Image Spatial QUality Evaluator),是一种无参考的空间域图像质量评估算法。先计算Mean Subtracted Contrast Normalized Coefficients(MSCN系数),MSCN系数反映了由于失真的存在而改变的特征统计,可以用来作为图像失真的统计特征。再用MSCN系数估计Generalized Gaussian Distribution(GDD)的参数α、σ,以及Asymmetric Generalized Gaussian Distribution(AGGD)在Horizontal Neighbour, Vertical Neighbour, On Diagonal Neighbour, Off Diagonal Neighbour上的参数α、η、σl、σr,将GDD的两个参数和AGGD的16个参数作为输出的特征。
MSCN系数:
$$
MSCN(i,j)=\frac { I(i,j)-μ(i,j) } { σ(i,j)+C }
$$
$$
μ(i,j)=\sum^{K}_{k=-K}{\sum^{L}_{l=-L}{w_{k,l}I_{k,l}(i,j)}}
$$
$$
σ(i,j)=\sqrt{\sum^{K}_{k=-K}{\sum^{L}_{l=-L}{w_{k,l}(I_{k,l}(i,j)-μ(i,j))^2}}}
$$
其中$I(i,j)$表示原始图像i行j列元素的值。
Generalized Gaussian Distribution:
$$
f(x;α,σ^2)=\frac {α} {2βΓ(1/α)} e^{-(\frac {|x|}{β})^α}
$$
$$
β=σ\sqrt{\frac{Γ(1/α)}{Γ(3/α)}}
$$
$$
Γ(α)=\int^{\infty}_{0}{t^{α-1}e^{-t}dt} α>0
$$
Neighbours:
$$
HorizontalNeighbour(i,j)=MSCN(i,j)MSCN(i,j+1)
$$
$$
VerticalNeighbour(i,j)=MSCN(i,j)MSCN(i+1,j)
$$
$$
OnDiagonalNeighbour(i,j)=MSCN(i,j)MSCN(i+1,j+1)
$$
$$
OffDiagonalNeighbour(i,j)=MSCN(i,j)MSCN(i+1,j-1)
$$
Asymmetric Generalized Gaussian Distribution:
$$
f(x;α,σ_l^2,σ_r^2)=
\frac {α}{(β_l+β_r)Γ(1/α)}e^{-(\frac {-x}{β_l})^α} x<0
\frac {α}{(β_l+β_r)Γ(1/α)}e^{-(\frac {x}{β_r})^α} x>=0
β_l=σ_l\sqrt{\frac{Γ(1/α)}{Γ(3/α)}} \\
β_r=σ_r\sqrt{\frac{Γ(1/α)}{Γ(3/α)}}
$$
Examples
-------
```
import os
from scient.image import feature
import numpy
from PIL import Image
images=['test/data/I10.BMP','test/data/i10_23_3.bmp','test/data/i10_23_4.bmp','test/data/i10_23_5.bmp','test/data/i10_24_5.bmp']
#读取图像文件
images=[Image.open(os.path.join(os.path.dirname(feature.__file__),'..',i)) for i in images]
#计算brisque
brisques=[]
for i in images:
brisques.append(feature.brisque(numpy.array(i.convert('L'))))
print(brisques)
```
运行结果
```
[(2.8390000000000026, 0.5387382509471336, 0.8180000000000005, 0.1597336483186561, 0.19928197982139934, 0.4696747920784309, 0.8640000000000005, 0.17081167501931036, 0.1703080506100513, 0.440894038756712, 0.8610000000000007, -0.002437981115828319, 0.2983089768677447, 0.2943996123553127, 0.8670000000000007, 0.03657370089459203, 0.2641503963750437, 0.32229688865209727), (2.179000000000002, 0.3755805588864052, 0.6610000000000005, 0.2105638785869636, 0.06573065885425396, 0.3546433105372317, 0.7250000000000005, 0.2035633011201771, 0.04895566298941261, 0.2895746994148656, 0.7110000000000005, 0.09196294223642214, 0.10660221933416321, 0.22150476223116147, 0.7220000000000004, 0.10061626044729756, 0.09951649928883519, 0.22307536755643081), (1.489000000000001, 0.19567592119387475, 0.4370000000000002, 0.16656579278574843, 0.005144811587270607, 0.1595102390164801, 0.4400000000000002, 0.14819323960693676, 0.007946536338563829, 0.14400949152877282, 0.46900000000000025, 0.1304195444573072, 0.010840852166168865, 0.12285748598680354, 0.47300000000000025, 0.12785146234621667, 0.011051488263507676, 0.11939877242752284), (1.2570000000000008, 0.1189807661854071, 0.2940000000000001, 0.09858069094224381, 0.0033503171775502846, 0.1003980673321924, 0.2960000000000001, 0.09662228540309649, 0.0037953392707882772, 0.09854664422093222, 0.3160000000000001, 0.08840261656054116, 0.004225987220008733, 0.08029184471742051, 0.3180000000000001, 0.08631426420092875, 0.004399447310061135, 0.07751730107145516), (1.203000000000001, 0.14103130545847511, 0.3270000000000001, 0.10623288442963101, 0.008919473174326557, 0.12226537626029133, 0.3280000000000001, 0.06853644417080812, 0.02378947796849877, 0.10143999168472712, 0.33900000000000013, 0.05689116726400874, 0.02385946076111514, 0.08256978072093775, 0.33900000000000013, 0.05450324427873719, 0.02492368706293601, 0.0813272014967197)]
```
### scient.image.hash
图像hash模块,包括均值hash(mean hash),差值hash(diff hash),感知hash(percept hash)。
#### scient.image.hash.percept(image,hash_size=64)
Parameters
----------
image : numpy.array 2D
hash_size : 输出hash值的长度
Returns
-------
list
Algorithms
-------
先将图片缩放成hash_size*hash_size大小,然后对图像进行离散余弦变换DCT,并输出左上角h*w=hash_size的mean hash。
DCT是一种特殊的傅立叶变换,将图片从像素域变换为频率域,并且DCT矩阵从左上角到右下角代表越来越高频率的系数,图片的主要信息保留左上角的低频区域。
一维DCT变换:
$$
F(x)=c(x)\sum^{N-1}_{i=0}{f(i)cos(\frac {(i+0.5)π}{N}x) }
$$
$$
c(x)=\left\{\begin{matrix}\sqrt{\frac{1}{N}} ,x=0\\\sqrt{\frac{2}{N}} ,x!=0 \end{matrix}\right.
$$
f(i)为原始的信号,F(x)是DCT变换后的系数,N为原始信号的点数,c(x)是补偿系数。
二维DCT变换:
$$
F(x,y)=c(x)c(y)sum^{N-1}_{i=0}{sum^{N-1}_{j=0}{f(i,j)cos(\frac {(i+0.5)π}{N}x)cos(\frac {(j+0.5)π}{N}y)}}
$$
$$
c(x)=\left\{\begin{matrix}\sqrt{\frac{1}{N}} ,x=0\\\sqrt{\frac{2}{N}} ,x!=0 \end{matrix}\right.
$$
二维DCT变换也可表示为:
$$
F=AfA^T
$$
$$
A(i,j)=c(i)cos(\frac {(j+0.5)π}{N}i)
$$
此形式更方便计算。DCT变换是对称的,因此可以对经过DCT变换的图片进行还原操作。
Examples
-------
计算图像感知相似度时,首先计算图像的PHASH值,再采用海明(hamming)距离相似度计算图片PHASH值的相似度。
```
#采用感知hash计算图片的感知相似度
import os
from scient.image import hash
from scient.algorithms import similar
import numpy
from PIL import Image
ref_image='test/data/I10.BMP'
images=['test/data/I10.BMP','test/data/i10_23_3.bmp','test/data/i10_23_4.bmp','test/data/i10_23_5.bmp','test/data/i10_24_5.bmp']
#读取图像文件
ref_image=Image.open(os.path.join(os.path.dirname(hash.__file__),'..',ref_image))
images=[Image.open(os.path.join(os.path.dirname(hash.__file__),'..',i)) for i in images]
#计算感知hash
phash=hash.percept(numpy.array(ref_image.convert("L")))
phashs=[hash.percept(numpy.array(i.convert("L"))) for i in images]
#计算感知相似度
for i in phashs:
print(similar.hamming(i,phash))
```
运行结果
```
1.0
0.9384615384615385
0.8615384615384616
0.8153846153846154
0.6
```
## scient.neuralnet
神经网络相关算法模块,包括attention、transformer、bert、lstm、resnet、crf、dataset、fit等。
### scient.neuralnet.fit
神经网络训练模块,将torch构建的神经网络模型的训练方式简化为model.fit(),使torch神经网络模型训练更简捷,更优雅。
使用步骤:
(1)基于torch构建模型model,采用torch.utils.data.DataLoader加载训练数据集train_loader、验证数据集eval_loader(可选);
(2)采用fit.set()设置模型训练参数,参数详情:
* optimizer=None: 优化器,可以用类似torch.optim模块内的优化器来定义;
* scheduler=None: 优化器的调度器,可以用类似torch.optim.lr_scheduler模块内的调度器来定义;
* loss_func=None: 损失函数,可以用类似torch.nn.CrossEntropyLoss()来定义;
* grad_func=None: 梯度操作函数,可进行如梯度裁剪的操作;
* perform_func=None: 模型性能函数,模型传入预测值和实际值,用以评估模型性能;
* n_iter=10: 模型在数据集上迭代训练的次数;
- 如果n_iter为int,表示模型在数据集上迭代训练n_iter后停止;
- 如果n_iter为(int,int),表示模型在数据集上迭代训练的最小min_iter和最大max_iter次数, 如果迭代次数超过min_iter且eval的perform_func比上一个iter大,结束训练。n_iter为(int,int)时,必须提供eval_loader,且perform_func必须是一个数值,且值越大模型性能越好;
* device=None: 模型训练的设备,如device = torch.device('cuda' if torch.cuda.is_available() else 'cpu');
* n_batch_step: 每n个batch更新一次optimizer的梯度,以节省显存及计算量;
* n_batch_plot: 每n个batch更新一下损失曲线,训练过程中会实时绘制损失曲线;
* save_path: 每个iter完成后模型保存路径,模型名为“模型类名_iter_i.checkpoint”,保存的内容为{'model_state_dict':model.state_dict(),'optimizer_state_dict':optimizer.state_dict(),'batch_loss':batch_loss},如果训练时未提供eval_loader,batch_loss=train_batch_loss, 否则batch_loss=[train_batch_loss,eval_batch_loss]
(3)采用model.fit(train_loader,eval_loader,mode=('input','target'))训练模型:
* train_loader: 训练数据集
* eval_loader: 验证数据集
* mode: 数据集包含的内容,分四种情况:
- mode=('input','target'), loader data item is one input and one target;
- mode='input', loader data item is only one input;
- mode=('inputs','target'), loader data item is a list of input and one target;
- mode='inputs', loader data item is a list of input.
- mode中不包含target时,不能使用perform_func
Examples
-------
首先构建模型model、训练数据加载器train_loader、验证数据加载器eval_loader:
```
import os
import torch
from scient.neuralnet import resnet, fit
import torchvision.transforms as tt
from torchvision.datasets import ImageFolder
# 数据转换(归一化和数据增强)
stats = ((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
train_tfms = tt.Compose([tt.RandomCrop(160, padding=4, padding_mode='reflect'),
tt.RandomHorizontalFlip(),
tt.ToTensor(),
tt.Normalize(*stats,inplace=True)])
valid_tfms = tt.Compose([tt.Resize([160,160]),tt.ToTensor(), tt.Normalize(*stats)])
# 创建ImageFolder对象
data_train = ImageFolder(os.path.join(os.path.dirname(fit.__file__),'..','test/data/imagewoof/train'), train_tfms)
data_eval = ImageFolder(os.path.join(os.path.dirname(fit.__file__),'..','test/data/imagewoof/val'), valid_tfms)
# 设置批量大小
batch_size = 2
# 创建训练集和验证集的数据加载器
train_loader = torch.utils.data.DataLoader(data_train, batch_size=batch_size, shuffle=True)
eval_loader = torch.utils.data.DataLoader(data_eval, batch_size=batch_size, shuffle=False)
#resnet50模型
model=resnet.ResNet50(n_class=3)
```
然后设置模型训练参数、训练模型:
```
#设置训练参数
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-4)
loss_func=torch.nn.CrossEntropyLoss()
model=fit.set(model,optimizer=optimizer,loss_func=loss_func,n_iter=5,device=device)
#训练
model.fit(train_loader=train_loader,eval_loader=eval_loader,mode=('input','target'))
```
运行结果:
```
train iter 0: avg_batch_loss=1.27345: 100%|██████████| 60/60 [00:04<00:00, 12.92it/s]
eval iter 0: avg_batch_loss=1.33363: 100%|██████████| 8/8 [00:00<00:00, 59.44it/s]
train iter 1: avg_batch_loss=1.24023: 100%|██████████| 60/60 [00:04<00:00, 13.39it/s]
eval iter 1: avg_batch_loss=1.08319: 100%|██████████| 8/8 [00:00<00:00, 58.83it/s]
train iter 2: batch_loss=1.42699 avg_batch_loss=1.16666: 63%|██████▎ | 38/60 [00:02<00:01, 13.37it/s]
```
Examples: 训练时不使用eval_loader
-------
```
#设置训练参数
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-4)
loss_func=torch.nn.CrossEntropyLoss()
model=fit.set(model,optimizer=optimizer,loss_func=loss_func,n_iter=5,device=device)
#训练
model.fit(train_loader=train_loader,mode=('input','target'))
```
运行结果:
```
train iter 0: avg_batch_loss=1.07998: 100%|██████████| 60/60 [00:04<00:00, 12.27it/s]
train iter 1: avg_batch_loss=1.16323: 100%|██████████| 60/60 [00:04<00:00, 12.95it/s]
train iter 2: batch_loss=0.61398 avg_batch_loss=1.00838: 67%|██████▋ | 40/60 [00:03<00:01, 13.06it/s]
```
Examples: 使用scheduler在训练过程中改变学习率等optimizer参数
-------
```
#设置训练参数
n_iter=5
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.0001, epochs=n_iter,steps_per_epoch=len(train_loader))
loss_func=torch.nn.CrossEntropyLoss()
model=fit.set(model,optimizer=optimizer,loss_func=loss_func,n_iter=n_iter,device=device)
#训练
model.fit(train_loader=train_loader,eval_loader=eval_loader,mode=('input','target'))
```
Examples: 使用perform_func在训练过程中评估模型性能
-------
```
#设置训练参数
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-4)
loss_func=torch.nn.CrossEntropyLoss()
def perform_func(y_hat,y):#perform_func的输入是预测值y_hat和实际值y
y_hat,y=torch.concat(y_hat),torch.concat(y)#先将y_hat和y分别concat,由于y_hat和y是按loader分批计算和收集的,所以y_hat和y是batch_size大小的多个对象组成的list
_,y_hat=y_hat.max(axis=1)#该模型输出值y_hat最大值对应的索引是预测的类别
return round((y_hat==y).sum().item()/len(y),4)#输出准确率,并保留4位小数
model=fit.set(model,optimizer=optimizer,loss_func=loss_func,perform_func=perform_func,n_iter=5,device=device)
#训练
model.fit(train_loader=train_loader,eval_loader=eval_loader,mode=('input','target'))
```
运行结果,可以在每个iter结束后得到perform的值:
```
train iter 0: avg_batch_loss=1.27428 perform=0.3417: 100%|██████████| 60/60 [00:04<00:00, 12.34it/s]
eval iter 0: avg_batch_loss=1.09305 perform=0.4: 100%|██████████| 8/8 [00:00<00:00, 55.59it/s]
train iter 1: avg_batch_loss=1.09102 perform=0.4417: 100%|██████████| 60/60 [00:04<00:00, 13.30it/s]
eval iter 1: avg_batch_loss=1.18128 perform=0.4: 100%|██████████| 8/8 [00:00<00:00, 60.46it/s]
train iter 2: avg_batch_loss=1.24860 perform=0.3583: 100%|██████████| 60/60 [00:04<00:00, 13.19it/s]
eval iter 2: avg_batch_loss=1.23469 perform=0.4: 100%|██████████| 8/8 [00:00<00:00, 60.57it/s]
```
Examples: 使用grad_func在训练过程对梯度进行裁剪
-------
```
#设置训练参数
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-4)
loss_func=torch.nn.CrossEntropyLoss()
def grad_func(x):#grad_func的输入是model.parameters(),该操作在loss.backward()后起作用
torch.nn.utils.clip_grad_value_(x, 0.1)
model=fit.set(model,optimizer=optimizer,loss_func=loss_func,grad_func=grad_func,n_iter=5,device=device)
#训练
model.fit(train_loader=train_loader,eval_loader=eval_loader,mode=('input','target'))
```
Examples: 使用n_batch_step在小显存上模拟大batch_size的训练
-------
该功能实现了多次反向误差传播并累积梯度后,再让optimizer进行梯度下降优化。
```
#设置训练参数
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-4)
loss_func=torch.nn.CrossEntropyLoss()
model=fit.set(model,optimizer=optimizer,loss_func=loss_func,n_iter=5,device=device,n_batch_step=5)
#训练
model.fit(train_loader=train_loader,eval_loader=eval_loader,mode=('input','target'))
```
Examples: 当训练到一定迭代次数后,如果模型在验证集上性能下降,提前停止训练
-------
用n_iter=(min_iter,max_iter)设置模型的最小和最大训练迭代次数,当模型训练迭次数超过min_iter时,判断本次迭代训练模型性能是否优于上次迭代训练模型性能,如果不优于上次,则停止训练。过功能可防止过多的训练导致过拟合。该功能需要在eval_loader上计算perform,因此eval_loader不能为空,且perform_func输出必须为一个数值,该数值越大表示模型越优。
```
#设置训练参数
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-4)
loss_func=torch.nn.CrossEntropyLoss()
def perform_func(y_hat,y):#perform_func的输入是预测值y_hat和实际值y
y_hat,y=torch.concat(y_hat),torch.concat(y)#先将y_hat和y分别concat,由于y_hat和y是按loader分批计算和收集的,所以y_hat和y是batch_size大小的多个对象组成的list
_,y_hat=y_hat.max(axis=1)#该模型输出值y_hat最大值对应的索引是预测的类别
return round((y_hat==y).sum().item()/len(y),4)#输出准确率,并保留4位小数
model=fit.set(model,optimizer=optimizer,loss_func=loss_func,perform_func=perform_func,n_iter=(5,20),device=device)
#训练
model.fit(train_loader=train_loader,eval_loader=eval_loader,mode=('input','target'))
```
运行结果,可以看到模型运行到第iter 6停止,并提示性能最佳的模型是iter 4:
```
train iter 0: avg_batch_loss=1.17016 perform=0.375: 100%|██████████| 60/60 [00:04<00:00, 12.60it/s]
eval iter 0: avg_batch_loss=1.48805 perform=0.3333: 100%|██████████| 8/8 [00:00<00:00, 60.23it/s]
train iter 1: avg_batch_loss=1.17200 perform=0.3833: 100%|██████████| 60/60 [00:04<00:00, 13.08it/s]
eval iter 1: avg_batch_loss=1.18933 perform=0.2667: 100%|██████████| 8/8 [00:00<00:00, 59.39it/s]
train iter 2: avg_batch_loss=1.09923 perform=0.4333: 100%|██████████| 60/60 [00:04<00:00, 13.14it/s]
eval iter 2: avg_batch_loss=1.32449 perform=0.3333: 100%|██████████| 8/8 [00:00<00:00, 60.92it/s]
train iter 3: avg_batch_loss=1.20507 perform=0.4083: 100%|██████████| 60/60 [00:05<00:00, 11.66it/s]
eval iter 3: avg_batch_loss=1.23331 perform=0.2667: 100%|██████████| 8/8 [00:00<00:00, 57.59it/s]
train iter 4: avg_batch_loss=1.09205 perform=0.4167: 100%|██████████| 60/60 [00:04<00:00, 12.87it/s]
eval iter 4: avg_batch_loss=1.11206 perform=0.4: 100%|██████████| 8/8 [00:00<00:00, 59.80it/s]
train iter 5: avg_batch_loss=1.10706 perform=0.4583: 100%|██████████| 60/60 [00:04<00:00, 12.94it/s]
eval iter 5: avg_batch_loss=1.07162 perform=0.4: 100%|██████████| 8/8 [00:00<00:00, 39.96it/s]
train iter 6: avg_batch_loss=1.15846 perform=0.4333: 100%|██████████| 60/60 [00:04<00:00, 12.34it/s]
eval iter 6: avg_batch_loss=1.16467 perform=0.4: 100%|██████████| 8/8 [00:00<00:00, 58.85it/s]early stop and the best model is iter 4, the perform is 0.4
```
Examples: 训练过程中实时显示loss曲线,并在每一个iter完成后保存模型
-------
设置n_batch_plot和save_path,保存的模型以checkpoint为后缀名,可以用torch.load打开保存的模型,模型里保存了3项内容:model_state_dict、optimizer_state_dict、batch_loss
```
#设置训练参数
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-4)
loss_func=torch.nn.CrossEntropyLoss()
model=fit.set(model,optimizer=optimizer,loss_func=loss_func,n_iter=5,device=device,n_batch_plot=5,save_path='d:/')
#训练
model.fit(train_loader=train_loader,eval_loader=eval_loader,mode=('input','target'))
#打开并查看保存的模型
checkpoint=torch.load('D:/ResNet_iter_2.checkpoint')
checkpoint.keys()
checkpoint['batch_loss']
checkpoint['model_state_dict']
```
Examples: 如果型输出结果本身就是损失,可以省略loss_func
-------
先定义一个输出为loss的模型
```
#模型
class output_loss(torch.nn.Module):
def __init__(self):
super(output_loss,self).__init__()
self.model=resnet.ResNet50(n_class=3)
self.loss_func=torch.nn.CrossEntropyLoss()
def forward(self,x,y):
y_hat=self.model(x)
return self.loss_func(y_hat,y)#输出为loss无需在训练过程中计算loss
model=output_loss()
```
然后设置模型训练参数时,省略loss_func,因为此时loader的input和target都要输入到模型的forward中,因此可以将其看成inputs=[input,target],在训练时mode='inputs'
```
#设置训练参数
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-4)
model=fit.set(model,optimizer=optimizer,n_iter=10,device=device)
#训练
model.fit(train_loader=train_loader,eval_loader=eval_loader,mode='inputs')
```
Examples: 如果型输出结果有多个值,只用其中的1个值计算损失,需要自定义loss_func
-------
先定义一个输出为多个值的模型
```
#模型
class output_multi(torch.nn.Module):
def __init__(self):
super(output_multi,self).__init__()
self.model=resnet.ResNet50(n_class=3)
def forward(self,x):
y_hat=self.model(x)
return y_hat,x#输出为两个值,只使用y_hat计算损失
model=output_multi()
```
然后设置模型训练参数时,对loss_func进行修改,用其中需要参与loss计算的部分计算loss
```
loss_func_=torch.nn.CrossEntropyLoss()
def loss_func(y_hat,y):
return loss_func_(y_hat[0],y)#指定用输型输出的第0个值计算loss
#设置训练参数
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-4)
model=fit.set(model,optimizer=optimizer,loss_func=loss_func,n_iter=10,device=device)
#训练
model.fit(train_loader=train_loader,eval_loader=eval_loader,mode=('input','target'))
```
Raw data
{
"_id": null,
"home_page": null,
"name": "scient",
"maintainer": null,
"docs_url": null,
"requires_python": ">=3.10",
"maintainer_email": null,
"keywords": "science compute, image, natural language, machine learning, neural network, optimize algorithm, graphic algorithm",
"author": "scient",
"author_email": "yaomsn@live.cn",
"download_url": "https://files.pythonhosted.org/packages/54/6b/fd89d1fa785c0f7e6326b3e648912789929a232105604263943ff808240b/scient-0.8.0.tar.gz",
"platform": "any",
"description": "# scient\r\n\r\n**scient**\u4e00\u4e2a\u7528python\u5b9e\u73b0\u79d1\u5b66\u8ba1\u7b97\u76f8\u5173\u7b97\u6cd5\u7684\u5305\uff0c\u5305\u62ec\u81ea\u7136\u8bed\u8a00\u3001\u56fe\u50cf\u3001\u795e\u7ecf\u7f51\u7edc\u3001\u4f18\u5316\u7b97\u6cd5\u3001\u673a\u5668\u5b66\u4e60\u3001\u56fe\u8ba1\u7b97\u7b49\u6a21\u5757\u3002\r\n\r\n**scient**\u6e90\u7801\u548c\u7f16\u8bd1\u5b89\u88c5\u5305\u53ef\u4ee5\u5728`Python package index`\u83b7\u53d6\u3002\r\n\r\nThe source code and binary installers for the latest released version are available at the [Python package index].\r\n\r\n[https://pypi.org/project/scient](https://pypi.org/project/scient)\r\n\r\n\u53ef\u4ee5\u7528`pip`\u5b89\u88c5`scient`\u3002\r\n\r\nYou can install `scient` like this:\r\n \r\n```\r\npip install scient\r\n```\r\n\r\n\u4e5f\u53ef\u4ee5\u7528`setup.py`\u5b89\u88c5\u3002\r\n\r\nOr in the `scient` directory, execute:\r\n\r\n```\r\npython setup.py install\r\n```\r\n\r\n## scient.image\r\n\r\n\u56fe\u50cf\u76f8\u5173\u7b97\u6cd5\u6a21\u5757\uff0c\u5305\u62ec\u8fb9\u7f18\u68c0\u6d4b\u3001\u56fe\u50cf\u76f8\u4f3c\u5ea6\u8ba1\u7b97\u3001\u56fe\u50cf\u8d28\u91cf\u8bc4\u4ef7\u3001\u56fe\u50cf\u7279\u5f81\u63d0\u53d6\u7b49\u3002\r\n\r\n### scient.image.friqa\r\n\r\n\u5168\u53c2\u8003\u56fe\u50cf\u8d28\u91cf\u8bc4\u4ef7\u6a21\u5757\uff0c\u5305\u62ec\u5cf0\u503c\u4fe1\u566a\u6bd4\uff08PSNR\uff09\uff0c\u7ed3\u6784\u76f8\u4f3c\u5ea6\uff08SSIM\uff09\uff0c\u76f4\u65b9\u56fe\u76f8\u4f3c\u5ea6\uff08HistSim\uff09\u3002\r\n\r\n#### scient.image.friqa.psnr(image1,image2,max_pix=255)\r\n\r\nParameters\r\n----------\r\nimage1 : numpy.array 2D or 3D\uff0c\u53c2\u8003\u56fe\u50cf\r\n\r\nimage2 : numpy.array 2D or 3D\uff0c\u5f85\u8bc4\u4ef7\u56fe\u50cf\r\n\r\nmax_pix : int, optional default is 255, \u50cf\u7d20\u503c\u7684\u6700\u5927\u503c\uff0c\u9ed8\u8ba4\u503c\u662f255.\r\n\r\nReturns\r\n-------\r\nfloat\r\n\r\nAlgorithms\r\n-------\r\nPSNR(Peak Signal to Noise Ratio)\uff0c\u5cf0\u503c\u4fe1\u566a\u6bd4\uff0c\u662f\u4e00\u79cd\u8bc4\u4ef7\u56fe\u50cf\u7684\u5ba2\u89c2\u6807\u51c6\uff0c\u5355\u4f4ddB\u3002\u56fe\u50cf\u5728\u7ecf\u8fc7\u538b\u7f29\u4e4b\u540e\uff0c\u4f1a\u5728\u67d0\u79cd\u7a0b\u5ea6\u4e0e\u539f\u59cb\u56fe\u50cf\u4e0d\u540c\uff0cPSNR\u503c\u7528\u6765\u8861\u91cf\u7ecf\u8fc7\u5904\u7406\u540e\u7684\u56fe\u50cf\u54c1\u8d28\u662f\u5426\u4ee4\u4eba\u6ee1\u610f\u3002\r\n\r\n$$\r\nPSNR=10 \\cdot \\log _ {10} ( \\frac { MAX _ I ^ 2 } { MSE }) = 20 \\cdot \\log _ {10} ( \\frac { MAX _ I } { MSE })\r\n$$\r\n\r\n\u5176\u4e2d\uff0c$MAX _ I$\u662f\u56fe\u50cf\u50cf\u7d20\u503c\u7684\u6700\u5927\u503c\uff0c\u4e00\u822c\u6bcf\u4e2a\u91c7\u6837\u70b9\u75288\u4f4d\u8868\u793a\uff0c\u90a3\u4e48$MAX _ I$\u5c31\u662f255\u3002\r\n\r\n$MSE$\u662f\u5f85\u8bc4\u4ef7\u56fe\u50cf\u4e0e\u53c2\u8003\u56fe\u50cf\u7684\u5747\u65b9\u8bef\u5dee\uff0c$MSE$\u8d8a\u5c0f\uff0cPSNR\u8d8a\u5927\uff1bPSNR\u8d8a\u5927\uff0c\u5f85\u8bc4\u4ef7\u56fe\u50cf\u8d28\u91cf\u8d8a\u597d\u3002\r\n\r\n* PSNR\u9ad8\u4e8e40dB\u8bf4\u660e\u5f85\u8bc4\u4ef7\u56fe\u50cf\u8d28\u91cf\u6781\u597d,\u975e\u5e38\u63a5\u8fd1\u539f\u59cb\u56fe\u50cf\uff1b\r\n* PSNR\u572830\u201440dB\u8bf4\u660e\u5f85\u8bc4\u4ef7\u56fe\u50cf\u8d28\u91cf\u662f\u8f83\u597d\uff0c\u867d\u7136\u6709\u660e\u663e\u5931\u771f\u4f46\u53ef\u4ee5\u63a5\u53d7\uff1b\r\n* PSNR\u572820\u201430dB\u8bf4\u660e\u5f85\u8bc4\u4ef7\u56fe\u50cf\u8d28\u91cf\u5dee\uff1b\r\n* PSNR\u4f4e\u4e8e20dB\u8bf4\u660e\u5f85\u8bc4\u4ef7\u56fe\u50cf\u8d28\u91cf\u4e0d\u53ef\u63a5\u53d7\u3002\r\n\r\n\r\nPSNR\u7f3a\u70b9\uff1a\u57fa\u4e8e\u5bf9\u5e94\u50cf\u7d20\u70b9\u95f4\u7684\u8bef\u5dee\uff0c\u5373\u57fa\u4e8e\u8bef\u5dee\u654f\u611f\u7684\u56fe\u50cf\u8d28\u91cf\u8bc4\u4ef7\u3002\u7531\u4e8e\u5e76\u672a\u8003\u8651\u5230\u4eba\u773c\u7684\u89c6\u89c9\u7279\u6027\uff08\u4eba\u773c\u5bf9\u7a7a\u95f4\u9891\u7387\u8f83\u4f4e\u7684\u5bf9\u6bd4\u5dee\u5f02\u654f\u611f\u5ea6\u8f83\u9ad8\uff0c\u4eba\u773c\u5bf9\u4eae\u5ea6\u5bf9\u6bd4\u5dee\u5f02\u7684\u654f\u611f\u5ea6\u8f83\u8272\u5ea6\u9ad8\uff0c\u4eba\u773c\u5bf9\u4e00\u4e2a \u533a\u57df\u7684\u611f\u77e5\u7ed3\u679c\u4f1a\u53d7\u5230\u5176\u5468\u56f4\u90bb\u8fd1\u533a\u57df\u7684\u5f71\u54cd\u7b49\uff09\uff0c\u56e0\u800c\u7ecf\u5e38\u51fa\u73b0\u8bc4\u4ef7\u7ed3\u679c\u4e0e\u4eba\u7684\u4e3b\u89c2\u611f\u89c9\u4e0d\u4e00\u81f4\u7684\u60c5\u51b5\u3002\r\n\r\nExamples\r\n-------\r\n\r\n```\r\nimport os\r\nfrom scient.image import friqa\r\nimport numpy\r\nfrom PIL import Image\r\n\r\nref_image='test/data/I10.BMP'\r\nimages=['test/data/I10.BMP','test/data/i10_23_3.bmp','test/data/i10_23_4.bmp','test/data/i10_23_5.bmp','test/data/i10_24_5.bmp']\r\n\r\n#\u8bfb\u53d6\u56fe\u50cf\u6587\u4ef6\r\nref_image=Image.open(os.path.join(os.path.dirname(friqa.__file__),'..',ref_image))\r\nimages=[Image.open(os.path.join(os.path.dirname(friqa.__file__),'..',i)) for i in images]\r\n\r\n#\u8ba1\u7b97psnr\r\nfor i in images:\r\n print(friqa.psnr(numpy.array(ref_image),numpy.array(i)))\r\n```\r\n\r\n\u8fd0\u884c\u7ed3\u679c\r\n\r\n```\r\n100\r\n32.436263852012544\r\n31.184291262813648\r\n30.272831107297733\r\n29.3584810257951\r\n```\r\n \r\n \r\n#### scient.image.friqa.ssim(image1,image2,k1=0.01,k2=0.03,block_size=(8, 8),max_pix=255)\r\n\r\nParameters\r\n----------\r\nimage1 : numpy.array 2D\r\n\r\nimage2 : numpy.array 2D\r\n\r\nk1 : float, optional\uff0ck1<<1,\u907f\u514d\u5206\u6bcd\u4e3a0\u9020\u6210\u4e0d\u7a33\u5b9a. The default is 0.01.\r\n\r\nk2 : float, optional\uff0ck2<<1,\u907f\u514d\u5206\u6bcd\u4e3a0\u9020\u6210\u4e0d\u7a33\u5b9a. The default is 0.03.\r\n\r\nblock_size : tuple, optional\uff0c\u5c06\u56fe\u50cf\u5206\u6210\u591a\u4e2ablock,\u91c7\u7528gaussian\u52a0\u6743\u8ba1\u7b97\u6240\u6709block\u7684\u5747\u503c\u3001\u65b9\u5dee\u3001\u534f\u65b9\u5dee,\u8fdb\u800c\u8ba1\u7b97\u6240\u6709block\u7684ssim,\u6700\u540e\u7684ssim\u53d6\u6240\u6709block\u7684\u5e73\u5747\u503c. The default is (8,8).\r\n\r\nmax_pix : int, optional default is 255, \u50cf\u7d20\u503c\u7684\u6700\u5927\u503c\uff0c\u9ed8\u8ba4\u503c\u662f255.\r\n\r\nReturns\r\n-------\r\nfloat\r\n\r\nAlgorithms\r\n-------\r\n\r\nSSIM(Structural Similarity)\uff0c\u7ed3\u6784\u76f8\u4f3c\u5ea6\uff0c\u7528\u4e8e\u8861\u91cf\u4e24\u4e2a\u56fe\u50cf\u76f8\u4f3c\u7a0b\u5ea6\uff0c\u6216\u68c0\u6d4b\u56fe\u50cf\u7684\u5931\u771f\u7a0b\u5ea6\u3002\r\nSSIM\u57fa\u4e8e\u6837\u672c\u4e4b\u95f4\u7684\u4eae\u5ea6(luminance,\u50cf\u7d20\u5e73\u5747\u503c)\u3001\u5bf9\u6bd4\u5ea6(contrast,\u50cf\u7d20\u6807\u51c6\u5dee)\u548c\u7ed3\u6784(structure,\u50cf\u7d20\u51cf\u5747\u503c\u9664\u4ee5\u6807\u51c6\u5dee)\u8ba1\u7b97\u3002\r\n\r\n$$\r\nSSIM(x,y)=f(l(x,y),c(x,y),s(x,y))\r\n$$\r\n\r\n$l(x,y)$\u4e3a\u4eae\u5ea6\u5bf9\u6bd4\u51fd\u6570\uff0c\u662f\u5173\u4e8e\u56fe\u50cf\u7684\u5e73\u5747\u7070\u5ea6$\u03bc_x,\u03bc_y$\u7684\u51fd\u6570\uff1b\r\n\r\n$$\r\nl(x,y)=\\frac { 2\u03bc_x \u03bc_y + C1 } { \u03bc_x^2 \u03bc_y^2 + C1 }\r\n$$\r\n\r\n$$\r\n\u03bc_x=\\frac { 1 } { N } \\sum^{N}_{i=1}{x_i}\r\n$$\r\n\r\n$$\r\nC1=(K_1 L)^2\r\n$$\r\n\r\n\u50cf\u7d20\u503c\u7684\u6700\u5927\u503c\uff0c\u9ed8\u8ba4\u503c\u662f255. K1<<1\u3002\r\n\r\n$c(x,y)$\u4e3a\u5bf9\u6bd4\u5ea6\u5bf9\u6bd4\u51fd\u6570\uff0c\u662f\u5173\u4e8e\u56fe\u50cf\u7684\u6807\u51c6\u5dee$\u03c3_x,\u03c3_y$\u7684\u51fd\u6570\uff1b\r\n\r\n$$\r\nc(x,y)=\\frac { 2\u03c3_x \u03c3_y + C2 } { \u03c3_x^2 \u03c3_y^2 + C2 }\r\n$$\r\n\r\n$$\r\n\u03c3_x=(\\frac { 1 } { N-1 } \\sum^{N}_{i=1}{(x_i-\u03bc_x)^2})^{\\frac { 1 } { 2 }}\r\n$$\r\n\r\n$$\r\nC2=(K_2 L)^2\r\n$$\r\n\r\nK2<<1\r\n\r\n$s(x,y)$\u4e3a\u7ed3\u6784\u5bf9\u6bd4\u51fd\u6570\uff0c\u662f\u5173\u4e8e\u56fe\u50cf\u7684\u6807\u51c6\u5316$\\frac { x-\u03bc_x } { \u03c3_x },\\frac { y-\u03bc_y } { \u03c3_y }$\u7684\u51fd\u6570\uff1b\r\n\r\n$$\r\ns(x,y)=\\frac { \u03c3_{xy} + C3 } { \u03c3_x \u03c3_y + C3 }\r\n$$\r\n\r\n$$\r\n\u03c3_{xy}=\\frac { 1 } { N-1 } (\\sum^{N}_{i=1}{(x_i-\u03bc_x)(y_i-\u03bc_y)})\r\n$$\r\n\r\n$$\r\nSSIM(x,y)=[l(x,y)]^\u03b1[c(x,y)]^\u03b2[s(x,y)]^\u03b3\r\n$$\r\n\r\n\u03b1,\u03b2,\u03b3\u53d61\uff0c\u4ee4$C_3=\\frac { C_2 } { 2 }$\uff0c\u53ef\u5c06SSIM\u7b80\u5316\u4e3a\uff1a\r\n\r\n$$\r\nSSIM(x,y)=\\frac { (2\u03bc_x \u03bc_y + C1)(2\u03c3_{xy} + C2) } { (\u03bc_x^2 \u03bc_y^2 + C1)(\u03c3_x^2 \u03c3_y^2 + C2) }\r\n$$\r\n\r\nSSIM\u53d6\u503c\u8303\u56f4\u4e3a[0,1]\uff0c\u503c\u8d8a\u5927\u8868\u793a\u56fe\u50cf\u8d28\u91cf\u8d8a\u597d\u3002\r\nSSIM\u5177\u6709\uff1a\u5bf9\u79f0\u6027\uff0cssim(x,y)==ssim(y,x);\r\n \u6709\u754c\u6027,ssim(x,y)<=1;\r\n \u6700\u5927\u503c\u552f\u4e00\u6027\uff0c\u5f53\u4e14\u4ec5\u5f53x==y\u65f6\uff0cssim(x,y)==1\u3002\r\nSSIM\u7f3a\u70b9\uff1a\u5bf9\u4e8e\u56fe\u50cf\u51fa\u73b0\u4f4d\u79fb\u3001\u7f29\u653e\u3001\u65cb\u8f6c\uff08\u7686\u5c5e\u4e8e\u975e\u7ed3\u6784\u6027\u7684\u5931\u771f\uff09\u7684\u60c5\u51b5\u65e0\u6cd5\u6709\u6548\u7684\u5224\u65ad\u3002\r\n\r\nExamples\r\n-------\r\n\r\n```\r\nimport os\r\nfrom scient.image import friqa\r\nimport numpy\r\nfrom PIL import Image\r\n\r\nref_image='test/data/I10.BMP'\r\nimages=['test/data/I10.BMP','test/data/i10_23_3.bmp','test/data/i10_23_4.bmp','test/data/i10_23_5.bmp','test/data/i10_24_5.bmp']\r\n\r\n#\u8bfb\u53d6\u56fe\u50cf\u6587\u4ef6\r\nref_image=Image.open(os.path.join(os.path.dirname(friqa.__file__),'..',ref_image))\r\nimages=[Image.open(os.path.join(os.path.dirname(friqa.__file__),'..',i)) for i in images]\r\n\r\n#\u8ba1\u7b97ssim\r\nfor i in images:\r\n print(friqa.ssim(numpy.array(ref_image.convert(\"L\")),numpy.array(i.convert(\"L\"))))\r\n```\r\n\r\n\u8fd0\u884c\u7ed3\u679c\r\n\r\n```\r\n1.0\r\n0.8568124416229375\r\n0.6810351495300123\r\n0.5575398637742431\r\n0.5072153083460104\r\n```\r\n\r\n\r\n### scient.image.feature\r\n\r\n\u56fe\u50cf\u7279\u5f81\u63d0\u53d6\u6a21\u5757\uff0c\u5305\u62ecBRISQUE\uff0c\u57fa\u4e8e\u7d2f\u79ef\u6982\u7387\u7684\u9510\u5316\u56e0\u5b50\uff08CPB\uff09\uff0c\u66dd\u5149\u5ea6\u3002\r\n\r\n#### scient.image.feature.brisque(image)\r\n\r\nParameters\r\n----------\r\nimage : numpy.array 2D\r\n\r\nReturns\r\n-------\r\ntuple\r\n('gdd_\u03b1','gdd_\u03c3',\r\n'aggd_\u03b11','aggd_\u03b71','aggd_\u03c3l1','aggd_\u03c3r1',\r\n'aggd_\u03b12','aggd_\u03b72','aggd_\u03c3l2','aggd_\u03c3r2',\r\n'aggd_\u03b13','aggd_\u03b73','aggd_\u03c3l3','aggd_\u03c3r3',\r\n'aggd_\u03b14','aggd_\u03b74','aggd_\u03c3l4','aggd_\u03c3r4')\r\n\r\nAlgorithms\r\n-------\r\n BRISQUE\uff08Blind/Referenceless Image Spatial QUality Evaluator\uff09\uff0c\u662f\u4e00\u79cd\u65e0\u53c2\u8003\u7684\u7a7a\u95f4\u57df\u56fe\u50cf\u8d28\u91cf\u8bc4\u4f30\u7b97\u6cd5\u3002\u5148\u8ba1\u7b97Mean Subtracted Contrast Normalized Coefficients\uff08MSCN\u7cfb\u6570\uff09\uff0cMSCN\u7cfb\u6570\u53cd\u6620\u4e86\u7531\u4e8e\u5931\u771f\u7684\u5b58\u5728\u800c\u6539\u53d8\u7684\u7279\u5f81\u7edf\u8ba1\uff0c\u53ef\u4ee5\u7528\u6765\u4f5c\u4e3a\u56fe\u50cf\u5931\u771f\u7684\u7edf\u8ba1\u7279\u5f81\u3002\u518d\u7528MSCN\u7cfb\u6570\u4f30\u8ba1Generalized Gaussian Distribution\uff08GDD\uff09\u7684\u53c2\u6570\u03b1\u3001\u03c3\uff0c\u4ee5\u53caAsymmetric Generalized Gaussian Distribution\uff08AGGD\uff09\u5728Horizontal Neighbour, Vertical Neighbour, On Diagonal Neighbour, Off Diagonal Neighbour\u4e0a\u7684\u53c2\u6570\u03b1\u3001\u03b7\u3001\u03c3l\u3001\u03c3r\uff0c\u5c06GDD\u7684\u4e24\u4e2a\u53c2\u6570\u548cAGGD\u768416\u4e2a\u53c2\u6570\u4f5c\u4e3a\u8f93\u51fa\u7684\u7279\u5f81\u3002\r\n\r\nMSCN\u7cfb\u6570\uff1a\r\n\r\n$$\r\nMSCN(i,j)=\\frac { I(i,j)-\u03bc(i,j) } { \u03c3(i,j)+C }\r\n$$\r\n\r\n$$\r\n\u03bc(i,j)=\\sum^{K}_{k=-K}{\\sum^{L}_{l=-L}{w_{k,l}I_{k,l}(i,j)}}\r\n$$\r\n\r\n$$\r\n\u03c3(i,j)=\\sqrt{\\sum^{K}_{k=-K}{\\sum^{L}_{l=-L}{w_{k,l}(I_{k,l}(i,j)-\u03bc(i,j))^2}}}\r\n$$\r\n\r\n\u5176\u4e2d$I(i,j)$\u8868\u793a\u539f\u59cb\u56fe\u50cfi\u884cj\u5217\u5143\u7d20\u7684\u503c\u3002\r\n\r\nGeneralized Gaussian Distribution\uff1a\r\n\r\n$$\r\nf(x;\u03b1,\u03c3^2)=\\frac {\u03b1} {2\u03b2\u0393(1/\u03b1)} e^{-(\\frac {|x|}{\u03b2})^\u03b1}\r\n$$\r\n\r\n$$\r\n\u03b2=\u03c3\\sqrt{\\frac{\u0393(1/\u03b1)}{\u0393(3/\u03b1)}}\r\n$$\r\n\r\n$$\r\n\u0393(\u03b1)=\\int^{\\infty}_{0}{t^{\u03b1-1}e^{-t}dt} \u03b1>0\r\n$$\r\n\r\nNeighbours:\r\n\r\n$$\r\nHorizontalNeighbour(i,j)=MSCN(i,j)MSCN(i,j+1)\r\n$$\r\n\r\n$$\r\nVerticalNeighbour(i,j)=MSCN(i,j)MSCN(i+1,j)\r\n$$\r\n\r\n$$\r\nOnDiagonalNeighbour(i,j)=MSCN(i,j)MSCN(i+1,j+1)\r\n$$\r\n\r\n$$\r\nOffDiagonalNeighbour(i,j)=MSCN(i,j)MSCN(i+1,j-1)\r\n$$\r\n\r\nAsymmetric Generalized Gaussian Distribution:\r\n\r\n$$\r\nf(x;\u03b1,\u03c3_l^2,\u03c3_r^2)=\r\n\\frac {\u03b1}{(\u03b2_l+\u03b2_r)\u0393(1/\u03b1)}e^{-(\\frac {-x}{\u03b2_l})^\u03b1} x<0\r\n\\frac {\u03b1}{(\u03b2_l+\u03b2_r)\u0393(1/\u03b1)}e^{-(\\frac {x}{\u03b2_r})^\u03b1} x>=0\r\n\u03b2_l=\u03c3_l\\sqrt{\\frac{\u0393(1/\u03b1)}{\u0393(3/\u03b1)}} \\\\\r\n\u03b2_r=\u03c3_r\\sqrt{\\frac{\u0393(1/\u03b1)}{\u0393(3/\u03b1)}}\r\n$$\r\n\r\nExamples\r\n-------\r\n\r\n```\r\nimport os\r\nfrom scient.image import feature\r\nimport numpy\r\nfrom PIL import Image\r\n\r\nimages=['test/data/I10.BMP','test/data/i10_23_3.bmp','test/data/i10_23_4.bmp','test/data/i10_23_5.bmp','test/data/i10_24_5.bmp']\r\n\r\n#\u8bfb\u53d6\u56fe\u50cf\u6587\u4ef6\r\nimages=[Image.open(os.path.join(os.path.dirname(feature.__file__),'..',i)) for i in images]\r\n\r\n#\u8ba1\u7b97brisque\r\nbrisques=[]\r\nfor i in images:\r\n brisques.append(feature.brisque(numpy.array(i.convert('L'))))\r\nprint(brisques)\r\n```\r\n\r\n\u8fd0\u884c\u7ed3\u679c\r\n\r\n```\r\n[(2.8390000000000026, 0.5387382509471336, 0.8180000000000005, 0.1597336483186561, 0.19928197982139934, 0.4696747920784309, 0.8640000000000005, 0.17081167501931036, 0.1703080506100513, 0.440894038756712, 0.8610000000000007, -0.002437981115828319, 0.2983089768677447, 0.2943996123553127, 0.8670000000000007, 0.03657370089459203, 0.2641503963750437, 0.32229688865209727), (2.179000000000002, 0.3755805588864052, 0.6610000000000005, 0.2105638785869636, 0.06573065885425396, 0.3546433105372317, 0.7250000000000005, 0.2035633011201771, 0.04895566298941261, 0.2895746994148656, 0.7110000000000005, 0.09196294223642214, 0.10660221933416321, 0.22150476223116147, 0.7220000000000004, 0.10061626044729756, 0.09951649928883519, 0.22307536755643081), (1.489000000000001, 0.19567592119387475, 0.4370000000000002, 0.16656579278574843, 0.005144811587270607, 0.1595102390164801, 0.4400000000000002, 0.14819323960693676, 0.007946536338563829, 0.14400949152877282, 0.46900000000000025, 0.1304195444573072, 0.010840852166168865, 0.12285748598680354, 0.47300000000000025, 0.12785146234621667, 0.011051488263507676, 0.11939877242752284), (1.2570000000000008, 0.1189807661854071, 0.2940000000000001, 0.09858069094224381, 0.0033503171775502846, 0.1003980673321924, 0.2960000000000001, 0.09662228540309649, 0.0037953392707882772, 0.09854664422093222, 0.3160000000000001, 0.08840261656054116, 0.004225987220008733, 0.08029184471742051, 0.3180000000000001, 0.08631426420092875, 0.004399447310061135, 0.07751730107145516), (1.203000000000001, 0.14103130545847511, 0.3270000000000001, 0.10623288442963101, 0.008919473174326557, 0.12226537626029133, 0.3280000000000001, 0.06853644417080812, 0.02378947796849877, 0.10143999168472712, 0.33900000000000013, 0.05689116726400874, 0.02385946076111514, 0.08256978072093775, 0.33900000000000013, 0.05450324427873719, 0.02492368706293601, 0.0813272014967197)]\r\n```\r\n\r\n\r\n### scient.image.hash\r\n\r\n\u56fe\u50cfhash\u6a21\u5757\uff0c\u5305\u62ec\u5747\u503chash(mean hash),\u5dee\u503chash(diff hash),\u611f\u77e5hash(percept hash)\u3002\r\n\r\n#### scient.image.hash.percept(image,hash_size=64)\r\n\r\nParameters\r\n----------\r\nimage : numpy.array 2D\r\n\r\nhash_size : \u8f93\u51fahash\u503c\u7684\u957f\u5ea6\r\n\r\nReturns\r\n-------\r\nlist\r\n\r\nAlgorithms\r\n-------\r\n\u5148\u5c06\u56fe\u7247\u7f29\u653e\u6210hash_size*hash_size\u5927\u5c0f\uff0c\u7136\u540e\u5bf9\u56fe\u50cf\u8fdb\u884c\u79bb\u6563\u4f59\u5f26\u53d8\u6362DCT\uff0c\u5e76\u8f93\u51fa\u5de6\u4e0a\u89d2h*w=hash_size\u7684mean hash\u3002\r\nDCT\u662f\u4e00\u79cd\u7279\u6b8a\u7684\u5085\u7acb\u53f6\u53d8\u6362\uff0c\u5c06\u56fe\u7247\u4ece\u50cf\u7d20\u57df\u53d8\u6362\u4e3a\u9891\u7387\u57df\uff0c\u5e76\u4e14DCT\u77e9\u9635\u4ece\u5de6\u4e0a\u89d2\u5230\u53f3\u4e0b\u89d2\u4ee3\u8868\u8d8a\u6765\u8d8a\u9ad8\u9891\u7387\u7684\u7cfb\u6570\uff0c\u56fe\u7247\u7684\u4e3b\u8981\u4fe1\u606f\u4fdd\u7559\u5de6\u4e0a\u89d2\u7684\u4f4e\u9891\u533a\u57df\u3002\r\n\r\n\u4e00\u7ef4DCT\u53d8\u6362\uff1a\r\n\r\n$$\r\nF(x)=c(x)\\sum^{N-1}_{i=0}{f(i)cos(\\frac {(i+0.5)\u03c0}{N}x) }\r\n$$\r\n\r\n$$\r\nc(x)=\\left\\{\\begin{matrix}\\sqrt{\\frac{1}{N}} ,x=0\\\\\\sqrt{\\frac{2}{N}} ,x!=0 \\end{matrix}\\right.\r\n$$\r\n\r\nf(i)\u4e3a\u539f\u59cb\u7684\u4fe1\u53f7\uff0cF(x)\u662fDCT\u53d8\u6362\u540e\u7684\u7cfb\u6570\uff0cN\u4e3a\u539f\u59cb\u4fe1\u53f7\u7684\u70b9\u6570\uff0cc(x)\u662f\u8865\u507f\u7cfb\u6570\u3002\r\n\r\n\u4e8c\u7ef4DCT\u53d8\u6362\uff1a\r\n\r\n$$\r\nF(x,y)=c(x)c(y)sum^{N-1}_{i=0}{sum^{N-1}_{j=0}{f(i,j)cos(\\frac {(i+0.5)\u03c0}{N}x)cos(\\frac {(j+0.5)\u03c0}{N}y)}}\r\n$$\r\n\r\n$$\r\nc(x)=\\left\\{\\begin{matrix}\\sqrt{\\frac{1}{N}} ,x=0\\\\\\sqrt{\\frac{2}{N}} ,x!=0 \\end{matrix}\\right.\r\n$$\r\n\r\n\u4e8c\u7ef4DCT\u53d8\u6362\u4e5f\u53ef\u8868\u793a\u4e3a\uff1a\r\n\r\n$$\r\nF=AfA^T\r\n$$\r\n\r\n$$\r\nA(i,j)=c(i)cos(\\frac {(j+0.5)\u03c0}{N}i)\r\n$$\r\n\r\n\u6b64\u5f62\u5f0f\u66f4\u65b9\u4fbf\u8ba1\u7b97\u3002DCT\u53d8\u6362\u662f\u5bf9\u79f0\u7684\uff0c\u56e0\u6b64\u53ef\u4ee5\u5bf9\u7ecf\u8fc7DCT\u53d8\u6362\u7684\u56fe\u7247\u8fdb\u884c\u8fd8\u539f\u64cd\u4f5c\u3002\r\n\r\nExamples\r\n-------\r\n\r\n\u8ba1\u7b97\u56fe\u50cf\u611f\u77e5\u76f8\u4f3c\u5ea6\u65f6\uff0c\u9996\u5148\u8ba1\u7b97\u56fe\u50cf\u7684PHASH\u503c\uff0c\u518d\u91c7\u7528\u6d77\u660e\uff08hamming\uff09\u8ddd\u79bb\u76f8\u4f3c\u5ea6\u8ba1\u7b97\u56fe\u7247PHASH\u503c\u7684\u76f8\u4f3c\u5ea6\u3002\r\n```\r\n#\u91c7\u7528\u611f\u77e5hash\u8ba1\u7b97\u56fe\u7247\u7684\u611f\u77e5\u76f8\u4f3c\u5ea6\r\nimport os\r\nfrom scient.image import hash\r\nfrom scient.algorithms import similar\r\nimport numpy\r\nfrom PIL import Image\r\n\r\nref_image='test/data/I10.BMP'\r\nimages=['test/data/I10.BMP','test/data/i10_23_3.bmp','test/data/i10_23_4.bmp','test/data/i10_23_5.bmp','test/data/i10_24_5.bmp']\r\n\r\n#\u8bfb\u53d6\u56fe\u50cf\u6587\u4ef6\r\nref_image=Image.open(os.path.join(os.path.dirname(hash.__file__),'..',ref_image))\r\nimages=[Image.open(os.path.join(os.path.dirname(hash.__file__),'..',i)) for i in images]\r\n\r\n#\u8ba1\u7b97\u611f\u77e5hash\r\nphash=hash.percept(numpy.array(ref_image.convert(\"L\")))\r\nphashs=[hash.percept(numpy.array(i.convert(\"L\"))) for i in images]\r\n\r\n#\u8ba1\u7b97\u611f\u77e5\u76f8\u4f3c\u5ea6\r\nfor i in phashs:\r\n print(similar.hamming(i,phash))\r\n```\r\n\r\n\u8fd0\u884c\u7ed3\u679c\r\n\r\n```\r\n1.0\r\n0.9384615384615385\r\n0.8615384615384616\r\n0.8153846153846154\r\n0.6\r\n```\r\n\r\n## scient.neuralnet\r\n\r\n\u795e\u7ecf\u7f51\u7edc\u76f8\u5173\u7b97\u6cd5\u6a21\u5757\uff0c\u5305\u62ecattention\u3001transformer\u3001bert\u3001lstm\u3001resnet\u3001crf\u3001dataset\u3001fit\u7b49\u3002\r\n\r\n### scient.neuralnet.fit\r\n\r\n\u795e\u7ecf\u7f51\u7edc\u8bad\u7ec3\u6a21\u5757\uff0c\u5c06torch\u6784\u5efa\u7684\u795e\u7ecf\u7f51\u7edc\u6a21\u578b\u7684\u8bad\u7ec3\u65b9\u5f0f\u7b80\u5316\u4e3amodel.fit()\uff0c\u4f7ftorch\u795e\u7ecf\u7f51\u7edc\u6a21\u578b\u8bad\u7ec3\u66f4\u7b80\u6377\uff0c\u66f4\u4f18\u96c5\u3002\r\n\r\n\u4f7f\u7528\u6b65\u9aa4\uff1a\r\n\r\n\uff081\uff09\u57fa\u4e8etorch\u6784\u5efa\u6a21\u578bmodel\uff0c\u91c7\u7528torch.utils.data.DataLoader\u52a0\u8f7d\u8bad\u7ec3\u6570\u636e\u96c6train_loader\u3001\u9a8c\u8bc1\u6570\u636e\u96c6eval_loader(\u53ef\u9009)\uff1b\r\n\r\n\uff082\uff09\u91c7\u7528fit.set()\u8bbe\u7f6e\u6a21\u578b\u8bad\u7ec3\u53c2\u6570\uff0c\u53c2\u6570\u8be6\u60c5\uff1a\r\n* optimizer=None: \u4f18\u5316\u5668\uff0c\u53ef\u4ee5\u7528\u7c7b\u4f3ctorch.optim\u6a21\u5757\u5185\u7684\u4f18\u5316\u5668\u6765\u5b9a\u4e49\uff1b\r\n* scheduler=None: \u4f18\u5316\u5668\u7684\u8c03\u5ea6\u5668\uff0c\u53ef\u4ee5\u7528\u7c7b\u4f3ctorch.optim.lr_scheduler\u6a21\u5757\u5185\u7684\u8c03\u5ea6\u5668\u6765\u5b9a\u4e49\uff1b\r\n* loss_func=None: \u635f\u5931\u51fd\u6570\uff0c\u53ef\u4ee5\u7528\u7c7b\u4f3ctorch.nn.CrossEntropyLoss()\u6765\u5b9a\u4e49\uff1b\r\n* grad_func=None: \u68af\u5ea6\u64cd\u4f5c\u51fd\u6570\uff0c\u53ef\u8fdb\u884c\u5982\u68af\u5ea6\u88c1\u526a\u7684\u64cd\u4f5c\uff1b\r\n* perform_func=None: \u6a21\u578b\u6027\u80fd\u51fd\u6570\uff0c\u6a21\u578b\u4f20\u5165\u9884\u6d4b\u503c\u548c\u5b9e\u9645\u503c\uff0c\u7528\u4ee5\u8bc4\u4f30\u6a21\u578b\u6027\u80fd\uff1b\r\n* n_iter=10: \u6a21\u578b\u5728\u6570\u636e\u96c6\u4e0a\u8fed\u4ee3\u8bad\u7ec3\u7684\u6b21\u6570\uff1b\r\n - \u5982\u679cn_iter\u4e3aint,\u8868\u793a\u6a21\u578b\u5728\u6570\u636e\u96c6\u4e0a\u8fed\u4ee3\u8bad\u7ec3n_iter\u540e\u505c\u6b62\uff1b\r\n - \u5982\u679cn_iter\u4e3a(int,int),\u8868\u793a\u6a21\u578b\u5728\u6570\u636e\u96c6\u4e0a\u8fed\u4ee3\u8bad\u7ec3\u7684\u6700\u5c0fmin_iter\u548c\u6700\u5927max_iter\u6b21\u6570, \u5982\u679c\u8fed\u4ee3\u6b21\u6570\u8d85\u8fc7min_iter\u4e14eval\u7684perform_func\u6bd4\u4e0a\u4e00\u4e2aiter\u5927\uff0c\u7ed3\u675f\u8bad\u7ec3\u3002n_iter\u4e3a(int,int)\u65f6\uff0c\u5fc5\u987b\u63d0\u4f9beval_loader\uff0c\u4e14perform_func\u5fc5\u987b\u662f\u4e00\u4e2a\u6570\u503c\uff0c\u4e14\u503c\u8d8a\u5927\u6a21\u578b\u6027\u80fd\u8d8a\u597d\uff1b\r\n* device=None: \u6a21\u578b\u8bad\u7ec3\u7684\u8bbe\u5907\uff0c\u5982device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\uff1b\r\n* n_batch_step: \u6bcfn\u4e2abatch\u66f4\u65b0\u4e00\u6b21optimizer\u7684\u68af\u5ea6\uff0c\u4ee5\u8282\u7701\u663e\u5b58\u53ca\u8ba1\u7b97\u91cf\uff1b\r\n* n_batch_plot: \u6bcfn\u4e2abatch\u66f4\u65b0\u4e00\u4e0b\u635f\u5931\u66f2\u7ebf\uff0c\u8bad\u7ec3\u8fc7\u7a0b\u4e2d\u4f1a\u5b9e\u65f6\u7ed8\u5236\u635f\u5931\u66f2\u7ebf\uff1b\r\n* save_path: \u6bcf\u4e2aiter\u5b8c\u6210\u540e\u6a21\u578b\u4fdd\u5b58\u8def\u5f84\uff0c\u6a21\u578b\u540d\u4e3a\u201c\u6a21\u578b\u7c7b\u540d_iter_i.checkpoint\u201d,\u4fdd\u5b58\u7684\u5185\u5bb9\u4e3a{'model_state_dict':model.state_dict(),'optimizer_state_dict':optimizer.state_dict(),'batch_loss':batch_loss}\uff0c\u5982\u679c\u8bad\u7ec3\u65f6\u672a\u63d0\u4f9beval_loader\uff0cbatch_loss=train_batch_loss, \u5426\u5219batch_loss=[train_batch_loss,eval_batch_loss]\r\n\r\n\uff083\uff09\u91c7\u7528model.fit(train_loader,eval_loader,mode=('input','target'))\u8bad\u7ec3\u6a21\u578b\uff1a\r\n* train_loader: \u8bad\u7ec3\u6570\u636e\u96c6\r\n* eval_loader: \u9a8c\u8bc1\u6570\u636e\u96c6\r\n* mode: \u6570\u636e\u96c6\u5305\u542b\u7684\u5185\u5bb9\uff0c\u5206\u56db\u79cd\u60c5\u51b5\uff1a\r\n - mode=('input','target'), loader data item is one input and one target;\r\n - mode='input', loader data item is only one input;\r\n - mode=('inputs','target'), loader data item is a list of input and one target;\r\n - mode='inputs', loader data item is a list of input.\r\n - mode\u4e2d\u4e0d\u5305\u542btarget\u65f6\uff0c\u4e0d\u80fd\u4f7f\u7528perform_func\r\n \r\nExamples\r\n-------\r\n\r\n\u9996\u5148\u6784\u5efa\u6a21\u578bmodel\u3001\u8bad\u7ec3\u6570\u636e\u52a0\u8f7d\u5668train_loader\u3001\u9a8c\u8bc1\u6570\u636e\u52a0\u8f7d\u5668eval_loader\uff1a\r\n\r\n```\r\nimport os\r\nimport torch\r\nfrom scient.neuralnet import resnet, fit\r\nimport torchvision.transforms as tt\r\nfrom torchvision.datasets import ImageFolder\r\n\r\n# \u6570\u636e\u8f6c\u6362\uff08\u5f52\u4e00\u5316\u548c\u6570\u636e\u589e\u5f3a\uff09\r\nstats = ((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))\r\ntrain_tfms = tt.Compose([tt.RandomCrop(160, padding=4, padding_mode='reflect'), \r\n tt.RandomHorizontalFlip(), \r\n tt.ToTensor(), \r\n tt.Normalize(*stats,inplace=True)])\r\nvalid_tfms = tt.Compose([tt.Resize([160,160]),tt.ToTensor(), tt.Normalize(*stats)])\r\n\r\n# \u521b\u5efaImageFolder\u5bf9\u8c61\r\ndata_train = ImageFolder(os.path.join(os.path.dirname(fit.__file__),'..','test/data/imagewoof/train'), train_tfms)\r\ndata_eval = ImageFolder(os.path.join(os.path.dirname(fit.__file__),'..','test/data/imagewoof/val'), valid_tfms)\r\n\r\n# \u8bbe\u7f6e\u6279\u91cf\u5927\u5c0f\r\nbatch_size = 2\r\n\r\n# \u521b\u5efa\u8bad\u7ec3\u96c6\u548c\u9a8c\u8bc1\u96c6\u7684\u6570\u636e\u52a0\u8f7d\u5668\r\ntrain_loader = torch.utils.data.DataLoader(data_train, batch_size=batch_size, shuffle=True)\r\neval_loader = torch.utils.data.DataLoader(data_eval, batch_size=batch_size, shuffle=False)\r\n\r\n#resnet50\u6a21\u578b\r\nmodel=resnet.ResNet50(n_class=3)\r\n```\r\n\r\n\u7136\u540e\u8bbe\u7f6e\u6a21\u578b\u8bad\u7ec3\u53c2\u6570\u3001\u8bad\u7ec3\u6a21\u578b\uff1a\r\n\r\n```\r\n#\u8bbe\u7f6e\u8bad\u7ec3\u53c2\u6570\r\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\r\noptimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-4)\r\nloss_func=torch.nn.CrossEntropyLoss()\r\nmodel=fit.set(model,optimizer=optimizer,loss_func=loss_func,n_iter=5,device=device)\r\n\r\n#\u8bad\u7ec3\r\nmodel.fit(train_loader=train_loader,eval_loader=eval_loader,mode=('input','target'))\r\n```\r\n\r\n\u8fd0\u884c\u7ed3\u679c\uff1a\r\n\r\n```\r\ntrain iter 0: avg_batch_loss=1.27345: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 60/60 [00:04<00:00, 12.92it/s] \r\neval iter 0: avg_batch_loss=1.33363: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 8/8 [00:00<00:00, 59.44it/s] \r\ntrain iter 1: avg_batch_loss=1.24023: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 60/60 [00:04<00:00, 13.39it/s] \r\neval iter 1: avg_batch_loss=1.08319: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 8/8 [00:00<00:00, 58.83it/s] \r\ntrain iter 2: batch_loss=1.42699 avg_batch_loss=1.16666: 63%|\u2588\u2588\u2588\u2588\u2588\u2588\u258e | 38/60 [00:02<00:01, 13.37it/s]\r\n```\r\n\r\nExamples: \u8bad\u7ec3\u65f6\u4e0d\u4f7f\u7528eval_loader\r\n-------\r\n\r\n```\r\n#\u8bbe\u7f6e\u8bad\u7ec3\u53c2\u6570\r\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\r\noptimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-4)\r\nloss_func=torch.nn.CrossEntropyLoss()\r\nmodel=fit.set(model,optimizer=optimizer,loss_func=loss_func,n_iter=5,device=device)\r\n\r\n#\u8bad\u7ec3\r\nmodel.fit(train_loader=train_loader,mode=('input','target'))\r\n```\r\n\r\n\u8fd0\u884c\u7ed3\u679c\uff1a\r\n\r\n```\r\ntrain iter 0: avg_batch_loss=1.07998: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 60/60 [00:04<00:00, 12.27it/s] \r\ntrain iter 1: avg_batch_loss=1.16323: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 60/60 [00:04<00:00, 12.95it/s] \r\ntrain iter 2: batch_loss=0.61398 avg_batch_loss=1.00838: 67%|\u2588\u2588\u2588\u2588\u2588\u2588\u258b | 40/60 [00:03<00:01, 13.06it/s]\r\n```\r\n\r\nExamples: \u4f7f\u7528scheduler\u5728\u8bad\u7ec3\u8fc7\u7a0b\u4e2d\u6539\u53d8\u5b66\u4e60\u7387\u7b49optimizer\u53c2\u6570\r\n-------\r\n\r\n```\r\n#\u8bbe\u7f6e\u8bad\u7ec3\u53c2\u6570\r\nn_iter=5\r\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\r\noptimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-4)\r\nscheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.0001, epochs=n_iter,steps_per_epoch=len(train_loader))\r\nloss_func=torch.nn.CrossEntropyLoss()\r\nmodel=fit.set(model,optimizer=optimizer,loss_func=loss_func,n_iter=n_iter,device=device)\r\n\r\n#\u8bad\u7ec3\r\nmodel.fit(train_loader=train_loader,eval_loader=eval_loader,mode=('input','target'))\r\n```\r\n\r\nExamples: \u4f7f\u7528perform_func\u5728\u8bad\u7ec3\u8fc7\u7a0b\u4e2d\u8bc4\u4f30\u6a21\u578b\u6027\u80fd\r\n-------\r\n\r\n```\r\n#\u8bbe\u7f6e\u8bad\u7ec3\u53c2\u6570\r\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\r\noptimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-4)\r\nloss_func=torch.nn.CrossEntropyLoss()\r\ndef perform_func(y_hat,y):#perform_func\u7684\u8f93\u5165\u662f\u9884\u6d4b\u503cy_hat\u548c\u5b9e\u9645\u503cy\r\n y_hat,y=torch.concat(y_hat),torch.concat(y)#\u5148\u5c06y_hat\u548cy\u5206\u522bconcat\uff0c\u7531\u4e8ey_hat\u548cy\u662f\u6309loader\u5206\u6279\u8ba1\u7b97\u548c\u6536\u96c6\u7684\uff0c\u6240\u4ee5y_hat\u548cy\u662fbatch_size\u5927\u5c0f\u7684\u591a\u4e2a\u5bf9\u8c61\u7ec4\u6210\u7684list\r\n _,y_hat=y_hat.max(axis=1)#\u8be5\u6a21\u578b\u8f93\u51fa\u503cy_hat\u6700\u5927\u503c\u5bf9\u5e94\u7684\u7d22\u5f15\u662f\u9884\u6d4b\u7684\u7c7b\u522b\r\n return round((y_hat==y).sum().item()/len(y),4)#\u8f93\u51fa\u51c6\u786e\u7387\uff0c\u5e76\u4fdd\u75594\u4f4d\u5c0f\u6570\r\nmodel=fit.set(model,optimizer=optimizer,loss_func=loss_func,perform_func=perform_func,n_iter=5,device=device)\r\n\r\n#\u8bad\u7ec3\r\nmodel.fit(train_loader=train_loader,eval_loader=eval_loader,mode=('input','target'))\r\n```\r\n\r\n\u8fd0\u884c\u7ed3\u679c\uff0c\u53ef\u4ee5\u5728\u6bcf\u4e2aiter\u7ed3\u675f\u540e\u5f97\u5230perform\u7684\u503c\uff1a\r\n\r\n```\r\ntrain iter 0: avg_batch_loss=1.27428 perform=0.3417: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 60/60 [00:04<00:00, 12.34it/s] \r\neval iter 0: avg_batch_loss=1.09305 perform=0.4: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 8/8 [00:00<00:00, 55.59it/s] \r\ntrain iter 1: avg_batch_loss=1.09102 perform=0.4417: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 60/60 [00:04<00:00, 13.30it/s] \r\neval iter 1: avg_batch_loss=1.18128 perform=0.4: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 8/8 [00:00<00:00, 60.46it/s] \r\ntrain iter 2: avg_batch_loss=1.24860 perform=0.3583: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 60/60 [00:04<00:00, 13.19it/s] \r\neval iter 2: avg_batch_loss=1.23469 perform=0.4: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 8/8 [00:00<00:00, 60.57it/s] \r\n```\r\n\r\nExamples: \u4f7f\u7528grad_func\u5728\u8bad\u7ec3\u8fc7\u7a0b\u5bf9\u68af\u5ea6\u8fdb\u884c\u88c1\u526a\r\n-------\r\n\r\n```\r\n#\u8bbe\u7f6e\u8bad\u7ec3\u53c2\u6570\r\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\r\noptimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-4)\r\nloss_func=torch.nn.CrossEntropyLoss()\r\ndef grad_func(x):#grad_func\u7684\u8f93\u5165\u662fmodel.parameters(),\u8be5\u64cd\u4f5c\u5728loss.backward()\u540e\u8d77\u4f5c\u7528\r\n torch.nn.utils.clip_grad_value_(x, 0.1)\r\nmodel=fit.set(model,optimizer=optimizer,loss_func=loss_func,grad_func=grad_func,n_iter=5,device=device)\r\n\r\n#\u8bad\u7ec3\r\nmodel.fit(train_loader=train_loader,eval_loader=eval_loader,mode=('input','target'))\r\n```\r\n\r\nExamples: \u4f7f\u7528n_batch_step\u5728\u5c0f\u663e\u5b58\u4e0a\u6a21\u62df\u5927batch_size\u7684\u8bad\u7ec3\r\n-------\r\n\u8be5\u529f\u80fd\u5b9e\u73b0\u4e86\u591a\u6b21\u53cd\u5411\u8bef\u5dee\u4f20\u64ad\u5e76\u7d2f\u79ef\u68af\u5ea6\u540e\uff0c\u518d\u8ba9optimizer\u8fdb\u884c\u68af\u5ea6\u4e0b\u964d\u4f18\u5316\u3002\r\n\r\n```\r\n#\u8bbe\u7f6e\u8bad\u7ec3\u53c2\u6570\r\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\r\noptimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-4)\r\nloss_func=torch.nn.CrossEntropyLoss()\r\nmodel=fit.set(model,optimizer=optimizer,loss_func=loss_func,n_iter=5,device=device,n_batch_step=5)\r\n\r\n#\u8bad\u7ec3\r\nmodel.fit(train_loader=train_loader,eval_loader=eval_loader,mode=('input','target'))\r\n```\r\n\r\nExamples: \u5f53\u8bad\u7ec3\u5230\u4e00\u5b9a\u8fed\u4ee3\u6b21\u6570\u540e\uff0c\u5982\u679c\u6a21\u578b\u5728\u9a8c\u8bc1\u96c6\u4e0a\u6027\u80fd\u4e0b\u964d\uff0c\u63d0\u524d\u505c\u6b62\u8bad\u7ec3\r\n-------\r\n\r\n\u7528n_iter=(min_iter,max_iter)\u8bbe\u7f6e\u6a21\u578b\u7684\u6700\u5c0f\u548c\u6700\u5927\u8bad\u7ec3\u8fed\u4ee3\u6b21\u6570\uff0c\u5f53\u6a21\u578b\u8bad\u7ec3\u8fed\u6b21\u6570\u8d85\u8fc7min_iter\u65f6\uff0c\u5224\u65ad\u672c\u6b21\u8fed\u4ee3\u8bad\u7ec3\u6a21\u578b\u6027\u80fd\u662f\u5426\u4f18\u4e8e\u4e0a\u6b21\u8fed\u4ee3\u8bad\u7ec3\u6a21\u578b\u6027\u80fd\uff0c\u5982\u679c\u4e0d\u4f18\u4e8e\u4e0a\u6b21\uff0c\u5219\u505c\u6b62\u8bad\u7ec3\u3002\u8fc7\u529f\u80fd\u53ef\u9632\u6b62\u8fc7\u591a\u7684\u8bad\u7ec3\u5bfc\u81f4\u8fc7\u62df\u5408\u3002\u8be5\u529f\u80fd\u9700\u8981\u5728eval_loader\u4e0a\u8ba1\u7b97perform\uff0c\u56e0\u6b64eval_loader\u4e0d\u80fd\u4e3a\u7a7a\uff0c\u4e14perform_func\u8f93\u51fa\u5fc5\u987b\u4e3a\u4e00\u4e2a\u6570\u503c\uff0c\u8be5\u6570\u503c\u8d8a\u5927\u8868\u793a\u6a21\u578b\u8d8a\u4f18\u3002\r\n\r\n```\r\n#\u8bbe\u7f6e\u8bad\u7ec3\u53c2\u6570\r\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\r\noptimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-4)\r\nloss_func=torch.nn.CrossEntropyLoss()\r\ndef perform_func(y_hat,y):#perform_func\u7684\u8f93\u5165\u662f\u9884\u6d4b\u503cy_hat\u548c\u5b9e\u9645\u503cy\r\n y_hat,y=torch.concat(y_hat),torch.concat(y)#\u5148\u5c06y_hat\u548cy\u5206\u522bconcat\uff0c\u7531\u4e8ey_hat\u548cy\u662f\u6309loader\u5206\u6279\u8ba1\u7b97\u548c\u6536\u96c6\u7684\uff0c\u6240\u4ee5y_hat\u548cy\u662fbatch_size\u5927\u5c0f\u7684\u591a\u4e2a\u5bf9\u8c61\u7ec4\u6210\u7684list\r\n _,y_hat=y_hat.max(axis=1)#\u8be5\u6a21\u578b\u8f93\u51fa\u503cy_hat\u6700\u5927\u503c\u5bf9\u5e94\u7684\u7d22\u5f15\u662f\u9884\u6d4b\u7684\u7c7b\u522b\r\n return round((y_hat==y).sum().item()/len(y),4)#\u8f93\u51fa\u51c6\u786e\u7387\uff0c\u5e76\u4fdd\u75594\u4f4d\u5c0f\u6570\r\nmodel=fit.set(model,optimizer=optimizer,loss_func=loss_func,perform_func=perform_func,n_iter=(5,20),device=device)\r\n\r\n#\u8bad\u7ec3\r\nmodel.fit(train_loader=train_loader,eval_loader=eval_loader,mode=('input','target'))\r\n```\r\n\r\n\u8fd0\u884c\u7ed3\u679c\uff0c\u53ef\u4ee5\u770b\u5230\u6a21\u578b\u8fd0\u884c\u5230\u7b2citer 6\u505c\u6b62\uff0c\u5e76\u63d0\u793a\u6027\u80fd\u6700\u4f73\u7684\u6a21\u578b\u662fiter 4\uff1a\r\n\r\n```\r\ntrain iter 0: avg_batch_loss=1.17016 perform=0.375: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 60/60 [00:04<00:00, 12.60it/s] \r\neval iter 0: avg_batch_loss=1.48805 perform=0.3333: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 8/8 [00:00<00:00, 60.23it/s] \r\ntrain iter 1: avg_batch_loss=1.17200 perform=0.3833: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 60/60 [00:04<00:00, 13.08it/s] \r\neval iter 1: avg_batch_loss=1.18933 perform=0.2667: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 8/8 [00:00<00:00, 59.39it/s] \r\ntrain iter 2: avg_batch_loss=1.09923 perform=0.4333: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 60/60 [00:04<00:00, 13.14it/s] \r\neval iter 2: avg_batch_loss=1.32449 perform=0.3333: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 8/8 [00:00<00:00, 60.92it/s] \r\ntrain iter 3: avg_batch_loss=1.20507 perform=0.4083: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 60/60 [00:05<00:00, 11.66it/s] \r\neval iter 3: avg_batch_loss=1.23331 perform=0.2667: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 8/8 [00:00<00:00, 57.59it/s] \r\ntrain iter 4: avg_batch_loss=1.09205 perform=0.4167: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 60/60 [00:04<00:00, 12.87it/s] \r\neval iter 4: avg_batch_loss=1.11206 perform=0.4: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 8/8 [00:00<00:00, 59.80it/s] \r\ntrain iter 5: avg_batch_loss=1.10706 perform=0.4583: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 60/60 [00:04<00:00, 12.94it/s] \r\neval iter 5: avg_batch_loss=1.07162 perform=0.4: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 8/8 [00:00<00:00, 39.96it/s] \r\ntrain iter 6: avg_batch_loss=1.15846 perform=0.4333: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 60/60 [00:04<00:00, 12.34it/s] \r\neval iter 6: avg_batch_loss=1.16467 perform=0.4: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 8/8 [00:00<00:00, 58.85it/s]early stop and the best model is iter 4, the perform is 0.4\r\n```\r\n\r\nExamples: \u8bad\u7ec3\u8fc7\u7a0b\u4e2d\u5b9e\u65f6\u663e\u793aloss\u66f2\u7ebf\uff0c\u5e76\u5728\u6bcf\u4e00\u4e2aiter\u5b8c\u6210\u540e\u4fdd\u5b58\u6a21\u578b\r\n-------\r\n\r\n\u8bbe\u7f6en_batch_plot\u548csave_path\uff0c\u4fdd\u5b58\u7684\u6a21\u578b\u4ee5checkpoint\u4e3a\u540e\u7f00\u540d\uff0c\u53ef\u4ee5\u7528torch.load\u6253\u5f00\u4fdd\u5b58\u7684\u6a21\u578b\uff0c\u6a21\u578b\u91cc\u4fdd\u5b58\u4e863\u9879\u5185\u5bb9\uff1amodel_state_dict\u3001optimizer_state_dict\u3001batch_loss\r\n\r\n```\r\n#\u8bbe\u7f6e\u8bad\u7ec3\u53c2\u6570\r\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\r\noptimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-4)\r\nloss_func=torch.nn.CrossEntropyLoss()\r\nmodel=fit.set(model,optimizer=optimizer,loss_func=loss_func,n_iter=5,device=device,n_batch_plot=5,save_path='d:/')\r\n\r\n#\u8bad\u7ec3\r\nmodel.fit(train_loader=train_loader,eval_loader=eval_loader,mode=('input','target'))\r\n\r\n#\u6253\u5f00\u5e76\u67e5\u770b\u4fdd\u5b58\u7684\u6a21\u578b\r\ncheckpoint=torch.load('D:/ResNet_iter_2.checkpoint')\r\ncheckpoint.keys()\r\ncheckpoint['batch_loss']\r\ncheckpoint['model_state_dict']\r\n```\r\n\r\nExamples: \u5982\u679c\u578b\u8f93\u51fa\u7ed3\u679c\u672c\u8eab\u5c31\u662f\u635f\u5931\uff0c\u53ef\u4ee5\u7701\u7565loss_func\r\n-------\r\n\u5148\u5b9a\u4e49\u4e00\u4e2a\u8f93\u51fa\u4e3aloss\u7684\u6a21\u578b\r\n\r\n```\r\n#\u6a21\u578b\r\nclass output_loss(torch.nn.Module):\r\n def __init__(self):\r\n super(output_loss,self).__init__()\r\n\r\n self.model=resnet.ResNet50(n_class=3)\r\n self.loss_func=torch.nn.CrossEntropyLoss()\r\n\r\n def forward(self,x,y):\r\n y_hat=self.model(x)\r\n return self.loss_func(y_hat,y)#\u8f93\u51fa\u4e3aloss\u65e0\u9700\u5728\u8bad\u7ec3\u8fc7\u7a0b\u4e2d\u8ba1\u7b97loss\r\n \r\nmodel=output_loss()\r\n```\r\n\r\n\u7136\u540e\u8bbe\u7f6e\u6a21\u578b\u8bad\u7ec3\u53c2\u6570\u65f6\uff0c\u7701\u7565loss_func\uff0c\u56e0\u4e3a\u6b64\u65f6loader\u7684input\u548ctarget\u90fd\u8981\u8f93\u5165\u5230\u6a21\u578b\u7684forward\u4e2d\uff0c\u56e0\u6b64\u53ef\u4ee5\u5c06\u5176\u770b\u6210inputs=[input,target]\uff0c\u5728\u8bad\u7ec3\u65f6mode='inputs'\r\n\r\n```\r\n#\u8bbe\u7f6e\u8bad\u7ec3\u53c2\u6570\r\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\r\noptimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-4)\r\nmodel=fit.set(model,optimizer=optimizer,n_iter=10,device=device)\r\n\r\n#\u8bad\u7ec3\r\nmodel.fit(train_loader=train_loader,eval_loader=eval_loader,mode='inputs')\r\n```\r\n\r\nExamples: \u5982\u679c\u578b\u8f93\u51fa\u7ed3\u679c\u6709\u591a\u4e2a\u503c\uff0c\u53ea\u7528\u5176\u4e2d\u76841\u4e2a\u503c\u8ba1\u7b97\u635f\u5931\uff0c\u9700\u8981\u81ea\u5b9a\u4e49loss_func\r\n-------\r\n\u5148\u5b9a\u4e49\u4e00\u4e2a\u8f93\u51fa\u4e3a\u591a\u4e2a\u503c\u7684\u6a21\u578b\r\n\r\n```\r\n#\u6a21\u578b\r\nclass output_multi(torch.nn.Module):\r\n def __init__(self):\r\n super(output_multi,self).__init__()\r\n\r\n self.model=resnet.ResNet50(n_class=3)\r\n\r\n def forward(self,x):\r\n y_hat=self.model(x)\r\n return y_hat,x#\u8f93\u51fa\u4e3a\u4e24\u4e2a\u503c\uff0c\u53ea\u4f7f\u7528y_hat\u8ba1\u7b97\u635f\u5931\r\n\r\nmodel=output_multi()\r\n```\r\n\r\n\u7136\u540e\u8bbe\u7f6e\u6a21\u578b\u8bad\u7ec3\u53c2\u6570\u65f6\uff0c\u5bf9loss_func\u8fdb\u884c\u4fee\u6539\uff0c\u7528\u5176\u4e2d\u9700\u8981\u53c2\u4e0eloss\u8ba1\u7b97\u7684\u90e8\u5206\u8ba1\u7b97loss\r\n\r\n```\r\nloss_func_=torch.nn.CrossEntropyLoss()\r\ndef loss_func(y_hat,y):\r\n return loss_func_(y_hat[0],y)#\u6307\u5b9a\u7528\u8f93\u578b\u8f93\u51fa\u7684\u7b2c0\u4e2a\u503c\u8ba1\u7b97loss\r\n\r\n#\u8bbe\u7f6e\u8bad\u7ec3\u53c2\u6570\r\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\r\noptimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-4)\r\nmodel=fit.set(model,optimizer=optimizer,loss_func=loss_func,n_iter=10,device=device)\r\n\r\n#\u8bad\u7ec3\r\nmodel.fit(train_loader=train_loader,eval_loader=eval_loader,mode=('input','target'))\r\n```\r\n\r\n",
"bugtrack_url": null,
"license": null,
"summary": "A python package about science compute algorithm, include natural language, image, neural network, optimize algorithm, machine learning, graphic algorithm, etc.",
"version": "0.8.0",
"project_urls": null,
"split_keywords": [
"science compute",
" image",
" natural language",
" machine learning",
" neural network",
" optimize algorithm",
" graphic algorithm"
],
"urls": [
{
"comment_text": "",
"digests": {
"blake2b_256": "87c18f7f6dd18ddc1ecc9158c8d7a62bd83177aa3a98788c277afc8344ffe9fa",
"md5": "675ce138831dc2bcdbee3284b004c3ed",
"sha256": "d3b5879c880c4dea00f34d760c5e4b2a1c2041734052021b9544fbc5c4dccba6"
},
"downloads": -1,
"filename": "scient-0.8.0-cp312-cp312-win_amd64.whl",
"has_sig": false,
"md5_digest": "675ce138831dc2bcdbee3284b004c3ed",
"packagetype": "bdist_wheel",
"python_version": "cp312",
"requires_python": ">=3.10",
"size": 6880126,
"upload_time": "2024-08-15T04:08:40",
"upload_time_iso_8601": "2024-08-15T04:08:40.989775Z",
"url": "https://files.pythonhosted.org/packages/87/c1/8f7f6dd18ddc1ecc9158c8d7a62bd83177aa3a98788c277afc8344ffe9fa/scient-0.8.0-cp312-cp312-win_amd64.whl",
"yanked": false,
"yanked_reason": null
},
{
"comment_text": "",
"digests": {
"blake2b_256": "546bfd89d1fa785c0f7e6326b3e648912789929a232105604263943ff808240b",
"md5": "95f57c84697fc6d9dd8c5ffd44d48f05",
"sha256": "3efbc7cd8efc3409c005b812f407fc19a3093e2794ae4d0309a7c3b8a0dfc9a3"
},
"downloads": -1,
"filename": "scient-0.8.0.tar.gz",
"has_sig": false,
"md5_digest": "95f57c84697fc6d9dd8c5ffd44d48f05",
"packagetype": "sdist",
"python_version": "source",
"requires_python": ">=3.10",
"size": 7472184,
"upload_time": "2024-08-15T04:08:17",
"upload_time_iso_8601": "2024-08-15T04:08:17.516486Z",
"url": "https://files.pythonhosted.org/packages/54/6b/fd89d1fa785c0f7e6326b3e648912789929a232105604263943ff808240b/scient-0.8.0.tar.gz",
"yanked": false,
"yanked_reason": null
}
],
"upload_time": "2024-08-15 04:08:17",
"github": false,
"gitlab": false,
"bitbucket": false,
"codeberg": false,
"lcname": "scient"
}