云计算百科
云计算领域专业知识百科平台

flask搭建微服务器并训练CNN水果识别模型应用于网页

一. 搭建flask环境

概念

  • flask:一个轻量级 Web 应用框架,被设计为简单、灵活,能够快速启动一个 Web 项目。
  • CNN:深度学习模型,用于处理具有网格状拓扑结构的数据,如图像(2D网格)和视频(3D网格)。
  • PyTorch:开源的机器学习库,应用于如计算机视觉和自然语言处理等领域的深度学习。

flask环境搭建操作步骤: 

  • pycharm终端创建新的虚拟环境:python -m venv virtualName 。
  • 激活虚拟环境。
  • 在虚拟环境中安装flask。
  • 运行第一个前端网页。
  • 流程图例

    1.

    2.

    3.

    4.

    步骤4代码:

    from flask import Flask
    app = Flask(__name__)

    @app.route('/')
    def hello_world():
    return "<h1>hello world!</h1>"

    if __name__ == '__main__':
    app.run(debug=True)


    二. 训练水果模型

    水果识别CNN训练操作步骤: 

  • 准备数据集(kaggle官网可下载)。
  • 安装pyrorch。
  • 使用pytorch的nn模型定义参数。
  • 训练模型。
  • 得到训练好的pth模型。
  • 流程图例

    1.

    2.

    5.

    步骤3代码:

    import torch
    from torch import nn

    # 水果分类模型参数配置

    class NumberNet(nn.Module):
    def __init__(self, device, classes=10):
    super().__init__()
    if device is None:
    device = torch.device("cpu")
    if torch.cuda.is_available():
    device = torch.device("cuda:0")
    self.cnn = nn.Sequential(
    nn.Conv2d(3, 16, 3), # 100×100 -> 98×98
    nn.ReLU(),
    nn.MaxPool2d(2, 2), # 98×98 -> 49×49
    nn.Conv2d(16, 32, 3, padding=1), # 49×49 -> 49×49
    nn.ReLU(),
    nn.MaxPool2d(2, 2), # 49×49 -> 24×24
    nn.Conv2d(32, 64, 3, padding=1), # 24×24 -> 24×24
    nn.ReLU(),
    nn.MaxPool2d(2, 2), # 24×24 -> 12×12
    nn.Flatten(),
    nn.Dropout(),
    nn.Linear(64 * 12 * 12, 1024), # 调整线性层的输入特征数量
    nn.ReLU(),
    nn.Dropout(),
    nn.Linear(1024, classes),
    nn.LogSoftmax(dim=-1)
    )

    def forward(self, X):
    return self.cnn(X)

    步骤4代码:

    import torch
    from torch import nn
    from NumberNet import NumberNet
    from torchvision import transforms
    from torchvision.datasets import ImageFolder
    from torch.utils.data import random_split

    # 水果分类训练
    # 数据集配置
    # 假设 NumberNet 模型期望的输入是 3 通道彩色图像
    transform = transforms.Compose([
    transforms.ToTensor(), # 这将把 PIL 图像或 NumPy 数组转换为张量,并且范围从 [0, 255] 标准化到 [0.0, 1.0]
    # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 可选:标准化
    ])

    # 加载项目目录下的水果文件夹
    img_dataset = ImageFolder("../fruits", transform=transform)
    len_dataset = len(img_dataset)
    train_size = int(len_dataset * 0.8)
    valid_size = len_dataset – train_size
    train_dataset, valid_dataset = random_split(img_dataset, [train_size, valid_size])

    # 数据加载器
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1000, shuffle=True)
    valid_dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=1000)
    # batch_total 应该是 dataloader 的总批次数量,这里计算方式不正确
    batch_total = len(train_dataloader) # 应该直接使用 len(dataloader)

    # 使用conda或者cpu开始训练
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    epochs = 10
    model = NumberNet(device)
    criterion = nn.CrossEntropyLoss()
    adam = torch.optim.Adam(model.parameters(), lr=0.01)

    for epoch in range(epochs):
    losses = []
    for batch_num, (images, labels) in enumerate(train_dataloader, start=1): # 使用 enumerate 来获取批次编号
    adam.zero_grad()
    predict = model(images.to(device))
    loss = criterion(predict, labels.to(device))
    print(f"batch size: {batch_num} / {batch_total} — loss: {loss.item():.4f} ")
    losses.append(loss.item())
    loss.backward()
    adam.step()
    acc_list = []
    with torch.no_grad():
    for images, labels in valid_dataloader:
    predict = model(images.to(device))
    result = torch.argmax(predict, dim=-1)
    acc = (result == labels.to(device)).float().mean() # 使用 torch 的函数来计算准确率
    acc_list.append(acc.item())

    total_acc = sum(acc_list) / len(acc_list)
    total_loss = sum(losses) / batch_total
    print(f"epoch: {epoch + 1} / {epochs} — loss: {total_loss:.4f} — acc: {total_acc:.4f} ")

    # 保存模型参数,而不是整个模型
    torch.save(model, "../readyModel/model.pth")


     三. 将训练好的模型嵌入flask后端

    实现水果识别web操作步骤: 

  • 在虚拟化环境下创建.py后端启动文件,并且创建模型实例,同时将训练好的.pth文件放入代码对应的文件路径。
  • 创建index.html文件,作为后续前端文件。
  • 在前端代码和后端代码使用Jason进行路由。
  • 启动项目,实现功能。
  •  步骤1代码:

    from flask import Flask, render_template, request, jsonify
    import time
    import torch
    import cv2
    import numpy as np
    from FruitNet import FruitNet # 确保FruitNet定义是正确的

    app = Flask(__name__)

    # 定义设备
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 创建模型实例
    model = FruitNet(device=device, classes=5) # 确保类别数与训练时一致
    model.to(device)

    # 加载训练好的权重
    model.load_state_dict(torch.load("static/fruit_model.pth")) # 确保权重文件名为fruit_model.pth
    model.eval() # 设置模型为评估模式

    def predict_image(image_data):
    # 通过cv2加载图片数据
    img = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR)

    # 将图像从BGR转换为RGB格式(因为OpenCV默认加载的是BGR格式)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    # 调整图片大小到100×100(与训练时的输入大小一致)
    img = cv2.resize(img, (100, 100))

    # 在第一个位置增加一个维度,形成batch大小为1
    img = np.expand_dims(img, 0)

    # 将numpy对象转化为pytorch的tensor对象
    img = torch.from_numpy(img)

    # 调整图像通道顺序
    img = torch.permute(img, [0, 3, 1, 2]) # 转换为 (batch_size, channels, height, width)

    # 测试最终的结果
    with torch.no_grad(): # 关闭梯度计算
    img = img.to(device).float() # 确保输入是float类型,并发送到指定设备
    predict = model(img)
    predicted_class = torch.argmax(predict, dim=-1).item()

    # 定义水果类别标签
    fruit_classes = ["Apple Golden 1", "Banana", "Pear Red", "Tomato Heart", "Watermelon"] # 根据你的数据集定义类别标签

    # 输出预测的水果种类
    predicted_fruit = fruit_classes[predicted_class]
    return predicted_fruit

     步骤2代码:

    <!DOCTYPE html>
    <html lang="en">
    <head>
    <meta charset="UTF-8">
    <title>水果识别</title>
    <link rel="stylesheet" href="./static/css/index.css">
    <script src="./static/js/jquery-3.7.1.min.js"></script>
    </head>
    <body>
    <div class="main">
    <div>
    <!– 显示上传的图片 –>
    <div class="upload-img">
    <img id="upload-img" src="" alt="请上传图片"/>
    </div>

    <!– 表单用于上传图片 –>
    <form id="upload-btn" action="/upload" method="post" enctype="multipart/form-data">
    <input style="margin-left: 120px" type="file" name="the_file" id="selectImg"> <br/>
    <input type="submit" value="识别该水果">
    </form>
    </div>

    <!– 显示识别结果 –>
    <div class="result">
    <h2 id="result-show"></h2>
    </div>
    </div>

    <script>
    // 将文件转为 Base64 用于图片预览
    function convertToBase64(file, callback) {
    const reader = new FileReader();
    reader.onload = function(e) {
    callback(e.target.result);
    };
    reader.readAsDataURL(file);
    }

    $(function(){
    // 处理图片选择后的显示
    $("#selectImg").change(function(ev){
    const file = $(this)[0].files[0];
    if (file) {
    convertToBase64(file, function(base64Img){
    $("#upload-img").attr("src", base64Img); // 更新图片预览
    });
    }
    });

    // 处理表单提交
    $('#upload-btn').submit(function(ev){
    ev.preventDefault(); // 阻止默认表单提交

    var formData = new FormData(this); // 获取表单数据
    $.ajax({
    url: '/upload', // 请求的后端地址
    type: 'POST',
    data: formData,
    contentType: false,
    processData: false,
    success: function(response){
    console.log('文件上传成功');
    console.log(response);

    // 更新识别结果
    $('#result-show').text('识别结果:' + response.result); // 显示识别结果
    },
    error: function(error){
    console.error('文件上传失败');
    console.error(error);
    }
    });
    });
    });
    </script>
    </body>
    </html>

     步骤3代码:

    <script>
    // 将文件转为 Base64 用于图片预览
    function convertToBase64(file, callback) {
    const reader = new FileReader();
    reader.onload = function(e) {
    callback(e.target.result);
    };
    reader.readAsDataURL(file);
    }

    $(function(){
    // 处理图片选择后的显示
    $("#selectImg").change(function(ev){
    const file = $(this)[0].files[0];
    if (file) {
    convertToBase64(file, function(base64Img){
    $("#upload-img").attr("src", base64Img); // 更新图片预览
    });
    }
    });

    // 处理表单提交
    $('#upload-btn').submit(function(ev){
    ev.preventDefault(); // 阻止默认表单提交

    var formData = new FormData(this); // 获取表单数据
    $.ajax({
    url: '/upload', // 请求的后端地址
    type: 'POST',
    data: formData,
    contentType: false,
    processData: false,
    success: function(response){
    console.log('文件上传成功');
    console.log(response);

    // 更新识别结果
    $('#result-show').text('识别结果:' + response.result); // 显示识别结果
    },
    error: function(error){
    console.error('文件上传失败');
    console.error(error);
    }
    });
    });
    });
    </script>
    @app.route("/")
    def home():
    return render_template("index.html")

    @app.route('/upload', methods=['POST'])
    def upload_file():
    if request.method == 'POST':
    f = request.files['the_file']
    # 保存图片到静态目录
    timestamp = time.strftime("%Y%m%d%H%M%S")
    file_path = f'./static/uploads/{timestamp}.png'
    f.save(file_path)

    # 读取保存后的图片数据并预测
    with open(file_path, 'rb') as image_file:
    image_data = image_file.read()

    predicted_fruit = predict_image(image_data)

    # 返回JSON数据
    return jsonify({
    'file_id': timestamp,
    'result': predicted_fruit,
    'img_path': f'/static/uploads/{timestamp}.png'
    })

      步骤4实现效果:

    赞(0)
    未经允许不得转载:网硕互联帮助中心 » flask搭建微服务器并训练CNN水果识别模型应用于网页
    分享到: 更多 (0)

    评论 抢沙发

    评论前必须登录!