一、思路分析

websocket理解为http的升级版本即可

将所有用户抽象成对象User,User中应当包括一个连接和一个消息信道

数据处理器Hub:用于获取到某个用户发送的数据推送给每个用户

二、实现&核心代码

定义User:

type User struct {
    conn *websocket.Conn
    msg  chan []byte
}

定义数据处理器:

type Hub struct {
    //用户列表,保存所有用户
    userList map[*User]bool
    //注册chan,用户注册时添加到chan中
    register chan *User
    //注销chan,用户退出时添加到chan中,再从map中删除
    unregister chan *User
    //广播消息,将消息广播给所有连接
    broadcast chan []byte
}

数据处理器处理方法:

//处理中心处理获取到的信息

func (h *Hub) run() {
    for {
        select {
        //从注册chan中取数据
        case user := <-h.register:
            //取到数据后将数据添加到用户列表中
            h.userList[user] = true
        case user := <-h.unregister:
            //从注销列表中取数据,判断用户列表中是否存在这个用户,存在就删掉
            if _, ok := h.userList[user]; ok {
                delete(h.userList, user)
            }

        case data := <-h.broadcast:
            //从广播chan中取消息,然后遍历给每个用户,发送到用户的msg中
            for u := range h.userList {
                select {
                case u.msg <- data:
                default:
                    delete(h.userList, u)
                    close(u.msg)
                }
            }
        }
    }
}

定义websocket升级器:

//定义一个升级器,将普通的http连接升级为websocket连接

var up = &websocket.Upgrader{
    //定义读写缓冲区大小
    WriteBufferSize: 1024,
    ReadBufferSize:  1024,
    //校验请求
    CheckOrigin: func(r *http.Request) bool {
        //如果不是get请求,返回错误
        if r.Method != "GET" {
            fmt.Println("请求方式错误")
            return false
        }

        //如果路径中不包括chat,返回错误
        if r.URL.Path != "/chat" {
            fmt.Println("请求路径错误")
            return false
        }

        //还可以根据其他需求定制校验规则
        return true
    },
}

用户连接到服务的回调函数:用于处理用户的读写操作

func wsHandle(w http.ResponseWriter, r *http.Request) {
    //通过升级后的升级器得到链接
    conn, err := up.Upgrade(w, r, nil)
    if err != nil {
        fmt.Println("获取连接失败:", err)
        return
    }

    //连接成功后注册用户
    user := &User{
        conn: conn,
        msg:  make(chan []byte),
    }

    hub.register <- user
    defer func() {
        hub.unregister <- user
    }()

    //得到连接后,就可以开始读写数据了
    go read(user)
    write(user)

}

func read(user *User) {
    //从连接中循环读取信息
    for {
        _, msg, err := user.conn.ReadMessage()
        if err != nil {
            fmt.Println("用户退出:",user.conn.RemoteAddr().String())
            hub.unregister<-user
            break
        }

        //将读取到的信息传入websocket处理器中的broadcast中,
        hub.broadcast <- msg
    }
}

func write(user *User) {
    for data := range user.msg {
        err := user.conn.WriteMessage(1, data)
        if err != nil {
            fmt.Println("写入错误")
            break
        }
    }
}

启动程序:

func main() {
    //后台启动处理器
    go hub.run()
    http.HandleFunc("/chat", wsHandle)         //将chat请求交给wshandle处理
    http.ListenAndServe("127.0.0.1:8888", nil) //开始监听
}

三、完整代码,一个server.go文件,可以根据情况自己拆分

package main

import (
    "fmt"
    "github.com/gorilla/websocket"
    "net/http"
)

func main() {
    //后台启动处理器
    go hub.run()
    http.HandleFunc("/chat", wsHandle)         //将chat请求交给wshandle处理
    http.ListenAndServe("127.0.0.1:8888", nil) //开始监听
}

//定义一个websocket处理器,用于收集消息和广播消息

type Hub struct {
    //用户列表,保存所有用户
    userList map[*User]bool
    //注册chan,用户注册时添加到chan中
    register chan *User

    //注销chan,用户退出时添加到chan中,再从map中删除

    unregister chan *User

    //广播消息,将消息广播给所有连接
    broadcast chan []byte
}

