Remove Extensible parent from EventHandler

This also fixes SSL certificate support when m_sslinfo is not loaded

git-svn-id: http://svn.inspircd.org/repository/trunk/inspircd@12048 e03df62e-2008-0410-955e-edbf42e46eb7
This commit is contained in:
danieldg 2009-11-06 22:37:52 +00:00
parent a26502ff51
commit eaace5ed7c
9 changed files with 95 additions and 55 deletions

View File

@ -151,7 +151,7 @@ enum EventMask
* must have a file descriptor. What this file descriptor * must have a file descriptor. What this file descriptor
* is actually attached to is completely up to you. * is actually attached to is completely up to you.
*/ */
class CoreExport EventHandler : public Extensible class CoreExport EventHandler : public classbase
{ {
private: private:
/** Private state maintained by socket engine */ /** Private state maintained by socket engine */

View File

@ -77,13 +77,10 @@ static ssize_t gnutls_push_wrapper(gnutls_transport_ptr_t user_wrap, const void*
class issl_session class issl_session
{ {
public: public:
issl_session()
{
sess = NULL;
}
gnutls_session_t sess; gnutls_session_t sess;
issl_status status; issl_status status;
reference<ssl_cert> cert;
issl_session() : sess(NULL) {}
}; };
class CommandStartTLS : public SplitCommand class CommandStartTLS : public SplitCommand
@ -332,11 +329,15 @@ class ModuleSSLGnuTLS : public Module
void OnRequest(Request& request) void OnRequest(Request& request)
{ {
Module* sslinfo = ServerInstance->Modules->Find("m_sslinfo.so"); if (strcmp("GET_SSL_CERT", request.id) == 0)
if (sslinfo) {
sslinfo->OnRequest(request); SocketCertificateRequest& req = static_cast<SocketCertificateRequest&>(request);
} int fd = req.sock->GetFd();
issl_session* session = &sessions[fd];
req.cert = session->cert;
}
}
void OnStreamSocketAccept(StreamSocket* user, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) void OnStreamSocketAccept(StreamSocket* user, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server)
{ {
@ -548,10 +549,11 @@ class ModuleSSLGnuTLS : public Module
void OnUserConnect(LocalUser* user) void OnUserConnect(LocalUser* user)
{ {
if (user->GetIOHook() == this) if (user->eh.GetIOHook() == this)
{ {
if (sessions[user->GetFd()].sess) if (sessions[user->GetFd()].sess)
{ {
SSLCertSubmission(user, this, ServerInstance->Modules->Find("m_sslinfo.so"), sessions[user->GetFd()].cert);
std::string cipher = gnutls_kx_get_name(gnutls_kx_get(sessions[user->GetFd()].sess)); std::string cipher = gnutls_kx_get_name(gnutls_kx_get(sessions[user->GetFd()].sess));
cipher.append("-").append(gnutls_cipher_get_name(gnutls_cipher_get(sessions[user->GetFd()].sess))).append("-"); cipher.append("-").append(gnutls_cipher_get_name(gnutls_cipher_get(sessions[user->GetFd()].sess))).append("-");
cipher.append(gnutls_mac_get_name(gnutls_mac_get(sessions[user->GetFd()].sess))); cipher.append(gnutls_mac_get_name(gnutls_mac_get(sessions[user->GetFd()].sess)));
@ -562,23 +564,19 @@ class ModuleSSLGnuTLS : public Module
void CloseSession(issl_session* session) void CloseSession(issl_session* session)
{ {
if(session->sess) if (session->sess)
{ {
gnutls_bye(session->sess, GNUTLS_SHUT_WR); gnutls_bye(session->sess, GNUTLS_SHUT_WR);
gnutls_deinit(session->sess); gnutls_deinit(session->sess);
} }
session->sess = NULL; session->sess = NULL;
session->cert = NULL;
session->status = ISSL_NONE; session->status = ISSL_NONE;
} }
void VerifyCertificate(issl_session* session, Extensible* user) void VerifyCertificate(issl_session* session, StreamSocket* user)
{ {
if (!session->sess || !user) if (!session->sess || !user || session->cert)
return;
Module* sslinfo = ServerInstance->Modules->Find("m_sslinfo.so");
if (!sslinfo)
return; return;
unsigned int status; unsigned int status;
@ -591,6 +589,7 @@ class ModuleSSLGnuTLS : public Module
size_t digest_size = sizeof(digest); size_t digest_size = sizeof(digest);
size_t name_size = sizeof(name); size_t name_size = sizeof(name);
ssl_cert* certinfo = new ssl_cert; ssl_cert* certinfo = new ssl_cert;
session->cert = certinfo;
/* This verification function uses the trusted CAs in the credentials /* This verification function uses the trusted CAs in the credentials
* structure. So you must have installed one or more CA certificates. * structure. So you must have installed one or more CA certificates.
@ -600,7 +599,7 @@ class ModuleSSLGnuTLS : public Module
if (ret < 0) if (ret < 0)
{ {
certinfo->error = std::string(gnutls_strerror(ret)); certinfo->error = std::string(gnutls_strerror(ret));
goto info_done; return;
} }
certinfo->invalid = (status & GNUTLS_CERT_INVALID); certinfo->invalid = (status & GNUTLS_CERT_INVALID);
@ -615,14 +614,14 @@ class ModuleSSLGnuTLS : public Module
if (gnutls_certificate_type_get(session->sess) != GNUTLS_CRT_X509) if (gnutls_certificate_type_get(session->sess) != GNUTLS_CRT_X509)
{ {
certinfo->error = "No X509 keys sent"; certinfo->error = "No X509 keys sent";
goto info_done; return;
} }
ret = gnutls_x509_crt_init(&cert); ret = gnutls_x509_crt_init(&cert);
if (ret < 0) if (ret < 0)
{ {
certinfo->error = gnutls_strerror(ret); certinfo->error = gnutls_strerror(ret);
goto info_done; return;
} }
cert_list_size = 0; cert_list_size = 0;
@ -668,8 +667,6 @@ class ModuleSSLGnuTLS : public Module
info_done_dealloc: info_done_dealloc:
gnutls_x509_crt_deinit(cert); gnutls_x509_crt_deinit(cert);
info_done:
SSLCertSubmission(user, this, sslinfo, certinfo);
} }
void OnEvent(Event& ev) void OnEvent(Event& ev)

View File

@ -53,6 +53,7 @@ class issl_session
public: public:
SSL* sess; SSL* sess;
issl_status status; issl_status status;
reference<ssl_cert> cert;
int fd; int fd;
bool outbound; bool outbound;
@ -125,7 +126,7 @@ class ModuleSSLOpenSSL : public Module
// Needs the flag as it ignores a plain /rehash // Needs the flag as it ignores a plain /rehash
OnModuleRehash(NULL,"ssl"); OnModuleRehash(NULL,"ssl");
Implementation eventlist[] = { I_On005Numeric, I_OnRehash, I_OnModuleRehash, I_OnHookIO }; Implementation eventlist[] = { I_On005Numeric, I_OnRehash, I_OnModuleRehash, I_OnHookIO, I_OnUserConnect };
ServerInstance->Modules->Attach(eventlist, this, sizeof(eventlist)/sizeof(Implementation)); ServerInstance->Modules->Attach(eventlist, this, sizeof(eventlist)/sizeof(Implementation));
} }
@ -244,6 +245,17 @@ class ModuleSSLOpenSSL : public Module
delete[] sessions; delete[] sessions;
} }
void OnUserConnect(LocalUser* user)
{
if (user->eh.GetIOHook() == this)
{
if (sessions[user->GetFd()].sess)
{
SSLCertSubmission(user, this, ServerInstance->Modules->Find("m_sslinfo.so"), sessions[user->GetFd()].cert);
}
}
}
void OnCleanup(int target_type, void* item) void OnCleanup(int target_type, void* item)
{ {
if (target_type == TYPE_USER) if (target_type == TYPE_USER)
@ -264,14 +276,17 @@ class ModuleSSLOpenSSL : public Module
return Version("Provides SSL support for clients", VF_VENDOR); return Version("Provides SSL support for clients", VF_VENDOR);
} }
void OnRequest(Request& request) void OnRequest(Request& request)
{ {
Module* sslinfo = ServerInstance->Modules->Find("m_sslinfo.so"); if (strcmp("GET_SSL_CERT", request.id) == 0)
if (sslinfo) {
sslinfo->OnRequest(request); SocketCertificateRequest& req = static_cast<SocketCertificateRequest&>(request);
} int fd = req.sock->GetFd();
issl_session* session = &sessions[fd];
req.cert = session->cert;
}
}
void OnStreamSocketAccept(StreamSocket* user, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server) void OnStreamSocketAccept(StreamSocket* user, irc::sockets::sockaddrs* client, irc::sockets::sockaddrs* server)
{ {
@ -472,7 +487,7 @@ class ModuleSSLOpenSSL : public Module
return 0; return 0;
} }
bool Handshake(EventHandler* user, issl_session* session) bool Handshake(StreamSocket* user, issl_session* session)
{ {
int ret; int ret;
@ -537,17 +552,14 @@ class ModuleSSLOpenSSL : public Module
errno = EIO; errno = EIO;
} }
void VerifyCertificate(issl_session* session, Extensible* user) void VerifyCertificate(issl_session* session, StreamSocket* user)
{ {
if (!session->sess || !user) if (!session->sess || !user || session->cert)
return;
Module* sslinfo = ServerInstance->Modules->Find("m_sslinfo.so");
if (!sslinfo)
return; return;
X509* cert; X509* cert;
ssl_cert* certinfo = new ssl_cert; ssl_cert* certinfo = new ssl_cert;
session->cert = certinfo;
unsigned int n; unsigned int n;
unsigned char md[EVP_MAX_MD_SIZE]; unsigned char md[EVP_MAX_MD_SIZE];
const EVP_MD *digest = EVP_md5(); const EVP_MD *digest = EVP_md5();
@ -557,7 +569,6 @@ class ModuleSSLOpenSSL : public Module
if (!cert) if (!cert)
{ {
certinfo->error = "Could not get peer certificate: "+std::string(get_error()); certinfo->error = "Could not get peer certificate: "+std::string(get_error());
SSLCertSubmission(user, this, sslinfo, certinfo);
return; return;
} }
@ -592,7 +603,6 @@ class ModuleSSLOpenSSL : public Module
} }
X509_free(cert); X509_free(cert);
SSLCertSubmission(user, this, sslinfo, certinfo);
} }
}; };

View File

@ -128,7 +128,7 @@ bool TreeSocket::ComparePass(const Link& link, const std::string &theirs)
std::string fp; std::string fp;
if (GetIOHook()) if (GetIOHook())
{ {
SSLCertificateRequest req(this, Utils->Creator); SocketCertificateRequest req(this, Utils->Creator, GetIOHook());
if (req.cert) if (req.cert)
{ {
fp = req.cert->GetFingerprint(); fp = req.cert->GetFingerprint();

View File

@ -610,6 +610,14 @@ void ModuleSpanningTree::OnUserConnect(LocalUser* user)
params.push_back(":"+std::string(user->fullname)); params.push_back(":"+std::string(user->fullname));
Utils->DoOneToMany(ServerInstance->Config->GetSID(), "UID", params); Utils->DoOneToMany(ServerInstance->Config->GetSID(), "UID", params);
for(Extensible::ExtensibleStore::const_iterator i = user->GetExtList().begin(); i != user->GetExtList().end(); i++)
{
ExtensionItem* item = i->first;
std::string value = item->serialize(FORMAT_NETWORK, user, i->second);
if (!value.empty())
ProtoSendMetaData(this, user, item->key, value);
}
Utils->TreeRoot->SetUserCount(1); // increment by 1 Utils->TreeRoot->SetUserCount(1); // increment by 1
} }

View File

@ -220,7 +220,7 @@ void TreeSocket::SendChannelModes(TreeServer* Current)
this->WriteLine(data); this->WriteLine(data);
} }
for(ExtensibleStore::const_iterator i = c->second->GetExtList().begin(); i != c->second->GetExtList().end(); i++) for(Extensible::ExtensibleStore::const_iterator i = c->second->GetExtList().begin(); i != c->second->GetExtList().end(); i++)
{ {
ExtensionItem* item = i->first; ExtensionItem* item = i->first;
std::string value = item->serialize(FORMAT_NETWORK, c->second, i->second); std::string value = item->serialize(FORMAT_NETWORK, c->second, i->second);
@ -269,7 +269,7 @@ void TreeSocket::SendUsers(TreeServer* Current)
} }
} }
for(ExtensibleStore::const_iterator i = u->second->GetExtList().begin(); i != u->second->GetExtList().end(); i++) for(Extensible::ExtensibleStore::const_iterator i = u->second->GetExtList().begin(); i != u->second->GetExtList().end(); i++)
{ {
ExtensionItem* item = i->first; ExtensionItem* item = i->first;
std::string value = item->serialize(FORMAT_NETWORK, u->second, i->second); std::string value = item->serialize(FORMAT_NETWORK, u->second, i->second);

View File

@ -25,8 +25,10 @@ class SSLCertExt : public ExtensionItem {
} }
void set(Extensible* item, ssl_cert* value) void set(Extensible* item, ssl_cert* value)
{ {
value->refcount_inc();
ssl_cert* old = static_cast<ssl_cert*>(set_raw(item, value)); ssl_cert* old = static_cast<ssl_cert*>(set_raw(item, value));
delete old; if (old && old->refcount_dec())
delete old;
} }
std::string serialize(SerializeFormat format, const Extensible* container, void* item) const std::string serialize(SerializeFormat format, const Extensible* container, void* item) const
@ -61,7 +63,9 @@ class SSLCertExt : public ExtensionItem {
void free(void* item) void free(void* item)
{ {
delete static_cast<ssl_cert*>(item); ssl_cert* old = static_cast<ssl_cert*>(item);
if (old && old->refcount_dec())
delete old;
} }
}; };
@ -228,10 +232,10 @@ class ModuleSSLInfo : public Module
void OnRequest(Request& request) void OnRequest(Request& request)
{ {
if (strcmp("GET_CERT", request.id) == 0) if (strcmp("GET_USER_CERT", request.id) == 0)
{ {
SSLCertificateRequest& req = static_cast<SSLCertificateRequest&>(request); UserCertificateRequest& req = static_cast<UserCertificateRequest&>(request);
req.cert = cmd.CertExt.get(req.item); req.cert = cmd.CertExt.get(req.user);
} }
else if (strcmp("SET_CERT", request.id) == 0) else if (strcmp("SET_CERT", request.id) == 0)
{ {

View File

@ -34,7 +34,7 @@ class SSLMode : public ModeHandler
const UserMembList* userlist = channel->GetUsers(); const UserMembList* userlist = channel->GetUsers();
for(UserMembCIter i = userlist->begin(); i != userlist->end(); i++) for(UserMembCIter i = userlist->begin(); i != userlist->end(); i++)
{ {
SSLCertificateRequest req(i->first, creator); UserCertificateRequest req(i->first, creator);
req.Send(); req.Send();
if(!req.cert && !ServerInstance->ULine(i->first->server)) if(!req.cert && !ServerInstance->ULine(i->first->server))
{ {
@ -83,7 +83,7 @@ class ModuleSSLModes : public Module
{ {
if(chan && chan->IsModeSet('z')) if(chan && chan->IsModeSet('z'))
{ {
SSLCertificateRequest req(user, this); UserCertificateRequest req(user, this);
req.Send(); req.Send();
if (req.cert) if (req.cert)
{ {
@ -105,7 +105,7 @@ class ModuleSSLModes : public Module
{ {
if (mask[0] == 'z' && mask[1] == ':') if (mask[0] == 'z' && mask[1] == ':')
{ {
SSLCertificateRequest req(user, this); UserCertificateRequest req(user, this);
req.Send(); req.Send();
if (req.cert && InspIRCd::Match(req.cert->GetFingerprint(), mask.substr(2))) if (req.cert && InspIRCd::Match(req.cert->GetFingerprint(), mask.substr(2)))
return MOD_RES_DENY; return MOD_RES_DENY;

View File

@ -25,7 +25,7 @@
* in a unified manner. These classes are attached to ssl- * in a unified manner. These classes are attached to ssl-
* connected local users using SSLCertExt * connected local users using SSLCertExt
*/ */
class ssl_cert class ssl_cert : public refcountbase
{ {
public: public:
std::string dn; std::string dn;
@ -118,13 +118,34 @@ class ssl_cert
} }
}; };
struct SSLCertificateRequest : public Request /** Get certificate from a socket (only useful with an SSL module) */
struct SocketCertificateRequest : public Request
{ {
Extensible* const item; StreamSocket* const sock;
ssl_cert* cert; ssl_cert* cert;
SSLCertificateRequest(Extensible* e, Module* Me, Module* info = ServerInstance->Modules->Find("m_sslinfo.so")) SocketCertificateRequest(StreamSocket* ss, Module* Me, Module* hook)
: Request(Me, info, "GET_CERT"), item(e), cert(NULL) : Request(Me, hook, "GET_SSL_CERT"), sock(ss), cert(NULL)
{
Send();
}
std::string GetFingerprint()
{
if (cert)
return cert->GetFingerprint();
return "";
}
};
/** Get certificate from a user (requires m_sslinfo) */
struct UserCertificateRequest : public Request
{
User* const user;
ssl_cert* cert;
UserCertificateRequest(User* u, Module* Me, Module* info = ServerInstance->Modules->Find("m_sslinfo.so"))
: Request(Me, info, "GET_USER_CERT"), user(u), cert(NULL)
{ {
Send(); Send();
} }