//定义一个websocket连接对象,连接中包含每个连接的信息

type User struct {

    conn *websocket.Conn

    msg  chan []byte

}

//定义一个升级器,将普通的http连接升级为websocket连接

var up = &websocket.Upgrader{

    //定义读写缓冲区大小

    WriteBufferSize: 1024,

    ReadBufferSize:  1024,

    //校验请求

    CheckOrigin: func(r *http.Request) bool {

        //如果不是get请求,返回错误

        if r.Method != "GET" {

            fmt.Println("请求方式错误")

            return false

        }

        //如果路径中不包括chat,返回错误

        if r.URL.Path != "/chat" {

            fmt.Println("请求路径错误")

            return false

        }

        //还可以根据其他需求定制校验规则

        return true

    },

}

//初始化处理中心,以便调用

var hub = &Hub{

    userList:   make(map[*User]bool),

    register:   make(chan *User),

    unregister: make(chan *User),

    broadcast:  make(chan []byte),
}

func wsHandle(w http.ResponseWriter, r *http.Request) {

    //通过升级后的升级器得到链接

    conn, err := up.Upgrade(w, r, nil)

    if err != nil {

        fmt.Println("获取连接失败:", err)

        return

    }

    //连接成功后注册用户

    user := &User{

        conn: conn,
        msg:  make(chan []byte),
    }

    hub.register <- user

    defer func() {

        hub.unregister <- user

    }()

    //得到连接后,就可以开始读写数据了

    go read(user)

    write(user)

}

func read(user *User) {

    //从连接中循环读取信息

    for {

        _, msg, err := user.conn.ReadMessage()

        if err != nil {

            fmt.Println("用户退出:",user.conn.RemoteAddr().String())

            hub.unregister<-user

            break

        }

        //将读取到的信息传入websocket处理器中的broadcast中,

        hub.broadcast <- msg

    }

}

func write(user *User) {

    for data := range user.msg {

        err := user.conn.WriteMessage(1, data)

        if err != nil {

            fmt.Println("写入错误")

            break

        }
    }
}

//处理中心处理获取到的信息

func (h *Hub) run() {

    for {
        select {
        //从注册chan中取数据
        case user := <-h.register:
            //取到数据后将数据添加到用户列表中
            h.userList[user] = true
        case user := <-h.unregister:
            //从注销列表中取数据,判断用户列表中是否存在这个用户,存在就删掉
            if _, ok := h.userList[user]; ok {
                delete(h.userList, user)
            }

        case data := <-h.broadcast:

            //从广播chan中取消息,然后遍历给每个用户,发送到用户的msg中

            for u := range h.userList {
                select {
                case u.msg <- data:
                default:
                    delete(h.userList, u)
                    close(u.msg)
                }
            }

        }
    }

}

四、测试代码,使用go编写的一个客户端,多开模拟多个用户

package main

import (
    "bufio"
    "fmt"
    "github.com/gorilla/websocket"
    "io"
    "os"
    "sync"
)

var wg sync.WaitGroup

func main() {

    conn, _, err := websocket.DefaultDialer.Dial("ws://127.0.0.1:8888/chat", nil)

    if err != nil {
        fmt.Println("错误信息:", err)
    }

    wg.Add(2)
    go read(conn)
    go writeM(conn)
    wg.Wait()
}

func read(conn *websocket.Conn) {

    defer wg.Done()
    for {
        _, msg, err := conn.ReadMessage()
        if err != nil {
            fmt.Println("错误信息:", err)
            break
        }

        if err == io.EOF {
            continue
        }

        fmt.Println("获取到的信息:", string(msg))
    }
}

func writeM(conn *websocket.Conn) {
    defer wg.Done()
    for {
        fmt.Print("请输入:")
        reader := bufio.NewReader(os.Stdin)
        data, _ := reader.ReadString('\n')
        conn.WriteMessage(1, []byte(data))
    }

}

发表评论

邮箱地址不会被公开。 必填项已用*标注